From 07c971beb1d00c247109a539741521a1be4d62a9 Mon Sep 17 00:00:00 2001 From: Benjamin Capodanno Date: Tue, 12 May 2026 14:25:09 -0700 Subject: [PATCH] fix(authentication): handle concurrent first-login race condition When multiple servers process a user's first login simultaneously, each may query, find no existing record, and attempt to INSERT. The losing request raised an unhandled IntegrityError. - Wrap the commit in a try/except IntegrityError block - On collision: rollback and re-fetch the existing user record - Add a test that simulates the race and asserts the correct user is returned with exactly one DB row --- src/mavedb/lib/authentication.py | 11 ++++++- tests/lib/test_authentication.py | 53 ++++++++++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/src/mavedb/lib/authentication.py b/src/mavedb/lib/authentication.py index 4ff59272d..fed13ae1f 100644 --- a/src/mavedb/lib/authentication.py +++ b/src/mavedb/lib/authentication.py @@ -13,6 +13,7 @@ HTTPBearer, ) from jose import jwt +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from mavedb import deps @@ -230,7 +231,15 @@ async def get_current_user( ) db.add(user) - db.commit() + try: + db.commit() + except IntegrityError: + # A concurrent request created this user between our initial query and this commit. + # Roll back and re-fetch the existing record. + db.rollback() + user = db.query(User).filter(User.username == username).one() + logger.debug(msg="Concurrent first-login resolved; returning existing user.", extra=logging_context()) + db.refresh(user) logger.info(msg="Successfully authenticated user via JWT.", extra=logging_context()) diff --git a/tests/lib/test_authentication.py b/tests/lib/test_authentication.py index 534271930..80ee503cb 100644 --- a/tests/lib/test_authentication.py +++ b/tests/lib/test_authentication.py @@ -1,8 +1,10 @@ # ruff: noqa: E402 -import pytest from unittest.mock import patch +import pytest +from sqlalchemy.exc import IntegrityError + arq = pytest.importorskip("arq") cdot = pytest.importorskip("cdot") fastapi = pytest.importorskip("fastapi") @@ -11,7 +13,6 @@ from mavedb.models.enums.user_role import UserRole from mavedb.models.user import User from tests.helpers.constants import ADMIN_USER, ADMIN_USER_DECODED_JWT, TEST_USER, TEST_USER_DECODED_JWT - from tests.helpers.util.access_key import create_api_key_for_user from tests.helpers.util.user import mark_user_inactive @@ -121,3 +122,51 @@ async def test_get_current_user_user_extraneous_roles(session, setup_lib_db): assert user_data.user.username == TEST_USER["username"] assert user_data.active_roles == [] + + +@pytest.mark.asyncio +async def test_get_current_user_concurrent_first_login_integrity_error_returns_existing_user(session, setup_lib_db): + """ + Simulate two servers racing on first login: the commit raises IntegrityError because a + concurrent request already inserted the row. The handler should roll back and return the + existing user rather than surfacing the error. + """ + new_user_jwt = { + "sub": "9999-0000-0000-9999", + "given_name": "Race", + "family_name": "Condition", + } + + # Insert the user as if a concurrent request already committed it. + pre_existing = User( + username=new_user_jwt["sub"], + first_name=new_user_jwt["given_name"], + last_name=new_user_jwt["family_name"], + is_active=True, + is_first_login=True, + ) + session.add(pre_existing) + session.commit() + + # Wrap the real session so we can intercept the first commit call and raise IntegrityError, + # letting subsequent calls (rollback, refresh, etc.) pass through to the real session. + original_commit = session.commit + commit_calls = [] + + def fake_commit(): + commit_calls.append(1) + if len(commit_calls) == 1: + raise IntegrityError(statement=None, params=None, orig=Exception("duplicate key")) + return original_commit() + + session.commit = fake_commit + + with patch("mavedb.lib.authentication.fetch_orcid_user_email", return_value=None): + user_data = await get_current_user(None, new_user_jwt, session, None) + + assert user_data is not None + assert user_data.user.username == new_user_jwt["sub"] + + # Only one user record should exist in the database. + users = session.query(User).filter(User.username == new_user_jwt["sub"]).all() + assert len(users) == 1