From abbfff36b601ab2adaa09f4d4b342889e043e580 Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Tue, 23 Dec 2025 11:07:48 +0100 Subject: [PATCH] feat: Allow missing keys in sampling overrides --- dataframely/schema.py | 77 +++++++++++++++++++++++++-------- tests/collection/test_sample.py | 13 ++++++ tests/schema/test_sample.py | 43 ++++++++++-------- 3 files changed, 97 insertions(+), 36 deletions(-) diff --git a/dataframely/schema.py b/dataframely/schema.py index f19981c2..a00582ed 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -7,7 +7,7 @@ import sys import warnings from abc import ABC -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Mapping, Sequence from json import JSONDecodeError from pathlib import Path from typing import IO, Any, Literal, overload @@ -177,7 +177,7 @@ def sample( num_rows: int | None = None, *, overrides: ( - Mapping[str, Iterable[Any]] | Sequence[Mapping[str, Any]] | None + Mapping[str, Sequence[Any] | Any] | Sequence[Mapping[str, Any]] | None ) = None, generator: Generator | None = None, ) -> DataFrame[Self]: @@ -234,26 +234,22 @@ def sample( g = generator or Generator() # Precondition: valid overrides. We put them into a data frame to remember which - # values have been used in the algorithm below. - if overrides: + # values have been used in the algorithm below. When the user passes a sequence + # of mappings, they do not require to have the same keys. Hence, we have to + # remember that the data frame has "holes". + missing_override_indices: dict[str, pl.Series] = {} + if overrides is not None: override_keys = ( - set(overrides) if isinstance(overrides, Mapping) else set(overrides[0]) + set(overrides) + if isinstance(overrides, Mapping) + else ( + set.union(*[set(o.keys()) for o in overrides]) + if len(overrides) > 0 + else set() + ) ) - if isinstance(overrides, Sequence): - # Check that overrides entries are consistent. Not necessary for mapping - # overrides as polars checks the series lists upon data frame construction. - inconsistent_override_keys = [ - index - for index, current in enumerate(overrides) - if set(current) != override_keys - ] - if len(inconsistent_override_keys) > 0: - raise ValueError( - "The `overrides` entries at the following indices " - "do not provide the same keys as the first entry: " - f"{inconsistent_override_keys}." - ) + # Check that all override keys refer to valid columns column_names = set(cls.column_names()) if not override_keys.issubset(column_names): raise ValueError( @@ -261,6 +257,19 @@ def sample( "which are not in the schema." ) + # Remember the "holes" of the inputs if overrides are provided as a sequence + if isinstance(overrides, Sequence): + for key in override_keys: + indices = [ + i for i, override in enumerate(overrides) if key not in override + ] + if len(indices) > 0: + missing_override_indices[key] = pl.Series(indices) + + # NOTE: Even if the user-provided overrides have "holes", we can still just + # create the data frame. Polars will fill the missing values with nulls, we + # will replace them later during sampling. If we were to already replace + # them here, we would not be able to resample these values. values = pl.DataFrame( overrides, schema={ @@ -323,6 +332,7 @@ def sample( used_values=values.slice(0, 0), remaining_values=values, override_expressions=override_expressions, + missing_value_indices=missing_override_indices, ) sampling_rounds = 1 @@ -360,6 +370,7 @@ def sample( used_values=used_values, remaining_values=remaining_values, override_expressions=override_expressions, + missing_value_indices=missing_override_indices, ) sampling_rounds += 1 @@ -388,6 +399,7 @@ def _sample_filter( used_values: pl.DataFrame, remaining_values: pl.DataFrame, override_expressions: list[pl.Expr], + missing_value_indices: dict[str, pl.Series], ) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: """Private method to sample a data frame with the schema including subsequent filtering. @@ -406,6 +418,33 @@ def _sample_filter( } ) + # If we have missing value indices, we need to sample new values for the + # indices that overlap with indices in the remaining values and replace them + # in the sampled data frame. + for name, indices in missing_value_indices.items(): + remapped_indices = ( + indices.to_frame("idx") + .join( + remaining_values.select("__row_index__").with_row_index( + "__row_index_loop__" + ), + left_on="idx", + right_on="__row_index__", + ) + .select("__row_index_loop__") + .to_series() + ) + if (num := len(remapped_indices)) > 0: + sampled_values = cls.columns()[name].sample(generator, num) + sampled = sampled.with_columns( + sampled[name] + # NOTE: We need to sort here as `scatter` requires sorted indices. + # Due to concatenations in `remaining_values`, the indices can go + # out of order. + .scatter(remapped_indices.sort(), sampled_values) + .alias(name) + ) + combined_dataframe = pl.concat([previous_result, sampled]) # Pre-process columns before filtering. combined_dataframe = combined_dataframe.with_columns(override_expressions) diff --git a/tests/collection/test_sample.py b/tests/collection/test_sample.py index be732ae6..6c8ca771 100644 --- a/tests/collection/test_sample.py +++ b/tests/collection/test_sample.py @@ -189,3 +189,16 @@ def test_duplicate_column_inlined_for_sampling() -> None: ], }, ) + + +def test_sample_override_sequence_with_missing_keys() -> None: + collection = MyCollection.sample( + overrides=[{"first": {"a": 1}, "second": [{"c": 2}, {}, {"b": 5}]}] + ) + assert collection.first.collect().height == 1 + assert collection.second is not None + + second = collection.second.collect() + assert second.height == 3 + assert second["c"].item(0) == 2 + assert second["b"].item(2) == 5 diff --git a/tests/schema/test_sample.py b/tests/schema/test_sample.py index 696130b1..5ba2a9bb 100644 --- a/tests/schema/test_sample.py +++ b/tests/schema/test_sample.py @@ -9,6 +9,7 @@ import dataframely as dy from dataframely.random import Generator +from dataframely.testing import create_schema class MySimpleSchema(dy.Schema): @@ -214,23 +215,6 @@ def test_sample_raises_superfluous_column_override() -> None: SchemaWithIrrelevantColumnPreProcessing.sample(100) -def test_sample_with_inconsistent_overrides_keys_raises() -> None: - with pytest.raises( - ValueError, - match=( - r"The `overrides` entries at the following indices do not provide " - r"the same keys as the first entry: \[1, 2\]." - ), - ): - MySimpleSchema.sample( - overrides=[ - {"a": 1, "b": "one"}, - {"a": 2}, - {"b": 2}, - ] - ) - - @pytest.mark.parametrize( "overrides,failed_column,failed_rule,failed_rows", [ @@ -252,3 +236,28 @@ def test_sample_invalid_override_values_raises( ): with dy.Config(max_sampling_iterations=100): # speed up the test MyAdvancedSchema.sample(overrides=overrides) + + +def test_sample_empty_override_sequence() -> None: + df = MySimpleSchema.sample(overrides=[]) + assert len(df) == 0 + + +def test_sample_override_sequence_with_missing_keys() -> None: + df = MySimpleSchema.sample(overrides=[{"a": 1}, {"b": "two"}]) + assert df.item(0, 0) == 1 + assert df.item(1, 1) == "two" + assert len(df) == 2 + + +def test_sample_override_sequence_with_missing_keys_and_resampling() -> None: + schema = create_schema("test", {"a": dy.UInt8(primary_key=True), "b": dy.String()}) + generator = Generator(seed=42) + df = schema.sample( + overrides=[{"a": i} for i in range(250)] + [{"b": "two"}, {"b": "three"}], + generator=generator, + ) + assert len(df) == 252 + assert all(df.item(i, 0) == i for i in range(250)) + assert df.item(250, 1) == "two" + assert df.item(251, 1) == "three"