Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pydantic import JsonValue
from sqlalchemy import and_, func, or_, tuple_, update
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.exc import IntegrityError, NoResultFound, SQLAlchemyError
from sqlalchemy.orm import joinedload
from sqlalchemy.sql import select
from structlog.contextvars import bind_contextvars
Expand Down Expand Up @@ -729,8 +729,19 @@ def ti_put_rtif(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
)
task_instance.update_rtif(put_rtif_payload, session)
log.debug("RenderedTaskInstanceFields updated successfully")
try:
task_instance.update_rtif(put_rtif_payload, session)
except IntegrityError:
session.rollback()
# Re-fetch the task instance after rollback since the previous one is detached
task_instance = session.scalar(select(TI).where(TI.id == task_instance_id))
if task_instance:
# Retry: the record now exists from the concurrent request,
# so merge will find it and update rather than insert.
task_instance.update_rtif(put_rtif_payload, session)
log.info("RenderedTaskInstanceFields updated after concurrent write conflict")
else:
log.debug("RenderedTaskInstanceFields updated successfully")

return {"message": "Rendered task instance fields successfully set"}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1953,6 +1953,82 @@ def test_ti_put_rtif_missing_ti(self, client, session, create_task_instance):
assert response.json()["detail"] == "Not Found"


def test_ti_put_rtif_concurrent_write(self, client, session, create_task_instance):
"""Test that concurrent RTIF writes don't cause 409 errors.

When two workers try to write rendered fields for the same task instance
simultaneously, the second write should succeed by updating the existing record
rather than failing with a unique constraint violation.
"""
ti = create_task_instance(
task_id="test_ti_put_rtif_concurrent",
state=State.RUNNING,
session=session,
)
session.commit()

payload1 = {"field1": "value1"}
payload2 = {"field1": "value2"}

# First write should succeed
response1 = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload1)
assert response1.status_code == 201

# Second write (simulating concurrent update) should also succeed by merging
response2 = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload2)
assert response2.status_code == 201

session.expire_all()
rtifs = session.scalars(select(RenderedTaskInstanceFields)).all()
assert len(rtifs) == 1
assert rtifs[0].rendered_fields == payload2

def test_ti_put_rtif_integrity_error_handled(self, client, session, create_task_instance):
"""Test that IntegrityError from a race condition is handled gracefully.

Simulates the race condition where the first update_rtif call raises
IntegrityError (as if another concurrent request already inserted the record),
and verifies the endpoint retries successfully.
"""
from unittest.mock import patch

from sqlalchemy.exc import IntegrityError

from airflow.models.taskinstance import TaskInstance

ti = create_task_instance(
task_id="test_ti_put_rtif_integrity",
state=State.RUNNING,
session=session,
)
session.commit()

payload = {"field1": "rendered_value1"}

original_update_rtif = TaskInstance.update_rtif
call_count = 0

def mock_update_rtif(self_ti, rendered_fields, session):
nonlocal call_count
call_count += 1
if call_count == 1:
raise IntegrityError(
statement="INSERT INTO rendered_task_instance_fields",
params={},
orig=Exception(
'duplicate key value violates unique constraint "rendered_task_instance_fields_pkey"'
),
)
return original_update_rtif(self_ti, rendered_fields, session)

with patch.object(TaskInstance, "update_rtif", mock_update_rtif):
response = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload)

assert response.status_code == 201
assert response.json() == {"message": "Rendered task instance fields successfully set"}
assert call_count == 2 # First call raises, second succeeds


class TestPreviousDagRun:
def setup_method(self):
clear_db_runs()
Expand Down
Loading