diff --git a/dataframely/columns/string.py b/dataframely/columns/string.py index fadb4f2..7234d35 100644 --- a/dataframely/columns/string.py +++ b/dataframely/columns/string.py @@ -14,6 +14,8 @@ from ._base import Check, Column from ._registry import register +DEFAULT_SAMPLING_REGEX = r"[0-9a-zA-Z]" + @register class String(Column): @@ -126,9 +128,9 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: str_max = f"{self.max_length}" if self.max_length is not None else "" # NOTE: We generate single-byte unicode characters here as validation uses # `len_bytes()`. Potentially we need to be more accurate at some point... - regex = f"[\x01-\x7a]{{{str_min},{str_max}}}" + regex = f"{DEFAULT_SAMPLING_REGEX}{{{str_min},{str_max}}}" else: - regex = r"[\x01-\x7a]*" + regex = rf"{DEFAULT_SAMPLING_REGEX}*" return generator.sample_string( n, diff --git a/tests/columns/test_str.py b/tests/columns/test_str.py index 9520822..5288103 100644 --- a/tests/columns/test_str.py +++ b/tests/columns/test_str.py @@ -1,10 +1,13 @@ # Copyright (c) QuantCo 2025-2025 # SPDX-License-Identifier: BSD-3-Clause +import re import pytest import dataframely as dy from dataframely.columns import Column +from dataframely.columns.string import DEFAULT_SAMPLING_REGEX +from dataframely.random import Generator from dataframely.testing import ALL_COLUMN_TYPES @@ -32,3 +35,17 @@ def test_string_representation_array() -> None: def test_string_representation_struct() -> None: column = dy.Struct({"a": dy.String()}) assert str(column) == dy.Struct.__name__.lower() + + +@pytest.mark.parametrize("min_length", [None, 5, 10]) +@pytest.mark.parametrize("max_length", [None, 20]) +def test_string_sampling_without_regex( + min_length: int | None, max_length: int | None +) -> None: + # Check that if no regex is provided, the sampled strings only use + # characters from the DEFAULT_SAMPLING_REGEX. + column = dy.String(min_length=min_length, max_length=max_length) + generator = Generator(seed=42) + sample = column.sample(generator=generator, n=1000) + + assert all(re.match(f"{DEFAULT_SAMPLING_REGEX}*", value) for value in sample)