Automatic Differentiation Intuition Dump

Alec Jacobson

December 13, 2014

weblog/

Every so often I re-read the wikipedia page for automatic differentiation and get confused about forward and reverse accumulation. Both are neat and have appropriate and inappropriate applications. There are many tutorials online, and in addition here's my intuition.

Forward accumulation

At each step of computation, maintain a derivative value. We seed each initial variable with derivative 1 or 0 according to whether we're differentiating with respect to it.

Augmenting numerical types with a "dual value" (X := x + x'ε) such that ε*ε=0and overloading math operations is an easy way to implement this method.

For f:R→Rⁿ and n>>1 this is ideal since we end up computing and storing 1 value at each computation step. If there are m computation variables then we track m derivatives. Work and memory is O(m) to get the n-long vector derivative.

For f:Rⁿ→R this is not ideal. To take a gradient we need to store n derivatives for each computation variable or sweep through the computation n times: O(mn) work.

Backward accumulation

At each step of computation, maintain the current step's gradient with respect to its inputs and pointers to its input variables. When retrieving derivatives, evaluate the outermost gradient apply the chain rule recursively to its remembered arguments.

Can also implement this with a special numerical type and mathematics operation overloading. This type should maintain the entire expression graph of the computation (or at least store the most immediate computation and live pointers to previous computation variables of the same type), with gradients also provided by each mathematical operation. I suppose one way to implement this is by altering math operations to augment their output with handles to functions computing gradients. Traditionally compilers should be bad at evaluating this stored expression graph, but I wonder if modern compilers with inline function optimization couldn't optimize this?

In any case, for f:Rⁿ→R and n>>1 this is ideal since a single derivative extraction traversal involving m computation variables will (re)visit each computation variable once: O(m).

For f:R→Rⁿ this is not ideal. At each computation variable we need to store n derivatives and keep them around until evaluation: O(m*n) memory and work. Whereas forward accumulation just tracks n values across m computation variables: O(m) memory.