Posted on 2022-09-25 11:15
Tail-call optimisation in Python using ast
Python does not have tail-call optimisation (TCO), and it likely never will.
Several packages exist that add TCO via decorators (e.g., tco
).
Most of these packages use some lambda calculus to implement TCO.
In this post I want to explore an alternative way to achieve TCO, using the ast
module to rewrite the decorated function.
Note: do not use this. Do not use any of this. It is fun to rewrite function definitions using
ast
, but you lose a lot of stack information along the way. If you think you need this, you might be better off just rewriting the function yourself.
What's tail-call optimisation?¶
Tail-call optimisation refers to an optimisation technique commonly applied by compilers of functional languages to avoid creating new stack frames for each function call. In particular, it applies to cases where a function, upon returning, calls itself recursively. If that is the only place in the function definition where recursion occurs, the function itself is tail-call recursive (TCR).
Such tail-call recursive functions can, for example, be rewritten as iterative functions, using an infinite loop and appropriate variable updating.
This is explained very well on the wiki.
Rewriting the function in this way avoids the recursive calls entirely, which in Python also avoids a RecursionError
after a fixed recursion depth.
On my system, that's after 3,000 recursive calls:
import sys
sys.getrecursionlimit()
The following is not a tail-call recursive function:
def factorial_not_tcr(n: int) -> int:
if n == 1:
return 1
return n * factorial_not_tcr(n - 1)
Since the recursive call is not last thing the function does before returning: it is evaluating the multiplication n * <res for n - 1>
instead.
This is clear when we inspect the disassembled bytecode:
import dis
dis.dis(factorial_not_tcr)
To make this function tail-call recursive, one could use an accumulator:
def factorial_tcr(n: int, acc: int = 1) -> int:
if n == 1:
return acc
return factorial_tcr(n - 1, acc * n)
dis.dis(factorial_tcr)
This function, factorial_tcr
, is what we will work with.
Rewriting tail-call recursive functions¶
Our goal is to design a decorator function that turns tail-call recursive functions into their iterative equivalent, and executes that iterative function. Such a decorator should only work on functions that are, in fact, tail-call recursive. As such, we need a way to programmatically determine whether a function is tail-call recursive. Let's first do that.
We will design something simple: an is_tcr
function that takes a callable argument f
and determines whether f
is TCR, by parsing and evaluating its abstract syntax tree (AST) representation.
In particular, we check that any recursive function call to f
occurs only in a return
statement, and that the return
statement does nothing else.
For this, we will use the ast
module, and particularly the ast.NodeVisitor
class.
import ast
import inspect
class TCRVisitor(ast.NodeVisitor):
def __init__(self, fname: str):
self._fname = fname
def visit_Return(self, node):
# From the docs:
# > Note that child nodes of nodes that have a custom visitor method
# > won't be visited unless the visitor calls generic_visit() or visits
# > them itself.
if not isinstance(node.value, ast.Call) \
or node.value.func.id != self._fname:
self.generic_visit(node)
def visit_Call(self, node):
if node.func.id == self._fname:
raise TypeError("Expected a tail-call recursive function; this"
" AST does not appear to be one.")
def is_tcr(f: callable) -> bool:
tree = ast.parse(inspect.getsource(f))
visitor = TCRVisitor(f.__name__)
try:
visitor.visit(tree)
return True
except TypeError:
return False
Let's check if this works:
assert not is_tcr(factorial_not_tcr)
assert is_tcr(factorial_tcr)
Good.
We are now ready to start rewriting TCR functions.
For this we will use the ast.NodeTransformer
class, since we are no longer just statically inspecting code.
In particular, we need to do two things given the AST of a function f
:
- Wrap the function body in an infinite loop;
- Replace every tail-recursive call with variable updates, and
continue
. The continue is needed in case the tail-call occurs in a branch.
class TCRRewriter(ast.NodeTransformer):
def __init__(self, fname: str):
self._fname = fname
self._args = []
def visit_FunctionDef(self, node):
if node.name == self._fname:
self._args = node.args.args
# Insert infinite loop. This is our iterative looping construct.
node.body = [ast.While(ast.Constant(True), node.body, "")]
# Remove the decorator - we only need to rewrite once.
if "rewrite_tcr" in node.decorator_list:
node.decorator_list.remove("rewrite_tcr")
self.generic_visit(node)
return node
def visit_Return(self, node):
for child in ast.iter_child_nodes(node):
if isinstance(child, ast.Call) and child.func.id == self._fname:
# This creates a single assignment, using tuple unpacking of
# both the values and targets.
targets = [ast.Name(id=arg.arg, ctx=ast.Store())
for arg in self._args]
targets = [ast.Tuple(elts=targets, ctx=ast.Store())]
value = ast.Tuple(elts=child.args, ctx=ast.Load())
assignments = ast.Assign(targets, value)
return ast.Module(body=[assignments, ast.Continue],
type_ignores=[])
return node
from functools import wraps
def rewrite_tcr(f):
assert is_tcr(f)
@wraps(f)
def wrapper(*args, **kwargs):
tree = ast.parse(inspect.getsource(f))
transformer = TCRRewriter(f.__name__)
transformer.visit(tree)
# Add location (line numbers and column offsets) for newly generated
# nodes, and then bring the changed function into scope locally.
ast.fix_missing_locations(tree)
exec(ast.unparse(tree))
# Call the new iterative function, and return its output.
defs = locals()
return defs[f.__name__](*args, **kwargs)
return wrapper
Let's test if this works by calling the decorated factorial function with an argument that exceeds the recursion limit:
factorial = rewrite_tcr(factorial_tcr)
factorial(5000)
And, just to be sure we did not make a mistake:
import math
assert math.factorial(5000) == factorial(5000)
Conclusions¶
Above we designed a simple decorator, rewrite_tcr
, that performs tail-call optimisation by rewriting the function definition using the ast
module.
This decorator is by no means complete:
- Earlier, we had to rewrite
factorial_not_tcr
into an alternative TCR function, using an accumulator. We could also do this programmatically. - The decorator does not detect cycles where
f
callsg
, which in turn callsf
again. - The decorator does not support keyword or variable function arguments.
Feel free to patch any of these omissions!