Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

import copy
import json
from typing import Annotated

from fastapi import Depends, HTTPException, Query, status
Expand Down Expand Up @@ -270,27 +269,24 @@ def create_xcom_entry(
)

try:
value = json.dumps(request_body.value)
except (ValueError, TypeError):
XComModel.set(
key=request_body.key,
value=request_body.value,
dag_id=dag_id,
task_id=task_id,
run_id=dag_run_id,
map_index=request_body.map_index,
serialize=False,
session=session,
)
except (ValueError, TypeError) as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, f"Couldn't serialise the XCom with key: `{request_body.key}`"
)

new = XComModel(
dag_run_id=dag_run.id,
key=request_body.key,
value=value,
run_id=dag_run_id,
task_id=task_id,
dag_id=dag_id,
map_index=request_body.map_index,
)
session.add(new)
session.flush()
) from e

xcom = session.scalar(
select(XComModel)
.filter(
.where(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Hi @amoghrajesh, I updated this to align with the modern SQLAlchemy 2.0 syntax. While .filter() still works here, .where() is the canonical method when using select() constructs.
Let me know if you'd prefer I revert this part.

XComModel.dag_id == dag_id,
XComModel.task_id == task_id,
XComModel.run_id == dag_run_id,
Expand Down Expand Up @@ -324,11 +320,12 @@ def update_xcom_entry(
dag_run_id: str,
xcom_key: str,
patch_body: XComUpdateBody,
*,
session: SessionDep,
) -> XComResponseNative:
"""Update an existing XCom entry."""
# Check if XCom entry exists
xcom_entry = session.scalar(
xcom_query = (
select(XComModel)
.where(
XComModel.dag_id == dag_id,
Expand All @@ -340,16 +337,32 @@ def update_xcom_entry(
.limit(1)
.options(joinedload(XComModel.task), joinedload(XComModel.dag_run).joinedload(DR.dag_model))
)
xcom_entry = session.scalar(xcom_query)

if not xcom_entry:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
f"The XCom with key: `{xcom_key}` with mentioned task instance doesn't exist.",
)

# Update XCom entry
xcom_entry.value = json.dumps(patch_body.value)
try:
XComModel.set(
key=xcom_key,
value=patch_body.value,
dag_id=dag_id,
task_id=task_id,
run_id=dag_run_id,
map_index=patch_body.map_index,
serialize=False,
session=session,
)
except (ValueError, TypeError) as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, f"Couldn't serialise the XCom with key: `{xcom_key}`"
) from e

# Fetch after setting, to get fresh object for response
xcom_entry = session.scalar(xcom_query)
return XComResponseNative.model_validate(xcom_entry)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def test_create_xcom_entry(
# Validate the created XCom response
current_data = response.json()
assert current_data["key"] == request_body.key
assert current_data["value"] == XComModel.serialize_value(request_body.value)
assert current_data["value"] == request_body.value
assert current_data["dag_id"] == dag_id
assert current_data["task_id"] == task_id
assert current_data["run_id"] == dag_run_id
Expand Down Expand Up @@ -716,7 +716,7 @@ def test_create_xcom_entry_with_slash_key(self, test_client):
)
assert get_resp.status_code == 200
assert get_resp.json()["key"] == slash_key
assert get_resp.json()["value"] == json.dumps(TEST_XCOM_VALUE)
assert get_resp.json()["value"] == TEST_XCOM_VALUE

@pytest.mark.parametrize(
("key", "value"),
Expand Down Expand Up @@ -833,7 +833,7 @@ def test_patch_xcom_entry(self, key, patch_body, expected_status, expected_detai
assert response.status_code == expected_status

if expected_status == 200:
assert response.json()["value"] == json.dumps(patch_body["value"])
assert response.json()["value"] == patch_body["value"]
else:
assert response.json()["detail"] == expected_detail
check_last_log(session, dag_id=TEST_DAG_ID, event="update_xcom_entry", logical_date=None)
Expand Down Expand Up @@ -862,5 +862,23 @@ def test_patch_xcom_entry_with_slash_key(self, test_client, session):
)
assert response.status_code == 200
assert response.json()["key"] == slash_key
assert response.json()["value"] == json.dumps(new_value)
assert response.json()["value"] == new_value
check_last_log(session, dag_id=TEST_DAG_ID, event="update_xcom_entry", logical_date=None)

def test_patch_xcom_preserves_int_type(self, test_client, session):
"""Test scenario described in #59032: if existing XCom value type is int,
after patching with different value, it should still be int in the API response.
"""
key = "int_type_xcom"
# Create with int value
self._create_xcom(key, 42)
patch_value = 100
response = test_client.patch(
f"/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{key}",
json={"value": patch_value},
)
assert response.status_code == 200
data = response.json()
assert data["value"] == patch_value
assert isinstance(data["value"], int), f"Expected int type but got {type(data['value'])}"
check_last_log(session, dag_id=TEST_DAG_ID, event="update_xcom_entry", logical_date=None)
Loading