Skip to content

Commit 630de66

Browse files
author
Sameer Mesiah
committed
Fix upstream map index resolution after placeholder expansion with unit test.
1 parent 8c42617 commit 630de66

2 files changed

Lines changed: 145 additions & 0 deletions

File tree

airflow-core/src/airflow/models/taskinstance.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2028,6 +2028,14 @@ def tg2(inp):
20282028
# and "ti_count == ancestor_ti_count" does not work, since the further
20292029
# expansion may be of length 1.
20302030
if not _is_further_mapped_inside(relative, common_ancestor):
2031+
placeholder_index = resolve_placeholder_map_index(
2032+
task=task, relative=relative, map_index=ancestor_map_index, run_id=run_id, session=session
2033+
)
2034+
# Handle cases where an upstream mapped placeholder (map_index = -1) has already
2035+
# been expanded and replaced by its successor (map_index = 0) at evaluation time.
2036+
if placeholder_index is not None:
2037+
return placeholder_index
2038+
20312039
return ancestor_map_index
20322040

20332041
# Otherwise we need a partial aggregation for values from selected task
@@ -2102,6 +2110,54 @@ def _visit_relevant_relatives_for_mapped(mapped_tasks: Iterable[tuple[str, int]]
21022110
return visited
21032111

21042112

2113+
def resolve_placeholder_map_index(
2114+
*,
2115+
task: Operator,
2116+
relative: Operator,
2117+
map_index: int,
2118+
run_id: str,
2119+
session: Session,
2120+
) -> int | None:
2121+
"""
2122+
Resolve the correct map_index for upstream dependency evaluation.
2123+
2124+
This handles the transition from map_index = -1 (pre-expansion placeholder)
2125+
to map_index = 0 (post-expansion placeholder successor).
2126+
2127+
Returns:
2128+
- 0 if the placeholder has transitioned from -1 to 0
2129+
- None if no override should be applied
2130+
"""
2131+
if map_index != -1:
2132+
return None
2133+
2134+
rows = session.execute(
2135+
select(TaskInstance.task_id, TaskInstance.map_index).where(
2136+
TaskInstance.dag_id == relative.dag_id,
2137+
TaskInstance.run_id == run_id,
2138+
TaskInstance.task_id.in_([task.task_id, relative.task_id]),
2139+
TaskInstance.map_index.in_([-1, 0]),
2140+
)
2141+
).all()
2142+
2143+
task_to_map_indexes: dict[str, list[int]] = defaultdict(list)
2144+
for task_id, mi in rows:
2145+
task_to_map_indexes[task_id].append(mi)
2146+
2147+
# We only rewrite when:
2148+
# 1) the current task is still using the placeholder (-1)
2149+
# 2) the upstream placeholder (-1) no longer exists
2150+
# 3) the post-expansion placeholder (0) does exist
2151+
if (
2152+
-1 in task_to_map_indexes.get(task.task_id, [])
2153+
and -1 not in task_to_map_indexes.get(relative.task_id, [])
2154+
and 0 in task_to_map_indexes.get(relative.task_id, [])
2155+
):
2156+
return 0
2157+
2158+
return None
2159+
2160+
21052161
class TaskInstanceNote(Base):
21062162
"""For storage of arbitrary notes concerning the task instance."""
21072163

airflow-core/tests/unit/models/test_taskinstance.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2981,6 +2981,95 @@ def g(v):
29812981
assert result == expected
29822982

29832983

2984+
def test_downstream_placeholder_handles_upstream_post_expansion(dag_maker, session):
2985+
"""
2986+
Test dynamic task mapping behavior when an upstream placeholder task
2987+
(map_index = -1) has been replaced by the first expanded task
2988+
(map_index = 0).
2989+
2990+
This verifies that trigger rule evaluation correctly resolves relevant
2991+
upstream map indexes both when referencing the original placeholder
2992+
and when referencing the first expanded task instance.
2993+
"""
2994+
2995+
with dag_maker(session=session) as dag:
2996+
2997+
@task
2998+
def get_mapping_source():
2999+
return ["one", "two", "three"]
3000+
3001+
@task
3002+
def mapped_task(x):
3003+
output = f"{x}"
3004+
return output
3005+
3006+
@task_group(prefix_group_id=False)
3007+
def the_task_group(x):
3008+
start = MockOperator(task_id="start")
3009+
upstream = mapped_task(x)
3010+
3011+
# Plain downstream inside task group (no mapping source).
3012+
downstream = MockOperator(task_id="downstream")
3013+
3014+
start >> upstream >> downstream
3015+
3016+
mapping_source = get_mapping_source()
3017+
mapped_tg = the_task_group.expand(x=mapping_source)
3018+
3019+
mapping_source >> mapped_tg
3020+
3021+
# Create DAG run and execute prerequisites.
3022+
dr = dag_maker.create_dagrun()
3023+
3024+
dag_maker.run_ti("get_mapping_source", map_index=-1, dag_run=dr, session=session)
3025+
3026+
# Force expansion of the upstream mapped task.
3027+
upstream_task = dag.get_task("mapped_task")
3028+
_, max_index = TaskMap.expand_mapped_task(
3029+
upstream_task,
3030+
dr.run_id,
3031+
session=session,
3032+
)
3033+
expanded_ti_count = max_index + 1
3034+
3035+
downstream_task = dag.get_task("downstream")
3036+
3037+
# Grab the downstream placeholder TI.
3038+
downstream_ti = dr.get_task_instance(task_id="downstream", map_index=-1, session=session)
3039+
downstream_ti.refresh_from_task(downstream_task)
3040+
3041+
result = downstream_ti.get_relevant_upstream_map_indexes(
3042+
upstream=upstream_task,
3043+
ti_count=expanded_ti_count,
3044+
session=session,
3045+
)
3046+
3047+
assert result == 0
3048+
3049+
# Now do the same for downstream expanded (map_index = 0) to ensure existing behavior is not broken.
3050+
# Force expansion of the downstream mapped task.
3051+
_, max_index = TaskMap.expand_mapped_task(
3052+
downstream_task,
3053+
dr.run_id,
3054+
session=session,
3055+
)
3056+
expanded_ti_count = max_index + 1
3057+
3058+
# Grab the first expanded downstream task. Behavior is the same for all cases where map_index >= 0.
3059+
downstream_ti = dr.get_task_instance(task_id="downstream", map_index=0, session=session)
3060+
downstream_ti.refresh_from_task(downstream_task)
3061+
3062+
result = downstream_ti.get_relevant_upstream_map_indexes(
3063+
upstream=upstream_task,
3064+
ti_count=expanded_ti_count,
3065+
session=session,
3066+
)
3067+
3068+
# Verify behavior remains unchanged once the downstream task itself
3069+
# has expanded (map_index >= 0).
3070+
assert result == 0
3071+
3072+
29843073
def test_find_relevant_relatives_with_non_mapped_task_as_tuple(dag_maker, session):
29853074
"""Test that specifying a non-mapped task as a tuple doesn't raise NotMapped exception."""
29863075
# t1 -> t2 (non-mapped) -> t3

0 commit comments

Comments
 (0)