Menu

Official website

Browsing collections with foldLeft and foldRight


11 Jul 2022

min read

Some time ago, I was asked to develop a small algorithm for my project. It involved going through all the frames of a video sequence to determine which manually applied label occupied the most frames within that sequence. This exercise allowed me to (re)discover a function that I rarely use on a daily basis, the foldLeft (and its corollary, the foldRight).

These methods, which make it possible to browse lists performing operations between the current value and the previous value while allowing to return a completely different type, deserve to be studied.

Let’s first look at the signature of this method :

def foldLeft[B](z: B)(op: (B, A) => B): B

We see several interesting things here :

  1. The method returns an instance of type B, which is the output type

  2. foldLeft has two parameter lists. The first one is a single parameter z the aforementioned type B. This parameter is important because it is the starting point of the iteration, and it is used to initialise it and to collect the results. We will refer to this as the accumulator (sometimes, people refer to it as a tally which records intermediate values of the iteration through all input values).

  3. And in the second parameter "op" (operation) is a function which takes two parameters (B, the output type and A, the type of the elements of the list) and which makes it possible to obtain this famous output type B.

To summarise, this function will go through all the elements of the list [A] and will collect the result in the accumulator [B]. At the end, we end up with [B].

Let’s start with a simple example. We will add each number in a list with its predecessor to get the sum of the list.

FoldLeft decomposition
val numbers = List(1, 2, 3, 4)
numbers.foldLeft[Int](z = 0)(op = (accumulator, number) => accumulator + number)
// Simplified writing
numbers.foldLeft(0)(_ + _)
// Result => Out: Int = 10

In reality, we can use the reduce method here which does not use an initial condition and returns the same type as the input.

val numbers = List(1, 2, 3, 4)
numbers.reduce(_ + _)
// This is equivalent to :
numbers.sum
// Result => Out: Int = 10

Let’s move on to more concrete cases.

In the following example, I don’t want to sum the numbers in the list, but to obtain a new list where each number is the addition of the number in the list "numbers" with its predecessor.

We will go through the list "List(1, 2, 3, 4)" and at each iteration do this operation:

  1. empty accumulator and current number = 1. Result: accumulator = List(1)

  2. accumulator (1) and current number = 2. Result: accumulator = List(1, 2+1)

  3. accumulator (1, 3) and current number = 3. Result: accumulator = List(1, 3, 3+3)

  4. accumulator (1, 3, 6) and current number = 4. Result: accumulator = List(1, 3, 6, 4+6)

Final result: newList = List(1, 3, 6, 10)

A possible writing can be based on a pattern matching with a dedicated condition for the first iteration. Since we have nothing to add at the beginning, we simply fill the list with the first number. If we don’t consider this case, we start with an empty list which will give an error at runtime because we will be doing accumulator.last on an empty list.

This implementation has some performance issues which we will discuss at the end of the article.

val numbers = List(1, 2, 3, 4)
val newList: Seq[Int] = numbers.foldLeft[Seq[Int]](Seq.empty) {
    case (accumulator, number) if accumulator.isEmpty => accumulator :+ number
    case (accumulator, number) => accumulator :+ (accumulator.last + number)
}
println(newList)
// Result => List(1, 3, 6, 10)

Working with different input and output types

Let’s take this a step further. Let’s imagine that our output type is different from our input type, for example to add computed data to an entity to be persisted in a database from an object coming from an API.

In this case, the user would enter a name and date of birth in a GUI and the system would calculate the age and age difference with its predecessor before saving it.

The calculation of the interval is not really interesting, but it will allow us to make some calculations and to observe that foldLeft and foldRight do not systematically give the same results even if we give them the same input data (the condition for these two functions to return the same result is that the function op needs to be both commutative and associative).

A possible way of writing this could be as follows :

import java.time.LocalDate
case class UserApi(name: String, birthYear: Int)
case class UserData(name: String, birthYear: Int, age: Int, deltaWithPrecedent: Int)
val user1 = UserApi("Marc", 1982)
val user2 = UserApi("Pierre", 1995)
val user3 = UserApi("Marie", 1987)
val user4 = UserApi("Lydia", 1987)
val user5 = UserApi("Sophie", 1990)
val userList = Seq(user1, user2, user3, user4, user5)
private def computeAge(birthYear: Int) = LocalDate.now.getYear - birthYear
private def computeDeltaWithPrecedent(birthYear: Int, precedentBirthYear: Int) = birthYear - precedentBirthYear
def computeUserDatas(users: Seq[UserApi]): Seq[UserData] =
    users
      .sortBy(user => (user.birthYear, user.name)) // Sort first by "birthYear", then by "name"
      .foldLeft[Seq[UserData]](Seq.empty) { (acc, user) =>

        val userDataList = if (acc.isEmpty) {
          acc :+ UserData(
            user.name,
            user.birthYear,
            computeAge(user.birthYear),
            0
          )
        }
        else acc :+ UserData(
            user.name,
            user.birthYear,
            computeAge(user.birthYear),
            computeDeltaWithPrecedent(user.birthYear, acc.last.birthYear)
          )
      userDataList
      }
computeUserDatas(userList).foreach(println)
/* Result => Each interval is calculated in relation to the lower year
  UserData(Marc,1982,40,0)
  UserData(Lydia,1987,35,5)
  UserData(Marie,1987,35,0)
  UserData(Sophie,1990,32,3)
  UserData(Pierre,1995,27,5)
*/

Reverting the path with foldRight

If we now use a foldRight on our list of UserApi, we can traverse the list from the end to the beginning.

In this case, the interval is calculated not between the current value and its previous one on the left, but between the current value and its previous one on the right. The result of the interval between the dates of birth will therefore be different.

In the following example, I use a slightly more concise script and have reorganised the code by integrating the two private methods into the main method.

import java.time.LocalDate
case class UserApi(name: String, birthYear: Int)
case class UserData(name: String, birthYear: Int, age: Int, deltaWithPrecedent: Int)
val user1 = UserApi("Marc", 1982)
val user2 = UserApi("Pierre", 1995)
val user3 = UserApi("Marie", 1987)
val user4 = UserApi("Lydia", 1987)
val user5 = UserApi("Sophie", 1990)
val userList = Seq(user1, user2, user3, user4, user5)
def computeUserDatas(users: Seq[UserApi]): Seq[UserData] =
  users
    .sortBy(user => (user.birthYear, user.name))
    // The pair (current value, accumulator) is inverted with respect to the foldLeft
    .foldRight[Seq[UserData]](Seq.empty) { (user, acc) =>
      def computeAge(birthYear: Int) = LocalDate.now.getYear - birthYear
      // The direction of the operation must be reversed to avoid negative results, or use (birthYear - precedentBirthYear).abs
      def computeDeltaWithPrecedent(birthYear: Int, precedentBirthYear: Int) = precedentBirthYear - birthYear
      if (acc.isEmpty)
        acc :+ UserData(
          user.name,
          user.birthYear,
          computeAge(user.birthYear),
          0
        ) else acc :+ UserData(
        user.name,
        user.birthYear,
        computeAge(user.birthYear),
        computeDeltaWithPrecedent(user.birthYear, acc.last.birthYear)
      )
    }
computeUserDatas(userList).foreach(println)
/* Result => (each interval is calculated in relation to the year above)
UserData(Pierre,1995,27,0)
UserData(Sophie,1990,32,5)
UserData(Marie,1987,35,3)
UserData(Lydia,1987,35,0)
UserData(Marc,1982,40,5)
*/

Handling an exception with Either and Cats

Finally, here is a more complex example of exception handling, first with an Either, then with the Cats library.

Let’s imagine that we are managing a team (Team) made up of players who can take on different statuses over time. Let’s imagine we have an endpoint that allows us to delete players by giving them the status Deleted unless a player has the status Enrolled (entered in a competition for example, in which case, deleting them would cause some problems).

For some reason (actually, for the very good reason that it serves my example), you save the whole list or nothing at all. So the idea here is to stop processing and throw an exception in a Left if a Player with Enrolled status is found in the list, which is the case here.

import scala.concurrent.{ Await, ExecutionContextExecutor, Future }
import scala.concurrent.duration.DurationInt
implicit val executor: ExecutionContextExecutor = scala.concurrent.ExecutionContext.global
sealed trait PlayerStatus
object PlayerStatus {
  case object Available extends PlayerStatus
  case object Enrolled extends PlayerStatus
  case object Resting extends PlayerStatus
  case object Deleted extends PlayerStatus
}
case class Player(name: String, currentStatus: PlayerStatus) {
  def updateStatus(
    status: PlayerStatus
  ): Either[Exception, Player] =
    if (currentStatus == PlayerStatus.Enrolled) Left(new IllegalArgumentException(s"status is $currentStatus"))
    else Right(copy(currentStatus = status))
}
case class Team(players: Seq[Player])
val team = Team(
  Seq(
    Player("player1", PlayerStatus.Available),
    Player("player2", PlayerStatus.Resting),
    Player("player3", PlayerStatus.Enrolled) // The status that causes the interruption
  )
)
val resultEither: Future[Either[IllegalArgumentException, Seq[Player]]] =
  for {
    updatedPlayers <- Future.successful {
      team.players
        .map(_.updateStatus(PlayerStatus.Deleted))
        .foldLeft[Either[Exception, Seq[Player]]](Right(Seq.empty[Player])) { (acc, current) =>
          acc.flatMap { players =>
            current.map(_ +: players)
          }
        }
        .left
        .map(error => new IllegalArgumentException(s"Unable to delete the player due to ${error.getMessage}"))
    }
  } yield updatedPlayers
Await.result(resultEither, 1.second)
/* Result =>
Left(java.lang.IllegalArgumentException: Unable to delete the task due to status is Enrolled)
*/

Some details :

acc.flatMap { players =>
            current.map(_ +: players)
          }

The flatMap allows you to access the Player Sequence located in the Right of the accumulator’s Either and return an Either[Exception, Seq[Player]] instead of an Either[Exception, Either[Exception, Seq[Player]].

.left
.map(error => ...

If there is no Right, then Left is considered a return type. As there is only one possible Left in our return type Either[Exception, Seq[Player]], then processing is stopped as soon as it is filled in.

With the Cats library, we can code this up as follows :

import cats.data.{EitherT, Validated}
import cats.implicits._
import scala.concurrent.duration.DurationInt
import scala.concurrent.{Await, ExecutionContextExecutor, Future}
implicit val executor: ExecutionContextExecutor = scala.concurrent.ExecutionContext.global
sealed trait PlayerStatus
object PlayerStatus {
  case object Available extends PlayerStatus
  case object Enrolled extends PlayerStatus
  case object Resting extends PlayerStatus
  case object Deleted extends PlayerStatus
}
case class Player(name: String, currentStatus: PlayerStatus) {
  def updateStatus(
      status: PlayerStatus
  ): Validated[Exception, Player] =
    if (currentStatus == PlayerStatus.Enrolled)
      Validated.invalid[Exception, Player](new IllegalArgumentException(s"status is $currentStatus"))
    else Validated.valid[Exception, Player](copy(currentStatus = status))
}
case class Team(players: Seq[Player])
val team = Team(
  Seq(
    Player("player1", PlayerStatus.Available),
    Player("player2", PlayerStatus.Resting),
    Player("player3", PlayerStatus.Enrolled) // The status that causes the interruption
  )
)
val resultEitherT: EitherT[Future, IllegalArgumentException, Seq[Player]] =
  for {
    updatedPlayers <- EitherT.fromEither[Future] {
      team.players
        .map(_.updateStatus(PlayerStatus.Deleted))
        .foldLeft[Validated[Exception, Seq[Player]]](Validated.Valid(Seq.empty[Player])) { (acc, current) =>
          acc.andThen { players =>
            current.map(_ +: players)
          }
        }
        .leftMap(error => new IllegalArgumentException(s"Unable to delete the task due to ${error.getMessage}"))
        .toEither
    }
  } yield updatedPlayers
Await.result(resultEitherT.value, 1.second)
/* Result =>
Left(java.lang.IllegalArgumentException: Unable to delete the task due to status is Enrolled)
*/

Attentive users may have noticed this piece of code :

acc.andThen { players =>
            current.map(_ +: players)
          }

Again, this is the Cats version of left.map(…​) Finally, we wrap our block with EitherT.fromEither[Future] { { …​ }.toEither } to change the type from Validated to EitherT. Note that we have used EitherT and Validated, two Cats specific types.

Performances

If you remember, I proposed this implementation at the beginning of the article :

val numbers = List(1, 2, 3, 4)
val newList: Seq[Int] = numbers.foldLeft[Seq[Int]](Seq.empty) {
    case (accumulator, number) if accumulator.isEmpty => accumulator :+ number
    case (accumulator, number) => accumulator :+ (accumulator.last + number)
}
println(newList)
// Result => List(1, 3, 6, 10)

In reality, we have initialized the accumulator with Seq.empty[Int] while the proposed list is of type List. Since the type is generic (foldLeft takes a Seq[Int]), the compiler will assign the List type of our list of numbers to the accumulator. The problem for performance is that the List type will be found at every stage of processing the list elements:

  1. when retrieving the last element of the list (accumulator.last)

  2. when adding the new element to the end of the list (accumulator :+ …​)

For the accumulator.last, the implementation in Scala deletes the first element, then looks at how many elements are left in the list. It will do this again until there is only one element left to return.

Referring to the scala documentation on collection performance, the operation of adding items to a List collection takes longer the larger the list.

To solve this problem while remaining generic, we can rewrite our method as follows :

List(1,2,3,4).foldLeft(Seq.empty[Int]) {
  case (Nil, element) => Seq(element)
  case (accumulator, element) => (accumulator.head + element) +: accumulator
}.reverse

In this way, accessing or adding an element to the accumulator will be a constant time (fast) operation. But as this implementation produces a reverse result, it is sufficient to add a .reverse at the end, or to use a foldRight, which you will notice is just a reverse foldLeft.

def foldRight[B](z: B)(op: (A, B) ⇒ B): B = reversed.foldLeft(z)b, a) ⇒ op(a, b

Conclusion

As demonstrated in this article, foldLeft and foldRight are very powerful methods that can be considered the equivalent of the Swiss Army knife in the Scala collection library: it operates on a collection of elements of some type A and can generate a value that is the same type A or a completely different type B.

I hope this article has enlightened you on how to use them simply and convinced you to use and even abuse them !

Thanks to Vincent and Eric for their accurate review and contribution.

expand_less