Skip to content

Commit eac9478

Browse files
author
Sameer Mesiah
committed
Fix upstream map index resolution after placeholder expansion with unit test.
1 parent a9a9c5a commit eac9478

File tree

2 files changed

+145
-0
lines changed

2 files changed

+145
-0
lines changed

airflow-core/src/airflow/models/taskinstance.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,6 +2186,14 @@ def tg2(inp):
21862186
# and "ti_count == ancestor_ti_count" does not work, since the further
21872187
# expansion may be of length 1.
21882188
if not _is_further_mapped_inside(relative, common_ancestor):
2189+
placeholder_index = resolve_placeholder_map_index(
2190+
task=task, relative=relative, map_index=ancestor_map_index, run_id=run_id, session=session
2191+
)
2192+
# Handle cases where an upstream mapped placeholder (map_index = -1) has already
2193+
# been expanded and replaced by its successor (map_index = 0) at evaluation time.
2194+
if placeholder_index is not None:
2195+
return placeholder_index
2196+
21892197
return ancestor_map_index
21902198

21912199
# Otherwise we need a partial aggregation for values from selected task
@@ -2260,6 +2268,54 @@ def _visit_relevant_relatives_for_mapped(mapped_tasks: Iterable[tuple[str, int]]
22602268
return visited
22612269

22622270

2271+
def resolve_placeholder_map_index(
2272+
*,
2273+
task: Operator,
2274+
relative: Operator,
2275+
map_index: int,
2276+
run_id: str,
2277+
session: Session,
2278+
) -> int | None:
2279+
"""
2280+
Resolve the correct map_index for upstream dependency evaluation.
2281+
2282+
This handles the transition from map_index = -1 (pre-expansion placeholder)
2283+
to map_index = 0 (post-expansion placeholder successor).
2284+
2285+
Returns:
2286+
- 0 if the placeholder has transitioned from -1 to 0
2287+
- None if no override should be applied
2288+
"""
2289+
if map_index != -1:
2290+
return None
2291+
2292+
rows = session.execute(
2293+
select(TaskInstance.task_id, TaskInstance.map_index).where(
2294+
TaskInstance.dag_id == relative.dag_id,
2295+
TaskInstance.run_id == run_id,
2296+
TaskInstance.task_id.in_([task.task_id, relative.task_id]),
2297+
TaskInstance.map_index.in_([-1, 0]),
2298+
)
2299+
).all()
2300+
2301+
task_to_map_indexes: dict[str, list[int]] = defaultdict(list)
2302+
for task_id, mi in rows:
2303+
task_to_map_indexes[task_id].append(mi)
2304+
2305+
# We only rewrite when:
2306+
# 1) the current task is still using the placeholder (-1)
2307+
# 2) the upstream placeholder (-1) no longer exists
2308+
# 3) the post-expansion placeholder (0) does exist
2309+
if (
2310+
-1 in task_to_map_indexes.get(task.task_id, [])
2311+
and -1 not in task_to_map_indexes.get(relative.task_id, [])
2312+
and 0 in task_to_map_indexes.get(relative.task_id, [])
2313+
):
2314+
return 0
2315+
2316+
return None
2317+
2318+
22632319
class TaskInstanceNote(Base):
22642320
"""For storage of arbitrary notes concerning the task instance."""
22652321

airflow-core/tests/unit/models/test_taskinstance.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3102,6 +3102,95 @@ def g(v):
31023102
assert result == expected
31033103

31043104

3105+
def test_downstream_placeholder_handles_upstream_post_expansion(dag_maker, session):
3106+
"""
3107+
Test dynamic task mapping behavior when an upstream placeholder task
3108+
(map_index = -1) has been replaced by the first expanded task
3109+
(map_index = 0).
3110+
3111+
This verifies that trigger rule evaluation correctly resolves relevant
3112+
upstream map indexes both when referencing the original placeholder
3113+
and when referencing the first expanded task instance.
3114+
"""
3115+
3116+
with dag_maker(session=session) as dag:
3117+
3118+
@task
3119+
def get_mapping_source():
3120+
return ["one", "two", "three"]
3121+
3122+
@task
3123+
def mapped_task(x):
3124+
output = f"{x}"
3125+
return output
3126+
3127+
@task_group(prefix_group_id=False)
3128+
def the_task_group(x):
3129+
start = MockOperator(task_id="start")
3130+
upstream = mapped_task(x)
3131+
3132+
# Plain downstream inside task group (no mapping source).
3133+
downstream = MockOperator(task_id="downstream")
3134+
3135+
start >> upstream >> downstream
3136+
3137+
mapping_source = get_mapping_source()
3138+
mapped_tg = the_task_group.expand(x=mapping_source)
3139+
3140+
mapping_source >> mapped_tg
3141+
3142+
# Create DAG run and execute prerequisites.
3143+
dr = dag_maker.create_dagrun()
3144+
3145+
dag_maker.run_ti("get_mapping_source", map_index=-1, dag_run=dr, session=session)
3146+
3147+
# Force expansion of the upstream mapped task.
3148+
upstream_task = dag.get_task("mapped_task")
3149+
_, max_index = TaskMap.expand_mapped_task(
3150+
upstream_task,
3151+
dr.run_id,
3152+
session=session,
3153+
)
3154+
expanded_ti_count = max_index + 1
3155+
3156+
downstream_task = dag.get_task("downstream")
3157+
3158+
# Grab the downstream placeholder TI.
3159+
downstream_ti = dr.get_task_instance(task_id="downstream", map_index=-1, session=session)
3160+
downstream_ti.refresh_from_task(downstream_task)
3161+
3162+
result = downstream_ti.get_relevant_upstream_map_indexes(
3163+
upstream=upstream_task,
3164+
ti_count=expanded_ti_count,
3165+
session=session,
3166+
)
3167+
3168+
assert result == 0
3169+
3170+
# Now do the same for downstream expanded (map_index = 0) to ensure existing behavior is not broken.
3171+
# Force expansion of the downstream mapped task.
3172+
_, max_index = TaskMap.expand_mapped_task(
3173+
downstream_task,
3174+
dr.run_id,
3175+
session=session,
3176+
)
3177+
expanded_ti_count = max_index + 1
3178+
3179+
# Grab the first expanded downstream task. Behavior is the same for all cases where map_index >= 0.
3180+
downstream_ti = dr.get_task_instance(task_id="downstream", map_index=0, session=session)
3181+
downstream_ti.refresh_from_task(downstream_task)
3182+
3183+
result = downstream_ti.get_relevant_upstream_map_indexes(
3184+
upstream=upstream_task,
3185+
ti_count=expanded_ti_count,
3186+
session=session,
3187+
)
3188+
3189+
# Verify behavior remains unchanged once the downstream task itself
3190+
# has expanded (map_index >= 0).
3191+
assert result == 0
3192+
3193+
31053194
def test_find_relevant_relatives_with_non_mapped_task_as_tuple(dag_maker, session):
31063195
"""Test that specifying a non-mapped task as a tuple doesn't raise NotMapped exception."""
31073196
# t1 -> t2 (non-mapped) -> t3

0 commit comments

Comments
 (0)