diff --git a/dataframely/columns/list.py b/dataframely/columns/list.py index 5c22d24..6c1c66a 100644 --- a/dataframely/columns/list.py +++ b/dataframely/columns/list.py @@ -131,9 +131,12 @@ def pyarrow_dtype(self) -> pa.DataType: def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: # First, sample the number of items per list element # NOTE: We default to 32 for the upper bound as we need some kind of reasonable - # upper bound if none is set. + # upper bound if none is set. If min_length is greater than 32, we use + # min_length as the default upper bound instead. + min_len = self.min_length or 0 + default_max = max(32, min_len) element_lengths = generator.sample_int( - n, min=self.min_length or 0, max=(self.max_length or 32) + 1 + n, min=min_len, max=(self.max_length or default_max) + 1 ) # Then, we can sample the inner elements in a flat series diff --git a/tests/column_types/test_list.py b/tests/column_types/test_list.py index e860b85..44bb858 100644 --- a/tests/column_types/test_list.py +++ b/tests/column_types/test_list.py @@ -1,6 +1,8 @@ # Copyright (c) QuantCo 2025-2025 # SPDX-License-Identifier: BSD-3-Clause +from typing import cast + import polars as pl import pytest @@ -179,3 +181,14 @@ def test_inner_primary_key_struct( _, failure = schema.filter(df) assert failure.counts() == {"a|primary_key": failure_count} assert validation_mask(df, failure).to_list() == mask + + +@pytest.mark.parametrize("min_length", [0, 10, 33, 100]) +def test_list_sampling_with_min_length(min_length: int) -> None: + """Test that sampling works correctly when min_length > 32.""" + schema = create_schema("test", {"a": dy.List(dy.Int64(), min_length=min_length)}) + df = schema.sample(num_rows=10) + assert len(df) == 10 + # Verify all lists have at least min_length elements + min_list_len = cast(int, df["a"].list.len().min()) + assert min_list_len >= min_length