avl.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. class TreeNode(object):
  2. def __init__(self, val):
  3. self.val = val
  4. self.left = None
  5. self.right = None
  6. self.height = 1
  7. self.size = 1
  8. def __repr__(self) -> str:
  9. return f'{self.val}({self.size})'
  10. class AVL_Tree(object):
  11. @staticmethod
  12. def insert(root, key):
  13. if not root:
  14. return TreeNode(key)
  15. elif key < root.val:
  16. root.left = AVL_Tree.insert(root.left, key)
  17. else:
  18. root.right = AVL_Tree.insert(root.right, key)
  19. root.height = 1 + max(AVL_Tree.getHeight(root.left),
  20. AVL_Tree.getHeight(root.right))
  21. root.size = 1 + \
  22. AVL_Tree.getSize(root.left) + AVL_Tree.getSize(root.right)
  23. balance = AVL_Tree.getBalance(root)
  24. if balance > 1 and key <= root.left.val:
  25. return AVL_Tree.rightRotate(root)
  26. if balance < -1 and key >= root.right.val:
  27. return AVL_Tree.leftRotate(root)
  28. if balance > 1 and key >= root.left.val:
  29. root.left = AVL_Tree.leftRotate(root.left)
  30. return AVL_Tree.rightRotate(root)
  31. if balance < -1 and key <= root.right.val:
  32. root.right = AVL_Tree.rightRotate(root.right)
  33. return AVL_Tree.leftRotate(root)
  34. return root
  35. @staticmethod
  36. def leftRotate(z):
  37. y = z.right
  38. T2 = y.left
  39. y.left = z
  40. z.right = T2
  41. z.height = 1 + max(AVL_Tree.getHeight(z.left),
  42. AVL_Tree.getHeight(z.right))
  43. y.height = 1 + max(AVL_Tree.getHeight(y.left),
  44. AVL_Tree.getHeight(y.right))
  45. ty = AVL_Tree.getSize(y.right)
  46. tz = AVL_Tree.getSize(z.left)
  47. z.size -= (ty + 1)
  48. y.size += (tz + 1)
  49. return y
  50. @staticmethod
  51. def rightRotate(z):
  52. y = z.left
  53. T3 = y.right
  54. y.right = z
  55. z.left = T3
  56. z.height = 1 + max(AVL_Tree.getHeight(z.left),
  57. AVL_Tree.getHeight(z.right))
  58. y.height = 1 + max(AVL_Tree.getHeight(y.left),
  59. AVL_Tree.getHeight(y.right))
  60. ty = AVL_Tree.getSize(y.left)
  61. tz = AVL_Tree.getSize(z.right)
  62. z.size -= (ty + 1)
  63. y.size += (tz + 1)
  64. return y
  65. @staticmethod
  66. def getHeight(root):
  67. if not root:
  68. return 0
  69. return root.height
  70. @staticmethod
  71. def getSize(root):
  72. if not root:
  73. return 0
  74. return root.size
  75. @staticmethod
  76. def getBalance(root):
  77. if not root:
  78. return 0
  79. return AVL_Tree.getHeight(root.left) - AVL_Tree.getHeight(root.right)
  80. @staticmethod
  81. def preOrder(root, depth=0):
  82. if not root:
  83. return
  84. print("{0}{1}".format(' ' * depth, root))
  85. AVL_Tree.preOrder(root.left, depth+1)
  86. AVL_Tree.preOrder(root.right, depth+1)
  87. @staticmethod
  88. def getitem(root, idx):
  89. #print(root, idx)
  90. s = AVL_Tree.getSize(root.left)
  91. if s == idx:
  92. return root.val
  93. elif idx > s:
  94. return AVL_Tree.getitem(root.right, idx - s - 1)
  95. else:
  96. return AVL_Tree.getitem(root.left, idx)
  97. @staticmethod
  98. def index(root, val):
  99. if not root:
  100. return 0
  101. if root.val < val:
  102. return AVL_Tree.index(root.right, val) + AVL_Tree.getSize(root.left) + 1
  103. else:
  104. return AVL_Tree.index(root.left, val)
  105. # Driver program to test above function
  106. root = None
  107. z = [6,21,-27,17,-20,3,1,-2,10,2,23,15,-3,1,9,19,-9,-24,-30,-26,-13,23,2,-10,20,0,27,24,-28,26,0,-29,-16,0,12,-28,7,1,22,-23,20,-22,-11,7,-10,-5,27,27,0,19,-9,28,-2,6,23,-9,-9,1,8,-15]
  108. for i in z:
  109. root = AVL_Tree.insert(root, i)
  110. AVL_Tree.preOrder(root)
  111. for i in range(root.size):
  112. print(AVL_Tree.getitem(root, i))