Skip to content

Commit b3c5bde

Browse files
Formatting
1 parent 784c9b6 commit b3c5bde

1 file changed

Lines changed: 71 additions & 22 deletions

File tree

src/rules.py

Lines changed: 71 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,19 @@ def validate(self):
163163
)
164164
case "filter_contains":
165165
if n_args != 2 and n_args != 3:
166-
raise RuleError("filter_contains(column, value, contains=True) expects 2 or 3 args")
166+
raise RuleError(
167+
"filter_contains(column, value, contains=True) expects 2 or 3 args"
168+
)
167169
case "filter_compare":
168170
if n_args != 3:
169-
raise RuleError("filter_compare(column, op, threshold) expects 3 args")
171+
raise RuleError(
172+
"filter_compare(column, op, threshold) expects 3 args"
173+
)
170174
case _:
171-
raise RuleError(f"Unable to parse function in rules. Function: {self.value}")
175+
raise RuleError(
176+
f"Unable to parse function in rules. Function: {self.value}"
177+
)
178+
172179

173180
@dataclass(frozen=True)
174181
class Pipe(Expr):
@@ -202,7 +209,7 @@ def _as_int(e: Expr) -> int:
202209

203210

204211
def _as_bool(e: Expr) -> bool:
205-
if isinstance(e,(Name, String)):
212+
if isinstance(e, (Name, String)):
206213
name = e.value.lower()
207214
if name in {"true", "1", "yes", "t"}:
208215
return True
@@ -256,13 +263,15 @@ def call(self, items):
256263
name = str(items[0])
257264
args = tuple(items[1:])
258265
return Call(name, args)
259-
266+
260267
def pipe_(self, items):
261268
left, right = items
269+
262270
def to_list(x):
263271
if isinstance(x, PipeChain):
264272
return list(x.calls)
265273
return [x]
274+
266275
calls = tuple(to_list(left) + to_list(right))
267276
return PipeChain(calls)
268277

@@ -532,7 +541,7 @@ def eval_bool(self, expr: Expr) -> np.ndarray:
532541
raise
533542
self._memo[expr] = out
534543
return out
535-
544+
536545
if isinstance(expr, PipeChain):
537546
out_or_df = self.annotations
538547
kwargs = {"df": self.annotations, "masks": []}
@@ -580,7 +589,9 @@ def eval_cycle(self, expr: Steps) -> np.ndarray:
580589
self._memo_list[expr] = mat
581590
return mat
582591

583-
def eval_call(self, call: Call, df: pl.DataFrame = None, masks: list[pl.Expr] = None) -> np.ndarray:
592+
def eval_call(
593+
self, call: Call, df: pl.DataFrame = None, masks: list[pl.Expr] = None
594+
) -> np.ndarray:
584595
fn = call.value
585596
args = call.args
586597
kwargs = {}
@@ -605,18 +616,33 @@ def eval_call(self, call: Call, df: pl.DataFrame = None, masks: list[pl.Expr] =
605616
val_thr=_as_float(args[2]),
606617
count_op=_as_str(args[3]),
607618
count_thr=_as_float(args[4]),
608-
**kwargs
619+
**kwargs,
609620
)
610621
case "column_sum_values":
611622
return self.column_count_values(
612-
col=_as_str(args[0]), op=_as_str(args[1]), thr=_as_float(args[2]), **kwargs
623+
col=_as_str(args[0]),
624+
op=_as_str(args[1]),
625+
thr=_as_float(args[2]),
626+
**kwargs,
613627
)
614628
case "filter_contains":
615629
if len(args) == 2:
616-
return self.filter_contains(col=_as_str(args[0]), val=_as_str(args[1]), **kwargs)
617-
return self.filter_contains(col=_as_str(args[0]), val=_as_str(args[1]), contains=_as_bool(args[2]), **kwargs)
630+
return self.filter_contains(
631+
col=_as_str(args[0]), val=_as_str(args[1]), **kwargs
632+
)
633+
return self.filter_contains(
634+
col=_as_str(args[0]),
635+
val=_as_str(args[1]),
636+
contains=_as_bool(args[2]),
637+
**kwargs,
638+
)
618639
case "filter_compare":
619-
return self.filter_compare(col=_as_str(args[0]), op=_as_str(args[1]), thr=_as_float(args[2]), **kwargs)
640+
return self.filter_compare(
641+
col=_as_str(args[0]),
642+
op=_as_str(args[1]),
643+
thr=_as_float(args[2]),
644+
**kwargs,
645+
)
620646
case _:
621647
raise RuleError(f"Unable to parse function in rules. Function: {fn}")
622648

@@ -626,26 +652,35 @@ def _sort_df_to_ordered_df(self, df):
626652
.sort("_order")
627653
.drop("_order")
628654
)
629-
655+
630656
def eval_filter_dec(func):
631657
"""Decorator to evaluate filter functions with optional df and masks
632-
Applies masks to the specified column before calling the function."""
658+
Applies masks to the specified column before calling the function."""
659+
633660
@functools.wraps(func)
634-
def wrapper(self, col, *args, df: pl.DataFrame = None, masks: list[pl.Expr] = None, **kwargs):
661+
def wrapper(
662+
self,
663+
col,
664+
*args,
665+
df: pl.DataFrame = None,
666+
masks: list[pl.Expr] = None,
667+
**kwargs,
668+
):
635669
df = self.annotations if df is None else df
636670
masks = masks or []
637671
if masks:
638672
df = df.with_columns(
639673
pl.when(pl.lit(True).and_(*masks))
640674
.then(pl.col(col))
641675
.otherwise(pl.lit(None))
642-
)
676+
)
643677
return func(self, col, *args, df=df, **kwargs)
678+
644679
return wrapper
645680

646681
# Call functions
647682
@staticmethod
648-
def not_(x: np.ndarray = None) -> np.ndarray|pl.Expr:
683+
def not_(x: np.ndarray = None) -> np.ndarray | pl.Expr:
649684
if isinstance(x, list) and len(x) > 0 or isinstance(x, pl.Expr):
650685
if isinstance(x, pl.Expr):
651686
x = [x]
@@ -671,7 +706,15 @@ def at_least(k: int, x: np.ndarray) -> np.ndarray:
671706

672707
@eval_filter_dec
673708
def column_count_values(
674-
self, col: str, val_op: str, val_thr: float, count_op: str, count_thr: float, *, df: pl.DataFrame) -> np.ndarray:
709+
self,
710+
col: str,
711+
val_op: str,
712+
val_thr: float,
713+
count_op: str,
714+
count_thr: float,
715+
*,
716+
df: pl.DataFrame,
717+
) -> np.ndarray:
675718
df = self.annotations if df is None else df
676719

677720
if col not in df.columns:
@@ -703,7 +746,9 @@ def column_count_values(
703746
return df.select(pl.col(col)).to_series().to_numpy()
704747

705748
@eval_filter_dec
706-
def column_sum_values(self, col: str, op: str, thr: float, *, df: pl.DataFrame = None) -> np.ndarray:
749+
def column_sum_values(
750+
self, col: str, op: str, thr: float, *, df: pl.DataFrame = None
751+
) -> np.ndarray:
707752
df = self.annotations if df is None else df
708753

709754
if col not in df.columns:
@@ -719,13 +764,17 @@ def column_sum_values(self, col: str, op: str, thr: float, *, df: pl.DataFrame =
719764
df = self._sort_df_to_ordered_df(df)
720765
return df.select(pl.col(col)).to_series().to_numpy()
721766

722-
def filter_contains(self, col: str, val: str, df: pl.DataFrame = None, **kwargs) -> pl.DataFrame:
767+
def filter_contains(
768+
self, col: str, val: str, df: pl.DataFrame = None, **kwargs
769+
) -> pl.DataFrame:
723770
df = self.annotations if df is None else df
724771
if col not in df:
725772
raise RuleError(f"Missing column '{col}' for filter_contains()")
726773
return pl.col(col).str.contains(val)
727-
728-
def filter_compare(self, col: str, op: str, thr: float, df: pl.DataFrame = None, **kwargs) -> pl.DataFrame:
774+
775+
def filter_compare(
776+
self, col: str, op: str, thr: float, df: pl.DataFrame = None, **kwargs
777+
) -> pl.DataFrame:
729778
df = self.annotations if df is None else df
730779
if col not in df:
731780
raise RuleError(f"Missing column '{col}' for filter_compare()")

0 commit comments

Comments
 (0)