Skip to content

Commit a224cf2

Browse files
Allow string labels in Bin primitive
- remove integer cast for bin labels during deserialization\n- update Bin typing to support non-integer labels\n- add serialization/transform tests for string labels\n\nImplements #96
1 parent e9a7ada commit a224cf2

2 files changed

Lines changed: 23 additions & 8 deletions

File tree

src/harmonization_framework/primitives/bin_primitive.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from .base import PrimitiveOperation, support_iterable
2-
from typing import List, Tuple
2+
from typing import Any, List, Tuple
33

44
class _IntervalNode:
5-
def __init__(self, label: int, lower: int, upper: int, left=None, right=None):
5+
def __init__(self, label: Any, lower: int, upper: int, left=None, right=None):
66
self.label = label
77
self.lower = lower
88
self.upper = upper
@@ -15,7 +15,7 @@ class Bin(PrimitiveOperation):
1515
Assign values into histogram bins.
1616
Performs a range query using an interval tree. Bins must not overlap.
1717
"""
18-
def __init__(self, bins: List[Tuple[int, Tuple[int]]]):
18+
def __init__(self, bins: List[Tuple[Any, Tuple[int, int]]]):
1919
self.bins = self._validate_bins(bins)
2020
self._tree = self._build_tree(self.bins, 0, len(self.bins)-1)
2121

@@ -35,13 +35,13 @@ def to_dict(self):
3535
return output
3636

3737
@support_iterable
38-
def transform(self, value: int) -> int:
38+
def transform(self, value: int) -> Any:
3939
transformed = self._query(value, self._tree)
4040
if transformed is None:
4141
print(f"Warning: value={value} does not belong to a bin.")
4242
return transformed
4343

44-
def _query(self, value: int, node: _IntervalNode) -> int:
44+
def _query(self, value: int, node: _IntervalNode) -> Any:
4545
if node is None:
4646
return None
4747
if value < node.lower:
@@ -52,7 +52,7 @@ def _query(self, value: int, node: _IntervalNode) -> int:
5252
return node.label
5353
return None
5454

55-
def _build_tree(self, bins: List[Tuple[int, Tuple[int]]], left: int, right: int):
55+
def _build_tree(self, bins: List[Tuple[Any, Tuple[int, int]]], left: int, right: int):
5656
if left > right:
5757
return None
5858

@@ -68,7 +68,7 @@ def _build_tree(self, bins: List[Tuple[int, Tuple[int]]], left: int, right: int)
6868
)
6969
return node
7070

71-
def _validate_bins(self, bins: List[Tuple[int, Tuple[int]]]) -> List[Tuple[int, Tuple[int]]]:
71+
def _validate_bins(self, bins: List[Tuple[Any, Tuple[int, int]]]) -> List[Tuple[Any, Tuple[int, int]]]:
7272
normalized = []
7373
for label, (start, end) in bins:
7474
if start > end:
@@ -91,7 +91,7 @@ def _validate_bins(self, bins: List[Tuple[int, Tuple[int]]]) -> List[Tuple[int,
9191
@classmethod
9292
def from_serialization(cls, serialization):
9393
bins = [
94-
(int(interval["label"]), (int(interval["start"]), int(interval["end"])))
94+
(interval["label"], (int(interval["start"]), int(interval["end"])))
9595
for interval in serialization["bins"]
9696
]
9797
return Bin(bins)

tests/test_primitives_serialization.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,21 @@ def test_bin_serialization_and_transform():
4242
assert primitive.transform([3, 12]) == [low_label, high_label]
4343

4444

45+
def test_bin_string_labels_roundtrip_and_transform():
46+
primitive = Bin([("child", (0, 12)), ("adult", (13, 120))])
47+
payload = primitive.to_dict()
48+
49+
assert payload["bins"] == [
50+
{"label": "child", "start": 0, "end": 12},
51+
{"label": "adult", "start": 13, "end": 120},
52+
]
53+
54+
roundtrip = Bin.from_serialization(payload)
55+
assert roundtrip.to_dict() == payload
56+
assert roundtrip.transform(7) == "child"
57+
assert roundtrip.transform(34) == "adult"
58+
59+
4560
def test_bin_rejects_overlapping_bins():
4661
with pytest.raises(ValueError, match="Overlapping bins detected"):
4762
Bin([(1, (0, 10)), (2, (10, 20))])

0 commit comments

Comments
 (0)