Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions dataframely/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
import warnings
from abc import ABC
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Mapping, Sequence
from json import JSONDecodeError
from pathlib import Path
from typing import IO, Any, Literal, overload
Expand Down Expand Up @@ -177,7 +177,7 @@ def sample(
num_rows: int | None = None,
*,
overrides: (
Mapping[str, Iterable[Any]] | Sequence[Mapping[str, Any]] | None
Mapping[str, Sequence[Any] | Any] | Sequence[Mapping[str, Any]] | None
) = None,
generator: Generator | None = None,
) -> DataFrame[Self]:
Expand Down Expand Up @@ -234,33 +234,42 @@ def sample(
g = generator or Generator()

# Precondition: valid overrides. We put them into a data frame to remember which
# values have been used in the algorithm below.
if overrides:
# values have been used in the algorithm below. When the user passes a sequence
# of mappings, they do not require to have the same keys. Hence, we have to
# remember that the data frame has "holes".
missing_override_indices: dict[str, pl.Series] = {}
if overrides is not None:
override_keys = (
set(overrides) if isinstance(overrides, Mapping) else set(overrides[0])
set(overrides)
if isinstance(overrides, Mapping)
else (
set.union(*[set(o.keys()) for o in overrides])
if len(overrides) > 0
else set()
)
)
if isinstance(overrides, Sequence):
# Check that overrides entries are consistent. Not necessary for mapping
# overrides as polars checks the series lists upon data frame construction.
inconsistent_override_keys = [
index
for index, current in enumerate(overrides)
if set(current) != override_keys
]
if len(inconsistent_override_keys) > 0:
raise ValueError(
"The `overrides` entries at the following indices "
"do not provide the same keys as the first entry: "
f"{inconsistent_override_keys}."
)

# Check that all override keys refer to valid columns
column_names = set(cls.column_names())
if not override_keys.issubset(column_names):
raise ValueError(
f"Values are provided for columns {override_keys - column_names} "
"which are not in the schema."
)

# Remember the "holes" of the inputs if overrides are provided as a sequence
if isinstance(overrides, Sequence):
for key in override_keys:
indices = [
i for i, override in enumerate(overrides) if key not in override
]
if len(indices) > 0:
missing_override_indices[key] = pl.Series(indices)

# NOTE: Even if the user-provided overrides have "holes", we can still just
# create the data frame. Polars will fill the missing values with nulls, we
# will replace them later during sampling. If we were to already replace
# them here, we would not be able to resample these values.
values = pl.DataFrame(
overrides,
schema={
Expand Down Expand Up @@ -323,6 +332,7 @@ def sample(
used_values=values.slice(0, 0),
remaining_values=values,
override_expressions=override_expressions,
missing_value_indices=missing_override_indices,
)

sampling_rounds = 1
Expand Down Expand Up @@ -360,6 +370,7 @@ def sample(
used_values=used_values,
remaining_values=remaining_values,
override_expressions=override_expressions,
missing_value_indices=missing_override_indices,
)
sampling_rounds += 1

Expand Down Expand Up @@ -388,6 +399,7 @@ def _sample_filter(
used_values: pl.DataFrame,
remaining_values: pl.DataFrame,
override_expressions: list[pl.Expr],
missing_value_indices: dict[str, pl.Series],
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
"""Private method to sample a data frame with the schema including subsequent
filtering.
Expand All @@ -406,6 +418,33 @@ def _sample_filter(
}
)

# If we have missing value indices, we need to sample new values for the
# indices that overlap with indices in the remaining values and replace them
# in the sampled data frame.
for name, indices in missing_value_indices.items():
remapped_indices = (
indices.to_frame("idx")
.join(
remaining_values.select("__row_index__").with_row_index(
"__row_index_loop__"
),
left_on="idx",
right_on="__row_index__",
)
.select("__row_index_loop__")
.to_series()
)
if (num := len(remapped_indices)) > 0:
sampled_values = cls.columns()[name].sample(generator, num)
sampled = sampled.with_columns(
sampled[name]
# NOTE: We need to sort here as `scatter` requires sorted indices.
# Due to concatenations in `remaining_values`, the indices can go
# out of order.
.scatter(remapped_indices.sort(), sampled_values)
.alias(name)
)

combined_dataframe = pl.concat([previous_result, sampled])
# Pre-process columns before filtering.
combined_dataframe = combined_dataframe.with_columns(override_expressions)
Expand Down
13 changes: 13 additions & 0 deletions tests/collection/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,16 @@ def test_duplicate_column_inlined_for_sampling() -> None:
],
},
)


def test_sample_override_sequence_with_missing_keys() -> None:
collection = MyCollection.sample(
overrides=[{"first": {"a": 1}, "second": [{"c": 2}, {}, {"b": 5}]}]
)
assert collection.first.collect().height == 1
assert collection.second is not None

second = collection.second.collect()
assert second.height == 3
assert second["c"].item(0) == 2
assert second["b"].item(2) == 5
43 changes: 26 additions & 17 deletions tests/schema/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import dataframely as dy
from dataframely.random import Generator
from dataframely.testing import create_schema


class MySimpleSchema(dy.Schema):
Expand Down Expand Up @@ -214,23 +215,6 @@ def test_sample_raises_superfluous_column_override() -> None:
SchemaWithIrrelevantColumnPreProcessing.sample(100)


def test_sample_with_inconsistent_overrides_keys_raises() -> None:
with pytest.raises(
ValueError,
match=(
r"The `overrides` entries at the following indices do not provide "
r"the same keys as the first entry: \[1, 2\]."
),
):
MySimpleSchema.sample(
overrides=[
{"a": 1, "b": "one"},
{"a": 2},
{"b": 2},
]
)


@pytest.mark.parametrize(
"overrides,failed_column,failed_rule,failed_rows",
[
Expand All @@ -252,3 +236,28 @@ def test_sample_invalid_override_values_raises(
):
with dy.Config(max_sampling_iterations=100): # speed up the test
MyAdvancedSchema.sample(overrides=overrides)


def test_sample_empty_override_sequence() -> None:
df = MySimpleSchema.sample(overrides=[])
assert len(df) == 0


def test_sample_override_sequence_with_missing_keys() -> None:
df = MySimpleSchema.sample(overrides=[{"a": 1}, {"b": "two"}])
assert df.item(0, 0) == 1
assert df.item(1, 1) == "two"
assert len(df) == 2


def test_sample_override_sequence_with_missing_keys_and_resampling() -> None:
schema = create_schema("test", {"a": dy.UInt8(primary_key=True), "b": dy.String()})
generator = Generator(seed=42)
df = schema.sample(
overrides=[{"a": i} for i in range(250)] + [{"b": "two"}, {"b": "three"}],
generator=generator,
)
assert len(df) == 252
assert all(df.item(i, 0) == i for i in range(250))
assert df.item(250, 1) == "two"
assert df.item(251, 1) == "three"
Loading