Lecture 8: Local Function Definitions and Tail Calls
1 Growing the language:   local function definitions, tail calls
1.1 Concrete syntax
1.2 Abstract syntax for Calls
1.3 Defining our own functions
1.4 Semantics
1.5 Code Generation:   Function Definitions
1.6 Code Generation:   Tail Calls
8.10

Lecture 8: Local Function Definitions and Tail Calls

So far we’ve developed a nice suite of basic features for our language:

However, the computational power of our language is fundamentally limited: all of our programs use finite space and terminate in finite time. Over the next few weeks we will substantially increase the power of our snake languages by lifting both of these restrictions. We will use functions and looping constructs to lift our restriction of terminating programs and we will use the stack and the heap to give us access to unbounded memory. Today we will start this journey towards Turing-completeness by adding function definitions and a restricted form of calls that can be compiled directly to jumps. This increases the power of our language to that of finite-state automata1Note that technically to implement finite-state automata we would need somethink like a readBool() or readInt() built-in that reads the next character from the input string so that we could take arbitrarily large strings as input..

1 Growing the language: local function definitions, tail calls

Reminder: Every time we enhance our source language, we need to consider several things:

  1. Its impact on the concrete syntax of the language

  2. Examples using the new enhancements, so we build intuition of them

  3. Its impact on the abstract syntax and semantics of the language

  4. Any new or changed transformations needed to process the new forms

  5. Executable tests to confirm the enhancement works as intended

1.1 Concrete syntax

We’ll start with concrete syntax. A function call is a new form of expression that starts with a function name and takes zero or more comma-separated expressions as arguments.

‹expr›: ... | IDENTIFIER ( ‹exprs› ) | IDENTIFIER ( ) ‹exprs›: ‹expr› | ‹expr› , ‹exprs›

To account for function definitions we have a choice to make. The easiest to implement would be to simply have a sequence of top level function definitions followed by one main body expression. But we will do something more flexible: we will allow for local function definitions inside arbitrary expressions.

‹expr›: | ... | ‹decls› in ‹expr› ‹decls›: | ‹decl› | ‹decls› and ‹decl› ‹decl›: | def IDENTIFIER ( ‹ids› ) : ‹expr› | def IDENTIFIER ( ) : ‹expr› ‹ids›: | IDENTIFIER | IDENTIFIER , ‹ids›

Allowing for local function definitions like this is very convenient for programming. For instance, we can use recursive functions in lieu of looping constructs, like in this factorial example: def fac(x): def loop(x, acc): if x == 0: acc else: loop(x - 1, acc * x) in loop(x, 1) in fac(10)

If we only had top-level function definitions, the programmer would have to lift loops (and nested loops) like this to the top-level, polluting the global namespace and making the program less clear overall.

Our new form also allows for mutually recursive function definitions. For instance, here is how we might implement a function that checks if the input is even or odd without having access to a modulo or div operator:

def even(x): def evn(n): if n == 0: true else: odd(n - 1) and def odd(n): if n == 0: false else: even(n - 1) in if x >= 0: evn(x) else: evn(-1 * x) in even(24)

For our examples, let’s design max, which takes two numeric arguments and returns the larger of the two.

def max(x,y): if x >= y: x else: y end max(17,31)

should evaluate to 31.

1.2 Abstract syntax for Calls

First, let’s cover the calling side of the language.

Do Now!

What should the semantics of f(e1, e2, ..., en) be? How should we represent this in our Exp data definition? What knock-on effects does it have for the transformation passes of our compiler?

The first thing we might notice is that attempting to call an unknown function should be prohibited — this is analogous to the scope-checking we already do for variable names, and should be done at the same time. Indeed, we can generalize our scope-checking to a suite of well-formedness checks, that assert that the program we’re compiling is “put together right”. (These static checks include static type-checking, which we are not yet doing, and in fact many popular languages these days are focusing heavily on improving the precision and efficiency of their well-formedness checking as a way to improve programmer efficiency and correctness.) Checking for undefined functions implies that we need something like an environment of known functions. We don’t yet know what that environment should contain, but at a minimum it needs to contain the names of the functions we support.

Do Now!

What other programmer mistakes should we try to catch with well-formedness checking? What new mistakes are possible with function calls?

What should happen if a programmer tries to call max(1) or max(1, 2, 3)? Certainly nothing good can happen at runtime if we allowed this to occur. Fortunately, we can track enough information to prevent this at well-formedness time, too. Our function environment should keep track of known function names and their arities. Then we can check every function call expression and see whether it contains the correct number of actual arguments.

We need more examples:

Source

     

Output

max(1)

     

Compile Error: expected 2 arguments, got 1

max(1, 2, 3)

     

Compile Error: expected 2 arguments, got 3

unknown(1, 2)

     

Compile Error: unknown function 'unknown'

To represent call expressions in our AST, we just need to keep track of the function name, the argument list, and any tag information:


enum Exp<Ann> {
  ...
  Call(String, Vec<Exp<Ann>>, Ann),
}

We need to consider how our expression evaluates, which in turn means considering how it should normalize into sequential form.

Do Now!

What are the design choices here?

Since Exp::Call expressions are compound, containing multiple subexpressions, they probably should normalize similar to how we normalize Prim2 expressions: the arguments should all be made immediate.

pub enum SeqExp<Ann> {
  ...
  Call(String, Vec<ImmExp>, Ann),
}

We have at least two possible designs here, for how to normalize these expressions: we can choose a left-to-right or right-to-left evaluation order for the arguments. For consistency with infix operators, we’ll choose a left-to-right ordering.

Do Now!

What tiny example program, using only the expressions we have so far, would demonstrate the difference between these two orderings?

Do Now!

Extend sequentialization to handle Exp::Call.

1.3 Defining our own functions

Now that our programs include function definitions and a main expression, our AST representation now grows to match:

pub struct FunDecl<E, Ann> {
    pub name: String,
    pub parameters: Vec<String>,
    pub body: E,
    pub ann: Ann,
}

pub struct Prog<E, Ann> {
    pub funs: Vec<FunDecl<E, Ann>>,
    pub main: E,
    pub ann: Ann,
}

Here we are abstract over annotations, as well as the type of expressions. This allows us to instantiate to Prog<Exp> or Prog<SeqExp> to encode whether or not the expressions are in sequential form.

1.4 Semantics

Do Now!

What new semantic concerns do we have with providing our own definitions?

As soon as we introduce a new form of definition into our language, we need to consider scoping concerns. One possibility is to declare that earlier definitions can be used by later ones, but not vice versa. This possibility is relatively easy to implement, but restrictive: it prevents us from having mutually-recursive functions. Fortunately, because all our functions are statically defined, supporting mutual recursion is not all that difficult; the only complication is getting the well-formedness checks to work out correctly.

Exercise

Do so.

Additionally, the bodies of function definitions need to consider scope as well.

def sum3(x, y, z):
  a + b + c
end

x + 5

This program refers to names that are not in scope: a, b and c are not in scope within sum3, and x is not in scope outside of it.

def f(x): x end

def f(y): y end

f(3)

Repeatedly defining functions of the same name should be problematic: which function is intended to be called?

def f(x, x): x end

f(3, 4)

Having multiple arguments of the same name should be problematic: which argument should be returned here?

1.5 Code Generation: Function Definitions

How should we compile function definitions? Well, each function definition needs to have associated code that we can jump to to execute the function, so each function will need to produce a label followed by the code corresponding to the body. But we also need to consider the arguments of the function. The simplest thing to do is to extend our current treatment of local variables in let bindings: arguments will just be new local variables stored on the stack. So in an expression like

let a = 1 in def f(b,c,d): let x = 2 in def g(y,z): ... in ... in ...

When we compile this code we need to determine where all of the local variables go. For a first attempt we can extend our treatment of let-bound variables: a goes in the first free stack slot [RSP - 8 * 1] and b goes in the next [RSP - 8 * 2], and c,d,x,y,z go in 3,4,5,6,7 respectively.

Other than that, the code generation for function definitions is straightforward: each function definition

def f(x,y,z): e end

can be simply compiled to a label followed by the code for the body

f:
    ;; code for e

We just need to make sure the labels we use for the function names are all unique, using similar techniques for other labels we generate.

1.6 Code Generation: Tail Calls

Now what about the function call? Well it turns out that in all of the example programs so far, the function calls have had a very special property: they are always the last thing that the caller does: for instance in our implementation of even, the main expression ends by calling even(24), each branch of even ends by calling evn and the functions evn and odd in each branch either return a value or end with a call to the other. We refer to these calls as tail calls because they are at the "tail end" of the expression. These are very nicely behaved from an implementation perspective: if the last thing an expression does is call another function, then that expression should return whatever value that function does, meaning that the function being called (the callee) should use the same return address that the caller has on the stack. Furthermore, all of the local variables of the caller on the stack are no longer needed, so we can safely overwrite all of them with the arguments to the callee and the callee can use the space for its stack frame. For this reason, tail calls can be implemented extremely efficiently as simply jmp instructions, after the caller places the arguments where the callee expects them to be on the stack.

On the other hand, not every function call in a program is a tail call, for instance consider this implementation of factorial:

def factorial(x): if x == 0: 1 else: x * factorial(x - 1) in factorial(6)

In this function, the call factorial(6) is a tail call, because it’s the last thing the main expression does, but the recursive call factorial(x - 1) is not a tail call, because we have to do something else after it returns: multiply the result by x. Non-tail calls will be compiled differently, so we will cover them later. For today, we will only support tail calls.

Regardless of how tail and non-tail calls are implemented, since they are compiled differently, we need to be able to precisely determine which calls are tail or not. To do this we generalize from just the calls to define when any sub-expression is in tail position:

  1. The expression of our program is in tail position.

  2. The body of a function is in tail position.

  3. If a let-binding is in tail position, then (a) its body is in tail position, but (b) the bindings themselves are not.

  4. If a conditional is in tail position, then (a) its branches are in tail position, but (b) the condition itself is not.

  5. The operands to an operator are never in tail position.

Visually, green expressions are always in tail position, yellow expressions are potentially in tail position, and red expressions are never in tail position:

pub struct Prog<E, Ann> {
    pub funs: Vec<FunDecl<E, Ann>>,
    pub ~hl:3:s~main: E~hl:3:e~
    pub ann: Ann,
}
pub struct FunDecl<E, Ann> {
    pub name: String,
    pub parameters: Vec<String>,
    pub ~hl:3:s~body: E~hl:3:e~,
    pub ann: Ann,
}
pub enum SeqExp<Ann> {
    Imm(~hl:2:s~ImmExp~hl:2:e~, Ann),
    Prim1(Prim1, ~hl:2:s~ImmExp~hl:2:e~, Ann),
    Prim2(Prim2, ~hl:2:s~ImmExp~hl:2:e~, ~hl:2:s~ImmExp~hl:2:e~, Ann),
    Let {
        var: String,
        ~hl:2:s~bound_exp: Box<SeqExp<Ann>>~hl:2:e~,
        ~hl:1:s~body: Box<SeqExp<Ann>>~hl:1:e~,
        ann: Ann,
    },
    If {
        ~hl:2:s~cond: ImmExp~hl:2:e~,
        ~hl:1:s~thn: Box<SeqExp<Ann>>~hl:1:e~,
        ~hl:1:s~els: Box<SeqExp<Ann>>~hl:1:e~,
        ann: Ann,
    },
    ~hl:1:s~Call~hl:1:e~(String, ~hl:2:s~Vec<ImmExp>~hl:2:e~, Ann),
}

We can codify this, if we so choose, as a kind of tagging operation:

fn mark_tails<Ann>(e: &SeqExp<Ann>, is_tail: bool) -> SeqExp<bool> {
    match e {
        SeqExp::Imm(i, _) => SeqExp::Imm(i.clone(), is_tail),
        SeqExp::Prim1(op, i, _) => SeqExp::Prim1(*op, i.clone(), is_tail),
        ...
        SeqExp::Let {
            var,
            bound_exp,
            body,
            ..
        } => SeqExp::Let {
            var: var.clone(),
            bound_exp: Box::new(mark_tails(bound_exp, false)),
            body: Box::new(mark_tails(&body, is_tail)),
            ann: is_tail,
        },
        SeqExp::If { cond, thn, els, .. } => SeqExp::If {
            cond: cond.clone(),
            thn: Box::new(mark_tails(thn, is_tail)),
            els: Box::new(mark_tails(els, is_tail)),
            ann: is_tail,
        },
    }
}

In practice we probably don’t need to, and instead can just carry along a boolean flag through our compile_with_env function that keeps track of our tail-position status:

fn compile_with_env<'exp>(
    e: &'exp SeqExp<u32>,
    .., // other arguments
    is_tail: bool) // true when the expression is in tail position
    -> Vec<Instr> {
    match e {
      ...
      SeqExp::Let { var, bound_exp, body, ann } => {
        ...
        let bound_exp_is = compile_with_env(bound_exp, ..., false);
        ...
        let body_is = compile_with_env(body, ..., is_tail)
      }
      SeqExp::Call(fun, args, ann) => {
        if is_tail {
          // generate a tail call
        } else {
          // generate a non-tail call
        }
      }
      ...
    }
}

Ok, now that we know for sure we have a function call in tail position, how do we compile it? For a simple start, let’s say we are compiling the following expression: let x = 7 def f(a,b,c): if b: x * a else: c in let y = x * 2 in let z = print(y) in f(5, true, 13)

Following our scheme for storing local variables, x, y will be placed at [RSP - 8 * 1], [RSP - 8 * 2]. When we call f(5,true,13), we want to transfer control to the code where the body of f is implemented, but with the arguments a,b,c having the correct values. So when we call f, we need to remember that in the body of f, we have access to 4 local variables; x,a,b,c, which are placed in the first 4 stack slots. So to compile the function call f(5,true,13), we simply mov the arguments to the appropriate place on the stack and then jmp to f:

mov [RSP - 8 * 2], 10       ;; a = 5
mov [RSP - 8 * 3], 0x7FF... ;; b = true
mov [RSP - 8 * 4], 26       ;; c = 13
jmp f                       ;; tail call f

Note here that a is at RSP - 8 * 2 because f expects x to be at RSP - 8 * 1. This works for constants, but there is one tricky situation we should be careful of more generally. Consider the following seemingly innocuous change: change the call to f(5, true, y). Then if we naïvely try to compile it the same way something horribly wrong happens:

mov [RSP - 8 * 2], 10       ;; a = 5
mov [RSP - 8 * 3], 0x7FF... ;; b = true
mov RAX, [RSP - 8 * 2]      ;; load y
mov [RSP - 8 * 4], RAX      ;; c = y
jmp f                       ;; tail call f

Do Now!

What went horribly wrong?

Exercise

Try to fix it.

Since a tail call re-uses the current stack frame in place, we overwrote the value of y on the stack before we were able to use it as an argument! The effect is that the call is accidentally compiled as if it were f(5,true,5).

How can we avoid this happening? Well the easiest way is for our sequentialize function to always generate new temporaries for the arguments, that way, whenever we load the arguments for the function call, they will always be stored at an address higher on the stack than the addresses that have been overwritten before it. For instance, this function would instead be:

let x = 7 def f(a,b,c): if b: x * a else: c in let y = x * 2 in let z = print(y) in let a = 5 in let b = true in let c = y in f(a, b, c)

in which case the call would be compiled as

mov RAX, [RSP - 8 * 4]      ;; load a
mov [RSP - 8 * 2], RAX       ;; a = a
mov RAX, [RSP - 8 * 5]      ;; load b
mov [RSP - 8 * 3], RAX       ;; b = b
mov RAX, [RSP - 8 * 6]      ;; load c
mov [RSP - 8 * 4], RAX       ;; c = c
jmp f                       ;; tail call f

A little wasteful to be sure, but the good news is that clearly we can never run into the problem from before: all of these locals are stored at addresses that have not been overwritten when they are used. In this case, the variable a will be overwritten when we store the value for c, but at that point it is ok because we no longer need its value.

Exercise

We could have avoided this problem by loading the arguments in a different order (i.e., loading the third argument first). Can you come up with an example program where this is not possible?

Exercise

Come up with an algorithm that minimizes the total number of extra temporaries needed to perform a tail call.

1Note that technically to implement finite-state automata we would need somethink like a readBool() or readInt() built-in that reads the next character from the input string so that we could take arbitrarily large strings as input.