Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/adagio/cli/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ def _is_required_param(spec: ParamSpec) -> bool:
return bool(spec.required and spec.default is None)


def _is_collection_type(type_name: str) -> bool:
return type_name.startswith("List[") or type_name.startswith("Collection[")


def build_dynamic_run(
*,
input_specs: list[InputSpec],
Expand Down Expand Up @@ -453,7 +457,7 @@ def add_input_spec(spec: InputSpec) -> None:
ident=ident,
opt=opt,
required=False,
py_type=str,
py_type=list[str] if _is_collection_type(spec.type) else str,
help_text=_format_help_text(
description=spec.description,
),
Expand Down Expand Up @@ -573,4 +577,4 @@ def run(


def _is_missing(value: Any) -> bool:
return value is None or value == "<fill me>"
return value is None or value == "" or value == "<fill me>" or value == [] or value == {}
2 changes: 1 addition & 1 deletion src/adagio/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _load_arguments_data(path: Path, _console: Console | None = None) -> dict[st


def _is_missing(value: Any) -> bool:
return value is None or value == "<fill me>"
return value is None or value == "" or value == "<fill me>" or value == [] or value == {}


if __name__ == "__main__":
Expand Down
9 changes: 7 additions & 2 deletions src/adagio/cli/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ def run_pipeline_from_kwargs(
for ident, original in input_bindings:
value = kwargs.get(ident)
if value is not None:
arguments.inputs[original] = str(value)
if isinstance(value, list):
arguments.inputs[original] = [str(item) for item in value]
elif isinstance(value, dict):
arguments.inputs[original] = {str(key): str(item) for key, item in value.items()}
else:
arguments.inputs[original] = str(value)

for ident, original in param_bindings:
value = kwargs.get(ident)
Expand Down Expand Up @@ -177,7 +182,7 @@ def run_pipeline_from_kwargs(

def _is_missing(value: Any) -> bool:
"""Treat placeholders and null values as missing."""
return value is None or value == "<fill me>"
return value is None or value == "" or value == "<fill me>" or value == [] or value == {}


def _is_missing_output(value: Any) -> bool:
Expand Down
22 changes: 19 additions & 3 deletions src/adagio/cli/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _apply_named_arguments(
raw_inputs = runtime_arguments.get("inputs", {})
if isinstance(raw_inputs, dict):
for name, value in raw_inputs.items():
arguments.inputs[name] = _resolve_input_path(
arguments.inputs[name] = _resolve_input_value(
value, storage_root=storage_root
)

Expand Down Expand Up @@ -283,7 +283,7 @@ def _apply_legacy_arguments(
named_inputs = runtime_arguments.get("inputs", {})
if isinstance(named_inputs, dict):
for name, value in named_inputs.items():
arguments.inputs[name] = _resolve_input_path(
arguments.inputs[name] = _resolve_input_value(
value, storage_root=storage_root
)

Expand Down Expand Up @@ -312,6 +312,22 @@ def _resolve_input_path(value: Any, *, storage_root: str) -> str:
return str(value)


def _resolve_input_value(value: Any, *, storage_root: str) -> Any:
if isinstance(value, list):
return [_resolve_input_value(item, storage_root=storage_root) for item in value]
if isinstance(value, dict):
path = value.get("path")
if path is not None:
return _normalize_path(path, storage_root=storage_root)
return {
str(key): _resolve_input_value(item, storage_root=storage_root)
for key, item in value.items()
}
if isinstance(value, str):
return _normalize_path(value, storage_root=storage_root)
return str(value)


def _resolve_outputs(value: Any, *, storage_root: str) -> str | dict[str, str] | None:
if value is None:
return None
Expand Down Expand Up @@ -345,7 +361,7 @@ def _outputs_need_default(outputs: str | dict[str, str]) -> bool:


def _is_missing(value: Any) -> bool:
return value is None or value == "" or value == "<fill me>"
return value is None or value == "" or value == "<fill me>" or value == [] or value == {}


def _validate_required_arguments(
Expand Down
13 changes: 13 additions & 0 deletions src/adagio/executors/path_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from .container_support import is_uri

InputSource = str | list[str] | dict[str, str]


def resolve_host_path(*, source: str, cwd: Path) -> str:
if is_uri(source):
Expand All @@ -14,6 +16,17 @@ def resolve_host_path(*, source: str, cwd: Path) -> str:
return str((cwd / path).resolve())


def resolve_host_input(*, source: InputSource, cwd: Path) -> InputSource:
if isinstance(source, list):
return [resolve_host_path(source=item, cwd=cwd) for item in source]
if isinstance(source, dict):
return {
key: resolve_host_path(source=value, cwd=cwd)
for key, value in source.items()
}
return resolve_host_path(source=source, cwd=cwd)


def resolve_output_destination(
*,
output_name: str,
Expand Down
69 changes: 66 additions & 3 deletions src/adagio/executors/serial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from adagio.monitor.tty import RichMonitor

from .cache_support import ExecutionCacheConfig
from .container_support import is_uri
from .common import plan_execution_order, task_label
from .path_utils import resolve_host_path
from .path_utils import InputSource, resolve_host_input, resolve_host_path

CONTAINER_SUBTASK_COUNT = 1

Expand All @@ -23,7 +24,7 @@ class SerialExecutionState:
cwd: Path
work_path: Path
params: dict[str, t.Any]
scope: dict[str, str]
scope: dict[str, InputSource]
cache_config: ExecutionCacheConfig | None
saved_output_ids: set[str] = field(default_factory=set)
save_output_started: bool = False
Expand Down Expand Up @@ -64,7 +65,9 @@ def run_serial_pipeline(
active_monitor.start_load_input()
for input_def in sig.inputs:
source = arguments.inputs[input_def.name]
state.scope[input_def.id] = resolve_host_path(source=source, cwd=state.cwd)
state.scope[input_def.id] = resolve_pipeline_input(
source=source, type_name=input_def.type, cwd=state.cwd
)
active_monitor.finish_load_input()

execution_plan = plan_execution_order(tasks=tasks, scope=state.scope)
Expand Down Expand Up @@ -133,3 +136,63 @@ def resolve_monitor(*, console: Console | None, monitor: Monitor | None) -> Moni
if console is not None:
return RichMonitor(console=console)
return LogMonitor()


def resolve_pipeline_input(
*, source: InputSource, type_name: str, cwd: Path
) -> InputSource:
resolved = resolve_host_input(source=source, cwd=cwd)
if not is_collection_type(type_name):
return resolved

if isinstance(resolved, str):
return expand_collection_input_source(resolved)
if isinstance(resolved, list):
if len(resolved) == 1:
return expand_collection_input_source(resolved[0])
return resolved
return list(resolved.values())


def is_collection_type(type_name: str) -> bool:
return type_name.startswith("List[") or type_name.startswith("Collection[")


def expand_collection_input_source(source: str) -> list[str]:
path = Path(source)
if (
not is_uri(source)
and path.suffix.lower() in {".tsv", ".txt"}
and path.is_file()
):
return read_collection_manifest(path)
return [source]


def read_collection_manifest(path: Path) -> list[str]:
rows = [
line.rstrip("\n").split("\t")
for line in path.read_text(encoding="utf-8").splitlines()
if line.strip()
]
if not rows:
return []

header = [cell.strip().lower() for cell in rows[0]]
path_index = header.index("path") if "path" in header else None
data_rows = rows[1:] if path_index is not None else rows

result: list[str] = []
for row in data_rows:
if path_index is not None:
if path_index >= len(row):
continue
raw_path = row[path_index].strip()
elif len(row) >= 2:
raw_path = row[1].strip()
else:
raw_path = row[0].strip()

if raw_path:
result.append(resolve_host_path(source=raw_path, cwd=path.parent))
return result
35 changes: 30 additions & 5 deletions src/adagio/executors/task_environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,24 @@ def _execute_plugin_action(
metadata_inputs: dict[str, str] = {}
for name, src in task.inputs.items():
if src.kind == "archive":
archive_inputs[name] = state.scope[src.id]
value = state.scope[src.id]
if isinstance(value, list):
archive_collection_inputs[name] = value
elif isinstance(value, dict):
archive_collection_inputs[name] = list(value.values())
else:
archive_inputs[name] = value
elif src.kind == "archive-collection":
archive_collection_inputs[name] = [
state.scope[item.id] for item in src.items
]
archive_collection_inputs[name] = _flatten_collection_items(
[state.scope[item.id] for item in src.items]
)
elif src.kind == "metadata":
metadata_inputs[name] = state.scope[src.id]
value = state.scope[src.id]
if not isinstance(value, str):
raise TypeError(
f"Metadata input {name!r} must resolve to a single path."
)
metadata_inputs[name] = value
else:
raise TypeError(f"Unsupported input kind: {src.kind!r}")

Expand Down Expand Up @@ -165,6 +176,20 @@ def _execute_plugin_action(
return result.reused


def _flatten_collection_items(
values: list[str | list[str] | dict[str, str]],
) -> list[str]:
result: list[str] = []
for value in values:
if isinstance(value, list):
result.extend(value)
elif isinstance(value, dict):
result.extend(value.values())
else:
result.append(value)
return result


def _save_outputs(
*,
sig,
Expand Down
19 changes: 18 additions & 1 deletion src/adagio/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,23 @@ def load_input(*, ctx, source: str):

return [input]


@lift_parsl(lambda fut: fut)
def load_input_collection(*, ctx, sources):
from qiime2.sdk import Artifact
from qiime2.sdk import PluginManager

PluginManager()

if isinstance(sources, dict):
sources = list(sources.values())
elif isinstance(sources, str):
sources = [sources]

with ctx.cache:
return [Artifact.load(source) for source in sources]


@lift_parsl(ProxyMetadata)
def load_metadata(*, ctx, source: str):
from qiime2 import Artifact, Metadata
Expand All @@ -38,4 +55,4 @@ def convert_metadata(*, ctx, metadata):
if isinstance(metadata, qiime2.Artifact):
metadata = metadata.view(qiime2.Metadata)

return metadata
return metadata
6 changes: 4 additions & 2 deletions src/adagio/model/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

from .task import AllowableValue

InputValue = str | list[str] | dict[str, str]


class AdagioArguments(BaseModel):
inputs: dict[str, str]
inputs: dict[str, InputValue]
parameters: dict[str, AllowableValue]
outputs: str | dict[str, str]

Expand Down Expand Up @@ -34,6 +36,6 @@ class AdagioArgumentsFile(BaseModel):
"""Represent arguments loaded from a JSON file."""

version: int = 1
inputs: dict[str, str] = Field(default_factory=dict)
inputs: dict[str, InputValue] = Field(default_factory=dict)
parameters: dict[str, AllowableValue] = Field(default_factory=dict)
outputs: str | dict[str, str] | None = None
10 changes: 9 additions & 1 deletion src/adagio/model/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_params(self, args: AdagioArguments):
return lookup

def load_inputs(self, ctx, arguments, scope):
from adagio.io import load_input, load_metadata
from adagio.io import load_input, load_input_collection, load_metadata

for input in self.inputs:
source = arguments.inputs[input.name]
Expand All @@ -74,6 +74,10 @@ def load_inputs(self, ctx, arguments, scope):
scope[input.id] = load_metadata(ctx=ctx, source=source)
# IIFE for the dreaded for-loop in the parent closure problem.
scope[input.id]._future_.add_done_callback((lambda str: (lambda x: print("DONE:", str)))(f'load_metadata({source!r})'))
elif _is_collection_type(input.type):
print("SCHEDULED:", f'load_input_collection({source!r})')
scope[input.id] = load_input_collection(ctx=ctx, sources=source)
scope[input.id]._future_.add_done_callback((lambda str: (lambda x: print("DONE:", str)))(f'load_input_collection({source!r})'))
else:
print("SCHEDULED:", f'load_input({source!r})')
scope[input.id] = load_input(ctx=ctx, source=source)
Expand Down Expand Up @@ -132,3 +136,7 @@ def _is_metadata_ast(ast: TypeAST) -> bool:
if isinstance(ast, (TypeASTUnion, TypeASTIntersection)):
return any(_is_metadata_ast(member) for member in ast.members)
return False


def _is_collection_type(type_name: str) -> bool:
return type_name.startswith('List[') or type_name.startswith('Collection[')
16 changes: 15 additions & 1 deletion src/adagio/model/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def exec(self, ctx, params, scope):
if src.kind == 'archive':
kwargs[name] = scope[src.id]
elif src.kind == 'archive-collection':
kwargs[name] = [scope[item.id] for item in src.items]
kwargs[name] = _flatten_collection_values(
[scope[item.id] for item in src.items]
)
elif src.kind == 'metadata':
# store for second pass in params
metadata[name] = scope[src.id]
Expand Down Expand Up @@ -133,3 +135,15 @@ def input_source_ids(value: TaskInputVal) -> list[str]:
if value.kind == 'archive-collection':
return [item.id for item in value.items]
return [value.id]


def _flatten_collection_values(values: list[t.Any]) -> list[t.Any]:
result: list[t.Any] = []
for value in values:
if isinstance(value, list):
result.extend(value)
elif isinstance(value, dict):
result.extend(value.values())
else:
result.append(value)
return result
Loading
Loading