Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions task-sdk/src/airflow/sdk/definitions/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from __future__ import annotations

import logging
from collections.abc import Sequence
from typing import Any

Expand All @@ -26,8 +25,6 @@
from airflow.sdk.definitions._internal.types import NOTSET
from airflow.sdk.log import mask_secret

log = logging.getLogger(__name__)


@attrs.define
class Variable:
Expand Down Expand Up @@ -60,13 +57,9 @@ def get(cls, key: str, default: Any = NOTSET, deserialize_json: bool = False):

@classmethod
def set(cls, key: str, value: Any, description: str | None = None, serialize_json: bool = False) -> None:
from airflow.sdk.exceptions import AirflowRuntimeError
from airflow.sdk.execution_time.context import _set_variable

try:
return _set_variable(key, value, description, serialize_json=serialize_json)
except AirflowRuntimeError as e:
log.exception(e)
_set_variable(key, value, description, serialize_json=serialize_json)

@classmethod
def keys(cls, prefix: str | None = None) -> Sequence[str]:
Expand Down Expand Up @@ -94,10 +87,6 @@ def keys(cls, prefix: str | None = None) -> Sequence[str]:

@classmethod
def delete(cls, key: str) -> None:
from airflow.sdk.exceptions import AirflowRuntimeError
from airflow.sdk.execution_time.context import _delete_variable

try:
_delete_variable(key=key)
except AirflowRuntimeError as e:
log.exception(e)
_delete_variable(key=key)
24 changes: 24 additions & 0 deletions task-sdk/tests/task_sdk/definitions/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,30 @@ def test_var_set(self, key, value, description, serialize_json, mock_supervisor_
),
)

def test_var_set_raises_on_error(self, mock_supervisor_comms):
"""Variable.set() must propagate AirflowRuntimeError so the task fails on a rejected write."""
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse

mock_supervisor_comms.send.side_effect = AirflowRuntimeError(
error=ErrorResponse(error=ErrorType.API_SERVER_ERROR, detail={"message": "forbidden"})
)

with pytest.raises(AirflowRuntimeError):
Variable.set(key="forbidden_key", value="v")

def test_var_delete_raises_on_error(self, mock_supervisor_comms):
"""Variable.delete() must propagate AirflowRuntimeError so the task fails on a rejected delete."""
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse

mock_supervisor_comms.send.side_effect = AirflowRuntimeError(
error=ErrorResponse(error=ErrorType.API_SERVER_ERROR, detail={"message": "forbidden"})
)

with pytest.raises(AirflowRuntimeError):
Variable.delete(key="forbidden_key")


class TestVariableKeys:
@pytest.mark.parametrize(
Expand Down