Skip to content

Commit e80b6b3

Browse files
committed
chore: add unit tests
1 parent 60a9e96 commit e80b6b3

2 files changed

Lines changed: 160 additions & 0 deletions

File tree

airflow-core/tests/unit/models/test_dagrun.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2422,6 +2422,76 @@ def mapped_print_value(arg):
24222422
assert len(success_tis) == rerun_length
24232423

24242424

2425+
def test_mapped_task_length_reduction_rerun_downstream_not_deadlocked(session, dag_maker):
2426+
@task
2427+
def producer():
2428+
context = get_current_context()
2429+
if context["ti"].try_number == 0:
2430+
return [i for i in range(3)]
2431+
return [i for i in range(2)]
2432+
2433+
@task
2434+
def work(arg):
2435+
return arg
2436+
2437+
@task
2438+
def finish(data):
2439+
return sum(data)
2440+
2441+
def _task_ids(tis):
2442+
return [(ti.task_id, ti.map_index) for ti in tis]
2443+
2444+
with dag_maker(session=session):
2445+
produced = producer()
2446+
mapped = work.expand(arg=produced)
2447+
done = finish(produced)
2448+
mapped >> done
2449+
2450+
dr: DagRun = dag_maker.create_dagrun()
2451+
2452+
# First run with 3 mapped task instances.
2453+
dag_maker.run_ti("producer", dr)
2454+
decision = dr.task_instance_scheduling_decisions(session=session)
2455+
assert _task_ids(decision.schedulable_tis) == [("work", 0), ("work", 1), ("work", 2)]
2456+
2457+
for ti in decision.schedulable_tis:
2458+
dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index)
2459+
decision = dr.task_instance_scheduling_decisions(session=session)
2460+
assert _task_ids(decision.schedulable_tis) == [("finish", -1)]
2461+
dag_maker.run_ti("finish", dr)
2462+
2463+
# Clear and rerun with one fewer mapped task instance.
2464+
clear_task_instances(dr.get_task_instances(session=session), session=session)
2465+
ti = dr.get_task_instance(task_id="producer", session=session)
2466+
ti.try_number += 1
2467+
session.merge(ti)
2468+
2469+
dag_maker.run_ti("producer", dr)
2470+
decision = dr.task_instance_scheduling_decisions(session=session)
2471+
assert _task_ids(decision.schedulable_tis) == [("work", 0), ("work", 1)]
2472+
2473+
mapped_states = session.execute(
2474+
select(TI.map_index, TI.state)
2475+
.where(TI.task_id == "work", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id)
2476+
.order_by(TI.map_index)
2477+
).all()
2478+
assert mapped_states == [
2479+
(0, State.NONE),
2480+
(1, State.NONE),
2481+
(2, TaskInstanceState.REMOVED),
2482+
]
2483+
2484+
for ti in decision.schedulable_tis:
2485+
dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index)
2486+
decision = dr.task_instance_scheduling_decisions(session=session)
2487+
assert _task_ids(decision.schedulable_tis) == [("finish", -1)]
2488+
2489+
dag_maker.run_ti("finish", dr)
2490+
finish_ti = dr.get_task_instance(task_id="finish", map_index=-1, session=session)
2491+
assert finish_ti
2492+
assert finish_ti.state == TaskInstanceState.SUCCESS
2493+
2494+
24252495
def test_operator_mapped_task_group_receives_value(dag_maker, session):
24262496
with dag_maker(session=session):
24272497

airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,96 @@ def test_mapped_task_upstream_removed_with_none_failed_trigger_rules(
13221322

13231323
_test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed)
13241324

1325+
@pytest.mark.parametrize("flag_upstream_failed", [True, False])
1326+
@pytest.mark.parametrize(
1327+
("trigger_rule", "upstream_states"),
1328+
[
1329+
(
1330+
TriggerRule.ALL_SUCCESS,
1331+
_UpstreamTIStates(
1332+
success=3,
1333+
skipped=0,
1334+
failed=0,
1335+
upstream_failed=0,
1336+
removed=2,
1337+
done=5,
1338+
skipped_setup=0,
1339+
success_setup=0,
1340+
),
1341+
),
1342+
(
1343+
TriggerRule.ALL_FAILED,
1344+
_UpstreamTIStates(
1345+
success=0,
1346+
skipped=0,
1347+
failed=3,
1348+
upstream_failed=0,
1349+
removed=2,
1350+
done=5,
1351+
skipped_setup=0,
1352+
success_setup=0,
1353+
),
1354+
),
1355+
(
1356+
TriggerRule.NONE_FAILED,
1357+
_UpstreamTIStates(
1358+
success=3,
1359+
skipped=0,
1360+
failed=0,
1361+
upstream_failed=0,
1362+
removed=2,
1363+
done=5,
1364+
skipped_setup=0,
1365+
success_setup=0,
1366+
),
1367+
),
1368+
(
1369+
TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
1370+
_UpstreamTIStates(
1371+
success=3,
1372+
skipped=0,
1373+
failed=0,
1374+
upstream_failed=0,
1375+
removed=2,
1376+
done=5,
1377+
skipped_setup=0,
1378+
success_setup=0,
1379+
),
1380+
),
1381+
(
1382+
TriggerRule.ALL_DONE_MIN_ONE_SUCCESS,
1383+
_UpstreamTIStates(
1384+
success=3,
1385+
skipped=0,
1386+
failed=0,
1387+
upstream_failed=0,
1388+
removed=2,
1389+
done=5,
1390+
skipped_setup=0,
1391+
success_setup=0,
1392+
),
1393+
),
1394+
],
1395+
)
1396+
def test_non_mapped_task_ignores_removed_upstream_tis(
1397+
self,
1398+
monkeypatch,
1399+
session,
1400+
get_task_instance,
1401+
flag_upstream_failed,
1402+
trigger_rule,
1403+
upstream_states,
1404+
):
1405+
"""
1406+
Non-mapped trigger-rule checks should exclude removed upstream task instances.
1407+
"""
1408+
ti = get_task_instance(
1409+
trigger_rule,
1410+
normal_tasks=["upstream_1", "upstream_2", "upstream_3", "upstream_4", "upstream_5"],
1411+
)
1412+
monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: upstream_states)
1413+
_test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed)
1414+
13251415

13261416
def test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session):
13271417
from airflow.sdk import task, task_group

0 commit comments

Comments
 (0)