Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion plugboard/cli/process/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@

from plugboard.diagram import MermaidDiagram
from plugboard.process import Process, ProcessBuilder
from plugboard.schemas import ConfigSpec
from plugboard.schemas import (
ConfigSpec,
ConnectorBuilderSpec,
StateBackendSpec,
)
from plugboard.tune import Tuner
from plugboard.utils import add_sys_path, run_coro_sync

Expand All @@ -34,6 +38,65 @@ def _read_yaml(path: Path) -> ConfigSpec:
return ConfigSpec.model_validate(data)


def _override_process_type(config: ConfigSpec, process_type: str) -> None:
"""Override the process type in the config and ensure compatible connector and state.

Args:
config: The configuration spec to modify
process_type: The process type to use ("local" or "ray")
"""
if process_type == "ray":
# Prepare updates for RayProcess
ray_updates: dict[str, _t.Any] = {
"type": "plugboard.process.RayProcess",
}
# Override connector builder to RayConnector if it's the default AsyncioConnector
if (
config.plugboard.process.connector_builder.type
== "plugboard.connector.AsyncioConnector"
):
ray_updates["connector_builder"] = ConnectorBuilderSpec(
type="plugboard.connector.RayConnector",
args=config.plugboard.process.connector_builder.args,
)
# Override state backend to RayStateBackend if it's the default DictStateBackend
if config.plugboard.process.args.state.type == "plugboard.state.DictStateBackend":
new_state = StateBackendSpec(
type="plugboard.state.RayStateBackend",
args=config.plugboard.process.args.state.args,
)
ray_updates["args"] = config.plugboard.process.args.model_copy(
update={"state": new_state}
)

# Apply all updates at once using model_copy
config.plugboard.process = config.plugboard.process.model_copy(update=ray_updates)

elif process_type == "local":
# Prepare updates for LocalProcess
local_updates: dict[str, _t.Any] = {
"type": "plugboard.process.LocalProcess",
}
# Override connector builder to AsyncioConnector if it's RayConnector
if config.plugboard.process.connector_builder.type == "plugboard.connector.RayConnector":
local_updates["connector_builder"] = ConnectorBuilderSpec(
type="plugboard.connector.AsyncioConnector",
args=config.plugboard.process.connector_builder.args,
)
# Override state backend to DictStateBackend if it's RayStateBackend
if config.plugboard.process.args.state.type == "plugboard.state.RayStateBackend":
new_state = StateBackendSpec(
type="plugboard.state.DictStateBackend",
args=config.plugboard.process.args.state.args,
)
local_updates["args"] = config.plugboard.process.args.model_copy(
update={"state": new_state}
)

# Apply all updates at once using model_copy
config.plugboard.process = config.plugboard.process.model_copy(update=local_updates)


def _build_process(config: ConfigSpec) -> Process:
process = ProcessBuilder.build(config.plugboard.process)
return process
Expand Down Expand Up @@ -92,6 +155,16 @@ def run(
help="Job ID for the process. If not provided, a random job ID will be generated.",
),
] = None,
process_type: Annotated[
_t.Optional[str],
typer.Option(
"--process-type",
help=(
"Override the process type. "
"Options: 'local' for LocalProcess, 'ray' for RayProcess."
),
),
] = None,
) -> None:
"""Run a Plugboard process."""
config_spec = _read_yaml(config)
Expand All @@ -100,6 +173,16 @@ def run(
# Override job ID in config file if set
config_spec.plugboard.process.args.state.args.job_id = job_id

if process_type is not None:
# Validate and normalize process type
process_type_lower = process_type.lower()
if process_type_lower not in ["local", "ray"]:
stderr.print(
f"[red]Invalid process type: {process_type}. Must be 'local' or 'ray'.[/red]"
)
raise typer.Exit(1)
_override_process_type(config_spec, process_type_lower)

with Progress(
SpinnerColumn("arrow3"),
TextColumn("[progress.description]{task.description}"),
Expand Down
55 changes: 55 additions & 0 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,58 @@ def test_cli_process_diagram() -> None:
assert result.exit_code == 0
# Must output a Mermaid flowchart
assert "flowchart" in result.stdout


@pytest.mark.asyncio
async def test_cli_process_run_with_local_override() -> None:
"""Tests the process run command with --process-type local."""
with patch("plugboard.cli.process.ProcessBuilder") as mock_process_builder:
mock_process = AsyncMock()
mock_process_builder.build.return_value = mock_process
result = runner.invoke(
app,
["process", "run", "tests/data/minimal-process.yaml", "--process-type", "local"],
)
# CLI must run without error
assert result.exit_code == 0
assert "Process complete" in result.stdout
# Process must be built with LocalProcess type
mock_process_builder.build.assert_called_once()
call_args = mock_process_builder.build.call_args
process_spec = call_args[0][0]
assert process_spec.type == "plugboard.process.LocalProcess"
assert process_spec.connector_builder.type == "plugboard.connector.AsyncioConnector"


@pytest.mark.asyncio
async def test_cli_process_run_with_ray_override() -> None:
"""Tests the process run command with --process-type ray."""
with patch("plugboard.cli.process.ProcessBuilder") as mock_process_builder:
mock_process = AsyncMock()
mock_process_builder.build.return_value = mock_process
result = runner.invoke(
app,
["process", "run", "tests/data/minimal-process.yaml", "--process-type", "ray"],
)
# CLI must run without error
assert result.exit_code == 0
assert "Process complete" in result.stdout
# Process must be built with RayProcess type
mock_process_builder.build.assert_called_once()
call_args = mock_process_builder.build.call_args
process_spec = call_args[0][0]
assert process_spec.type == "plugboard.process.RayProcess"
assert process_spec.connector_builder.type == "plugboard.connector.RayConnector"
assert process_spec.args.state.type == "plugboard.state.RayStateBackend"


@pytest.mark.asyncio
async def test_cli_process_run_with_invalid_process_type() -> None:
"""Tests the process run command with invalid --process-type."""
result = runner.invoke(
app,
["process", "run", "tests/data/minimal-process.yaml", "--process-type", "invalid"],
)
# CLI must exit with error
assert result.exit_code == 1
assert "Invalid process type" in result.stdout