Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions dataframely/collection/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def my_filter(self) -> pl.DataFrame:
#: the collection's common primary key. Two members that share common column names
#: may not both be inlined for sampling.
inline_for_sampling: bool = False
#: Whether individual row failures in this member should be propagated to the
#: collection, i.e., cause the common primary key of the failures to be filtered
#: out from the entire collection. This setting is ignored if `ignored_in_filters`
#: is `True`.
propagate_row_failures: bool = False


# --------------------------------------- UTILS -------------------------------------- #
Expand Down Expand Up @@ -250,6 +255,7 @@ def _derive_member_info(
is_optional=True,
ignored_in_filters=collection_member.ignored_in_filters,
inline_for_sampling=collection_member.inline_for_sampling,
propagate_row_failures=collection_member.propagate_row_failures,
)
elif issubclass(origin, TypedLazyFrame):
# Happy path: required member
Expand All @@ -258,6 +264,7 @@ def _derive_member_info(
is_optional=False,
ignored_in_filters=collection_member.ignored_in_filters,
inline_for_sampling=collection_member.inline_for_sampling,
propagate_row_failures=collection_member.propagate_row_failures,
)
else:
# Some other unknown annotation
Expand Down Expand Up @@ -333,6 +340,16 @@ def non_ignored_members(cls) -> set[str]:
if not member.ignored_in_filters
}

@classmethod
def _failure_propagating_members(cls) -> set[str]:
"""The names of all members of the collection that propagate individual row
failures to the collection."""
return {
name
for name, member in cls.members().items()
if member.propagate_row_failures
}

@classmethod
def common_primary_key(cls) -> list[str]:
"""The primary keys shared by non ignored members of the collection."""
Expand Down
30 changes: 25 additions & 5 deletions dataframely/collection/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,8 @@ class HospitalInvoiceData(dy.Collection):
# Once we've done that, we can apply the filters on this collection. To this end,
# we iterate over all filters and store the filter results.
filters = cls._filters()
if len(filters) > 0:
failure_propagating_members = cls._failure_propagating_members()
if len(filters) > 0 or len(failure_propagating_members) > 0:
result_cls = cls._init(results)
primary_key = cls.common_primary_key()

Expand All @@ -586,6 +587,17 @@ class HospitalInvoiceData(dy.Collection):
.lazy()
)

drop: dict[str, pl.LazyFrame] = {}
for failure_propagating_member in failure_propagating_members:
annotation_column = f"{failure_propagating_member}|failure_propagation"
drop[annotation_column] = (
failures[failure_propagating_member]
._lf.select(primary_key)
.unique()
.pipe(collect_if, eager)
.lazy()
)

# Now we can iterate over the results and left-join onto each individual
# filter to obtain independent boolean indicators of whether to keep the row
for member_name, filtered in results.items():
Expand All @@ -601,15 +613,23 @@ class HospitalInvoiceData(dy.Collection):
how="left",
maintain_order="left",
).with_columns(pl.col(name).fill_null(False))
for name, filter_drop in drop.items():
lf_with_eval = lf_with_eval.join(
filter_drop.with_columns(pl.lit(False).alias(name)),
on=primary_key,
how="left",
maintain_order="left",
).with_columns(pl.col(name).fill_null(True))

lf_with_eval = lf_with_eval.pipe(collect_if, eager).lazy()

# Filtering `lf_with_eval` by the rows for which all joins
# "succeeded", we can identify the rows that pass all the filters. We
# keep these rows for the result.
all_filter_columns = list(keep.keys()) + list(drop.keys())
results[member_name] = lf_with_eval.filter(
pl.all_horizontal(keep.keys())
).drop(keep.keys())
pl.all_horizontal(all_filter_columns)
).drop(all_filter_columns)

# Filtering `lf_with_eval` with the inverse condition, we find all
# the problematic rows. We can build a single failure info object by
Expand All @@ -623,7 +643,7 @@ class HospitalInvoiceData(dy.Collection):
#
failure = failures[member_name]
filtered_failure = lf_with_eval.filter(
~pl.all_horizontal(keep.keys())
~pl.all_horizontal(all_filter_columns)
).lazy()

# If we cast previously, `failure` and `filtered_failure` have different
Expand Down Expand Up @@ -660,7 +680,7 @@ class HospitalInvoiceData(dy.Collection):

failures[member_name] = FailureInfo(
lf=failure_lf,
rule_columns=failure._rule_columns + list(keep.keys()),
rule_columns=failure._rule_columns + all_filter_columns,
schema=failure.schema,
)

Expand Down
126 changes: 126 additions & 0 deletions tests/collection/test_propagate_row_failures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) QuantCo 2025-2025
# SPDX-License-Identifier: BSD-3-Clause

from typing import Annotated

import polars as pl
import pytest

import dataframely as dy


class MyTestSchema(dy.Schema):
shared = dy.UInt8(primary_key=True)
u = dy.UInt8()
v = dy.UInt8()


class MyTestSchema2(dy.Schema):
shared = dy.UInt8(primary_key=True)
id = dy.UInt8(primary_key=True)
x = dy.UInt8()
y = dy.UInt8()


class MyTestCollection(dy.Collection):
a: dy.LazyFrame[MyTestSchema]
b: Annotated[
dy.LazyFrame[MyTestSchema2],
dy.CollectionMember(propagate_row_failures=True),
]

@dy.filter()
def x_greater_u(self) -> pl.LazyFrame:
return (
self.a.join(self.b, on="shared")
.filter(pl.col("x") > pl.col("u"))
.unique("shared")
)


@pytest.fixture()
def valid_a() -> pl.LazyFrame:
return pl.LazyFrame(
[
{"shared": 1, "u": 10, "v": 5},
{"shared": 2, "u": 20, "v": 15},
{"shared": 3, "u": 30, "v": 25},
]
)


@pytest.fixture()
def valid_b() -> pl.LazyFrame:
return pl.LazyFrame(
[
{"shared": 1, "id": 1, "x": 15, "y": 50},
{"shared": 2, "id": 1, "x": 25, "y": 60},
{"shared": 2, "id": 2, "x": 25, "y": 70},
{"shared": 3, "id": 1, "x": 5, "y": 70},
]
)


@pytest.fixture()
def invalid_b() -> pl.LazyFrame:
return pl.LazyFrame(
[
{"shared": 1, "id": 1, "x": 15, "y": 50},
{"shared": 2, "id": 1, "x": 25, "y": 60},
{
"shared": 2,
"id": 2,
"x": 25,
"y": None,
}, # invalid row, should be propagated
{"shared": 3, "id": 1, "x": 10, "y": 70}, # filtered out due to the filter
]
)


def test_collection_propagate_row_failures_meta() -> None:
assert MyTestCollection._failure_propagating_members() == {"b"}


def test_collection_propagate_row_failure_no_propagation(
valid_a: pl.LazyFrame,
valid_b: pl.LazyFrame,
) -> None:
success, failures = MyTestCollection.filter(
{
"a": valid_a,
"b": valid_b,
},
cast=True,
)
# Assert that only id 3 is filtered out (caused by the filter)
assert success.a.select("shared").collect().to_series().to_list() == [1, 2]
assert success.b.select("shared").collect().to_series().to_list() == [1, 2, 2]
# Assert that nothing is filtered out due to propagation
for member_name in MyTestCollection.members().keys():
assert failures[member_name].counts()["x_greater_u"] == 1
assert len(failures[member_name].counts()) == 1


def test_collection_propagate_row_failure_with_propagation(
valid_a: pl.LazyFrame,
invalid_b: pl.LazyFrame,
) -> None:
success, failures = MyTestCollection.filter(
{
"a": valid_a,
"b": invalid_b,
},
cast=True,
)
# Assert that id 2 is also filtered out
assert success.a.select("shared").collect().to_series().to_list() == [1]
assert success.b.select("shared").collect().to_series().to_list() == [1]
# Assert that nothing is filtered out due to propagation
for member_name in MyTestCollection.members().keys():
assert failures[member_name].counts()["x_greater_u"] == 1
assert failures[member_name].counts()["b|failure_propagation"] == 1
assert failures[member_name].counts()["x_greater_u"] == 1
assert failures[member_name].counts()["b|failure_propagation"] == 1
assert len(failures["a"].counts()) == 2
assert len(failures["b"].counts()) == 3
Loading