Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions airflow-core/src/airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,17 @@ core:
type: boolean
example: ~
default: "False"
max_dag_version_cache_size:
description: |
Maximum number of DAG versions to keep in the in-memory cache used by DBDagBag.
When the limit is reached the least-recently-used version is evicted. Increase this
value if you have many concurrently active DAG versions and can afford the memory;
decrease it to reduce the memory footprint of long-running API server or scheduler
processes.
version_added: 3.2.0
type: integer
example: ~
default: "512"
database:
description: ~
options:
Expand Down
15 changes: 13 additions & 2 deletions airflow-core/src/airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from __future__ import annotations

import hashlib
from collections import OrderedDict
from typing import TYPE_CHECKING, Any
from uuid import UUID

from sqlalchemy import String, select
from sqlalchemy.orm import Mapped, joinedload, mapped_column

from airflow.configuration import conf
from airflow.models.base import Base, StringID
from airflow.models.dag_version import DagVersion

Expand All @@ -45,13 +47,20 @@ class DBDagBag:
"""

def __init__(self, load_op_links: bool = True) -> None:
self._dags: dict[UUID, SerializedDagModel] = {} # dag_version_id to dag
self._max_dag_version_cache_size: int = conf.getint(
"core", "max_dag_version_cache_size", fallback=512
)
self._dags: OrderedDict[UUID, SerializedDagModel] = OrderedDict()
self.load_op_links = load_op_links

def _read_dag(self, serialized_dag_model: SerializedDagModel) -> SerializedDAG | None:
serialized_dag_model.load_op_links = self.load_op_links
if dag := serialized_dag_model.dag:
self._dags[serialized_dag_model.dag_version_id] = serialized_dag_model
version_id = serialized_dag_model.dag_version_id
self._dags[version_id] = serialized_dag_model
self._dags.move_to_end(version_id)
if len(self._dags) > self._max_dag_version_cache_size:
self._dags.popitem(last=False) # evict LRU entry
return dag

def get_serialized_dag_model(self, version_id: UUID, session: Session) -> SerializedDagModel | None:
Expand All @@ -77,6 +86,8 @@ def get_serialized_dag_model(self, version_id: UUID, session: Session) -> Serial
if not dag_version or not (serialized_dag_model := dag_version.serialized_dag):
return None
self._read_dag(serialized_dag_model)
else:
self._dags.move_to_end(version_id) # promote to MRU on cache hit
return serialized_dag_model

def get_dag(self, version_id: UUID, session: Session) -> SerializedDAG | None:
Expand Down
61 changes: 61 additions & 0 deletions airflow-core/tests/unit/api_fastapi/common/test_dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
# under the License.
from __future__ import annotations

from collections import OrderedDict
from unittest import mock
from uuid import uuid4

import pytest

from airflow.api_fastapi.app import purge_cached_app
from airflow.models.dagbag import DBDagBag
from airflow.sdk import BaseOperator

from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags
Expand Down Expand Up @@ -82,3 +85,61 @@ def test_dagbag_used_as_singleton_in_dependency(self, session, dag_maker, test_c
assert resp2.status_code == 200

assert self.dagbag_call_counter["count"] == 1


class TestDBDagBagLRUCache:
"""Tests for the bounded LRU eviction behaviour of DBDagBag._dags."""

def _make_bag(self, max_size: int) -> DBDagBag:
bag = DBDagBag.__new__(DBDagBag)
bag.load_op_links = True
bag._max_dag_version_cache_size = max_size
bag._dags = OrderedDict()
return bag

def _make_model(self, version_id):
m = mock.MagicMock()
m.dag_version_id = version_id
m.dag = mock.MagicMock() # truthy — deserialization succeeds
return m

def test_cache_bounded_by_max_size(self):
"""Inserting beyond max_size evicts the least-recently-used entry."""
bag = self._make_bag(max_size=3)
ids = [uuid4() for _ in range(4)]
for uid in ids:
bag._read_dag(self._make_model(uid))

assert len(bag._dags) == 3
assert ids[0] not in bag._dags # first inserted → LRU → evicted
assert ids[3] in bag._dags

def test_cache_hit_promotes_to_mru(self):
"""A cache hit via get_serialized_dag_model promotes the entry to MRU."""
bag = self._make_bag(max_size=3)
ids = [uuid4() for _ in range(3)]
models = {uid: self._make_model(uid) for uid in ids}
for uid in ids:
bag._read_dag(models[uid])

# Re-access ids[0] through get_serialized_dag_model to promote it
session = mock.MagicMock()
bag.get_serialized_dag_model(ids[0], session=session)
session.get.assert_not_called() # should be a pure cache hit

# Insert a 4th entry — ids[1] (now LRU) should be evicted, not ids[0]
bag._read_dag(self._make_model(uuid4()))

assert ids[0] in bag._dags # promoted to MRU, survives
assert ids[1] not in bag._dags # was LRU after ids[0] promoted, evicted

def test_failed_deserialization_not_cached(self):
"""Entries whose .dag property is falsy are not inserted into the cache."""
bag = self._make_bag(max_size=10)
m = mock.MagicMock()
m.dag_version_id = uuid4()
m.dag = None # deserialization failure

bag._read_dag(m)

assert len(bag._dags) == 0
Loading