diff --git a/dataframely/collection/_base.py b/dataframely/collection/_base.py index 0ff42d9..f1cd7af 100644 --- a/dataframely/collection/_base.py +++ b/dataframely/collection/_base.py @@ -65,6 +65,11 @@ def my_filter(self) -> pl.DataFrame: #: the collection's common primary key. Two members that share common column names #: may not both be inlined for sampling. inline_for_sampling: bool = False + #: Whether individual row failures in this member should be propagated to the + #: collection, i.e., cause the common primary key of the failures to be filtered + #: out from the entire collection. This setting is ignored if `ignored_in_filters` + #: is `True`. + propagate_row_failures: bool = False # --------------------------------------- UTILS -------------------------------------- # @@ -250,6 +255,7 @@ def _derive_member_info( is_optional=True, ignored_in_filters=collection_member.ignored_in_filters, inline_for_sampling=collection_member.inline_for_sampling, + propagate_row_failures=collection_member.propagate_row_failures, ) elif issubclass(origin, TypedLazyFrame): # Happy path: required member @@ -258,6 +264,7 @@ def _derive_member_info( is_optional=False, ignored_in_filters=collection_member.ignored_in_filters, inline_for_sampling=collection_member.inline_for_sampling, + propagate_row_failures=collection_member.propagate_row_failures, ) else: # Some other unknown annotation @@ -333,6 +340,16 @@ def non_ignored_members(cls) -> set[str]: if not member.ignored_in_filters } + @classmethod + def _failure_propagating_members(cls) -> set[str]: + """The names of all members of the collection that propagate individual row + failures to the collection.""" + return { + name + for name, member in cls.members().items() + if member.propagate_row_failures + } + @classmethod def common_primary_key(cls) -> list[str]: """The primary keys shared by non ignored members of the collection.""" diff --git a/dataframely/collection/collection.py b/dataframely/collection/collection.py index ee9c822..61d7e58 100644 --- a/dataframely/collection/collection.py +++ b/dataframely/collection/collection.py @@ -573,7 +573,8 @@ class HospitalInvoiceData(dy.Collection): # Once we've done that, we can apply the filters on this collection. To this end, # we iterate over all filters and store the filter results. filters = cls._filters() - if len(filters) > 0: + failure_propagating_members = cls._failure_propagating_members() + if len(filters) > 0 or len(failure_propagating_members) > 0: result_cls = cls._init(results) primary_key = cls.common_primary_key() @@ -586,6 +587,17 @@ class HospitalInvoiceData(dy.Collection): .lazy() ) + drop: dict[str, pl.LazyFrame] = {} + for failure_propagating_member in failure_propagating_members: + annotation_column = f"{failure_propagating_member}|failure_propagation" + drop[annotation_column] = ( + failures[failure_propagating_member] + ._lf.select(primary_key) + .unique() + .pipe(collect_if, eager) + .lazy() + ) + # Now we can iterate over the results and left-join onto each individual # filter to obtain independent boolean indicators of whether to keep the row for member_name, filtered in results.items(): @@ -601,15 +613,23 @@ class HospitalInvoiceData(dy.Collection): how="left", maintain_order="left", ).with_columns(pl.col(name).fill_null(False)) + for name, filter_drop in drop.items(): + lf_with_eval = lf_with_eval.join( + filter_drop.with_columns(pl.lit(False).alias(name)), + on=primary_key, + how="left", + maintain_order="left", + ).with_columns(pl.col(name).fill_null(True)) lf_with_eval = lf_with_eval.pipe(collect_if, eager).lazy() # Filtering `lf_with_eval` by the rows for which all joins # "succeeded", we can identify the rows that pass all the filters. We # keep these rows for the result. + all_filter_columns = list(keep.keys()) + list(drop.keys()) results[member_name] = lf_with_eval.filter( - pl.all_horizontal(keep.keys()) - ).drop(keep.keys()) + pl.all_horizontal(all_filter_columns) + ).drop(all_filter_columns) # Filtering `lf_with_eval` with the inverse condition, we find all # the problematic rows. We can build a single failure info object by @@ -623,7 +643,7 @@ class HospitalInvoiceData(dy.Collection): # failure = failures[member_name] filtered_failure = lf_with_eval.filter( - ~pl.all_horizontal(keep.keys()) + ~pl.all_horizontal(all_filter_columns) ).lazy() # If we cast previously, `failure` and `filtered_failure` have different @@ -660,7 +680,7 @@ class HospitalInvoiceData(dy.Collection): failures[member_name] = FailureInfo( lf=failure_lf, - rule_columns=failure._rule_columns + list(keep.keys()), + rule_columns=failure._rule_columns + all_filter_columns, schema=failure.schema, ) diff --git a/tests/collection/test_propagate_row_failures.py b/tests/collection/test_propagate_row_failures.py new file mode 100644 index 0000000..dafe80b --- /dev/null +++ b/tests/collection/test_propagate_row_failures.py @@ -0,0 +1,126 @@ +# Copyright (c) QuantCo 2025-2025 +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Annotated + +import polars as pl +import pytest + +import dataframely as dy + + +class MyTestSchema(dy.Schema): + shared = dy.UInt8(primary_key=True) + u = dy.UInt8() + v = dy.UInt8() + + +class MyTestSchema2(dy.Schema): + shared = dy.UInt8(primary_key=True) + id = dy.UInt8(primary_key=True) + x = dy.UInt8() + y = dy.UInt8() + + +class MyTestCollection(dy.Collection): + a: dy.LazyFrame[MyTestSchema] + b: Annotated[ + dy.LazyFrame[MyTestSchema2], + dy.CollectionMember(propagate_row_failures=True), + ] + + @dy.filter() + def x_greater_u(self) -> pl.LazyFrame: + return ( + self.a.join(self.b, on="shared") + .filter(pl.col("x") > pl.col("u")) + .unique("shared") + ) + + +@pytest.fixture() +def valid_a() -> pl.LazyFrame: + return pl.LazyFrame( + [ + {"shared": 1, "u": 10, "v": 5}, + {"shared": 2, "u": 20, "v": 15}, + {"shared": 3, "u": 30, "v": 25}, + ] + ) + + +@pytest.fixture() +def valid_b() -> pl.LazyFrame: + return pl.LazyFrame( + [ + {"shared": 1, "id": 1, "x": 15, "y": 50}, + {"shared": 2, "id": 1, "x": 25, "y": 60}, + {"shared": 2, "id": 2, "x": 25, "y": 70}, + {"shared": 3, "id": 1, "x": 5, "y": 70}, + ] + ) + + +@pytest.fixture() +def invalid_b() -> pl.LazyFrame: + return pl.LazyFrame( + [ + {"shared": 1, "id": 1, "x": 15, "y": 50}, + {"shared": 2, "id": 1, "x": 25, "y": 60}, + { + "shared": 2, + "id": 2, + "x": 25, + "y": None, + }, # invalid row, should be propagated + {"shared": 3, "id": 1, "x": 10, "y": 70}, # filtered out due to the filter + ] + ) + + +def test_collection_propagate_row_failures_meta() -> None: + assert MyTestCollection._failure_propagating_members() == {"b"} + + +def test_collection_propagate_row_failure_no_propagation( + valid_a: pl.LazyFrame, + valid_b: pl.LazyFrame, +) -> None: + success, failures = MyTestCollection.filter( + { + "a": valid_a, + "b": valid_b, + }, + cast=True, + ) + # Assert that only id 3 is filtered out (caused by the filter) + assert success.a.select("shared").collect().to_series().to_list() == [1, 2] + assert success.b.select("shared").collect().to_series().to_list() == [1, 2, 2] + # Assert that nothing is filtered out due to propagation + for member_name in MyTestCollection.members().keys(): + assert failures[member_name].counts()["x_greater_u"] == 1 + assert len(failures[member_name].counts()) == 1 + + +def test_collection_propagate_row_failure_with_propagation( + valid_a: pl.LazyFrame, + invalid_b: pl.LazyFrame, +) -> None: + success, failures = MyTestCollection.filter( + { + "a": valid_a, + "b": invalid_b, + }, + cast=True, + ) + # Assert that id 2 is also filtered out + assert success.a.select("shared").collect().to_series().to_list() == [1] + assert success.b.select("shared").collect().to_series().to_list() == [1] + # Assert that nothing is filtered out due to propagation + for member_name in MyTestCollection.members().keys(): + assert failures[member_name].counts()["x_greater_u"] == 1 + assert failures[member_name].counts()["b|failure_propagation"] == 1 + assert failures[member_name].counts()["x_greater_u"] == 1 + assert failures[member_name].counts()["b|failure_propagation"] == 1 + assert len(failures["a"].counts()) == 2 + assert len(failures["b"].counts()) == 3