Fold the recursive function
Explicit recursion to Fold blog
The other day I was reviewing a pull request at work, and while scrolling through the changes my eyes fell on this bit of code (this is an anonymous version of it):
object RecursiveFunction2Fold { type ErrorMessage = String case class Filter[A](predicate: A => Boolean, filterError: String) def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] = { def recursiveFiltering(elements: Seq[A], filters: Seq[Filter[A]]): Either[ErrorMessage, Seq[A]] = filters match { case Nil => Right(elements) case head :: remainingFilters => filterOnce(elements, head) match { case Left(err) => Left(err) case Right(filtered) => recursiveFiltering(filtered, remainingFilters) } } recursiveFiltering(elements, filters) } def filterOnce[A](elements: Seq[A], filter: Filter[A]): Either[ErrorMessage, Seq[A]] = elements.filter(filter.predicate) match { case Nil => Left(filter.filterError) case remaining => Right(remaining) } }
You may notice how this function uses explicit recursion (the
recursiveFiltering
bit). There is a cool paper about how fold
makes explicit recursion unnecessary really:
http://www.cs.nott.ac.uk/~pszgmh/fold.pdf
So, I asked my colleague if there was any chance to make this function
more concise through fold
. Naturally the office mayhem did not let
us have time for this, and for keeping the business running the
function was shipped as is.
Well, that function did not escape my mind and I wanted to share how to make a transformation from explicitly recursive function to fold based functions happen.
Discovering what the function does
Let's get what this recursive function tries to do.
def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] = {
The signature of the outer function gives as input a list of elements
and a list of filters and as output either an error or a sequence of
elements. At this point a guess may be that we are going to apply each
filter in the filters
sequence to the given elements
.
def recursiveFiltering(elements: Seq[A], filters: Seq[Filter[A]]): Either[ErrorMessage, Seq[A]] = filters match {
The recursive function has the same signature, so nothing interesting.
case Nil => Right(elements)
We are now looking at what our sequence of filters
really is: if
empty, we return the original list of elements
.
case head :: remainingFilters => filterOnce(elements, head) match {
Otherwise we apply the given filter (head
), and we check its output.
case Left(err) => Left(err)
If the output is an error, we end with the same error.
case Right(filtered) => recursiveFiltering(filtered, remainingFilters) }
Otherwise we continue with the rest of the filters.
} recursiveFiltering(elements, filters) }
By the way, filterOnce
does the filtering (through
filter.predicate
) on the sequence of elements
:
def filterOnce[A](elements: Seq[A], filter: Filter[A]): Either[ErrorMessage, Seq[A]] = elements.filter(filter.predicate) match { ...
So, as far as we can tell this function:
- tries to run all the filters against the input sequence (to reduce
the
elements
size) - has to fail when the filter produces an empty sequence of elements
Assert what we found out in the code
So we understood something about the function. Let's check out that our understanding is right with some assertions in the code (i.e., lightweight tests) that we will keep using during our refactoring later.
object RecursiveFunction2Fold { type ErrorMessage = String case class Filter[A](predicate: A => Boolean, filterError: String) def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] = { def recursiveFiltering(elements: Seq[A], filters: Seq[Filter[A]]): Either[ErrorMessage, Seq[A]] = filters match { case Nil => Right(elements) case head :: remainingFilters => filterOnce(elements, head) match { case Left(err) => Left(err) case Right(filtered) => recursiveFiltering(filtered, remainingFilters) } } recursiveFiltering(elements, filters) } def filterOnce[A](elements: Seq[A], filter: Filter[A]): Either[ErrorMessage, Seq[A]] = elements.filter(filter.predicate) match { case Nil => Left(filter.filterError) case remaining => Right(remaining) } } // Let's test! import RecursiveFunction2Fold._ val neverFailingFilter: Filter[Int] = Filter({i: Int => true}, "Unreachable") val passingFilters: Seq[Filter[Int]] = Seq(neverFailingFilter) val reducingFilters: Seq[Filter[Int]] = Seq(neverFailingFilter, Filter({i: Int => i < 42}, "Oh no, all the given integers are greater than 42!")) val breakingFilters: Seq[Filter[Int]] = Seq(neverFailingFilter, Filter({i: Int => i < -1}, "Oh no, all the given integers are greater than -1!")) val elements: Seq[Int] = Seq(1,2,3,42,5) // tries to run all the filters against the input sequence (to reduce the ~elements~ size) assert(filterMany(elements, passingFilters) == Right(List(1, 2, 3, 42, 5))) assert(filterMany(elements, reducingFilters) == Right(List(1, 2, 3, 5))) // has to fail when the filter produces an empty sequence of elements assert(filterMany(elements, breakingFilters) == Left("Oh no, all the given integers are greater than -1!"))
Refactor to fold it!
Now we are ready for our little refactoring exercise.
Let's review the signature of foldLeft
:
def foldLeft[B](z: B)(op: (B, A) => B): B
We have a default value z
in case the sequence on which we fold is
empty. And then we have an operation that takes an accumulator of type
B
(the result type) and an element of the sequence of type A
.
Now let's give another look at the recursive function:
def recursiveFiltering(elements: Seq[A], filters: Seq[Filter[A]]): Either[ErrorMessage, Seq[A]] = filters match { case Nil => Right(elements) case head :: remainingFilters => filterOnce(elements, head) match { case Left(err) => Left(err) case Right(filtered) => recursiveFiltering(filtered, remainingFilters) } } recursiveFiltering(elements, filters)
The first question for our refactoring is: what is the default value
z
? We can ask another question to answer this: what is the base case
for the explicit recursion? Here it is:
case Nil => Right(elements)
So let's start rewriting our function:
def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] = filters.foldLeft(Right(elements): Either[ErrorMessage, Seq[A]])((acc, e) => ???)
Next we need to look at the core of the recursive function: the recursive case.
case head :: remainingFilters => filterOnce(elements, head) match { case Left(err) => Left(err) case Right(filtered) => recursiveFiltering(filtered, remainingFilters) }
Here we take the first filter, which is e
in our fold; we have to
apply with filterOnce
and handle the cases. One thing to remember is
that our fold's accumulator is storing the elements that we have
filtered so far. Let's try to translate that in our fold version:
def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] = filters.foldLeft(Right(elements): Either[ErrorMessage, Seq[A]])((acc: Either[ErrorMessage, Seq[A]], e: Filter[A]) => acc match { case Left(err) => Left(err) // the previous run produced an error, no need to apply any other filter case Right(filtered) => filterOnce(filtered,e) // filterOnce produces a Either itself })
The idea here is that we start fold with an accumulator containing a
Rigth with all the elements in it, picking the first filter we will
end up in the case Right(filtered)
, and we will produce an Either
.
Now we will repeat this for each filter and we will stop if
filterOnce
produces a Left
.
Before testing this version let's apply some functional refinement:
this pattern matching really can be removed with a flatMap
.
def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] = filters.foldLeft(Right(elements): Either[ErrorMessage, Seq[A]])((acc: Either[ErrorMessage, Seq[A]], e: Filter[A]) => acc.flatMap(filterOnce(_,e)))
So our final functions looks:
object RecursiveFunction2Fold { type ErrorMessage = String case class Filter[A](predicate: A => Boolean, filterError: String) def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] = filters.foldLeft(Right(elements): Either[ErrorMessage, Seq[A]])((acc: Either[ErrorMessage, Seq[A]], e: Filter[A]) => acc.flatMap(filterOnce(_,e))) def filterOnce[A](elements: Seq[A], filter: Filter[A]): Either[ErrorMessage, Seq[A]] = elements.filter(filter.predicate) match { case Nil => Left(filter.filterError) case remaining => Right(remaining) } }
Now time to test our refactoring!
object RecursiveFunction2Fold { type ErrorMessage = String case class Filter[A](predicate: A => Boolean, filterError: String) def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] = filters.foldLeft(Right(elements): Either[ErrorMessage, Seq[A]])((acc: Either[ErrorMessage, Seq[A]], e: Filter[A]) => acc.flatMap(filterOnce(_,e))) def filterOnce[A](elements: Seq[A], filter: Filter[A]): Either[ErrorMessage, Seq[A]] = elements.filter(filter.predicate) match { case Nil => Left(filter.filterError) case remaining => Right(remaining) } } // Let's test! import RecursiveFunction2Fold._ val neverFailingFilter: Filter[Int] = Filter({i: Int => true}, "Unreachable") val passingFilters: Seq[Filter[Int]] = Seq(neverFailingFilter) val reducingFilters: Seq[Filter[Int]] = Seq(neverFailingFilter, Filter({i: Int => i < 42}, "Oh no, all the given integers are greater than 42!")) val breakingFilters: Seq[Filter[Int]] = Seq(neverFailingFilter, Filter({i: Int => i < -1}, "Oh no, all the given integers are greater than -1!")) val elements: Seq[Int] = Seq(1,2,3,42,5) // tries to run all the filters against the input sequence (to reduce the ~elements~ size) assert(filterMany(elements, passingFilters) == Right(List(1, 2, 3, 42, 5))) assert(filterMany(elements, reducingFilters) == Right(List(1, 2, 3, 5))) // has to fail when the filter produces an empty sequence of elements assert(filterMany(elements, breakingFilters) == Left("Oh no, all the given integers are greater than -1!"))
Done! Tests pass, and the refactoring seems successful. Note that I kept the number of tests at minimum to ease the readability: you may want to have more tests to feel more secure that the refactoring was successful.
Happy coding!