Lecture 8: Local Function Definitions and Tail Calls
So far we’ve developed a nice suite of basic features for our language:
The ability to perform arithmetic operations
The ability to perform logical operations and make decisions based on them
The ability to interact with the operating system in limited ways (errors, printing, random numbers etc)
However, the computational power of our language is fundamentally
limited: all of our programs use finite space and terminate in finite
time. Over the next few weeks we will substantially increase the power
of our snake languages by lifting both of these restrictions. We will
use functions and looping constructs to lift our
restriction of terminating programs and we will use the stack
and the heap to give us access to unbounded memory. Today we
will start this journey towards Turing-completeness by adding function
definitions and a restricted form of calls that can be compiled
directly to jumps. This increases the power of our language to that of
finite-state automata1Note that technically to implement
finite-state automata we would need somethink like a
readBool()
or readInt()
built-in that
reads the next character from the input string so that we could take
arbitrarily large strings as input..
1 Growing the language: local function definitions, tail calls
Reminder: Every time we enhance our source language, we need to consider several things:
Its impact on the concrete syntax of the language
Examples using the new enhancements, so we build intuition of them
Its impact on the abstract syntax and semantics of the language
Any new or changed transformations needed to process the new forms
Executable tests to confirm the enhancement works as intended
1.1 Concrete syntax
We’ll start with concrete syntax. A function call is a new form of expression that starts with a function name and takes zero or more comma-separated expressions as arguments.
‹expr› ... IDENTIFIER ( ‹exprs› ) IDENTIFIER ( ) ‹exprs› ‹expr›‹expr› , ‹exprs›
To account for function definitions we have a choice to make. The easiest to implement would be to simply have a sequence of top level function definitions followed by one main body expression. But we will do something more flexible: we will allow for local function definitions inside arbitrary expressions.
‹expr› ... ‹decls› in ‹expr› ‹decls› ‹decl› ‹decls› and ‹decl› ‹decl› def IDENTIFIER ( ‹ids› ) : ‹expr› def IDENTIFIER ( ) : ‹expr› ‹ids› IDENTIFIER IDENTIFIER , ‹ids›
Allowing for local function definitions like this is very convenient
for programming. For instance, we can use recursive functions in lieu
of looping constructs, like in this factorial example:
def fac(x):
def loop(x, acc):
if x == 0:
acc
else:
loop(x - 1, acc * x)
in
loop(x, 1)
in
fac(10)
If we only had top-level function definitions, the programmer would have to lift loops (and nested loops) like this to the top-level, polluting the global namespace and making the program less clear overall.
Our new form also allows for mutually recursive function definitions. For instance, here is how we might implement a function that checks if the input is even or odd without having access to a modulo or div operator:
def even(x):
def evn(n):
if n == 0:
true
else:
odd(n - 1)
and
def odd(n):
if n == 0:
false
else:
even(n - 1)
in
if x >= 0:
evn(x)
else:
evn(-1 * x)
in
even(24)
For our examples, let’s design max
, which takes two numeric arguments and
returns the larger of the two.
def max(x,y):
if x >= y: x else: y
end
max(17,31)
should evaluate to 31
.
1.2 Abstract syntax for Calls
First, let’s cover the calling side of the language.
Do Now!
What should the semantics of
f(e1, e2, ..., en)
be? How should we represent this in ourExp
data definition? What knock-on effects does it have for the transformation passes of our compiler?
The first thing we might notice is that attempting to call an unknown function
should be prohibited —
Do Now!
What other programmer mistakes should we try to catch with well-formedness checking? What new mistakes are possible with function calls?
What should happen if a programmer tries to call max(1)
or
max(1, 2, 3)
? Certainly nothing good can happen at runtime if we
allowed this to occur. Fortunately, we can track enough information to prevent
this at well-formedness time, too. Our function environment should keep track
of known function names and their arities. Then we can check every function
call expression and see whether it contains the correct number of actual
arguments.
We need more examples:
Source |
| Output |
|
|
|
|
|
|
|
|
|
To represent call expressions in our AST, we just need to keep track of the function name, the argument list, and any tag information:
enum Exp<Ann> {
...
Call(String, Vec<Exp<Ann>>, Ann),
}
We need to consider how our expression evaluates, which in turn means considering how it should normalize into sequential form.
Do Now!
What are the design choices here?
Since Exp::Call
expressions are compound, containing multiple subexpressions,
they probably should normalize similar to how we normalize Prim2
expressions: the arguments should all be made immediate.
pub enum SeqExp<Ann> {
...
Call(String, Vec<ImmExp>, Ann),
}
We have at least two possible designs here, for how to normalize these expressions: we can choose a left-to-right or right-to-left evaluation order for the arguments. For consistency with infix operators, we’ll choose a left-to-right ordering.
Do Now!
What tiny example program, using only the expressions we have so far, would demonstrate the difference between these two orderings?
Do Now!
Extend sequentialization to handle
Exp::Call
.
1.3 Defining our own functions
Now that our programs include function definitions and a main expression, our AST representation now grows to match:
pub struct FunDecl<E, Ann> {
pub name: String,
pub parameters: Vec<String>,
pub body: E,
pub ann: Ann,
}
pub struct Prog<E, Ann> {
pub funs: Vec<FunDecl<E, Ann>>,
pub main: E,
pub ann: Ann,
}
Here we are abstract over annotations, as well as the type of
expressions. This allows us to instantiate to Prog<Exp>
or
Prog<SeqExp>
to encode whether or not the expressions are in
sequential form.
1.4 Semantics
Do Now!
What new semantic concerns do we have with providing our own definitions?
As soon as we introduce a new form of definition into our language, we need to consider scoping concerns. One possibility is to declare that earlier definitions can be used by later ones, but not vice versa. This possibility is relatively easy to implement, but restrictive: it prevents us from having mutually-recursive functions. Fortunately, because all our functions are statically defined, supporting mutual recursion is not all that difficult; the only complication is getting the well-formedness checks to work out correctly.
Exercise
Do so.
Additionally, the bodies of function definitions need to consider scope as well.
def sum3(x, y, z):
a + b + c
end
x + 5
This program refers to names that are not in scope: a
, b
and c
are not in scope within sum3
, and x
is not in scope outside of it.
def f(x): x end
def f(y): y end
f(3)
Repeatedly defining functions of the same name should be problematic: which function is intended to be called?
def f(x, x): x end
f(3, 4)
Having multiple arguments of the same name should be problematic: which argument should be returned here?
1.5 Code Generation: Function Definitions
How should we compile function definitions? Well, each function
definition needs to have associated code that we can jump to to
execute the function, so each function will need to produce a label
followed by the code corresponding to the body. But we also need to
consider the arguments of the function. The simplest thing to
do is to extend our current treatment of local variables in let
bindings: arguments will just be new local variables stored on the
stack. So in an expression like
let a = 1 in
def f(b,c,d):
let x = 2 in
def g(y,z):
...
in
...
in
...
When we compile this code we need to determine where all of the local
variables go. For a first attempt we can extend our treatment of
let-bound variables: a
goes in the first free stack slot
[RSP - 8 * 1]
and b
goes in the next [RSP - 8 * 2]
,
and c,d,x,y,z
go in 3,4,5,6,7
respectively.
Other than that, the code generation for function definitions is straightforward: each function definition
def f(x,y,z):
e
end
can be simply compiled to a label followed by the code for the body
f:
;; code for e
We just need to make sure the labels we use for the function names are all unique, using similar techniques for other labels we generate.
1.6 Code Generation: Tail Calls
Now what about the function call? Well it turns out that in all
of the example programs so far, the function calls have had a very
special property: they are always the last thing that the
caller does: for instance in our implementation of even
,
the main expression ends by calling even(24)
, each
branch of even
ends by calling evn
and the
functions evn
and odd
in each branch
either return a value or end with a call to the other. We refer to
these calls as tail calls because they are at the "tail end" of
the expression. These are very nicely behaved from an implementation
perspective: if the last thing an expression does is call another
function, then that expression should return whatever value that
function does, meaning that the function being called (the callee)
should use the same return address that the caller has on the
stack. Furthermore, all of the local variables of the caller on the
stack are no longer needed, so we can safely overwrite all of them
with the arguments to the callee and the callee can use the space for
its stack frame. For this reason, tail calls can be implemented
extremely efficiently as simply jmp
instructions, after the
caller places the arguments where the callee expects them to be on the
stack.
On the other hand, not every function call in a program is a tail call, for instance consider this implementation of factorial:
def factorial(x):
if x == 0:
1
else:
x * factorial(x - 1)
in
factorial(6)
In this function, the call factorial(6)
is a tail call,
because it’s the last thing the main expression does, but the
recursive call factorial(x - 1)
is not a tail
call, because we have to do something else after it returns: multiply
the result by x
. Non-tail calls will be compiled
differently, so we will cover them later. For today, we will only
support tail calls.
Regardless of how tail and non-tail calls are implemented, since they are compiled differently, we need to be able to precisely determine which calls are tail or not. To do this we generalize from just the calls to define when any sub-expression is in tail position:
The expression of our program is in tail position.
The body of a function is in tail position.
If a let-binding is in tail position, then (a) its body is in tail position, but (b) the bindings themselves are not.
If a conditional is in tail position, then (a) its branches are in tail position, but (b) the condition itself is not.
The operands to an operator are never in tail position.
Visually, green expressions are always in tail position, yellow expressions are potentially in tail position, and red expressions are never in tail position:
pub struct Prog<E, Ann> {
pub funs: Vec<FunDecl<E, Ann>>,
pub ~hl:3:s~main: E~hl:3:e~
pub ann: Ann,
}
pub struct FunDecl<E, Ann> {
pub name: String,
pub parameters: Vec<String>,
pub ~hl:3:s~body: E~hl:3:e~,
pub ann: Ann,
}
pub enum SeqExp<Ann> {
Imm(~hl:2:s~ImmExp~hl:2:e~, Ann),
Prim1(Prim1, ~hl:2:s~ImmExp~hl:2:e~, Ann),
Prim2(Prim2, ~hl:2:s~ImmExp~hl:2:e~, ~hl:2:s~ImmExp~hl:2:e~, Ann),
Let {
var: String,
~hl:2:s~bound_exp: Box<SeqExp<Ann>>~hl:2:e~,
~hl:1:s~body: Box<SeqExp<Ann>>~hl:1:e~,
ann: Ann,
},
If {
~hl:2:s~cond: ImmExp~hl:2:e~,
~hl:1:s~thn: Box<SeqExp<Ann>>~hl:1:e~,
~hl:1:s~els: Box<SeqExp<Ann>>~hl:1:e~,
ann: Ann,
},
~hl:1:s~Call~hl:1:e~(String, ~hl:2:s~Vec<ImmExp>~hl:2:e~, Ann),
}
We can codify this, if we so choose, as a kind of tagging operation:
fn mark_tails<Ann>(e: &SeqExp<Ann>, is_tail: bool) -> SeqExp<bool> {
match e {
SeqExp::Imm(i, _) => SeqExp::Imm(i.clone(), is_tail),
SeqExp::Prim1(op, i, _) => SeqExp::Prim1(*op, i.clone(), is_tail),
...
SeqExp::Let {
var,
bound_exp,
body,
..
} => SeqExp::Let {
var: var.clone(),
bound_exp: Box::new(mark_tails(bound_exp, false)),
body: Box::new(mark_tails(&body, is_tail)),
ann: is_tail,
},
SeqExp::If { cond, thn, els, .. } => SeqExp::If {
cond: cond.clone(),
thn: Box::new(mark_tails(thn, is_tail)),
els: Box::new(mark_tails(els, is_tail)),
ann: is_tail,
},
}
}
In practice we probably don’t need to, and instead can just carry along a
boolean flag through our compile_with_env
function that keeps track of our
tail-position status:
fn compile_with_env<'exp>(
e: &'exp SeqExp<u32>,
.., // other arguments
is_tail: bool) // true when the expression is in tail position
-> Vec<Instr> {
match e {
...
SeqExp::Let { var, bound_exp, body, ann } => {
...
let bound_exp_is = compile_with_env(bound_exp, ..., false);
...
let body_is = compile_with_env(body, ..., is_tail)
}
SeqExp::Call(fun, args, ann) => {
if is_tail {
// generate a tail call
} else {
// generate a non-tail call
}
}
...
}
}
Ok, now that we know for sure we have a function call in tail
position, how do we compile it? For a simple start, let’s say we are
compiling the following expression:
let x = 7
def f(a,b,c):
if b:
x * a
else:
c
in
let y = x * 2 in
let z = print(y) in
f(5, true, 13)
Following our scheme for storing local variables, x, y
will be placed at [RSP - 8 * 1], [RSP - 8 * 2]
. When we
call f(5,true,13)
, we want to transfer control to the
code where the body of f
is implemented, but with the
arguments a,b,c
having the correct values. So when we
call f
, we need to remember that in the body of
f
, we have access to 4 local variables;
x,a,b,c
, which are placed in the first 4 stack slots.
So to compile the function call f(5,true,13)
, we simply
mov
the arguments to the appropriate place on the stack and then
jmp
to f
:
mov [RSP - 8 * 2], 10 ;; a = 5
mov [RSP - 8 * 3], 0x7FF... ;; b = true
mov [RSP - 8 * 4], 26 ;; c = 13
jmp f ;; tail call f
Note here that a
is at RSP - 8 * 2
because
f
expects x
to be at RSP - 8 * 1
.
This works for constants, but there is one tricky situation we should
be careful of more generally. Consider the following seemingly
innocuous change: change the call to f(5, true, y)
. Then
if we naïvely try to compile it the same way something horribly
wrong happens:
mov [RSP - 8 * 2], 10 ;; a = 5
mov [RSP - 8 * 3], 0x7FF... ;; b = true
mov RAX, [RSP - 8 * 2] ;; load y
mov [RSP - 8 * 4], RAX ;; c = y
jmp f ;; tail call f
Do Now!
What went horribly wrong?
Exercise
Try to fix it.
Since a tail call re-uses the current stack frame in place, we
overwrote the value of y
on the stack before we were
able to use it as an argument! The effect is that the call is
accidentally compiled as if it were f(5,true,5)
.
How can we avoid this happening? Well the easiest way is for our
sequentialize
function to always generate new
temporaries for the arguments, that way, whenever we load the
arguments for the function call, they will always be stored at an
address higher on the stack than the addresses that have been
overwritten before it. For instance, this function would instead be:
let x = 7
def f(a,b,c):
if b:
x * a
else:
c
in
let y = x * 2 in
let z = print(y) in
let a = 5 in
let b = true in
let c = y in
f(a, b, c)
in which case the call would be compiled as
mov RAX, [RSP - 8 * 4] ;; load a
mov [RSP - 8 * 2], RAX ;; a = a
mov RAX, [RSP - 8 * 5] ;; load b
mov [RSP - 8 * 3], RAX ;; b = b
mov RAX, [RSP - 8 * 6] ;; load c
mov [RSP - 8 * 4], RAX ;; c = c
jmp f ;; tail call f
A little wasteful to be sure, but the good news is that clearly we can
never run into the problem from before: all of these locals are stored
at addresses that have not been overwritten when they are used. In
this case, the variable a
will be overwritten when we
store the value for c
, but at that point it is ok
because we no longer need its value.
Exercise
We could have avoided this problem by loading the arguments in a different order (i.e., loading the third argument first). Can you come up with an example program where this is not possible?
Exercise
Come up with an algorithm that minimizes the total number of extra temporaries needed to perform a tail call.
1Note that technically to implement
finite-state automata we would need somethink like a
readBool()
or readInt()
built-in that
reads the next character from the input string so that we could take
arbitrarily large strings as input.