diff --git a/src/adagio/cli/dynamic.py b/src/adagio/cli/dynamic.py index 65437d1..39a1206 100644 --- a/src/adagio/cli/dynamic.py +++ b/src/adagio/cli/dynamic.py @@ -233,6 +233,10 @@ def _is_required_param(spec: ParamSpec) -> bool: return bool(spec.required and spec.default is None) +def _is_collection_type(type_name: str) -> bool: + return type_name.startswith("List[") or type_name.startswith("Collection[") + + def build_dynamic_run( *, input_specs: list[InputSpec], @@ -453,7 +457,7 @@ def add_input_spec(spec: InputSpec) -> None: ident=ident, opt=opt, required=False, - py_type=str, + py_type=list[str] if _is_collection_type(spec.type) else str, help_text=_format_help_text( description=spec.description, ), @@ -573,4 +577,4 @@ def run( def _is_missing(value: Any) -> bool: - return value is None or value == "" + return value is None or value == "" or value == "" or value == [] or value == {} diff --git a/src/adagio/cli/main.py b/src/adagio/cli/main.py index 8df40ef..82bdad7 100644 --- a/src/adagio/cli/main.py +++ b/src/adagio/cli/main.py @@ -256,7 +256,7 @@ def _load_arguments_data(path: Path, _console: Console | None = None) -> dict[st def _is_missing(value: Any) -> bool: - return value is None or value == "" + return value is None or value == "" or value == "" or value == [] or value == {} if __name__ == "__main__": diff --git a/src/adagio/cli/runner.py b/src/adagio/cli/runner.py index 6783355..52c8a2c 100644 --- a/src/adagio/cli/runner.py +++ b/src/adagio/cli/runner.py @@ -100,7 +100,12 @@ def run_pipeline_from_kwargs( for ident, original in input_bindings: value = kwargs.get(ident) if value is not None: - arguments.inputs[original] = str(value) + if isinstance(value, list): + arguments.inputs[original] = [str(item) for item in value] + elif isinstance(value, dict): + arguments.inputs[original] = {str(key): str(item) for key, item in value.items()} + else: + arguments.inputs[original] = str(value) for ident, original in param_bindings: value = kwargs.get(ident) @@ -177,7 +182,7 @@ def run_pipeline_from_kwargs( def _is_missing(value: Any) -> bool: """Treat placeholders and null values as missing.""" - return value is None or value == "" + return value is None or value == "" or value == "" or value == [] or value == {} def _is_missing_output(value: Any) -> bool: diff --git a/src/adagio/cli/runtime.py b/src/adagio/cli/runtime.py index 8328035..5886e24 100644 --- a/src/adagio/cli/runtime.py +++ b/src/adagio/cli/runtime.py @@ -239,7 +239,7 @@ def _apply_named_arguments( raw_inputs = runtime_arguments.get("inputs", {}) if isinstance(raw_inputs, dict): for name, value in raw_inputs.items(): - arguments.inputs[name] = _resolve_input_path( + arguments.inputs[name] = _resolve_input_value( value, storage_root=storage_root ) @@ -283,7 +283,7 @@ def _apply_legacy_arguments( named_inputs = runtime_arguments.get("inputs", {}) if isinstance(named_inputs, dict): for name, value in named_inputs.items(): - arguments.inputs[name] = _resolve_input_path( + arguments.inputs[name] = _resolve_input_value( value, storage_root=storage_root ) @@ -312,6 +312,22 @@ def _resolve_input_path(value: Any, *, storage_root: str) -> str: return str(value) +def _resolve_input_value(value: Any, *, storage_root: str) -> Any: + if isinstance(value, list): + return [_resolve_input_value(item, storage_root=storage_root) for item in value] + if isinstance(value, dict): + path = value.get("path") + if path is not None: + return _normalize_path(path, storage_root=storage_root) + return { + str(key): _resolve_input_value(item, storage_root=storage_root) + for key, item in value.items() + } + if isinstance(value, str): + return _normalize_path(value, storage_root=storage_root) + return str(value) + + def _resolve_outputs(value: Any, *, storage_root: str) -> str | dict[str, str] | None: if value is None: return None @@ -345,7 +361,7 @@ def _outputs_need_default(outputs: str | dict[str, str]) -> bool: def _is_missing(value: Any) -> bool: - return value is None or value == "" or value == "" + return value is None or value == "" or value == "" or value == [] or value == {} def _validate_required_arguments( diff --git a/src/adagio/executors/path_utils.py b/src/adagio/executors/path_utils.py index 245b18e..d17d479 100644 --- a/src/adagio/executors/path_utils.py +++ b/src/adagio/executors/path_utils.py @@ -4,6 +4,8 @@ from .container_support import is_uri +InputSource = str | list[str] | dict[str, str] + def resolve_host_path(*, source: str, cwd: Path) -> str: if is_uri(source): @@ -14,6 +16,17 @@ def resolve_host_path(*, source: str, cwd: Path) -> str: return str((cwd / path).resolve()) +def resolve_host_input(*, source: InputSource, cwd: Path) -> InputSource: + if isinstance(source, list): + return [resolve_host_path(source=item, cwd=cwd) for item in source] + if isinstance(source, dict): + return { + key: resolve_host_path(source=value, cwd=cwd) + for key, value in source.items() + } + return resolve_host_path(source=source, cwd=cwd) + + def resolve_output_destination( *, output_name: str, diff --git a/src/adagio/executors/serial_runner.py b/src/adagio/executors/serial_runner.py index ddbd3ed..d4f06ff 100644 --- a/src/adagio/executors/serial_runner.py +++ b/src/adagio/executors/serial_runner.py @@ -12,8 +12,9 @@ from adagio.monitor.tty import RichMonitor from .cache_support import ExecutionCacheConfig +from .container_support import is_uri from .common import plan_execution_order, task_label -from .path_utils import resolve_host_path +from .path_utils import InputSource, resolve_host_input, resolve_host_path CONTAINER_SUBTASK_COUNT = 1 @@ -23,7 +24,7 @@ class SerialExecutionState: cwd: Path work_path: Path params: dict[str, t.Any] - scope: dict[str, str] + scope: dict[str, InputSource] cache_config: ExecutionCacheConfig | None saved_output_ids: set[str] = field(default_factory=set) save_output_started: bool = False @@ -64,7 +65,9 @@ def run_serial_pipeline( active_monitor.start_load_input() for input_def in sig.inputs: source = arguments.inputs[input_def.name] - state.scope[input_def.id] = resolve_host_path(source=source, cwd=state.cwd) + state.scope[input_def.id] = resolve_pipeline_input( + source=source, type_name=input_def.type, cwd=state.cwd + ) active_monitor.finish_load_input() execution_plan = plan_execution_order(tasks=tasks, scope=state.scope) @@ -133,3 +136,63 @@ def resolve_monitor(*, console: Console | None, monitor: Monitor | None) -> Moni if console is not None: return RichMonitor(console=console) return LogMonitor() + + +def resolve_pipeline_input( + *, source: InputSource, type_name: str, cwd: Path +) -> InputSource: + resolved = resolve_host_input(source=source, cwd=cwd) + if not is_collection_type(type_name): + return resolved + + if isinstance(resolved, str): + return expand_collection_input_source(resolved) + if isinstance(resolved, list): + if len(resolved) == 1: + return expand_collection_input_source(resolved[0]) + return resolved + return list(resolved.values()) + + +def is_collection_type(type_name: str) -> bool: + return type_name.startswith("List[") or type_name.startswith("Collection[") + + +def expand_collection_input_source(source: str) -> list[str]: + path = Path(source) + if ( + not is_uri(source) + and path.suffix.lower() in {".tsv", ".txt"} + and path.is_file() + ): + return read_collection_manifest(path) + return [source] + + +def read_collection_manifest(path: Path) -> list[str]: + rows = [ + line.rstrip("\n").split("\t") + for line in path.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + if not rows: + return [] + + header = [cell.strip().lower() for cell in rows[0]] + path_index = header.index("path") if "path" in header else None + data_rows = rows[1:] if path_index is not None else rows + + result: list[str] = [] + for row in data_rows: + if path_index is not None: + if path_index >= len(row): + continue + raw_path = row[path_index].strip() + elif len(row) >= 2: + raw_path = row[1].strip() + else: + raw_path = row[0].strip() + + if raw_path: + result.append(resolve_host_path(source=raw_path, cwd=path.parent)) + return result diff --git a/src/adagio/executors/task_environments.py b/src/adagio/executors/task_environments.py index 2654b4e..396ed28 100644 --- a/src/adagio/executors/task_environments.py +++ b/src/adagio/executors/task_environments.py @@ -91,13 +91,24 @@ def _execute_plugin_action( metadata_inputs: dict[str, str] = {} for name, src in task.inputs.items(): if src.kind == "archive": - archive_inputs[name] = state.scope[src.id] + value = state.scope[src.id] + if isinstance(value, list): + archive_collection_inputs[name] = value + elif isinstance(value, dict): + archive_collection_inputs[name] = list(value.values()) + else: + archive_inputs[name] = value elif src.kind == "archive-collection": - archive_collection_inputs[name] = [ - state.scope[item.id] for item in src.items - ] + archive_collection_inputs[name] = _flatten_collection_items( + [state.scope[item.id] for item in src.items] + ) elif src.kind == "metadata": - metadata_inputs[name] = state.scope[src.id] + value = state.scope[src.id] + if not isinstance(value, str): + raise TypeError( + f"Metadata input {name!r} must resolve to a single path." + ) + metadata_inputs[name] = value else: raise TypeError(f"Unsupported input kind: {src.kind!r}") @@ -165,6 +176,20 @@ def _execute_plugin_action( return result.reused +def _flatten_collection_items( + values: list[str | list[str] | dict[str, str]], +) -> list[str]: + result: list[str] = [] + for value in values: + if isinstance(value, list): + result.extend(value) + elif isinstance(value, dict): + result.extend(value.values()) + else: + result.append(value) + return result + + def _save_outputs( *, sig, diff --git a/src/adagio/io.py b/src/adagio/io.py index 7e76f38..cfcf5b3 100644 --- a/src/adagio/io.py +++ b/src/adagio/io.py @@ -13,6 +13,23 @@ def load_input(*, ctx, source: str): return [input] + +@lift_parsl(lambda fut: fut) +def load_input_collection(*, ctx, sources): + from qiime2.sdk import Artifact + from qiime2.sdk import PluginManager + + PluginManager() + + if isinstance(sources, dict): + sources = list(sources.values()) + elif isinstance(sources, str): + sources = [sources] + + with ctx.cache: + return [Artifact.load(source) for source in sources] + + @lift_parsl(ProxyMetadata) def load_metadata(*, ctx, source: str): from qiime2 import Artifact, Metadata @@ -38,4 +55,4 @@ def convert_metadata(*, ctx, metadata): if isinstance(metadata, qiime2.Artifact): metadata = metadata.view(qiime2.Metadata) - return metadata \ No newline at end of file + return metadata diff --git a/src/adagio/model/arguments.py b/src/adagio/model/arguments.py index cc8f01c..4d372dc 100644 --- a/src/adagio/model/arguments.py +++ b/src/adagio/model/arguments.py @@ -3,9 +3,11 @@ from .task import AllowableValue +InputValue = str | list[str] | dict[str, str] + class AdagioArguments(BaseModel): - inputs: dict[str, str] + inputs: dict[str, InputValue] parameters: dict[str, AllowableValue] outputs: str | dict[str, str] @@ -34,6 +36,6 @@ class AdagioArgumentsFile(BaseModel): """Represent arguments loaded from a JSON file.""" version: int = 1 - inputs: dict[str, str] = Field(default_factory=dict) + inputs: dict[str, InputValue] = Field(default_factory=dict) parameters: dict[str, AllowableValue] = Field(default_factory=dict) outputs: str | dict[str, str] | None = None diff --git a/src/adagio/model/pipeline.py b/src/adagio/model/pipeline.py index 17e83bc..4432646 100644 --- a/src/adagio/model/pipeline.py +++ b/src/adagio/model/pipeline.py @@ -65,7 +65,7 @@ def get_params(self, args: AdagioArguments): return lookup def load_inputs(self, ctx, arguments, scope): - from adagio.io import load_input, load_metadata + from adagio.io import load_input, load_input_collection, load_metadata for input in self.inputs: source = arguments.inputs[input.name] @@ -74,6 +74,10 @@ def load_inputs(self, ctx, arguments, scope): scope[input.id] = load_metadata(ctx=ctx, source=source) # IIFE for the dreaded for-loop in the parent closure problem. scope[input.id]._future_.add_done_callback((lambda str: (lambda x: print("DONE:", str)))(f'load_metadata({source!r})')) + elif _is_collection_type(input.type): + print("SCHEDULED:", f'load_input_collection({source!r})') + scope[input.id] = load_input_collection(ctx=ctx, sources=source) + scope[input.id]._future_.add_done_callback((lambda str: (lambda x: print("DONE:", str)))(f'load_input_collection({source!r})')) else: print("SCHEDULED:", f'load_input({source!r})') scope[input.id] = load_input(ctx=ctx, source=source) @@ -132,3 +136,7 @@ def _is_metadata_ast(ast: TypeAST) -> bool: if isinstance(ast, (TypeASTUnion, TypeASTIntersection)): return any(_is_metadata_ast(member) for member in ast.members) return False + + +def _is_collection_type(type_name: str) -> bool: + return type_name.startswith('List[') or type_name.startswith('Collection[') diff --git a/src/adagio/model/task.py b/src/adagio/model/task.py index f932d68..a9b3b5d 100644 --- a/src/adagio/model/task.py +++ b/src/adagio/model/task.py @@ -30,7 +30,9 @@ def exec(self, ctx, params, scope): if src.kind == 'archive': kwargs[name] = scope[src.id] elif src.kind == 'archive-collection': - kwargs[name] = [scope[item.id] for item in src.items] + kwargs[name] = _flatten_collection_values( + [scope[item.id] for item in src.items] + ) elif src.kind == 'metadata': # store for second pass in params metadata[name] = scope[src.id] @@ -133,3 +135,15 @@ def input_source_ids(value: TaskInputVal) -> list[str]: if value.kind == 'archive-collection': return [item.id for item in value.items] return [value.id] + + +def _flatten_collection_values(values: list[t.Any]) -> list[t.Any]: + result: list[t.Any] = [] + for value in values: + if isinstance(value, list): + result.extend(value) + elif isinstance(value, dict): + result.extend(value.values()) + else: + result.append(value) + return result diff --git a/tests/test_output_options.py b/tests/test_output_options.py index 229ab8e..a430143 100644 --- a/tests/test_output_options.py +++ b/tests/test_output_options.py @@ -57,6 +57,27 @@ def test_dynamic_run_adds_output_dir_and_per_output_options(self) -> None: self.assertIn("Denoised feature table.", output_help) self.assertIn("Overrides --output-dir", output_help) + def test_dynamic_run_uses_variadic_options_for_collection_inputs(self) -> None: + dynamic_run = build_dynamic_run( + input_specs=[ + Input( + id="00000000-0000-0000-0000-000000000001", + name="matrices", + required=True, + type="List[DistanceMatrix]", + description="Distance matrices.", + ) + ], + param_specs=[], + output_specs=[], + run_handler=lambda *args, **kwargs: None, + ) + + annotation = dynamic_run.__signature__.parameters["input_matrices"].annotation + value_type = typing.get_args(annotation)[0] + + self.assertIn(list[str], typing.get_args(value_type)) + def test_output_dir_is_a_command_option_and_required_pipeline_options_are_first( self, ) -> None: diff --git a/tests/test_serial_runner.py b/tests/test_serial_runner.py index 8ccad51..e99f911 100644 --- a/tests/test_serial_runner.py +++ b/tests/test_serial_runner.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from pathlib import Path -from adagio.executors.serial_runner import run_serial_pipeline +from adagio.executors.serial_runner import resolve_pipeline_input, run_serial_pipeline from adagio.executors.task_environments import _save_outputs from adagio.model.arguments import AdagioArguments from adagio.monitor.api import Monitor @@ -82,6 +82,47 @@ def finish_save_output(self) -> None: class SerialRunnerOutputTests(unittest.TestCase): + def test_collection_input_manifest_expands_to_paths(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + manifest = root / "matrices.tsv" + manifest.write_text( + "key\tpath\n1\tdm-a.qza\n2\tdata/dm-b.qza\n", + encoding="utf-8", + ) + + resolved = resolve_pipeline_input( + source=str(manifest), + type_name="List[DistanceMatrix]", + cwd=root, + ) + + self.assertEqual( + resolved, + [ + str((root / "dm-a.qza").resolve()), + str((root / "data" / "dm-b.qza").resolve()), + ], + ) + + def test_collection_input_list_resolves_each_path(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + + resolved = resolve_pipeline_input( + source=["dm-a.qza", "nested/dm-b.qza"], + type_name="List[DistanceMatrix]", + cwd=root, + ) + + self.assertEqual( + resolved, + [ + str((root / "dm-a.qza").resolve()), + str((root / "nested" / "dm-b.qza").resolve()), + ], + ) + def test_preserves_completed_output_when_later_task_fails(self) -> None: output_def = FakeOutputDef(id="out-1", name="result") pipeline = FakePipeline(