Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 23 additions & 45 deletions airflow-core/src/airflow/cli/commands/asset_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,100 +19,78 @@

import typing

from sqlalchemy import select

from airflow.api_fastapi.core_api.datamodels.assets import AssetAliasResponse, AssetResponse
from airflow.cli.api_client import NEW_API_CLIENT, Client, provide_api_client
from airflow.cli.simple_table import AirflowConsole
from airflow.cli.utils import deprecated_for_airflowctl
from airflow.models.asset import AssetAliasModel, AssetModel
from airflow.utils import cli as cli_utils
from airflow.utils.session import NEW_SESSION, provide_session

if typing.TYPE_CHECKING:
from typing import Any

from sqlalchemy.orm import Session

from airflow.api_fastapi.core_api.base import BaseModel


def _list_asset_aliases(args, *, session: Session) -> tuple[Any, type[BaseModel]]:
aliases = session.scalars(select(AssetAliasModel).order_by(AssetAliasModel.name))
return aliases, AssetAliasResponse


def _list_assets(args, *, session: Session) -> tuple[Any, type[BaseModel]]:
assets = session.scalars(select(AssetModel).order_by(AssetModel.name)).all()
for asset in assets:
for watcher in asset.watchers:
# ``AssetWatcherModel`` has no ``created_date`` column; like the public API
# serializer, derive it from the watcher's trigger so ``AssetResponse`` validation
# succeeds. Set on the instance so ``model_validate`` reads it via ``from_attributes``.
watcher.created_date = watcher.trigger.created_date
return assets, AssetResponse


@cli_utils.action_cli
@provide_session
def asset_list(args, *, session: Session = NEW_SESSION) -> None:
@deprecated_for_airflowctl("airflowctl assets list / airflowctl assets list-by-alias")
@provide_api_client
def asset_list(args, api_client: Client = NEW_API_CLIENT) -> None:
"""Display assets in the command line."""
if args.alias:
data, model_cls = _list_asset_aliases(args, session=session)
data = api_client.assets.list_by_alias().asset_aliases
else:
data, model_cls = _list_assets(args, session=session)
data = api_client.assets.list().assets

def detail_mapper(asset: Any) -> dict[str, Any]:
model = model_cls.model_validate(asset)
return model.model_dump(mode="json", include=args.columns)
return asset.model_dump(mode="json", include=args.columns)

AirflowConsole().print_as(data=data, output=args.output, mapper=detail_mapper)


def _detail_asset_alias(args, *, session: Session) -> BaseModel:
def _detail_asset_alias(args, api_client: Client = NEW_API_CLIENT) -> BaseModel:
if not args.name:
raise SystemExit("Required --name with --alias")
if args.uri:
raise SystemExit("Cannot use --uri with --alias")

alias = session.scalar(select(AssetAliasModel).where(AssetAliasModel.name == args.name))
alias = api_client.assets.get_by_alias(alias=args.name)
if alias is None:
raise SystemExit(f"Asset alias with name {args.name} does not exist.")

return AssetAliasResponse.model_validate(alias)
return alias


def _detail_asset(args, *, session: Session) -> BaseModel:
def _detail_asset(args, api_client: Client = NEW_API_CLIENT) -> BaseModel:
if not args.name and not args.uri:
raise SystemExit("Either --name or --uri is required")

stmt = select(AssetModel)
select_message_parts = []
if args.name:
stmt = stmt.where(AssetModel.name == args.name)
select_message_parts.append(f"name {args.name}")
if args.uri:
stmt = stmt.where(AssetModel.uri == args.uri)
select_message_parts.append(f"URI {args.uri}")
asset_it = iter(session.scalars(stmt.limit(2)))
matches = [
asset
for asset in api_client.assets.list().assets
if (not args.name or asset.name == args.name) and (not args.uri or asset.uri == args.uri)
]
select_message = " and ".join(select_message_parts)

if (asset := next(asset_it, None)) is None:
if not matches:
raise SystemExit(f"Asset with {select_message} does not exist.")
if next(asset_it, None) is not None:
if len(matches) > 1:
raise SystemExit(f"More than one asset exists with {select_message}.")

return AssetResponse.model_validate(asset)
return matches[0]


@cli_utils.action_cli
@provide_session
def asset_details(args, *, session: Session = NEW_SESSION) -> None:
@deprecated_for_airflowctl("airflowctl assets get / airflowctl assets get-by-alias")
@provide_api_client
def asset_details(args, api_client: Client = NEW_API_CLIENT) -> None:
"""Display details of an asset."""
if args.alias:
model = _detail_asset_alias(args, session=session)
model = _detail_asset_alias(args, api_client)
else:
model = _detail_asset(args, session=session)
model = _detail_asset(args, api_client)

model_data = model.model_dump(mode="json")
if args.output in ["table", "plain"]:
Expand Down
171 changes: 105 additions & 66 deletions airflow-core/tests/unit/cli/commands/test_asset_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,72 +53,111 @@ def parser() -> ArgumentParser:
return cli_parser.get_parser()


def test_cli_assets_list(prepare_examples, parser: ArgumentParser, stdout_capture) -> None:
args = parser.parse_args(["assets", "list", "--output=json"])
with stdout_capture as capture:
asset_command.asset_list(args)

asset_list = json.loads(capture.getvalue())
assert len(asset_list) > 0
assert set(asset_list[0]) == {"name", "uri", "group", "extra"}
assert any(asset["uri"] == "s3://dag1/output_1.txt" for asset in asset_list), asset_list


def test_cli_assets_alias_list(prepare_examples, parser: ArgumentParser, stdout_capture) -> None:
args = parser.parse_args(["assets", "list", "--alias", "--output=json"])
with stdout_capture as capture:
asset_command.asset_list(args)

alias_list = json.loads(capture.getvalue())
assert len(alias_list) > 0
assert set(alias_list[0]) == {"name", "group"}
assert any(alias["name"] == "example-alias" for alias in alias_list), alias_list


def test_cli_assets_details(prepare_examples, parser: ArgumentParser, stdout_capture) -> None:
args = parser.parse_args(["assets", "details", "--name=asset1_producer", "--output=json"])
with stdout_capture as capture:
asset_command.asset_details(args)

asset_detail_list = json.loads(capture.getvalue())
assert len(asset_detail_list) == 1

# No good way to statically compare these.
undeterministic = {
"id": None,
"created_at": None,
"updated_at": None,
"scheduled_dags": None,
"producing_tasks": None,
"consuming_tasks": None,
}

assert asset_detail_list[0] | undeterministic == undeterministic | {
"name": "asset1_producer",
"uri": "s3://bucket/asset1_producer",
"group": "asset",
"extra": {},
"aliases": [],
"watchers": [],
"last_asset_event": None,
}


def test_cli_assets_alias_details(prepare_examples, parser: ArgumentParser, stdout_capture) -> None:
args = parser.parse_args(["assets", "details", "--alias", "--name=example-alias", "--output=json"])
with stdout_capture as capture:
asset_command.asset_details(args)

alias_detail_list = json.loads(capture.getvalue())
assert len(alias_detail_list) == 1

# No good way to statically compare these.
undeterministic = {"id": None}

assert alias_detail_list[0] | undeterministic == undeterministic | {
"name": "example-alias",
"group": "asset",
}
@pytest.mark.non_db_test_override
class TestCliAssetsList:
"""`assets list` goes through the airflowctl client; mocked here (no DB/server)."""

def test_list(self, parser: ArgumentParser, mock_cli_api_client, stdout_capture) -> None:
mock_cli_api_client.assets.list.return_value.assets = [
SimpleNamespace(
model_dump=lambda **kwargs: {
"name": "asset1",
"uri": "s3://dag1/output_1.txt",
"group": "asset",
"extra": {},
}
),
]
args = parser.parse_args(["assets", "list", "--output=json"])
with stdout_capture as capture:
asset_command.asset_list(args)

asset_list = json.loads(capture.getvalue())
assert asset_list == [
{"name": "asset1", "uri": "s3://dag1/output_1.txt", "group": "asset", "extra": {}}
]
mock_cli_api_client.assets.list.assert_called_once()
mock_cli_api_client.assets.list_by_alias.assert_not_called()

def test_list_aliases(self, parser: ArgumentParser, mock_cli_api_client, stdout_capture) -> None:
mock_cli_api_client.assets.list_by_alias.return_value.asset_aliases = [
SimpleNamespace(model_dump=lambda **kwargs: {"name": "example-alias", "group": "asset"}),
]
args = parser.parse_args(["assets", "list", "--alias", "--output=json"])
with stdout_capture as capture:
asset_command.asset_list(args)

alias_list = json.loads(capture.getvalue())
assert alias_list == [{"name": "example-alias", "group": "asset"}]
mock_cli_api_client.assets.list_by_alias.assert_called_once()
mock_cli_api_client.assets.list.assert_not_called()


@pytest.mark.non_db_test_override
class TestCliAssetsDetails:
"""`assets details` goes through the airflowctl client; mocked here (no DB/server)."""

def test_details(self, parser: ArgumentParser, mock_cli_api_client, stdout_capture) -> None:
mock_cli_api_client.assets.list.return_value.assets = [
SimpleNamespace(
name="asset1_producer",
uri="s3://bucket/asset1_producer",
model_dump=lambda **kwargs: {
"name": "asset1_producer",
"uri": "s3://bucket/asset1_producer",
"group": "asset",
"extra": {},
},
),
SimpleNamespace(name="other", uri="s3://bucket/other", model_dump=lambda **kwargs: {}),
]
args = parser.parse_args(["assets", "details", "--name=asset1_producer", "--output=json"])
with stdout_capture as capture:
asset_command.asset_details(args)

detail_list = json.loads(capture.getvalue())
assert detail_list == [
{
"name": "asset1_producer",
"uri": "s3://bucket/asset1_producer",
"group": "asset",
"extra": {},
}
]
mock_cli_api_client.assets.list.assert_called_once()
mock_cli_api_client.assets.get_by_alias.assert_not_called()

def test_details_by_alias(self, parser: ArgumentParser, mock_cli_api_client, stdout_capture) -> None:
mock_cli_api_client.assets.get_by_alias.return_value.model_dump.return_value = {
"name": "example-alias",
"group": "asset",
}
args = parser.parse_args(["assets", "details", "--alias", "--name=example-alias", "--output=json"])
with stdout_capture as capture:
asset_command.asset_details(args)

detail_list = json.loads(capture.getvalue())
assert detail_list == [{"name": "example-alias", "group": "asset"}]
mock_cli_api_client.assets.get_by_alias.assert_called_once_with(alias="example-alias")
mock_cli_api_client.assets.list.assert_not_called()

def test_details_requires_name_or_uri(self, parser: ArgumentParser, mock_cli_api_client) -> None:
with pytest.raises(SystemExit, match="Either --name or --uri is required"):
asset_command.asset_details(parser.parse_args(["assets", "details"]))
mock_cli_api_client.assets.list.assert_not_called()

def test_details_missing(self, parser: ArgumentParser, mock_cli_api_client) -> None:
mock_cli_api_client.assets.list.return_value.assets = []
with pytest.raises(SystemExit, match="Asset with name nope does not exist"):
asset_command.asset_details(parser.parse_args(["assets", "details", "--name=nope"]))

def test_details_ambiguous(self, parser: ArgumentParser, mock_cli_api_client) -> None:
mock_cli_api_client.assets.list.return_value.assets = [
SimpleNamespace(name="dup", uri="s3://a", model_dump=lambda **kwargs: {}),
SimpleNamespace(name="dup", uri="s3://b", model_dump=lambda **kwargs: {}),
]
with pytest.raises(SystemExit, match="More than one asset exists with name dup"):
asset_command.asset_details(parser.parse_args(["assets", "details", "--name=dup"]))


@pytest.mark.non_db_test_override
Expand Down
10 changes: 10 additions & 0 deletions airflow-core/tests/unit/cli/commands/test_command_deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@
["assets", "materialize", "--name=foo"],
"airflowctl assets materialize",
),
(
asset_command.asset_list,
["assets", "list"],
"airflowctl assets list / airflowctl assets list-by-alias",
),
(
asset_command.asset_details,
["assets", "details", "--name=food"],
"airflowctl assets get / airflowctl assets get-by-alias",
),
]


Expand Down
Loading