Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions dataframely/columns/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from ._base import Check, Column
from ._registry import column_from_dict, register
from .list import _list_primary_key_check
from .list import List, _list_primary_key_check
from .struct import Struct

if sys.version_info >= (3, 11):
from typing import Self
Expand Down Expand Up @@ -117,9 +118,46 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
n_elements = n * math.prod(self.shape)
all_elements = self.inner.sample(generator, n_elements)

# For nested types (List, Array, Struct), we can't use reshape() directly
# because the inner type is not a scalar. Instead, we need to construct
# the nested structure manually.
if isinstance(self.inner, (List, Array, Struct)):
# Convert to a list and then group into arrays of the specified shape
all_elements_list = all_elements.to_list()

def build_nested_structure(elements: list, shape: tuple[int, ...]) -> list:
"""Build nested structure for a single array."""
if len(shape) == 1:
# Base case: this is a 1D array
return elements
else:
# Recursive case: split into rows
row_size = math.prod(shape[1:])
rows = []
for i in range(shape[0]):
start = i * row_size
end = start + row_size
row_elements = elements[start:end]
rows.append(build_nested_structure(row_elements, shape[1:]))
return rows

# Build n arrays, each with the specified shape
elements_per_array = math.prod(self.shape)
nested_arrays = []
for i in range(n):
start = i * elements_per_array
end = start + elements_per_array
array_elements = all_elements_list[start:end]
nested_arrays.append(build_nested_structure(array_elements, self.shape))

result = pl.Series(nested_arrays, dtype=self.dtype)
else:
# For scalar types, use the original reshape approach
result = all_elements.reshape((n, *self.shape))

# Finally, apply a null mask
return generator._apply_null_mask(
all_elements.reshape((n, *self.shape)),
result,
null_probability=self._null_probability,
)

Expand Down
30 changes: 30 additions & 0 deletions tests/columns/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,36 @@ def test_sample_array(generator: Generator) -> None:
assert set(samples.arr.len()) == {2, None}


@pytest.mark.parametrize(
"arr_size,n_samples",
[
(1, 1),
(2, 1),
(3, 5),
(1, 10),
],
)
def test_sample_array_list(arr_size: int, n_samples: int, generator: Generator) -> None:
"""Test sampling for Array(List(...), ...) which previously failed."""
column = dy.Array(dy.List(dy.Bool()), arr_size)
samples = sample_and_validate(column, generator, n=n_samples)
assert len(samples) == n_samples


def test_sample_array_of_array(generator: Generator) -> None:
"""Test sampling for Array(Array(...), ...)."""
column = dy.Array(dy.Array(dy.Bool(), 2), 3)
samples = sample_and_validate(column, generator, n=10)
assert len(samples) == 10


def test_sample_array_of_struct(generator: Generator) -> None:
"""Test sampling for Array(Struct(...), ...)."""
column = dy.Array(dy.Struct({"x": dy.Bool(), "y": dy.Integer()}), 2)
samples = sample_and_validate(column, generator, n=10)
assert len(samples) == 10


def test_sample_struct(generator: Generator) -> None:
column = dy.Struct(
{"a": dy.String(regex="[abc]"), "b": dy.String(regex="[a-z]xx")}, nullable=True
Expand Down
Loading