Skip to content

Commit f0ba549

Browse files
committed
feat: add depth parameter to the Python flatten function
1 parent f661d19 commit f0ba549

File tree

2 files changed

+50
-6
lines changed

2 files changed

+50
-6
lines changed
Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from typing import Generator, TypeVar
2+
from math import inf
23

34
T = TypeVar("T")
4-
NestedList = T | list["NestedList"]
5+
NestedList = T | list["nested_list"]
56

67

7-
def flatten(self: NestedList) -> Generator[T, None, None]:
8-
for i in self:
9-
if isinstance(i, list):
10-
yield from flatten(i)
8+
def flatten(
9+
nested_list: NestedList, depth: int = inf
10+
) -> Generator[T, None, None]:
11+
"""Flatten a nested list up to the specified depth.
12+
13+
Args:
14+
nested_list: The nested list to flatten
15+
depth: The maximum depth to flatten. Default is infinity (flatten all levels).
16+
"""
17+
for item in nested_list if isinstance(nested_list, list) else [nested_list]:
18+
if isinstance(item, list) and depth > 0:
19+
yield from flatten(item, depth - 1)
1120
else:
12-
yield i
21+
yield item

workspaces/adventure-pack/goodies/python3/src/flatten/test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,38 @@ def test_flatten_random_types():
1414
flattened = list(flatten(l))
1515
expected = ["hi", -80, "c", 0.0, (3, 3), dictionary, 6, 5, 4, 3, 2, 1, 0]
1616
assert flattened == expected
17+
18+
19+
def test_flatten_with_depth_0():
20+
# Test with depth=0 (no flattening)
21+
input_ = [1, [2, [3, [4, 5]]]]
22+
assert list(flatten(input_, depth=0)) == input_
23+
24+
25+
def test_flatten_with_depth_2():
26+
# Test with depth=2
27+
input_ = [1, [2, [3, [4, 5]]]]
28+
expected = [1, 2, 3, [4, 5]]
29+
assert list(flatten(input_, depth=2)) == expected
30+
31+
32+
def test_flatten_with_depth_greater_than_needed():
33+
# Test with depth larger than needed
34+
input_ = [1, [2, [3, [4, 5]]]]
35+
expected = [1, 2, 3, 4, 5]
36+
assert list(flatten(input_, depth=10)) == expected
37+
38+
39+
def test_flatten_with_empty_lists():
40+
# Test with empty lists at different levels
41+
input_ = [[], [1, [2, [], [3, []]], 4], []]
42+
assert list(flatten(input_, depth=1)) == [1, [2, [], [3, []]], 4]
43+
assert list(flatten(input_, depth=2)) == [1, 2, [], [3, []], 4]
44+
assert list(flatten(input_)) == [1, 2, 3, 4]
45+
46+
47+
def test_flatten_non_list_input():
48+
# Test with non-list input
49+
assert list(flatten(42)) == [42]
50+
assert list(flatten("hello")) == ["hello"]
51+
assert list(flatten(None)) == [None]

0 commit comments

Comments
 (0)