Saturday 12 July 2014

Optimise Your Own Fuckin' Tail Calls

I've heard a few people moaning that Java 8 still doesn't optimise tail calls. Not having programmed in a functional language since Moscow ML at University, I didn't really give a shit, but thought I should find out what the deal is.

A tail call method is a method where the final action is to call another method. And a tail-recursive method is when a method's final action is to call itself.

Now, we all know recursion is mathematically beautiful and elegant for people who write programs that don't do anything. For the rest of us, using recursion for anything non-trivial means you're going to run out of stack at some point.

For example NASA's JPL C Coding Standard disallows the use of recursion because:-

"The presence of statically verifiable loop bounds and the absence of recursion prevent runaway code, and help to secure predictable performance for all tasks. The absence of recursion also simplifies the task of deriving reliable bounds on stack use. The two rules combined secure a strictly acyclic function call graph and control-flow structure, which in turn enhances the capabilities for static checking tools to catch a broad range of coding defects."

Languages which support tail call optimisation replace tail call recursion with a loop so you don't have to worry about runaway stack usage. Think about that - the compiler tries to get rid of recursion for you, because it's inefficient and prone to causing defects. How elegant is your recursive algorithm really?

Also, in what other situation would you expect the compiler to so drastically change your code for you? Yeah I know you can expect the Java JIT compiler to perform optimisations, but tail call elimination is a step too far for me. While you're optimising my recursive calls, why don't you just go ahead and write the rest of my algorithms for me? - GET OFF MY LAWN!

A good example of converting a recursive algorithm to a while loop is available here: http://c2.com/cgi/wiki?TailCallOptimization and I've converted it to Java below.

I've used BigInteger in the code to avoid integer overflow, because I want to demonstrate stack overflow with the default JVM settings.

First, a recursive method which is NOT tail call recursive. The method is not tail call recursive because the final action in the method is not a call to itself but a multiplication:

    private static BigInteger notTailRecursionFactorial(final BigInteger n) {

        if (n.compareTo(TWO) < 0) {
            return ONE;
        } else {
            return n.multiply(notTailRecursionFactorial(n.subtract(ONE)));
        }
    }

Modifying the method to use an accumulator variable makes the method tail call recursive, and a candidate for tail call optimisation:

    private static BigInteger tailRecursionFactorial(final BigInteger n, final BigInteger accumulator) {

        if (n.compareTo(TWO) < 0) {
            return accumulator;
        } else {
            return tailRecursionFactorial(n.subtract(ONE), n.multiply(accumulator));
        }
    }

The recursive method converted to a while loop:

    private static BigInteger tailRecursionEliminationFactorial(BigInteger n, BigInteger accumulator) {

        while (n.compareTo(TWO) >= 0) {
            accumulator = accumulator.multiply(n);
            n = n.subtract(ONE);
        }

        return accumulator;
    }

Further optimisation to the tail call eliminated method:

    private static BigInteger tailRecursionEliminationFactorialOptimised(BigInteger n) {

        BigInteger accumulator = ONE;

        while (n.compareTo(TWO) >= 0) {
            accumulator = accumulator.multiply(n);
            n = n.subtract(ONE);
        }

        return accumulator;
    }

The full class looks like this:

TailCallOptimisation.java

package org.adrianwalker;

import java.math.BigInteger;

public class TailCallOptimisation {

  private static final BigInteger ONE = new BigInteger("1");
  private static final BigInteger TWO = new BigInteger("2");
  private static final BigInteger ONE_HUNDRED_THOUSAND = new BigInteger("100000");

  public static void main(final String[] args) {

    System.out.println("\nnotTailRecursionFactorial:");

    try {
      System.out.println(notTailRecursionFactorial(ONE_HUNDRED_THOUSAND));
    } catch (final Throwable t) {
      System.out.println("Stack Overflow Error");
    }

    System.out.println("\ntailRecursionFactorial:");

    try {
      System.out.println(tailRecursionFactorial(ONE_HUNDRED_THOUSAND, ONE));
    } catch (final Throwable t) {
      System.out.println("Stack Overflow Error");
    }

    System.out.println("\ntailRecursionEliminationFactorial:");

    System.out.println(tailRecursionEliminationFactorial(ONE_HUNDRED_THOUSAND, ONE));

    System.out.println("\ntailRecursionEliminationFactorialOptimised:");

    System.out.println(tailRecursionEliminationFactorialOptimised(ONE_HUNDRED_THOUSAND));
  }

  private static BigInteger notTailRecursionFactorial(final BigInteger n) {

    if (n.compareTo(TWO) < 0) {
      return ONE;
    } else {
      return n.multiply(notTailRecursionFactorial(n.subtract(ONE)));
    }
  }

  private static BigInteger tailRecursionFactorial(final BigInteger n, final BigInteger accumulator) {

    if (n.compareTo(TWO) < 0) {
      return accumulator;
    } else {
      return tailRecursionFactorial(n.subtract(ONE), n.multiply(accumulator));
    }
  }

  private static BigInteger tailRecursionEliminationFactorial(BigInteger n, BigInteger accumulator) {

    while (n.compareTo(TWO) >= 0) {
      accumulator = accumulator.multiply(n);
      n = n.subtract(ONE);
    }

    return accumulator;
  }

  private static BigInteger tailRecursionEliminationFactorialOptimised(BigInteger n) {

    BigInteger accumulator = ONE;

    while (n.compareTo(TWO) >= 0) {
      accumulator = accumulator.multiply(n);
      n = n.subtract(ONE);
    }

    return accumulator;
  }
}

And the output on my machine is:

notTailRecursionFactorial:
Stack Overflow Error

tailRecursionFactorial:
Stack Overflow Error

tailRecursionEliminationFactorial:
282422940796034787429342157802453551847749492609... (loads more)

tailRecursionEliminationFactorialOptimised:
282422940796034787429342157802453551847749492609... (loads more)

The recursive methods run out of stack way before getting anywhere near an answer for 100,000!.

Bottom line - optimise your own fuckin' tail calls.

Update - 24/08/14

A clever chap called Dr Rowan Davies got in touch about this post. Here's some food for thought:-

Just a small comment that your page "Optimise Your Own Fuckin' Tail Calls" is missing the point.

Basically, you've only considered the very simplest situation, a single recursive function that corresponds to a loop. Functional compilers like Scala do exactly the transformation to loops you've shown, so no one is complaining about that kind of tail call on the JVM.

It is the much more powerful uses of tail calls for things that don't correspond to loops. E.g., in F# on .NET (which supports tail calls) there is really nice support for asynchronous programming that depends on tail calls to avoid the stack increasing when you swap between different asynchronous handlers and lightweight software threads. The correspond code in Scala can't do that, the JVM just doesn't support it - and I challenge you to convert such asynchronous code using tail calls to Java code that doesn't chew up stack as it goes.

And, optimized tail-calls are not rewriting your code - they are exactly doing the tail call, just in a way that doesn't keep crap on your stack that clearly can never be needed in the future. Basically rejecting tail call optimization is demanding that the JVM keep this unneeded crap on the stack. Why do you insist that the JVM keep crap around? It's equivalent to insisting that for loops chew up stack space as they go through iterations - it really makes no sense at all to require that.

And this one as well:-

Your view is actually pretty common. But, it really is unnatural to make the implementation of tail calls consume stack when they don't have to. It's not so much about mathematical elegance, it's about whether certain powerful and natural ways of programming explode the stack or not. You're expecting a function call to always be implemented in a certain way that is really non-optimal sometimes. It's not really an optimization, it's about not doing something stupid that keeps potentially huge amounts of unnecessary data on the stack. gcc has done this basically forever. Indeed, many years ago I checked the machine code output and the C compiler cc in UNIX System V and it clearly supported efficient tail calls back in 1989.

The JVM doesn't include efficient tail calls largely because the JVM security model isn't so compatible with them, and changing the model would change the assumptions existing code could depend on. This is a basically a bureaucratic reason, the requirement to support earlier JVM programs that depend on the lack of efficient tail calls. But, many hackers and language implementers will celebrate if it is allowed sometime, otherwise those hackers and languages will eventually move on to VMs that do support what they need.

Cheers for the comments doc.