diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 55313198d5b39..d28d39cbb2039 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -517,6 +517,17 @@ core: type: boolean example: ~ default: "False" + max_dag_version_cache_size: + description: | + Maximum number of DAG versions to keep in the in-memory cache used by DBDagBag. + When the limit is reached the least-recently-used version is evicted. Increase this + value if you have many concurrently active DAG versions and can afford the memory; + decrease it to reduce the memory footprint of long-running API server or scheduler + processes. + version_added: 3.2.0 + type: integer + example: ~ + default: "512" database: description: ~ options: diff --git a/airflow-core/src/airflow/models/dagbag.py b/airflow-core/src/airflow/models/dagbag.py index 98799bbde0c0c..964052c09539f 100644 --- a/airflow-core/src/airflow/models/dagbag.py +++ b/airflow-core/src/airflow/models/dagbag.py @@ -18,12 +18,14 @@ from __future__ import annotations import hashlib +from collections import OrderedDict from typing import TYPE_CHECKING, Any from uuid import UUID from sqlalchemy import String, select from sqlalchemy.orm import Mapped, joinedload, mapped_column +from airflow.configuration import conf from airflow.models.base import Base, StringID from airflow.models.dag_version import DagVersion @@ -45,13 +47,20 @@ class DBDagBag: """ def __init__(self, load_op_links: bool = True) -> None: - self._dags: dict[UUID, SerializedDagModel] = {} # dag_version_id to dag + self._max_dag_version_cache_size: int = conf.getint( + "core", "max_dag_version_cache_size", fallback=512 + ) + self._dags: OrderedDict[UUID, SerializedDagModel] = OrderedDict() self.load_op_links = load_op_links def _read_dag(self, serialized_dag_model: SerializedDagModel) -> SerializedDAG | None: serialized_dag_model.load_op_links = self.load_op_links if dag := serialized_dag_model.dag: - self._dags[serialized_dag_model.dag_version_id] = serialized_dag_model + version_id = serialized_dag_model.dag_version_id + self._dags[version_id] = serialized_dag_model + self._dags.move_to_end(version_id) + if len(self._dags) > self._max_dag_version_cache_size: + self._dags.popitem(last=False) # evict LRU entry return dag def get_serialized_dag_model(self, version_id: UUID, session: Session) -> SerializedDagModel | None: @@ -77,6 +86,8 @@ def get_serialized_dag_model(self, version_id: UUID, session: Session) -> Serial if not dag_version or not (serialized_dag_model := dag_version.serialized_dag): return None self._read_dag(serialized_dag_model) + else: + self._dags.move_to_end(version_id) # promote to MRU on cache hit return serialized_dag_model def get_dag(self, version_id: UUID, session: Session) -> SerializedDAG | None: diff --git a/airflow-core/tests/unit/api_fastapi/common/test_dagbag.py b/airflow-core/tests/unit/api_fastapi/common/test_dagbag.py index 27f34064e5f77..85c915e083a1c 100644 --- a/airflow-core/tests/unit/api_fastapi/common/test_dagbag.py +++ b/airflow-core/tests/unit/api_fastapi/common/test_dagbag.py @@ -16,11 +16,14 @@ # under the License. from __future__ import annotations +from collections import OrderedDict from unittest import mock +from uuid import uuid4 import pytest from airflow.api_fastapi.app import purge_cached_app +from airflow.models.dagbag import DBDagBag from airflow.sdk import BaseOperator from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags @@ -82,3 +85,61 @@ def test_dagbag_used_as_singleton_in_dependency(self, session, dag_maker, test_c assert resp2.status_code == 200 assert self.dagbag_call_counter["count"] == 1 + + +class TestDBDagBagLRUCache: + """Tests for the bounded LRU eviction behaviour of DBDagBag._dags.""" + + def _make_bag(self, max_size: int) -> DBDagBag: + bag = DBDagBag.__new__(DBDagBag) + bag.load_op_links = True + bag._max_dag_version_cache_size = max_size + bag._dags = OrderedDict() + return bag + + def _make_model(self, version_id): + m = mock.MagicMock() + m.dag_version_id = version_id + m.dag = mock.MagicMock() # truthy — deserialization succeeds + return m + + def test_cache_bounded_by_max_size(self): + """Inserting beyond max_size evicts the least-recently-used entry.""" + bag = self._make_bag(max_size=3) + ids = [uuid4() for _ in range(4)] + for uid in ids: + bag._read_dag(self._make_model(uid)) + + assert len(bag._dags) == 3 + assert ids[0] not in bag._dags # first inserted → LRU → evicted + assert ids[3] in bag._dags + + def test_cache_hit_promotes_to_mru(self): + """A cache hit via get_serialized_dag_model promotes the entry to MRU.""" + bag = self._make_bag(max_size=3) + ids = [uuid4() for _ in range(3)] + models = {uid: self._make_model(uid) for uid in ids} + for uid in ids: + bag._read_dag(models[uid]) + + # Re-access ids[0] through get_serialized_dag_model to promote it + session = mock.MagicMock() + bag.get_serialized_dag_model(ids[0], session=session) + session.get.assert_not_called() # should be a pure cache hit + + # Insert a 4th entry — ids[1] (now LRU) should be evicted, not ids[0] + bag._read_dag(self._make_model(uuid4())) + + assert ids[0] in bag._dags # promoted to MRU, survives + assert ids[1] not in bag._dags # was LRU after ids[0] promoted, evicted + + def test_failed_deserialization_not_cached(self): + """Entries whose .dag property is falsy are not inserted into the cache.""" + bag = self._make_bag(max_size=10) + m = mock.MagicMock() + m.dag_version_id = uuid4() + m.dag = None # deserialization failure + + bag._read_dag(m) + + assert len(bag._dags) == 0