Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/adagio/cli/qapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from cyclopts import App, Parameter
from rich.console import Console
from rich.markup import escape

from ..qapi import DEFAULT_SCHEMA_VERSION, generate_qapi_payload, submit_qapi_payload

Expand Down Expand Up @@ -36,7 +37,8 @@ def _print_submission_summary(response_body: object) -> None:
overwritten = [
operation["plugin_name"]
for operation in operations
if isinstance(operation, dict) and operation.get("action") == "overwrite"
if isinstance(operation, dict)
and operation.get("action") == "overwrite"
]
if created:
console.print(f"[green]Create:[/green] {', '.join(created)}")
Expand All @@ -53,6 +55,26 @@ def _print_submission_summary(response_body: object) -> None:
console.print(json.dumps(response_body, indent=2))


def _print_skipped_private_actions(skipped_actions: list[str]) -> None:
if not skipped_actions:
return

sorted_actions = sorted(skipped_actions)
display_limit = 20
displayed_actions = ", ".join(
escape(action_name) for action_name in sorted_actions[:display_limit]
)
remaining_count = len(sorted_actions) - display_limit
if remaining_count > 0:
displayed_actions += f", and {remaining_count} more"

noun = "action" if len(sorted_actions) == 1 else "actions"
console.print(
f"[yellow]Skipped {len(sorted_actions)} private QIIME {noun}:[/yellow] "
f"{displayed_actions}"
)


def build_qapi(
*,
action_url: Annotated[
Expand Down Expand Up @@ -139,13 +161,16 @@ def build_qapi(
raise SystemExit("Use either --all or --plugin, not both.")

requested_plugins = None if all_plugins or not plugin else plugin
skipped_private_actions: list[str] = []
try:
request_body = generate_qapi_payload(
schema_version=schema_version,
plugins=requested_plugins,
on_skipped_private_action=skipped_private_actions.append,
)
except ValueError as exc:
raise SystemExit(str(exc)) from exc
_print_skipped_private_actions(skipped_private_actions)

if output is not None:
output.write_text(json.dumps(request_body, indent=2), encoding="utf-8")
Expand Down
63 changes: 54 additions & 9 deletions src/adagio/qapi/build.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,37 @@
import collections
from collections.abc import Sequence
from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import Any, cast

DEFAULT_SCHEMA_VERSION = "0.1.0"
PRIVATE_QIIME_ACTION_PREFIXES = ("_", "-")


def _private_qiime_action_id(action_key: object, action: Any) -> str | None:
action_id = getattr(action, "id", None)
for value in (action_id, action_key):
if isinstance(value, str) and value.startswith(PRIVATE_QIIME_ACTION_PREFIXES):
return value
return None


def _iter_public_qiime_actions(
actions: Mapping[object, Any],
*,
plugin_name: str | None = None,
on_skipped_private_action: Callable[[str], None] | None = None,
) -> Iterator[tuple[object, Any]]:
for key, action in actions.items():
private_action_id = _private_qiime_action_id(key, action)
if private_action_id is not None:
if on_skipped_private_action is not None:
skipped_action_id = (
f"{plugin_name}.{private_action_id}"
if plugin_name is not None
else private_action_id
)
on_skipped_private_action(skipped_action_id)
continue
yield key, action


def normalize_plugin_selection(plugin_names: Sequence[str] | None) -> list[str] | None:
Expand All @@ -24,6 +53,7 @@ def generate_qapi_payload(
*,
schema_version: str = DEFAULT_SCHEMA_VERSION,
plugins: Sequence[str] | None = None,
on_skipped_private_action: Callable[[str], None] | None = None,
) -> dict[str, Any]:
"""Generate a QAPI payload for all plugins or a selected subset."""
import qiime2
Expand Down Expand Up @@ -53,7 +83,9 @@ def flatten_type_maps(qiime_type: Any) -> Any:
final_predicate = None
if isinstance(qiime_type.predicate, UnionExp):
predicate = qiime_type.predicate.unpack_union()
final_predicate = UnionExp([flatten_type_maps(elem) for elem in predicate])
final_predicate = UnionExp(
[flatten_type_maps(elem) for elem in predicate]
)
final_predicate.normalize()
elif isinstance(qiime_type.predicate, IntersectionExp):
predicate = qiime_type.predicate.unpack_intersection()
Expand All @@ -72,7 +104,10 @@ def ast_to_basename(ast: dict[str, Any]) -> str:
if not ast.get("fields"):
return cast(str, ast["name"])

fields = [ast_to_basename(field) for field in cast(list[dict[str, Any]], ast["fields"])]
fields = [
ast_to_basename(field)
for field in cast(list[dict[str, Any]], ast["fields"])
]
return f"{ast['name']}[{', '.join(fields)}]"

def add_metadata_flag(ast: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -118,7 +153,9 @@ def build_inspect_dict(action: Any) -> dict[str, Any]:
{
"name": name,
"type": repr(spec.qiime_type),
"ast": add_metadata_flag(flatten_type_maps(spec.qiime_type).to_ast()),
"ast": add_metadata_flag(
flatten_type_maps(spec.qiime_type).to_ast()
),
"description": optional_desc(spec.description),
}
for name, spec in action.signature.outputs.items()
Expand All @@ -128,10 +165,16 @@ def build_inspect_dict(action: Any) -> dict[str, Any]:
"source": action.source.replace("\n```python\n", "").replace("```\n", ""),
}

def build_data_dict(data: Any) -> dict[str, Any]:
def build_data_dict(
*, plugin_name: str, data: Mapping[object, Any]
) -> dict[str, Any]:
result: dict[str, Any] = collections.defaultdict(dict)
for key, value in data.items():
result[key] = build_inspect_dict(value)
for key, value in _iter_public_qiime_actions(
data,
plugin_name=plugin_name,
on_skipped_private_action=on_skipped_private_action,
):
result[str(key)] = build_inspect_dict(value)
return result

qapi: dict[str, Any] = {}
Expand All @@ -147,8 +190,10 @@ def build_data_dict(data: Any) -> dict[str, Any]:

for plugin_name in selected_plugins:
plugin = plugin_manager.plugins[plugin_name]
methods_dict = build_data_dict(plugin.actions)
methods_dict.update(build_data_dict(plugin.pipelines))
methods_dict = build_data_dict(plugin_name=plugin_name, data=plugin.actions)
methods_dict.update(
build_data_dict(plugin_name=plugin_name, data=plugin.pipelines)
)
qapi[plugin_name] = {"methods": methods_dict}

return {
Expand Down
94 changes: 94 additions & 0 deletions tests/test_qapi_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import io
import unittest
from types import SimpleNamespace
from unittest.mock import patch

from rich.console import Console

from adagio.cli import qapi as qapi_cli
from adagio.qapi.build import _iter_public_qiime_actions


class QapiBuildTests(unittest.TestCase):
def test_iter_public_qiime_actions_skips_private_action_names(self) -> None:
public_action = SimpleNamespace(id="public_action")
skipped_actions: list[str] = []

actions = {
"public_action": public_action,
"_private_by_key": SimpleNamespace(id="private_by_key"),
"-private_by_key": SimpleNamespace(id="private_by_key"),
"private_by_id": SimpleNamespace(id="_private_by_id"),
"private_by_hyphen_id": SimpleNamespace(id="-private_by_hyphen_id"),
}

public_actions = list(
_iter_public_qiime_actions(
actions,
plugin_name="example",
on_skipped_private_action=skipped_actions.append,
)
)

self.assertEqual(public_actions, [("public_action", public_action)])
self.assertEqual(
skipped_actions,
[
"example._private_by_key",
"example.-private_by_key",
"example._private_by_id",
"example.-private_by_hyphen_id",
],
)

def test_build_qapi_submits_payload_after_private_actions_are_skipped(self) -> None:
output = io.StringIO()
original_console = qapi_cli.console
qapi_cli.console = Console(file=output, force_terminal=False, color_system=None)

def fake_generate_qapi_payload(*, on_skipped_private_action, **kwargs):
on_skipped_private_action("example._private_action")
return {
"qiime_version": "2024.10.0",
"schema_version": "0.1.0",
"data": {
"example": {
"methods": {
"public_action": {
"id": "public_action",
},
},
},
},
}

try:
with (
patch(
"adagio.cli.qapi.generate_qapi_payload",
side_effect=fake_generate_qapi_payload,
),
patch("adagio.cli.qapi.submit_qapi_payload") as submit_mock,
):
submit_mock.return_value = (
"https://adagiodata.com/api/v1/qapi/",
200,
{"message": "ok"},
)

qapi_cli.build_qapi(action_url="https://adagiodata.com/api/v1")
finally:
qapi_cli.console = original_console

submit_mock.assert_called_once()
submitted_payload = submit_mock.call_args.args[0]
self.assertEqual(
submitted_payload["data"]["example"]["methods"],
{"public_action": {"id": "public_action"}},
)
self.assertIn("Skipped 1 private QIIME action", output.getvalue())
self.assertIn("example._private_action", output.getvalue())


if __name__ == "__main__":
unittest.main()
Loading