We will start with a simple recursive function in Scheme that just adds up the first n positive integers:
(define sum
(lambda (n)
(if (= n 0)
0
(+ n (sum (- n 1))))))
(sum 5)
Next, we rewrite sum
in continuation-passing style (CPS):
(define sum-cps
(lambda (n k)
(if (= n 0)
(k 0)
(sum-cps (- n 1)
(lambda (value)
(k (+ n value)))))))
To run sum-cps
we need to provide the initial continuation (lambda (value) value)
(sum-cps 5 (lambda (value) value))
For convenience, our top-level sum
function will just call sum-cps
with the initial continuation:
(define sum
(lambda (n)
(sum-cps n (lambda (value) value))))
(sum 5)
Notice that our continuations are represented as anonymous lambda functions. Here we are taking advantage of Scheme's ability to easily create and pass around higher-order functions. In the above code, there are two such continuation functions: (lambda (value) value)
and (lambda (value) (k (+ n value)))
. In general, a lambda expression that represents a continuation may or may not contain free variables, depending on the context in which it arises. For example, the continuation (lambda (value) value)
has no free variables, since value
is bound by the lambda's formal parameter, while the continuation (lambda (value) (k (+ n value)))
has two free variables: k
and n
.
We now change the representation of continuations from Scheme functions to Scheme data structures, which will just be ordinary lists (also referred to as continuation records). Each lambda expression that represents a continuation will get replaced by a call of the form (make-cont <label> x y z ...)
, where <label>
is a unique identifier for the continuation being replaced, and x y z ... are the free variables (if any) of the lambda expression. Here are the new definitions of sum
and sum-cps
, where the (lambda (value) value)
continuation has been labeled <cont-1>
and the (lambda (value) (k (+ n value)))
continuation has been labeled <cont-2>
:
(define sum
(lambda (n)
(sum-cps n (make-cont '<cont-1>))))
(define sum-cps
(lambda (n k)
(if (= n 0)
(k 0)
(sum-cps (- n 1) (make-cont '<cont-2> n k)))))
The make-cont
function simply packages the <cont-N>
label along with the values of any free variables x, y, z, etc. from the original lambda expression into a Scheme list of the form (continuation <cont-N> x y z ...)
.
(define make-cont
(lambda args
(cons 'continuation args)))
Here are some examples of continuation records:
(make-cont '<cont-1>)
(make-cont '<cont-2> 'free1 'free2)
(make-cont '<cont-42> 10 20 30 40)
To make this code actually work, we need to define a new function apply-cont
which takes a continuation k represented in this new form (a Scheme list), along with a value, and does the same thing to value as the previous functional representation of the continuation did, according to the continuation type specified by the label. Any free variables that are needed are just retrieved from the fields of the continuation record.
We also need to change all applications of the form (k value)
to (apply-cont k value)
since k is no longer represented as a Scheme function. Here is the resulting code:
(define sum
(lambda (n)
(sum-cps n (make-cont '<cont-1>))))
(define sum-cps
(lambda (n k)
(if (= n 0)
(apply-cont k 0)
(sum-cps (- n 1) (make-cont '<cont-2> n k)))))
(define apply-cont
(lambda (k value)
;; k is now a list of the form (continuation <label> x y z ...)
(let ((label (cadr k))
(fields (cddr k))) ;; saved values of free variables
(cond
((eq? label '<cont-1>) value)
((eq? label '<cont-2>)
(let ((n (car fields))
(k (cadr fields)))
(apply-cont k (+ n value))))
(else (error "invalid continuation label"))))))
(sum 5)
This code works, but notice that whenever apply-cont
receives a continuation k to apply to some value, it uses a cond
to decide what to do, based on k's label. In this example, there are only two possible continuation labels, but larger programs could have many more. To make the code more efficient, we define separate functions corresponding to each label, and then simply apply the appropriate label function to value and the saved free variables. Instead of the labels being Scheme symbols, they are now Scheme functions that can be called directly.
(define apply-cont
(lambda (k value)
(let ((label (cadr k))
(fields (cddr k)))
(label value fields))))
(define <cont-1>
(lambda (value fields)
value))
(define <cont-2>
(lambda (value fields)
(let ((n (car fields))
(k (cadr fields)))
(apply-cont k (+ n value)))))
(define sum
(lambda (n)
(sum-cps n (make-cont <cont-1>)))) ;; removed ' from <cont-1>
(define sum-cps
(lambda (n k)
(if (= n 0)
(apply-cont k 0)
(sum-cps (- n 1) (make-cont <cont-2> n k))))) ;; removed ' from <cont-2>
(sum 5)
(make-cont <cont-1>)
(make-cont <cont-2> 'free1 'free2)
We have now changed the representation of continuations from Scheme functions to data structures, so our sum-cps
function could in principle be implemented easily in languages that do not support higher-order functions. However, there is another feature of Scheme that we are implicitly relying on, namely the fact that Scheme imposes no limit on the depth of its recursion stack, which enables us to call sum
with arbitrarily large values of n.
Unfortunately, many other languages (including Python, Java, and C++) impose a maximum depth on the recursion stack, which will cause recursive programs to crash for sufficiently large input values. In order to implement our program in such a language, we need to further transform it to avoid growing the recursion stack. We can accomplish this by passing information to functions through a set of global registers rather than via function arguments. Each function call of the form (func arg1 arg2 arg3 ...)
will be replaced by a call of the form(func)
, which in most other languages can be simulated by a simple "goto" instruction that does not grow the stack. The appropriate registers will be initialized to arg1, arg2, arg3, etc., prior to calling (func)
.
We first define a set of registers to be used for passing information to functions in place of formal parameters. The registers needed by a particular function are determined by that function's formal parameters. For example, here are the sum-cps
function and the top-level sum
function, which calls sum-cps
with n and the initial <cont-1> continuation:
(define sum-cps
(lambda (n k)
(if (= n 0)
(apply-cont k 0)
(sum-cps (- n 1) (make-cont <cont-2> n k)))))
(define sum
(lambda (n)
(sum-cps n (make-cont <cont-1>))))
We will rewrite sum-cps
as a function of no arguments, and create new registers n_reg
and k_reg
for passing the necessary information to sum-cps
. In addition, wherever we call (sum-cps)
, we must first assign the appropriate values to the registers. We also must change all references to the formal parameters n and k in sum-cps
to n_reg and k_reg, respectively:
(define n_reg 'undefined)
(define k_reg 'undefined)
(define sum-cps
(lambda ()
(if (= n_reg 0)
(apply-cont k_reg 0)
(begin
;; order of assignments matters!
(set! k_reg (make-cont <cont-2> n_reg k_reg))
(set! n_reg (- n_reg 1))
(sum-cps)))))
(define sum
(lambda (n)
(set! k_reg (make-cont <cont-1>))
(set! n_reg n)
(sum-cps)))
(sum 5)
Notice that in the definition of sum-cps
above, the order of the assignment statements matters. The n_reg assignment must happen after the k_reg assignment, otherwise the value of n_reg saved in the <cont-2> continuation record will be incorrect.
We also need to transform apply-cont
, <cont-1>
, and <cont-2>
in a similar way. Here are their current definitions:
(define apply-cont
(lambda (k value)
(let ((label (cadr k))
(fields (cddr k)))
(label value fields))))
(define <cont-1>
(lambda (value fields)
value))
(define <cont-2>
(lambda (value fields)
(let ((n (car fields))
(k (cadr fields)))
(apply-cont k (+ n value)))))
We will use the registers k_reg
and value_reg
to pass information to apply-cont
. To pass information to the continuation label functions <cont-1>
and <cont-2>
, we will use the registers value_reg
and fields_reg
:
;; additional registers
(define value_reg 'undefined)
(define fields_reg 'undefined)
(define apply-cont
(lambda ()
(let ((label (cadr k_reg))
(fields (cddr k_reg)))
;; set up value_reg and fields_reg before calling (label)
(set! value_reg value_reg)
(set! fields_reg fields)
(label))))
(define <cont-1>
(lambda ()
value_reg))
(define <cont-2>
(lambda ()
(let ((n (car fields_reg))
(k (cadr fields_reg)))
;; set up k_reg and value_reg before calling (apply-cont)
(set! k_reg k)
(set! value_reg (+ n value_reg))
(apply-cont))))
We also need to rewrite the call to apply-cont
that appears in the definition of sum-cps
:
(define sum-cps
(lambda ()
(if (= n_reg 0)
(begin
;; set up k_reg and value_reg before calling (apply-cont)
(set! k_reg k_reg)
(set! value_reg 0)
(apply-cont))
(begin
(set! k_reg (make-cont <cont-2> n_reg k_reg))
(set! n_reg (- n_reg 1))
(sum-cps)))))
(sum 5)
Notice that the assignment statements (set! k_reg k_reg)
and (set! value_reg value_reg)
in sum-cps
and apply-cont
are unnecessary, so we can simply remove them from the code.
(define sum-cps
(lambda ()
(if (= n_reg 0)
(begin
(set! value_reg 0)
(apply-cont))
(begin
(set! k_reg (make-cont <cont-2> n_reg k_reg))
(set! n_reg (- n_reg 1))
(sum-cps)))))
(define apply-cont
(lambda ()
(let ((label (cadr k_reg))
(fields (cddr k_reg)))
(set! fields_reg fields)
(label))))
(sum 5)
Although the above code works, it still generates an arbitrarily long chain of tail-recursive function calls of the form (sum-cps)
, (apply-cont)
, (label)
, etc., each of which essentially acts like a "goto" instruction. This is not a problem in Scheme, since no limit is imposed on the length of such call-chains. However, in other languages it could be a problem. Therefore we need to break the chain of function calls into single steps. This is accomplished through the use of a trampoline, which is essentially a while-loop that performs the computation one step at a time and avoids building up a chain of function calls.
The trampoline uses a special register called pc
, which contains the next function to call on each loop cycle. Calling the function simply updates the registers appropriately for the next loop cycle. The pc
register itself also gets updated on each cycle. This process continues until the pc
register becomes empty, at which point the final result of the computation will be available in the register final_reg
.
;; additional registers
(define pc 'undefined)
(define final_reg 'undefined)
;; equivalent to a while-loop
(define trampoline
(lambda ()
(if pc
(begin
(pc)
(trampoline))
final_reg)))
Instead of calling a function directly such as (sum-cps)
, we replace the function call with (set! pc sum-cps)
, which sets the pc
register to the sum-cps
function itself. The trampoline will then invoke it within the loop. All functions other than the trampoline simply execute if-statements and assignments, without ever calling another function directly. Here are the transformed versions of the other functions, showing the changes made to the code:
(define sum-cps
(lambda ()
(if (= n_reg 0)
(begin
(set! value_reg 0)
(set! pc apply-cont)) ;; changed
(begin
(set! k_reg (make-cont <cont-2> n_reg k_reg))
(set! n_reg (- n_reg 1))
(set! pc sum-cps))))) ;; changed
(define apply-cont
(lambda ()
(let ((label (cadr k_reg))
(fields (cddr k_reg)))
(set! fields_reg fields)
(set! pc label)))) ;; changed
(define <cont-1>
(lambda ()
(set! final_reg value_reg) ;; changed
(set! pc #f))) ;; added
(define <cont-2>
(lambda ()
(let ((n (car fields_reg))
(k (cadr fields_reg)))
(set! k_reg k)
(set! value_reg (+ n value_reg))
(set! pc apply-cont)))) ;; changed
The top-level function sum
initializes the registers and then starts the trampoline, which runs the computation to completion.
(define sum
(lambda (n)
(set! k_reg (make-cont <cont-1>))
(set! n_reg n)
(set! pc sum-cps)
(trampoline)))
(sum 5)
The complete register machine code is given below:
;; global registers
(define n_reg 'undefined)
(define k_reg 'undefined)
(define value_reg 'undefined)
(define fields_reg 'undefined)
(define pc 'undefined)
(define final_reg 'undefined)
(define trampoline
(lambda ()
(if pc
(begin
(pc)
(trampoline))
final_reg)))
(define make-cont
(lambda args
(cons 'continuation args)))
(define apply-cont
(lambda ()
(let ((label (cadr k_reg))
(fields (cddr k_reg)))
(set! fields_reg fields)
(set! pc label))))
(define <cont-1>
(lambda ()
(set! final_reg value_reg)
(set! pc #f)))
(define <cont-2>
(lambda ()
(let ((n (car fields_reg))
(k (cadr fields_reg)))
(set! k_reg k)
(set! value_reg (+ n value_reg))
(set! pc apply-cont))))
(define sum-cps
(lambda ()
(if (= n_reg 0)
(begin
(set! value_reg 0)
(set! pc apply-cont))
(begin
(set! k_reg (make-cont <cont-2> n_reg k_reg))
(set! n_reg (- n_reg 1))
(set! pc sum-cps)))))
;; top-level function
(define sum
(lambda (n)
(set! k_reg (make-cont <cont-1>))
(set! n_reg n)
(set! pc sum-cps)
(trampoline)))
(sum 5)
(sum 5000)
If we tried to implement sum
recursively in Python, it would crash for values of n that exceed Python's recursion depth limit. For example:
%%python
def sum_recursive(n):
if n == 0:
return 0
else:
return n + sum_recursive(n - 1)
%%python
sum_recursive(5)
%%python
sum_recursive(5000)
Fortunately, our Scheme register machine does not grow the recursion stack, and can be easily translated directly into Python. All we need to do is define Python versions of the car
and cdr
primitives:
%%python
def car(lst):
return lst[0]
def cdr(lst):
return lst[1:]
def cadr(lst):
return car(cdr(lst))
def cddr(lst):
return cdr(cdr(lst))
Here is the register machine code translated into Python:
%%python
# global registers
n_reg = None
k_reg = None
value_reg = None
fields_reg = None
pc = None
final_reg = None
def trampoline():
while pc:
pc()
return final_reg
def make_cont(*args):
return ("continuation",) + args
def apply_cont():
global fields_reg, pc
label = cadr(k_reg)
fields = cddr(k_reg)
fields_reg = fields
pc = label
def cont_1():
global final_reg, pc
final_reg = value_reg
pc = False
def cont_2():
global k_reg, value_reg, pc
n = car(fields_reg)
k = cadr(fields_reg)
k_reg = k
value_reg = n + value_reg
pc = apply_cont
def sum_cps():
global value_reg, pc, k_reg, n_reg
if n_reg == 0:
value_reg = 0
pc = apply_cont
else:
k_reg = make_cont(cont_2, n_reg, k_reg)
n_reg = n_reg - 1
pc = sum_cps
# top-level function
def sum(n):
global k_reg, n_reg, pc
k_reg = make_cont(cont_1)
n_reg = n
pc = sum_cps
return trampoline()
The Python version is no longer subject to the recursion stack depth limit. It is also much faster than the Scheme version:
%%python
sum(5)
%%python
sum(5000)
%%python
sum(100000)