Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion airflow-core/src/airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ class DagAttributeTypes(str, Enum):
TIMEDELTA = "timedelta"
TIMEZONE = "timezone"
RELATIVEDELTA = "relativedelta"
BASE_TRIGGER = "base_trigger"
AIRFLOW_EXC_SER = "airflow_exc_ser"
BASE_EXC_SER = "base_exc_ser"
DICT = "dict"
Expand Down
17 changes: 2 additions & 15 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
validate_and_load_priority_weight_strategy,
)
from airflow.timetables.base import DagRunInfo, Timetable
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.code_utils import get_python_source
from airflow.utils.db import LazySelectSequence

Expand Down Expand Up @@ -470,7 +470,6 @@ def serialize(
:meta private:
"""
from airflow.sdk.definitions._internal.types import is_arg_set
from airflow.sdk.exceptions import TaskDeferred

if not is_arg_set(var):
return cls._encode(None, type_=DAT.ARG_NOT_SET)
Expand Down Expand Up @@ -535,7 +534,7 @@ def serialize(
var._asdict(),
type_=DAT.TASK_INSTANCE_KEY,
)
elif isinstance(var, (AirflowException, TaskDeferred)) and hasattr(var, "serialize"):
elif isinstance(var, AirflowException) and hasattr(var, "serialize"):
exc_cls_name, args, kwargs = var.serialize()
return cls._encode(
cls.serialize(
Expand All @@ -556,14 +555,6 @@ def serialize(
),
type_=DAT.BASE_EXC_SER,
)
elif isinstance(var, BaseTrigger):
return cls._encode(
cls.serialize(
var.serialize(),
strict=strict,
),
type_=DAT.BASE_TRIGGER,
)
elif callable(var):
return str(get_python_source(var))
elif isinstance(var, set):
Expand Down Expand Up @@ -672,10 +663,6 @@ def deserialize(cls, encoded_var: Any) -> Any:
else:
exc_cls = import_string(f"builtins.{exc_cls_name}")
return exc_cls(*args, **kwargs)
elif type_ == DAT.BASE_TRIGGER:
tr_cls_name, kwargs = cls.deserialize(var)
tr_cls = import_string(tr_cls_name)
return tr_cls(**kwargs)
elif type_ == DAT.SET:
return {cls.deserialize(v) for v in var}
elif type_ == DAT.TUPLE:
Expand Down
32 changes: 1 addition & 31 deletions airflow-core/tests/unit/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
AirflowFailException,
AirflowRescheduleException,
SerializationError,
TaskDeferred,
)
from airflow.models.connection import Connection
from airflow.models.dag import DAG
Expand Down Expand Up @@ -104,7 +103,6 @@
LazyDeserializedDAG,
_has_kubernetes,
)
from airflow.triggers.base import BaseTrigger
from airflow.utils.db import LazySelectSequence

from unit.models import DEFAULT_DATE
Expand Down Expand Up @@ -570,42 +568,14 @@ def test_ser_of_asset_event_accessor():
assert d[Asset(name="yo", uri="test://yo")].extra == {"this": "that", "the": "other"}


class MyTrigger(BaseTrigger):
def __init__(self, hi):
self.hi = hi

def serialize(self):
return "unit.serialization.test_serialized_objects.MyTrigger", {"hi": self.hi}

async def run(self):
yield


def test_roundtrip_exceptions():
"""
This is for AIP-44 when we need to send certain non-error exceptions
as part of an RPC call e.g. TaskDeferred or AirflowRescheduleException.
"""
"""Non-error AirflowExceptions (e.g. AirflowRescheduleException) round-trip through BaseSerialization."""
some_date = pendulum.now()
resched_exc = AirflowRescheduleException(reschedule_date=some_date)
ser = BaseSerialization.serialize(resched_exc)
deser = BaseSerialization.deserialize(ser)
assert isinstance(deser, AirflowRescheduleException)
assert deser.reschedule_date == some_date
del ser
del deser
exc = TaskDeferred(
trigger=MyTrigger(hi="yo"),
method_name="meth_name",
kwargs={"have": "pie"},
timeout=timedelta(seconds=30),
)
ser = BaseSerialization.serialize(exc)
deser = BaseSerialization.deserialize(ser)
assert deser.trigger.hi == "yo"
assert deser.method_name == "meth_name"
assert deser.kwargs == {"have": "pie"}
assert deser.timeout == timedelta(seconds=30)


@pytest.mark.parametrize(
Expand Down
Loading