diff --git a/task-sdk/src/airflow/sdk/definitions/variable.py b/task-sdk/src/airflow/sdk/definitions/variable.py index a379022a90918..4d2d3b17e8440 100644 --- a/task-sdk/src/airflow/sdk/definitions/variable.py +++ b/task-sdk/src/airflow/sdk/definitions/variable.py @@ -67,6 +67,7 @@ def set(cls, key: str, value: Any, description: str | None = None, serialize_jso return _set_variable(key, value, description, serialize_json=serialize_json) except AirflowRuntimeError as e: log.exception(e) + raise @classmethod def keys(cls, prefix: str | None = None) -> Sequence[str]: @@ -101,3 +102,4 @@ def delete(cls, key: str) -> None: _delete_variable(key=key) except AirflowRuntimeError as e: log.exception(e) + raise diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py b/task-sdk/tests/task_sdk/definitions/test_variables.py index 6e94ccf503f8c..2f9d3d9e4d6e6 100644 --- a/task-sdk/tests/task_sdk/definitions/test_variables.py +++ b/task-sdk/tests/task_sdk/definitions/test_variables.py @@ -25,7 +25,15 @@ from airflow.sdk import Variable from airflow.sdk.configuration import initialize_secrets_backends -from airflow.sdk.execution_time.comms import GetVariableKeys, PutVariable, VariableKeysResult, VariableResult +from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType +from airflow.sdk.execution_time.comms import ( + DeleteVariable, + ErrorResponse, + GetVariableKeys, + PutVariable, + VariableKeysResult, + VariableResult, +) from airflow.sdk.execution_time.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS from tests_common.test_utils.config import conf_vars @@ -89,6 +97,31 @@ def test_var_set(self, key, value, description, serialize_json, mock_supervisor_ ), ) + def test_var_set_raises_on_runtime_error(self, mock_supervisor_comms): + with mock.patch( + "airflow.sdk.execution_time.context._set_variable", + side_effect=AirflowRuntimeError( + ErrorResponse(error=ErrorType.GENERIC_ERROR, detail={"message": "API_SERVER_ERROR"}) + ), + ): + with pytest.raises(AirflowRuntimeError): + Variable.set(key="my_key", value="my_value") + + def test_var_delete(self, mock_supervisor_comms): + Variable.delete(key="my_key") + + mock_supervisor_comms.send.assert_called_once_with(msg=DeleteVariable(key="my_key")) + + def test_var_delete_raises_on_runtime_error(self, mock_supervisor_comms): + with mock.patch( + "airflow.sdk.execution_time.context._delete_variable", + side_effect=AirflowRuntimeError( + ErrorResponse(error=ErrorType.GENERIC_ERROR, detail={"message": "API_SERVER_ERROR"}) + ), + ): + with pytest.raises(AirflowRuntimeError): + Variable.delete(key="my_key") + class TestVariableKeys: @pytest.mark.parametrize(