Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions dataframely/columns/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from ._base import Check, Column
from ._registry import register

DEFAULT_SAMPLING_REGEX = r"[0-9a-zA-Z]"


@register
class String(Column):
Expand Down Expand Up @@ -126,9 +128,9 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
str_max = f"{self.max_length}" if self.max_length is not None else ""
# NOTE: We generate single-byte unicode characters here as validation uses
# `len_bytes()`. Potentially we need to be more accurate at some point...
regex = f"[\x01-\x7a]{{{str_min},{str_max}}}"
regex = f"{DEFAULT_SAMPLING_REGEX}{{{str_min},{str_max}}}"
else:
regex = r"[\x01-\x7a]*"
regex = rf"{DEFAULT_SAMPLING_REGEX}*"

return generator.sample_string(
n,
Expand Down
17 changes: 17 additions & 0 deletions tests/columns/test_str.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright (c) QuantCo 2025-2025
# SPDX-License-Identifier: BSD-3-Clause
import re

import pytest

import dataframely as dy
from dataframely.columns import Column
from dataframely.columns.string import DEFAULT_SAMPLING_REGEX
from dataframely.random import Generator
from dataframely.testing import ALL_COLUMN_TYPES


Expand Down Expand Up @@ -32,3 +35,17 @@ def test_string_representation_array() -> None:
def test_string_representation_struct() -> None:
column = dy.Struct({"a": dy.String()})
assert str(column) == dy.Struct.__name__.lower()


@pytest.mark.parametrize("min_length", [None, 5, 10])
@pytest.mark.parametrize("max_length", [None, 20])
def test_string_sampling_without_regex(
min_length: int | None, max_length: int | None
) -> None:
# Check that if no regex is provided, the sampled strings only use
# characters from the DEFAULT_SAMPLING_REGEX.
column = dy.String(min_length=min_length, max_length=max_length)
generator = Generator(seed=42)
sample = column.sample(generator=generator, n=1000)

assert all(re.match(f"{DEFAULT_SAMPLING_REGEX}*", value) for value in sample)
Loading