-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrbtree.py
More file actions
319 lines (271 loc) · 8.85 KB
/
rbtree.py
File metadata and controls
319 lines (271 loc) · 8.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
"""
红黑树的性质:
- 每一个节点要么是黑色, 要么是红色.
- 根节点是黑色的.
- 叶子节点是黑色的.
- 如果一个节点是红色的, 那么它的子节点必须是黑色的.
- 从一个节点到NULL节点的所有路径必须包含相同数量的黑色节点.
"""
import graphviz
RED = 'R'
BLACK = 'B'
DEBUG = False
_print = print
def print(*args):
if DEBUG: _print(*args)
class RBNode:
def __init__(self):
self.value: int = None
self.color = RED
self.left:RBNode = None
self.right:RBNode = None
self.parent:RBNode = None
def get_parent(self) -> "RBNode":
return self.parent
def get_sibling(self) -> "RBNode":
p = self.get_parent()
if p is None:
return
return p.right if self == p.left else p.left
def get_grandparent(self) -> "RBNode":
p = self.get_parent()
if p is None:
return
gp = p.get_parent()
return gp
def get_uncle(self) -> "RBNode":
p = self.get_parent()
uncle = p.get_sibling()
return uncle
def __str__(self):
return f'{self.value}.{self.color}'
class RBTree:
def __init__(self):
self.root = None
def find(self, value) -> RBNode:
curr = self.root
while curr and curr.value != value:
curr = curr.left if value < curr.value else curr.right
return curr
def add(self, value:int):
new_node = RBNode()
new_node.value = value
new_node.color = RED
curr = self.root
while curr:
if new_node.value < curr.value:
if curr.left is None:
curr.left = new_node
break
curr = curr.left
else:
if curr.right is None:
curr.right = new_node
break
curr = curr.right
new_node.parent = curr
self._insert(new_node)
def _insert(self, node: RBNode) -> RBNode:
print(f'insert node: {node.value}')
parent = node.get_parent()
# case 1
if parent == None:
print('insert case 1')
node.color = BLACK
self.root = node
return
# case 2
if parent.color == BLACK:
print('insert case 2')
return
uncle = node.get_uncle()
grandparent = node.get_grandparent()
# case 3
if uncle and uncle.color == RED:
print('insert case 3')
parent.color = uncle.color = BLACK
grandparent.color = RED
return self._insert(grandparent)
# case 4
print('insert case 4')
if parent == grandparent.left:
if node == parent.right:
self.rotateleft(parent)
node, parent = parent, node
self.rotateright(grandparent)
else:
if node == parent.left:
self.rotateright(parent)
node, parent = parent, node
self.rotateleft(grandparent)
parent.color = BLACK
grandparent.color = RED
if grandparent == self.root:
print(f'grandparent {grandparent} is root and it\'s rotated, set parent `{parent}`to be new root')
self.root = parent
def remove(self, value):
node = self.find(value)
if node is None:
print(f'remove: can\'t find value `{value}`')
return
m = self.getmax(node.left)
if m :
node.value = m.value
node = m
parent = node.get_parent()
# simple case 1
if node.color == RED:
if node == parent.left:
parent.left = None
else:
parent.right = None
return
# simple case 2
child = node.left if node.left else node.right
if node.color == BLACK and child:
child.parent = parent
child.color = BLACK
else:
self._remove(node)
# node is maybe self.root
if node == self.root:
self.root = child
else:
if node == parent.left:
parent.left = child
else:
parent.right = child
def _remove(self, node: RBNode):
# case 1
if node.get_parent() is None:
return
# case 2
sibling = node.get_sibling()
parent = node.get_parent()
if sibling.color == RED:
parent.color = RED
sibling.color = BLACK
if node == parent.left:
self.rotateleft(parent)
else:
self.rotateright(parent)
# case3
# because the case 2's rotation, node's sibling and parent is changed.
sibling = node.get_sibling()
parent = node.get_parent()
if (parent.color == BLACK and sibling.color == BLACK and
(sibling.left is None or sibling.left.color == BLACK) and
(sibling.right is None or sibling.right.color == BLACK)):
sibling.color = RED
self._remove(parent)
return
# case 4
if parent.color == RED:
sibling.color = RED
parent.color = BLACK
return
# case5
if (node == parent.left and sibling.left and sibling.left.color == RED and
(sibling.right is None or sibling.right.color == BLACK)):
sibling.color = RED
sibling.left.color = BLACK
self.rotateright(sibling)
elif (node == parent.right and sibling.right and sibling.right.color == RED and
(sibling.left is None or sibling.left.color == BLACK)):
sibling.color = RED
sibling.right.color = BLACK
self.rotateleft(sibling)
# case6
sibling = node.get_sibling()
sibling.color = parent.color
parent.color = BLACK
if node == parent.left:
sibling.right.color = BLACK
self.rotateleft(parent)
else:
sibling.left.color = BLACK
self.rotateright(parent)
if parent == self.root:
self.root = sibling
def getmin(self, root):
while root and root.left:
root = root.left
return root
def getmax(self, root):
while root and root.right:
root = root.right
return root
@staticmethod
def rotateleft(node:RBNode):
child:RBNode = node.right
if child is None:
return
node.right, child.left = child.left, node
if node.right:
node.right.parent = node
if node.parent:
if node == node.parent.left:
node.parent.left = child
else:
node.parent.right = child
node.parent, child.parent = child, node.parent
@staticmethod
def rotateright(node):
child = node.left
if child is None:
return
node.left, child.right = child.right, node
if node.left:
node.left.parent = node
if node.parent:
if node == node.parent.left:
node.parent.left = child
else:
node.parent.right = child
node.parent, child.parent = child, node.parent
def display(self):
queue = [self.root]
while queue:
l = len(queue)
line = []
for i in range(l):
n = queue.pop(0)
if n is None:
line.append(f'N.{BLACK}')
else:
line.append(str(n))
queue.append(n.left)
queue.append(n.right)
print(', '.join(line))
print()
def display2(self):
def dis(node):
if node is None:
return
name = label = str(node.value)
color = 'black' if node.color == BLACK else 'red'
graph.node(name, label, fillcolor=color, pencolor=color, bgcolor=color, color=color, style='filled')
if node.left:
l = dis(node.left)
graph.edge(name, l)
if node.right:
r = dis(node.right)
graph.edge(name, r)
return name
graph = graphviz.Digraph(name='rbtree', format='png', node_attr={'fontcolor': 'white', 'shape': 'circle', 'fixedsize': 'True'})
dis(self.root)
graph.render('./rbtree.dot')
if __name__ == '__main__':
rbtree = RBTree()
n = 20
for i in range(1, n + 1):
rbtree.add(i)
rbtree.display2()
rbtree.remove(3)
rbtree.display()
rbtree.remove(5)
rbtree.display()
rbtree.remove(8)
rbtree.display()
rbtree.remove(9)
rbtree.display()