BinarySearchTree
code: python
from typing import Optional
class Node:
def __init__(self, value: int):
self.value: int = value
self.left: OptionalNode = None
self.right: OptionalNode = None
# 0
# / \
# 1 4
# / \ / \
# 2 3 5 6
class BinarySearchTree:
def __init__(self):
self.root: OptionalNode = None
# O(log n), O(log n)
def insert(self, value: int) -> None:
self.root = self._insert_recursive(self.root, value)
# O(log n), O(log n)
def _insert_recursive(self, node: OptionalNode, value: int) -> Node:
if node is None:
return Node(value)
if value < node.value:
node.left = self._insert_recursive(node.left, value)
else:
node.right = self._insert_recursive(node.right, value)
return node
# O(log n), O(log n)
def remove(self, value: int) -> None:
self.root = self._remove_recursive(self.root, value)
# O(log n), O(log n)
def _remove_recursive(self, node: OptionalNode, value: int) -> OptionalNode:
if node is None:
return None
if value < node.value:
node.left = self._remove_recursive(node.left, value)
elif value > node.value:
node.right = self._remove_recursive(node.right, value)
else:
if node.left is None:
return node.right
elif node.right is None:
return node.left
max_node = self._find_max(node.left)
node.value = max_node.value
node.left = self._remove_recursive(node.left, max_node.value)
# min_node = self._find_min(node.right)
# node.value = min_node.value
# node.right = self._remove_recursive(node.right, min_node.value)
return node
# O(log n), O(1)
def _find_max(self, node: Node) -> Node:
current = node
while current.right:
current = current.right
return current
# O(log n), O(1)
def _find_min(self, node: Node) -> Node:
current = node
while current.left:
current = current.left
return current
# O(log n), O(log n)
def find(self, value: int) -> OptionalNode:
return self._find_recursive(self.root, value)
# O(log n), O(log n)
def _find_recursive(self, node: OptionalNode, value: int) -> OptionalNode:
if node is None or node.value == value:
return node
if value < node.value:
return self._find_recursive(node.left, value)
return self._find_recursive(node.right, value)
import pytest
@pytest.fixture
def bst():
tree = BinarySearchTree()
values = 15, 9, 23, 3, 12, 17, 1, 4
for value in values:
tree.insert(value)
return tree
# 15
# / \
# 9 23
# / \ /
# 3 12 17
# / \
# 1 4
def test_insert(bst):
bst.insert(16)
assert bst.root.right.left.left.value == 16
def test_find(bst):
assert bst.find(12).value == 12
assert bst.find(100) is None
def test_remove_leaf(bst):
bst.remove(4)
assert bst.find(4) is None
assert bst.root.left.left.right is None
def test_remove_with_one_child(bst):
bst.remove(3)
assert bst.find(3) is None
assert bst.root.left.left.value == 1
assert bst.root.left.left.right.value == 4
# assert bst.root.left.left.value == 4
# assert bst.root.left.left.left.value == 1
def test_remove_with_two_children(bst):
bst.remove(9)
assert bst.find(9) is None
assert bst.root.left.value == 4
assert bst.root.left.left.value == 3
assert bst.root.left.right.value == 12
# assert bst.root.left.value == 12
# assert bst.root.left.left.value == 3
# assert bst.root.left.right is None