diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index e1687206d5547..c24d016453279 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -35,7 +35,7 @@ from pydantic import JsonValue from sqlalchemy import and_, func, or_, tuple_, update from sqlalchemy.engine import CursorResult -from sqlalchemy.exc import NoResultFound, SQLAlchemyError +from sqlalchemy.exc import IntegrityError, NoResultFound, SQLAlchemyError from sqlalchemy.orm import joinedload from sqlalchemy.sql import select from structlog.contextvars import bind_contextvars @@ -797,8 +797,19 @@ def ti_put_rtif( raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, ) - task_instance.update_rtif(put_rtif_payload, session) - log.debug("RenderedTaskInstanceFields updated successfully") + try: + task_instance.update_rtif(put_rtif_payload, session) + except IntegrityError: + session.rollback() + # Re-fetch the task instance after rollback since the previous one is detached + task_instance = session.scalar(select(TI).where(TI.id == task_instance_id)) + if task_instance: + # Retry: the record now exists from the concurrent request, + # so merge will find it and update rather than insert. + task_instance.update_rtif(put_rtif_payload, session) + log.info("RenderedTaskInstanceFields updated after concurrent write conflict") + else: + log.debug("RenderedTaskInstanceFields updated successfully") return {"message": "Rendered task instance fields successfully set"} diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index c6135711be9bf..2c6b05cf85f51 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -1995,6 +1995,82 @@ def test_ti_put_rtif_missing_ti(self, client, session, create_task_instance): assert response.json()["detail"] == "Not Found" + def test_ti_put_rtif_concurrent_write(self, client, session, create_task_instance): + """Test that concurrent RTIF writes don't cause 409 errors. + + When two workers try to write rendered fields for the same task instance + simultaneously, the second write should succeed by updating the existing record + rather than failing with a unique constraint violation. + """ + ti = create_task_instance( + task_id="test_ti_put_rtif_concurrent", + state=State.RUNNING, + session=session, + ) + session.commit() + + payload1 = {"field1": "value1"} + payload2 = {"field1": "value2"} + + # First write should succeed + response1 = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload1) + assert response1.status_code == 201 + + # Second write (simulating concurrent update) should also succeed by merging + response2 = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload2) + assert response2.status_code == 201 + + session.expire_all() + rtifs = session.scalars(select(RenderedTaskInstanceFields)).all() + assert len(rtifs) == 1 + assert rtifs[0].rendered_fields == payload2 + + def test_ti_put_rtif_integrity_error_handled(self, client, session, create_task_instance): + """Test that IntegrityError from a race condition is handled gracefully. + + Simulates the race condition where the first update_rtif call raises + IntegrityError (as if another concurrent request already inserted the record), + and verifies the endpoint retries successfully. + """ + from unittest.mock import patch + + from sqlalchemy.exc import IntegrityError + + from airflow.models.taskinstance import TaskInstance + + ti = create_task_instance( + task_id="test_ti_put_rtif_integrity", + state=State.RUNNING, + session=session, + ) + session.commit() + + payload = {"field1": "rendered_value1"} + + original_update_rtif = TaskInstance.update_rtif + call_count = 0 + + def mock_update_rtif(self_ti, rendered_fields, session): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise IntegrityError( + statement="INSERT INTO rendered_task_instance_fields", + params={}, + orig=Exception( + 'duplicate key value violates unique constraint "rendered_task_instance_fields_pkey"' + ), + ) + return original_update_rtif(self_ti, rendered_fields, session) + + with patch.object(TaskInstance, "update_rtif", mock_update_rtif): + response = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload) + + assert response.status_code == 201 + assert response.json() == {"message": "Rendered task instance fields successfully set"} + assert call_count == 2 # First call raises, second succeeds + + class TestPreviousDagRun: def setup_method(self): clear_db_runs()