Skip to content

Commit 6628c9d

Browse files
committed
fix task group serialize logics
1 parent 34b2fd5 commit 6628c9d

7 files changed

Lines changed: 15 additions & 30 deletions

File tree

airflow-core/src/airflow/serialization/definitions/baseoperator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,8 @@ def __getattr__(self, name):
326326
# For regular attributes, use task_type in the error message
327327
raise AttributeError(f"'{self.task_type}' object has no attribute '{name}'")
328328

329-
def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
330-
return DagAttributeTypes.OP, self.task_id
329+
def serialize_for_task_group(self) -> list:
330+
return [DagAttributeTypes.OP.value, self.task_id]
331331

332332
def expand_start_from_trigger(self, *, context: Context) -> bool:
333333
"""

airflow-core/src/airflow/serialization/definitions/mappedoperator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,9 @@ def get_extra_links(self, ti: TaskInstance, name: str) -> str | None:
413413
return None
414414
return link.get_link(self, ti_key=ti.key)
415415

416-
def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
416+
def serialize_for_task_group(self) -> list:
417417
"""Implement DAGNode."""
418-
return DagAttributeTypes.OP, self.task_id
418+
return [DagAttributeTypes.OP.value, self.task_id]
419419

420420
# TODO (GH-52141): Copied from sdk. Find a better place for this to live in.
421421
def _get_specified_expand_input(self) -> SchedulerExpandInput:

airflow-core/src/airflow/serialization/serialized_objects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2092,7 +2092,7 @@ def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None:
20922092
"ui_color": task_group.ui_color,
20932093
"ui_fgcolor": task_group.ui_fgcolor,
20942094
"children": {
2095-
label: child.match_serialize_task_group_form() for label, child in task_group.children.items()
2095+
label: child.serialize_for_task_group() for label, child in task_group.children.items()
20962096
},
20972097
"upstream_group_ids": cls.serialize(sorted(task_group.upstream_group_ids)),
20982098
"downstream_group_ids": cls.serialize(sorted(task_group.downstream_group_ids)),

task-sdk/src/airflow/sdk/bases/operator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def db_safe_priority(priority_weight: int) -> int:
9595
from airflow.sdk.definitions.operator_resources import Resources
9696
from airflow.sdk.definitions.taskgroup import TaskGroup
9797
from airflow.sdk.definitions.xcom_arg import XComArg
98-
from airflow.serialization.enums import DagAttributeTypes
9998
from airflow.task.priority_strategy import PriorityWeightStrategy
10099
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
101100

@@ -1552,11 +1551,11 @@ def prepare_for_execution(self) -> Self:
15521551
other._lock_for_execution = True
15531552
return other
15541553

1555-
def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
1554+
def serialize_for_task_group(self) -> list:
15561555
"""Serialize; required by DAGNode."""
15571556
from airflow.serialization.enums import DagAttributeTypes
15581557

1559-
return DagAttributeTypes.OP, self.task_id
1558+
return [DagAttributeTypes.OP.value, self.task_id]
15601559

15611560
def unmap(self, resolve: None | Mapping[str, Any]) -> Self:
15621561
"""

task-sdk/src/airflow/sdk/definitions/_internal/node.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from abc import ABCMeta, abstractmethod
2222
from collections.abc import Sequence
2323
from datetime import datetime
24-
from typing import TYPE_CHECKING, Any
24+
from typing import TYPE_CHECKING
2525

2626
from airflow.sdk._shared.dagnode.node import GenericDAGNode
2727
from airflow.sdk.definitions._internal.mixins import DependencyMixin
@@ -31,7 +31,6 @@
3131
from airflow.sdk.definitions.edges import EdgeModifier
3232
from airflow.sdk.definitions.taskgroup import TaskGroup # noqa: F401
3333
from airflow.sdk.types import Operator # noqa: F401
34-
from airflow.serialization.enums import DagAttributeTypes
3534

3635

3736
KEY_REGEX = re.compile(r"^[\w.-]+$")
@@ -157,18 +156,6 @@ def set_upstream(
157156
"""Set a node (or nodes) to be directly upstream from the current node."""
158157
self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier)
159158

160-
def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
159+
def serialize_for_task_group(self) -> list:
161160
"""Serialize a task group's content; used by TaskGroupSerialization."""
162161
raise NotImplementedError()
163-
164-
def match_serialize_task_group_form(self) -> list:
165-
"""
166-
Match the serialized task_group format modified during inter-process communication.
167-
168-
The serialized task_group from dag-process gets modified during inter-process communication.
169-
(<DagAttributeTypes.OP: 'operator'>, 'task_id') -> ['operator', 'task_id']
170-
171-
This function aligns the values to match the modified state after IPC.
172-
"""
173-
serialized_task_group = self.serialize_for_task_group()
174-
return [serialized_task_group[0].value, serialized_task_group[1]]

task-sdk/src/airflow/sdk/definitions/mappedoperator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -713,9 +713,9 @@ def output(self) -> XComArg:
713713

714714
return XComArg(operator=self)
715715

716-
def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
716+
def serialize_for_task_group(self) -> list:
717717
"""Implement DAGNode."""
718-
return DagAttributeTypes.OP, self.task_id
718+
return [DagAttributeTypes.OP.value, self.task_id]
719719

720720
def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
721721
"""

task-sdk/src/airflow/sdk/definitions/taskgroup.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from airflow.sdk.definitions.dag import DAG
4444
from airflow.sdk.definitions.edges import EdgeModifier
4545
from airflow.sdk.types import Operator
46-
from airflow.serialization.enums import DagAttributeTypes
4746

4847

4948
def _default_parent_group() -> TaskGroup | None:
@@ -491,15 +490,15 @@ def get_child_by_label(self, label: str) -> DAGNode:
491490
"""Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)."""
492491
return self.children[self.child_id(label)]
493492

494-
def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
493+
def serialize_for_task_group(self) -> list:
495494
"""Serialize task group; required by DagNode."""
496495
from airflow.serialization.enums import DagAttributeTypes
497496
from airflow.serialization.serialized_objects import TaskGroupSerialization
498497

499-
return (
500-
DagAttributeTypes.TASK_GROUP,
498+
return [
499+
DagAttributeTypes.TASK_GROUP.value,
501500
TaskGroupSerialization.serialize_task_group(self),
502-
)
501+
]
503502

504503
def hierarchical_alphabetical_sort(self):
505504
"""

0 commit comments

Comments
 (0)