@@ -163,12 +163,19 @@ def validate(self):
163163 )
164164 case "filter_contains" :
165165 if n_args != 2 and n_args != 3 :
166- raise RuleError ("filter_contains(column, value, contains=True) expects 2 or 3 args" )
166+ raise RuleError (
167+ "filter_contains(column, value, contains=True) expects 2 or 3 args"
168+ )
167169 case "filter_compare" :
168170 if n_args != 3 :
169- raise RuleError ("filter_compare(column, op, threshold) expects 3 args" )
171+ raise RuleError (
172+ "filter_compare(column, op, threshold) expects 3 args"
173+ )
170174 case _:
171- raise RuleError (f"Unable to parse function in rules. Function: { self .value } " )
175+ raise RuleError (
176+ f"Unable to parse function in rules. Function: { self .value } "
177+ )
178+
172179
173180@dataclass (frozen = True )
174181class Pipe (Expr ):
@@ -202,7 +209,7 @@ def _as_int(e: Expr) -> int:
202209
203210
204211def _as_bool (e : Expr ) -> bool :
205- if isinstance (e ,(Name , String )):
212+ if isinstance (e , (Name , String )):
206213 name = e .value .lower ()
207214 if name in {"true" , "1" , "yes" , "t" }:
208215 return True
@@ -256,13 +263,15 @@ def call(self, items):
256263 name = str (items [0 ])
257264 args = tuple (items [1 :])
258265 return Call (name , args )
259-
266+
260267 def pipe_ (self , items ):
261268 left , right = items
269+
262270 def to_list (x ):
263271 if isinstance (x , PipeChain ):
264272 return list (x .calls )
265273 return [x ]
274+
266275 calls = tuple (to_list (left ) + to_list (right ))
267276 return PipeChain (calls )
268277
@@ -532,7 +541,7 @@ def eval_bool(self, expr: Expr) -> np.ndarray:
532541 raise
533542 self ._memo [expr ] = out
534543 return out
535-
544+
536545 if isinstance (expr , PipeChain ):
537546 out_or_df = self .annotations
538547 kwargs = {"df" : self .annotations , "masks" : []}
@@ -580,7 +589,9 @@ def eval_cycle(self, expr: Steps) -> np.ndarray:
580589 self ._memo_list [expr ] = mat
581590 return mat
582591
583- def eval_call (self , call : Call , df : pl .DataFrame = None , masks : list [pl .Expr ] = None ) -> np .ndarray :
592+ def eval_call (
593+ self , call : Call , df : pl .DataFrame = None , masks : list [pl .Expr ] = None
594+ ) -> np .ndarray :
584595 fn = call .value
585596 args = call .args
586597 kwargs = {}
@@ -605,18 +616,33 @@ def eval_call(self, call: Call, df: pl.DataFrame = None, masks: list[pl.Expr] =
605616 val_thr = _as_float (args [2 ]),
606617 count_op = _as_str (args [3 ]),
607618 count_thr = _as_float (args [4 ]),
608- ** kwargs
619+ ** kwargs ,
609620 )
610621 case "column_sum_values" :
611622 return self .column_count_values (
612- col = _as_str (args [0 ]), op = _as_str (args [1 ]), thr = _as_float (args [2 ]), ** kwargs
623+ col = _as_str (args [0 ]),
624+ op = _as_str (args [1 ]),
625+ thr = _as_float (args [2 ]),
626+ ** kwargs ,
613627 )
614628 case "filter_contains" :
615629 if len (args ) == 2 :
616- return self .filter_contains (col = _as_str (args [0 ]), val = _as_str (args [1 ]), ** kwargs )
617- return self .filter_contains (col = _as_str (args [0 ]), val = _as_str (args [1 ]), contains = _as_bool (args [2 ]), ** kwargs )
630+ return self .filter_contains (
631+ col = _as_str (args [0 ]), val = _as_str (args [1 ]), ** kwargs
632+ )
633+ return self .filter_contains (
634+ col = _as_str (args [0 ]),
635+ val = _as_str (args [1 ]),
636+ contains = _as_bool (args [2 ]),
637+ ** kwargs ,
638+ )
618639 case "filter_compare" :
619- return self .filter_compare (col = _as_str (args [0 ]), op = _as_str (args [1 ]), thr = _as_float (args [2 ]), ** kwargs )
640+ return self .filter_compare (
641+ col = _as_str (args [0 ]),
642+ op = _as_str (args [1 ]),
643+ thr = _as_float (args [2 ]),
644+ ** kwargs ,
645+ )
620646 case _:
621647 raise RuleError (f"Unable to parse function in rules. Function: { fn } " )
622648
@@ -626,26 +652,35 @@ def _sort_df_to_ordered_df(self, df):
626652 .sort ("_order" )
627653 .drop ("_order" )
628654 )
629-
655+
630656 def eval_filter_dec (func ):
631657 """Decorator to evaluate filter functions with optional df and masks
632- Applies masks to the specified column before calling the function."""
658+ Applies masks to the specified column before calling the function."""
659+
633660 @functools .wraps (func )
634- def wrapper (self , col , * args , df : pl .DataFrame = None , masks : list [pl .Expr ] = None , ** kwargs ):
661+ def wrapper (
662+ self ,
663+ col ,
664+ * args ,
665+ df : pl .DataFrame = None ,
666+ masks : list [pl .Expr ] = None ,
667+ ** kwargs ,
668+ ):
635669 df = self .annotations if df is None else df
636670 masks = masks or []
637671 if masks :
638672 df = df .with_columns (
639673 pl .when (pl .lit (True ).and_ (* masks ))
640674 .then (pl .col (col ))
641675 .otherwise (pl .lit (None ))
642- )
676+ )
643677 return func (self , col , * args , df = df , ** kwargs )
678+
644679 return wrapper
645680
646681 # Call functions
647682 @staticmethod
648- def not_ (x : np .ndarray = None ) -> np .ndarray | pl .Expr :
683+ def not_ (x : np .ndarray = None ) -> np .ndarray | pl .Expr :
649684 if isinstance (x , list ) and len (x ) > 0 or isinstance (x , pl .Expr ):
650685 if isinstance (x , pl .Expr ):
651686 x = [x ]
@@ -671,7 +706,15 @@ def at_least(k: int, x: np.ndarray) -> np.ndarray:
671706
672707 @eval_filter_dec
673708 def column_count_values (
674- self , col : str , val_op : str , val_thr : float , count_op : str , count_thr : float , * , df : pl .DataFrame ) -> np .ndarray :
709+ self ,
710+ col : str ,
711+ val_op : str ,
712+ val_thr : float ,
713+ count_op : str ,
714+ count_thr : float ,
715+ * ,
716+ df : pl .DataFrame ,
717+ ) -> np .ndarray :
675718 df = self .annotations if df is None else df
676719
677720 if col not in df .columns :
@@ -703,7 +746,9 @@ def column_count_values(
703746 return df .select (pl .col (col )).to_series ().to_numpy ()
704747
705748 @eval_filter_dec
706- def column_sum_values (self , col : str , op : str , thr : float , * , df : pl .DataFrame = None ) -> np .ndarray :
749+ def column_sum_values (
750+ self , col : str , op : str , thr : float , * , df : pl .DataFrame = None
751+ ) -> np .ndarray :
707752 df = self .annotations if df is None else df
708753
709754 if col not in df .columns :
@@ -719,13 +764,17 @@ def column_sum_values(self, col: str, op: str, thr: float, *, df: pl.DataFrame =
719764 df = self ._sort_df_to_ordered_df (df )
720765 return df .select (pl .col (col )).to_series ().to_numpy ()
721766
722- def filter_contains (self , col : str , val : str , df : pl .DataFrame = None , ** kwargs ) -> pl .DataFrame :
767+ def filter_contains (
768+ self , col : str , val : str , df : pl .DataFrame = None , ** kwargs
769+ ) -> pl .DataFrame :
723770 df = self .annotations if df is None else df
724771 if col not in df :
725772 raise RuleError (f"Missing column '{ col } ' for filter_contains()" )
726773 return pl .col (col ).str .contains (val )
727-
728- def filter_compare (self , col : str , op : str , thr : float , df : pl .DataFrame = None , ** kwargs ) -> pl .DataFrame :
774+
775+ def filter_compare (
776+ self , col : str , op : str , thr : float , df : pl .DataFrame = None , ** kwargs
777+ ) -> pl .DataFrame :
729778 df = self .annotations if df is None else df
730779 if col not in df :
731780 raise RuleError (f"Missing column '{ col } ' for filter_compare()" )
0 commit comments