Binary Search Tree in Standard ML

In this post, I would like to describe a Standard ML implementation of binary search tree (BST), in which every tree node n satisfies the condition that n’s left child, if exists, has a value less than that of n and n’s right child, if exists, has a value greater than that of n. Each non-empty tree node records information about its own element or value along with its descendant(s). A client should be able to

  1. find the element of a node by key;
  2. insert elements onto an existing tree;
  3. remove nodes from an existing tree;
  4. perform tree traversal using higher order functions.

The BST is kept as general as possible such that the elements and keys can be any types.

Table of Contents

Type Declarations

In Standard ML (or SML for short), datatype’s are used to introduce new types with value and/or type constructors. For example, order is a new type for key comparisons declared with three nullary value constructors. A value of type order is either LESS, or EQUAL, or GREATER.

datatype order = LESS | EQUAL | GREATER

A value of type compare_function is a function that accepts a pair of keys and returns the relative order between them. A value of type to_key_function is a function that takes an element and returns the matched key. The declarations shown below are type bindings with parameterized type constructors.

type 'k compare_function = (('k * 'k) -> order)
type ('e, 'k) to_key_function = 'e -> 'k

One convenience of datatype in SML is that we can create a new type recursively. Here, a tree node is either empty (“doesn’t exist”, represented by a Nil value) or composed of an element, a left child, and a right child. A leaf is just a node takes the form of Sub (element, Nil, Nil). Quite naturally, it follows that a tree can be represented as a tuple combining the root node with a key comparison function and a hash function. If the root node is Nil, then it is an empty tree. Since we already have a to_key_function type, there is no need to store the key information in tree nodes.

datatype 'e node = Nil | Sub of 'e * ('e node) * ('e node)
type ('e, 'k) tree = ('e node * 'k compare_function * ('e, 'k) to_key_function)

fun create_empty(cmp : 'k compare_function, to_key : ('e, 'k) to_key_function) : ('e, 'k) tree =
  (Nil, cmp, to_key)

fun create_empty_simple(cmp : 'k compare_function) : ('k, 'k) tree =
  (Nil, cmp, (fn(v)=>v))

It is rather simple to compute the height of a binary tree, which can be defined as the length of the longest path from the root node to any leaf node:

(* Returns x if x equals y *)
fun max(x, y) = if x < y then y else x;

(* ('e, 'k) tree -> int *)
fun height(bst) =
  let
    val (root, _, _) = bst
  in
    case root of
      Nil => 0
    | Sub(element, left_child, right_child) =>
        1 + max(height(left_child), height(right_child))
  end

(* 'e node -> int *)
fun height_simple(Nil) = 0
  | height_simple(Sub (element, left_child, right_child)) =
      1 + max(height_simple(left_child), height_simple(right_child))

Function find()

fun find(t : ('e, 'k) tree, key : 'k) : 'e option =
  let
    val (root, cmp, to_key) = t
    fun find_helper(n : 'e node) : 'e option =
      case n of
        Nil => NONE
      | Sub (this_val, left_child, right_child) =>
          case cmp(key, to_key(this_val)) of
            GREATER => find_helper(right_child)
          | LESS => find_helper(left_child)
          | EQUAL => SOME this_val
  in
    find_helper(root)
  end

Function insert()

The insert() function returns a pair containing the new tree and an option type of the replaced element. If the key provided matches an existing node in the tree, the old element is replaced. In a functional programming language like SML, objects are created by means of initialization and bindings are immutable once declared. Hence, I chose to modify the contents of a tree by means of creating nodes along the way of tree traversal.

fun insert(t : ('e, 'k) tree, element : 'e) : (('e, 'k) tree * 'e option) =
  let
    val (root, cmp, to_key) = t
    val ret = ref NONE
    fun insert_helper(n : 'e node) : 'e node =
      case n of
        Nil => Sub (element, Nil, Nil)
      | Sub (this_val, left_child, right_child) =>
          case cmp(to_key(element), to_key(this_val)) of
            GREATER => Sub (this_val, left_child, insert_helper(right_child))
          | LESS => Sub(this_val, insert_helper(left_child), right_child)
          | EQUAL => (ret := SOME element; Sub(element, left_child, right_child))
  in
    ((insert_helper(root), cmp, to_key), !ret)
  end

The insert_helper() function is written in a recursive manner:

Function remove()

The remove() function returns a pair containing the modified tree and an option type of the removed element. If the key provided matches an existing node in the tree, the node should be removed from the tree.

At first glance, the structure of the remove() function can be similar to the insert() function, in which we recursively construct a new tree that does not contain the node specified by the key. However, things become complicated as we reach the node to remove. Consider the following code:

fun remove(t : ('e, 'k) tree, key : 'k) : (('e, 'k) tree * 'e option) =
  let
    val (root, cmp, to_key) = t
    val removed = ref NONE
    fun remove_helper(n : 'e node) : 'e node =
      case n of
        Nil => n (* Do nothing if an empty node is passed in *)
      | Sub (this_val, left_child, right_child) =>
          case cmp(key, to_key(this_val)) of
            GREATER => Sub (this_val, left_child, remove_helper(right_child))
          | LESS => Sub (this_val, remove_helper(left_child), right_child)
          | EQUAL => (* Found the node to be removed *)
             (removed := SOME this_val;
              case (left_child, right_child) of
                (Nil, Nil) => Nil (* Simply clear the leaf *)
              | (Nil, only_right) => only_right (* Return right child *)
              | (only_left, Nil) => only_left (* Return left child *)
              | _ =>
                  let
                    (* Find the rightmost descendant of the left sub-tree *)
                    fun find_right(Sub (left_sub_val, left_sub_left, left_sub_right) : 'e node) : 'e node =
                      case left_sub_right of
                        Nil => Sub (left_sub_val, left_sub_left, left_sub_right)
                      | _ => find_right(left_sub_right)
                  in
                    let
                      val Sub (new_val, _, _) = find_right(left_child)
                    in
                      Sub (new_val, left_child, right_child)
                    end
                  end)
  in
    ((remove_helper(root), cmp, to_key), !removed)
  end

Starting at the nineth line inside the remove_helper function, we have to handle the sub-tree whose root needs to be removed:

The code is actually problematic in that for the final case, simply returning n with updated element will result in duplicate elements inside the tree. We need to set the descendant we find (the return value of find_right) to Nil.

My solution is to perform another tree traversal using a helper function:

fun remove_duplicate(n : 'e node) : 'e node =
  case n of
    Nil => Nil
  | Sub (this_val, left_child, right_child) =>
      case right_child of
        Nil => left_child
      | Sub (right_child_val, right_child_left, right_child_right) =>
          Sub (this_val, left_child, remove_duplicate(right_child))

Then, for the final case, return:

Sub (new_val, remove_duplicate(left_child), right_child)

Tree Traversal: fold

(* Return in-order element list *)
fun to_list_lnr(n : 'e node) : 'e list =
  case n of
    Nil => []
  | Sub (this_val, left_child, right_child) =>
      to_list_lnr(left_child) @ [this_val] @ to_list_lnr(right_child)


(* depth-first, in-order traversal
 * https://en.wikipedia.org/wiki/Tree_traversal#In-order_(LNR)
 *)
fun fold_lnr(f, init, t) = 
  let
    val (root, _, _) = t
  in
    List.foldl f init (to_list_lnr(root))
  end

(* depth-first, reverse in-order traversal
 * https://en.wikipedia.org/wiki/Tree_traversal#Reverse_in-order_(RNL)
 *)
fun fold_rnl(f, init, t) = 
  let
    val (root, _, _) = t
  in
    List.foldr f init (to_list_lnr(root))
  end

fun debug_message(element_to_string : 'e -> string, t : ('e,'k) tree) : string =
  let
    val (root, _, _) = t
  in
    String.concat(List.map element_to_string (to_list_lnr(root)))
  end

Balancing the BST

So far, our BST is not automatically height-balanced. A BST is height-balanced if the difference between the heights of its left child and right child is no greater than one. In this case, the BST is also called an AVL tree. A convenient way to implement an AVL tree is to let each non-empty node take on the information about its own height, i.e., the length of the longest path from the node to any leaf node1:

datatype 'e node = Nil | Sub of int * 'e * ('e node) * ('e node)

Hence, given an arbitrary node, its height information may be retrieved like this:

fun height_node(Nil) = 0
  | height_node(Sub (h, _, _, _)) = h;

Given an arbitrary node, we can check if it suffices to act as the root of an AVL tree:

fun is_balanced_node(n : 'e node) : bool =
  case n of
    Nil => true
  | Sub (_, this_val, left_child, right_child) =>
      is_balanced_node(left_child) andalso
      is_balanced_node(right_child) andalso
      abs(height_node(left_child) - height_node(right_child)) <= 1

The balance factor of a node of tree t is the height difference between its two children, which is going to be used whenever we modify t, and rebalance it subsequently:

fun balance_factor(Nil) = 0
  | balance_factor(Sub (_, _, left_child, right_child)) = height_node(left_child) - height_node(right_child);

Function balance()

The correctness of our AVL tree operations depends on how we implement the balance() function. We shall only pay attention to moments when the absolute value of the balance factors of certain nodes turn to two because when less than two no tree balancing is needed.

(* Creates a new sub-tree with new height *)
fun new_root_node(element : 'e, left_sub : 'e node, right_sub : 'e node) : 'e node =
  let 
    val new_height = 1 + max(height_node(left_sub), height_node(right_sub))
  in 
    Sub (new_height, element, left_sub, right_sub)
  end

fun rotate_left(n : 'e node) : 'e node =
  case n of
    Sub (_, v, ls, Sub (_, rv, rls, rrs)) => 
      new_root_node(rv, new_root_node(v, ls, rls), rrs)
  | _ => n

fun rotate_right(n : 'e node) : 'e node =
  case n of
    Sub (_, v, Sub (lv, lls, lrs), rs) => 
      new_root_node(lv, lls, new_root_node(v, lrs, rs))
  | _ => n

fun balance(n as Sub (height, this_val, left_child, right_child) : 'e node) : 'e node =
  let
    val n_balance_factor = balance_factor(n)
    val l_balance_factor = balance_factor(left_child)
    val r_balance_factor = balance_factor(right_child)
  in
    case (n_balance_factor, l_balance_factor, r_balance_factor) of
      (2, ~1, _) =>
        (* First rotate_left then rotate_right *)
        rotate_right(new_root_node(this_val, rotate_left(left_child), right_child))
    | (2, _, _) => rotate_right(n)
    | (~2, _, 1) =>
        (* First rotate_right then rotate_left *)
        rotate_left(new_root_node(this_val, left_child, rotate_right(right_child)))
    | (~2, _, _) => rotate_left(n)
    | _ => n
  end

The rotate_left() and rotate_right() functions are the building blocks of our code here. To balance a binary tree whose root node has a balance factor of 2 or ~2, it is useful to consider the following cases:

Our implementation for the insert() and remove() functions is then straightforward:

fun insert(t : ('e, 'k) tree, element : 'e) : (('e, 'k) tree * 'e option) =
  let
    val (root, cmp, to_key) = t
    val ret = ref NONE
    fun insert_helper(n : 'e node) : 'e node =
      case n of
        Nil => new_root_node(element, Nil, Nil)
      | Sub (height, this_val, left_child, right_child) =>
          case cmp(to_key(element), to_key(this_val)) of
            GREATER =>
              balance(new_root_node(this_val, left_child, insert_helper(right_child)))
          | LESS =>
              balance(new_root_node(this_val, insert_helper(left_child), right_child))
          | EQUAL => 
              (ret := SOME element; n)
  in
    ((insert_helper(root), cmp, to_key), !ret)
  end
	
fun find_left(n as Sub (_, right_sub_val, right_sub_left, right_sub_right) : 'e node) : ('e node * 'e) =
  case right_sub_left of
    Nil => (right_sub_right, right_sub_val) (* Found the leftmost descendant *)
  | _ =>
      let
        val (new_left, new_val) = find_left(right_sub_left)
      in
        (balance(new_root_node(right_sub_val, new_left, right_sub_right)), new_val)
      end
	
fun remove(t : ('e, 'k) tree, key : 'k) : (('e, 'k) tree * 'e option) =
  let
    val (root, cmp, to_key) = t
    val removed = ref NONE
    fun remove_helper(n : 'e node) : 'e node =
      case n of
        Nil => n
      | Sub (_, this_val, left_child, right_child) =>
          case cmp(key, to_key(this_val)) of
            GREATER =>
              balance(new_root_node(this_val, left_child, remove_helper(right_child)))
          | LESS =>
              balance(new_root_node(this_val, remove_helper(left_child), right_child))
          | EQUAL =>
             (removed := SOME this_val;
              case (left_child, right_child) of
                (_, Nil) => left_child
              | (Nil, _) => right_child
              | _ =>
                let
                  (* Find the leftmost descendant of the right sub-tree *)
                  val (new_right, new_val) = find_left(right_child)
                in
                  balance(new_root_node(new_val, left_child, new_right))
                end)
  in
    ((remove_helper(root), cmp, to_key), !removed)
  end

Our BST implementation can be easily modified in order to support a sorted dictionary data structure:

type (''k, 'v) dictionary = ((''k * 'v), ''k) BinarySearchTree.tree

which probably performs four basic operations, with $O(\log N)$ expected performance of course:

val get : (''k, 'v) dictionary * ''k -> 'v option
val put : (''k, 'v) dictionary * ''k * 'v -> (''k * 'v) dictionary * 'v option
val remove : (''k, 'v) dictionary * ''k -> (''k, 'v) dictionary * 'v option
val entries : (''k, 'v) dictionary -> (''k * 'v) list

Here is the code:

structure SortedDictionary = DictionaryFn(struct
  type ''k compare_function = (''k * ''k) -> order
  type (''k, 'v) dictionary = ((''k * 'v), ''k) BinarySearchTree.tree
  type ''k create_parameter_type = ''k compare_function

  fun create(cmp : ''k compare_function) : (''k, 'v) dictionary =
    let
      fun to_key((key, value) : ''k * 'v) : ''k = key
    in
      BinarySearchTree.create_empty(cmp, to_key)
    end

  fun get(dict : (''k, 'v) dictionary, key : ''k) : 'v option =
    case BinarySearchTree.find(dict, key) of
      NONE => NONE
    | SOME(this_key, this_val) => SOME this_val

  fun put(dict : (''k, 'v) dictionary, key : ''k, value : 'v) : (''k, 'v) dictionary * 'v option =
    let
      val (this_dict, this_key_val) = BinarySearchTree.insert(dict, (key, value))
    in
      case this_key_val of
        NONE => (this_dict, NONE)
      | SOME(this_key, this_val) => (this_dict, SOME this_val)
    end

  fun remove(dict : (''k, 'v) dictionary, key : ''k) : (''k, 'v) dictionary * 'v option =
    let
      val (this_dict, this_key_val) = BinarySearchTree.remove(dict, key)
    in
      case this_key_val of
        NONE => (this_dict, NONE)
      | SOME(this_key, this_val) => (this_dict, SOME this_val)
    end

  fun entries(dict : (''k, 'v) dictionary) : (''k * 'v) list =
    BinarySearchTree.fold_rnl((fn (node, acc) => node::acc), [], dict)

end)

Notes

  1. Alternatively, we may want to implement a red-black tree, which is a height-balanced binary search tree with the color invariant:

    1. No red node has a red parent, and
    2. every path from the root node to a leaf node has the same number of black nodes.

    As shown above, a non-empty node may be implemented to take on the information about its own color. An implementation in OCaml is presented here