diff --git a/plugboard/cli/process/__init__.py b/plugboard/cli/process/__init__.py index 518abd79..8ebaf475 100644 --- a/plugboard/cli/process/__init__.py +++ b/plugboard/cli/process/__init__.py @@ -13,7 +13,11 @@ from plugboard.diagram import MermaidDiagram from plugboard.process import Process, ProcessBuilder -from plugboard.schemas import ConfigSpec +from plugboard.schemas import ( + ConfigSpec, + ConnectorBuilderSpec, + StateBackendSpec, +) from plugboard.tune import Tuner from plugboard.utils import add_sys_path, run_coro_sync @@ -34,6 +38,65 @@ def _read_yaml(path: Path) -> ConfigSpec: return ConfigSpec.model_validate(data) +def _override_process_type(config: ConfigSpec, process_type: str) -> None: + """Override the process type in the config and ensure compatible connector and state. + + Args: + config: The configuration spec to modify + process_type: The process type to use ("local" or "ray") + """ + if process_type == "ray": + # Prepare updates for RayProcess + ray_updates: dict[str, _t.Any] = { + "type": "plugboard.process.RayProcess", + } + # Override connector builder to RayConnector if it's the default AsyncioConnector + if ( + config.plugboard.process.connector_builder.type + == "plugboard.connector.AsyncioConnector" + ): + ray_updates["connector_builder"] = ConnectorBuilderSpec( + type="plugboard.connector.RayConnector", + args=config.plugboard.process.connector_builder.args, + ) + # Override state backend to RayStateBackend if it's the default DictStateBackend + if config.plugboard.process.args.state.type == "plugboard.state.DictStateBackend": + new_state = StateBackendSpec( + type="plugboard.state.RayStateBackend", + args=config.plugboard.process.args.state.args, + ) + ray_updates["args"] = config.plugboard.process.args.model_copy( + update={"state": new_state} + ) + + # Apply all updates at once using model_copy + config.plugboard.process = config.plugboard.process.model_copy(update=ray_updates) + + elif process_type == "local": + # Prepare updates for LocalProcess + local_updates: dict[str, _t.Any] = { + "type": "plugboard.process.LocalProcess", + } + # Override connector builder to AsyncioConnector if it's RayConnector + if config.plugboard.process.connector_builder.type == "plugboard.connector.RayConnector": + local_updates["connector_builder"] = ConnectorBuilderSpec( + type="plugboard.connector.AsyncioConnector", + args=config.plugboard.process.connector_builder.args, + ) + # Override state backend to DictStateBackend if it's RayStateBackend + if config.plugboard.process.args.state.type == "plugboard.state.RayStateBackend": + new_state = StateBackendSpec( + type="plugboard.state.DictStateBackend", + args=config.plugboard.process.args.state.args, + ) + local_updates["args"] = config.plugboard.process.args.model_copy( + update={"state": new_state} + ) + + # Apply all updates at once using model_copy + config.plugboard.process = config.plugboard.process.model_copy(update=local_updates) + + def _build_process(config: ConfigSpec) -> Process: process = ProcessBuilder.build(config.plugboard.process) return process @@ -92,6 +155,16 @@ def run( help="Job ID for the process. If not provided, a random job ID will be generated.", ), ] = None, + process_type: Annotated[ + _t.Optional[str], + typer.Option( + "--process-type", + help=( + "Override the process type. " + "Options: 'local' for LocalProcess, 'ray' for RayProcess." + ), + ), + ] = None, ) -> None: """Run a Plugboard process.""" config_spec = _read_yaml(config) @@ -100,6 +173,16 @@ def run( # Override job ID in config file if set config_spec.plugboard.process.args.state.args.job_id = job_id + if process_type is not None: + # Validate and normalize process type + process_type_lower = process_type.lower() + if process_type_lower not in ["local", "ray"]: + stderr.print( + f"[red]Invalid process type: {process_type}. Must be 'local' or 'ray'.[/red]" + ) + raise typer.Exit(1) + _override_process_type(config_spec, process_type_lower) + with Progress( SpinnerColumn("arrow3"), TextColumn("[progress.description]{task.description}"), diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 76950f6d..d1d1a4e3 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -59,3 +59,58 @@ def test_cli_process_diagram() -> None: assert result.exit_code == 0 # Must output a Mermaid flowchart assert "flowchart" in result.stdout + + +@pytest.mark.asyncio +async def test_cli_process_run_with_local_override() -> None: + """Tests the process run command with --process-type local.""" + with patch("plugboard.cli.process.ProcessBuilder") as mock_process_builder: + mock_process = AsyncMock() + mock_process_builder.build.return_value = mock_process + result = runner.invoke( + app, + ["process", "run", "tests/data/minimal-process.yaml", "--process-type", "local"], + ) + # CLI must run without error + assert result.exit_code == 0 + assert "Process complete" in result.stdout + # Process must be built with LocalProcess type + mock_process_builder.build.assert_called_once() + call_args = mock_process_builder.build.call_args + process_spec = call_args[0][0] + assert process_spec.type == "plugboard.process.LocalProcess" + assert process_spec.connector_builder.type == "plugboard.connector.AsyncioConnector" + + +@pytest.mark.asyncio +async def test_cli_process_run_with_ray_override() -> None: + """Tests the process run command with --process-type ray.""" + with patch("plugboard.cli.process.ProcessBuilder") as mock_process_builder: + mock_process = AsyncMock() + mock_process_builder.build.return_value = mock_process + result = runner.invoke( + app, + ["process", "run", "tests/data/minimal-process.yaml", "--process-type", "ray"], + ) + # CLI must run without error + assert result.exit_code == 0 + assert "Process complete" in result.stdout + # Process must be built with RayProcess type + mock_process_builder.build.assert_called_once() + call_args = mock_process_builder.build.call_args + process_spec = call_args[0][0] + assert process_spec.type == "plugboard.process.RayProcess" + assert process_spec.connector_builder.type == "plugboard.connector.RayConnector" + assert process_spec.args.state.type == "plugboard.state.RayStateBackend" + + +@pytest.mark.asyncio +async def test_cli_process_run_with_invalid_process_type() -> None: + """Tests the process run command with invalid --process-type.""" + result = runner.invoke( + app, + ["process", "run", "tests/data/minimal-process.yaml", "--process-type", "invalid"], + ) + # CLI must exit with error + assert result.exit_code == 1 + assert "Invalid process type" in result.stdout