Skip to content

Commit cee1968

Browse files
author
Sameer Mesiah
committed
Fix upstream map index resolution after placeholder expansion with unit test.
1 parent 3cfe4b9 commit cee1968

2 files changed

Lines changed: 145 additions & 0 deletions

File tree

airflow-core/src/airflow/models/taskinstance.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,6 +2026,14 @@ def tg2(inp):
20262026
# and "ti_count == ancestor_ti_count" does not work, since the further
20272027
# expansion may be of length 1.
20282028
if not _is_further_mapped_inside(relative, common_ancestor):
2029+
placeholder_index = resolve_placeholder_map_index(
2030+
task=task, relative=relative, map_index=ancestor_map_index, run_id=run_id, session=session
2031+
)
2032+
# Handle cases where an upstream mapped placeholder (map_index = -1) has already
2033+
# been expanded and replaced by its successor (map_index = 0) at evaluation time.
2034+
if placeholder_index is not None:
2035+
return placeholder_index
2036+
20292037
return ancestor_map_index
20302038

20312039
# Otherwise we need a partial aggregation for values from selected task
@@ -2100,6 +2108,54 @@ def _visit_relevant_relatives_for_mapped(mapped_tasks: Iterable[tuple[str, int]]
21002108
return visited
21012109

21022110

2111+
def resolve_placeholder_map_index(
2112+
*,
2113+
task: Operator,
2114+
relative: Operator,
2115+
map_index: int,
2116+
run_id: str,
2117+
session: Session,
2118+
) -> int | None:
2119+
"""
2120+
Resolve the correct map_index for upstream dependency evaluation.
2121+
2122+
This handles the transition from map_index = -1 (pre-expansion placeholder)
2123+
to map_index = 0 (post-expansion placeholder successor).
2124+
2125+
Returns:
2126+
- 0 if the placeholder has transitioned from -1 to 0
2127+
- None if no override should be applied
2128+
"""
2129+
if map_index != -1:
2130+
return None
2131+
2132+
rows = session.execute(
2133+
select(TaskInstance.task_id, TaskInstance.map_index).where(
2134+
TaskInstance.dag_id == relative.dag_id,
2135+
TaskInstance.run_id == run_id,
2136+
TaskInstance.task_id.in_([task.task_id, relative.task_id]),
2137+
TaskInstance.map_index.in_([-1, 0]),
2138+
)
2139+
).all()
2140+
2141+
task_to_map_indexes: dict[str, list[int]] = defaultdict(list)
2142+
for task_id, mi in rows:
2143+
task_to_map_indexes[task_id].append(mi)
2144+
2145+
# We only rewrite when:
2146+
# 1) the current task is still using the placeholder (-1)
2147+
# 2) the upstream placeholder (-1) no longer exists
2148+
# 3) the post-expansion placeholder (0) does exist
2149+
if (
2150+
-1 in task_to_map_indexes.get(task.task_id, [])
2151+
and -1 not in task_to_map_indexes.get(relative.task_id, [])
2152+
and 0 in task_to_map_indexes.get(relative.task_id, [])
2153+
):
2154+
return 0
2155+
2156+
return None
2157+
2158+
21032159
class TaskInstanceNote(Base):
21042160
"""For storage of arbitrary notes concerning the task instance."""
21052161

airflow-core/tests/unit/models/test_taskinstance.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3019,6 +3019,95 @@ def g(v):
30193019
assert result == expected
30203020

30213021

3022+
def test_downstream_placeholder_handles_upstream_post_expansion(dag_maker, session):
3023+
"""
3024+
Test dynamic task mapping behavior when an upstream placeholder task
3025+
(map_index = -1) has been replaced by the first expanded task
3026+
(map_index = 0).
3027+
3028+
This verifies that trigger rule evaluation correctly resolves relevant
3029+
upstream map indexes both when referencing the original placeholder
3030+
and when referencing the first expanded task instance.
3031+
"""
3032+
3033+
with dag_maker(session=session) as dag:
3034+
3035+
@task
3036+
def get_mapping_source():
3037+
return ["one", "two", "three"]
3038+
3039+
@task
3040+
def mapped_task(x):
3041+
output = f"{x}"
3042+
return output
3043+
3044+
@task_group(prefix_group_id=False)
3045+
def the_task_group(x):
3046+
start = MockOperator(task_id="start")
3047+
upstream = mapped_task(x)
3048+
3049+
# Plain downstream inside task group (no mapping source).
3050+
downstream = MockOperator(task_id="downstream")
3051+
3052+
start >> upstream >> downstream
3053+
3054+
mapping_source = get_mapping_source()
3055+
mapped_tg = the_task_group.expand(x=mapping_source)
3056+
3057+
mapping_source >> mapped_tg
3058+
3059+
# Create DAG run and execute prerequisites.
3060+
dr = dag_maker.create_dagrun()
3061+
3062+
dag_maker.run_ti("get_mapping_source", map_index=-1, dag_run=dr, session=session)
3063+
3064+
# Force expansion of the upstream mapped task.
3065+
upstream_task = dag.get_task("mapped_task")
3066+
_, max_index = TaskMap.expand_mapped_task(
3067+
upstream_task,
3068+
dr.run_id,
3069+
session=session,
3070+
)
3071+
expanded_ti_count = max_index + 1
3072+
3073+
downstream_task = dag.get_task("downstream")
3074+
3075+
# Grab the downstream placeholder TI.
3076+
downstream_ti = dr.get_task_instance(task_id="downstream", map_index=-1, session=session)
3077+
downstream_ti.refresh_from_task(downstream_task)
3078+
3079+
result = downstream_ti.get_relevant_upstream_map_indexes(
3080+
upstream=upstream_task,
3081+
ti_count=expanded_ti_count,
3082+
session=session,
3083+
)
3084+
3085+
assert result == 0
3086+
3087+
# Now do the same for downstream expanded (map_index = 0) to ensure existing behavior is not broken.
3088+
# Force expansion of the downstream mapped task.
3089+
_, max_index = TaskMap.expand_mapped_task(
3090+
downstream_task,
3091+
dr.run_id,
3092+
session=session,
3093+
)
3094+
expanded_ti_count = max_index + 1
3095+
3096+
# Grab the first expanded downstream task. Behavior is the same for all cases where map_index >= 0.
3097+
downstream_ti = dr.get_task_instance(task_id="downstream", map_index=0, session=session)
3098+
downstream_ti.refresh_from_task(downstream_task)
3099+
3100+
result = downstream_ti.get_relevant_upstream_map_indexes(
3101+
upstream=upstream_task,
3102+
ti_count=expanded_ti_count,
3103+
session=session,
3104+
)
3105+
3106+
# Verify behavior remains unchanged once the downstream task itself
3107+
# has expanded (map_index >= 0).
3108+
assert result == 0
3109+
3110+
30223111
def test_find_relevant_relatives_with_non_mapped_task_as_tuple(dag_maker, session):
30233112
"""Test that specifying a non-mapped task as a tuple doesn't raise NotMapped exception."""
30243113
# t1 -> t2 (non-mapped) -> t3

0 commit comments

Comments
 (0)