chenyihui 2 years ago
commit
acddff58d3
1 changed files with 140 additions and 0 deletions
  1. 140 0
      avl.py

+ 140 - 0
avl.py

@@ -0,0 +1,140 @@
+class TreeNode(object):
+    def __init__(self, val):
+        self.val = val
+        self.left = None
+        self.right = None
+        self.height = 1
+        self.size = 1
+
+    def __repr__(self) -> str:
+        return f'{self.val}({self.size})'
+
+
+class AVL_Tree(object):
+    @staticmethod
+    def insert(root, key):
+        if not root:
+            return TreeNode(key)
+        elif key < root.val:
+            root.left = AVL_Tree.insert(root.left, key)
+        else:
+            root.right = AVL_Tree.insert(root.right, key)
+
+        root.height = 1 + max(AVL_Tree.getHeight(root.left),
+                              AVL_Tree.getHeight(root.right))
+        root.size = 1 + \
+            AVL_Tree.getSize(root.left) + AVL_Tree.getSize(root.right)
+        balance = AVL_Tree.getBalance(root)
+
+        if balance > 1 and key <= root.left.val:
+            return AVL_Tree.rightRotate(root)
+
+        if balance < -1 and key >= root.right.val:
+            return AVL_Tree.leftRotate(root)
+
+        if balance > 1 and key >= root.left.val:
+            root.left = AVL_Tree.leftRotate(root.left)
+            return AVL_Tree.rightRotate(root)
+
+        if balance < -1 and key <= root.right.val:
+            root.right = AVL_Tree.rightRotate(root.right)
+            return AVL_Tree.leftRotate(root)
+
+        return root
+
+    @staticmethod
+    def leftRotate(z):
+        y = z.right
+        T2 = y.left
+
+        y.left = z
+        z.right = T2
+
+        z.height = 1 + max(AVL_Tree.getHeight(z.left),
+                           AVL_Tree.getHeight(z.right))
+        y.height = 1 + max(AVL_Tree.getHeight(y.left),
+                           AVL_Tree.getHeight(y.right))
+        ty = AVL_Tree.getSize(y.right)
+        tz = AVL_Tree.getSize(z.left)
+        z.size -= (ty + 1)
+        y.size += (tz + 1)
+        return y
+
+    @staticmethod
+    def rightRotate(z):
+        y = z.left
+        T3 = y.right
+
+        y.right = z
+        z.left = T3
+
+        z.height = 1 + max(AVL_Tree.getHeight(z.left),
+                           AVL_Tree.getHeight(z.right))
+        y.height = 1 + max(AVL_Tree.getHeight(y.left),
+                           AVL_Tree.getHeight(y.right))
+
+        ty = AVL_Tree.getSize(y.left)
+        tz = AVL_Tree.getSize(z.right)
+        z.size -= (ty + 1)
+        y.size += (tz + 1)
+        return y
+
+    @staticmethod
+    def getHeight(root):
+        if not root:
+            return 0
+        return root.height
+
+    @staticmethod
+    def getSize(root):
+        if not root:
+            return 0
+        return root.size
+
+    @staticmethod
+    def getBalance(root):
+        if not root:
+            return 0
+        return AVL_Tree.getHeight(root.left) - AVL_Tree.getHeight(root.right)
+
+    @staticmethod
+    def preOrder(root, depth=0):
+        if not root:
+            return
+        print("{0}{1}".format(' ' * depth, root))
+        AVL_Tree.preOrder(root.left, depth+1)
+        AVL_Tree.preOrder(root.right, depth+1)
+
+    @staticmethod
+    def getitem(root, idx):
+        #print(root, idx)
+        s = AVL_Tree.getSize(root.left)
+        if s == idx:
+            return root.val
+        elif idx > s:
+            return AVL_Tree.getitem(root.right, idx - s - 1)
+        else:
+            return AVL_Tree.getitem(root.left, idx)
+
+    @staticmethod
+    def index(root, val):
+        if not root:
+            return 0
+        if root.val < val:
+            return AVL_Tree.index(root.right, val) + AVL_Tree.getSize(root.left) + 1
+        else:
+            return AVL_Tree.index(root.left, val)
+# Driver program to test above function
+
+root = None
+
+
+z = [6,21,-27,17,-20,3,1,-2,10,2,23,15,-3,1,9,19,-9,-24,-30,-26,-13,23,2,-10,20,0,27,24,-28,26,0,-29,-16,0,12,-28,7,1,22,-23,20,-22,-11,7,-10,-5,27,27,0,19,-9,28,-2,6,23,-9,-9,1,8,-15]
+
+for i in z:
+    root = AVL_Tree.insert(root, i)
+
+AVL_Tree.preOrder(root)
+
+for i in range(root.size):
+    print(AVL_Tree.getitem(root, i))