From 4edbe5c9b0192525c4eed37835b3717c162c9acb Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Tue, 2 Jun 2026 13:30:23 +0200 Subject: [PATCH] feat: clean up AuthDB refresh tokens by dropping partitions Replace the row-level DELETE sweeps for the RefreshTokens table with maintenance of its monthly JTI range-partitions: drop partitions whose whole month is older than the retention horizon, and add partitions ahead of time so the p_future catch-all never fills. Dropping a partition is an O(1) metadata operation and avoids the row-lock and lock-memory cost of large DELETEs. Retention is now expressed in calendar months via the new DIRACX_SERVICE_AUTH_REFRESH_TOKEN_RETENTION_MONTHS setting (default 6), replacing revoked_refresh_token_retention_minutes. Partition maintenance is implemented for MySQL only and raises NotImplementedError for other dialects. The unpartitioned flow tables keep their existing DELETE-based cleanup. --- diracx-core/src/diracx/core/settings.py | 9 +- diracx-db/src/diracx/db/sql/auth/db.py | 151 +++++++++++--- diracx-db/tests/auth/test_partitions.py | 185 ++++++++++++++++++ diracx-db/tests/auth/test_refresh_token.py | 43 ---- diracx-logic/src/diracx/logic/__main__.py | 6 +- .../src/diracx/logic/auth/management.py | 15 +- docs/admin/reference/env-variables.md | 11 +- 7 files changed, 332 insertions(+), 88 deletions(-) create mode 100644 diracx-db/tests/auth/test_partitions.py diff --git a/diracx-core/src/diracx/core/settings.py b/diracx-core/src/diracx/core/settings.py index dd55a88a8..820de242c 100644 --- a/diracx-core/src/diracx/core/settings.py +++ b/diracx-core/src/diracx/core/settings.py @@ -251,11 +251,12 @@ def check_retention_greater_than_expiration(self) -> Self: through a new authentication flow. Default: 60 minutes. """ - revoked_refresh_token_retention_minutes: int = 43200 - """Retention time in minutes for revoked refresh tokens. + refresh_token_retention_months: int = 6 + """Retention time in months for refresh tokens. - The maximum retention time of refresh tokens after being - revoked and before they are deleted. Default: 43200 minutes (30 days). + Refresh tokens live in monthly partitions that are dropped once the whole + month is older than this many months. It is therefore the longest a refresh + token (revoked or not) is kept before removal. Default: 6 months. """ available_properties: set[SecurityProperty] = Field( diff --git a/diracx-db/src/diracx/db/sql/auth/db.py b/diracx-db/src/diracx/db/sql/auth/db.py index caf1b0226..a35a0ad19 100644 --- a/diracx-db/src/diracx/db/sql/auth/db.py +++ b/diracx-db/src/diracx/db/sql/auth/db.py @@ -1,10 +1,12 @@ from __future__ import annotations import logging +import re import secrets from datetime import UTC, datetime from itertools import pairwise +from dateutil.relativedelta import relativedelta from dateutil.rrule import MONTHLY, rrule from sqlalchemy import delete, insert, select, text, update from sqlalchemy.exc import IntegrityError, NoResultFound @@ -32,6 +34,58 @@ logger = logging.getLogger(__name__) +# Always keep at least this many months of future RefreshTokens partitions ahead +# of "now" so the ``p_future`` catch-all partition never accumulates rows. +PARTITION_MONTHS_AHEAD = 12 + + +def _month_start(dt: datetime) -> datetime: + """Truncate ``dt`` to the first instant of its month.""" + return dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + + +def _partition_name(month_start: datetime) -> str: + """Name of the partition holding the tokens created during ``month_start``.""" + return f"p_{month_start.year}_{month_start.month}" + + +def _partition_boundary(dt: datetime) -> str: + """``RANGE COLUMNS(JTI)`` upper bound (exclusive) for tokens created before ``dt``.""" + return str(uuid7_from_datetime(dt, randomize=False)).replace("-", "") + + +def plan_partition_maintenance( + existing_months: list[datetime], + now: datetime, + retention_months: int, + months_ahead: int, +) -> tuple[list[datetime], list[datetime]]: + """Decide which monthly ``RefreshTokens`` partitions to drop and to add. + + ``existing_months`` are the month-start datetimes of the existing + ``p__`` partitions (excluding ``p_future``). Returns + ``(months_to_drop, months_to_add)`` as month-start datetimes. + """ + existing = sorted(existing_months) + + # A partition for month ``m`` holds tokens created before ``m + 1 month``, so + # the whole partition is expired once that upper bound is older than the + # retention horizon. Keeping ``retention_months`` worth of partitions never + # drops a token younger than that many calendar months. + horizon = now - relativedelta(months=retention_months) + months_to_drop = [m for m in existing if m + relativedelta(months=1) <= horizon] + + # Ensure a partition exists for every month up to ``now + months_ahead`` by + # appending months above the highest existing partition. + target_last = _month_start(now) + relativedelta(months=months_ahead) + cursor = max(existing) if existing else _month_start(now) - relativedelta(months=1) + months_to_add: list[datetime] = [] + while cursor < target_last: + cursor += relativedelta(months=1) + months_to_add.append(cursor) + + return months_to_drop, months_to_add + class AuthDB(BaseSQLDB): metadata = AuthDBBase.metadata @@ -67,8 +121,8 @@ async def post_create(cls, conn: AsyncConnection) -> None: partition_list = [] for name, limit in pairwise(dates): partition_list.append( - f"PARTITION p_{name.year}_{name.month} " - f"VALUES LESS THAN ('{str(uuid7_from_datetime(limit, randomize=False)).replace('-', '')}')" + f"PARTITION {_partition_name(name)} " + f"VALUES LESS THAN ('{_partition_boundary(limit)}')" ) partition_list.append("PARTITION p_future VALUES LESS THAN (MAXVALUE)") @@ -340,37 +394,84 @@ async def revoke_user_refresh_tokens(self, subject): .values(status=RefreshTokenStatus.REVOKED) ) - async def clean_expired_refresh_tokens(self, max_validity: int) -> int: - """Delete expired refresh tokens. + async def maintain_refresh_token_partitions( + self, + retention_months: int, + months_ahead: int = PARTITION_MONTHS_AHEAD, + ) -> None: + """Maintain the monthly partitions of the RefreshTokens table. - max_validity: Maximum validity time in minutes for refresh tokens. + Drops partitions whose entire month is older than ``retention_months`` + and adds partitions ahead of time so the ``p_future`` catch-all never + fills. Cleanup of expired refresh tokens is achieved by dropping whole + partitions rather than deleting rows. + + Only implemented for MySQL; raises ``NotImplementedError`` for any other + dialect (the table is only partitioned on MySQL). """ - expired_date = str( - uuid7_from_datetime(substract_date(minutes=max_validity), randomize=False) - ) - stmt_expired = delete(RefreshTokens).where( - RefreshTokens.status == RefreshTokenStatus.CREATED, - RefreshTokens.jti < expired_date, + dialect = self.conn.dialect.name + if dialect != "mysql": + raise NotImplementedError( + "Refresh token partition maintenance is only implemented for " + f"MySQL, not {dialect!r}" + ) + + check_partition_query = text( + "SELECT PARTITION_NAME FROM information_schema.partitions " + "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'RefreshTokens' " + "AND PARTITION_NAME IS NOT NULL" ) - res_expired = await self.conn.execute(stmt_expired) + partition_names = (await self.conn.execute(check_partition_query)).all() - return res_expired.rowcount + existing_months = [] + for (name,) in partition_names: + if match := re.fullmatch(r"p_(\d+)_(\d+)", name): + existing_months.append( + datetime(int(match.group(1)), int(match.group(2)), 1, tzinfo=UTC) + ) - async def clean_revoked_refresh_tokens(self, max_retention: int) -> int: - """Delete old revoked refresh tokens. + if not existing_months: + logger.warning( + "RefreshTokens is not partitioned; skipping partition maintenance. " + "Partition the table manually (see AuthDB.post_create)." + ) + return - max_retention: Maximum retention time in minutes for revoked refresh tokens. - """ - revoked_date = str( - uuid7_from_datetime(substract_date(minutes=max_retention), randomize=False) + months_to_drop, months_to_add = plan_partition_maintenance( + existing_months, + now=datetime.now(tz=UTC), + retention_months=retention_months, + months_ahead=months_ahead, ) - stmt_revoked = delete(RefreshTokens).where( - RefreshTokens.status == RefreshTokenStatus.REVOKED, - RefreshTokens.jti < revoked_date, - ) - res_revoked = await self.conn.execute(stmt_revoked) - return res_revoked.rowcount + # Add new partitions first, by splitting the p_future catch-all. + if months_to_add: + new_partitions = [ + f"PARTITION {_partition_name(m)} " + f"VALUES LESS THAN ('{_partition_boundary(m + relativedelta(months=1))}')" + for m in months_to_add + ] + new_partitions.append("PARTITION p_future VALUES LESS THAN (MAXVALUE)") + await self.conn.execute( + text( + "ALTER TABLE RefreshTokens REORGANIZE PARTITION p_future INTO (" + + ", ".join(new_partitions) + + ")" + ) + ) + + # Then drop the partitions whose whole month is past the retention horizon. + if months_to_drop: + drop_names = ", ".join(_partition_name(m) for m in months_to_drop) + await self.conn.execute( + text(f"ALTER TABLE RefreshTokens DROP PARTITION {drop_names}") + ) + + logger.info( + "Refresh token partition maintenance: added %d, dropped %d", + len(months_to_add), + len(months_to_drop), + ) async def clean_expired_authorization_flows(self, max_retention: int) -> int: """Delete old authorization flows. diff --git a/diracx-db/tests/auth/test_partitions.py b/diracx-db/tests/auth/test_partitions.py new file mode 100644 index 000000000..979981a8b --- /dev/null +++ b/diracx-db/tests/auth/test_partitions.py @@ -0,0 +1,185 @@ +"""Tests for the RefreshTokens partition-maintenance logic. + +The pure planner (``plan_partition_maintenance``) and the name/boundary helpers +are dialect-independent and exercised directly here. The MySQL-only executor +(``maintain_refresh_token_partitions``) cannot be run against the in-memory +SQLite test database, so we only assert that it refuses to run on SQLite. +""" + +from __future__ import annotations + +from datetime import UTC, datetime + +import pytest +from dateutil.relativedelta import relativedelta + +from diracx.db.sql.auth.db import ( + AuthDB, + _partition_boundary, + _partition_name, + plan_partition_maintenance, +) +from diracx.db.sql.utils import uuid7_from_datetime + + +def m(year: int, month: int) -> datetime: + """Month-start datetime helper.""" + return datetime(year, month, 1, tzinfo=UTC) + + +@pytest.fixture +async def auth_db(tmp_path): + auth_db = AuthDB("sqlite+aiosqlite:///:memory:") + async with auth_db.engine_context(): + async with auth_db.engine.begin() as conn: + await conn.run_sync(auth_db.metadata.create_all) + yield auth_db + + +# --- helpers --------------------------------------------------------------- + + +def test_partition_name(): + assert _partition_name(m(2026, 3)) == "p_2026_3" + assert _partition_name(m(2026, 12)) == "p_2026_12" + + +def test_partition_boundary_matches_uuid7(): + dt = m(2026, 4) + boundary = _partition_boundary(dt) + # The boundary is the dash-stripped lowest UUIDv7 for the timestamp. + assert boundary == str(uuid7_from_datetime(dt, randomize=False)).replace("-", "") + assert len(boundary) == 32 # 32 hex chars, no dashes + + +def test_partition_boundary_is_monotonic(): + # The executor relies on lexical ordering of the JTI string boundaries. + assert _partition_boundary(m(2026, 1)) < _partition_boundary(m(2026, 2)) + assert _partition_boundary(m(2026, 12)) < _partition_boundary(m(2027, 1)) + + +# --- planner: drop --------------------------------------------------------- + + +def test_plan_drops_only_fully_expired_partitions(): + now = datetime(2026, 6, 15, tzinfo=UTC) + existing = [m(2026, month) for month in range(1, 9)] # Jan..Aug 2026 + + to_drop, _ = plan_partition_maintenance( + existing, now=now, retention_months=1, months_ahead=0 + ) + + # A partition for month X has upper bound X+1mo; drop when that is older than + # now - 1 month (2026-05-15). Jan..Apr have bounds Feb1..May1 (all <= May15). + assert to_drop == [m(2026, 1), m(2026, 2), m(2026, 3), m(2026, 4)] + + +def test_plan_drop_boundary_is_inclusive(): + # Upper bound exactly equal to the horizon must be dropped (<=). + now = m(2026, 6) + # now - 1 month == 2026-05-01 + existing = [m(2026, 4), m(2026, 5)] # bounds: May1, Jun1 + + to_drop, _ = plan_partition_maintenance( + existing, now=now, retention_months=1, months_ahead=0 + ) + assert to_drop == [m(2026, 4)] # May1 <= May1 drops April; June kept + + +def test_plan_keeps_last_six_months_by_default(): + # The deployment policy: keep the last 6 months worth of refresh tokens. + now = datetime(2026, 7, 15, tzinfo=UTC) + existing = [m(2025, month) for month in range(6, 13)] + [ + m(2026, month) for month in range(1, 8) + ] # 2025-06 .. 2026-07 + + to_drop, _ = plan_partition_maintenance( + existing, now=now, retention_months=6, months_ahead=0 + ) + + # Horizon is 2026-01-15: nothing from the last 6 months is dropped. + assert all(d < m(2026, 1) for d in to_drop) + assert m(2025, 12) in to_drop + assert m(2026, 1) not in to_drop + assert max(to_drop) == m(2025, 12) + + +def test_plan_keeps_everything_when_retention_is_huge(): + now = datetime(2026, 6, 15, tzinfo=UTC) + existing = [m(2026, month) for month in range(1, 9)] + to_drop, _ = plan_partition_maintenance( + existing, now=now, retention_months=120, months_ahead=0 + ) + assert to_drop == [] + + +# --- planner: add ---------------------------------------------------------- + + +def test_plan_adds_months_up_to_horizon(): + now = datetime(2026, 6, 15, tzinfo=UTC) + existing = [m(2026, 7)] # highest existing partition is July + _, to_add = plan_partition_maintenance( + existing, now=now, retention_months=6, months_ahead=3 + ) + # target_last = month_start(now) + 3 = 2026-09; append above July. + assert to_add == [m(2026, 8), m(2026, 9)] + + +def test_plan_adds_nothing_when_buffer_already_covered(): + now = datetime(2026, 6, 15, tzinfo=UTC) + existing = [m(2026, month) for month in range(6, 10)] # Jun..Sep + _, to_add = plan_partition_maintenance( + existing, now=now, retention_months=6, months_ahead=3 + ) + assert to_add == [] # highest existing (Sep) already == now+3mo + + +def test_plan_crosses_year_boundary(): + now = datetime(2026, 11, 15, tzinfo=UTC) + existing = [m(2026, 11)] + _, to_add = plan_partition_maintenance( + existing, now=now, retention_months=6, months_ahead=3 + ) + assert to_add == [m(2026, 12), m(2027, 1), m(2027, 2)] + + +def test_plan_empty_existing_seeds_from_current_month(): + now = datetime(2026, 6, 15, tzinfo=UTC) + _, to_add = plan_partition_maintenance( + [], now=now, retention_months=6, months_ahead=2 + ) + # No partitions yet: seed current month + buffer. + assert to_add == [m(2026, 6), m(2026, 7), m(2026, 8)] + + +def test_plan_combined_drop_and_add(): + now = datetime(2026, 6, 15, tzinfo=UTC) + existing = [m(2026, month) for month in range(1, 8)] # Jan..Jul + to_drop, to_add = plan_partition_maintenance( + existing, now=now, retention_months=1, months_ahead=2 + ) + assert to_drop == [m(2026, 1), m(2026, 2), m(2026, 3), m(2026, 4)] + assert to_add == [m(2026, 8)] # target_last = 2026-08, append above July + + +def test_plan_added_months_are_contiguous_and_increasing(): + now = datetime(2026, 6, 15, tzinfo=UTC) + existing = [m(2026, 6)] + _, to_add = plan_partition_maintenance( + existing, now=now, retention_months=6, months_ahead=12 + ) + # Each added month is exactly one month after the previous. + for previous, current in zip(to_add, to_add[1:]): + assert current == previous + relativedelta(months=1) + assert to_add[0] == m(2026, 7) + assert to_add[-1] == m(2027, 6) # now + 12 months + + +# --- executor: dialect guard ---------------------------------------------- + + +async def test_maintain_partitions_requires_mysql(auth_db: AuthDB): + async with auth_db as auth_db: + with pytest.raises(NotImplementedError, match="MySQL"): + await auth_db.maintain_refresh_token_partitions(retention_months=6) diff --git a/diracx-db/tests/auth/test_refresh_token.py b/diracx-db/tests/auth/test_refresh_token.py index 39fc69b37..28d6dfe9d 100644 --- a/diracx-db/tests/auth/test_refresh_token.py +++ b/diracx-db/tests/auth/test_refresh_token.py @@ -257,46 +257,3 @@ async def test_get_refresh_tokens(auth_db: AuthDB): # Check the number of retrieved refresh tokens (should be 3 refresh tokens) assert len(refresh_tokens) == 2 - - -async def test_clean_refresh_tokens(auth_db: AuthDB): - # Insert two refresh tokens - jtis = [] - async with auth_db as auth_db: - for _ in range(2): - jti = uuid7() - await auth_db.insert_refresh_token( - jti, - "subject", - "scope", - ) - jtis.append(jti) - - # Revoke one of the refresh token - async with auth_db as auth_db: - await auth_db.revoke_refresh_token(jtis[0]) - - # Check the number of deleted refresh tokens (should be 0) - async with auth_db as auth_db: - deleted_expired = await auth_db.clean_expired_refresh_tokens(max_validity=10) - assert deleted_expired == 0 - - async with auth_db as auth_db: - deleted_revoked = await auth_db.clean_revoked_refresh_tokens(max_retention=30) - assert deleted_revoked == 0 - - # Check the number of deleted refresh tokens (should be 1 of each) - async with auth_db as auth_db: - deleted_expired = await auth_db.clean_expired_refresh_tokens(max_validity=0) - assert deleted_expired == 1 - - async with auth_db as auth_db: - deleted_revoked = await auth_db.clean_revoked_refresh_tokens(max_retention=0) - assert deleted_revoked == 1 - - # Get all refresh tokens (Admin) - async with auth_db as auth_db: - refresh_tokens = await auth_db.get_user_refresh_tokens() - - # Check the number of retrieved refresh tokens (should be 0) - assert len(refresh_tokens) == 0 diff --git a/diracx-logic/src/diracx/logic/__main__.py b/diracx-logic/src/diracx/logic/__main__.py index d7b7674a5..e54837092 100644 --- a/diracx-logic/src/diracx/logic/__main__.py +++ b/diracx-logic/src/diracx/logic/__main__.py @@ -91,8 +91,8 @@ async def delete_jwk(args): async def cleanup_authdb(args): - """Delete expired tokens and flows from the AuthDB.""" - logger.info("Deleting expired tokens and flows") + """Maintain AuthDB partitions and remove expired flows.""" + logger.info("Maintaining AuthDB partitions and removing expired flows") import os from diracx.core.settings import AuthSettings @@ -138,7 +138,7 @@ def parse_args(): delete_jwk_parser.set_defaults(func=delete_jwk) cleanup_authdb_parser = subparsers.add_parser( - "cleanup-authdb", help="Delete expired tokens and flows from the AuthDB" + "cleanup-authdb", help="Maintain AuthDB partitions and remove expired flows" ) cleanup_authdb_parser.set_defaults(func=cleanup_authdb) diff --git a/diracx-logic/src/diracx/logic/auth/management.py b/diracx-logic/src/diracx/logic/auth/management.py index 4dfab8ce6..a70ed96a3 100644 --- a/diracx-logic/src/diracx/logic/auth/management.py +++ b/diracx-logic/src/diracx/logic/auth/management.py @@ -65,16 +65,15 @@ async def revoke_refresh_token_by_refresh_token( async def cleanup_expired_data(auth_db: AuthDB, settings: AuthSettings) -> None: - """Remove expired data from the auth database.""" - expired_tokens = await auth_db.clean_expired_refresh_tokens( - max_validity=settings.refresh_token_expire_minutes, - ) - logger.info("Deleted %d expired refresh tokens", expired_tokens) + """Remove expired data from the auth database. - revoked_tokens = await auth_db.clean_revoked_refresh_tokens( - max_retention=settings.revoked_refresh_token_retention_minutes, + Expired refresh tokens are removed by dropping whole monthly partitions of + the RefreshTokens table (see ``AuthDB.maintain_refresh_token_partitions``). + The flow tables are not partitioned, so their expired rows are deleted. + """ + await auth_db.maintain_refresh_token_partitions( + retention_months=settings.refresh_token_retention_months, ) - logger.info("Deleted %d revoked refresh tokens", revoked_tokens) auth = await auth_db.clean_expired_authorization_flows( max_retention=settings.completed_flow_retention_minutes, diff --git a/docs/admin/reference/env-variables.md b/docs/admin/reference/env-variables.md index ea9a04a49..d81d1e85e 100644 --- a/docs/admin/reference/env-variables.md +++ b/docs/admin/reference/env-variables.md @@ -120,14 +120,15 @@ Expiration time in minutes for refresh tokens. The maximum lifetime of refresh tokens before they must be re-issued through a new authentication flow. Default: 60 minutes. -### `DIRACX_SERVICE_AUTH_REVOKED_REFRESH_TOKEN_RETENTION_MINUTES` +### `DIRACX_SERVICE_AUTH_REFRESH_TOKEN_RETENTION_MONTHS` -*Optional*, default value: `43200` +*Optional*, default value: `6` -Retention time in minutes for revoked refresh tokens. +Retention time in months for refresh tokens. -The maximum retention time of refresh tokens after being -revoked and before they are deleted. Default: 43200 minutes (30 days). +Refresh tokens live in monthly partitions that are dropped once the whole +month is older than this many months. It is therefore the longest a refresh +token (revoked or not) is kept before removal. Default: 6 months. ### `DIRACX_SERVICE_AUTH_AVAILABLE_PROPERTIES`