Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,33 @@ def _decode_priority_weight_strategy(var: str) -> PriorityWeightStrategy:
return priority_weight_strategy_class()


# Module prefixes whose code is trusted to import while deserializing a stored DAG.
# The class path from the serialized blob is validated against these *before*
# import_string runs, so a path outside the trusted namespaces is rejected without
# importing it. Mirrors decode_timetable's prefix gate / the priority-weight registry.
_TRUSTED_DESERIALIZE_PREFIXES = ("airflow.",)


def _safe_import_for_deserialize(cls_name: str, base: type, *, allow_builtins: bool = False) -> type:
"""
Resolve ``cls_name`` to a subclass of ``base``, validating the name before importing.

The check runs on the string before any import, so a class path outside the
trusted namespaces is never imported. A post-import ``issubclass`` check is
kept as a second gate.
"""
module_path = cls_name.rpartition(".")[0]
trusted = cls_name.startswith(_TRUSTED_DESERIALIZE_PREFIXES) or (
allow_builtins and module_path == "builtins"
)
if not trusted:
raise ValueError(f"Refusing to deserialize disallowed class path {cls_name!r}")
cls = import_string(cls_name)
if not (isinstance(cls, type) and issubclass(cls, base)):
raise ValueError(f"{cls_name!r} is not a {base.__name__} subclass")
return cls


def _encode_start_trigger_args(var: StartTriggerArgs) -> dict[str, Any]:
"""Encode a StartTriggerArgs."""

Expand Down Expand Up @@ -674,7 +701,7 @@ def deserialize(cls, encoded_var: Any) -> Any:
return exc_cls(*args, **kwargs)
elif type_ == DAT.BASE_TRIGGER:
tr_cls_name, kwargs = cls.deserialize(var)
tr_cls = import_string(tr_cls_name)
tr_cls = _safe_import_for_deserialize(tr_cls_name, BaseTrigger)
return tr_cls(**kwargs)
elif type_ == DAT.SET:
return {cls.deserialize(v) for v in var}
Expand Down
24 changes: 24 additions & 0 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2786,6 +2786,30 @@ def test_create_dagrun_accepts_partition_key_for_partition_at_runtime_dag(self,
dr = dag_maker.create_dagrun(partition_key="runtime-key")
assert dr.partition_key == "runtime-key"

def test_base_trigger_deserialization_rejects_disallowed_class_path(self):
"""A BASE_TRIGGER class path outside the trusted namespaces is rejected before import."""
from airflow.serialization.enums import DagAttributeTypes

# subprocess.run is importable; asserting the pre-import message (not the
# subclass message) proves the import is never attempted.
encoded = BaseSerialization._encode(
BaseSerialization.serialize(["subprocess.run", {"args": ["true"]}]),
type_=DagAttributeTypes.BASE_TRIGGER,
)
with pytest.raises(ValueError, match="Refusing to deserialize disallowed class path"):
BaseSerialization.deserialize(encoded)

def test_base_trigger_deserialization_rejects_non_trigger_class(self):
"""A trusted-namespace class that is not a BaseTrigger subclass is still rejected."""
from airflow.serialization.enums import DagAttributeTypes

encoded = BaseSerialization._encode(
BaseSerialization.serialize(["airflow.models.dag.DAG", {}]),
type_=DagAttributeTypes.BASE_TRIGGER,
)
with pytest.raises(ValueError, match="not a BaseTrigger subclass"):
BaseSerialization.deserialize(encoded)


def test_kubernetes_optional():
"""Test that serialization module loads without kubernetes, but deserialization of PODs requires it"""
Expand Down
Loading