diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d3e870..b10ab53 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,22 @@ using [PEP 440](https://packaging.python.org/en/latest/specifications/version-sp ## [Unreleased] +## 0.1.0a4 - 2026-05-01 + +### Added + +- Adds generated qapi metadata transformer actions so compatible artifacts can + be converted to metadata inside exported Adagio pipelines. +- Adds pipeline/runtime support for built-in metadata conversion steps and + archive collection bindings. + +### Fixed + +- Fixes optional pipeline inputs so omitted optional values are not treated as + required at runtime. +- Fixes dynamic run options so `--show-params` only controls help display and + does not affect which CLI options can be passed. + ## 0.1.0a3 - 2026-05-01 - Adds support for collections. Adagio pipelines with collections are now handled diff --git a/pyproject.toml b/pyproject.toml index 17e6a84..ca9a8f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "adagio-cli" -version = "0.1.0a3" +version = "0.1.0a4" description = "Command-line runner for Adagio pipeline files." readme = "README.md" requires-python = ">=3.10" diff --git a/src/adagio/cli/dynamic.py b/src/adagio/cli/dynamic.py index 39a1206..5428acd 100644 --- a/src/adagio/cli/dynamic.py +++ b/src/adagio/cli/dynamic.py @@ -242,6 +242,9 @@ def build_dynamic_run( input_specs: list[InputSpec], param_specs: list[ParamSpec], output_specs: list[OutputSpec], + visible_input_names: set[str] | None = None, + visible_param_names: set[str] | None = None, + visible_output_names: set[str] | None = None, argument_inputs: dict[str, Any] | None = None, argument_params: dict[str, Any] | None = None, run_handler: Callable[ @@ -261,6 +264,13 @@ def build_dynamic_run( ], ): """Build a dynamic run command from pipeline input, parameter, and output specs.""" + visible_input_names = ( + set(visible_input_names) if visible_input_names is not None else None + ) + visible_param_names = set(visible_param_names) if visible_param_names is not None else None + visible_output_names = ( + set(visible_output_names) if visible_output_names is not None else None + ) input_bindings: list[tuple[str, str]] = [] param_bindings: list[tuple[str, str]] = [] output_bindings: list[tuple[str, str]] = [] @@ -401,6 +411,7 @@ def add_dynamic_option( help_text: str, default: Any, group: Group | tuple[Group, ...], + show: bool = True, ) -> None: if opt in seen_opts: raise ValueError(f"Conflicting CLI option generated: {opt!r}.") @@ -414,6 +425,7 @@ def add_dynamic_option( group=group, help=help_text, required=required, + show=show, ), ] parameters.append( @@ -446,6 +458,7 @@ def add_input_spec(spec: InputSpec) -> None: type_text = spec.type opt = dynamic_opt(original, ParamType.INPUT) + show = visible_input_names is None or original in visible_input_names entry_metadata[opt] = { "type_label": _display_type_label( spec_type=type_text, type_hint=str, is_input=True @@ -463,6 +476,7 @@ def add_input_spec(spec: InputSpec) -> None: ), default=None, group=pipeline_group, + show=show, ) def add_param_spec(spec: ParamSpec) -> None: @@ -486,6 +500,7 @@ def add_param_spec(spec: ParamSpec) -> None: param_default = None param_type: Any = _resolve_param_type(spec.type, default) opt = dynamic_opt(original, ParamType.PARAM) + show = visible_param_names is None or original in visible_param_names if is_required: required_params.append(original) entry_metadata[opt] = { @@ -505,6 +520,7 @@ def add_param_spec(spec: ParamSpec) -> None: ), default=param_default, group=pipeline_group, + show=show, ) for spec in required_input_specs: @@ -526,6 +542,7 @@ def add_param_spec(spec: ParamSpec) -> None: seen_idents.add(ident) output_bindings.append((ident, original)) opt = dynamic_opt(original, ParamType.OUTPUT) + show = visible_output_names is None or original in visible_output_names entry_metadata[opt] = { "type_label": path_type_label(spec.type), "default": None, @@ -541,6 +558,7 @@ def add_param_spec(spec: ParamSpec) -> None: ), default=None, group=pipeline_group, + show=show, ) def run( diff --git a/src/adagio/cli/main.py b/src/adagio/cli/main.py index 82bdad7..8e23417 100644 --- a/src/adagio/cli/main.py +++ b/src/adagio/cli/main.py @@ -182,9 +182,12 @@ def run( ) dynamic_run = build_dynamic_run( - input_specs=visible_inputs, - param_specs=visible_params, - output_specs=visible_outputs, + input_specs=input_specs, + param_specs=param_specs, + output_specs=output_specs, + visible_input_names={spec.name for spec in visible_inputs}, + visible_param_names={spec.name for spec in visible_params}, + visible_output_names={spec.name for spec in visible_outputs}, argument_inputs=arguments_data.get("inputs", {}) if arguments_data else None, argument_params=arguments_data.get("parameters", {}) if arguments_data else None, run_handler=partial(run_pipeline_from_kwargs, console=console), diff --git a/src/adagio/executors/common.py b/src/adagio/executors/common.py index 5f5b1eb..65c795c 100644 --- a/src/adagio/executors/common.py +++ b/src/adagio/executors/common.py @@ -3,9 +3,15 @@ from adagio.model.task import input_source_ids -def plan_execution_order(*, tasks: list[t.Any], scope: dict[str, t.Any]) -> list[t.Any]: +def plan_execution_order( + *, + tasks: list[t.Any], + scope: dict[str, t.Any], + optional_missing_ids: set[str] | None = None, +) -> list[t.Any]: """Return a dependency-respecting serial execution plan.""" available_ids = set(scope.keys()) + optional_missing_ids = optional_missing_ids or set() remaining = list(tasks) planned: list[t.Any] = [] @@ -16,7 +22,7 @@ def plan_execution_order(*, tasks: list[t.Any], scope: dict[str, t.Any]) -> list source_id for src in task.inputs.values() for source_id in input_source_ids(src) - if source_id not in available_ids + if source_id not in available_ids and source_id not in optional_missing_ids ] if missing: continue @@ -35,6 +41,7 @@ def plan_execution_order(*, tasks: list[t.Any], scope: dict[str, t.Any]) -> list for src in task.inputs.values() for source_id in input_source_ids(src) if source_id not in available_ids + and source_id not in optional_missing_ids ) details.append(f"{task.id}: missing [{missing}]") raise RuntimeError("Unable to resolve task dependencies. " + "; ".join(details)) diff --git a/src/adagio/executors/serial_runner.py b/src/adagio/executors/serial_runner.py index d4f06ff..405d65f 100644 --- a/src/adagio/executors/serial_runner.py +++ b/src/adagio/executors/serial_runner.py @@ -26,6 +26,7 @@ class SerialExecutionState: params: dict[str, t.Any] scope: dict[str, InputSource] cache_config: ExecutionCacheConfig | None + missing_optional_ids: set[str] = field(default_factory=set) saved_output_ids: set[str] = field(default_factory=set) save_output_started: bool = False @@ -64,13 +65,21 @@ def run_serial_pipeline( active_monitor.start_load_input() for input_def in sig.inputs: - source = arguments.inputs[input_def.name] + source = arguments.inputs.get(input_def.name) + if _is_missing(source): + if not input_def.required: + state.missing_optional_ids.add(input_def.id) + continue state.scope[input_def.id] = resolve_pipeline_input( source=source, type_name=input_def.type, cwd=state.cwd ) active_monitor.finish_load_input() - execution_plan = plan_execution_order(tasks=tasks, scope=state.scope) + execution_plan = plan_execution_order( + tasks=tasks, + scope=state.scope, + optional_missing_ids=state.missing_optional_ids, + ) for task in execution_plan: active_monitor.queue_task( task_id=task.id, @@ -138,6 +147,10 @@ def resolve_monitor(*, console: Console | None, monitor: Monitor | None) -> Moni return LogMonitor() +def _is_missing(value: t.Any) -> bool: + return value is None or value == "" or value == "" or value == [] or value == {} + + def resolve_pipeline_input( *, source: InputSource, type_name: str, cwd: Path ) -> InputSource: diff --git a/src/adagio/executors/task_environments.py b/src/adagio/executors/task_environments.py index 396ed28..7905443 100644 --- a/src/adagio/executors/task_environments.py +++ b/src/adagio/executors/task_environments.py @@ -5,7 +5,7 @@ from rich.console import Console from adagio.model.arguments import AdagioArguments -from adagio.model.task import PluginActionTask, RootInputTask +from adagio.model.task import ConvertToMetadataTask, PluginActionTask, RootInputTask from adagio.monitor.api import Monitor from .base import ( @@ -60,9 +60,21 @@ def _resolve_task( if isinstance(task, RootInputTask): for name, src in task.inputs.items(): dst = task.outputs[name] + if src.id in state.missing_optional_ids: + state.missing_optional_ids.add(dst.id) + continue state.scope[dst.id] = state.scope[src.id] return False + if isinstance(task, ConvertToMetadataTask): + if task.inputs["data"].id in state.missing_optional_ids: + state.missing_optional_ids.add(task.outputs["metadata"].id) + return False + state.scope[task.outputs["metadata"].id] = state.scope[ + task.inputs["data"].id + ] + return False + if isinstance(task, PluginActionTask): return self._execute_plugin_action( task=task, @@ -91,6 +103,8 @@ def _execute_plugin_action( metadata_inputs: dict[str, str] = {} for name, src in task.inputs.items(): if src.kind == "archive": + if src.id in state.missing_optional_ids: + continue value = state.scope[src.id] if isinstance(value, list): archive_collection_inputs[name] = value @@ -99,10 +113,15 @@ def _execute_plugin_action( else: archive_inputs[name] = value elif src.kind == "archive-collection": - archive_collection_inputs[name] = _flatten_collection_items( - [state.scope[item.id] for item in src.items] + values = _present_collection_item_values( + items=src.items, + state=state, ) + if values: + archive_collection_inputs[name] = _flatten_collection_items(values) elif src.kind == "metadata": + if src.id in state.missing_optional_ids: + continue value = state.scope[src.id] if not isinstance(value, str): raise TypeError( @@ -190,6 +209,19 @@ def _flatten_collection_items( return result +def _present_collection_item_values( + *, + items, + state: SerialExecutionState, +) -> list[str | list[str] | dict[str, str]]: + values: list[str | list[str] | dict[str, str]] = [] + for item in items: + if item.id in state.missing_optional_ids: + continue + values.append(state.scope[item.id]) + return values + + def _save_outputs( *, sig, diff --git a/src/adagio/model/pipeline.py b/src/adagio/model/pipeline.py index 4432646..dd3fd53 100644 --- a/src/adagio/model/pipeline.py +++ b/src/adagio/model/pipeline.py @@ -1,8 +1,7 @@ import typing as t import os -import json -from pydantic import BaseModel, RootModel, model_validator, Field +from pydantic import BaseModel, RootModel, model_validator from .arguments import AdagioArguments @@ -68,7 +67,9 @@ def load_inputs(self, ctx, arguments, scope): from adagio.io import load_input, load_input_collection, load_metadata for input in self.inputs: - source = arguments.inputs[input.name] + source = arguments.inputs.get(input.name) + if _is_missing(source): + continue if _is_metadata_ast(input.ast): print("SCHEDULED:", f'load_metadata({source!r})') scope[input.id] = load_metadata(ctx=ctx, source=source) @@ -140,3 +141,7 @@ def _is_metadata_ast(ast: TypeAST) -> bool: def _is_collection_type(type_name: str) -> bool: return type_name.startswith('List[') or type_name.startswith('Collection[') + + +def _is_missing(value: t.Any) -> bool: + return value is None or value == "" or value == "" or value == [] or value == {} diff --git a/src/adagio/model/task.py b/src/adagio/model/task.py index a9b3b5d..ed1bfbf 100644 --- a/src/adagio/model/task.py +++ b/src/adagio/model/task.py @@ -5,9 +5,9 @@ class _BaseTask(BaseModel): id: str kind: str - inputs: dict[str, 'TaskInputVal'] - parameters: dict[str, 'LiteralVal | MetadataVal | PromotedVal'] - outputs: dict[str, 'OutputVal'] + inputs: dict[str, "TaskInputVal"] + parameters: dict[str, "LiteralVal | MetadataVal | PromotedVal"] + outputs: dict[str, "OutputVal"] def exec(self, ctx, params, scope): raise NotImplementedError @@ -15,7 +15,7 @@ def exec(self, ctx, params, scope): class PluginActionTask(_BaseTask): id: str - kind: t.Literal['plugin-action'] + kind: t.Literal["plugin-action"] name: str | None = None plugin: str action: str @@ -27,39 +27,43 @@ def exec(self, ctx, params, scope): kwargs = {} metadata = {} for name, src in self.inputs.items(): - if src.kind == 'archive': + if src.kind == "archive": + if src.id not in scope: + continue kwargs[name] = scope[src.id] - elif src.kind == 'archive-collection': + elif src.kind == "archive-collection": kwargs[name] = _flatten_collection_values( - [scope[item.id] for item in src.items] + [scope[item.id] for item in src.items if item.id in scope] ) - elif src.kind == 'metadata': + elif src.kind == "metadata": + if src.id not in scope: + continue # store for second pass in params metadata[name] = scope[src.id] else: - raise NotImplementedError('impossible') + raise NotImplementedError("impossible") for name, param in self.parameters.items(): - if param.kind == 'metadata': - if param.column.kind == 'literal': + if param.kind == "metadata": + if param.column.kind == "literal": col = param.value - elif param.column.kind == 'promoted': + elif param.column.kind == "promoted": col = params[param.column.id] else: - raise NotImplementedError('impossible') + raise NotImplementedError("impossible") source = metadata.pop(name) md = convert_metadata(ctx=ctx, metadata=source) kwargs[name] = md.get_column(col) - elif param.kind == 'literal': + elif param.kind == "literal": kwargs[name] = param.value - elif param.kind == 'promoted': + elif param.kind == "promoted": kwargs[name] = params[param.id] else: - raise NotImplementedError('impossible') + raise NotImplementedError("impossible") # any remaining metadata is used directly for name, value in metadata.items(): @@ -71,17 +75,29 @@ def exec(self, ctx, params, scope): class RootInputTask(_BaseTask): - kind: t.Literal['built-in'] - name: t.Literal['root-input'] + kind: t.Literal["built-in"] + name: t.Literal["root-input"] def exec(self, ctx, params, scope): for name, src in self.inputs.items(): dst = self.outputs[name] + if src.id in scope: + scope[dst.id] = scope[src.id] + + +class ConvertToMetadataTask(_BaseTask): + kind: t.Literal["built-in"] + name: t.Literal["convert-to-metadata"] + + def exec(self, ctx, params, scope): + src = self.inputs["data"] + dst = self.outputs["metadata"] + if src.id in scope: scope[dst.id] = scope[src.id] class InputVal(BaseModel): - kind: t.Literal['archive', 'metadata'] + kind: t.Literal["archive", "metadata"] id: str @@ -91,24 +107,24 @@ class ArchiveCollectionItemVal(BaseModel): class ArchiveCollectionInputVal(BaseModel): - kind: t.Literal['archive-collection'] - style: t.Literal['list'] + kind: t.Literal["archive-collection"] + style: t.Literal["list"] items: list[ArchiveCollectionItemVal] class OutputVal(BaseModel): - kind: t.Literal['archive'] + kind: t.Literal["archive"] id: str class PromotedVal(BaseModel): - kind: t.Literal['promoted'] + kind: t.Literal["promoted"] id: str class LiteralVal(BaseModel): - kind: t.Literal['literal'] - value: 'AllowableValue' + kind: t.Literal["literal"] + value: "AllowableValue" class LiteralStrVal(LiteralVal): @@ -116,7 +132,7 @@ class LiteralStrVal(LiteralVal): class MetadataVal(BaseModel): - kind: t.Literal['metadata'] + kind: t.Literal["metadata"] column: PromotedVal | LiteralStrVal @@ -124,15 +140,18 @@ class MetadataVal(BaseModel): Collection = list[Primitive] | dict[str, Primitive] AllowableValue = Primitive | Collection TaskInputVal = t.Annotated[ - t.Union[InputVal, ArchiveCollectionInputVal], - Field(discriminator='kind') + t.Union[InputVal, ArchiveCollectionInputVal], Field(discriminator="kind") +] +BuiltInTask = t.Annotated[ + t.Union[RootInputTask, ConvertToMetadataTask], Field(discriminator="name") +] +AdagioTask = t.Annotated[ + t.Union[PluginActionTask, BuiltInTask], Field(discriminator="kind") ] -AdagioTask = t.Annotated[t.Union[PluginActionTask, RootInputTask], - Field(discriminator='kind')] def input_source_ids(value: TaskInputVal) -> list[str]: - if value.kind == 'archive-collection': + if value.kind == "archive-collection": return [item.id for item in value.items] return [value.id] diff --git a/src/adagio/qapi/build.py b/src/adagio/qapi/build.py index 758efa2..7804a3a 100644 --- a/src/adagio/qapi/build.py +++ b/src/adagio/qapi/build.py @@ -4,6 +4,64 @@ DEFAULT_SCHEMA_VERSION = "0.1.0" PRIVATE_QIIME_ACTION_PREFIXES = ("_", "-") +ADAGIO_BUILTIN_PLUGIN = "adagio_builtin" +CONVERT_TO_METADATA_ACTION_ID = "convert_to_metadata" +CONVERT_TO_METADATA_ACTION_NAME = "convert-to-metadata" + + +def _metadata_ast() -> dict[str, Any]: + return { + "name": "Metadata", + "type": "expression", + "fields": [], + "builtin": True, + "predicate": None, + } + + +def _union_ast(members: list[dict[str, Any]]) -> dict[str, Any]: + return {"type": "union", "members": members} + + +def _build_convert_to_metadata_action( + source_types: Sequence[tuple[str, dict[str, Any]]], +) -> dict[str, Any] | None: + """Build the synthetic QAPI action for artifact-to-Metadata conversion.""" + if not source_types: + return None + + sorted_source_types = sorted(source_types, key=lambda item: item[0]) + type_names = [type_name for type_name, _ in sorted_source_types] + source_ast = _union_ast([ast for _, ast in sorted_source_types]) + metadata_ast = _metadata_ast() + + return { + "id": CONVERT_TO_METADATA_ACTION_ID, + "name": CONVERT_TO_METADATA_ACTION_NAME, + "description": ( + "Convert an artifact with a registered QIIME 2 metadata transformer " + "into metadata for downstream actions." + ), + "inputs": [ + { + "name": "data", + "type": " | ".join(type_names), + "ast": source_ast, + "required": True, + "description": ("Artifact that can be viewed as QIIME 2 Metadata."), + } + ], + "parameters": [], + "outputs": [ + { + "name": "metadata", + "type": "Metadata", + "ast": metadata_ast, + "description": "Metadata view of the input artifact.", + } + ], + "adagio_builtin": "metadata_transformer", + } def _private_qiime_action_id(action_key: object, action: Any) -> str | None: @@ -121,6 +179,23 @@ def add_metadata_flag(ast: dict[str, Any]) -> dict[str, Any]: return ast return ast + def iter_metadata_transformer_source_types() -> Iterator[ + tuple[str, dict[str, Any]] + ]: + to_type = transform.ModelType.from_view_type(qiime2.Metadata) + for _, artifact_class in sorted(plugin_manager.artifact_classes.items()): + try: + from_type = transform.ModelType.from_view_type(artifact_class.format) + if not from_type.has_transformation(to_type): + continue + semantic_type = artifact_class.semantic_type + yield ( + repr(semantic_type), + flatten_type_maps(semantic_type).to_ast(), + ) + except Exception: + continue + def optional_desc(value: Any) -> str | None: no_value = qiime2.core.type.signature.__NoValueMeta # type: ignore[attr-defined] return value if type(value) is not no_value else None @@ -196,6 +271,17 @@ def build_data_dict( ) qapi[plugin_name] = {"methods": methods_dict} + if requested_plugins is None: + convert_to_metadata = _build_convert_to_metadata_action( + list(iter_metadata_transformer_source_types()) + ) + if convert_to_metadata is not None: + qapi[ADAGIO_BUILTIN_PLUGIN] = { + "methods": { + CONVERT_TO_METADATA_ACTION_ID: convert_to_metadata, + } + } + return { "qiime_version": qiime2.__version__, "schema_version": schema_version, diff --git a/tests/test_output_options.py b/tests/test_output_options.py index a430143..0c6a10f 100644 --- a/tests/test_output_options.py +++ b/tests/test_output_options.py @@ -1,9 +1,14 @@ +import json +import tempfile import typing import unittest +from pathlib import Path +from unittest.mock import patch from adagio.app.parsers.pipeline import Input, Output, Parameter, parse_outputs from adagio.cli.args import ShowParamsMode from adagio.cli.dynamic import build_dynamic_run +from adagio.cli.main import main from adagio.cli.main import _filter_visible_specs from adagio.cli.runner import _apply_output_overrides @@ -154,6 +159,105 @@ def test_output_dir_is_a_command_option_and_required_pipeline_options_are_first( ], ) + def test_hidden_display_options_remain_registered_for_runtime(self) -> None: + dynamic_run = build_dynamic_run( + input_specs=[ + Input( + id="00000000-0000-0000-0000-000000000001", + name="seqs", + required=True, + type="SampleData[Sequences]", + description="Required sequences.", + ), + Input( + id="00000000-0000-0000-0000-000000000002", + name="tree", + required=False, + type="Phylogeny[Rooted]", + description="Optional tree.", + ), + ], + param_specs=[], + output_specs=[ + Output( + id="00000000-0000-0000-0000-000000000003", + name="table", + type="FeatureTable[Frequency]", + description="Output table.", + ) + ], + visible_input_names={"seqs"}, + visible_param_names=set(), + visible_output_names=set(), + run_handler=lambda *args, **kwargs: None, + ) + + params = dynamic_run.__signature__.parameters + + self.assertIn("input_tree", params) + self.assertIn("output_table", params) + input_tree_param = typing.get_args(params["input_tree"].annotation)[1] + output_table_param = typing.get_args(params["output_table"].annotation)[1] + input_seqs_param = typing.get_args(params["input_seqs"].annotation)[1] + self.assertFalse(input_tree_param.show) + self.assertFalse(output_table_param.show) + self.assertTrue(input_seqs_param.show) + + def test_main_accepts_hidden_optional_input_without_show_params_all(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + pipeline_path = root / "pipeline.adg" + pipeline_path.write_text( + json.dumps( + { + "signature": { + "inputs": [ + { + "id": "00000000-0000-0000-0000-000000000001", + "name": "seqs", + "required": True, + "type": "SampleData[Sequences]", + "description": "Required sequences.", + }, + { + "id": "00000000-0000-0000-0000-000000000002", + "name": "tree", + "required": False, + "type": "Phylogeny[Rooted]", + "description": "Optional tree.", + }, + ], + "parameters": [], + "outputs": [], + } + } + ), + encoding="utf-8", + ) + + with patch("adagio.cli.main.run_pipeline_from_kwargs") as run_mock: + with self.assertRaises(SystemExit) as exc: + main( + [ + "run", + "--pipeline", + str(pipeline_path), + "--cache-dir", + str(root / "cache"), + "--input-seqs", + "seqs.qza", + "--input-tree", + "tree.qza", + ] + ) + self.assertEqual(exc.exception.code, 0) + + run_mock.assert_called_once() + runtime_kwargs = run_mock.call_args.args[3] + input_bindings = run_mock.call_args.args[4] + self.assertEqual(runtime_kwargs["input_tree"], "tree.qza") + self.assertIn(("input_tree", "tree"), input_bindings) + def test_outputs_are_only_visible_in_all_mode(self) -> None: output_specs = [ Output( diff --git a/tests/test_qapi_build.py b/tests/test_qapi_build.py index 1910348..dd102f6 100644 --- a/tests/test_qapi_build.py +++ b/tests/test_qapi_build.py @@ -6,7 +6,14 @@ from rich.console import Console from adagio.cli import qapi as qapi_cli -from adagio.qapi.build import _iter_public_qiime_actions +from adagio.model.pipeline import AdagioPipeline +from adagio.model.task import ConvertToMetadataTask +from adagio.qapi.build import ( + CONVERT_TO_METADATA_ACTION_ID, + CONVERT_TO_METADATA_ACTION_NAME, + _build_convert_to_metadata_action, + _iter_public_qiime_actions, +) class QapiBuildTests(unittest.TestCase): @@ -89,6 +96,101 @@ def fake_generate_qapi_payload(*, on_skipped_private_action, **kwargs): self.assertIn("Skipped 1 private QIIME action", output.getvalue()) self.assertIn("example._private_action", output.getvalue()) + def test_build_convert_to_metadata_action_uses_union_input(self) -> None: + feature_table_ast = { + "name": "FeatureTable", + "type": "expression", + "fields": [ + { + "name": "Frequency", + "type": "expression", + "fields": [], + "builtin": False, + "predicate": None, + } + ], + "builtin": False, + "predicate": None, + } + alpha_ast = { + "name": "SampleData", + "type": "expression", + "fields": [ + { + "name": "AlphaDiversity", + "type": "expression", + "fields": [], + "builtin": False, + "predicate": None, + } + ], + "builtin": False, + "predicate": None, + } + + action = _build_convert_to_metadata_action( + [ + ("SampleData[AlphaDiversity]", alpha_ast), + ("FeatureTable[Frequency]", feature_table_ast), + ] + ) + + self.assertIsNotNone(action) + self.assertEqual(action["id"], CONVERT_TO_METADATA_ACTION_ID) + self.assertEqual(action["name"], CONVERT_TO_METADATA_ACTION_NAME) + self.assertEqual(action["inputs"][0]["name"], "data") + self.assertEqual( + action["inputs"][0]["type"], + "FeatureTable[Frequency] | SampleData[AlphaDiversity]", + ) + self.assertEqual(action["inputs"][0]["ast"]["type"], "union") + self.assertEqual(action["outputs"][0]["type"], "Metadata") + + def test_convert_to_metadata_task_aliases_input_to_output(self) -> None: + pipeline = AdagioPipeline.model_validate( + { + "type": "pipeline", + "signature": { + "inputs": [ + { + "id": "input-data", + "name": "data", + "type": "FeatureTable[Frequency]", + "ast": { + "name": "FeatureTable", + "type": "expression", + "fields": [], + "builtin": False, + "predicate": None, + }, + "required": True, + "description": None, + } + ], + "parameters": [], + "outputs": [], + }, + "graph": [ + { + "id": "convert-node", + "kind": "built-in", + "name": CONVERT_TO_METADATA_ACTION_NAME, + "inputs": {"data": {"kind": "archive", "id": "input-data"}}, + "parameters": {}, + "outputs": { + "metadata": {"kind": "archive", "id": "metadata-output"} + }, + } + ], + } + ) + + task = pipeline.graph[0] + self.assertIsInstance(task, ConvertToMetadataTask) + scope = {"input-data": "/tmp/table.qza"} + task.exec(ctx=None, params={}, scope=scope) + self.assertEqual(scope["metadata-output"], "/tmp/table.qza") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_serial_runner.py b/tests/test_serial_runner.py index e99f911..3500adf 100644 --- a/tests/test_serial_runner.py +++ b/tests/test_serial_runner.py @@ -3,9 +3,13 @@ from dataclasses import dataclass, field from pathlib import Path +from adagio.executors.base import TaskEnvironmentSpec, TaskExecutionResult from adagio.executors.serial_runner import resolve_pipeline_input, run_serial_pipeline +from adagio.executors.serial_runner import SerialExecutionState +from adagio.executors.task_environments import TaskEnvironmentExecutor from adagio.executors.task_environments import _save_outputs from adagio.model.arguments import AdagioArguments +from adagio.model.task import InputVal, PluginActionTask from adagio.monitor.api import Monitor @@ -20,6 +24,14 @@ class FakeOutputDef: name: str +@dataclass(frozen=True) +class FakeInputDef: + id: str + name: str + type: str + required: bool + + @dataclass class FakeTask: id: str @@ -31,8 +43,10 @@ class FakeTask: class FakeSignature: - def __init__(self, outputs: list[FakeOutputDef]) -> None: - self.inputs: list[object] = [] + def __init__( + self, outputs: list[FakeOutputDef], inputs: list[FakeInputDef] | None = None + ) -> None: + self.inputs: list[FakeInputDef] = inputs or [] self.parameters: list[object] = [] self.outputs = outputs @@ -45,8 +59,14 @@ def get_params(self, arguments: AdagioArguments) -> dict[str, object]: class FakePipeline: - def __init__(self, *, tasks: list[FakeTask], outputs: list[FakeOutputDef]) -> None: - self.signature = FakeSignature(outputs) + def __init__( + self, + *, + tasks: list[FakeTask], + outputs: list[FakeOutputDef], + inputs: list[FakeInputDef] | None = None, + ) -> None: + self.signature = FakeSignature(outputs, inputs) self._tasks = tasks def validate_graph(self) -> None: @@ -81,6 +101,24 @@ def finish_save_output(self) -> None: self.save_finish_count += 1 +class RecordingResolver: + def resolve(self, *, task): # noqa: ANN001 + del task + return TaskEnvironmentSpec(kind="recording", reference="recording") + + +class RecordingLauncher: + kind = "recording" + + def __init__(self) -> None: + self.request = None + + def launch(self, *, environment, request, console=None): # noqa: ANN001 + del environment, console + self.request = request + return TaskExecutionResult(outputs={}) + + class SerialRunnerOutputTests(unittest.TestCase): def test_collection_input_manifest_expands_to_paths(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: @@ -123,6 +161,96 @@ def test_collection_input_list_resolves_each_path(self) -> None: ], ) + def test_omitted_optional_input_does_not_block_task_execution(self) -> None: + pipeline = FakePipeline( + inputs=[ + FakeInputDef( + id="required-input", + name="seqs", + type="SampleData[Sequences]", + required=True, + ), + FakeInputDef( + id="optional-input", + name="tree", + type="Phylogeny[Rooted]", + required=False, + ), + ], + tasks=[ + FakeTask( + id="task-1", + inputs={ + "seqs": InputVal(kind="archive", id="required-input"), + "tree": InputVal(kind="archive", id="optional-input"), + }, + outputs={}, + ) + ], + outputs=[], + ) + arguments = AdagioArguments( + inputs={"seqs": "seqs.qza", "tree": ""}, + parameters={}, + outputs={}, + ) + seen_scope: dict[str, object] = {} + seen_missing_optional_ids: set[str] = set() + + def resolve_task(task, state, console): # noqa: ANN001 + del task, console + seen_scope.update(state.scope) + seen_missing_optional_ids.update(state.missing_optional_ids) + return False + + run_serial_pipeline( + pipeline=pipeline, + arguments=arguments, + resolve_task=resolve_task, + finish_outputs=_save_outputs, + ) + + self.assertIn("required-input", seen_scope) + self.assertNotIn("optional-input", seen_scope) + self.assertIn("optional-input", seen_missing_optional_ids) + + def test_task_environment_executor_omits_missing_optional_inputs(self) -> None: + launcher = RecordingLauncher() + executor = TaskEnvironmentExecutor( + environment_resolver=RecordingResolver(), + launchers={launcher.kind: launcher}, + ) + task = PluginActionTask.model_validate( + { + "id": "task-1", + "kind": "plugin-action", + "plugin": "feature_table", + "action": "tabulate_seqs", + "inputs": { + "data": {"kind": "archive", "id": "data-input"}, + "taxonomy": {"kind": "archive", "id": "taxonomy-input"}, + }, + "parameters": {}, + "outputs": {}, + } + ) + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + state = SerialExecutionState( + cwd=root, + work_path=root, + params={}, + scope={"data-input": "data.qza"}, + cache_config=None, + missing_optional_ids={"taxonomy-input"}, + ) + + executor._resolve_task(task, state, None) + + assert launcher.request is not None + self.assertEqual(launcher.request.archive_inputs, {"data": "data.qza"}) + self.assertEqual(launcher.request.archive_collection_inputs, {}) + def test_preserves_completed_output_when_later_task_fails(self) -> None: output_def = FakeOutputDef(id="out-1", name="result") pipeline = FakePipeline( diff --git a/uv.lock b/uv.lock index 47b7e36..e7a4557 100644 --- a/uv.lock +++ b/uv.lock @@ -4,7 +4,7 @@ requires-python = ">=3.10" [[package]] name = "adagio-cli" -version = "0.1.0a3" +version = "0.1.0a4" source = { editable = "." } dependencies = [ { name = "cyclopts" },