Skip to content

Commit 529673d

Browse files
author
Sameer Mesiah
committed
Fix upstream map index resolution after placeholder expansion with unit test.
1 parent 057c73f commit 529673d

2 files changed

Lines changed: 145 additions & 0 deletions

File tree

airflow-core/src/airflow/models/taskinstance.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2287,6 +2287,14 @@ def tg2(inp):
22872287
# and "ti_count == ancestor_ti_count" does not work, since the further
22882288
# expansion may be of length 1.
22892289
if not _is_further_mapped_inside(relative, common_ancestor):
2290+
placeholder_index = resolve_placeholder_map_index(
2291+
task=task, relative=relative, map_index=ancestor_map_index, run_id=run_id, session=session
2292+
)
2293+
# Handle cases where an upstream mapped placeholder (map_index = -1) has already
2294+
# been expanded and replaced by its successor (map_index = 0) at evaluation time.
2295+
if placeholder_index is not None:
2296+
return placeholder_index
2297+
22902298
return ancestor_map_index
22912299

22922300
# Otherwise we need a partial aggregation for values from selected task
@@ -2364,6 +2372,54 @@ def _visit_relevant_relatives_for_mapped(mapped_tasks: Iterable[tuple[str, int]]
23642372
return visited
23652373

23662374

2375+
def resolve_placeholder_map_index(
2376+
*,
2377+
task: Operator,
2378+
relative: Operator,
2379+
map_index: int,
2380+
run_id: str,
2381+
session: Session,
2382+
) -> int | None:
2383+
"""
2384+
Resolve the correct map_index for upstream dependency evaluation.
2385+
2386+
This handles the transition from map_index = -1 (pre-expansion placeholder)
2387+
to map_index = 0 (post-expansion placeholder successor).
2388+
2389+
Returns:
2390+
- 0 if the placeholder has transitioned from -1 to 0
2391+
- None if no override should be applied
2392+
"""
2393+
if map_index != -1:
2394+
return None
2395+
2396+
rows = session.execute(
2397+
select(TaskInstance.task_id, TaskInstance.map_index).where(
2398+
TaskInstance.dag_id == relative.dag_id,
2399+
TaskInstance.run_id == run_id,
2400+
TaskInstance.task_id.in_([task.task_id, relative.task_id]),
2401+
TaskInstance.map_index.in_([-1, 0]),
2402+
)
2403+
).all()
2404+
2405+
task_to_map_indexes: dict[str, list[int]] = defaultdict(list)
2406+
for task_id, mi in rows:
2407+
task_to_map_indexes[task_id].append(mi)
2408+
2409+
# We only rewrite when:
2410+
# 1) the current task is still using the placeholder (-1)
2411+
# 2) the upstream placeholder (-1) no longer exists
2412+
# 3) the post-expansion placeholder (0) does exist
2413+
if (
2414+
-1 in task_to_map_indexes.get(task.task_id, [])
2415+
and -1 not in task_to_map_indexes.get(relative.task_id, [])
2416+
and 0 in task_to_map_indexes.get(relative.task_id, [])
2417+
):
2418+
return 0
2419+
2420+
return None
2421+
2422+
23672423
class TaskInstanceNote(Base):
23682424
"""For storage of arbitrary notes concerning the task instance."""
23692425

airflow-core/tests/unit/models/test_taskinstance.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3235,6 +3235,95 @@ def g(v):
32353235
assert result == expected
32363236

32373237

3238+
def test_downstream_placeholder_handles_upstream_post_expansion(dag_maker, session):
3239+
"""
3240+
Test dynamic task mapping behavior when an upstream placeholder task
3241+
(map_index = -1) has been replaced by the first expanded task
3242+
(map_index = 0).
3243+
3244+
This verifies that trigger rule evaluation correctly resolves relevant
3245+
upstream map indexes both when referencing the original placeholder
3246+
and when referencing the first expanded task instance.
3247+
"""
3248+
3249+
with dag_maker(session=session) as dag:
3250+
3251+
@task
3252+
def get_mapping_source():
3253+
return ["one", "two", "three"]
3254+
3255+
@task
3256+
def mapped_task(x):
3257+
output = f"{x}"
3258+
return output
3259+
3260+
@task_group(prefix_group_id=False)
3261+
def the_task_group(x):
3262+
start = MockOperator(task_id="start")
3263+
upstream = mapped_task(x)
3264+
3265+
# Plain downstream inside task group (no mapping source).
3266+
downstream = MockOperator(task_id="downstream")
3267+
3268+
start >> upstream >> downstream
3269+
3270+
mapping_source = get_mapping_source()
3271+
mapped_tg = the_task_group.expand(x=mapping_source)
3272+
3273+
mapping_source >> mapped_tg
3274+
3275+
# Create DAG run and execute prerequisites.
3276+
dr = dag_maker.create_dagrun()
3277+
3278+
dag_maker.run_ti("get_mapping_source", map_index=-1, dag_run=dr, session=session)
3279+
3280+
# Force expansion of the upstream mapped task.
3281+
upstream_task = dag.get_task("mapped_task")
3282+
_, max_index = TaskMap.expand_mapped_task(
3283+
upstream_task,
3284+
dr.run_id,
3285+
session=session,
3286+
)
3287+
expanded_ti_count = max_index + 1
3288+
3289+
downstream_task = dag.get_task("downstream")
3290+
3291+
# Grab the downstream placeholder TI.
3292+
downstream_ti = dr.get_task_instance(task_id="downstream", map_index=-1, session=session)
3293+
downstream_ti.refresh_from_task(downstream_task)
3294+
3295+
result = downstream_ti.get_relevant_upstream_map_indexes(
3296+
upstream=upstream_task,
3297+
ti_count=expanded_ti_count,
3298+
session=session,
3299+
)
3300+
3301+
assert result == 0
3302+
3303+
# Now do the same for downstream expanded (map_index = 0) to ensure existing behavior is not broken.
3304+
# Force expansion of the downstream mapped task.
3305+
_, max_index = TaskMap.expand_mapped_task(
3306+
downstream_task,
3307+
dr.run_id,
3308+
session=session,
3309+
)
3310+
expanded_ti_count = max_index + 1
3311+
3312+
# Grab the first expanded downstream task. Behavior is the same for all cases where map_inddex >= 0.
3313+
downstream_ti = dr.get_task_instance(task_id="downstream", map_index=0, session=session)
3314+
downstream_ti.refresh_from_task(downstream_task)
3315+
3316+
result = downstream_ti.get_relevant_upstream_map_indexes(
3317+
upstream=upstream_task,
3318+
ti_count=expanded_ti_count,
3319+
session=session,
3320+
)
3321+
3322+
# Verify behavior remains unchanged once the downstream task itself
3323+
# has expanded (map_index >= 0).
3324+
assert result == 0
3325+
3326+
32383327
def test_find_relevant_relatives_with_non_mapped_task_as_tuple(dag_maker, session):
32393328
"""Test that specifying a non-mapped task as a tuple doesn't raise NotMapped exception."""
32403329
# t1 -> t2 (non-mapped) -> t3

0 commit comments

Comments
 (0)