Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions diracx-core/src/diracx/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,12 @@ def check_retention_greater_than_expiration(self) -> Self:
through a new authentication flow. Default: 60 minutes.
"""

revoked_refresh_token_retention_minutes: int = 43200
"""Retention time in minutes for revoked refresh tokens.
refresh_token_retention_months: int = 6
"""Retention time in months for refresh tokens.

The maximum retention time of refresh tokens after being
revoked and before they are deleted. Default: 43200 minutes (30 days).
Refresh tokens live in monthly partitions that are dropped once the whole
month is older than this many months. It is therefore the longest a refresh
token (revoked or not) is kept before removal. Default: 6 months.
"""

available_properties: set[SecurityProperty] = Field(
Expand Down
151 changes: 126 additions & 25 deletions diracx-db/src/diracx/db/sql/auth/db.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import logging
import re
import secrets
from datetime import UTC, datetime
from itertools import pairwise

from dateutil.relativedelta import relativedelta
from dateutil.rrule import MONTHLY, rrule
from sqlalchemy import delete, insert, select, text, update
from sqlalchemy.exc import IntegrityError, NoResultFound
Expand Down Expand Up @@ -32,6 +34,58 @@

logger = logging.getLogger(__name__)

# Always keep at least this many months of future RefreshTokens partitions ahead
# of "now" so the ``p_future`` catch-all partition never accumulates rows.
PARTITION_MONTHS_AHEAD = 12


def _month_start(dt: datetime) -> datetime:
"""Truncate ``dt`` to the first instant of its month."""
return dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0)


def _partition_name(month_start: datetime) -> str:
"""Name of the partition holding the tokens created during ``month_start``."""
return f"p_{month_start.year}_{month_start.month}"


def _partition_boundary(dt: datetime) -> str:
"""``RANGE COLUMNS(JTI)`` upper bound (exclusive) for tokens created before ``dt``."""
return str(uuid7_from_datetime(dt, randomize=False)).replace("-", "")


def plan_partition_maintenance(
existing_months: list[datetime],
now: datetime,
retention_months: int,
months_ahead: int,
) -> tuple[list[datetime], list[datetime]]:
"""Decide which monthly ``RefreshTokens`` partitions to drop and to add.
``existing_months`` are the month-start datetimes of the existing
``p_<year>_<month>`` partitions (excluding ``p_future``). Returns
``(months_to_drop, months_to_add)`` as month-start datetimes.
"""
existing = sorted(existing_months)

# A partition for month ``m`` holds tokens created before ``m + 1 month``, so
# the whole partition is expired once that upper bound is older than the
# retention horizon. Keeping ``retention_months`` worth of partitions never
# drops a token younger than that many calendar months.
horizon = now - relativedelta(months=retention_months)
months_to_drop = [m for m in existing if m + relativedelta(months=1) <= horizon]

# Ensure a partition exists for every month up to ``now + months_ahead`` by
# appending months above the highest existing partition.
target_last = _month_start(now) + relativedelta(months=months_ahead)
cursor = max(existing) if existing else _month_start(now) - relativedelta(months=1)
months_to_add: list[datetime] = []
while cursor < target_last:
cursor += relativedelta(months=1)
months_to_add.append(cursor)

return months_to_drop, months_to_add


class AuthDB(BaseSQLDB):
metadata = AuthDBBase.metadata
Expand Down Expand Up @@ -67,8 +121,8 @@ async def post_create(cls, conn: AsyncConnection) -> None:
partition_list = []
for name, limit in pairwise(dates):
partition_list.append(
f"PARTITION p_{name.year}_{name.month} "
f"VALUES LESS THAN ('{str(uuid7_from_datetime(limit, randomize=False)).replace('-', '')}')"
f"PARTITION {_partition_name(name)} "
f"VALUES LESS THAN ('{_partition_boundary(limit)}')"
)
partition_list.append("PARTITION p_future VALUES LESS THAN (MAXVALUE)")

Expand Down Expand Up @@ -340,37 +394,84 @@ async def revoke_user_refresh_tokens(self, subject):
.values(status=RefreshTokenStatus.REVOKED)
)

async def clean_expired_refresh_tokens(self, max_validity: int) -> int:
"""Delete expired refresh tokens.
async def maintain_refresh_token_partitions(
self,
retention_months: int,
months_ahead: int = PARTITION_MONTHS_AHEAD,
) -> None:
"""Maintain the monthly partitions of the RefreshTokens table.
max_validity: Maximum validity time in minutes for refresh tokens.
Drops partitions whose entire month is older than ``retention_months``
and adds partitions ahead of time so the ``p_future`` catch-all never
fills. Cleanup of expired refresh tokens is achieved by dropping whole
partitions rather than deleting rows.
Only implemented for MySQL; raises ``NotImplementedError`` for any other
dialect (the table is only partitioned on MySQL).
"""
expired_date = str(
uuid7_from_datetime(substract_date(minutes=max_validity), randomize=False)
)
stmt_expired = delete(RefreshTokens).where(
RefreshTokens.status == RefreshTokenStatus.CREATED,
RefreshTokens.jti < expired_date,
dialect = self.conn.dialect.name
if dialect != "mysql":
raise NotImplementedError(
"Refresh token partition maintenance is only implemented for "
f"MySQL, not {dialect!r}"
)

check_partition_query = text(
"SELECT PARTITION_NAME FROM information_schema.partitions "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'RefreshTokens' "
"AND PARTITION_NAME IS NOT NULL"
)
res_expired = await self.conn.execute(stmt_expired)
partition_names = (await self.conn.execute(check_partition_query)).all()

return res_expired.rowcount
existing_months = []
for (name,) in partition_names:
if match := re.fullmatch(r"p_(\d+)_(\d+)", name):
existing_months.append(
datetime(int(match.group(1)), int(match.group(2)), 1, tzinfo=UTC)
)

async def clean_revoked_refresh_tokens(self, max_retention: int) -> int:
"""Delete old revoked refresh tokens.
if not existing_months:
logger.warning(
"RefreshTokens is not partitioned; skipping partition maintenance. "
"Partition the table manually (see AuthDB.post_create)."
)
return

max_retention: Maximum retention time in minutes for revoked refresh tokens.
"""
revoked_date = str(
uuid7_from_datetime(substract_date(minutes=max_retention), randomize=False)
months_to_drop, months_to_add = plan_partition_maintenance(
existing_months,
now=datetime.now(tz=UTC),
retention_months=retention_months,
months_ahead=months_ahead,
)
stmt_revoked = delete(RefreshTokens).where(
RefreshTokens.status == RefreshTokenStatus.REVOKED,
RefreshTokens.jti < revoked_date,
)
res_revoked = await self.conn.execute(stmt_revoked)

return res_revoked.rowcount
# Add new partitions first, by splitting the p_future catch-all.
if months_to_add:
new_partitions = [
f"PARTITION {_partition_name(m)} "
f"VALUES LESS THAN ('{_partition_boundary(m + relativedelta(months=1))}')"
for m in months_to_add
]
new_partitions.append("PARTITION p_future VALUES LESS THAN (MAXVALUE)")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are already creating partitions in the post-update, isn't that a duplicate ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From chatting here, yes I'd missed that. Though at least LHCb doesn't have init-sql enabled at the moment so it's probably safer having it here and mostly harmless having it in two places.

await self.conn.execute(
text(
"ALTER TABLE RefreshTokens REORGANIZE PARTITION p_future INTO ("
+ ", ".join(new_partitions)
+ ")"
)
)

# Then drop the partitions whose whole month is past the retention horizon.
if months_to_drop:
drop_names = ", ".join(_partition_name(m) for m in months_to_drop)
await self.conn.execute(
text(f"ALTER TABLE RefreshTokens DROP PARTITION {drop_names}")
)

logger.info(
"Refresh token partition maintenance: added %d, dropped %d",
len(months_to_add),
len(months_to_drop),
)

async def clean_expired_authorization_flows(self, max_retention: int) -> int:
"""Delete old authorization flows.
Expand Down
185 changes: 185 additions & 0 deletions diracx-db/tests/auth/test_partitions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""Tests for the RefreshTokens partition-maintenance logic.

The pure planner (``plan_partition_maintenance``) and the name/boundary helpers
are dialect-independent and exercised directly here. The MySQL-only executor
(``maintain_refresh_token_partitions``) cannot be run against the in-memory
SQLite test database, so we only assert that it refuses to run on SQLite.
"""

from __future__ import annotations

from datetime import UTC, datetime

import pytest
from dateutil.relativedelta import relativedelta

from diracx.db.sql.auth.db import (
AuthDB,
_partition_boundary,
_partition_name,
plan_partition_maintenance,
)
from diracx.db.sql.utils import uuid7_from_datetime


def m(year: int, month: int) -> datetime:
"""Month-start datetime helper."""
return datetime(year, month, 1, tzinfo=UTC)


@pytest.fixture
async def auth_db(tmp_path):
auth_db = AuthDB("sqlite+aiosqlite:///:memory:")
async with auth_db.engine_context():
async with auth_db.engine.begin() as conn:
await conn.run_sync(auth_db.metadata.create_all)
yield auth_db


# --- helpers ---------------------------------------------------------------


def test_partition_name():
assert _partition_name(m(2026, 3)) == "p_2026_3"
assert _partition_name(m(2026, 12)) == "p_2026_12"


def test_partition_boundary_matches_uuid7():
dt = m(2026, 4)
boundary = _partition_boundary(dt)
# The boundary is the dash-stripped lowest UUIDv7 for the timestamp.
assert boundary == str(uuid7_from_datetime(dt, randomize=False)).replace("-", "")
assert len(boundary) == 32 # 32 hex chars, no dashes


def test_partition_boundary_is_monotonic():
# The executor relies on lexical ordering of the JTI string boundaries.
assert _partition_boundary(m(2026, 1)) < _partition_boundary(m(2026, 2))
assert _partition_boundary(m(2026, 12)) < _partition_boundary(m(2027, 1))


# --- planner: drop ---------------------------------------------------------


def test_plan_drops_only_fully_expired_partitions():
now = datetime(2026, 6, 15, tzinfo=UTC)
existing = [m(2026, month) for month in range(1, 9)] # Jan..Aug 2026

to_drop, _ = plan_partition_maintenance(
existing, now=now, retention_months=1, months_ahead=0
)

# A partition for month X has upper bound X+1mo; drop when that is older than
# now - 1 month (2026-05-15). Jan..Apr have bounds Feb1..May1 (all <= May15).
assert to_drop == [m(2026, 1), m(2026, 2), m(2026, 3), m(2026, 4)]


def test_plan_drop_boundary_is_inclusive():
# Upper bound exactly equal to the horizon must be dropped (<=).
now = m(2026, 6)
# now - 1 month == 2026-05-01
existing = [m(2026, 4), m(2026, 5)] # bounds: May1, Jun1

to_drop, _ = plan_partition_maintenance(
existing, now=now, retention_months=1, months_ahead=0
)
assert to_drop == [m(2026, 4)] # May1 <= May1 drops April; June kept


def test_plan_keeps_last_six_months_by_default():
# The deployment policy: keep the last 6 months worth of refresh tokens.
now = datetime(2026, 7, 15, tzinfo=UTC)
existing = [m(2025, month) for month in range(6, 13)] + [
m(2026, month) for month in range(1, 8)
] # 2025-06 .. 2026-07

to_drop, _ = plan_partition_maintenance(
existing, now=now, retention_months=6, months_ahead=0
)

# Horizon is 2026-01-15: nothing from the last 6 months is dropped.
assert all(d < m(2026, 1) for d in to_drop)
assert m(2025, 12) in to_drop
assert m(2026, 1) not in to_drop
assert max(to_drop) == m(2025, 12)


def test_plan_keeps_everything_when_retention_is_huge():
now = datetime(2026, 6, 15, tzinfo=UTC)
existing = [m(2026, month) for month in range(1, 9)]
to_drop, _ = plan_partition_maintenance(
existing, now=now, retention_months=120, months_ahead=0
)
assert to_drop == []


# --- planner: add ----------------------------------------------------------


def test_plan_adds_months_up_to_horizon():
now = datetime(2026, 6, 15, tzinfo=UTC)
existing = [m(2026, 7)] # highest existing partition is July
_, to_add = plan_partition_maintenance(
existing, now=now, retention_months=6, months_ahead=3
)
# target_last = month_start(now) + 3 = 2026-09; append above July.
assert to_add == [m(2026, 8), m(2026, 9)]


def test_plan_adds_nothing_when_buffer_already_covered():
now = datetime(2026, 6, 15, tzinfo=UTC)
existing = [m(2026, month) for month in range(6, 10)] # Jun..Sep
_, to_add = plan_partition_maintenance(
existing, now=now, retention_months=6, months_ahead=3
)
assert to_add == [] # highest existing (Sep) already == now+3mo


def test_plan_crosses_year_boundary():
now = datetime(2026, 11, 15, tzinfo=UTC)
existing = [m(2026, 11)]
_, to_add = plan_partition_maintenance(
existing, now=now, retention_months=6, months_ahead=3
)
assert to_add == [m(2026, 12), m(2027, 1), m(2027, 2)]


def test_plan_empty_existing_seeds_from_current_month():
now = datetime(2026, 6, 15, tzinfo=UTC)
_, to_add = plan_partition_maintenance(
[], now=now, retention_months=6, months_ahead=2
)
# No partitions yet: seed current month + buffer.
assert to_add == [m(2026, 6), m(2026, 7), m(2026, 8)]


def test_plan_combined_drop_and_add():
now = datetime(2026, 6, 15, tzinfo=UTC)
existing = [m(2026, month) for month in range(1, 8)] # Jan..Jul
to_drop, to_add = plan_partition_maintenance(
existing, now=now, retention_months=1, months_ahead=2
)
assert to_drop == [m(2026, 1), m(2026, 2), m(2026, 3), m(2026, 4)]
assert to_add == [m(2026, 8)] # target_last = 2026-08, append above July


def test_plan_added_months_are_contiguous_and_increasing():
now = datetime(2026, 6, 15, tzinfo=UTC)
existing = [m(2026, 6)]
_, to_add = plan_partition_maintenance(
existing, now=now, retention_months=6, months_ahead=12
)
# Each added month is exactly one month after the previous.
for previous, current in zip(to_add, to_add[1:]):
assert current == previous + relativedelta(months=1)
assert to_add[0] == m(2026, 7)
assert to_add[-1] == m(2027, 6) # now + 12 months


# --- executor: dialect guard ----------------------------------------------


async def test_maintain_partitions_requires_mysql(auth_db: AuthDB):
async with auth_db as auth_db:
with pytest.raises(NotImplementedError, match="MySQL"):
await auth_db.maintain_refresh_token_partitions(retention_months=6)
Loading
Loading