Sunday, April 28, 2013

Scala Tail Recursion

When you write a recursive function, every call to the recursive function results in another method call being placed on to the call stack. If the stack grows too much, you'll get a stack overflow error. For example, look at the following Scala code that uses recursion to calculate the sum of all integers in a list.

  def sum(s: Seq[Int]): BigInt = {
    if (s.isEmpty) 0 else s.head + sum(s.tail)
  }


When we execute this method for very large lists, we get a stack overflow error.

Exception in thread "main" java.lang.StackOverflowError
 at scala.collection.AbstractTraversable.(Traversable.scala:105)
 at scala.collection.AbstractIterable.(Iterable.scala:54)
 at scala.collection.AbstractSeq.(Seq.scala:40)
 at scala.collection.immutable.Range.(Range.scala:44)
 at scala.collection.immutable.Range$Inclusive.(Range.scala:330)
 at scala.collection.immutable.Range$Inclusive.copy(Range.scala:333)
 at scala.collection.immutable.Range.drop(Range.scala:170)
 at scala.collection.immutable.Range.tail(Range.scala:196)
 at scala.collection.immutable.Range.tail(Range.scala:44)
        ...


In Scala, we can use tail recursion to tell the compiler to turn our recursive call into a loop to avoid a stack overflow error. To do this, simply add a "tailrec" annotation to the method call.

@tailrec def sum(s: Seq[Int]): BigInt = { if (s.isEmpty) 0 else s.head + sum(s.tail) }

However, if we add the annotation and re-run our example, we get the following compiler error.

could not optimize @tailrec annotated method sum: it contains a recursive call not in tail position


This error is the result of the Scala compiler not being able to utilize tail recursion due to the structure of our code. Why is that? Well, if we take the "else" path in our sum method, the first step is the recursive call to "sum", passing in the tail of the list. The result of the recursive call is added to the head of the list, making the addition operation the last step in the computation. In order to utilize tail recursion, we need to refactor our code to make the recursive call the last step of the computation. Here's the same algorithm, after refactoring to make the recursive call to "sum" be the last step in the computation.

  @tailrec def sum(s: Seq[Int], p: BigInt): BigInt = {
    if (s.isEmpty) p else sum(s.tail, s.head + p)
  }

  def main(args: Array[String]) {
   val result = sum(1 to 1000000, 0)
   println(result)
  }


Now, after running the example, we get a successful result.

500000500000


Note that the Scala compiler will try to use tail recursion when it can; however, it's a good practice to annotate any methods where you expect this optimization to be done (i.e. the last step in your recursive algorithm is the call to the recursive function). That way, you'll be warned at compile-time that there was an issue applying the optimization, preventing surprises at run-time.

No comments:

Post a Comment