""" File: binary_search_tree.py Created Time: 2022-12-20 Author: a16su (lpluls001@gmail.com) """ import sys, os.path as osp sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) from modules import * class BinarySearchTree: """二叉搜索树""" def __init__(self): """构造方法""" # 初始化空树 self._root = None def get_root(self) -> TreeNode | None: """获取二叉树根节点""" return self._root def search(self, num: int) -> TreeNode | None: """查找节点""" cur = self._root # 循环查找,越过叶节点后跳出 while cur is not None: # 目标节点在 cur 的右子树中 if cur.val < num: cur = cur.right # 目标节点在 cur 的左子树中 elif cur.val > num: cur = cur.left # 找到目标节点,跳出循环 else: break return cur def insert(self, num: int): """插入节点""" # 若树为空,则初始化根节点 if self._root is None: self._root = TreeNode(num) return # 循环查找,越过叶节点后跳出 cur, pre = self._root, None while cur is not None: # 找到重复节点,直接返回 if cur.val == num: return pre = cur # 插入位置在 cur 的右子树中 if cur.val < num: cur = cur.right # 插入位置在 cur 的左子树中 else: cur = cur.left # 插入节点 node = TreeNode(num) if pre.val < num: pre.right = node else: pre.left = node def remove(self, num: int): """删除节点""" # 若树为空,直接提前返回 if self._root is None: return # 循环查找,越过叶节点后跳出 cur, pre = self._root, None while cur is not None: # 找到待删除节点,跳出循环 if cur.val == num: break pre = cur # 待删除节点在 cur 的右子树中 if cur.val < num: cur = cur.right # 待删除节点在 cur 的左子树中 else: cur = cur.left # 若无待删除节点,则直接返回 if cur is None: return # 子节点数量 = 0 or 1 if cur.left is None or cur.right is None: # 当子节点数量 = 0 / 1 时, child = null / 该子节点 child = cur.left or cur.right # 删除节点 cur if cur != self._root: if pre.left == cur: pre.left = child else: pre.right = child else: # 若删除节点为根节点,则重新指定根节点 self._root = child # 子节点数量 = 2 else: # 获取中序遍历中 cur 的下一个节点 tmp: TreeNode = cur.right while tmp.left is not None: tmp = tmp.left # 递归删除节点 tmp self.remove(tmp.val) # 用 tmp 覆盖 cur cur.val = tmp.val """Driver Code""" if __name__ == "__main__": # 初始化二叉搜索树 bst = BinarySearchTree() nums = [8, 4, 12, 2, 6, 10, 14, 1, 3, 5, 7, 9, 11, 13, 15] # 请注意,不同的插入顺序会生成不同的二叉树,该序列可以生成一个完美二叉树 for num in nums: bst.insert(num) print("\n初始化的二叉树为\n") print_tree(bst.get_root()) # 查找节点 node = bst.search(7) print("\n查找到的节点对象为: {},节点值 = {}".format(node, node.val)) # 插入节点 bst.insert(16) print("\n插入节点 16 后,二叉树为\n") print_tree(bst.get_root()) # 删除节点 bst.remove(1) print("\n删除节点 1 后,二叉树为\n") print_tree(bst.get_root()) bst.remove(2) print("\n删除节点 2 后,二叉树为\n") print_tree(bst.get_root()) bst.remove(4) print("\n删除节点 4 后,二叉树为\n") print_tree(bst.get_root())