Where parallels cross

Interesting bits of life

Fold the recursive function

Explicit recursion to Fold   blog

The other day I was reviewing a pull request at work, and while scrolling through the changes my eyes fell on this bit of code (this is an anonymous version of it):

object RecursiveFunction2Fold {
  type ErrorMessage = String

  case class Filter[A](predicate: A => Boolean, filterError: String)

  def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] = {
    def recursiveFiltering(elements: Seq[A], filters: Seq[Filter[A]]): Either[ErrorMessage, Seq[A]] = filters match {
      case Nil => Right(elements)
      case head :: remainingFilters => filterOnce(elements, head) match {
        case Left(err) => Left(err)
        case Right(filtered) => recursiveFiltering(filtered, remainingFilters)
      }
    }
    recursiveFiltering(elements, filters)
  }

  def filterOnce[A](elements: Seq[A], filter: Filter[A]): Either[ErrorMessage, Seq[A]] = elements.filter(filter.predicate) match {
    case Nil => Left(filter.filterError)
    case remaining => Right(remaining)
  }
}

You may notice how this function uses explicit recursion (the recursiveFiltering bit). There is a cool paper about how fold makes explicit recursion unnecessary really: http://www.cs.nott.ac.uk/~pszgmh/fold.pdf

So, I asked my colleague if there was any chance to make this function more concise through fold. Naturally the office mayhem did not let us have time for this, and for keeping the business running the function was shipped as is.

Well, that function did not escape my mind and I wanted to share how to make a transformation from explicitly recursive function to fold based functions happen.

Discovering what the function does

Let's get what this recursive function tries to do.

def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] = {

The signature of the outer function gives as input a list of elements and a list of filters and as output either an error or a sequence of elements. At this point a guess may be that we are going to apply each filter in the filters sequence to the given elements.

def recursiveFiltering(elements: Seq[A], filters: Seq[Filter[A]]): Either[ErrorMessage, Seq[A]] = filters match {

The recursive function has the same signature, so nothing interesting.

case Nil => Right(elements)

We are now looking at what our sequence of filters really is: if empty, we return the original list of elements.

case head :: remainingFilters => filterOnce(elements, head) match {

Otherwise we apply the given filter (head), and we check its output.

case Left(err) => Left(err)

If the output is an error, we end with the same error.

  case Right(filtered) => recursiveFiltering(filtered, remainingFilters)
}

Otherwise we continue with the rest of the filters.

  }
  recursiveFiltering(elements, filters)
}

By the way, filterOnce does the filtering (through filter.predicate) on the sequence of elements:

def filterOnce[A](elements: Seq[A], filter: Filter[A]): Either[ErrorMessage, Seq[A]] = elements.filter(filter.predicate) match {
...

So, as far as we can tell this function:

  • tries to run all the filters against the input sequence (to reduce the elements size)
  • has to fail when the filter produces an empty sequence of elements

Assert what we found out in the code

So we understood something about the function. Let's check out that our understanding is right with some assertions in the code (i.e., lightweight tests) that we will keep using during our refactoring later.

object RecursiveFunction2Fold {
  type ErrorMessage = String

  case class Filter[A](predicate: A => Boolean, filterError: String)

  def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] = {
    def recursiveFiltering(elements: Seq[A], filters: Seq[Filter[A]]): Either[ErrorMessage, Seq[A]] = filters match {
      case Nil => Right(elements)
      case head :: remainingFilters => filterOnce(elements, head) match {
        case Left(err) => Left(err)
        case Right(filtered) => recursiveFiltering(filtered, remainingFilters)
      }
    }
    recursiveFiltering(elements, filters)
  }

  def filterOnce[A](elements: Seq[A], filter: Filter[A]): Either[ErrorMessage, Seq[A]] = elements.filter(filter.predicate) match {
    case Nil => Left(filter.filterError)
    case remaining => Right(remaining)
  }
}

// Let's test!

import RecursiveFunction2Fold._

val neverFailingFilter: Filter[Int] = Filter({i: Int => true}, "Unreachable")
val passingFilters: Seq[Filter[Int]] = Seq(neverFailingFilter)
val reducingFilters: Seq[Filter[Int]] = Seq(neverFailingFilter, Filter({i: Int => i < 42}, "Oh no, all the given integers are greater than 42!"))
val breakingFilters: Seq[Filter[Int]] = Seq(neverFailingFilter, Filter({i: Int => i < -1}, "Oh no, all the given integers are greater than -1!"))
val elements: Seq[Int] = Seq(1,2,3,42,5)

// tries to run all the filters against the input sequence (to reduce the ~elements~ size)
assert(filterMany(elements, passingFilters) == Right(List(1, 2, 3, 42, 5)))
assert(filterMany(elements, reducingFilters) == Right(List(1, 2, 3, 5)))
// has to fail when the filter produces an empty sequence of elements
assert(filterMany(elements, breakingFilters) == Left("Oh no, all the given integers are greater than -1!"))

Refactor to fold it!

Now we are ready for our little refactoring exercise. Let's review the signature of foldLeft:

def foldLeft[B](z: B)(op: (B, A) => B): B

We have a default value z in case the sequence on which we fold is empty. And then we have an operation that takes an accumulator of type B (the result type) and an element of the sequence of type A.

Now let's give another look at the recursive function:

def recursiveFiltering(elements: Seq[A], filters: Seq[Filter[A]]): Either[ErrorMessage, Seq[A]] = filters match {
  case Nil => Right(elements)
  case head :: remainingFilters => filterOnce(elements, head) match {
    case Left(err) => Left(err)
    case Right(filtered) => recursiveFiltering(filtered, remainingFilters)
  }
}
recursiveFiltering(elements, filters)

The first question for our refactoring is: what is the default value z? We can ask another question to answer this: what is the base case for the explicit recursion? Here it is:

case Nil => Right(elements)

So let's start rewriting our function:

def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] =
  filters.foldLeft(Right(elements): Either[ErrorMessage, Seq[A]])((acc, e) => ???)

Next we need to look at the core of the recursive function: the recursive case.

case head :: remainingFilters => filterOnce(elements, head) match {
  case Left(err) => Left(err)
  case Right(filtered) => recursiveFiltering(filtered, remainingFilters)
}

Here we take the first filter, which is e in our fold; we have to apply with filterOnce and handle the cases. One thing to remember is that our fold's accumulator is storing the elements that we have filtered so far. Let's try to translate that in our fold version:

def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] =
  filters.foldLeft(Right(elements): Either[ErrorMessage, Seq[A]])((acc: Either[ErrorMessage, Seq[A]], e: Filter[A]) => acc match {
    case Left(err) => Left(err) // the previous run produced an error, no need to apply any other filter
    case Right(filtered) => filterOnce(filtered,e) // filterOnce produces a Either itself
  })

The idea here is that we start fold with an accumulator containing a Rigth with all the elements in it, picking the first filter we will end up in the case Right(filtered), and we will produce an Either. Now we will repeat this for each filter and we will stop if filterOnce produces a Left.

Before testing this version let's apply some functional refinement: this pattern matching really can be removed with a flatMap.

def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] =
  filters.foldLeft(Right(elements): Either[ErrorMessage, Seq[A]])((acc: Either[ErrorMessage, Seq[A]], e: Filter[A]) => acc.flatMap(filterOnce(_,e)))

So our final functions looks:

object RecursiveFunction2Fold {
  type ErrorMessage = String

  case class Filter[A](predicate: A => Boolean, filterError: String)

  def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] =
  filters.foldLeft(Right(elements): Either[ErrorMessage, Seq[A]])((acc: Either[ErrorMessage, Seq[A]], e: Filter[A]) => acc.flatMap(filterOnce(_,e)))

  def filterOnce[A](elements: Seq[A], filter: Filter[A]): Either[ErrorMessage, Seq[A]] = elements.filter(filter.predicate) match {
    case Nil => Left(filter.filterError)
    case remaining => Right(remaining)
  }
}

Now time to test our refactoring!

object RecursiveFunction2Fold {
  type ErrorMessage = String

  case class Filter[A](predicate: A => Boolean, filterError: String)

  def filterMany[A](elements: Seq[A], filters: Seq[Filter[A]]):Either[ErrorMessage, Seq[A]] =
  filters.foldLeft(Right(elements): Either[ErrorMessage, Seq[A]])((acc: Either[ErrorMessage, Seq[A]], e: Filter[A]) => acc.flatMap(filterOnce(_,e)))

  def filterOnce[A](elements: Seq[A], filter: Filter[A]): Either[ErrorMessage, Seq[A]] = elements.filter(filter.predicate) match {
    case Nil => Left(filter.filterError)
    case remaining => Right(remaining)
  }
}

// Let's test!

import RecursiveFunction2Fold._

val neverFailingFilter: Filter[Int] = Filter({i: Int => true}, "Unreachable")
val passingFilters: Seq[Filter[Int]] = Seq(neverFailingFilter)
val reducingFilters: Seq[Filter[Int]] = Seq(neverFailingFilter, Filter({i: Int => i < 42}, "Oh no, all the given integers are greater than 42!"))
val breakingFilters: Seq[Filter[Int]] = Seq(neverFailingFilter, Filter({i: Int => i < -1}, "Oh no, all the given integers are greater than -1!"))
val elements: Seq[Int] = Seq(1,2,3,42,5)

// tries to run all the filters against the input sequence (to reduce the ~elements~ size)
assert(filterMany(elements, passingFilters) == Right(List(1, 2, 3, 42, 5)))
assert(filterMany(elements, reducingFilters) == Right(List(1, 2, 3, 5)))
// has to fail when the filter produces an empty sequence of elements
assert(filterMany(elements, breakingFilters) == Left("Oh no, all the given integers are greater than -1!"))

Done! Tests pass, and the refactoring seems successful. Note that I kept the number of tests at minimum to ease the readability: you may want to have more tests to feel more secure that the refactoring was successful.

Happy coding!

Comments

comments powered by Disqus