diff --git a/packages/data-designer-config/src/data_designer/config/__init__.py b/packages/data-designer-config/src/data_designer/config/__init__.py index eb385e15a..ad3ab2414 100644 --- a/packages/data-designer-config/src/data_designer/config/__init__.py +++ b/packages/data-designer-config/src/data_designer/config/__init__.py @@ -82,6 +82,7 @@ UniformSamplerParams, UUIDSamplerParams, ) + from data_designer.config.script_params import DataDesignerScriptParams # noqa: F401 from data_designer.config.seed import ( # noqa: F401 IndexRange, PartitionBlock, @@ -204,6 +205,8 @@ "PartitionBlock": (_MOD_SEED, "PartitionBlock"), "SamplingStrategy": (_MOD_SEED, "SamplingStrategy"), "SeedConfig": (_MOD_SEED, "SeedConfig"), + # script params + "DataDesignerScriptParams": (f"{_MOD_BASE}.script_params", "DataDesignerScriptParams"), # seed_source "DataFrameSeedSource": (f"{_MOD_BASE}.seed_source_dataframe", "DataFrameSeedSource"), "AgentRolloutFormat": (_MOD_SEED_SOURCE, "AgentRolloutFormat"), diff --git a/packages/data-designer-config/src/data_designer/config/script_params.py b/packages/data-designer-config/src/data_designer/config/script_params.py new file mode 100644 index 000000000..2c8d2689b --- /dev/null +++ b/packages/data-designer-config/src/data_designer/config/script_params.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class DataDesignerScriptParams: + """Runtime parameters forwarded to Python config workflows. + + Attributes: + argv: Raw workflow arguments passed after the CLI ``--`` separator. + """ + + argv: tuple[str, ...] = () diff --git a/packages/data-designer/src/data_designer/cli/commands/create.py b/packages/data-designer/src/data_designer/cli/commands/create.py index ea98222ea..6d7053bc9 100644 --- a/packages/data-designer/src/data_designer/cli/commands/create.py +++ b/packages/data-designer/src/data_designer/cli/commands/create.py @@ -3,21 +3,29 @@ from __future__ import annotations +from typing import Annotated + import click import typer +from data_designer.cli.commands.generation_args import resolve_generation_config_target from data_designer.cli.controllers.generation_controller import GenerationController from data_designer.config.utils.constants import DEFAULT_NUM_RECORDS from data_designer.interface.results import SUPPORTED_EXPORT_FORMATS def create_command( - config_source: str = typer.Argument( - help=( - "Path or URL to a config file (.yaml/.yml/.json), or a local Python module (.py)" - " that defines a load_config_builder() function." + workflow_args: Annotated[ + list[str] | None, + typer.Argument( + metavar="[CONFIG_SOURCE] [-- WORKFLOW_ARGS]", + help=( + "Path or URL to a config file (.yaml/.yml/.json), or a local Python module (.py)" + " that defines a load_config_builder() function. Extra arguments after '--' are forwarded to Python" + " workflows." + ), ), - ), + ] = None, num_records: int = typer.Option( DEFAULT_NUM_RECORDS, "--num-records", @@ -67,9 +75,11 @@ def create_command( # Create from a Python module with custom output path data-designer create my_config.py --artifact-path /path/to/output """ + target = resolve_generation_config_target(workflow_args) controller = GenerationController() controller.run_create( - config_source=config_source, + config_source=target.config_source, + workflow_args=target.workflow_args, num_records=num_records, dataset_name=dataset_name, artifact_path=artifact_path, diff --git a/packages/data-designer/src/data_designer/cli/commands/generation_args.py b/packages/data-designer/src/data_designer/cli/commands/generation_args.py new file mode 100644 index 000000000..18935255a --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/commands/generation_args.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + +import click + + +@dataclass(frozen=True) +class GenerationConfigTarget: + """Resolved config target for create, preview, and validate commands.""" + + config_source: str + workflow_args: tuple[str, ...] + + +def resolve_generation_config_target(raw_args: list[str] | None) -> GenerationConfigTarget: + """Split variadic CLI args into a config source plus workflow args.""" + args = tuple(raw_args or ()) + if not args: + raise click.UsageError("Missing argument 'CONFIG_SOURCE'.") + + config_source, *workflow_args = args + return GenerationConfigTarget(config_source=config_source, workflow_args=tuple(workflow_args)) diff --git a/packages/data-designer/src/data_designer/cli/commands/preview.py b/packages/data-designer/src/data_designer/cli/commands/preview.py index 1c08edf4a..1a5c4cf49 100644 --- a/packages/data-designer/src/data_designer/cli/commands/preview.py +++ b/packages/data-designer/src/data_designer/cli/commands/preview.py @@ -3,20 +3,28 @@ from __future__ import annotations +from typing import Annotated + import click import typer +from data_designer.cli.commands.generation_args import resolve_generation_config_target from data_designer.cli.controllers.generation_controller import GenerationController from data_designer.config.utils.constants import DEFAULT_DISPLAY_WIDTH, DEFAULT_NUM_RECORDS def preview_command( - config_source: str = typer.Argument( - help=( - "Path or URL to a config file (.yaml/.yml/.json), or a local Python module (.py)" - " that defines a load_config_builder() function." + workflow_args: Annotated[ + list[str] | None, + typer.Argument( + metavar="[CONFIG_SOURCE] [-- WORKFLOW_ARGS]", + help=( + "Path or URL to a config file (.yaml/.yml/.json), or a local Python module (.py)" + " that defines a load_config_builder() function. Extra arguments after '--' are forwarded to Python" + " workflows." + ), ), - ), + ] = None, num_records: int = typer.Option( DEFAULT_NUM_RECORDS, "--num-records", @@ -54,9 +62,11 @@ def preview_command( ), ) -> None: """Generate a preview dataset for fast iteration on your configuration.""" + target = resolve_generation_config_target(workflow_args) controller = GenerationController() controller.run_preview( - config_source=config_source, + config_source=target.config_source, + workflow_args=target.workflow_args, num_records=num_records, non_interactive=non_interactive, save_results=save_results, diff --git a/packages/data-designer/src/data_designer/cli/commands/validate.py b/packages/data-designer/src/data_designer/cli/commands/validate.py index 19d338816..62cf8d723 100644 --- a/packages/data-designer/src/data_designer/cli/commands/validate.py +++ b/packages/data-designer/src/data_designer/cli/commands/validate.py @@ -3,18 +3,26 @@ from __future__ import annotations +from typing import Annotated + import typer +from data_designer.cli.commands.generation_args import resolve_generation_config_target from data_designer.cli.controllers.generation_controller import GenerationController def validate_command( - config_source: str = typer.Argument( - help=( - "Path or URL to a config file (.yaml/.yml/.json), or a local Python module (.py)" - " that defines a load_config_builder() function." + workflow_args: Annotated[ + list[str] | None, + typer.Argument( + metavar="[CONFIG_SOURCE] [-- WORKFLOW_ARGS]", + help=( + "Path or URL to a config file (.yaml/.yml/.json), or a local Python module (.py)" + " that defines a load_config_builder() function. Extra arguments after '--' are forwarded to Python" + " workflows." + ), ), - ), + ] = None, ) -> None: """Validate a Data Designer configuration. @@ -31,5 +39,6 @@ def validate_command( # Validate a Python module data-designer validate my_config.py """ + target = resolve_generation_config_target(workflow_args) controller = GenerationController() - controller.run_validate(config_source=config_source) + controller.run_validate(config_source=target.config_source, workflow_args=target.workflow_args) diff --git a/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py b/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py index 39c45f5f5..1f0c03e97 100644 --- a/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py +++ b/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py @@ -12,9 +12,10 @@ import typer from data_designer.cli.ui import console, print_error, print_header, print_success, wait_for_navigation_key -from data_designer.cli.utils.config_loader import ConfigLoadError, load_config_builder +from data_designer.cli.utils.config_loader import ConfigLoadError, WorkflowHelpRequested, load_config_builder from data_designer.cli.utils.sample_records_pager import PAGER_FILENAME, create_sample_records_pager from data_designer.config.errors import InvalidConfigError +from data_designer.config.script_params import DataDesignerScriptParams from data_designer.config.utils.constants import DEFAULT_DISPLAY_WIDTH from data_designer.interface import DataDesigner from data_designer.logging import LOG_INDENT @@ -31,13 +32,14 @@ class GenerationController: def run_preview( self, - config_source: str, + config_source: str | None, num_records: int, non_interactive: bool, save_results: bool = False, artifact_path: str | None = None, theme: Literal["dark", "light"] = "dark", display_width: int = DEFAULT_DISPLAY_WIDTH, + workflow_args: tuple[str, ...] = (), ) -> None: """Load config, generate a preview dataset, and display the results. @@ -49,8 +51,9 @@ def run_preview( artifact_path: Directory to save results in, or None for ./artifacts. theme: Color theme for HTML output (dark or light). display_width: Maximum width of the rendered record output in characters. + workflow_args: Arguments forwarded to Python config workflows. """ - config_builder = self._load_config(config_source) + config_builder = self._load_config(config_source, workflow_args=workflow_args) print_header("Data Designer Preview") console.print(f" Config: [bold]{config_source}[/bold]") @@ -86,13 +89,18 @@ def run_preview( console.print() print_success(f"Preview complete — {total} record(s) generated") - def run_validate(self, config_source: str) -> None: + def run_validate( + self, + config_source: str, + workflow_args: tuple[str, ...] = (), + ) -> None: """Load config and validate it against the engine. Args: config_source: Path to a config file or Python module. + workflow_args: Arguments forwarded to Python config workflows. """ - config_builder = self._load_config(config_source) + config_builder = self._load_config(config_source, workflow_args=workflow_args) print_header("Data Designer Validate") console.print(f" Config: [bold]{config_source}[/bold]") @@ -112,11 +120,12 @@ def run_validate(self, config_source: str) -> None: def run_create( self, - config_source: str, + config_source: str | None, num_records: int, dataset_name: str, artifact_path: str | None, output_format: str | None = None, + workflow_args: tuple[str, ...] = (), ) -> None: """Load config, create a full dataset, and save results to disk. @@ -127,8 +136,9 @@ def run_create( artifact_path: Path where generated artifacts will be stored, or None for default. output_format: If set, export the dataset to a single file in this format after generation. One of 'jsonl', 'csv', 'parquet'. + workflow_args: Arguments forwarded to Python config workflows. """ - config_builder = self._load_config(config_source) + config_builder = self._load_config(config_source, workflow_args=workflow_args) resolved_artifact_path = Path(artifact_path) if artifact_path else Path.cwd() / "artifacts" @@ -174,11 +184,16 @@ def run_create( print_success(f"Dataset created — {actual_record_count} record(s) generated") console.print() - def _load_config(self, config_source: str) -> DataDesignerConfigBuilder: + def _load_config( + self, + config_source: str, + workflow_args: tuple[str, ...] = (), + ) -> DataDesignerConfigBuilder: """Load a config builder from the given source, exiting on failure. Args: config_source: Path to a config file or Python module. + workflow_args: Arguments forwarded to Python config workflows. Returns: A DataDesignerConfigBuilder instance. @@ -186,8 +201,11 @@ def _load_config(self, config_source: str) -> DataDesignerConfigBuilder: Raises: typer.Exit: If the config cannot be loaded. """ + script_params = DataDesignerScriptParams(argv=workflow_args) try: - return load_config_builder(config_source) + return load_config_builder(config_source, script_params=script_params) + except WorkflowHelpRequested as e: + raise typer.Exit(code=0) from e except ConfigLoadError as e: print_error(str(e)) raise typer.Exit(code=1) diff --git a/packages/data-designer/src/data_designer/cli/utils/config_loader.py b/packages/data-designer/src/data_designer/cli/utils/config_loader.py index 9fe37b9f1..0b1e11f92 100644 --- a/packages/data-designer/src/data_designer/cli/utils/config_loader.py +++ b/packages/data-designer/src/data_designer/cli/utils/config_loader.py @@ -4,11 +4,15 @@ from __future__ import annotations import importlib.util +import inspect import sys +from collections.abc import Callable from pathlib import Path +from typing import Any from urllib.parse import urlparse from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.script_params import DataDesignerScriptParams from data_designer.config.utils.io_helpers import VALID_CONFIG_FILE_EXTENSIONS, is_http_url @@ -16,13 +20,20 @@ class ConfigLoadError(Exception): """Raised when a configuration source cannot be loaded.""" +class WorkflowHelpRequested(Exception): + """Raised when a Python workflow prints help and exits successfully.""" + + PYTHON_EXTENSIONS = {".py"} ALL_SUPPORTED_EXTENSIONS = VALID_CONFIG_FILE_EXTENSIONS | PYTHON_EXTENSIONS USER_MODULE_FUNC_NAME = "load_config_builder" -def load_config_builder(config_source: str) -> DataDesignerConfigBuilder: +def load_config_builder( + config_source: str, + script_params: DataDesignerScriptParams | None = None, +) -> DataDesignerConfigBuilder: """Load a DataDesignerConfigBuilder from a file path or URL. Auto-detects the file type by extension: @@ -32,6 +43,7 @@ def load_config_builder(config_source: str) -> DataDesignerConfigBuilder: Args: config_source: Path or URL to the configuration file, or path to a Python module. + script_params: Optional runtime arguments for Python config workflows. Returns: A DataDesignerConfigBuilder instance. @@ -40,6 +52,7 @@ def load_config_builder(config_source: str) -> DataDesignerConfigBuilder: ConfigLoadError: If the file cannot be loaded or is invalid. """ if is_http_url(config_source): + _reject_script_params_for_static_source(config_source, script_params) return _load_from_config_url(config_source) path = Path(config_source) @@ -57,9 +70,10 @@ def load_config_builder(config_source: str) -> DataDesignerConfigBuilder: raise ConfigLoadError(f"Unsupported file extension '{suffix}'. Supported extensions: {supported}") if suffix in VALID_CONFIG_FILE_EXTENSIONS: + _reject_script_params_for_static_source(str(path), script_params) return _load_from_config_file(path) - return _load_from_python_module(path) + return _load_from_python_module(path, script_params) def _load_from_config_url(config_source: str) -> DataDesignerConfigBuilder: @@ -101,7 +115,10 @@ def _load_from_config_file(path: Path | str) -> DataDesignerConfigBuilder: raise ConfigLoadError(f"Failed to load config from '{path}': {e}") from e -def _load_from_python_module(path: Path) -> DataDesignerConfigBuilder: +def _load_from_python_module( + path: Path, + script_params: DataDesignerScriptParams | None = None, +) -> DataDesignerConfigBuilder: """Load a DataDesignerConfigBuilder from a Python module. The module must define a load_config_builder() function that returns @@ -109,6 +126,7 @@ def _load_from_python_module(path: Path) -> DataDesignerConfigBuilder: Args: path: Path to the Python module. + script_params: Optional runtime arguments for Python config workflows. Returns: A DataDesignerConfigBuilder instance. @@ -149,10 +167,7 @@ def _load_from_python_module(path: Path) -> DataDesignerConfigBuilder: if not callable(func): raise ConfigLoadError(f"'{USER_MODULE_FUNC_NAME}' in '{path}' is not callable") - try: - config_builder = func() - except Exception as e: - raise ConfigLoadError(f"Error calling '{USER_MODULE_FUNC_NAME}()' in '{path}': {e}") from e + config_builder = call_config_builder_function(func, str(path), script_params) if not isinstance(config_builder, DataDesignerConfigBuilder): raise ConfigLoadError( @@ -162,7 +177,7 @@ def _load_from_python_module(path: Path) -> DataDesignerConfigBuilder: return config_builder - except ConfigLoadError: + except (ConfigLoadError, WorkflowHelpRequested): raise except Exception as e: raise ConfigLoadError(f"Failed to execute Python module '{path}': {e}") from e @@ -178,3 +193,106 @@ def _load_from_python_module(path: Path) -> DataDesignerConfigBuilder: sys.path.remove(parent_dir) except ValueError: pass + + +def call_config_builder_function( + func: Callable[..., Any], + source_name: str, + script_params: DataDesignerScriptParams | None = None, +) -> DataDesignerConfigBuilder: + """Call a user-provided config builder function with a supported signature.""" + params = script_params or DataDesignerScriptParams() + try: + signature = inspect.signature(func) + except (TypeError, ValueError) as e: + raise ConfigLoadError(f"Could not inspect '{USER_MODULE_FUNC_NAME}()' in '{source_name}': {e}") from e + + config_builder: Any + if len(signature.parameters) == 0: + if params.argv: + raise ConfigLoadError( + f"'{USER_MODULE_FUNC_NAME}()' in '{source_name}' does not accept workflow arguments. " + "Update it to accept a DataDesignerScriptParams parameter." + ) + try: + config_builder = func() + except SystemExit as e: + if _is_successful_system_exit(e): + raise WorkflowHelpRequested from e + raise ConfigLoadError(f"'{USER_MODULE_FUNC_NAME}()' in '{source_name}' exited with code {e.code}") from e + except Exception as e: + raise ConfigLoadError(f"Error calling '{USER_MODULE_FUNC_NAME}()' in '{source_name}': {e}") from e + else: + _validate_params_signature(signature, source_name) + try: + config_builder = _call_params_aware_function(func, signature, params) + except SystemExit as e: + if _is_successful_system_exit(e): + raise WorkflowHelpRequested from e + raise ConfigLoadError( + f"'{USER_MODULE_FUNC_NAME}(params)' in '{source_name}' exited with code {e.code}" + ) from e + except Exception as e: + raise ConfigLoadError(f"Error calling '{USER_MODULE_FUNC_NAME}(params)' in '{source_name}': {e}") from e + + if not isinstance(config_builder, DataDesignerConfigBuilder): + raise ConfigLoadError( + f"'{USER_MODULE_FUNC_NAME}()' in '{source_name}' returned " + f"{type(config_builder).__name__}, expected DataDesignerConfigBuilder" + ) + + return config_builder + + +def _validate_params_signature(signature: inspect.Signature, source_name: str) -> None: + parameters = list(signature.parameters.values()) + if len(parameters) != 1: + raise ConfigLoadError( + f"Unsupported '{USER_MODULE_FUNC_NAME}()' signature in '{source_name}'. " + "Expected zero arguments or one DataDesignerScriptParams parameter." + ) + + parameter = parameters[0] + supported_kinds = { + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + } + if parameter.kind not in supported_kinds: + raise ConfigLoadError( + f"Unsupported '{USER_MODULE_FUNC_NAME}()' signature in '{source_name}'. " + "Expected zero arguments or one DataDesignerScriptParams parameter." + ) + + if parameter.kind == inspect.Parameter.KEYWORD_ONLY and parameter.name != "params": + raise ConfigLoadError( + f"Unsupported '{USER_MODULE_FUNC_NAME}()' signature in '{source_name}'. " + "Keyword-only workflow parameters must be named 'params'." + ) + + +def _call_params_aware_function( + func: Callable[..., Any], + signature: inspect.Signature, + params: DataDesignerScriptParams, +) -> Any: + parameter = next(iter(signature.parameters.values())) + if parameter.kind == inspect.Parameter.KEYWORD_ONLY: + return func(params=params) + return func(params) + + +def _reject_script_params_for_static_source( + source_name: str, + script_params: DataDesignerScriptParams | None, +) -> None: + params = script_params or DataDesignerScriptParams() + if params.argv: + raise ConfigLoadError( + f"Workflow arguments are only supported for Python config modules, but '{source_name}' is not a " + "local Python module." + ) + + +def _is_successful_system_exit(exc: SystemExit) -> bool: + return exc.code is None or exc.code == 0 diff --git a/packages/data-designer/tests/cli/commands/test_create_command.py b/packages/data-designer/tests/cli/commands/test_create_command.py index fc779df7c..0ed125907 100644 --- a/packages/data-designer/tests/cli/commands/test_create_command.py +++ b/packages/data-designer/tests/cli/commands/test_create_command.py @@ -19,12 +19,13 @@ def test_create_command_delegates_to_controller(mock_ctrl_cls: MagicMock) -> Non mock_ctrl_cls.return_value = mock_ctrl create_command( - config_source="config.yaml", num_records=10, dataset_name="dataset", artifact_path=None, output_format=None + workflow_args=["config.yaml"], num_records=10, dataset_name="dataset", artifact_path=None, output_format=None ) mock_ctrl_cls.assert_called_once() mock_ctrl.run_create.assert_called_once_with( config_source="config.yaml", + workflow_args=(), num_records=10, dataset_name="dataset", artifact_path=None, @@ -39,7 +40,7 @@ def test_create_command_passes_custom_options(mock_ctrl_cls: MagicMock) -> None: mock_ctrl_cls.return_value = mock_ctrl create_command( - config_source="config.py", + workflow_args=["config.py", "--seed-path", "seed.jsonl"], num_records=100, dataset_name="my_data", artifact_path="/custom/output", @@ -48,6 +49,7 @@ def test_create_command_passes_custom_options(mock_ctrl_cls: MagicMock) -> None: mock_ctrl.run_create.assert_called_once_with( config_source="config.py", + workflow_args=("--seed-path", "seed.jsonl"), num_records=100, dataset_name="my_data", artifact_path="/custom/output", @@ -62,11 +64,12 @@ def test_create_command_default_artifact_path_is_none(mock_ctrl_cls: MagicMock) mock_ctrl_cls.return_value = mock_ctrl create_command( - config_source="config.yaml", num_records=5, dataset_name="ds", artifact_path=None, output_format=None + workflow_args=["config.yaml"], num_records=5, dataset_name="ds", artifact_path=None, output_format=None ) mock_ctrl.run_create.assert_called_once_with( config_source="config.yaml", + workflow_args=(), num_records=5, dataset_name="ds", artifact_path=None, @@ -81,7 +84,7 @@ def test_create_command_passes_output_format(mock_ctrl_cls: MagicMock) -> None: mock_ctrl_cls.return_value = mock_ctrl create_command( - config_source="config.yaml", + workflow_args=["config.yaml"], num_records=10, dataset_name="dataset", artifact_path=None, @@ -90,6 +93,7 @@ def test_create_command_passes_output_format(mock_ctrl_cls: MagicMock) -> None: mock_ctrl.run_create.assert_called_once_with( config_source="config.yaml", + workflow_args=(), num_records=10, dataset_name="dataset", artifact_path=None, diff --git a/packages/data-designer/tests/cli/commands/test_preview_command.py b/packages/data-designer/tests/cli/commands/test_preview_command.py index d9420a094..d24865f59 100644 --- a/packages/data-designer/tests/cli/commands/test_preview_command.py +++ b/packages/data-designer/tests/cli/commands/test_preview_command.py @@ -20,7 +20,7 @@ [ pytest.param( { - "config_source": "config.yaml", + "workflow_args": ["config.yaml"], "num_records": 5, "non_interactive": True, "save_results": False, @@ -32,7 +32,7 @@ ), pytest.param( { - "config_source": "config.yaml", + "workflow_args": ["config.yaml"], "num_records": 10, "non_interactive": False, "save_results": False, @@ -44,7 +44,7 @@ ), pytest.param( { - "config_source": "my_config.py", + "workflow_args": ["my_config.py", "--variant", "compact"], "num_records": 20, "non_interactive": True, "save_results": False, @@ -56,7 +56,7 @@ ), pytest.param( { - "config_source": "config.yaml", + "workflow_args": ["config.yaml"], "num_records": 5, "non_interactive": True, "save_results": True, @@ -68,7 +68,7 @@ ), pytest.param( { - "config_source": "config.yaml", + "workflow_args": ["config.yaml"], "num_records": 5, "non_interactive": True, "save_results": True, @@ -88,8 +88,14 @@ def test_preview_command_delegates_to_controller(mock_ctrl_cls: MagicMock, kwarg preview_command(**kwargs) + expected = { + **kwargs, + "config_source": kwargs["workflow_args"][0], + "workflow_args": tuple(kwargs["workflow_args"][1:]), + } + mock_ctrl_cls.assert_called_once() - mock_ctrl.run_preview.assert_called_once_with(**kwargs) + mock_ctrl.run_preview.assert_called_once_with(**expected) # --------------------------------------------------------------------------- diff --git a/packages/data-designer/tests/cli/commands/test_validate_command.py b/packages/data-designer/tests/cli/commands/test_validate_command.py index 2447c240d..194afb82f 100644 --- a/packages/data-designer/tests/cli/commands/test_validate_command.py +++ b/packages/data-designer/tests/cli/commands/test_validate_command.py @@ -18,10 +18,10 @@ def test_validate_command_delegates_to_controller(mock_ctrl_cls: MagicMock) -> N mock_ctrl = MagicMock() mock_ctrl_cls.return_value = mock_ctrl - validate_command(config_source="config.yaml") + validate_command(workflow_args=["config.yaml"]) mock_ctrl_cls.assert_called_once() - mock_ctrl.run_validate.assert_called_once_with(config_source="config.yaml") + mock_ctrl.run_validate.assert_called_once_with(config_source="config.yaml", workflow_args=()) @patch("data_designer.cli.commands.validate.GenerationController") @@ -30,6 +30,9 @@ def test_validate_command_passes_python_module_source(mock_ctrl_cls: MagicMock) mock_ctrl = MagicMock() mock_ctrl_cls.return_value = mock_ctrl - validate_command(config_source="my_config.py") + validate_command(workflow_args=["my_config.py", "--seed-path", "seed.jsonl"]) - mock_ctrl.run_validate.assert_called_once_with(config_source="my_config.py") + mock_ctrl.run_validate.assert_called_once_with( + config_source="my_config.py", + workflow_args=("--seed-path", "seed.jsonl"), + ) diff --git a/packages/data-designer/tests/cli/controllers/test_generation_controller.py b/packages/data-designer/tests/cli/controllers/test_generation_controller.py index 151f2cbb4..32251428a 100644 --- a/packages/data-designer/tests/cli/controllers/test_generation_controller.py +++ b/packages/data-designer/tests/cli/controllers/test_generation_controller.py @@ -10,9 +10,10 @@ import typer from data_designer.cli.controllers.generation_controller import GenerationController -from data_designer.cli.utils.config_loader import ConfigLoadError +from data_designer.cli.utils.config_loader import ConfigLoadError, WorkflowHelpRequested from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.errors import InvalidConfigError +from data_designer.config.script_params import DataDesignerScriptParams from data_designer.config.utils.constants import DEFAULT_DISPLAY_WIDTH _CTRL = "data_designer.cli.controllers.generation_controller" @@ -54,7 +55,7 @@ def test_run_preview_success(mock_load_config: MagicMock, mock_dd_cls: MagicMock controller = GenerationController() controller.run_preview(config_source="config.yaml", num_records=5, non_interactive=True) - mock_load_config.assert_called_once_with("config.yaml") + mock_load_config.assert_called_once_with("config.yaml", script_params=DataDesignerScriptParams()) mock_dd_cls.assert_called_once() mock_dd.preview.assert_called_once_with(mock_builder, num_records=5) @@ -88,6 +89,23 @@ def test_run_preview_config_load_error(mock_load_config: MagicMock) -> None: assert exc_info.value.exit_code == 1 +@patch(f"{_CTRL}.load_config_builder") +def test_run_preview_workflow_help_exits_successfully(mock_load_config: MagicMock) -> None: + """Test preview exits with code 0 when workflow help is requested.""" + mock_load_config.side_effect = WorkflowHelpRequested() + + controller = GenerationController() + with pytest.raises(typer.Exit) as exc_info: + controller.run_preview( + config_source="config.py", + workflow_args=("--help",), + num_records=10, + non_interactive=True, + ) + + assert exc_info.value.exit_code == 0 + + @patch(f"{_CTRL}.DataDesigner") @patch(f"{_CTRL}.load_config_builder") def test_run_preview_generation_fails(mock_load_config: MagicMock, mock_dd_cls: MagicMock) -> None: @@ -612,7 +630,7 @@ def test_run_validate_success(mock_load_config: MagicMock, mock_dd_cls: MagicMoc controller = GenerationController() controller.run_validate(config_source="config.yaml") - mock_load_config.assert_called_once_with("config.yaml") + mock_load_config.assert_called_once_with("config.yaml", script_params=DataDesignerScriptParams()) mock_dd_cls.assert_called_once() mock_dd.validate.assert_called_once_with(mock_builder) @@ -680,7 +698,7 @@ def test_run_create_success(mock_load_config: MagicMock, mock_dd_cls: MagicMock) controller = GenerationController() controller.run_create(config_source="config.yaml", num_records=10, dataset_name="dataset", artifact_path=None) - mock_load_config.assert_called_once_with("config.yaml") + mock_load_config.assert_called_once_with("config.yaml", script_params=DataDesignerScriptParams()) mock_dd_cls.assert_called_once_with(artifact_path=Path.cwd() / "artifacts") mock_dd.create.assert_called_once_with(mock_builder, num_records=10, dataset_name="dataset") diff --git a/packages/data-designer/tests/cli/test_main.py b/packages/data-designer/tests/cli/test_main.py index 15349d8e6..f002fe84e 100644 --- a/packages/data-designer/tests/cli/test_main.py +++ b/packages/data-designer/tests/cli/test_main.py @@ -81,8 +81,30 @@ def test_app_dispatches_lazy_create_command(mock_controller_cls: Mock) -> None: assert result.exit_code == 0 mock_controller.run_create.assert_called_once_with( config_source="config.yaml", + workflow_args=(), num_records=DEFAULT_NUM_RECORDS, dataset_name="dataset", artifact_path=None, output_format=None, ) + + +@patch("data_designer.cli.commands.preview.GenerationController") +def test_app_dispatches_lazy_preview_command_with_workflow_args(mock_controller_cls: Mock) -> None: + """The Typer app forwards workflow args after -- for Python config workflows.""" + mock_controller = Mock() + mock_controller_cls.return_value = mock_controller + + result = runner.invoke(app, ["preview", "config.py", "--", "--seed-path", "seed.jsonl"]) + + assert result.exit_code == 0 + mock_controller.run_preview.assert_called_once_with( + config_source="config.py", + workflow_args=("--seed-path", "seed.jsonl"), + num_records=DEFAULT_NUM_RECORDS, + non_interactive=False, + save_results=False, + artifact_path=None, + theme="dark", + display_width=110, + ) diff --git a/packages/data-designer/tests/cli/utils/test_config_loader.py b/packages/data-designer/tests/cli/utils/test_config_loader.py index e008290b0..bdf530406 100644 --- a/packages/data-designer/tests/cli/utils/test_config_loader.py +++ b/packages/data-designer/tests/cli/utils/test_config_loader.py @@ -10,9 +10,11 @@ from data_designer.cli.utils.config_loader import ( ConfigLoadError, + WorkflowHelpRequested, load_config_builder, ) from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.script_params import DataDesignerScriptParams @patch("data_designer.cli.utils.config_loader.DataDesignerConfigBuilder.from_config") @@ -97,7 +99,7 @@ def test_load_config_builder_from_python_module(tmp_path: Path) -> None: result = load_config_builder(str(py_file)) - mock_load_py.assert_called_once_with(py_file) + mock_load_py.assert_called_once_with(py_file, None) assert result is mock_builder @@ -206,6 +208,78 @@ def test_load_config_builder_python_module_sibling_import(tmp_path: Path) -> Non assert result._test_marker == "my_dataset" +def test_load_config_builder_python_module_receives_script_params(tmp_path: Path) -> None: + """Test that a params-aware Python config receives workflow arguments.""" + py_file = tmp_path / "params_config.py" + py_file.write_text( + "from data_designer.config.config_builder import DataDesignerConfigBuilder\n\n" + "def load_config_builder(params):\n" + " builder = DataDesignerConfigBuilder()\n" + " builder._test_argv = params.argv\n" + " return builder\n" + ) + + result = load_config_builder( + str(py_file), + script_params=DataDesignerScriptParams(argv=("--seed-path", "seed.jsonl")), + ) + + assert isinstance(result, DataDesignerConfigBuilder) + assert result._test_argv == ("--seed-path", "seed.jsonl") + + +def test_load_config_builder_python_module_preserves_argparse_help_exit(tmp_path: Path) -> None: + """Test that argparse --help exits cleanly instead of being treated as a load error.""" + py_file = tmp_path / "help_config.py" + py_file.write_text( + "import argparse\n" + "from data_designer.config.config_builder import DataDesignerConfigBuilder\n\n" + "def load_config_builder(params):\n" + " parser = argparse.ArgumentParser()\n" + " parser.add_argument('--seed-path')\n" + " parser.parse_args(list(params.argv))\n" + " return DataDesignerConfigBuilder()\n" + ) + + with pytest.raises(WorkflowHelpRequested): + load_config_builder(str(py_file), script_params=DataDesignerScriptParams(argv=("--help",))) + + +def test_load_config_builder_python_module_rejects_args_for_legacy_function(tmp_path: Path) -> None: + """Test that a no-arg Python config fails clearly when workflow args are supplied.""" + py_file = tmp_path / "legacy_config.py" + py_file.write_text( + "from data_designer.config.config_builder import DataDesignerConfigBuilder\n\n" + "def load_config_builder():\n" + " return DataDesignerConfigBuilder()\n" + ) + + with pytest.raises(ConfigLoadError, match="does not accept workflow arguments"): + load_config_builder(str(py_file), script_params=DataDesignerScriptParams(argv=("--seed-path", "seed.jsonl"))) + + +def test_load_config_builder_rejects_script_params_for_yaml(tmp_path: Path) -> None: + """Test that static YAML configs cannot receive workflow args.""" + yaml_file = tmp_path / "config.yaml" + yaml_file.write_text("data_designer:\n columns: []\n") + + with pytest.raises(ConfigLoadError, match="Workflow arguments are only supported"): + load_config_builder(str(yaml_file), script_params=DataDesignerScriptParams(argv=("--seed-path", "seed.jsonl"))) + + +def test_load_config_builder_python_module_rejects_unsupported_signature(tmp_path: Path) -> None: + """Test that Python config modules must use the supported workflow signature.""" + py_file = tmp_path / "bad_signature.py" + py_file.write_text( + "from data_designer.config.config_builder import DataDesignerConfigBuilder\n\n" + "def load_config_builder(first, second):\n" + " return DataDesignerConfigBuilder()\n" + ) + + with pytest.raises(ConfigLoadError, match="Unsupported 'load_config_builder\\(\\)' signature"): + load_config_builder(str(py_file)) + + def test_load_config_builder_python_module_cleans_sys_path(tmp_path: Path) -> None: """Test that the config's parent directory is removed from sys.path after loading.""" import sys