Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions src/a2a/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ def override(func): # noqa: ANN001, ANN201
return func


from google.protobuf.json_format import MessageToDict, ParseDict
from google.protobuf.json_format import MessageToDict, ParseDict, ParseError
from google.protobuf.message import Message as ProtoMessage
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError

from a2a.compat.v0_3 import conversions
from a2a.compat.v0_3 import types as types_v03
from a2a.types.a2a_pb2 import Artifact, Message, TaskStatus


Expand Down Expand Up @@ -81,7 +83,19 @@ def process_result_value(
if isinstance(self.pydantic_type, type) and issubclass(
self.pydantic_type, ProtoMessage
):
return ParseDict(value, self.pydantic_type()) # type: ignore[return-value]
try:
return ParseDict(value, self.pydantic_type()) # type: ignore[return-value]
except (ParseError, ValueError):
# Try legacy conversion
legacy_map = _get_legacy_conversions()
if self.pydantic_type in legacy_map:
legacy_type, convert_func = legacy_map[self.pydantic_type]
try:
legacy_instance = legacy_type.model_validate(value)
return convert_func(legacy_instance)
except ValidationError:
pass
raise
# Assume it's a Pydantic model
return self.pydantic_type.model_validate(value) # type: ignore[attr-defined]

Expand Down Expand Up @@ -130,7 +144,24 @@ def process_result_value(
if isinstance(self.pydantic_type, type) and issubclass(
self.pydantic_type, ProtoMessage
):
return [ParseDict(item, self.pydantic_type()) for item in value] # type: ignore[misc]
result = []
legacy_map = _get_legacy_conversions()
legacy_info = legacy_map.get(self.pydantic_type)

for item in value:
try:
result.append(ParseDict(item, self.pydantic_type()))
except (ParseError, ValueError): # noqa: PERF203
if legacy_info:
legacy_type, convert_func = legacy_info
try:
legacy_instance = legacy_type.model_validate(item)
result.append(convert_func(legacy_instance))
continue
except ValidationError:
pass
raise
return result # type: ignore[return-value]
# Assume it's a Pydantic model
return [self.pydantic_type.model_validate(item) for item in value] # type: ignore[attr-defined]

Expand Down Expand Up @@ -292,3 +323,15 @@ class PushNotificationConfigModel(PushNotificationConfigMixin, Base):
"""Default push notification config model with standard table name."""

__tablename__ = 'push_notification_configs'


def _get_legacy_conversions() -> dict[type, tuple[type, Any]]:
"""Get the mapping of current types to their legacy counterparts and conversion functions."""
return {
TaskStatus: (
types_v03.TaskStatus,
conversions.to_core_task_status,
),
Message: (types_v03.Message, conversions.to_core_message),
Artifact: (types_v03.Artifact, conversions.to_core_artifact),
}
135 changes: 135 additions & 0 deletions tests/server/tasks/test_database_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,4 +683,139 @@ async def test_owner_resource_scoping(
await task_store.delete('u2-task1', context_user2)


@pytest.mark.asyncio
async def test_get_0_3_task_detailed(
db_store_parameterized: DatabaseTaskStore,
) -> None:
"""Test retrieving a detailed legacy v0.3 task from the database.

This test simulates a database that already contains legacy v0.3 JSON data
(string-based enums, different field names) and verifies that the store
correctly converts it to the modern Protobuf-based Task model.
"""
from a2a.compat.v0_3 import types as types_v03
from sqlalchemy import insert

task_id = 'legacy-detailed-1'
owner = 'legacy_user'
context_user = ServerCallContext(user=SampleUser(user_name=owner))

# 1. Create a detailed legacy Task using v0.3 models
legacy_task = types_v03.Task(
id=task_id,
context_id='legacy-ctx-1',
status=types_v03.TaskStatus(
state=types_v03.TaskState.working,
message=types_v03.Message(
message_id='msg-status',
role=types_v03.Role.agent,
parts=[
types_v03.Part(
root=types_v03.TextPart(text='Legacy status message')
)
],
),
timestamp='2023-10-27T10:00:00Z',
),
history=[
types_v03.Message(
message_id='msg-1',
role=types_v03.Role.user,
parts=[
types_v03.Part(root=types_v03.TextPart(text='Hello legacy'))
],
),
types_v03.Message(
message_id='msg-2',
role=types_v03.Role.agent,
parts=[
types_v03.Part(
root=types_v03.DataPart(data={'legacy_key': 'value'})
)
],
),
],
artifacts=[
types_v03.Artifact(
artifact_id='art-1',
name='Legacy Artifact',
parts=[
types_v03.Part(
root=types_v03.FilePart(
file=types_v03.FileWithUri(
uri='https://example.com/legacy.txt',
mime_type='text/plain',
)
)
)
],
)
],
metadata={'meta_key': 'meta_val'},
)

# 2. Manually insert the legacy data into the database
# We must bypass the store's save() method because it expects Protobuf objects.
async with db_store_parameterized.async_session_maker.begin() as session:
# Pydantic model_dump(mode='json') produces exactly what would be in the legacy DB
legacy_data = legacy_task.model_dump(mode='json')

stmt = insert(db_store_parameterized.task_model).values(
id=task_id,
context_id=legacy_task.context_id,
owner=owner,
status=legacy_data['status'],
history=legacy_data['history'],
artifacts=legacy_data['artifacts'],
task_metadata=legacy_data['metadata'],
kind='task',
last_updated=None,
)
await session.execute(stmt)

# 3. Retrieve the task using the standard store.get()
# This will trigger the PydanticType/PydanticListType legacy fallback
retrieved_task = await db_store_parameterized.get(task_id, context_user)

# 4. Verify the conversion to modern Protobuf
assert retrieved_task is not None
assert retrieved_task.id == task_id
assert retrieved_task.context_id == 'legacy-ctx-1'

# Check Status & State (The most critical part: string 'working' -> enum TASK_STATE_WORKING)
assert retrieved_task.status.state == TaskState.TASK_STATE_WORKING
assert retrieved_task.status.message.message_id == 'msg-status'
assert retrieved_task.status.message.role == Role.ROLE_AGENT
assert (
retrieved_task.status.message.parts[0].text == 'Legacy status message'
)

# Check History
assert len(retrieved_task.history) == 2
assert retrieved_task.history[0].message_id == 'msg-1'
assert retrieved_task.history[0].role == Role.ROLE_USER
assert retrieved_task.history[0].parts[0].text == 'Hello legacy'

assert retrieved_task.history[1].message_id == 'msg-2'
assert retrieved_task.history[1].role == Role.ROLE_AGENT
assert (
MessageToDict(retrieved_task.history[1].parts[0].data)['legacy_key']
== 'value'
)

# Check Artifacts
assert len(retrieved_task.artifacts) == 1
assert retrieved_task.artifacts[0].artifact_id == 'art-1'
assert retrieved_task.artifacts[0].name == 'Legacy Artifact'
assert (
retrieved_task.artifacts[0].parts[0].url
== 'https://example.com/legacy.txt'
)

# Check Metadata
assert dict(retrieved_task.metadata) == {'meta_key': 'meta_val'}

await db_store_parameterized.delete(task_id, context_user)


# Ensure aiosqlite, asyncpg, and aiomysql are installed in the test environment (added to pyproject.toml).
Loading