Binary Search Tree in Standard ML
06 Oct 2022
In this post, I would like to describe a Standard ML implementation of binary search tree (BST), in which every tree node n
satisfies the condition that n
’s left child, if exists, has a value less than that of n
and n
’s right child, if exists, has a value greater than that of n
. Each non-empty tree node records information about its own element or value along with its descendant(s). A client should be able to
find
the element of a node by key;insert
elements onto an existing tree;remove
nodes from an existing tree;- perform tree traversal using higher order functions.
The BST is kept as general as possible such that the elements and keys can be any types.
Table of Contents
- Type Declarations
- Function
find()
- Function
insert()
- Function
remove()
- Tree Traversal:
fold
- Balancing the BST
- Notes
Type Declarations
In Standard ML (or SML for short), datatype
’s are used to introduce new types with value and/or type constructors. For example, order
is a new type for key comparisons declared with three nullary value constructors. A value of type order
is either LESS
, or EQUAL
, or GREATER
.
datatype order = LESS | EQUAL | GREATER
A value of type compare_function
is a function that accepts a pair of keys and returns the relative order between them. A value of type to_key_function
is a function that takes an element and returns the matched key. The declarations shown below are type bindings with parameterized type constructors.
type 'k compare_function = (('k * 'k) -> order)
type ('e, 'k) to_key_function = 'e -> 'k
One convenience of datatype
in SML is that we can create a new type recursively. Here, a tree node is either empty (“doesn’t exist”, represented by a Nil
value) or composed of an element, a left child, and a right child. A leaf is just a node takes the form of Sub (element, Nil, Nil)
. Quite naturally, it follows that a tree can be represented as a tuple combining the root node with a key comparison function and a hash function. If the root node is Nil
, then it is an empty tree. Since we already have a to_key_function
type, there is no need to store the key information in tree nodes.
datatype 'e node = Nil | Sub of 'e * ('e node) * ('e node)
type ('e, 'k) tree = ('e node * 'k compare_function * ('e, 'k) to_key_function)
fun create_empty(cmp : 'k compare_function, to_key : ('e, 'k) to_key_function) : ('e, 'k) tree =
(Nil, cmp, to_key)
fun create_empty_simple(cmp : 'k compare_function) : ('k, 'k) tree =
(Nil, cmp, (fn(v)=>v))
It is rather simple to compute the height of a binary tree, which can be defined as the length of the longest path from the root node to any leaf node:
(* Returns x if x equals y *)
fun max(x, y) = if x < y then y else x;
(* ('e, 'k) tree -> int *)
fun height(bst) =
let
val (root, _, _) = bst
in
case root of
Nil => 0
| Sub(element, left_child, right_child) =>
1 + max(height(left_child), height(right_child))
end
(* 'e node -> int *)
fun height_simple(Nil) = 0
| height_simple(Sub (element, left_child, right_child)) =
1 + max(height_simple(left_child), height_simple(right_child))
Function find()
fun find(t : ('e, 'k) tree, key : 'k) : 'e option =
let
val (root, cmp, to_key) = t
fun find_helper(n : 'e node) : 'e option =
case n of
Nil => NONE
| Sub (this_val, left_child, right_child) =>
case cmp(key, to_key(this_val)) of
GREATER => find_helper(right_child)
| LESS => find_helper(left_child)
| EQUAL => SOME this_val
in
find_helper(root)
end
Function insert()
The insert()
function returns a pair containing the new tree and an option
type of the replaced element. If the key provided matches an existing node in the tree, the old element is replaced. In a functional programming language like SML, objects are created by means of initialization and bindings are immutable once declared. Hence, I chose to modify the contents of a tree by means of creating nodes along the way of tree traversal.
fun insert(t : ('e, 'k) tree, element : 'e) : (('e, 'k) tree * 'e option) =
let
val (root, cmp, to_key) = t
val ret = ref NONE
fun insert_helper(n : 'e node) : 'e node =
case n of
Nil => Sub (element, Nil, Nil)
| Sub (this_val, left_child, right_child) =>
case cmp(to_key(element), to_key(this_val)) of
GREATER => Sub (this_val, left_child, insert_helper(right_child))
| LESS => Sub(this_val, insert_helper(left_child), right_child)
| EQUAL => (ret := SOME element; Sub(element, left_child, right_child))
in
((insert_helper(root), cmp, to_key), !ret)
end
The insert_helper()
function is written in a recursive manner:
- When an empty node is passed into the
insert_helper()
function, it returns a leaf node with the new element and let the optional replaced element beNONE
, which is the base case; - Otherwise, it constructs a new tree by repeatedly calling itself until it
- either finds a proper position to form a leaf node and let the optional replaced element be
NONE
, - or replaces an internal node that has the same element and update the optional replaced element to be
SOME element
.
- either finds a proper position to form a leaf node and let the optional replaced element be
Function remove()
The remove()
function returns a pair containing the modified tree and an option
type of the removed element. If the key provided matches an existing node in the tree, the node should be removed from the tree.
At first glance, the structure of the remove()
function can be similar to the insert()
function, in which we recursively construct a new tree that does not contain the node specified by the key. However, things become complicated as we reach the node to remove. Consider the following code:
fun remove(t : ('e, 'k) tree, key : 'k) : (('e, 'k) tree * 'e option) =
let
val (root, cmp, to_key) = t
val removed = ref NONE
fun remove_helper(n : 'e node) : 'e node =
case n of
Nil => n (* Do nothing if an empty node is passed in *)
| Sub (this_val, left_child, right_child) =>
case cmp(key, to_key(this_val)) of
GREATER => Sub (this_val, left_child, remove_helper(right_child))
| LESS => Sub (this_val, remove_helper(left_child), right_child)
| EQUAL => (* Found the node to be removed *)
(removed := SOME this_val;
case (left_child, right_child) of
(Nil, Nil) => Nil (* Simply clear the leaf *)
| (Nil, only_right) => only_right (* Return right child *)
| (only_left, Nil) => only_left (* Return left child *)
| _ =>
let
(* Find the rightmost descendant of the left sub-tree *)
fun find_right(Sub (left_sub_val, left_sub_left, left_sub_right) : 'e node) : 'e node =
case left_sub_right of
Nil => Sub (left_sub_val, left_sub_left, left_sub_right)
| _ => find_right(left_sub_right)
in
let
val Sub (new_val, _, _) = find_right(left_child)
in
Sub (new_val, left_child, right_child)
end
end)
in
((remove_helper(root), cmp, to_key), !removed)
end
Starting at the nineth line inside the remove_helper
function, we have to handle the sub-tree whose root needs to be removed:
- If current
n
is actually a leaf node, then returnsNil
so that the leaf node is removed and no further action is needed; - If current
n
only has a right child, then returns this right child and no further action is needed; - If current
n
only has a left child, then returns this left child and no further action is needed; - If Neither of current
n
’s children isNil
, replacen
with the rightmost descendant ofn
’s left child (or equivalently the leftmost descendant ofn
’s right child).
The code is actually problematic in that for the final case, simply returning n
with updated element will result in duplicate elements inside the tree. We need to set the descendant we find (the return value of find_right
) to Nil
.
My solution is to perform another tree traversal using a helper function:
fun remove_duplicate(n : 'e node) : 'e node =
case n of
Nil => Nil
| Sub (this_val, left_child, right_child) =>
case right_child of
Nil => left_child
| Sub (right_child_val, right_child_left, right_child_right) =>
Sub (this_val, left_child, remove_duplicate(right_child))
Then, for the final case, return:
Sub (new_val, remove_duplicate(left_child), right_child)
Tree Traversal: fold
(* Return in-order element list *)
fun to_list_lnr(n : 'e node) : 'e list =
case n of
Nil => []
| Sub (this_val, left_child, right_child) =>
to_list_lnr(left_child) @ [this_val] @ to_list_lnr(right_child)
(* depth-first, in-order traversal
* https://en.wikipedia.org/wiki/Tree_traversal#In-order_(LNR)
*)
fun fold_lnr(f, init, t) =
let
val (root, _, _) = t
in
List.foldl f init (to_list_lnr(root))
end
(* depth-first, reverse in-order traversal
* https://en.wikipedia.org/wiki/Tree_traversal#Reverse_in-order_(RNL)
*)
fun fold_rnl(f, init, t) =
let
val (root, _, _) = t
in
List.foldr f init (to_list_lnr(root))
end
fun debug_message(element_to_string : 'e -> string, t : ('e,'k) tree) : string =
let
val (root, _, _) = t
in
String.concat(List.map element_to_string (to_list_lnr(root)))
end
Balancing the BST
So far, our BST is not automatically height-balanced. A BST is height-balanced if the difference between the heights of its left child and right child is no greater than one. In this case, the BST is also called an AVL tree. A convenient way to implement an AVL tree is to let each non-empty node take on the information about its own height, i.e., the length of the longest path from the node to any leaf node1:
datatype 'e node = Nil | Sub of int * 'e * ('e node) * ('e node)
Hence, given an arbitrary node, its height information may be retrieved like this:
fun height_node(Nil) = 0
| height_node(Sub (h, _, _, _)) = h;
Given an arbitrary node, we can check if it suffices to act as the root of an AVL tree:
fun is_balanced_node(n : 'e node) : bool =
case n of
Nil => true
| Sub (_, this_val, left_child, right_child) =>
is_balanced_node(left_child) andalso
is_balanced_node(right_child) andalso
abs(height_node(left_child) - height_node(right_child)) <= 1
The balance factor of a node of tree t
is the height difference between its two children, which is going to be used whenever we modify t
, and rebalance it subsequently:
fun balance_factor(Nil) = 0
| balance_factor(Sub (_, _, left_child, right_child)) = height_node(left_child) - height_node(right_child);
Function balance()
The correctness of our AVL tree operations depends on how we implement the balance()
function. We shall only pay attention to moments when the absolute value of the balance factors of certain nodes turn to two because when less than two no tree balancing is needed.
(* Creates a new sub-tree with new height *)
fun new_root_node(element : 'e, left_sub : 'e node, right_sub : 'e node) : 'e node =
let
val new_height = 1 + max(height_node(left_sub), height_node(right_sub))
in
Sub (new_height, element, left_sub, right_sub)
end
fun rotate_left(n : 'e node) : 'e node =
case n of
Sub (_, v, ls, Sub (_, rv, rls, rrs)) =>
new_root_node(rv, new_root_node(v, ls, rls), rrs)
| _ => n
fun rotate_right(n : 'e node) : 'e node =
case n of
Sub (_, v, Sub (lv, lls, lrs), rs) =>
new_root_node(lv, lls, new_root_node(v, lrs, rs))
| _ => n
fun balance(n as Sub (height, this_val, left_child, right_child) : 'e node) : 'e node =
let
val n_balance_factor = balance_factor(n)
val l_balance_factor = balance_factor(left_child)
val r_balance_factor = balance_factor(right_child)
in
case (n_balance_factor, l_balance_factor, r_balance_factor) of
(2, ~1, _) =>
(* First rotate_left then rotate_right *)
rotate_right(new_root_node(this_val, rotate_left(left_child), right_child))
| (2, _, _) => rotate_right(n)
| (~2, _, 1) =>
(* First rotate_right then rotate_left *)
rotate_left(new_root_node(this_val, left_child, rotate_right(right_child)))
| (~2, _, _) => rotate_left(n)
| _ => n
end
The rotate_left()
and rotate_right()
functions are the building blocks of our code here. To balance a binary tree whose root
node has a balance factor of 2
or ~2
, it is useful to consider the following cases:
balance_factor(root) = 2
balance_factor(left_child) = ~1
: first performrotate_left
onleft_child
then performrotate_right
onroot
.- Otherwise: simply perform
rotate_right
onroot
.
balance_factor(root) = ~2
balance_factor(right_child) = 1
: first performrotate_right
onright_child
then performrotate_left
onroot
.- Otherwise: simply perform
rotate_left
onroot
.
Our implementation for the insert()
and remove()
functions is then straightforward:
fun insert(t : ('e, 'k) tree, element : 'e) : (('e, 'k) tree * 'e option) =
let
val (root, cmp, to_key) = t
val ret = ref NONE
fun insert_helper(n : 'e node) : 'e node =
case n of
Nil => new_root_node(element, Nil, Nil)
| Sub (height, this_val, left_child, right_child) =>
case cmp(to_key(element), to_key(this_val)) of
GREATER =>
balance(new_root_node(this_val, left_child, insert_helper(right_child)))
| LESS =>
balance(new_root_node(this_val, insert_helper(left_child), right_child))
| EQUAL =>
(ret := SOME element; n)
in
((insert_helper(root), cmp, to_key), !ret)
end
fun find_left(n as Sub (_, right_sub_val, right_sub_left, right_sub_right) : 'e node) : ('e node * 'e) =
case right_sub_left of
Nil => (right_sub_right, right_sub_val) (* Found the leftmost descendant *)
| _ =>
let
val (new_left, new_val) = find_left(right_sub_left)
in
(balance(new_root_node(right_sub_val, new_left, right_sub_right)), new_val)
end
fun remove(t : ('e, 'k) tree, key : 'k) : (('e, 'k) tree * 'e option) =
let
val (root, cmp, to_key) = t
val removed = ref NONE
fun remove_helper(n : 'e node) : 'e node =
case n of
Nil => n
| Sub (_, this_val, left_child, right_child) =>
case cmp(key, to_key(this_val)) of
GREATER =>
balance(new_root_node(this_val, left_child, remove_helper(right_child)))
| LESS =>
balance(new_root_node(this_val, remove_helper(left_child), right_child))
| EQUAL =>
(removed := SOME this_val;
case (left_child, right_child) of
(_, Nil) => left_child
| (Nil, _) => right_child
| _ =>
let
(* Find the leftmost descendant of the right sub-tree *)
val (new_right, new_val) = find_left(right_child)
in
balance(new_root_node(new_val, left_child, new_right))
end)
in
((remove_helper(root), cmp, to_key), !removed)
end
Our BST implementation can be easily modified in order to support a sorted dictionary data structure:
type (''k, 'v) dictionary = ((''k * 'v), ''k) BinarySearchTree.tree
which probably performs four basic operations, with $O(\log N)$ expected performance of course:
val get : (''k, 'v) dictionary * ''k -> 'v option
val put : (''k, 'v) dictionary * ''k * 'v -> (''k * 'v) dictionary * 'v option
val remove : (''k, 'v) dictionary * ''k -> (''k, 'v) dictionary * 'v option
val entries : (''k, 'v) dictionary -> (''k * 'v) list
Here is the code:
structure SortedDictionary = DictionaryFn(struct
type ''k compare_function = (''k * ''k) -> order
type (''k, 'v) dictionary = ((''k * 'v), ''k) BinarySearchTree.tree
type ''k create_parameter_type = ''k compare_function
fun create(cmp : ''k compare_function) : (''k, 'v) dictionary =
let
fun to_key((key, value) : ''k * 'v) : ''k = key
in
BinarySearchTree.create_empty(cmp, to_key)
end
fun get(dict : (''k, 'v) dictionary, key : ''k) : 'v option =
case BinarySearchTree.find(dict, key) of
NONE => NONE
| SOME(this_key, this_val) => SOME this_val
fun put(dict : (''k, 'v) dictionary, key : ''k, value : 'v) : (''k, 'v) dictionary * 'v option =
let
val (this_dict, this_key_val) = BinarySearchTree.insert(dict, (key, value))
in
case this_key_val of
NONE => (this_dict, NONE)
| SOME(this_key, this_val) => (this_dict, SOME this_val)
end
fun remove(dict : (''k, 'v) dictionary, key : ''k) : (''k, 'v) dictionary * 'v option =
let
val (this_dict, this_key_val) = BinarySearchTree.remove(dict, key)
in
case this_key_val of
NONE => (this_dict, NONE)
| SOME(this_key, this_val) => (this_dict, SOME this_val)
end
fun entries(dict : (''k, 'v) dictionary) : (''k * 'v) list =
BinarySearchTree.fold_rnl((fn (node, acc) => node::acc), [], dict)
end)
Notes
-
Alternatively, we may want to implement a red-black tree, which is a height-balanced binary search tree with the color invariant:
- No red node has a red parent, and
- every path from the root node to a leaf node has the same number of black nodes.
As shown above, a non-empty node may be implemented to take on the information about its own color. An implementation in OCaml is presented here. ⤴
Last updated: 2022-10-07