123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- class TreeNode(object):
- def __init__(self, val):
- self.val = val
- self.left = None
- self.right = None
- self.height = 1
- self.size = 1
- def __repr__(self) -> str:
- return f'{self.val}({self.size})'
- class AVL_Tree(object):
- @staticmethod
- def insert(root, key):
- if not root:
- return TreeNode(key)
- elif key < root.val:
- root.left = AVL_Tree.insert(root.left, key)
- else:
- root.right = AVL_Tree.insert(root.right, key)
- root.height = 1 + max(AVL_Tree.getHeight(root.left),
- AVL_Tree.getHeight(root.right))
- root.size = 1 + \
- AVL_Tree.getSize(root.left) + AVL_Tree.getSize(root.right)
- balance = AVL_Tree.getBalance(root)
- if balance > 1 and key <= root.left.val:
- return AVL_Tree.rightRotate(root)
- if balance < -1 and key >= root.right.val:
- return AVL_Tree.leftRotate(root)
- if balance > 1 and key >= root.left.val:
- root.left = AVL_Tree.leftRotate(root.left)
- return AVL_Tree.rightRotate(root)
- if balance < -1 and key <= root.right.val:
- root.right = AVL_Tree.rightRotate(root.right)
- return AVL_Tree.leftRotate(root)
- return root
- @staticmethod
- def leftRotate(z):
- y = z.right
- T2 = y.left
- y.left = z
- z.right = T2
- z.height = 1 + max(AVL_Tree.getHeight(z.left),
- AVL_Tree.getHeight(z.right))
- y.height = 1 + max(AVL_Tree.getHeight(y.left),
- AVL_Tree.getHeight(y.right))
- ty = AVL_Tree.getSize(y.right)
- tz = AVL_Tree.getSize(z.left)
- z.size -= (ty + 1)
- y.size += (tz + 1)
- return y
- @staticmethod
- def rightRotate(z):
- y = z.left
- T3 = y.right
- y.right = z
- z.left = T3
- z.height = 1 + max(AVL_Tree.getHeight(z.left),
- AVL_Tree.getHeight(z.right))
- y.height = 1 + max(AVL_Tree.getHeight(y.left),
- AVL_Tree.getHeight(y.right))
- ty = AVL_Tree.getSize(y.left)
- tz = AVL_Tree.getSize(z.right)
- z.size -= (ty + 1)
- y.size += (tz + 1)
- return y
- @staticmethod
- def getHeight(root):
- if not root:
- return 0
- return root.height
- @staticmethod
- def getSize(root):
- if not root:
- return 0
- return root.size
- @staticmethod
- def getBalance(root):
- if not root:
- return 0
- return AVL_Tree.getHeight(root.left) - AVL_Tree.getHeight(root.right)
- @staticmethod
- def preOrder(root, depth=0):
- if not root:
- return
- print("{0}{1}".format(' ' * depth, root))
- AVL_Tree.preOrder(root.left, depth+1)
- AVL_Tree.preOrder(root.right, depth+1)
- @staticmethod
- def getitem(root, idx):
- #print(root, idx)
- s = AVL_Tree.getSize(root.left)
- if s == idx:
- return root.val
- elif idx > s:
- return AVL_Tree.getitem(root.right, idx - s - 1)
- else:
- return AVL_Tree.getitem(root.left, idx)
- @staticmethod
- def index(root, val):
- if not root:
- return 0
- if root.val < val:
- return AVL_Tree.index(root.right, val) + AVL_Tree.getSize(root.left) + 1
- else:
- return AVL_Tree.index(root.left, val)
- # Driver program to test above function
- root = None
- z = [6,21,-27,17,-20,3,1,-2,10,2,23,15,-3,1,9,19,-9,-24,-30,-26,-13,23,2,-10,20,0,27,24,-28,26,0,-29,-16,0,12,-28,7,1,22,-23,20,-22,-11,7,-10,-5,27,27,0,19,-9,28,-2,6,23,-9,-9,1,8,-15]
- for i in z:
- root = AVL_Tree.insert(root, i)
- AVL_Tree.preOrder(root)
- for i in range(root.size):
- print(AVL_Tree.getitem(root, i))
|