TD7 - Binary Search Trees (BST)

0. Preliminaries

Below are all the prerequisite to run the exercise.

Also are appended the codes of the class Node on the fly.

In [1]:
import numpy as np
import sys
In [2]:
LEFT = 0
RIGHT= 1

class Node:
    def __init__(self, value, left_child=None, right_child=None):
        self.value = value
        self.children = [left_child, right_child]
        
    def insertion(self, value):
        direction = RIGHT if self.value < value else LEFT
        if self.children[direction] == None:
            self.children[direction] = Node(value)
        else:
            self.children[direction].insertion(value)   
            
    def insertion_random(self, value, direction=np.random.randint(0,2)):
        if self.children[direction] == None:
            self.children[direction] = Node(value)
        else:            
            self.children[direction].insertion_random(value,np.random.randint(0,2))

    def delete(self,value):
        if self == None:
            return

        if self.value > value:
            # This 'if' is a dirty trick to have access to node deletion one-step ahead of meeting the node
            # Another option is to pass the father of the current node as argument, so to delete the child from the father
            # [CAREFUL] "del self" only deletes a local pointer, not the node itself
            if (self.children[LEFT] != None and self.children[LEFT].value == value and self.children[LEFT].children[LEFT]==None and self.children[LEFT].children[RIGHT]==None):
                self.children[LEFT] = None
                return            
            self.children[LEFT].delete(value)

        if self.value < value:
            # same as above
            if (self.children[RIGHT] != None and self.children[RIGHT].value == value and self.children[RIGHT].children[LEFT]==None and self.children[RIGHT].children[RIGHT]==None):
                self.children[RIGHT] = None
                return 
            self.children[RIGHT].delete(value)

        if self.value == value:
            if self.children[LEFT] == None:
                self.value = self.children[RIGHT].value
                self.children = self.children[RIGHT].children
                return

            if self.children[RIGHT] == None:
                self.value = self.children[LEFT].value
                self.children = self.children[LEFT].children
                return

            # Finds the bottom-rightest value (i.e., the maximum) in the LEFT subtree
            current_node=self.children[LEFT]
            depth=0
            while current_node.children[RIGHT]!=None:
                previous_node = current_node
                current_node = current_node.children[RIGHT]
                depth+=1

            # Switch the node of the found maximum (bottom-rightmost of LEFT subtree) with the node to "delete"
            # and delete this bottom-rightmost node by removing it from its father (previous_node)
            self.value = current_node.value
            if depth == 0:
                self.children[LEFT] = None
            else:
                previous_node.children[RIGHT] = None
            return           

    def print_tree(self, level=0):
        print(" |---" * level,self.value)        
        for child in self.children:
            if child == None:
                print(" |---" * (level+1),"*")
            else:
                child.print_tree(level+1)           
In [3]:
treeSize = 10
values = np.random.randint(0,100,treeSize)

print("List of values: ",values)

tree_random = Node(values[0])

for v in values[1:]:
    tree_random.insertion_random(v)

print("\nRandom tree:\n")    
    
tree_random.print_tree()

print("\n\n")

tree_bst = Node(values[0])
for v in values[1:]:
    tree_bst.insertion(v)

print("BST tree:\n")    
    
tree_bst.print_tree()
List of values:  [79 73 10 50 44  7 22 69 82 31]

Random tree:

 79
 |--- 73
 |--- |--- 10
 |--- |--- |--- 50
 |--- |--- |--- |--- 69
 |--- |--- |--- |--- |--- *
 |--- |--- |--- |--- |--- *
 |--- |--- |--- |--- *
 |--- |--- |--- 82
 |--- |--- |--- |--- *
 |--- |--- |--- |--- *
 |--- |--- 44
 |--- |--- |--- 22
 |--- |--- |--- |--- *
 |--- |--- |--- |--- *
 |--- |--- |--- 7
 |--- |--- |--- |--- *
 |--- |--- |--- |--- 31
 |--- |--- |--- |--- |--- *
 |--- |--- |--- |--- |--- *
 |--- *



BST tree:

 79
 |--- 73
 |--- |--- 10
 |--- |--- |--- 7
 |--- |--- |--- |--- *
 |--- |--- |--- |--- *
 |--- |--- |--- 50
 |--- |--- |--- |--- 44
 |--- |--- |--- |--- |--- 22
 |--- |--- |--- |--- |--- |--- *
 |--- |--- |--- |--- |--- |--- 31
 |--- |--- |--- |--- |--- |--- |--- *
 |--- |--- |--- |--- |--- |--- |--- *
 |--- |--- |--- |--- |--- *
 |--- |--- |--- |--- 69
 |--- |--- |--- |--- |--- *
 |--- |--- |--- |--- |--- *
 |--- |--- *
 |--- 82
 |--- |--- *
 |--- |--- *

1. Checking that a binary tree is a BST

The idea is the verification algorithm is to go down the tree and check that "left child root value" < "root value" < "right child root value".

To keep track at the same time of the minimum and maximum value in the tree, these must be propagated down the tree when checking its validity.


Function: verif
INPUT : node
OUTPUT : boolean that checks if a tree rooted at node is a valid BST , (min,max) of the tree
In [4]:
def verif(self):
    if self == None:
        return True,sys.maxsize,-sys.maxsize
    
    if self.children[LEFT] == None and self.children[RIGHT] == None:
        return True,self.value,self.value
    
    if self.children[LEFT] == None:
        vr = verif(self.children[RIGHT])
        return vr[0] and (self.value<=vr[1]),min(self.value,vr[1]),max(self.value,vr[2])
    
    if self.children[RIGHT] == None:
        vl = verif(self.children[LEFT])
        return vl[0] and (self.value>=vl[2]),min(self.value,vl[1]),max(self.value,vl[2])
    
    vl = verif(self.children[LEFT])
    vr = verif(self.children[RIGHT])
    return vl[0] and vr[0] and (vl[2]<=self.value<=vr[1]),min(self.value,min(vl[1],vr[1])),max(self.value,max(vl[2],vr[2]))

verif=trace(verif)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-4-02a65c5aeb1c> in <module>
     18     return vl[0] and vr[0] and (vl[2]<=self.value<=vr[1]),min(self.value,min(vl[1],vr[1])),max(self.value,max(vl[2],vr[2]))
     19 
---> 20 verif=trace(verif)

NameError: name 'trace' is not defined
In [5]:
print("Verif for tree_random: ",verif(tree_random))

print("\n\n")

print("Verif for tree_bst: ",verif(tree_bst))
Verif for tree_random:  (False, 7, 82)



Verif for tree_bst:  (True, 7, 82)

2. Removal in a BST

There are two options: when the value to discard is found:

  • either replace it by the maximum value (bottom rightmost) of the left child tree (this is what we do here)
  • or replace it by the minimum value (bottom leftmost) of the right child tree

The cost can vary from $O(1)$ (actually $1$ operation) in the lucky case where the tree has an isolated child leaf and this value is to be removed. In the worst case, the full depth of the tree needs to be explored and this costs at most $O(n)$ if the tree is unluckily build always on the same side. Generally though, the average cost would be the depth of the tree which generically would be of order $O(\log n)$.

Function: delete
INPUT : BST,value
OUTPUT : no output but operation in place on the BST to remove the value from the tree (if found)
In [6]:
def delete(self,value):
    if self == None:
        return

    if self.value > value:
        # This 'if' is a dirty trick to have access to node deletion one-step ahead of meeting the node
        # Another option is to pass the father of the current node as argument, so to delete the child from the father
        # [CAREFUL] "del self" only deletes a local pointer, not the node itself
        if (self.children[LEFT] != None and self.children[LEFT].value == value and self.children[LEFT].children[LEFT]==None and self.children[LEFT].children[RIGHT]==None):
            self.children[LEFT] = None
            return            
        self.children[LEFT].delete(value)

    if self.value < value:
        # same as above
        if (self.children[RIGHT] != None and self.children[RIGHT].value == value and self.children[RIGHT].children[LEFT]==None and self.children[RIGHT].children[RIGHT]==None):
            self.children[RIGHT] = None
            return 
        self.children[RIGHT].delete(value)

    if self.value == value:
        if self.children[LEFT] == None:
            self.value = self.children[RIGHT].value
            self.children = self.children[RIGHT].children
            return

        if self.children[RIGHT] == None:
            self.value = self.children[LEFT].value
            self.children = self.children[LEFT].children
            return

        # Finds the bottom-rightest value (i.e., the maximum) in the LEFT subtree
        current_node=self.children[LEFT]
        depth=0
        while current_node.children[RIGHT]!=None:
            previous_node = current_node
            current_node = current_node.children[RIGHT]
            depth+=1

        # Switch the node of the found maximum (bottom-rightmost of LEFT subtree) with the node to "delete"
        # and delete this bottom-rightmost node by removing it from its father (previous_node)
        self.value = current_node.value
        if depth == 0:
            self.children[LEFT] = current_node.children[LEFT]
        else:
            previous_node.children[RIGHT] = None
        return   
In [12]:
treeSize = 10
values = np.random.randint(0,100,treeSize)

print(values)

tree_bst = Node(values[0])
for v in values[1:]:
    tree_bst.insertion(v)

print("Base tree:\n")    
    
tree_bst.print_tree()

index = np.random.randint(treeSize)


tree_bst.delete(values[index])
print("\nTree after deletion of values[",index,"]=",values[index],"\n")

tree_bst.print_tree()
[39 92 73 31 93  1 38 77 90 64]
Base tree:

 39
 |--- 31
 |--- |--- 1
 |--- |--- |--- *
 |--- |--- |--- *
 |--- |--- 38
 |--- |--- |--- *
 |--- |--- |--- *
 |--- 92
 |--- |--- 73
 |--- |--- |--- 64
 |--- |--- |--- |--- *
 |--- |--- |--- |--- *
 |--- |--- |--- 77
 |--- |--- |--- |--- *
 |--- |--- |--- |--- 90
 |--- |--- |--- |--- |--- *
 |--- |--- |--- |--- |--- *
 |--- |--- 93
 |--- |--- |--- *
 |--- |--- |--- *

Tree after deletion of values[ 0 ]= 39 

 38
 |--- 31
 |--- |--- 1
 |--- |--- |--- *
 |--- |--- |--- *
 |--- |--- *
 |--- 92
 |--- |--- 73
 |--- |--- |--- 64
 |--- |--- |--- |--- *
 |--- |--- |--- |--- *
 |--- |--- |--- 77
 |--- |--- |--- |--- *
 |--- |--- |--- |--- 90
 |--- |--- |--- |--- |--- *
 |--- |--- |--- |--- |--- *
 |--- |--- 93
 |--- |--- |--- *
 |--- |--- |--- *

EXTRAS

In [274]:
from functools import wraps

def trace(func):
    func_name = func.__name__
    separator = '|  '

    trace.recursion_depth = 0

    @wraps(func)
    def traced_func(*args, **kwargs):

        # repeat separator N times (where N is recursion depth)
        # `map(str, args)` prepares the iterable with str representation of positional arguments
        # `", ".join(map(str, args))` will generate comma-separated list of positional arguments
        # `"x"*5` will print `"xxxxx"` - so we can use multiplication operator to repeat separator
        print(f'{separator * trace.recursion_depth}|-- {func_name}({", ".join(map(str, args))})')
        # we're diving in
        trace.recursion_depth += 1
        result = func(*args, **kwargs)
        # going out of that level of recursion
        trace.recursion_depth -= 1
        # result is printed on the next level
        print(f'{separator * (trace.recursion_depth + 1)}|-- return {result}')

        return result

    return traced_func
In [ ]: