diff --git a/airflow-core/src/airflow/cli/commands/variable_command.py b/airflow-core/src/airflow/cli/commands/variable_command.py index 5216aabb446c6..2ad1c88ea6a2f 100644 --- a/airflow-core/src/airflow/cli/commands/variable_command.py +++ b/airflow-core/src/airflow/cli/commands/variable_command.py @@ -23,10 +23,13 @@ import os from typing import TYPE_CHECKING +from airflowctl.api.operations import ServerResponseError from sqlalchemy import select +from airflow.api_fastapi.core_api.datamodels.variables import VariableBody +from airflow.cli.api_client import NEW_API_CLIENT, Client, provide_api_client from airflow.cli.simple_table import AirflowConsole -from airflow.cli.utils import SENSITIVE_PLACEHOLDER, print_export_output +from airflow.cli.utils import SENSITIVE_PLACEHOLDER, deprecated_for_airflowctl, print_export_output from airflow.exceptions import ( AirflowFileParseException, AirflowUnsupportedFileTypeException, @@ -108,11 +111,24 @@ def variables_get(args): @cli_utils.action_cli +@deprecated_for_airflowctl("airflowctl variables set") +@suppress_logs_and_warning @providers_configuration_loaded -def variables_set(args): - """Create new variable with a given name, value and description.""" - Variable.set(args.key, args.value, args.description, serialize_json=args.json) - print(f"Variable {args.key} created") +@provide_api_client +def variables_set(args, api_client: Client = NEW_API_CLIENT): + """Set a variable, creating it if it does not exist and updating it otherwise.""" + value = json.dumps(args.value, indent=2) if args.json else args.value + variable_body = VariableBody(key=args.key, value=value, description=args.description) + try: + api_client.variables.get(variable_key=args.key) + except ServerResponseError as e: + if e.response.status_code == 404: + api_client.variables.create(variable=variable_body) # type: ignore[arg-type] + print(f"Variable {args.key} created") + return + raise + api_client.variables.update(variable=variable_body) # type: ignore[arg-type] + print(f"Variable {args.key} updated") @cli_utils.action_cli diff --git a/airflow-core/tests/unit/cli/commands/test_command_deprecations.py b/airflow-core/tests/unit/cli/commands/test_command_deprecations.py index b4eb6840c9069..c329278a66c5a 100644 --- a/airflow-core/tests/unit/cli/commands/test_command_deprecations.py +++ b/airflow-core/tests/unit/cli/commands/test_command_deprecations.py @@ -30,7 +30,7 @@ import pytest -from airflow.cli.commands import asset_command, dag_command, pool_command +from airflow.cli.commands import asset_command, dag_command, pool_command, variable_command from airflow.exceptions import RemovedInAirflow4Warning # (command callable, argv to parse, expected airflowctl replacement named in the warning) @@ -52,6 +52,7 @@ ["assets", "materialize", "--name=foo"], "airflowctl assets materialize", ), + (variable_command.variables_set, ["variables", "set", "foo", "bar"], "airflowctl variables set"), ] diff --git a/airflow-core/tests/unit/cli/commands/test_variable_command.py b/airflow-core/tests/unit/cli/commands/test_variable_command.py index a02c95aa31722..26c94a61c9301 100644 --- a/airflow-core/tests/unit/cli/commands/test_variable_command.py +++ b/airflow-core/tests/unit/cli/commands/test_variable_command.py @@ -22,8 +22,11 @@ from contextlib import redirect_stdout from io import StringIO +import httpx import pytest import yaml +from airflowctl.api.datamodels.generated import VariableResponse +from airflowctl.api.operations import ServerResponseError from sqlalchemy import select from airflow import models @@ -37,6 +40,12 @@ pytestmark = pytest.mark.db_test +def _server_error(status_code: int) -> ServerResponseError: + request = httpx.Request("GET", "http://testserver/api/v2/variables/foo") + response = httpx.Response(status_code, request=request, json={"detail": "boom"}) + return ServerResponseError(message="boom", request=request, response=response) + + # Test data fixtures @pytest.fixture def simple_variable_data(): @@ -129,28 +138,63 @@ def setup_method(self): def teardown_method(self): clear_db_variables() - def test_variables_set(self): - """Test variable_set command""" + def test_variables_set_creates_when_missing(self, mock_cli_api_client): + """``set`` creates the variable when it does not yet exist.""" + mock_cli_api_client.variables.get.side_effect = _server_error(404) + variable_command.variables_set(self.parser.parse_args(["variables", "set", "foo", "bar"])) - assert Variable.get("foo") is not None - with pytest.raises(KeyError): - Variable.get("foo1") - def test_variables_set_with_description(self): - """Test variable_set command with optional description argument""" - expected_var_desc = "foo_bar_description" - var_key = "foo" + mock_cli_api_client.variables.create.assert_called_once() + mock_cli_api_client.variables.update.assert_not_called() + body = mock_cli_api_client.variables.create.call_args.kwargs["variable"] + assert body.key == "foo" + assert body.value == "bar" + assert body.description is None + + def test_variables_set_updates_when_exists(self, mock_cli_api_client): + """``set`` updates the variable when it already exists.""" + mock_cli_api_client.variables.get.return_value = VariableResponse( + key="foo", value="old", is_encrypted=False + ) + + variable_command.variables_set(self.parser.parse_args(["variables", "set", "foo", "new"])) + + mock_cli_api_client.variables.update.assert_called_once() + mock_cli_api_client.variables.create.assert_not_called() + body = mock_cli_api_client.variables.update.call_args.kwargs["variable"] + assert body.key == "foo" + assert body.value == "new" + + def test_variables_set_with_description(self, mock_cli_api_client): + """``set`` forwards the optional description to the client.""" + mock_cli_api_client.variables.get.side_effect = _server_error(404) + variable_command.variables_set( - self.parser.parse_args(["variables", "set", var_key, "bar", "--description", expected_var_desc]) + self.parser.parse_args(["variables", "set", "foo", "bar", "--description", "foo_bar_description"]) ) - assert Variable.get(var_key) == "bar" - with create_session() as session: - actual_var_desc = session.scalar(select(Variable.description).where(Variable.key == var_key)) - assert actual_var_desc == expected_var_desc + body = mock_cli_api_client.variables.create.call_args.kwargs["variable"] + assert body.key == "foo" + assert body.value == "bar" + assert body.description == "foo_bar_description" - with pytest.raises(KeyError): - Variable.get("foo1") + def test_variables_set_serialize_json(self, mock_cli_api_client): + """``--json`` serializes the value before sending it.""" + mock_cli_api_client.variables.get.side_effect = _server_error(404) + + variable_command.variables_set( + self.parser.parse_args(["variables", "set", "foo", '{"a": 1}', "--json"]) + ) + + body = mock_cli_api_client.variables.create.call_args.kwargs["variable"] + assert body.key == "foo" + assert body.value == json.dumps('{"a": 1}', indent=2) + + def test_variables_set_reraises_non_404_error(self, mock_cli_api_client): + """Errors other than 404 from the existence check propagate.""" + mock_cli_api_client.variables.get.side_effect = _server_error(500) + with pytest.raises(ServerResponseError): + variable_command.variables_set(self.parser.parse_args(["variables", "set", "foo", "bar"])) def test_variables_get(self, stdout_capture): Variable.set("foo", {"foo": "bar"}, serialize_json=True) @@ -171,25 +215,19 @@ def test_get_variable_missing_variable(self): variable_command.variables_get(self.parser.parse_args(["variables", "get", "no-existing-VAR"])) def test_variables_set_different_types(self): - """Test storage of various data types""" - # Set a dict - variable_command.variables_set( - self.parser.parse_args(["variables", "set", "dict", '{"foo": "oops"}']) - ) - # Set a list - variable_command.variables_set(self.parser.parse_args(["variables", "set", "list", '["oops"]'])) - # Set str - variable_command.variables_set(self.parser.parse_args(["variables", "set", "str", "hello string"])) - # Set int - variable_command.variables_set(self.parser.parse_args(["variables", "set", "int", "42"])) - # Set float - variable_command.variables_set(self.parser.parse_args(["variables", "set", "float", "42.0"])) - # Set true - variable_command.variables_set(self.parser.parse_args(["variables", "set", "true", "true"])) - # Set false - variable_command.variables_set(self.parser.parse_args(["variables", "set", "false", "false"])) - # Set none - variable_command.variables_set(self.parser.parse_args(["variables", "set", "null", "null"])) + """Test export/import round-trips storage of various data types. + + ``set`` is migrated to the airflowctl client, so the variables are seeded directly + through the model here; ``export``/``import`` remain local DB commands. + """ + Variable.set("dict", '{"foo": "oops"}') + Variable.set("list", '["oops"]') + Variable.set("str", "hello string") + Variable.set("int", "42") + Variable.set("float", "42.0") + Variable.set("true", "true") + Variable.set("false", "false") + Variable.set("null", "null") # Export and then import variable_command.variables_export( @@ -210,8 +248,8 @@ def test_variables_set_different_types(self): assert Variable.get("null", deserialize_json=True) is None # test variable import skip existing - # set varliable list to ["airflow"] and have it skip during import - variable_command.variables_set(self.parser.parse_args(["variables", "set", "list", '["airflow"]'])) + # set variable list to ["airflow"] and have it skip during import + Variable.set("list", '["airflow"]') variable_command.variables_import( self.parser.parse_args( ["variables", "import", "variables_types.json", "--action-on-existing-key", "skip"] @@ -325,8 +363,8 @@ def test_variables_list_edge_cases(self): assert item["val"] == "***" def test_variables_delete(self): - """Test variable_delete command""" - variable_command.variables_set(self.parser.parse_args(["variables", "set", "foo", "bar"])) + """Test variable_delete command (``set`` is migrated, so seed via the model)""" + Variable.set("foo", "bar") variable_command.variables_delete(self.parser.parse_args(["variables", "delete", "foo"])) with pytest.raises(KeyError): Variable.get("foo") @@ -365,13 +403,13 @@ def test_variables_isolation(self, tmp_path): path1 = tmp_path / "testfile1.json" path2 = tmp_path / "testfile2.json" - # First export - variable_command.variables_set(self.parser.parse_args(["variables", "set", "foo", '{"foo":"bar"}'])) - variable_command.variables_set(self.parser.parse_args(["variables", "set", "bar", "original"])) + # First export (``set`` is migrated to airflowctl, so seed via the model) + Variable.set("foo", '{"foo":"bar"}') + Variable.set("bar", "original") variable_command.variables_export(self.parser.parse_args(["variables", "export", os.fspath(path1)])) - variable_command.variables_set(self.parser.parse_args(["variables", "set", "bar", "updated"])) - variable_command.variables_set(self.parser.parse_args(["variables", "set", "foo", '{"foo":"oops"}'])) + Variable.set("bar", "updated") + Variable.set("foo", '{"foo":"oops"}') variable_command.variables_delete(self.parser.parse_args(["variables", "delete", "foo"])) with create_session() as session: variable_command.variables_import( @@ -389,13 +427,10 @@ def test_variables_isolation(self, tmp_path): def test_variables_import_and_export_with_description(self, tmp_path): """Test variables_import with file-description parameter""" variables_types_file = tmp_path / "variables_types.json" - variable_command.variables_set( - self.parser.parse_args(["variables", "set", "foo", "bar", "--description", "Foo var description"]) - ) - variable_command.variables_set( - self.parser.parse_args(["variables", "set", "foo1", "bar1", "--description", "12"]) - ) - variable_command.variables_set(self.parser.parse_args(["variables", "set", "foo2", "bar2"])) + # ``set`` is migrated to airflowctl, so seed the variables via the model + Variable.set("foo", "bar", description="Foo var description") + Variable.set("foo1", "bar1", description="12") + Variable.set("foo2", "bar2") variable_command.variables_export( self.parser.parse_args(["variables", "export", os.fspath(variables_types_file)]) ) diff --git a/airflow-ctl/docs/images/command_hashes.txt b/airflow-ctl/docs/images/command_hashes.txt index 53c93e7546d1e..9fc2811f099bb 100644 --- a/airflow-ctl/docs/images/command_hashes.txt +++ b/airflow-ctl/docs/images/command_hashes.txt @@ -9,7 +9,7 @@ dagrun:c32e0011aa9a845456c778786717208e jobs:a5b644c5da8889443bb40ee10b599270 pools:19efe105b9515ab1926ebcaf0e028d71 providers:34502fe09dc0b8b0a13e7e46efdffda6 -variables:f8fc76d3d398b2780f4e97f7cd816646 +variables:68cf6c7b27960c35e5e96895053a349f version:31f4efdf8de0dbaaa4fac71ff7efecc3 plugins:4864fd8f356704bd2b3cd1aec3567e35 auth login:9fe2bb1dd5c602beea2eefb33a2b20a8 diff --git a/airflow-ctl/docs/images/output_variables.svg b/airflow-ctl/docs/images/output_variables.svg index a8833a923899d..91dfe2863342b 100644 --- a/airflow-ctl/docs/images/output_variables.svg +++ b/airflow-ctl/docs/images/output_variables.svg @@ -1,4 +1,4 @@ - + - - + + - + - + - + - + - + - + - + - + - + - + - + - + - + - + + + + + + + - + - + - - Usage:airflowctl variables [-hCOMMAND... - -Perform Variables operations - -Positional Arguments: -COMMAND -createCreate a new variable -deleteDelete a variable by its key -getRetrieve a variable by its key -importImport variables from a file exported with local CLI. -listList all variables -updateUpdate an existing variable - -Options: --h--helpshow this help message and exit + + Usage:airflowctl variables [-hCOMMAND... + +Perform Variables operations + +Positional Arguments: +COMMAND +createCreate a new variable +deleteDelete a variable by its key +getRetrieve a variable by its key +importImport variables from a file exported with local CLI. +listList all variables +setSet a variable, creating it if it does not exist and updating +it otherwise. +updateUpdate an existing variable + +Options: +-h--helpshow this help message and exit diff --git a/airflow-ctl/src/airflowctl/ctl/cli_config.py b/airflow-ctl/src/airflowctl/ctl/cli_config.py index 11ff4542e01ef..83210a3873c5f 100755 --- a/airflow-ctl/src/airflowctl/ctl/cli_config.py +++ b/airflow-ctl/src/airflowctl/ctl/cli_config.py @@ -276,6 +276,21 @@ def _load_help_texts_yaml() -> dict[str, dict[str, str]]: choices=("overwrite", "fail", "skip"), ) +# Variable command args +ARG_VAR_KEY = Arg(flags=("key",), type=str, help="Variable key") +ARG_VAR_VALUE = Arg(flags=("value",), metavar="VALUE", type=str, help="Variable value") +ARG_VAR_DESCRIPTION = Arg( + flags=("--description",), + type=str, + default=None, + help="Variable description, optional when setting a variable", +) +ARG_VAR_SERIALIZE_JSON = Arg( + flags=("-j", "--serialize-json"), + action="store_true", + help="Serialize JSON variable", +) + # Config arguments ARG_CONFIG_SECTION = Arg( flags=("--section",), @@ -1007,6 +1022,12 @@ def merge_commands( ) VARIABLE_COMMANDS = ( + ActionCommand( + name="set", + help="Set a variable, creating it if it does not exist and updating it otherwise.", + func=lazy_load_command("airflowctl.ctl.commands.variable_command.set_"), + args=(ARG_VAR_KEY, ARG_VAR_VALUE, ARG_VAR_DESCRIPTION, ARG_VAR_SERIALIZE_JSON), + ), ActionCommand( name="import", help="Import variables from a file exported with local CLI.", diff --git a/airflow-ctl/src/airflowctl/ctl/commands/variable_command.py b/airflow-ctl/src/airflowctl/ctl/commands/variable_command.py index 19321002e4442..9291c6cd66cb8 100644 --- a/airflow-ctl/src/airflowctl/ctl/commands/variable_command.py +++ b/airflow-ctl/src/airflowctl/ctl/commands/variable_command.py @@ -30,12 +30,34 @@ BulkCreateActionVariableBody, VariableBody, ) +from airflowctl.api.operations import ServerResponseError def _print_file_error(message: str, file_path: str) -> None: Console().print(f"[red]{message}: {file_path}", soft_wrap=True) +@provide_api_client(kind=ClientKind.CLI) +def set_(args, api_client=NEW_API_CLIENT) -> None: + """Set a variable, creating it if it does not exist and updating it otherwise.""" + value = args.value + if args.serialize_json: + value = json.dumps(value) + variable_body = VariableBody(key=args.key, value=value, description=args.description) + + try: + api_client.variables.get(variable_key=args.key) + except ServerResponseError as e: + if e.response.status_code == 404: + api_client.variables.create(variable=variable_body) + rich.print(f"[green]Variable {args.key} created[/green]") + return + raise + + api_client.variables.update(variable=variable_body) + rich.print(f"[green]Variable {args.key} updated[/green]") + + @provide_api_client(kind=ClientKind.CLI) def import_(args, api_client=NEW_API_CLIENT) -> list[str]: """Import variables from a given file.""" diff --git a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_variable_command.py b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_variable_command.py index f573585935fd4..7267cf441e5ae 100644 --- a/airflow-ctl/tests/airflow_ctl/ctl/commands/test_variable_command.py +++ b/airflow-ctl/tests/airflow_ctl/ctl/commands/test_variable_command.py @@ -18,7 +18,9 @@ import json from types import SimpleNamespace +from unittest import mock +import httpx import pytest from airflowctl.api.client import ClientKind @@ -28,10 +30,17 @@ VariableCollectionResponse, VariableResponse, ) +from airflowctl.api.operations import ServerResponseError from airflowctl.ctl import cli_parser from airflowctl.ctl.commands import variable_command +def _server_error(status_code: int) -> ServerResponseError: + request = httpx.Request("GET", "http://testserver/api/v2/variables/key") + response = httpx.Response(status_code, request=request, json={"detail": "boom"}) + return ServerResponseError(message="boom", request=request, response=response) + + class TestCliVariableCommands: key = "key" value = "value" @@ -63,6 +72,75 @@ class TestCliVariableCommands: delete=None, ) + def test_set_creates_when_missing(self): + """When the key does not exist, ``set`` falls back to creating it.""" + api_client = mock.MagicMock() + api_client.variables.get.side_effect = _server_error(404) + + variable_command.set_( + self.parser.parse_args(["variables", "set", "new_key", "new_value"]), + api_client=api_client, + ) + + api_client.variables.create.assert_called_once() + api_client.variables.update.assert_not_called() + body = api_client.variables.create.call_args.kwargs["variable"] + assert body.key == "new_key" + assert body.value.root == "new_value" + + def test_set_updates_when_exists(self): + """When the key already exists, ``set`` updates it instead of creating.""" + api_client = mock.MagicMock() + api_client.variables.get.return_value = self.variable_collection_response.variables[0] + + variable_command.set_( + self.parser.parse_args(["variables", "set", self.key, "updated_value"]), + api_client=api_client, + ) + + api_client.variables.update.assert_called_once() + api_client.variables.create.assert_not_called() + body = api_client.variables.update.call_args.kwargs["variable"] + assert body.key == self.key + assert body.value.root == "updated_value" + + def test_set_serialize_json(self): + """``--serialize-json`` JSON-encodes the value before sending it.""" + api_client = mock.MagicMock() + api_client.variables.get.side_effect = _server_error(404) + + variable_command.set_( + self.parser.parse_args(["variables", "set", "json_key", '{"a": 1}', "--serialize-json"]), + api_client=api_client, + ) + + body = api_client.variables.create.call_args.kwargs["variable"] + assert body.value.root == json.dumps('{"a": 1}') + + def test_set_forwards_description(self): + """``--description`` is forwarded to the created variable.""" + api_client = mock.MagicMock() + api_client.variables.get.side_effect = _server_error(404) + + variable_command.set_( + self.parser.parse_args(["variables", "set", "key", "value", "--description", "a description"]), + api_client=api_client, + ) + + body = api_client.variables.create.call_args.kwargs["variable"] + assert body.description == "a description" + + def test_set_reraises_non_404_error(self): + """Errors other than 404 from the existence check propagate.""" + api_client = mock.MagicMock() + api_client.variables.get.side_effect = _server_error(500) + + with pytest.raises(ServerResponseError): + variable_command.set_( + self.parser.parse_args(["variables", "set", "key", "value"]), + api_client=api_client, + ) + def test_import_success(self, api_client_maker, tmp_path, monkeypatch): api_client = api_client_maker( path="/api/v2/variables",