diff --git a/README.md b/README.md index c1fad43..b36e905 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,36 @@ async def main(): df = asyncio.run(main()) ``` +### Explicit Dependencies For Callables + +For plain callables, declare dependencies explicitly when one column needs another: + +```python +from chatan import call, dataset + +ds = dataset({ + "file_path": call(lambda: get_random_filepath()), + "file_content": call( + lambda ctx: get_file_content(ctx["file_path"]), + requires=["file_path"], + ), +}) +``` + +You can also use tuple syntax: `"file_content": (callable_fn, ["file_path"])`. + +If your callable argument names match column names, dependencies are inferred automatically: + +```python +def get_file_chunk(file_path): + return load_chunk(file_path) + +ds = dataset({ + "file_path": call(get_random_filepath), + "file_chunk": call(get_file_chunk), # infers dependency on file_path +}) +``` + ## Generator Options ### API-based Generators (included in base install) diff --git a/src/chatan/__init__.py b/src/chatan/__init__.py index defa38c..8e865b9 100644 --- a/src/chatan/__init__.py +++ b/src/chatan/__init__.py @@ -2,7 +2,7 @@ __version__ = "0.3.0" -from .dataset import dataset +from .dataset import call, dataset, depends_on from .evaluate import eval, evaluate from .generator import generator from .sampler import sample @@ -10,6 +10,8 @@ __all__ = [ "dataset", + "call", + "depends_on", "generator", "sample", "generate_with_viewer", diff --git a/src/chatan/dataset.py b/src/chatan/dataset.py index 31ecba9..4531f20 100644 --- a/src/chatan/dataset.py +++ b/src/chatan/dataset.py @@ -1,7 +1,8 @@ """Dataset creation and manipulation with async generation.""" import asyncio -from typing import Any, Dict, List, Optional, Union +import inspect +from typing import Any, Callable, Dict, Iterable, List, Optional, Union import pandas as pd from datasets import Dataset as HFDataset @@ -11,6 +12,8 @@ from .generator import GeneratorFunction from .sampler import SampleFunction +CONTEXT_ARG_NAMES = {"ctx", "context"} + class Dataset: """Async dataset generator with dependency-aware execution.""" @@ -156,7 +159,7 @@ async def _generate_column_value( await completion_events[dep].wait() # Generate the value - func = self.schema[column] + func = self._resolve_column_callable(self.schema[column]) if isinstance(func, GeneratorFunction): # Use async generator @@ -165,11 +168,8 @@ async def _generate_column_value( # Samplers are sync but fast value = func(row) elif callable(func): - # Check if it's an async callable - if asyncio.iscoroutinefunction(func): - value = await func(row) - else: - value = func(row) + result = _invoke_with_context(func, row) + value = await result if asyncio.iscoroutine(result) else result else: # Static value value = func @@ -188,19 +188,101 @@ def _build_dependency_graph(self) -> Dict[str, List[str]]: for column, func in self.schema.items(): deps = [] + explicit_deps = self._extract_explicit_dependencies(func) + if explicit_deps: + deps.extend(explicit_deps) + deps.extend(self._infer_signature_dependencies(func, column)) # Extract dependencies from generator functions if hasattr(func, "prompt_template"): import re template = getattr(func, "prompt_template", "") - deps = re.findall(r"\{(\w+)\}", template) + deps.extend(re.findall(r"\{(\w+)\}", template)) + + # Keep dependency order stable and remove duplicates. + ordered_deps = [] + for dep in deps: + if dep not in ordered_deps: + ordered_deps.append(dep) # Only include dependencies that are in the schema - dependencies[column] = [dep for dep in deps if dep in self.schema] + dependencies[column] = [dep for dep in ordered_deps if dep in self.schema] return dependencies + @staticmethod + def _resolve_column_callable(func: Any) -> Any: + """Unwrap schema value to the executable callable/value.""" + if isinstance(func, DependentCallable): + return func.func + + if ( + isinstance(func, tuple) + and len(func) == 2 + and callable(func[0]) + ): + return func[0] + + return func + + @staticmethod + def _extract_explicit_dependencies(func: Any) -> List[str]: + """Extract explicit dependency declarations from schema value.""" + deps = [] + + if isinstance(func, DependentCallable): + deps = func.dependencies + elif ( + isinstance(func, tuple) + and len(func) == 2 + and callable(func[0]) + ): + deps = func[1] + elif hasattr(func, "dependencies"): + deps = getattr(func, "dependencies") + elif hasattr(func, "depends_on"): + deps = getattr(func, "depends_on") + + if deps is None: + return [] + if isinstance(deps, str): + return [deps] + if isinstance(deps, Iterable): + return [dep for dep in deps if isinstance(dep, str)] + return [] + + def _infer_signature_dependencies(self, func: Any, current_column: str) -> List[str]: + """Infer dependencies from callable parameter names.""" + target = self._resolve_column_callable(func) + if not callable(target): + return [] + if isinstance(target, (GeneratorFunction, SampleFunction)): + return [] + + try: + signature = inspect.signature(target) + except (TypeError, ValueError): + return [] + + inferred = [] + for param in signature.parameters.values(): + if param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + continue + if param.name in CONTEXT_ARG_NAMES: + continue + if param.default is not inspect.Parameter.empty: + continue + if param.name == current_column: + continue + if param.name in self.schema: + inferred.append(param.name) + + return inferred + def _topological_sort(self, dependencies: Dict[str, List[str]]) -> List[str]: """Topologically sort columns by dependencies.""" visited = set() @@ -272,3 +354,106 @@ def dataset(schema: Union[Dict[str, Any], str], n: int = 100) -> Dataset: >>> df = asyncio.run(main()) """ return Dataset(schema, n) + + +class DependentCallable: + """Wrapper for callables with explicit column dependencies.""" + + def __init__(self, func: Callable[..., Any], dependencies: List[str]): + self.func = func + self.dependencies = dependencies + + def __call__(self, context: Dict[str, Any]) -> Any: + return _invoke_with_context(self.func, context) + + +def call( + func: Callable[..., Any], + *dependencies: str, + requires: Optional[List[str]] = None, + with_: Optional[List[str]] = None, + **kwargs, +) -> DependentCallable: + """Declare callable schema entries and optional explicit dependencies. + + Example: + schema = { + "file_path": call(lambda: random_path()), + "file_content": call( + lambda ctx: load(ctx["file_path"]), + requires=["file_path"], + ), + } + """ + with_deps = kwargs.pop("with", None) + if kwargs: + unexpected = ", ".join(kwargs.keys()) + raise TypeError(f"Unexpected keyword argument(s): {unexpected}") + + explicit = list(dependencies) + if requires: + if isinstance(requires, str): + explicit.append(requires) + else: + explicit.extend(requires) + if with_: + explicit.extend(with_) + if with_deps: + if isinstance(with_deps, str): + explicit.append(with_deps) + else: + explicit.extend(with_deps) + + return DependentCallable(func, explicit) + + +def depends_on(func: Callable[..., Any], *dependencies: str) -> DependentCallable: + """Backward-compatible alias for explicit callable dependencies.""" + return call(func, *dependencies) + + +def _invoke_with_context(func: Callable[..., Any], context: Dict[str, Any]) -> Any: + """Invoke callable using context-aware argument mapping.""" + try: + signature = inspect.signature(func) + except (TypeError, ValueError): + # Keep legacy behavior for callables without inspect metadata. + return func(context) + + params = list(signature.parameters.values()) + if not params: + return func() + + args = [] + kwargs = {} + missing_required = False + + for param in params: + if param.kind == inspect.Parameter.VAR_POSITIONAL: + continue + if param.kind == inspect.Parameter.VAR_KEYWORD: + continue + + if param.name in CONTEXT_ARG_NAMES: + value = context + has_value = True + else: + has_value = param.name in context + value = context.get(param.name) + + if not has_value and param.default is inspect.Parameter.empty: + missing_required = True + continue + if not has_value: + continue + + if param.kind == inspect.Parameter.POSITIONAL_ONLY: + args.append(value) + else: + kwargs[param.name] = value + + if not missing_required: + return func(*args, **kwargs) + + # Backward compatibility for legacy callables that expect `ctx`. + return func(context) diff --git a/tests/test_dataset_comprehensive.py b/tests/test_dataset_comprehensive.py index 843aaa5..c6ef37d 100644 --- a/tests/test_dataset_comprehensive.py +++ b/tests/test_dataset_comprehensive.py @@ -7,7 +7,7 @@ from unittest.mock import Mock from datasets import Dataset as HFDataset -from chatan.dataset import Dataset, dataset +from chatan.dataset import Dataset, call, dataset, depends_on from chatan.generator import GeneratorFunction, BaseGenerator from chatan.sampler import ChoiceSampler, UUIDSampler @@ -132,6 +132,41 @@ def test_dependency_outside_schema(self): # external_col should be filtered out assert dependencies["col2"] == ["col1"] + def test_explicit_callable_dependencies(self): + """Test explicit dependencies for callables.""" + schema = { + "col1": ChoiceSampler(["A"]), + "col2": call(lambda ctx: f"v:{ctx['col1']}", requires=["col1"]), + "col3": (lambda ctx: f"w:{ctx['col2']}", ["col2"]), + } + ds = Dataset(schema, n=2) + + dependencies = ds._build_dependency_graph() + assert dependencies["col1"] == [] + assert dependencies["col2"] == ["col1"] + assert dependencies["col3"] == ["col2"] + + def test_signature_inferred_callable_dependencies(self): + """Test dependencies inferred from callable argument names.""" + + def col2(col1): + return f"v:{col1}" + + def col3(col2): + return f"w:{col2}" + + schema = { + "col1": ChoiceSampler(["A"]), + "col2": call(col2), + "col3": call(col3), + } + ds = Dataset(schema, n=2) + + dependencies = ds._build_dependency_graph() + assert dependencies["col1"] == [] + assert dependencies["col2"] == ["col1"] + assert dependencies["col3"] == ["col2"] + @pytest.mark.asyncio class TestDataGeneration: @@ -209,6 +244,100 @@ async def test_complex_dependency_chain(self): assert row["d"] == row["b"] + row["c"] assert row["e"] == row["a"] + row["d"] + async def test_callable_call_wrapper(self): + """Test callable dependencies via call wrapper.""" + schema = { + "file_path": call(lambda: "src/main.ts"), + "file_content": call( + lambda ctx: f"content:{ctx['file_path']}", + requires=["file_path"], + ), + } + ds = Dataset(schema, n=5) + df = await ds.generate() + + assert len(df) == 5 + assert all(df["file_content"] == "content:src/main.ts") + + async def test_callable_call_wrapper_with_with_keyword_alias(self): + """Test call wrapper supports 'with' keyword alias via kwargs expansion.""" + schema = { + "file_path": call(lambda: "src/alias.ts"), + "file_content": call( + lambda ctx: f"content:{ctx['file_path']}", + **{"with": ["file_path"]}, + ), + } + ds = Dataset(schema, n=3) + df = await ds.generate() + + assert len(df) == 3 + assert all(df["file_content"] == "content:src/alias.ts") + + async def test_callable_call_wrapper_supports_requires_string(self): + """Test call wrapper supports requires as a single string.""" + schema = { + "file_path": call(lambda: "src/single.ts"), + "file_content": call( + lambda ctx: f"content:{ctx['file_path']}", + requires="file_path", + ), + } + ds = Dataset(schema, n=2) + df = await ds.generate() + + assert len(df) == 2 + assert all(df["file_content"] == "content:src/single.ts") + + async def test_depends_on_backwards_compatible(self): + """Test depends_on still works as alias.""" + schema = { + "file_path": lambda ctx: "src/legacy.ts", + "file_content": depends_on( + lambda ctx: f"content:{ctx['file_path']}", + "file_path", + ), + } + ds = Dataset(schema, n=3) + df = await ds.generate() + + assert len(df) == 3 + assert all(df["file_content"] == "content:src/legacy.ts") + + async def test_callable_call_wrapper_infers_dependencies_from_signature(self): + """Test call wrapper infers dependencies from function signature names.""" + + def file_path(): + return "src/inferred.ts" + + def file_content(file_path): + return f"content:{file_path}" + + schema = { + "file_path": call(file_path), + "file_content": call(file_content), + } + ds = Dataset(schema, n=3) + df = await ds.generate() + + assert len(df) == 3 + assert all(df["file_content"] == "content:src/inferred.ts") + + async def test_callable_tuple_dependency_spec(self): + """Test callable dependencies via tuple spec.""" + schema = { + "file_path": lambda ctx: "src/app.ts", + "file_content": ( + lambda ctx: f"content:{ctx['file_path']}", + ["file_path"], + ), + } + ds = Dataset(schema, n=5) + df = await ds.generate() + + assert len(df) == 5 + assert all(df["file_content"] == "content:src/app.ts") + async def test_override_sample_count(self): """Test overriding sample count in generate().""" schema = {"col": ChoiceSampler(["A"])}