forked from TheAlgorithms/Python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsegment_tree.py
More file actions
171 lines (142 loc) · 4.83 KB
/
segment_tree.py
File metadata and controls
171 lines (142 loc) · 4.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""Segment Tree Data Structure.
A Segment Tree is a binary tree used for storing intervals or segments.
It allows querying which of the stored segments contain a given point.
Typically used for range queries and updates.
Time Complexity:
- Build: O(n)
- Query: O(log n)
- Update: O(log n)
Space Complexity: O(n)
"""
from collections.abc import Callable
class SegmentTree:
"""Segment Tree implementation for range queries.
This implementation supports range sum queries and point updates.
Can be extended to support other operations like min/max queries.
Attributes:
tree: List storing the segment tree nodes
n: Size of the input array
operation: Function to combine two values (default: addition)
>>> st = SegmentTree([1, 3, 5, 7, 9, 11])
>>> st.query(1, 3)
15
>>> st.update(1, 10)
>>> st.query(1, 3)
22
>>> st.query(0, 5)
42
"""
def __init__(
self, arr: list[int], operation: Callable[[int, int], int] = lambda a, b: a + b
) -> None:
"""Initialize segment tree with array.
Args:
arr: Input array
operation: Function to combine two values (default: sum)
>>> st = SegmentTree([1, 2, 3, 4, 5])
>>> st.n
5
>>> len(st.tree)
20
"""
self.n = len(arr)
self.tree = [0] * (4 * self.n)
self.operation = operation
self._build(arr, 0, 0, self.n - 1)
def _build(self, arr: list[int], node: int, start: int, end: int) -> None:
"""Build segment tree recursively.
Args:
arr: Input array
node: Current node index
start: Start of current segment
end: End of current segment
"""
if start == end:
# Leaf node
self.tree[node] = arr[start]
else:
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
self._build(arr, left_child, start, mid)
self._build(arr, right_child, mid + 1, end)
self.tree[node] = self.operation(
self.tree[left_child], self.tree[right_child]
)
def query(self, left: int, right: int) -> int:
"""Query sum of elements in range [left, right].
Args:
left: Left boundary of query range
right: Right boundary of query range
Returns:
Sum of elements in range
>>> st = SegmentTree([1, 2, 3, 4, 5])
>>> st.query(0, 2)
6
>>> st.query(1, 3)
9
"""
return self._query(0, 0, self.n - 1, left, right)
def _query(self, node: int, start: int, end: int, left: int, right: int) -> int:
"""Recursive helper for range query.
Args:
node: Current node index
start: Start of current segment
end: End of current segment
left: Left boundary of query range
right: Right boundary of query range
Returns:
Query result for the range
"""
if right < start or left > end:
# No overlap
return 0
if left <= start and end <= right:
# Complete overlap
return self.tree[node]
# Partial overlap
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
left_sum = self._query(left_child, start, mid, left, right)
right_sum = self._query(right_child, mid + 1, end, left, right)
return self.operation(left_sum, right_sum)
def update(self, index: int, value: int) -> None:
"""Update value at given index.
Args:
index: Index to update
value: New value
>>> st = SegmentTree([1, 2, 3, 4, 5])
>>> st.query(0, 4)
15
>>> st.update(2, 10)
>>> st.query(0, 4)
22
"""
self._update(0, 0, self.n - 1, index, value)
def _update(self, node: int, start: int, end: int, index: int, value: int) -> None:
"""Recursive helper for point update.
Args:
node: Current node index
start: Start of current segment
end: End of current segment
index: Index to update
value: New value
"""
if start == end:
# Leaf node
self.tree[node] = value
else:
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
if index <= mid:
self._update(left_child, start, mid, index, value)
else:
self._update(right_child, mid + 1, end, index, value)
self.tree[node] = self.operation(
self.tree[left_child], self.tree[right_child]
)
if __name__ == "__main__":
import doctest
doctest.testmod()