Continuation-Passing Style in Standard ML

This post is some notes taken from Andrew W. Appel’s 1992 book, Compiling with Continuations1.

A continuation is a function that expresses “what to do next”2. A more formal definition of continuation is given by Ferrante and Allard (1996): it is a data structure which, in its simplest form, contains:

Continuation-passing style (CPS) is a program notation with clean semantics that makes every aspect of control flow and data flow explicit. It breaks the program into a set of continuations. Consider the following program written in ML:

(* An ML program that computes the product of all primes
   less than or equal to a positive integer n. *)
fun prodprimes(n) =
  if n = 1
  then 1
  else
    if isprime(n)
    then n * prodprimes(n - 1)  (* line 6 *)
    else prodprimes(n - 1)      (* line 7 *)

A CPS version of the program looks like this:

(* Control points c, k, j, h are continuation functions.
   Data labels b, p, a, m, q, i are variables. *)
fun prodprimes(n, c) =
  if n = 1
  then c(1)
  else
    let
      fun k(b) =
        if b = true
        then
          let
            fun j(p) =
              let
                val a = n * p
              in
                c(a)
              end
            val m = n - 1
          in
            prodprimes(m, j)
          end
        else
          let (* line 21 *)
            fun h(q) = c(q)
            val i = n - 1
          in
            prodprimes(i, h)
          end (* line 26 *)
    in
      isprime(n, k)
    end

In the original version, after the function isprime is called, it is passed a return address k, and it returns a Boolean value b. The first call to prodprimes in line 6 returns to a point j with an integer p, but the second call to prodprimes in line 7 returns to a point h with an integer q. The first computation of n - 1 is put in a temporary variable m, and a second one in i. The function prodprimes is given a continuation c as one of its argument in the CPS version, and when prodprimes has computed its result a, it will continue by applying c to a. Note than we can perform tail-recursion elimination by rewriting line 21-26 into:

let
  val i = n - 1
in
  prodprimes(i, c)
end

Table of Contents

The CPS Datatype

The CPS language is meant to model the program executed by a von Neumann machine, which likes to do just one thing at a time, with all the arguments to an operation ready-at-hand in registers. Therefore, arguments to a function (or a primitive operator) are always atomic (variables or constants, but not subexpressions); a function application can never be an argument of another application. Each continuation expression takes zero or more atomic arguments, binds zero or more results, and continues via zero or more continuation expressions.

signature CPS = sig

eqtype var

(* All the different kinds of atomic arguments that can be provided to the CPS operators *)
datatype value = VAR of var
               | LABEL of var
               | INT of int
               | REAL of string
               | STRING of string

datatype accesspath = OFFp of int
                    | SELp of int * accesspath

datatype primop = * | + | - | div | ~
                | ieql | ineq | < | <= | > | >= | rangechk
                | ! | subscript | ordof
                | := | unboxedassign | update | unboxedupdate | store
                | makeref | makerefunboxed | alength | slength
                | gethdlr | sethdlr
                | boxed
                | fadd | fsub | fdiv | fmul
                | feql | fneq | fge | fgt | fle | flt
                | rshift | lshift | orb | andb | xorb | notb

datatype cexp = RECORD of (value * accesspath) list * var * cexp
              | SELECT of int * value * var * cexp
              | OFFSET of int * value * var * cexp
              | APP of value * value list
              | FIX of (var * var list * cexp) list * cexp
              | SWITCH of value * cexp list
              | PRIMOP of primop * value list * var list * cexp list

end

The practical reason for expressing reals as strings is to make the CPS language independent of any particular machine representation. All function calls in CPS are “tail” calls; that is, they do not return to the calling function. This can be seen from the form of the APP continuation expression, which does not have any continuation expression as a subexpression.

Consider the following CPS ML code:

let
  fun f(x, k) = k(2 * x + 1)
  fun k1(i) = 
    let
      fun k2(j) = r(i * j)
    in
      f(c + d, k2)
    end
in
  f(a + b, k1)
end

which is converted from the following ML code:

let
  fun f(x) = 2 * x + 1
in
  f(a + b) * f(c + d)
end

r is the “rest” of the computation, to which the original version was to have returned its answer. Translate the code into the CPS language:

FIX( [( f, [x, k], PRIMOP( *, [INT 2, VAR x], [u], [
                     PRIMOP( +, [VAR u, INT 1], [v], [
                       APP( VAR k, [VAR v] )] )] ) ),
      ( k1, [i], FIX( [( k2, [j], PRIMOP( *, [VAR i, VAR j], [w], [
                                    APP( VAR r, [VAR w] )] ) )],
                      PRIMOP( +, [VAR c, VAR d], [m], [
                        APP( VAR f, [VAR m, VAR k2] )] ) ) )],
      PRIMOP( +, [VAR a, VAR b], [n], [
        APP( VAR f, [VAR n, VAR k1] )] ) )

Continuation functions are bound by ordinary FIX operators and passed as arguments to functions. The effect of evaluating FIX\((\overrightarrow{f}, E)\) is to define the functions in the list \(\overrightarrow{f}\) and then to evaluate \(E\), which (usually) calls one or more of the \(f_{i}\) by means of the APP operator, or passes one of the \(f_{i}\) as an argument to another function, or stores some of the \(f_{i}\) in a data structure.

When one function \(f\) is statically nested inside another function \(g\), \(f\) may refer to variables bound in \(g\) (call them free variables of \(f\)). In programming languages with higher-order functions like ML, the inner function can even be called after the outer function has returned. The usual implementation technique uses a closure data structure: a record containing the free variables of the inner function as well as a pointer to its machine code. A pointer to this record is made available to the machine code while it executes so that the free variables are accessible. By putting the code-pointer at a fixed offset of the record, users of the function need not know the format of the record or even its size in order to call the function. In “closure-passing style”, after a pass to gather free variable information, the converter traverses the CPS expression. At every FIX that binds an escaping function (some of its call sites are unknown because it is passed as an argument, stored into a data structure, or returned as a result of a function call), a RECORD is inserted to create an explicit closure record, and a new argument corresponding to the closure record is added to the argument list of each function. Free variables accessed within the body of the function are rewritten as SELECTs from this argument.

The finite-register rule for the CPS is this:

Semantics of the CPS

Operations that have side effects on memory are represented by means of a “store.” Each store value can be thought as a mapping from locations (addresses) to denotable values.

functor CPSsemantics(

  structure CPS: CPS

  val minint: int
  val maxint: int
  val minreal: real
  val maxreal: real
  val string2real: string -> real

  eqtype loc (* locations *)
  val nextloc: loc -> loc

  (* arbitrarily(a, b) evaluates either to a or to b *)
  val arbitrarily: 'a * 'a -> 'a

  type answer
  (* denotable values *)
  datatype dvalue = RECORD of dvalue list * int
                  | INT of int
                  | REAL of real
                  | FUNC of dvalue list -> (loc * (loc -> dvalue) * (loc -> int)) -> answer
                  | STRING of string
                  | BYTEARRAY of loc list
                  | ARRAY of loc list
                  | UARRAY of loc list
           
  (* the "current exception handler" *)
  val handler_ref : loc
  (* exception for arithmetic overflow *)
  val overflow_exn : dvalue
  (* exception for dividing by zero *)
  val div_exn : dvalue
) : sig 
      val eval: CPS.var list * CPS.cexp -> 
                dvalue list -> 
                (loc * (loc -> dvalue) * (loc -> int)) -> 
                answer
    end =

struct (* beginning of the functor body *)

  type store = loc * (loc -> dvalue) * (loc -> int)
  fun fetch ((_, f, _): store) (l: loc) = f l
  fun upd ((n, f, g): store), l: loc, v: dvalue) =
    (n, fn i => if i = l then v else f i, g)
  fun fetchi ((_, _, g): store) (l: loc) = g l
  fun updi ((n, f, g): store, l: loc, v: int) =
    (n, f, fn i => if i = l then v else g i)

  exception Undefined

  fun eq(RECORD(a, i), RECORD(b, j)) = arbitrarily((i = j) andalso eqlist(a, b), false)
    | eq(INT i, INT j) = (i = j)
    | eq(REAL a, REAL b) = arbitrarily((a = b), false)
    | eq(STRING a, STRING b) = arbitrarily((a = b), false)
    | eq(BYTEARRAY nil, BYTEARRAY nil) = true
    | eq(BYTEARRAY(a::_), BYTEARRAY(b::_)) = (a = b)
    | eq(ARRAY nil, ARRAY nil) = true
    | eq(ARRAY(a::_), ARRAY(b::_)) = (a = b)
    | eq(UARRAY nil, UARRAY nil) = true
    | eq(UARRAY(a::_), UARRAY(b::_)) = (a = b)
    | eq(FUNC a, FUNC b) = raise Undefined
    | eq(INT i, _) = arbitrarily(false, i < 0 orelse i > 255)
    | eq(_, INT i) = arbitrarily(false, i < 0 orelse i > 255)
    | eq(_, _) = false and eqlist(a::al, b::bl) = (eq(a, b) andalso eqlist(al, bl))
    | eqlist(nil, nil) = true

  fun do_raise exn s =
    let
      val FUNC f = fetch s handler_ref
    in
      f [exn] s
    end
  
  (* evaluates a number n, turns it into a dvalue list, 
     and hands it to its continuation argument c;
     but if it overflows, raises an exception instead. *)
  fun overflow(n: unit -> int, c: dvalue list -> store -> answer) =
    if (n() >= minint andalso n() <= maxint)
      handle Overflow => false
    then c [INT(n())]
    else do_raise overflow_exn
  
  fun overflowr(n, c) =
    if (n() >= minreal andalso n() <= maxreal)
      handle Overflow => false
    then c [REAL(n())]
    else do_raise overflow_exn
  
  fun evalprim(CPS.+: CPS.primop,
               [INT i, INT j]: dvalue list,
               [c]: (dvalue list -> store -> answer) list) = overflow(fn() => i + j, c)
    | evalprim(CPS.-, [INT i, INT j], [c]) = overflow(fn() => i - j, c)
    | evalprim(CPS.*, [INT i, INT j], [c]) = overflow(fn() => i * j, c)
    | evalprim(CPS.div, [INT i, INT 0], [c]) = do_raise div_exn
    | evalprim(CPS.div, [INT i, INT j], [c]) = overflow(fn() => i div j, c)
    | evalprim(CPS.~, [INT i], [c]) = overflow(fn() => 0 - i, c)
    | evalprim(CPS.<, [INT i, INT j], [t, f]) = if i < j then t[] else f[]
    | evalprim(CPS.<=, [INT i, INT j], [t, f]) = if j < i then f[] else t[]
    | evalprim(CPS.>, [INT i, INT j], [t, f]) = if j < i then t[] else f[]
    | evalprim(CPS.>=, [INT i, INT j], [t, f]) = if i < j then f[] else t[]
    | evalprim(CPS.ieql, [a, b], [t, f]) = if eq(a, b) then t[] else f[]
    | evalprim(CPS.ineq, [a, b], [t, f]) = if eq(a, b) then f[] else t[]
    | evalprim(CPS.rangechk, [INT i, INT j], [t, f]) =
        if j < 0
        then
          if i < 0
          then if i < j then t[] else f[]
          else t[]
        else
          if i < 0
          then f[]
          else if i < j then t[]
          else f[]
    | evalprim(CPS.boxed, [INT i], [t, f]) =
        if i < 0 orelse i > 255
        then arbitrarily(t[], f[])
        else f[]
    | evalprim(CPS.boxed, [RECORD _], [t, f]) = t[]
    | evalprim(CPS.boxed, [STRING _], [t, f]) = t[]
    | evalprim(CPS.boxed, [ARRAY _], [t, f]) = t[]
    | evalprim(CPS.boxed, [UARRAY _], [t, f]) = t[]
    | evalprim(CPS.boxed, [BYTEARRAY _], [t, f]) = t[]
    | evalprim(CPS.boxed, [FUNC _], [t, f]) = t[]
    | evalprim(CPS.!, [a], [c]) = evalprim(CPS.subscript, [a, INT 0], [c])
    | evalprim(CPS.subscript, [ARRAY a, INT n], [c]) =
        (fn s => c [fetch s (nth(a, n))] s)
    | evalprim(CPS.subscript, [UARRAY a, INT n], [c]) =
        (fn s => c [INT(fetchi s (nth(a, n)))] s)
    | evalprim(CPS.subscript, [RECORD(a, i), INT j], [c]) =
        c [nth(a, i + j)]
    | evalprim(CPS.ordof, [STRING a, INT i], [c]) =
        c [INT(String.ordof(a, i))]
    | evalprim(CPS.ordof, [BYTEARRAY a, INT i], [c]) =
        (fn s => c [INT(fetchi s(nth(a, i)))] s)
    | evalprim(CPS.:=, [a, v], [c]) = evalprim(CPS.update, [a, INT 0, v], [c])
    | evalprim(CPS.update, [ARRAY a, INT n, v], [c]) =
        (fn s => c [] (upd(s, nth(a, n), v)))
    | evalprim(CPS.update, [UARRAY a, INT n, INT v], [c]) =
        (fn s => c [] (updi(s, nth(a, n), v)))
    | evalprim(CPS.unboxedassign, [a, v], [c]) = evalprim(CPS.unboxedupdate, [a, INT 0, v], [c])
    | evalprim(CPS.unboxedupdate, [ARRAY a, INT n, INT v], [c]) =
        (fn s => c [] (upd(s, nth(a, n), INT v)))
    | evalprim(CPS.unboxedupdate, [UARRAY a, INT n, INT v], [c]) =
        (fn s => c [] (updi(s, nth(a, n), v)))
    | evalprim(CPS.store, [BYTEARRAY a, INT i, INT v], [c]) =
        if v < 0 orelse v >= 256
        then raise Undefined
        else (fn s => c [] (updi(s, nth(a, i), v)))
    | evalprim(CPS.makeref, [v], [c]) =
        (fn (l, f, g) => c [ARRAY[l]] (upd((nextloc l, f, g), l, v)))
    | evalprim(CPS.makerefunboxed, [INT v], [c]) =
        (fn (l, f, g) => c [UARRAY[l]] (updi((nextloc l, f, g), l, v)))
    | evalprim(CPS.alength, [ARRAY a], [c]) = c [INT(List.length a)]
    | evalprim(CPS.alength, [UARRAY a], [c]) = c [INT(List.length a)]
    | evalprim(CPS.slength, [BYTEARRAY a], [c]) = c[INT(List.length a)]
    | evalprim(CPS.slength, [STRING a], [c]) = c [INT(String.size a)]
    | evalprim(CPS.gethdlr, [], [c]) =
        (fn s => c [fetch s handler_ref] s)
    | evalprim(CPS.sethdlr, [h], [c]) =
        (fn s => c [] (upd(s, handler_ref, h)))
    | evalprim(CPS.fadd, [REAL a, REAL b], [c]) = overflow(fn() => a + b, c)
    | evalprim(CPS.fsub, [REAL a, REAL b], [c]) = overflow(fn() => a - b, c)
    | evalprim(CPS.fmul, [REAL a, REAL b], [c]) = overflow(fn() => a * b, c)
    | evalprim(CPS.fdiv, [REAL a, REAL 0.0], [c]) = do_raise div_exn
    | evalprim(CPS.fdiv, [REAL a, REAL b], [c]) = overflow(fn() => a / b, c)
    | evalprim(CPS.feql, [REAL a, REAL b], [t, f]) = if a = b then t[] else f[]
    | evalprim(CPS.fneq, [REAL a, REAL b], [t, f]) = if a = b then f[] else t[]
    | evalprim(CPS.flt, [REAL i, REAL j], [t, f]) = if i < j then t[] else f[]
    | evalprim(CPS.fle, [REAL i, REAL j], [t, f]) = if j < i then f[] else t[]
    | evalprim(CPS.fgt, [REAL i, REAL j], [t, f]) = if j < i then t[] else f[]
    | evalprim(CPS.fge, [REAL i, REAL j], [t, f]) = if i < j then f[] else t[]
  
  (* environment is a mapping from CPS variables to denotable values *)
  type env = CPS.var -> dvalue

  fun V env (CPS.INT i) = INT i
    | V env (CPS.REAL r) = REAL(string2real r)
    | V env (CPS.STRING s) = STRING s
    | V env (CPS.VAR v) = env v
    | V env (CPS.LABEL v) = env v

  fun bind(env: env, v: CPS.var, d) =
    fn w => if v = w then d else env w
  
  fun bindn(env, v::vl, d::dl) = bindn(bind(env, v, d), vl, dl)
    | bindn(env, nil, nil) = env
  
  fun F (x, CPS.OFFp 0) = x
    | F (RECORD(l, i), CPS.OFFp j) = RECORD(l, i + j)
    | F (RECORD(l, i), CPS.SELp(j, p)) = F(nth(l, i + j), p)
  
  fun E (CPS.SELECT(i, v, w, e)) env =
        let
          val RECORD(l, j) = V env v
        in
          E e (bind(env, w, nth(l, i + j)))
        end
    | E (CPS.OFFSET(i, v, w, e)) env =
        let
          val RECORD(l, j) = V env v
        in
          E e (bind(env, w, RECORD(l, i + j)))
        end
    | E (CPS.APP(f, vl)) env =
        let
          val FUNC g = V env f
        in
          g (map (V env) vl)
        end
    | E (CPS.RECORD(vl, w, e)) env = 
        E e (bind(env, w, RECORD(map (fn (x, p) => F(V env x, p)) vl, 0)))
    | E (CPS.SWICTH(v, el)) env =
        let
          val INT i = V env v
        in
          E (nth(el, i)) env
        end
    | E (CPS.PRIMOP(p, vl, wl, el)) env =
        evalprim(p, map (V env) vl, map (fn e => fn al => E e (bindn(env, wl, al))) el)
    | E (CPS.FIX(fl, e)) env =
        let
          fun h r1 (f, vl, b) =
            FUNC(fn al => E b (bindn(g, r1, vl, al)))
          and
          g r = bindn(r, map #1 fl, map (h r) fl)
        in
          E e (g env)
        end
  
  val env0 = fn x => raise Undefined

  fun eval (vl, e) dl = E e (bindn(env0, vl, dl))

end

Notes

  1. The compiler design ideas are also documented in the article: Andrew W. Appel and Trevor Jim, “Continuation-Passing, Closure-Passing Style,” POPL’89: Proceedings of the 16th ACM SIGPLAN-SIGACT Symposium on Principles of Programming Languages, January 1989, Pages 293-302. 

  2. John C. Reynolds: “In the early history of continuations, basic concepts were independently discovered an extraordinary number of times. This was due less to poor communication among computer scientists than to the rich variety of settings in which continuations were found useful: They underlie a method of program transformation (into continuation-passing style), a style of definitional interpreter (defining one language by an interpreter written in another language), and a style of denotational semantics (in the sense of Scott and Strachey). In each of these settings, by representing ‘the meaning of the rest of the program’ as a function or procedure, continuations provide an elegant description of a variety of language constructs, including call by value and goto statements.”