Continuation-Passing Style in Standard ML
14 Dec 2022
This post is some notes taken from Andrew W. Appel’s 1992 book, Compiling with Continuations1.
A continuation is a function that expresses “what to do next”2. A more formal definition of continuation is given by Ferrante and Allard (1996): it is a data structure which, in its simplest form, contains:
- A function (or pointer to a function),
- The arguments of this function,
- A pointer to the continuation that is the destination of the result(s) of applying the function to these arguments,
- An environment variable holding the current execution state.
Continuation-passing style (CPS) is a program notation with clean semantics that makes every aspect of control flow and data flow explicit. It breaks the program into a set of continuations. Consider the following program written in ML:
(* An ML program that computes the product of all primes
less than or equal to a positive integer n. *)
fun prodprimes(n) =
if n = 1
then 1
else
if isprime(n)
then n * prodprimes(n - 1) (* line 6 *)
else prodprimes(n - 1) (* line 7 *)
A CPS version of the program looks like this:
(* Control points c, k, j, h are continuation functions.
Data labels b, p, a, m, q, i are variables. *)
fun prodprimes(n, c) =
if n = 1
then c(1)
else
let
fun k(b) =
if b = true
then
let
fun j(p) =
let
val a = n * p
in
c(a)
end
val m = n - 1
in
prodprimes(m, j)
end
else
let (* line 21 *)
fun h(q) = c(q)
val i = n - 1
in
prodprimes(i, h)
end (* line 26 *)
in
isprime(n, k)
end
In the original version, after the function isprime
is called, it is passed a return address k
, and it returns a Boolean value b
. The first call to prodprimes
in line 6 returns to a point j
with an integer p
, but the second call to prodprimes
in line 7 returns to a point h
with an integer q
. The first computation of n - 1
is put in a temporary variable m
, and a second one in i
. The function prodprimes
is given a continuation c
as one of its argument in the CPS version, and when prodprimes
has computed its result a
, it will continue by applying c
to a
. Note than we can perform tail-recursion elimination by rewriting line 21-26 into:
let
val i = n - 1
in
prodprimes(i, c)
end
Table of Contents
The CPS Datatype
The CPS language is meant to model the program executed by a von Neumann machine, which likes to do just one thing at a time, with all the arguments to an operation ready-at-hand in registers. Therefore, arguments to a function (or a primitive operator) are always atomic (variables or constants, but not subexpressions); a function application can never be an argument of another application. Each continuation expression takes zero or more atomic arguments, binds zero or more results, and continues via zero or more continuation expressions.
signature CPS = sig
eqtype var
(* All the different kinds of atomic arguments that can be provided to the CPS operators *)
datatype value = VAR of var
| LABEL of var
| INT of int
| REAL of string
| STRING of string
datatype accesspath = OFFp of int
| SELp of int * accesspath
datatype primop = * | + | - | div | ~
| ieql | ineq | < | <= | > | >= | rangechk
| ! | subscript | ordof
| := | unboxedassign | update | unboxedupdate | store
| makeref | makerefunboxed | alength | slength
| gethdlr | sethdlr
| boxed
| fadd | fsub | fdiv | fmul
| feql | fneq | fge | fgt | fle | flt
| rshift | lshift | orb | andb | xorb | notb
datatype cexp = RECORD of (value * accesspath) list * var * cexp
| SELECT of int * value * var * cexp
| OFFSET of int * value * var * cexp
| APP of value * value list
| FIX of (var * var list * cexp) list * cexp
| SWITCH of value * cexp list
| PRIMOP of primop * value list * var list * cexp list
end
The practical reason for expressing reals as strings is to make the CPS language independent of any particular machine representation. All function calls in CPS are “tail” calls; that is, they do not return to the calling function. This can be seen from the form of the APP
continuation expression, which does not have any continuation expression as a subexpression.
Consider the following CPS ML code:
let
fun f(x, k) = k(2 * x + 1)
fun k1(i) =
let
fun k2(j) = r(i * j)
in
f(c + d, k2)
end
in
f(a + b, k1)
end
which is converted from the following ML code:
let
fun f(x) = 2 * x + 1
in
f(a + b) * f(c + d)
end
r
is the “rest” of the computation, to which the original version was to have returned its answer. Translate the code into the CPS language:
FIX( [( f, [x, k], PRIMOP( *, [INT 2, VAR x], [u], [
PRIMOP( +, [VAR u, INT 1], [v], [
APP( VAR k, [VAR v] )] )] ) ),
( k1, [i], FIX( [( k2, [j], PRIMOP( *, [VAR i, VAR j], [w], [
APP( VAR r, [VAR w] )] ) )],
PRIMOP( +, [VAR c, VAR d], [m], [
APP( VAR f, [VAR m, VAR k2] )] ) ) )],
PRIMOP( +, [VAR a, VAR b], [n], [
APP( VAR f, [VAR n, VAR k1] )] ) )
Continuation functions are bound by ordinary FIX
operators and passed as arguments to functions. The effect of evaluating FIX
\((\overrightarrow{f}, E)\) is to define the functions in the list \(\overrightarrow{f}\) and then to evaluate \(E\), which (usually) calls one or more of the \(f_{i}\) by means of the APP
operator, or passes one of the \(f_{i}\) as an argument to another function, or stores some of the \(f_{i}\) in a data structure.
When one function \(f\) is statically nested inside another function \(g\), \(f\) may refer to variables bound in \(g\) (call them free variables of \(f\)). In programming languages with higher-order functions like ML, the inner function can even be called after the outer function has returned. The usual implementation technique uses a closure data structure: a record containing the free variables of the inner function as well as a pointer to its machine code. A pointer to this record is made available to the machine code while it executes so that the free variables are accessible. By putting the code-pointer at a fixed offset of the record, users of the function need not know the format of the record or even its size in order to call the function. In “closure-passing style”, after a pass to gather free variable information, the converter traverses the CPS expression. At every FIX
that binds an escaping function (some of its call sites are unknown because it is passed as an argument, stored into a data structure, or returned as a result of a function call), a RECORD
is inserted to create an explicit closure record, and a new argument corresponding to the closure record is added to the argument list of each function. Free variables accessed within the body of the function are rewritten as SELECT
s from this argument.
The finite-register rule for the CPS is this:
- For compilation to a machine with \(k\) registers, no subexpression of the CPS may have more than \(k\) free variables.
- For compilation to a machine with \(k\) registers, no function of the CPS may have more than \(k\) formal parameters.
Semantics of the CPS
Operations that have side effects on memory are represented by means of a “store.” Each store value can be thought as a mapping from locations (addresses) to denotable values.
functor CPSsemantics(
structure CPS: CPS
val minint: int
val maxint: int
val minreal: real
val maxreal: real
val string2real: string -> real
eqtype loc (* locations *)
val nextloc: loc -> loc
(* arbitrarily(a, b) evaluates either to a or to b *)
val arbitrarily: 'a * 'a -> 'a
type answer
(* denotable values *)
datatype dvalue = RECORD of dvalue list * int
| INT of int
| REAL of real
| FUNC of dvalue list -> (loc * (loc -> dvalue) * (loc -> int)) -> answer
| STRING of string
| BYTEARRAY of loc list
| ARRAY of loc list
| UARRAY of loc list
(* the "current exception handler" *)
val handler_ref : loc
(* exception for arithmetic overflow *)
val overflow_exn : dvalue
(* exception for dividing by zero *)
val div_exn : dvalue
) : sig
val eval: CPS.var list * CPS.cexp ->
dvalue list ->
(loc * (loc -> dvalue) * (loc -> int)) ->
answer
end =
struct (* beginning of the functor body *)
type store = loc * (loc -> dvalue) * (loc -> int)
fun fetch ((_, f, _): store) (l: loc) = f l
fun upd ((n, f, g): store), l: loc, v: dvalue) =
(n, fn i => if i = l then v else f i, g)
fun fetchi ((_, _, g): store) (l: loc) = g l
fun updi ((n, f, g): store, l: loc, v: int) =
(n, f, fn i => if i = l then v else g i)
exception Undefined
fun eq(RECORD(a, i), RECORD(b, j)) = arbitrarily((i = j) andalso eqlist(a, b), false)
| eq(INT i, INT j) = (i = j)
| eq(REAL a, REAL b) = arbitrarily((a = b), false)
| eq(STRING a, STRING b) = arbitrarily((a = b), false)
| eq(BYTEARRAY nil, BYTEARRAY nil) = true
| eq(BYTEARRAY(a::_), BYTEARRAY(b::_)) = (a = b)
| eq(ARRAY nil, ARRAY nil) = true
| eq(ARRAY(a::_), ARRAY(b::_)) = (a = b)
| eq(UARRAY nil, UARRAY nil) = true
| eq(UARRAY(a::_), UARRAY(b::_)) = (a = b)
| eq(FUNC a, FUNC b) = raise Undefined
| eq(INT i, _) = arbitrarily(false, i < 0 orelse i > 255)
| eq(_, INT i) = arbitrarily(false, i < 0 orelse i > 255)
| eq(_, _) = false and eqlist(a::al, b::bl) = (eq(a, b) andalso eqlist(al, bl))
| eqlist(nil, nil) = true
fun do_raise exn s =
let
val FUNC f = fetch s handler_ref
in
f [exn] s
end
(* evaluates a number n, turns it into a dvalue list,
and hands it to its continuation argument c;
but if it overflows, raises an exception instead. *)
fun overflow(n: unit -> int, c: dvalue list -> store -> answer) =
if (n() >= minint andalso n() <= maxint)
handle Overflow => false
then c [INT(n())]
else do_raise overflow_exn
fun overflowr(n, c) =
if (n() >= minreal andalso n() <= maxreal)
handle Overflow => false
then c [REAL(n())]
else do_raise overflow_exn
fun evalprim(CPS.+: CPS.primop,
[INT i, INT j]: dvalue list,
[c]: (dvalue list -> store -> answer) list) = overflow(fn() => i + j, c)
| evalprim(CPS.-, [INT i, INT j], [c]) = overflow(fn() => i - j, c)
| evalprim(CPS.*, [INT i, INT j], [c]) = overflow(fn() => i * j, c)
| evalprim(CPS.div, [INT i, INT 0], [c]) = do_raise div_exn
| evalprim(CPS.div, [INT i, INT j], [c]) = overflow(fn() => i div j, c)
| evalprim(CPS.~, [INT i], [c]) = overflow(fn() => 0 - i, c)
| evalprim(CPS.<, [INT i, INT j], [t, f]) = if i < j then t[] else f[]
| evalprim(CPS.<=, [INT i, INT j], [t, f]) = if j < i then f[] else t[]
| evalprim(CPS.>, [INT i, INT j], [t, f]) = if j < i then t[] else f[]
| evalprim(CPS.>=, [INT i, INT j], [t, f]) = if i < j then f[] else t[]
| evalprim(CPS.ieql, [a, b], [t, f]) = if eq(a, b) then t[] else f[]
| evalprim(CPS.ineq, [a, b], [t, f]) = if eq(a, b) then f[] else t[]
| evalprim(CPS.rangechk, [INT i, INT j], [t, f]) =
if j < 0
then
if i < 0
then if i < j then t[] else f[]
else t[]
else
if i < 0
then f[]
else if i < j then t[]
else f[]
| evalprim(CPS.boxed, [INT i], [t, f]) =
if i < 0 orelse i > 255
then arbitrarily(t[], f[])
else f[]
| evalprim(CPS.boxed, [RECORD _], [t, f]) = t[]
| evalprim(CPS.boxed, [STRING _], [t, f]) = t[]
| evalprim(CPS.boxed, [ARRAY _], [t, f]) = t[]
| evalprim(CPS.boxed, [UARRAY _], [t, f]) = t[]
| evalprim(CPS.boxed, [BYTEARRAY _], [t, f]) = t[]
| evalprim(CPS.boxed, [FUNC _], [t, f]) = t[]
| evalprim(CPS.!, [a], [c]) = evalprim(CPS.subscript, [a, INT 0], [c])
| evalprim(CPS.subscript, [ARRAY a, INT n], [c]) =
(fn s => c [fetch s (nth(a, n))] s)
| evalprim(CPS.subscript, [UARRAY a, INT n], [c]) =
(fn s => c [INT(fetchi s (nth(a, n)))] s)
| evalprim(CPS.subscript, [RECORD(a, i), INT j], [c]) =
c [nth(a, i + j)]
| evalprim(CPS.ordof, [STRING a, INT i], [c]) =
c [INT(String.ordof(a, i))]
| evalprim(CPS.ordof, [BYTEARRAY a, INT i], [c]) =
(fn s => c [INT(fetchi s(nth(a, i)))] s)
| evalprim(CPS.:=, [a, v], [c]) = evalprim(CPS.update, [a, INT 0, v], [c])
| evalprim(CPS.update, [ARRAY a, INT n, v], [c]) =
(fn s => c [] (upd(s, nth(a, n), v)))
| evalprim(CPS.update, [UARRAY a, INT n, INT v], [c]) =
(fn s => c [] (updi(s, nth(a, n), v)))
| evalprim(CPS.unboxedassign, [a, v], [c]) = evalprim(CPS.unboxedupdate, [a, INT 0, v], [c])
| evalprim(CPS.unboxedupdate, [ARRAY a, INT n, INT v], [c]) =
(fn s => c [] (upd(s, nth(a, n), INT v)))
| evalprim(CPS.unboxedupdate, [UARRAY a, INT n, INT v], [c]) =
(fn s => c [] (updi(s, nth(a, n), v)))
| evalprim(CPS.store, [BYTEARRAY a, INT i, INT v], [c]) =
if v < 0 orelse v >= 256
then raise Undefined
else (fn s => c [] (updi(s, nth(a, i), v)))
| evalprim(CPS.makeref, [v], [c]) =
(fn (l, f, g) => c [ARRAY[l]] (upd((nextloc l, f, g), l, v)))
| evalprim(CPS.makerefunboxed, [INT v], [c]) =
(fn (l, f, g) => c [UARRAY[l]] (updi((nextloc l, f, g), l, v)))
| evalprim(CPS.alength, [ARRAY a], [c]) = c [INT(List.length a)]
| evalprim(CPS.alength, [UARRAY a], [c]) = c [INT(List.length a)]
| evalprim(CPS.slength, [BYTEARRAY a], [c]) = c[INT(List.length a)]
| evalprim(CPS.slength, [STRING a], [c]) = c [INT(String.size a)]
| evalprim(CPS.gethdlr, [], [c]) =
(fn s => c [fetch s handler_ref] s)
| evalprim(CPS.sethdlr, [h], [c]) =
(fn s => c [] (upd(s, handler_ref, h)))
| evalprim(CPS.fadd, [REAL a, REAL b], [c]) = overflow(fn() => a + b, c)
| evalprim(CPS.fsub, [REAL a, REAL b], [c]) = overflow(fn() => a - b, c)
| evalprim(CPS.fmul, [REAL a, REAL b], [c]) = overflow(fn() => a * b, c)
| evalprim(CPS.fdiv, [REAL a, REAL 0.0], [c]) = do_raise div_exn
| evalprim(CPS.fdiv, [REAL a, REAL b], [c]) = overflow(fn() => a / b, c)
| evalprim(CPS.feql, [REAL a, REAL b], [t, f]) = if a = b then t[] else f[]
| evalprim(CPS.fneq, [REAL a, REAL b], [t, f]) = if a = b then f[] else t[]
| evalprim(CPS.flt, [REAL i, REAL j], [t, f]) = if i < j then t[] else f[]
| evalprim(CPS.fle, [REAL i, REAL j], [t, f]) = if j < i then f[] else t[]
| evalprim(CPS.fgt, [REAL i, REAL j], [t, f]) = if j < i then t[] else f[]
| evalprim(CPS.fge, [REAL i, REAL j], [t, f]) = if i < j then f[] else t[]
(* environment is a mapping from CPS variables to denotable values *)
type env = CPS.var -> dvalue
fun V env (CPS.INT i) = INT i
| V env (CPS.REAL r) = REAL(string2real r)
| V env (CPS.STRING s) = STRING s
| V env (CPS.VAR v) = env v
| V env (CPS.LABEL v) = env v
fun bind(env: env, v: CPS.var, d) =
fn w => if v = w then d else env w
fun bindn(env, v::vl, d::dl) = bindn(bind(env, v, d), vl, dl)
| bindn(env, nil, nil) = env
fun F (x, CPS.OFFp 0) = x
| F (RECORD(l, i), CPS.OFFp j) = RECORD(l, i + j)
| F (RECORD(l, i), CPS.SELp(j, p)) = F(nth(l, i + j), p)
fun E (CPS.SELECT(i, v, w, e)) env =
let
val RECORD(l, j) = V env v
in
E e (bind(env, w, nth(l, i + j)))
end
| E (CPS.OFFSET(i, v, w, e)) env =
let
val RECORD(l, j) = V env v
in
E e (bind(env, w, RECORD(l, i + j)))
end
| E (CPS.APP(f, vl)) env =
let
val FUNC g = V env f
in
g (map (V env) vl)
end
| E (CPS.RECORD(vl, w, e)) env =
E e (bind(env, w, RECORD(map (fn (x, p) => F(V env x, p)) vl, 0)))
| E (CPS.SWICTH(v, el)) env =
let
val INT i = V env v
in
E (nth(el, i)) env
end
| E (CPS.PRIMOP(p, vl, wl, el)) env =
evalprim(p, map (V env) vl, map (fn e => fn al => E e (bindn(env, wl, al))) el)
| E (CPS.FIX(fl, e)) env =
let
fun h r1 (f, vl, b) =
FUNC(fn al => E b (bindn(g, r1, vl, al)))
and
g r = bindn(r, map #1 fl, map (h r) fl)
in
E e (g env)
end
val env0 = fn x => raise Undefined
fun eval (vl, e) dl = E e (bindn(env0, vl, dl))
end
Notes
-
The compiler design ideas are also documented in the article: Andrew W. Appel and Trevor Jim, “Continuation-Passing, Closure-Passing Style,” POPL’89: Proceedings of the 16th ACM SIGPLAN-SIGACT Symposium on Principles of Programming Languages, January 1989, Pages 293-302. ⤴
-
John C. Reynolds: “In the early history of continuations, basic concepts were independently discovered an extraordinary number of times. This was due less to poor communication among computer scientists than to the rich variety of settings in which continuations were found useful: They underlie a method of program transformation (into continuation-passing style), a style of definitional interpreter (defining one language by an interpreter written in another language), and a style of denotational semantics (in the sense of Scott and Strachey). In each of these settings, by representing ‘the meaning of the rest of the program’ as a function or procedure, continuations provide an elegant description of a variety of language constructs, including call by value and
goto
statements.” ⤴