Skip to content

Commit 871ca40

Browse files
Improve quick_select with type hints, edge case handling, and documentation
1 parent 841e947 commit 871ca40

File tree

1 file changed

+70
-22
lines changed

1 file changed

+70
-22
lines changed

searches/quick_select.py

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
11
"""
22
A Python implementation of the quick select algorithm, which is efficient for
33
calculating the value that would appear in the index of a list if it would be
4-
sorted, even if it is not already sorted
4+
sorted, even if it is not already sorted.
55
https://en.wikipedia.org/wiki/Quickselect
6+
7+
Time Complexity:
8+
Average: O(n)
9+
Worst: O(n^2)
610
"""
711

812
import random
13+
from typing import List, Optional, Tuple
914

1015

11-
def _partition(data: list, pivot) -> tuple:
16+
def _partition(data: List[int], pivot: int) -> Tuple[List[int], List[int], List[int]]:
1217
"""
13-
Three way partition the data into smaller, equal and greater lists,
14-
in relationship to the pivot
15-
:param data: The data to be sorted (a list)
16-
:param pivot: The value to partition the data on
17-
:return: Three list: smaller, equal and greater
18+
Partition the input list into three lists relative to a pivot.
19+
20+
Args:
21+
data (List[int]): Input list.
22+
pivot (int): Pivot value.
23+
24+
Returns:
25+
Tuple[List[int], List[int], List[int]]:
26+
Lists of elements less than, equal to, and greater than the pivot.
1827
"""
1928
less, equal, greater = [], [], []
2029
for element in data:
@@ -27,8 +36,23 @@ def _partition(data: list, pivot) -> tuple:
2736
return less, equal, greater
2837

2938

30-
def quick_select(items: list, index: int):
39+
def quick_select(items: List[int], index: int) -> Optional[int]:
3140
"""
41+
Return the element that would appear at the given index if the list
42+
were sorted, without fully sorting the list.
43+
44+
Args:
45+
items (List[int]): The unsorted input list.
46+
index (int): The zero-based target index in the sorted order.
47+
48+
Returns:
49+
Optional[int]: The element at the given sorted index, or None for
50+
invalid input.
51+
52+
Time Complexity:
53+
Average: O(n)
54+
Worst: O(n^2)
55+
3256
>>> quick_select([2, 4, 5, 7, 899, 54, 32], 5)
3357
54
3458
>>> quick_select([2, 4, 5, 7, 899, 54, 32], 1)
@@ -37,17 +61,18 @@ def quick_select(items: list, index: int):
3761
4
3862
>>> quick_select([3, 5, 7, 10, 2, 12], 3)
3963
7
64+
>>> quick_select([], 0) is None
65+
True
4066
"""
41-
# index = len(items) // 2 when trying to find the median
42-
# (value of index when items is sorted)
67+
if not items:
68+
return None
4369

44-
# invalid input
45-
if index >= len(items) or index < 0:
70+
if not 0 <= index < len(items):
4671
return None
4772

48-
pivot = items[random.randint(0, len(items) - 1)]
49-
count = 0
50-
smaller, equal, larger = _partition(items, pivot)
73+
pivot = random.choice(items)
74+
smaller, equal, greater = _partition(items, pivot)
75+
5176
count = len(equal)
5277
m = len(smaller)
5378

@@ -57,28 +82,51 @@ def quick_select(items: list, index: int):
5782
# must be in smaller
5883
elif m > index:
5984
return quick_select(smaller, index)
60-
# must be in larger
85+
# must be in greater
6186
else:
62-
return quick_select(larger, index - (m + count))
87+
return quick_select(greater, index - (m + count))
6388

6489

65-
def median(items: list):
90+
def median(items: List[int]) -> Optional[float]:
6691
"""
92+
Find the median of an unsorted list using Quickselect.
93+
6794
One common application of Quickselect is finding the median, which is
68-
the middle element (or average of the two middle elements) in a sorted dataset.
69-
It works efficiently on unsorted lists by partially sorting the data without
70-
fully sorting the entire list.
95+
the middle element (or average of the two middle elements) in a sorted
96+
dataset. It works efficiently on unsorted lists by partially sorting the
97+
data without fully sorting the entire list.
98+
99+
Args:
100+
items (List[int]): The input list.
101+
102+
Returns:
103+
Optional[float]: The median value, or None if the list is empty.
71104
72105
>>> median([3, 2, 2, 9, 9])
73106
3
74-
75107
>>> median([2, 2, 9, 9, 9, 3])
76108
6.0
109+
>>> median([]) is None
110+
True
77111
"""
112+
if not items:
113+
return None
114+
78115
mid, rem = divmod(len(items), 2)
79116
if rem != 0:
80117
return quick_select(items=items, index=mid)
81118
else:
82119
low_mid = quick_select(items=items, index=mid - 1)
83120
high_mid = quick_select(items=items, index=mid)
84121
return (low_mid + high_mid) / 2
122+
123+
124+
if __name__ == "__main__":
125+
assert quick_select([1, 2, 3], 1) == 2
126+
assert quick_select([], 0) is None
127+
assert quick_select([5, 4, 3, 2], 2) == 4
128+
assert quick_select([3, 5, 7, 10, 2, 12], 3) == 7
129+
assert median([3, 2, 2, 9, 9]) == 3
130+
assert median([2, 2, 9, 9, 9, 3]) == 6.0
131+
assert median([]) is None
132+
print("All assertions passed.")

0 commit comments

Comments
 (0)