11from .base import PrimitiveOperation , support_iterable
2- from typing import List , Tuple
2+ from typing import Any , List , Tuple
33
44class _IntervalNode :
5- def __init__ (self , label : int , lower : int , upper : int , left = None , right = None ):
5+ def __init__ (self , label : Any , lower : int , upper : int , left = None , right = None ):
66 self .label = label
77 self .lower = lower
88 self .upper = upper
@@ -15,7 +15,7 @@ class Bin(PrimitiveOperation):
1515 Assign values into histogram bins.
1616 Performs a range query using an interval tree. Bins must not overlap.
1717 """
18- def __init__ (self , bins : List [Tuple [int , Tuple [int ]]]):
18+ def __init__ (self , bins : List [Tuple [Any , Tuple [int , int ]]]):
1919 self .bins = self ._validate_bins (bins )
2020 self ._tree = self ._build_tree (self .bins , 0 , len (self .bins )- 1 )
2121
@@ -35,13 +35,13 @@ def to_dict(self):
3535 return output
3636
3737 @support_iterable
38- def transform (self , value : int ) -> int :
38+ def transform (self , value : int ) -> Any :
3939 transformed = self ._query (value , self ._tree )
4040 if transformed is None :
4141 print (f"Warning: value={ value } does not belong to a bin." )
4242 return transformed
4343
44- def _query (self , value : int , node : _IntervalNode ) -> int :
44+ def _query (self , value : int , node : _IntervalNode ) -> Any :
4545 if node is None :
4646 return None
4747 if value < node .lower :
@@ -52,7 +52,7 @@ def _query(self, value: int, node: _IntervalNode) -> int:
5252 return node .label
5353 return None
5454
55- def _build_tree (self , bins : List [Tuple [int , Tuple [int ]]], left : int , right : int ):
55+ def _build_tree (self , bins : List [Tuple [Any , Tuple [int , int ]]], left : int , right : int ):
5656 if left > right :
5757 return None
5858
@@ -68,7 +68,7 @@ def _build_tree(self, bins: List[Tuple[int, Tuple[int]]], left: int, right: int)
6868 )
6969 return node
7070
71- def _validate_bins (self , bins : List [Tuple [int , Tuple [int ]]]) -> List [Tuple [int , Tuple [int ]]]:
71+ def _validate_bins (self , bins : List [Tuple [Any , Tuple [int , int ]]]) -> List [Tuple [Any , Tuple [int , int ]]]:
7272 normalized = []
7373 for label , (start , end ) in bins :
7474 if start > end :
@@ -91,7 +91,7 @@ def _validate_bins(self, bins: List[Tuple[int, Tuple[int]]]) -> List[Tuple[int,
9191 @classmethod
9292 def from_serialization (cls , serialization ):
9393 bins = [
94- (int ( interval ["label" ]) , (int (interval ["start" ]), int (interval ["end" ])))
94+ (interval ["label" ], (int (interval ["start" ]), int (interval ["end" ])))
9595 for interval in serialization ["bins" ]
9696 ]
9797 return Bin (bins )
0 commit comments