Recursion and Trampolines in Scala

Recursion is beautiful. As an example, let’s consider this perfectly acceptable example of defining the functions even and odd in Scala, whose semantics you can guess:

def even(i: Int): Boolean = i match {
  case 0 => true
  case _ => odd(i - 1)
}

def odd(i: Int): Boolean = i match {
  case 0 => false
  case _ => even(i - 1)
}

We’ve defined even in a very natural way: an value is even when it’s zero, or when it’s odd after subtracting one. Odd is defined in a similar way.

Let’s ignore the problem this code has with negative integers, and the fact that an O(n) algorithm for computing even and odd is not ideal, and focus on a single problem: this code breaks for large n:

scala> even(5000)
res3: Boolean = true
scala> even(50000)
java.lang.StackOverflowError
at Trampolines$Stack$.odd(trampolines.scala:14)
at Trampolines$Stack$.even(trampolines.scala:9)
at Trampolines$Stack$.odd(trampolines.scala:14)
at Trampolines$Stack$.even(trampolines.scala:9)
at Trampolines$Stack$.odd(trampolines.scala:14)
at Trampolines$Stack$.even(trampolines.scala:9)
...
at Trampolines$Stack$.even(trampolines.scala:9)
at Trampolines$Stack$.odd(trampolines.scala:14)

In the remainder of this post we look at various strategies for preventing recursive programs from running out of memory.

But what’s the problem?

The problem manifests itself in the stack trace: even calls odd, then odd calls even, which calls odd again. Each of these invocations causes the creation of a stack frame, eating away the space that’s reserved for them.

Keeping all the stack frames would not really be necessary: when even calls odd, that is the very last thing that the even function will ever do! When odd returns, even itself returns immediately with that same value.

When we compute even(500), after a bunch of recursive calls, the 501st stack frame will be created, for the invocation of even(0), which returns true, and the 500 method calls that are left on the stack will one after another all return that same value, until the whole stack is unwinded.

Some languages or runtimes are smart enough to figure out that if a calling function will simply return the value that a function it calls returns, it does not need to allocate a new stack frame for the callee, but can reuse the stack frame of the caller for that.

Unfortunately, the JVM does not support this.

And so Scala, as long as it sticks to JVM method calling semantics, also can not support this.

A bit of light at the end of the tunnel

Luckily, there is one special case of tail calls which Scala treats differently, and that is the case of tail recursion. If a method does a tail call to itself, it’s called tail recursion, and Scala will rewrite the recursion into a loop that does not consume stack space.

Many recursive algorithms can be rewritten in a tail recursive way. For our even and odd problem, this is a possible solution:

def even(i: Int): Boolean = i match {
case 0 => true
case 1 => false
case _ => even(i - 2)
}
def odd(i: Int): Boolean = !even(i)

We’ve rewritten even to be tail-recursive, and now Scala will optimize it for us:

scala> even(500000000)
res25: Boolean = true

it no longer blows up for large numbers.

The @tailrec annotation is useful in these cases. It does not change how the function will be compiled, but it does trigger an error if the function is not actually tail-recursive. So this will protect you from accidentally changing a function that you want to be tail-recursive function into something that’s no longer tail recursive.

You can try this out by putting the @tailrec annotation on the first and the latest version of even that we’ve defined.

Trampolines

Another way of solving the problem of consuming too much stack space, is by moving the computation from the stack to the heap.

Instead of letting the even function recursively call odd, we could also let it return some recipe to it’s caller, how to continue with the computation. Let’s look at the following code:

sealed trait EvenOdd

case class Done(result: Boolean) extends EvenOdd
case class Even(value: Int) extends EvenOdd
case class Odd(value: Int) extends EvenOdd

def even(i: Int): EvenOdd = i match {
  case 0 => Done(true)
  case _ => Odd(i - 1)
}

def odd(i: Int): EvenOdd = i match {
  case 0 => Done(false)
  case _ => Even(i - 1)
}

Here we defined a little language, EvenOdd, that can express three things:

  • The computation is Done, and we have the result value

  • We need to compute whether a value is Even

  • We need to compute whether a value is Odd

  • Our even and odd functions no longer return a boolean, but return an expression in this little language.

Finally, we create an interpreter, that will recursively evaluate expressions in this language:

// Using Scala's self recursive tail call optimization
@tailrec
def run(evenOdd: EvenOdd): Boolean = evenOdd match {
  case Done(result) => result
  case Even(value) => run(even(value))
  case Odd(value) => run(odd(value))
}

Now we have expressed even and odd in their natural, mutually recursive way, and we still have a stack safe interpreter:

scala> run(even(5000000))
res1: Boolean = true

The disadvantage of this is that this is significantly slower. Unfortunately, we can’t seem to have our cake and eat it too :(

This strategy is sometimes called trampolining, because instead of creating a big stack, we go up to even, then down to run, then up to odd, then down to run, then up to even, down to run, etcetera. The size of our stack keeps growing and shrinking by one frame for every step in the computation. This looks a lot like going up and down on a trampoline :)

Generalizing

There is no need to specialize our little language to computing even and odd. We can also make a little language that can express recursion in a general way:

sealed trait Computation[A]

class Continue[A](n: => Computation[A]) extends Computation[A] {
  lazy val next = n
}

case class Done[A](result: A) extends Computation[A]
  def even(i: Int): Computation[Boolean] = i match {
    case 0 => Done(true)
    case _ => new Continue(odd(i - 1))
  }

  def odd(i: Int): Computation[Boolean] = i match {
    case 0 => Done(false)
    case _ => new Continue(even(i - 1))
  }

  @tailrec
  def run[A](computation: Computation[A]): A = computation match {
    case Done(a) => a
    case c: Continue[A] => run(c.next)
  }

Recursion and Trampolines in Scala

Here our even and odd functions don’t return domain specific values, but a general value that indicates whether the computation is done, or whether more steps are needed. The latter includes the next step as a by-name parameter, that the tail recursive runner function can call.

Note that our run function is no longer tied to computing even and odd, it can compute anything.

TailRec in the standard library

Something similar in spirit, but with a better implementation is also available in the Scala standard library:

import scala.util.control.TailCalls.{ TailRec, done, tailcall }

def even(i: Int): TailRec[Boolean] = i match {
  case 0 => done(true)
  case _ => tailcall(odd(i - 1))
}

def odd(i: Int): TailRec[Boolean] = i match {
  case 0 => done(false)
  case _ => tailcall(even(i - 1))
}

even(3000).result

Comparing performance

I compared the performance of these solutions with JMH, and these are the results:

[info] Benchmark Mode Cnt Score Error Units
[info] Trampolines.GeneralTrampolineRunner.bench thrpt 30 44916.024 ± 388.202 ops/s
[info] Trampolines.ScalaTrampolineRunner.bench thrpt 30 52106.426 ± 408.242 ops/s
[info] Trampolines.SpecializedTrampolineRunner.bench thrpt 30 94002.234 ± 1584.913 ops/s
[info] Trampolines.StackRunner.bench thrpt 30 358382.321 ± 6622.659 ops/s

As expected, the version that runs on the stack is the fastest. But remember that this is the version that breaks for a large number of recursions.

The specialized trampolining version, with the EvenOdd domain specific language and a runner optimized for this particular problem, takes about a 4 times speed hit compared to the stack version.

The general trampoline version that we defined here is about 2 times slower than the specialized version, and about 8 times slower than the stack version.

The TailRec version from the Scala standard library is about 20% faster than our general trampoline, making it about 7 times slower than the stack version.

Source code

The source code of the benchmarks (and all the code), is available on https://github.com/eamelink/scala-trampolines