Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion airflow-core/src/airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ class DagAttributeTypes(str, Enum):
TIMEDELTA = "timedelta"
TIMEZONE = "timezone"
RELATIVEDELTA = "relativedelta"
BASE_TRIGGER = "base_trigger"
AIRFLOW_EXC_SER = "airflow_exc_ser"
BASE_EXC_SER = "base_exc_ser"
DICT = "dict"
Expand Down
50 changes: 33 additions & 17 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
validate_and_load_priority_weight_strategy,
)
from airflow.timetables.base import DagRunInfo, Timetable
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.code_utils import get_python_source
from airflow.utils.db import LazySelectSequence

Expand Down Expand Up @@ -235,6 +235,33 @@ def _decode_priority_weight_strategy(var: str) -> PriorityWeightStrategy:
return priority_weight_strategy_class()


# Module prefixes whose code is trusted to import while deserializing a stored DAG.
# The class path from the serialized blob is validated against these *before*
# import_string runs, so a path outside the trusted namespaces is rejected without
# importing it. Mirrors decode_timetable's prefix gate / the priority-weight registry.
_TRUSTED_DESERIALIZE_PREFIXES = ("airflow.",)


def _safe_import_for_deserialize(cls_name: str, base: type, *, allow_builtins: bool = False) -> type:
"""
Resolve ``cls_name`` to a subclass of ``base``, validating the name before importing.

The check runs on the string before any import, so a class path outside the
trusted namespaces is never imported. A post-import ``issubclass`` check is
kept as a second gate.
"""
module_path = cls_name.rpartition(".")[0]
trusted = cls_name.startswith(_TRUSTED_DESERIALIZE_PREFIXES) or (
allow_builtins and module_path == "builtins"
)
if not trusted:
raise ValueError(f"Refusing to deserialize disallowed class path {cls_name!r}")
cls = import_string(cls_name)
if not (isinstance(cls, type) and issubclass(cls, base)):
raise ValueError(f"{cls_name!r} is not a {base.__name__} subclass")
return cls


def _encode_start_trigger_args(var: StartTriggerArgs) -> dict[str, Any]:
"""Encode a StartTriggerArgs."""

Expand Down Expand Up @@ -470,7 +497,6 @@ def serialize(
:meta private:
"""
from airflow.sdk.definitions._internal.types import is_arg_set
from airflow.sdk.exceptions import TaskDeferred

if not is_arg_set(var):
return cls._encode(None, type_=DAT.ARG_NOT_SET)
Expand Down Expand Up @@ -535,7 +561,7 @@ def serialize(
var._asdict(),
type_=DAT.TASK_INSTANCE_KEY,
)
elif isinstance(var, (AirflowException, TaskDeferred)) and hasattr(var, "serialize"):
elif isinstance(var, AirflowException) and hasattr(var, "serialize"):
exc_cls_name, args, kwargs = var.serialize()
return cls._encode(
cls.serialize(
Expand All @@ -556,14 +582,6 @@ def serialize(
),
type_=DAT.BASE_EXC_SER,
)
elif isinstance(var, BaseTrigger):
return cls._encode(
cls.serialize(
var.serialize(),
strict=strict,
),
type_=DAT.BASE_TRIGGER,
)
elif callable(var):
return str(get_python_source(var))
elif isinstance(var, set):
Expand Down Expand Up @@ -668,14 +686,12 @@ def deserialize(cls, encoded_var: Any) -> Any:
kwargs = deser["kwargs"]
del deser
if type_ == DAT.AIRFLOW_EXC_SER:
exc_cls = import_string(exc_cls_name)
exc_cls = _safe_import_for_deserialize(exc_cls_name, BaseException)
else:
exc_cls = import_string(f"builtins.{exc_cls_name}")
exc_cls = _safe_import_for_deserialize(
f"builtins.{exc_cls_name}", BaseException, allow_builtins=True
)
return exc_cls(*args, **kwargs)
elif type_ == DAT.BASE_TRIGGER:
tr_cls_name, kwargs = cls.deserialize(var)
tr_cls = import_string(tr_cls_name)
return tr_cls(**kwargs)
elif type_ == DAT.SET:
return {cls.deserialize(v) for v in var}
elif type_ == DAT.TUPLE:
Expand Down
36 changes: 36 additions & 0 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2786,6 +2786,42 @@ def test_create_dagrun_accepts_partition_key_for_partition_at_runtime_dag(self,
dr = dag_maker.create_dagrun(partition_key="runtime-key")
assert dr.partition_key == "runtime-key"

def test_airflow_exc_deserialization_rejects_disallowed_class_path(self):
"""An AIRFLOW_EXC_SER class path outside the trusted namespaces is rejected before import."""
from airflow.serialization.enums import DagAttributeTypes

encoded = BaseSerialization._encode(
BaseSerialization.serialize(
{"exc_cls_name": "subprocess.check_output", "args": [], "kwargs": {}}
),
type_=DagAttributeTypes.AIRFLOW_EXC_SER,
)
with pytest.raises(ValueError, match="Refusing to deserialize disallowed class path"):
BaseSerialization.deserialize(encoded)

def test_base_exc_deserialization_rejects_non_exception(self):
"""A builtins name that is not a BaseException subclass is rejected."""
from airflow.serialization.enums import DagAttributeTypes

encoded = BaseSerialization._encode(
BaseSerialization.serialize({"exc_cls_name": "eval", "args": ["1"], "kwargs": {}}),
type_=DagAttributeTypes.BASE_EXC_SER,
)
with pytest.raises(ValueError, match="not a BaseException subclass"):
BaseSerialization.deserialize(encoded)

def test_base_exc_deserialization_roundtrips_builtin_exception(self):
"""A genuine builtin exception still deserializes."""
from airflow.serialization.enums import DagAttributeTypes

encoded = BaseSerialization._encode(
BaseSerialization.serialize({"exc_cls_name": "ValueError", "args": ["boom"], "kwargs": {}}),
type_=DagAttributeTypes.BASE_EXC_SER,
)
result = BaseSerialization.deserialize(encoded)
assert isinstance(result, ValueError)
assert result.args == ("boom",)


def test_kubernetes_optional():
"""Test that serialization module loads without kubernetes, but deserialization of PODs requires it"""
Expand Down
32 changes: 1 addition & 31 deletions airflow-core/tests/unit/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
AirflowFailException,
AirflowRescheduleException,
SerializationError,
TaskDeferred,
)
from airflow.models.connection import Connection
from airflow.models.dag import DAG
Expand Down Expand Up @@ -104,7 +103,6 @@
LazyDeserializedDAG,
_has_kubernetes,
)
from airflow.triggers.base import BaseTrigger
from airflow.utils.db import LazySelectSequence

from unit.models import DEFAULT_DATE
Expand Down Expand Up @@ -570,42 +568,14 @@ def test_ser_of_asset_event_accessor():
assert d[Asset(name="yo", uri="test://yo")].extra == {"this": "that", "the": "other"}


class MyTrigger(BaseTrigger):
def __init__(self, hi):
self.hi = hi

def serialize(self):
return "unit.serialization.test_serialized_objects.MyTrigger", {"hi": self.hi}

async def run(self):
yield


def test_roundtrip_exceptions():
"""
This is for AIP-44 when we need to send certain non-error exceptions
as part of an RPC call e.g. TaskDeferred or AirflowRescheduleException.
"""
"""Non-error AirflowExceptions (e.g. AirflowRescheduleException) round-trip through BaseSerialization."""
some_date = pendulum.now()
resched_exc = AirflowRescheduleException(reschedule_date=some_date)
ser = BaseSerialization.serialize(resched_exc)
deser = BaseSerialization.deserialize(ser)
assert isinstance(deser, AirflowRescheduleException)
assert deser.reschedule_date == some_date
del ser
del deser
exc = TaskDeferred(
trigger=MyTrigger(hi="yo"),
method_name="meth_name",
kwargs={"have": "pie"},
timeout=timedelta(seconds=30),
)
ser = BaseSerialization.serialize(exc)
deser = BaseSerialization.deserialize(ser)
assert deser.trigger.hi == "yo"
assert deser.method_name == "meth_name"
assert deser.kwargs == {"have": "pie"}
assert deser.timeout == timedelta(seconds=30)


@pytest.mark.parametrize(
Expand Down
Loading