From b949dbde4c38c62c2415414d949f85c6803e4605 Mon Sep 17 00:00:00 2001 From: John Chase Date: Mon, 27 Apr 2026 20:35:43 -0700 Subject: [PATCH 1/3] Implements collections handling in acl --- src/adagio/cli/dynamic.py | 26 ++++++-- src/adagio/cli/main.py | 2 +- src/adagio/cli/runner.py | 17 ++++- src/adagio/cli/runtime.py | 28 ++++++++- src/adagio/executors/path_utils.py | 13 ++++ src/adagio/executors/serial_runner.py | 69 +++++++++++++++++++- src/adagio/executors/task_environments.py | 35 +++++++++-- src/adagio/io.py | 25 ++++++-- src/adagio/model/arguments.py | 25 ++++---- src/adagio/model/pipeline.py | 72 +++++++++++++-------- src/adagio/model/task.py | 76 ++++++++++++++--------- tests/test_output_options.py | 38 ++++++++++-- tests/test_serial_runner.py | 43 ++++++++++++- 13 files changed, 373 insertions(+), 96 deletions(-) diff --git a/src/adagio/cli/dynamic.py b/src/adagio/cli/dynamic.py index 65437d1..75ac059 100644 --- a/src/adagio/cli/dynamic.py +++ b/src/adagio/cli/dynamic.py @@ -79,7 +79,9 @@ def _pipeline_type_label(type_hint: Any) -> str: return "TEXT" -def _display_type_label(*, spec_type: str | None, type_hint: Any, is_input: bool) -> str: +def _display_type_label( + *, spec_type: str | None, type_hint: Any, is_input: bool +) -> str: if is_input: return path_type_label(spec_type) @@ -233,6 +235,10 @@ def _is_required_param(spec: ParamSpec) -> bool: return bool(spec.required and spec.default is None) +def _is_collection_type(type_name: str) -> bool: + return type_name.startswith("List[") or type_name.startswith("Collection[") + + def build_dynamic_run( *, input_specs: list[InputSpec], @@ -424,7 +430,9 @@ def add_dynamic_option( required_input_specs = [spec for spec in input_specs if spec.required] optional_input_specs = [spec for spec in input_specs if not spec.required] required_param_specs = [spec for spec in param_specs if _is_required_param(spec)] - optional_param_specs = [spec for spec in param_specs if not _is_required_param(spec)] + optional_param_specs = [ + spec for spec in param_specs if not _is_required_param(spec) + ] def add_input_spec(spec: InputSpec) -> None: original = spec.name @@ -453,7 +461,7 @@ def add_input_spec(spec: InputSpec) -> None: ident=ident, opt=opt, required=False, - py_type=str, + py_type=list[str] if _is_collection_type(spec.type) else str, help_text=_format_help_text( description=spec.description, ), @@ -476,7 +484,9 @@ def add_param_spec(spec: ParamSpec) -> None: argument_value = argument_params.get(original) has_argument_default = not _is_missing(argument_value) display_default = ( - default if default is not None else (argument_value if has_argument_default else None) + default + if default is not None + else (argument_value if has_argument_default else None) ) display_required = is_required and display_default is None param_default = None @@ -573,4 +583,10 @@ def run( def _is_missing(value: Any) -> bool: - return value is None or value == "" + return ( + value is None + or value == "" + or value == "" + or value == [] + or value == {} + ) diff --git a/src/adagio/cli/main.py b/src/adagio/cli/main.py index 8df40ef..82bdad7 100644 --- a/src/adagio/cli/main.py +++ b/src/adagio/cli/main.py @@ -256,7 +256,7 @@ def _load_arguments_data(path: Path, _console: Console | None = None) -> dict[st def _is_missing(value: Any) -> bool: - return value is None or value == "" + return value is None or value == "" or value == "" or value == [] or value == {} if __name__ == "__main__": diff --git a/src/adagio/cli/runner.py b/src/adagio/cli/runner.py index 6783355..deb94a9 100644 --- a/src/adagio/cli/runner.py +++ b/src/adagio/cli/runner.py @@ -100,7 +100,14 @@ def run_pipeline_from_kwargs( for ident, original in input_bindings: value = kwargs.get(ident) if value is not None: - arguments.inputs[original] = str(value) + if isinstance(value, list): + arguments.inputs[original] = [str(item) for item in value] + elif isinstance(value, dict): + arguments.inputs[original] = { + str(key): str(item) for key, item in value.items() + } + else: + arguments.inputs[original] = str(value) for ident, original in param_bindings: value = kwargs.get(ident) @@ -177,7 +184,13 @@ def run_pipeline_from_kwargs( def _is_missing(value: Any) -> bool: """Treat placeholders and null values as missing.""" - return value is None or value == "" + return ( + value is None + or value == "" + or value == "" + or value == [] + or value == {} + ) def _is_missing_output(value: Any) -> bool: diff --git a/src/adagio/cli/runtime.py b/src/adagio/cli/runtime.py index 8328035..59313e4 100644 --- a/src/adagio/cli/runtime.py +++ b/src/adagio/cli/runtime.py @@ -239,7 +239,7 @@ def _apply_named_arguments( raw_inputs = runtime_arguments.get("inputs", {}) if isinstance(raw_inputs, dict): for name, value in raw_inputs.items(): - arguments.inputs[name] = _resolve_input_path( + arguments.inputs[name] = _resolve_input_value( value, storage_root=storage_root ) @@ -283,7 +283,7 @@ def _apply_legacy_arguments( named_inputs = runtime_arguments.get("inputs", {}) if isinstance(named_inputs, dict): for name, value in named_inputs.items(): - arguments.inputs[name] = _resolve_input_path( + arguments.inputs[name] = _resolve_input_value( value, storage_root=storage_root ) @@ -312,6 +312,22 @@ def _resolve_input_path(value: Any, *, storage_root: str) -> str: return str(value) +def _resolve_input_value(value: Any, *, storage_root: str) -> Any: + if isinstance(value, list): + return [_resolve_input_value(item, storage_root=storage_root) for item in value] + if isinstance(value, dict): + path = value.get("path") + if path is not None: + return _normalize_path(path, storage_root=storage_root) + return { + str(key): _resolve_input_value(item, storage_root=storage_root) + for key, item in value.items() + } + if isinstance(value, str): + return _normalize_path(value, storage_root=storage_root) + return str(value) + + def _resolve_outputs(value: Any, *, storage_root: str) -> str | dict[str, str] | None: if value is None: return None @@ -345,7 +361,13 @@ def _outputs_need_default(outputs: str | dict[str, str]) -> bool: def _is_missing(value: Any) -> bool: - return value is None or value == "" or value == "" + return ( + value is None + or value == "" + or value == "" + or value == [] + or value == {} + ) def _validate_required_arguments( diff --git a/src/adagio/executors/path_utils.py b/src/adagio/executors/path_utils.py index 245b18e..d17d479 100644 --- a/src/adagio/executors/path_utils.py +++ b/src/adagio/executors/path_utils.py @@ -4,6 +4,8 @@ from .container_support import is_uri +InputSource = str | list[str] | dict[str, str] + def resolve_host_path(*, source: str, cwd: Path) -> str: if is_uri(source): @@ -14,6 +16,17 @@ def resolve_host_path(*, source: str, cwd: Path) -> str: return str((cwd / path).resolve()) +def resolve_host_input(*, source: InputSource, cwd: Path) -> InputSource: + if isinstance(source, list): + return [resolve_host_path(source=item, cwd=cwd) for item in source] + if isinstance(source, dict): + return { + key: resolve_host_path(source=value, cwd=cwd) + for key, value in source.items() + } + return resolve_host_path(source=source, cwd=cwd) + + def resolve_output_destination( *, output_name: str, diff --git a/src/adagio/executors/serial_runner.py b/src/adagio/executors/serial_runner.py index ddbd3ed..d4f06ff 100644 --- a/src/adagio/executors/serial_runner.py +++ b/src/adagio/executors/serial_runner.py @@ -12,8 +12,9 @@ from adagio.monitor.tty import RichMonitor from .cache_support import ExecutionCacheConfig +from .container_support import is_uri from .common import plan_execution_order, task_label -from .path_utils import resolve_host_path +from .path_utils import InputSource, resolve_host_input, resolve_host_path CONTAINER_SUBTASK_COUNT = 1 @@ -23,7 +24,7 @@ class SerialExecutionState: cwd: Path work_path: Path params: dict[str, t.Any] - scope: dict[str, str] + scope: dict[str, InputSource] cache_config: ExecutionCacheConfig | None saved_output_ids: set[str] = field(default_factory=set) save_output_started: bool = False @@ -64,7 +65,9 @@ def run_serial_pipeline( active_monitor.start_load_input() for input_def in sig.inputs: source = arguments.inputs[input_def.name] - state.scope[input_def.id] = resolve_host_path(source=source, cwd=state.cwd) + state.scope[input_def.id] = resolve_pipeline_input( + source=source, type_name=input_def.type, cwd=state.cwd + ) active_monitor.finish_load_input() execution_plan = plan_execution_order(tasks=tasks, scope=state.scope) @@ -133,3 +136,63 @@ def resolve_monitor(*, console: Console | None, monitor: Monitor | None) -> Moni if console is not None: return RichMonitor(console=console) return LogMonitor() + + +def resolve_pipeline_input( + *, source: InputSource, type_name: str, cwd: Path +) -> InputSource: + resolved = resolve_host_input(source=source, cwd=cwd) + if not is_collection_type(type_name): + return resolved + + if isinstance(resolved, str): + return expand_collection_input_source(resolved) + if isinstance(resolved, list): + if len(resolved) == 1: + return expand_collection_input_source(resolved[0]) + return resolved + return list(resolved.values()) + + +def is_collection_type(type_name: str) -> bool: + return type_name.startswith("List[") or type_name.startswith("Collection[") + + +def expand_collection_input_source(source: str) -> list[str]: + path = Path(source) + if ( + not is_uri(source) + and path.suffix.lower() in {".tsv", ".txt"} + and path.is_file() + ): + return read_collection_manifest(path) + return [source] + + +def read_collection_manifest(path: Path) -> list[str]: + rows = [ + line.rstrip("\n").split("\t") + for line in path.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + if not rows: + return [] + + header = [cell.strip().lower() for cell in rows[0]] + path_index = header.index("path") if "path" in header else None + data_rows = rows[1:] if path_index is not None else rows + + result: list[str] = [] + for row in data_rows: + if path_index is not None: + if path_index >= len(row): + continue + raw_path = row[path_index].strip() + elif len(row) >= 2: + raw_path = row[1].strip() + else: + raw_path = row[0].strip() + + if raw_path: + result.append(resolve_host_path(source=raw_path, cwd=path.parent)) + return result diff --git a/src/adagio/executors/task_environments.py b/src/adagio/executors/task_environments.py index 2654b4e..396ed28 100644 --- a/src/adagio/executors/task_environments.py +++ b/src/adagio/executors/task_environments.py @@ -91,13 +91,24 @@ def _execute_plugin_action( metadata_inputs: dict[str, str] = {} for name, src in task.inputs.items(): if src.kind == "archive": - archive_inputs[name] = state.scope[src.id] + value = state.scope[src.id] + if isinstance(value, list): + archive_collection_inputs[name] = value + elif isinstance(value, dict): + archive_collection_inputs[name] = list(value.values()) + else: + archive_inputs[name] = value elif src.kind == "archive-collection": - archive_collection_inputs[name] = [ - state.scope[item.id] for item in src.items - ] + archive_collection_inputs[name] = _flatten_collection_items( + [state.scope[item.id] for item in src.items] + ) elif src.kind == "metadata": - metadata_inputs[name] = state.scope[src.id] + value = state.scope[src.id] + if not isinstance(value, str): + raise TypeError( + f"Metadata input {name!r} must resolve to a single path." + ) + metadata_inputs[name] = value else: raise TypeError(f"Unsupported input kind: {src.kind!r}") @@ -165,6 +176,20 @@ def _execute_plugin_action( return result.reused +def _flatten_collection_items( + values: list[str | list[str] | dict[str, str]], +) -> list[str]: + result: list[str] = [] + for value in values: + if isinstance(value, list): + result.extend(value) + elif isinstance(value, dict): + result.extend(value.values()) + else: + result.append(value) + return result + + def _save_outputs( *, sig, diff --git a/src/adagio/io.py b/src/adagio/io.py index 7e76f38..4d8cdb6 100644 --- a/src/adagio/io.py +++ b/src/adagio/io.py @@ -1,11 +1,11 @@ - from adagio.execution.proxy import ProxyMetadata, lift_parsl, IndexedProxyArtifact @lift_parsl(lambda fut: IndexedProxyArtifact(fut, 0)) def load_input(*, ctx, source: str): - from qiime2.sdk import Results, Artifact + from qiime2.sdk import Artifact from qiime2.sdk import PluginManager + PluginManager() with ctx.cache: @@ -13,10 +13,28 @@ def load_input(*, ctx, source: str): return [input] + +@lift_parsl(lambda fut: fut) +def load_input_collection(*, ctx, sources): + from qiime2.sdk import Artifact + from qiime2.sdk import PluginManager + + PluginManager() + + if isinstance(sources, dict): + sources = list(sources.values()) + elif isinstance(sources, str): + sources = [sources] + + with ctx.cache: + return [Artifact.load(source) for source in sources] + + @lift_parsl(ProxyMetadata) def load_metadata(*, ctx, source: str): from qiime2 import Artifact, Metadata import zipfile + if zipfile.is_zipfile(source): metadata = Artifact.load(source).view(Metadata) else: @@ -25,7 +43,6 @@ def load_metadata(*, ctx, source: str): return metadata - @lift_parsl(lambda fut: fut) def save_output(*, ctx, output, destination): output.save(destination) @@ -38,4 +55,4 @@ def convert_metadata(*, ctx, metadata): if isinstance(metadata, qiime2.Artifact): metadata = metadata.view(qiime2.Metadata) - return metadata \ No newline at end of file + return metadata diff --git a/src/adagio/model/arguments.py b/src/adagio/model/arguments.py index cc8f01c..b991815 100644 --- a/src/adagio/model/arguments.py +++ b/src/adagio/model/arguments.py @@ -1,31 +1,34 @@ -import typing as t from pydantic import BaseModel, Field from .task import AllowableValue +InputValue = str | list[str] | dict[str, str] + class AdagioArguments(BaseModel): - inputs: dict[str, str] + inputs: dict[str, InputValue] parameters: dict[str, AllowableValue] outputs: str | dict[str, str] def __repr__(self): """Format arguments for display.""" - return '\n'.join([ - *self._format_repr_sect(self.inputs, 'inputs'), - *self._format_repr_sect(self.parameters, 'parameters'), - *self._format_repr_sect(self.outputs, 'outputs'), - ]) + return "\n".join( + [ + *self._format_repr_sect(self.inputs, "inputs"), + *self._format_repr_sect(self.parameters, "parameters"), + *self._format_repr_sect(self.outputs, "outputs"), + ] + ) def _format_repr_sect(self, section, name): """Format a single argument section.""" lines = [] if not section: - lines.append(f'{name}: {{}}') + lines.append(f"{name}: {{}}") else: - lines.append(f'{name}:') + lines.append(f"{name}:") for name, value in section.items(): - lines.append(f' {name}: {value!r}') + lines.append(f" {name}: {value!r}") return lines @@ -34,6 +37,6 @@ class AdagioArgumentsFile(BaseModel): """Represent arguments loaded from a JSON file.""" version: int = 1 - inputs: dict[str, str] = Field(default_factory=dict) + inputs: dict[str, InputValue] = Field(default_factory=dict) parameters: dict[str, AllowableValue] = Field(default_factory=dict) outputs: str | dict[str, str] | None = None diff --git a/src/adagio/model/pipeline.py b/src/adagio/model/pipeline.py index 17e83bc..7176616 100644 --- a/src/adagio/model/pipeline.py +++ b/src/adagio/model/pipeline.py @@ -1,8 +1,7 @@ import typing as t import os -import json -from pydantic import BaseModel, RootModel, model_validator, Field +from pydantic import BaseModel, RootModel, model_validator from .arguments import AdagioArguments @@ -11,53 +10,51 @@ class AdagioPipeline(BaseModel): - type: t.Literal['pipeline'] + type: t.Literal["pipeline"] # meta: 'AdagioPipelineMetadata' - signature: 'AdagioSignature' - graph: list['AdagioTask'] + signature: "AdagioSignature" + graph: list["AdagioTask"] def validate_graph(self): pass - def iter_tasks(self) -> t.Generator['AdagioTask', None, None]: + def iter_tasks(self) -> t.Generator["AdagioTask", None, None]: yield from self.graph - class AdagioPipelineMetadata(RootModel): root: dict[str, t.Any] - @model_validator(mode='before') + @model_validator(mode="before") def check_version(cls, data): - if 'version' not in data: + if "version" not in data: raise AssertionError('Missing "version" field.') class AdagioSignature(BaseModel): - inputs: 'list[_InputDef]' - parameters: 'list[_ParameterDef]' - outputs: 'list[_OutputDef]' + inputs: "list[_InputDef]" + parameters: "list[_ParameterDef]" + outputs: "list[_OutputDef]" def to_default_arguments(self): inputs = {} for input in self.inputs: - inputs[input.name] = '' + inputs[input.name] = "" params = {} for param in self.parameters: if param.required: - params[param.name] = '' + params[param.name] = "" else: params[param.name] = param.default outputs = {} for output in self.outputs: - outputs[output.name] = '' + outputs[output.name] = "" return AdagioArguments(inputs=inputs, parameters=params, outputs=outputs) def validate_arguments(self, args: AdagioArguments): return - def get_params(self, args: AdagioArguments): lookup = {} for param in self.parameters: @@ -65,20 +62,36 @@ def get_params(self, args: AdagioArguments): return lookup def load_inputs(self, ctx, arguments, scope): - from adagio.io import load_input, load_metadata + from adagio.io import load_input, load_input_collection, load_metadata for input in self.inputs: source = arguments.inputs[input.name] if _is_metadata_ast(input.ast): - print("SCHEDULED:", f'load_metadata({source!r})') + print("SCHEDULED:", f"load_metadata({source!r})") scope[input.id] = load_metadata(ctx=ctx, source=source) # IIFE for the dreaded for-loop in the parent closure problem. - scope[input.id]._future_.add_done_callback((lambda str: (lambda x: print("DONE:", str)))(f'load_metadata({source!r})')) + scope[input.id]._future_.add_done_callback( + (lambda str: lambda x: print("DONE:", str))( + f"load_metadata({source!r})" + ) + ) + elif _is_collection_type(input.type): + print("SCHEDULED:", f"load_input_collection({source!r})") + scope[input.id] = load_input_collection(ctx=ctx, sources=source) + scope[input.id]._future_.add_done_callback( + (lambda str: lambda x: print("DONE:", str))( + f"load_input_collection({source!r})" + ) + ) else: - print("SCHEDULED:", f'load_input({source!r})') + print("SCHEDULED:", f"load_input({source!r})") scope[input.id] = load_input(ctx=ctx, source=source) # IIFE for the dreaded for-loop in the parent closure problem. - scope[input.id]._future_.add_done_callback((lambda str: (lambda x: print("DONE:", str)))(f'load_input({source!r})')) + scope[input.id]._future_.add_done_callback( + (lambda str: lambda x: print("DONE:", str))( + f"load_input({source!r})" + ) + ) def save_outputs(self, ctx, arguments: AdagioArguments, scope): from adagio.io import save_output @@ -90,11 +103,15 @@ def save_outputs(self, ctx, arguments: AdagioArguments, scope): elif type(arguments.outputs) is dict: dest = arguments.outputs[output.name] else: - raise NotImplementedError('impossible') - print("SCHEDULED:", f'{output.name}.save({dest!r})') + raise NotImplementedError("impossible") + print("SCHEDULED:", f"{output.name}.save({dest!r})") future = save_output(ctx=ctx, output=scope[output.id], destination=dest) # IIFE for the dreaded for-loop in the parent closure problem. - future.add_done_callback((lambda str: (lambda x: print("DONE:", str)))(f'{output.name}.save({dest!r})')) + future.add_done_callback( + (lambda str: lambda x: print("DONE:", str))( + f"{output.name}.save({dest!r})" + ) + ) futures.append(future) for future in futures: @@ -104,7 +121,6 @@ def save_outputs(self, ctx, arguments: AdagioArguments, scope): pass - class _Def(BaseModel): id: str name: str @@ -119,7 +135,7 @@ class _InputDef(_Def): class _ParameterDef(_Def): required: bool - default: 'AllowableValue | None' = None + default: "AllowableValue | None" = None class _OutputDef(_Def): @@ -132,3 +148,7 @@ def _is_metadata_ast(ast: TypeAST) -> bool: if isinstance(ast, (TypeASTUnion, TypeASTIntersection)): return any(_is_metadata_ast(member) for member in ast.members) return False + + +def _is_collection_type(type_name: str) -> bool: + return type_name.startswith("List[") or type_name.startswith("Collection[") diff --git a/src/adagio/model/task.py b/src/adagio/model/task.py index f932d68..fb8d001 100644 --- a/src/adagio/model/task.py +++ b/src/adagio/model/task.py @@ -5,9 +5,9 @@ class _BaseTask(BaseModel): id: str kind: str - inputs: dict[str, 'TaskInputVal'] - parameters: dict[str, 'LiteralVal | MetadataVal | PromotedVal'] - outputs: dict[str, 'OutputVal'] + inputs: dict[str, "TaskInputVal"] + parameters: dict[str, "LiteralVal | MetadataVal | PromotedVal"] + outputs: dict[str, "OutputVal"] def exec(self, ctx, params, scope): raise NotImplementedError @@ -15,7 +15,7 @@ def exec(self, ctx, params, scope): class PluginActionTask(_BaseTask): id: str - kind: t.Literal['plugin-action'] + kind: t.Literal["plugin-action"] name: str | None = None plugin: str action: str @@ -27,37 +27,39 @@ def exec(self, ctx, params, scope): kwargs = {} metadata = {} for name, src in self.inputs.items(): - if src.kind == 'archive': + if src.kind == "archive": kwargs[name] = scope[src.id] - elif src.kind == 'archive-collection': - kwargs[name] = [scope[item.id] for item in src.items] - elif src.kind == 'metadata': + elif src.kind == "archive-collection": + kwargs[name] = _flatten_collection_values( + [scope[item.id] for item in src.items] + ) + elif src.kind == "metadata": # store for second pass in params metadata[name] = scope[src.id] else: - raise NotImplementedError('impossible') + raise NotImplementedError("impossible") for name, param in self.parameters.items(): - if param.kind == 'metadata': - if param.column.kind == 'literal': + if param.kind == "metadata": + if param.column.kind == "literal": col = param.value - elif param.column.kind == 'promoted': + elif param.column.kind == "promoted": col = params[param.column.id] else: - raise NotImplementedError('impossible') + raise NotImplementedError("impossible") source = metadata.pop(name) md = convert_metadata(ctx=ctx, metadata=source) kwargs[name] = md.get_column(col) - elif param.kind == 'literal': + elif param.kind == "literal": kwargs[name] = param.value - elif param.kind == 'promoted': + elif param.kind == "promoted": kwargs[name] = params[param.id] else: - raise NotImplementedError('impossible') + raise NotImplementedError("impossible") # any remaining metadata is used directly for name, value in metadata.items(): @@ -69,8 +71,8 @@ def exec(self, ctx, params, scope): class RootInputTask(_BaseTask): - kind: t.Literal['built-in'] - name: t.Literal['root-input'] + kind: t.Literal["built-in"] + name: t.Literal["root-input"] def exec(self, ctx, params, scope): for name, src in self.inputs.items(): @@ -79,7 +81,7 @@ def exec(self, ctx, params, scope): class InputVal(BaseModel): - kind: t.Literal['archive', 'metadata'] + kind: t.Literal["archive", "metadata"] id: str @@ -89,24 +91,24 @@ class ArchiveCollectionItemVal(BaseModel): class ArchiveCollectionInputVal(BaseModel): - kind: t.Literal['archive-collection'] - style: t.Literal['list'] + kind: t.Literal["archive-collection"] + style: t.Literal["list"] items: list[ArchiveCollectionItemVal] class OutputVal(BaseModel): - kind: t.Literal['archive'] + kind: t.Literal["archive"] id: str class PromotedVal(BaseModel): - kind: t.Literal['promoted'] + kind: t.Literal["promoted"] id: str class LiteralVal(BaseModel): - kind: t.Literal['literal'] - value: 'AllowableValue' + kind: t.Literal["literal"] + value: "AllowableValue" class LiteralStrVal(LiteralVal): @@ -114,7 +116,7 @@ class LiteralStrVal(LiteralVal): class MetadataVal(BaseModel): - kind: t.Literal['metadata'] + kind: t.Literal["metadata"] column: PromotedVal | LiteralStrVal @@ -122,14 +124,26 @@ class MetadataVal(BaseModel): Collection = list[Primitive] | dict[str, Primitive] AllowableValue = Primitive | Collection TaskInputVal = t.Annotated[ - t.Union[InputVal, ArchiveCollectionInputVal], - Field(discriminator='kind') + t.Union[InputVal, ArchiveCollectionInputVal], Field(discriminator="kind") +] +AdagioTask = t.Annotated[ + t.Union[PluginActionTask, RootInputTask], Field(discriminator="kind") ] -AdagioTask = t.Annotated[t.Union[PluginActionTask, RootInputTask], - Field(discriminator='kind')] def input_source_ids(value: TaskInputVal) -> list[str]: - if value.kind == 'archive-collection': + if value.kind == "archive-collection": return [item.id for item in value.items] return [value.id] + + +def _flatten_collection_values(values: list[t.Any]) -> list[t.Any]: + result: list[t.Any] = [] + for value in values: + if isinstance(value, list): + result.extend(value) + elif isinstance(value, dict): + result.extend(value.values()) + else: + result.append(value) + return result diff --git a/tests/test_output_options.py b/tests/test_output_options.py index 229ab8e..6bdfa40 100644 --- a/tests/test_output_options.py +++ b/tests/test_output_options.py @@ -48,8 +48,12 @@ def test_dynamic_run_adds_output_dir_and_per_output_options(self) -> None: self.assertIn("output_dir", dynamic_run.__signature__.parameters) self.assertIn("output_table", dynamic_run.__signature__.parameters) - output_dir_annotation = dynamic_run.__signature__.parameters["output_dir"].annotation - output_annotation = dynamic_run.__signature__.parameters["output_table"].annotation + output_dir_annotation = dynamic_run.__signature__.parameters[ + "output_dir" + ].annotation + output_annotation = dynamic_run.__signature__.parameters[ + "output_table" + ].annotation output_dir_help = typing.get_args(output_dir_annotation)[1].help output_help = typing.get_args(output_annotation)[1].help @@ -57,6 +61,27 @@ def test_dynamic_run_adds_output_dir_and_per_output_options(self) -> None: self.assertIn("Denoised feature table.", output_help) self.assertIn("Overrides --output-dir", output_help) + def test_dynamic_run_uses_variadic_options_for_collection_inputs(self) -> None: + dynamic_run = build_dynamic_run( + input_specs=[ + Input( + id="00000000-0000-0000-0000-000000000001", + name="matrices", + required=True, + type="List[DistanceMatrix]", + description="Distance matrices.", + ) + ], + param_specs=[], + output_specs=[], + run_handler=lambda *args, **kwargs: None, + ) + + annotation = dynamic_run.__signature__.parameters["input_matrices"].annotation + value_type = typing.get_args(annotation)[0] + + self.assertIn(list[str], typing.get_args(value_type)) + def test_output_dir_is_a_command_option_and_required_pipeline_options_are_first( self, ) -> None: @@ -106,7 +131,9 @@ def test_output_dir_is_a_command_option_and_required_pipeline_options_are_first( run_handler=lambda *args, **kwargs: None, ) - output_dir_annotation = dynamic_run.__signature__.parameters["output_dir"].annotation + output_dir_annotation = dynamic_run.__signature__.parameters[ + "output_dir" + ].annotation output_dir_group = typing.get_args(output_dir_annotation)[1].group self.assertEqual(output_dir_group[0]._name, "Command Options") @@ -171,7 +198,10 @@ def test_outputs_are_only_visible_in_all_mode(self) -> None: def test_output_dir_override_applies_to_all_outputs(self) -> None: resolved = _apply_output_overrides( - outputs={"table": "/tmp/from-file/table.qza", "stats": "/tmp/from-file/stats.qza"}, + outputs={ + "table": "/tmp/from-file/table.qza", + "stats": "/tmp/from-file/stats.qza", + }, output_names=["table", "stats"], output_dir="/tmp/all-outputs", output_overrides={"stats": "/tmp/custom/stats.qza"}, diff --git a/tests/test_serial_runner.py b/tests/test_serial_runner.py index 8ccad51..e99f911 100644 --- a/tests/test_serial_runner.py +++ b/tests/test_serial_runner.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from pathlib import Path -from adagio.executors.serial_runner import run_serial_pipeline +from adagio.executors.serial_runner import resolve_pipeline_input, run_serial_pipeline from adagio.executors.task_environments import _save_outputs from adagio.model.arguments import AdagioArguments from adagio.monitor.api import Monitor @@ -82,6 +82,47 @@ def finish_save_output(self) -> None: class SerialRunnerOutputTests(unittest.TestCase): + def test_collection_input_manifest_expands_to_paths(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + manifest = root / "matrices.tsv" + manifest.write_text( + "key\tpath\n1\tdm-a.qza\n2\tdata/dm-b.qza\n", + encoding="utf-8", + ) + + resolved = resolve_pipeline_input( + source=str(manifest), + type_name="List[DistanceMatrix]", + cwd=root, + ) + + self.assertEqual( + resolved, + [ + str((root / "dm-a.qza").resolve()), + str((root / "data" / "dm-b.qza").resolve()), + ], + ) + + def test_collection_input_list_resolves_each_path(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + + resolved = resolve_pipeline_input( + source=["dm-a.qza", "nested/dm-b.qza"], + type_name="List[DistanceMatrix]", + cwd=root, + ) + + self.assertEqual( + resolved, + [ + str((root / "dm-a.qza").resolve()), + str((root / "nested" / "dm-b.qza").resolve()), + ], + ) + def test_preserves_completed_output_when_later_task_fails(self) -> None: output_def = FakeOutputDef(id="out-1", name="result") pipeline = FakePipeline( From 638dfe8fba5e886753efc16828a4d8c149a3766e Mon Sep 17 00:00:00 2001 From: John Chase Date: Mon, 27 Apr 2026 20:41:56 -0700 Subject: [PATCH 2/3] Reverts formatting --- src/adagio/cli/dynamic.py | 20 ++-------- src/adagio/cli/runner.py | 12 +----- src/adagio/cli/runtime.py | 8 +--- src/adagio/io.py | 6 +-- src/adagio/model/arguments.py | 20 +++++----- src/adagio/model/pipeline.py | 70 +++++++++++++++-------------------- src/adagio/model/task.py | 60 +++++++++++++++--------------- tests/test_output_options.py | 17 ++------- 8 files changed, 83 insertions(+), 130 deletions(-) diff --git a/src/adagio/cli/dynamic.py b/src/adagio/cli/dynamic.py index 75ac059..39a1206 100644 --- a/src/adagio/cli/dynamic.py +++ b/src/adagio/cli/dynamic.py @@ -79,9 +79,7 @@ def _pipeline_type_label(type_hint: Any) -> str: return "TEXT" -def _display_type_label( - *, spec_type: str | None, type_hint: Any, is_input: bool -) -> str: +def _display_type_label(*, spec_type: str | None, type_hint: Any, is_input: bool) -> str: if is_input: return path_type_label(spec_type) @@ -430,9 +428,7 @@ def add_dynamic_option( required_input_specs = [spec for spec in input_specs if spec.required] optional_input_specs = [spec for spec in input_specs if not spec.required] required_param_specs = [spec for spec in param_specs if _is_required_param(spec)] - optional_param_specs = [ - spec for spec in param_specs if not _is_required_param(spec) - ] + optional_param_specs = [spec for spec in param_specs if not _is_required_param(spec)] def add_input_spec(spec: InputSpec) -> None: original = spec.name @@ -484,9 +480,7 @@ def add_param_spec(spec: ParamSpec) -> None: argument_value = argument_params.get(original) has_argument_default = not _is_missing(argument_value) display_default = ( - default - if default is not None - else (argument_value if has_argument_default else None) + default if default is not None else (argument_value if has_argument_default else None) ) display_required = is_required and display_default is None param_default = None @@ -583,10 +577,4 @@ def run( def _is_missing(value: Any) -> bool: - return ( - value is None - or value == "" - or value == "" - or value == [] - or value == {} - ) + return value is None or value == "" or value == "" or value == [] or value == {} diff --git a/src/adagio/cli/runner.py b/src/adagio/cli/runner.py index deb94a9..52c8a2c 100644 --- a/src/adagio/cli/runner.py +++ b/src/adagio/cli/runner.py @@ -103,9 +103,7 @@ def run_pipeline_from_kwargs( if isinstance(value, list): arguments.inputs[original] = [str(item) for item in value] elif isinstance(value, dict): - arguments.inputs[original] = { - str(key): str(item) for key, item in value.items() - } + arguments.inputs[original] = {str(key): str(item) for key, item in value.items()} else: arguments.inputs[original] = str(value) @@ -184,13 +182,7 @@ def run_pipeline_from_kwargs( def _is_missing(value: Any) -> bool: """Treat placeholders and null values as missing.""" - return ( - value is None - or value == "" - or value == "" - or value == [] - or value == {} - ) + return value is None or value == "" or value == "" or value == [] or value == {} def _is_missing_output(value: Any) -> bool: diff --git a/src/adagio/cli/runtime.py b/src/adagio/cli/runtime.py index 59313e4..5886e24 100644 --- a/src/adagio/cli/runtime.py +++ b/src/adagio/cli/runtime.py @@ -361,13 +361,7 @@ def _outputs_need_default(outputs: str | dict[str, str]) -> bool: def _is_missing(value: Any) -> bool: - return ( - value is None - or value == "" - or value == "" - or value == [] - or value == {} - ) + return value is None or value == "" or value == "" or value == [] or value == {} def _validate_required_arguments( diff --git a/src/adagio/io.py b/src/adagio/io.py index 4d8cdb6..cfcf5b3 100644 --- a/src/adagio/io.py +++ b/src/adagio/io.py @@ -1,11 +1,11 @@ + from adagio.execution.proxy import ProxyMetadata, lift_parsl, IndexedProxyArtifact @lift_parsl(lambda fut: IndexedProxyArtifact(fut, 0)) def load_input(*, ctx, source: str): - from qiime2.sdk import Artifact + from qiime2.sdk import Results, Artifact from qiime2.sdk import PluginManager - PluginManager() with ctx.cache: @@ -34,7 +34,6 @@ def load_input_collection(*, ctx, sources): def load_metadata(*, ctx, source: str): from qiime2 import Artifact, Metadata import zipfile - if zipfile.is_zipfile(source): metadata = Artifact.load(source).view(Metadata) else: @@ -43,6 +42,7 @@ def load_metadata(*, ctx, source: str): return metadata + @lift_parsl(lambda fut: fut) def save_output(*, ctx, output, destination): output.save(destination) diff --git a/src/adagio/model/arguments.py b/src/adagio/model/arguments.py index b991815..a4d5084 100644 --- a/src/adagio/model/arguments.py +++ b/src/adagio/model/arguments.py @@ -1,3 +1,5 @@ +import typing as t + from pydantic import BaseModel, Field from .task import AllowableValue @@ -12,23 +14,21 @@ class AdagioArguments(BaseModel): def __repr__(self): """Format arguments for display.""" - return "\n".join( - [ - *self._format_repr_sect(self.inputs, "inputs"), - *self._format_repr_sect(self.parameters, "parameters"), - *self._format_repr_sect(self.outputs, "outputs"), - ] - ) + return '\n'.join([ + *self._format_repr_sect(self.inputs, 'inputs'), + *self._format_repr_sect(self.parameters, 'parameters'), + *self._format_repr_sect(self.outputs, 'outputs'), + ]) def _format_repr_sect(self, section, name): """Format a single argument section.""" lines = [] if not section: - lines.append(f"{name}: {{}}") + lines.append(f'{name}: {{}}') else: - lines.append(f"{name}:") + lines.append(f'{name}:') for name, value in section.items(): - lines.append(f" {name}: {value!r}") + lines.append(f' {name}: {value!r}') return lines diff --git a/src/adagio/model/pipeline.py b/src/adagio/model/pipeline.py index 7176616..488b736 100644 --- a/src/adagio/model/pipeline.py +++ b/src/adagio/model/pipeline.py @@ -1,7 +1,8 @@ import typing as t import os +import json -from pydantic import BaseModel, RootModel, model_validator +from pydantic import BaseModel, RootModel, model_validator, Field from .arguments import AdagioArguments @@ -10,51 +11,53 @@ class AdagioPipeline(BaseModel): - type: t.Literal["pipeline"] + type: t.Literal['pipeline'] # meta: 'AdagioPipelineMetadata' - signature: "AdagioSignature" - graph: list["AdagioTask"] + signature: 'AdagioSignature' + graph: list['AdagioTask'] def validate_graph(self): pass - def iter_tasks(self) -> t.Generator["AdagioTask", None, None]: + def iter_tasks(self) -> t.Generator['AdagioTask', None, None]: yield from self.graph + class AdagioPipelineMetadata(RootModel): root: dict[str, t.Any] - @model_validator(mode="before") + @model_validator(mode='before') def check_version(cls, data): - if "version" not in data: + if 'version' not in data: raise AssertionError('Missing "version" field.') class AdagioSignature(BaseModel): - inputs: "list[_InputDef]" - parameters: "list[_ParameterDef]" - outputs: "list[_OutputDef]" + inputs: 'list[_InputDef]' + parameters: 'list[_ParameterDef]' + outputs: 'list[_OutputDef]' def to_default_arguments(self): inputs = {} for input in self.inputs: - inputs[input.name] = "" + inputs[input.name] = '' params = {} for param in self.parameters: if param.required: - params[param.name] = "" + params[param.name] = '' else: params[param.name] = param.default outputs = {} for output in self.outputs: - outputs[output.name] = "" + outputs[output.name] = '' return AdagioArguments(inputs=inputs, parameters=params, outputs=outputs) def validate_arguments(self, args: AdagioArguments): return + def get_params(self, args: AdagioArguments): lookup = {} for param in self.parameters: @@ -67,31 +70,19 @@ def load_inputs(self, ctx, arguments, scope): for input in self.inputs: source = arguments.inputs[input.name] if _is_metadata_ast(input.ast): - print("SCHEDULED:", f"load_metadata({source!r})") + print("SCHEDULED:", f'load_metadata({source!r})') scope[input.id] = load_metadata(ctx=ctx, source=source) # IIFE for the dreaded for-loop in the parent closure problem. - scope[input.id]._future_.add_done_callback( - (lambda str: lambda x: print("DONE:", str))( - f"load_metadata({source!r})" - ) - ) + scope[input.id]._future_.add_done_callback((lambda str: (lambda x: print("DONE:", str)))(f'load_metadata({source!r})')) elif _is_collection_type(input.type): - print("SCHEDULED:", f"load_input_collection({source!r})") + print("SCHEDULED:", f'load_input_collection({source!r})') scope[input.id] = load_input_collection(ctx=ctx, sources=source) - scope[input.id]._future_.add_done_callback( - (lambda str: lambda x: print("DONE:", str))( - f"load_input_collection({source!r})" - ) - ) + scope[input.id]._future_.add_done_callback((lambda str: (lambda x: print("DONE:", str)))(f'load_input_collection({source!r})')) else: - print("SCHEDULED:", f"load_input({source!r})") + print("SCHEDULED:", f'load_input({source!r})') scope[input.id] = load_input(ctx=ctx, source=source) # IIFE for the dreaded for-loop in the parent closure problem. - scope[input.id]._future_.add_done_callback( - (lambda str: lambda x: print("DONE:", str))( - f"load_input({source!r})" - ) - ) + scope[input.id]._future_.add_done_callback((lambda str: (lambda x: print("DONE:", str)))(f'load_input({source!r})')) def save_outputs(self, ctx, arguments: AdagioArguments, scope): from adagio.io import save_output @@ -103,15 +94,11 @@ def save_outputs(self, ctx, arguments: AdagioArguments, scope): elif type(arguments.outputs) is dict: dest = arguments.outputs[output.name] else: - raise NotImplementedError("impossible") - print("SCHEDULED:", f"{output.name}.save({dest!r})") + raise NotImplementedError('impossible') + print("SCHEDULED:", f'{output.name}.save({dest!r})') future = save_output(ctx=ctx, output=scope[output.id], destination=dest) # IIFE for the dreaded for-loop in the parent closure problem. - future.add_done_callback( - (lambda str: lambda x: print("DONE:", str))( - f"{output.name}.save({dest!r})" - ) - ) + future.add_done_callback((lambda str: (lambda x: print("DONE:", str)))(f'{output.name}.save({dest!r})')) futures.append(future) for future in futures: @@ -121,6 +108,7 @@ def save_outputs(self, ctx, arguments: AdagioArguments, scope): pass + class _Def(BaseModel): id: str name: str @@ -135,7 +123,7 @@ class _InputDef(_Def): class _ParameterDef(_Def): required: bool - default: "AllowableValue | None" = None + default: 'AllowableValue | None' = None class _OutputDef(_Def): @@ -144,11 +132,11 @@ class _OutputDef(_Def): def _is_metadata_ast(ast: TypeAST) -> bool: if isinstance(ast, TypeASTExpression): - return bool(ast.builtin and ast.name.startswith("Metadata")) + return bool(ast.builtin and ast.name.startswith('Metadata')) if isinstance(ast, (TypeASTUnion, TypeASTIntersection)): return any(_is_metadata_ast(member) for member in ast.members) return False def _is_collection_type(type_name: str) -> bool: - return type_name.startswith("List[") or type_name.startswith("Collection[") + return type_name.startswith('List[') or type_name.startswith('Collection[') diff --git a/src/adagio/model/task.py b/src/adagio/model/task.py index fb8d001..a9b3b5d 100644 --- a/src/adagio/model/task.py +++ b/src/adagio/model/task.py @@ -5,9 +5,9 @@ class _BaseTask(BaseModel): id: str kind: str - inputs: dict[str, "TaskInputVal"] - parameters: dict[str, "LiteralVal | MetadataVal | PromotedVal"] - outputs: dict[str, "OutputVal"] + inputs: dict[str, 'TaskInputVal'] + parameters: dict[str, 'LiteralVal | MetadataVal | PromotedVal'] + outputs: dict[str, 'OutputVal'] def exec(self, ctx, params, scope): raise NotImplementedError @@ -15,7 +15,7 @@ def exec(self, ctx, params, scope): class PluginActionTask(_BaseTask): id: str - kind: t.Literal["plugin-action"] + kind: t.Literal['plugin-action'] name: str | None = None plugin: str action: str @@ -27,39 +27,39 @@ def exec(self, ctx, params, scope): kwargs = {} metadata = {} for name, src in self.inputs.items(): - if src.kind == "archive": + if src.kind == 'archive': kwargs[name] = scope[src.id] - elif src.kind == "archive-collection": + elif src.kind == 'archive-collection': kwargs[name] = _flatten_collection_values( [scope[item.id] for item in src.items] ) - elif src.kind == "metadata": + elif src.kind == 'metadata': # store for second pass in params metadata[name] = scope[src.id] else: - raise NotImplementedError("impossible") + raise NotImplementedError('impossible') for name, param in self.parameters.items(): - if param.kind == "metadata": - if param.column.kind == "literal": + if param.kind == 'metadata': + if param.column.kind == 'literal': col = param.value - elif param.column.kind == "promoted": + elif param.column.kind == 'promoted': col = params[param.column.id] else: - raise NotImplementedError("impossible") + raise NotImplementedError('impossible') source = metadata.pop(name) md = convert_metadata(ctx=ctx, metadata=source) kwargs[name] = md.get_column(col) - elif param.kind == "literal": + elif param.kind == 'literal': kwargs[name] = param.value - elif param.kind == "promoted": + elif param.kind == 'promoted': kwargs[name] = params[param.id] else: - raise NotImplementedError("impossible") + raise NotImplementedError('impossible') # any remaining metadata is used directly for name, value in metadata.items(): @@ -71,8 +71,8 @@ def exec(self, ctx, params, scope): class RootInputTask(_BaseTask): - kind: t.Literal["built-in"] - name: t.Literal["root-input"] + kind: t.Literal['built-in'] + name: t.Literal['root-input'] def exec(self, ctx, params, scope): for name, src in self.inputs.items(): @@ -81,7 +81,7 @@ def exec(self, ctx, params, scope): class InputVal(BaseModel): - kind: t.Literal["archive", "metadata"] + kind: t.Literal['archive', 'metadata'] id: str @@ -91,24 +91,24 @@ class ArchiveCollectionItemVal(BaseModel): class ArchiveCollectionInputVal(BaseModel): - kind: t.Literal["archive-collection"] - style: t.Literal["list"] + kind: t.Literal['archive-collection'] + style: t.Literal['list'] items: list[ArchiveCollectionItemVal] class OutputVal(BaseModel): - kind: t.Literal["archive"] + kind: t.Literal['archive'] id: str class PromotedVal(BaseModel): - kind: t.Literal["promoted"] + kind: t.Literal['promoted'] id: str class LiteralVal(BaseModel): - kind: t.Literal["literal"] - value: "AllowableValue" + kind: t.Literal['literal'] + value: 'AllowableValue' class LiteralStrVal(LiteralVal): @@ -116,7 +116,7 @@ class LiteralStrVal(LiteralVal): class MetadataVal(BaseModel): - kind: t.Literal["metadata"] + kind: t.Literal['metadata'] column: PromotedVal | LiteralStrVal @@ -124,15 +124,15 @@ class MetadataVal(BaseModel): Collection = list[Primitive] | dict[str, Primitive] AllowableValue = Primitive | Collection TaskInputVal = t.Annotated[ - t.Union[InputVal, ArchiveCollectionInputVal], Field(discriminator="kind") -] -AdagioTask = t.Annotated[ - t.Union[PluginActionTask, RootInputTask], Field(discriminator="kind") + t.Union[InputVal, ArchiveCollectionInputVal], + Field(discriminator='kind') ] +AdagioTask = t.Annotated[t.Union[PluginActionTask, RootInputTask], + Field(discriminator='kind')] def input_source_ids(value: TaskInputVal) -> list[str]: - if value.kind == "archive-collection": + if value.kind == 'archive-collection': return [item.id for item in value.items] return [value.id] diff --git a/tests/test_output_options.py b/tests/test_output_options.py index 6bdfa40..a430143 100644 --- a/tests/test_output_options.py +++ b/tests/test_output_options.py @@ -48,12 +48,8 @@ def test_dynamic_run_adds_output_dir_and_per_output_options(self) -> None: self.assertIn("output_dir", dynamic_run.__signature__.parameters) self.assertIn("output_table", dynamic_run.__signature__.parameters) - output_dir_annotation = dynamic_run.__signature__.parameters[ - "output_dir" - ].annotation - output_annotation = dynamic_run.__signature__.parameters[ - "output_table" - ].annotation + output_dir_annotation = dynamic_run.__signature__.parameters["output_dir"].annotation + output_annotation = dynamic_run.__signature__.parameters["output_table"].annotation output_dir_help = typing.get_args(output_dir_annotation)[1].help output_help = typing.get_args(output_annotation)[1].help @@ -131,9 +127,7 @@ def test_output_dir_is_a_command_option_and_required_pipeline_options_are_first( run_handler=lambda *args, **kwargs: None, ) - output_dir_annotation = dynamic_run.__signature__.parameters[ - "output_dir" - ].annotation + output_dir_annotation = dynamic_run.__signature__.parameters["output_dir"].annotation output_dir_group = typing.get_args(output_dir_annotation)[1].group self.assertEqual(output_dir_group[0]._name, "Command Options") @@ -198,10 +192,7 @@ def test_outputs_are_only_visible_in_all_mode(self) -> None: def test_output_dir_override_applies_to_all_outputs(self) -> None: resolved = _apply_output_overrides( - outputs={ - "table": "/tmp/from-file/table.qza", - "stats": "/tmp/from-file/stats.qza", - }, + outputs={"table": "/tmp/from-file/table.qza", "stats": "/tmp/from-file/stats.qza"}, output_names=["table", "stats"], output_dir="/tmp/all-outputs", output_overrides={"stats": "/tmp/custom/stats.qza"}, From ecf25abbaf6ee5dab68cb625d48406fc9753b9ad Mon Sep 17 00:00:00 2001 From: John Chase Date: Mon, 27 Apr 2026 20:59:42 -0700 Subject: [PATCH 3/3] Updats --- src/adagio/model/arguments.py | 1 - src/adagio/model/pipeline.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/adagio/model/arguments.py b/src/adagio/model/arguments.py index a4d5084..4d372dc 100644 --- a/src/adagio/model/arguments.py +++ b/src/adagio/model/arguments.py @@ -1,5 +1,4 @@ import typing as t - from pydantic import BaseModel, Field from .task import AllowableValue diff --git a/src/adagio/model/pipeline.py b/src/adagio/model/pipeline.py index 488b736..4432646 100644 --- a/src/adagio/model/pipeline.py +++ b/src/adagio/model/pipeline.py @@ -132,7 +132,7 @@ class _OutputDef(_Def): def _is_metadata_ast(ast: TypeAST) -> bool: if isinstance(ast, TypeASTExpression): - return bool(ast.builtin and ast.name.startswith('Metadata')) + return bool(ast.builtin and ast.name.startswith("Metadata")) if isinstance(ast, (TypeASTUnion, TypeASTIntersection)): return any(_is_metadata_ast(member) for member in ast.members) return False