diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index bba12e90..eb2604e7 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -11,10 +11,12 @@ def override(func): # noqa: ANN001, ANN201 return func -from google.protobuf.json_format import MessageToDict, ParseDict +from google.protobuf.json_format import MessageToDict, ParseDict, ParseError from google.protobuf.message import Message as ProtoMessage -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError +from a2a.compat.v0_3 import conversions +from a2a.compat.v0_3 import types as types_v03 from a2a.types.a2a_pb2 import Artifact, Message, TaskStatus @@ -81,7 +83,19 @@ def process_result_value( if isinstance(self.pydantic_type, type) and issubclass( self.pydantic_type, ProtoMessage ): - return ParseDict(value, self.pydantic_type()) # type: ignore[return-value] + try: + return ParseDict(value, self.pydantic_type()) # type: ignore[return-value] + except (ParseError, ValueError): + # Try legacy conversion + legacy_map = _get_legacy_conversions() + if self.pydantic_type in legacy_map: + legacy_type, convert_func = legacy_map[self.pydantic_type] + try: + legacy_instance = legacy_type.model_validate(value) + return convert_func(legacy_instance) + except ValidationError: + pass + raise # Assume it's a Pydantic model return self.pydantic_type.model_validate(value) # type: ignore[attr-defined] @@ -130,7 +144,24 @@ def process_result_value( if isinstance(self.pydantic_type, type) and issubclass( self.pydantic_type, ProtoMessage ): - return [ParseDict(item, self.pydantic_type()) for item in value] # type: ignore[misc] + result = [] + legacy_map = _get_legacy_conversions() + legacy_info = legacy_map.get(self.pydantic_type) + + for item in value: + try: + result.append(ParseDict(item, self.pydantic_type())) + except (ParseError, ValueError): # noqa: PERF203 + if legacy_info: + legacy_type, convert_func = legacy_info + try: + legacy_instance = legacy_type.model_validate(item) + result.append(convert_func(legacy_instance)) + continue + except ValidationError: + pass + raise + return result # type: ignore[return-value] # Assume it's a Pydantic model return [self.pydantic_type.model_validate(item) for item in value] # type: ignore[attr-defined] @@ -292,3 +323,15 @@ class PushNotificationConfigModel(PushNotificationConfigMixin, Base): """Default push notification config model with standard table name.""" __tablename__ = 'push_notification_configs' + + +def _get_legacy_conversions() -> dict[type, tuple[type, Any]]: + """Get the mapping of current types to their legacy counterparts and conversion functions.""" + return { + TaskStatus: ( + types_v03.TaskStatus, + conversions.to_core_task_status, + ), + Message: (types_v03.Message, conversions.to_core_message), + Artifact: (types_v03.Artifact, conversions.to_core_artifact), + } diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index b71fd709..9514a07a 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -683,4 +683,139 @@ async def test_owner_resource_scoping( await task_store.delete('u2-task1', context_user2) +@pytest.mark.asyncio +async def test_get_0_3_task_detailed( + db_store_parameterized: DatabaseTaskStore, +) -> None: + """Test retrieving a detailed legacy v0.3 task from the database. + + This test simulates a database that already contains legacy v0.3 JSON data + (string-based enums, different field names) and verifies that the store + correctly converts it to the modern Protobuf-based Task model. + """ + from a2a.compat.v0_3 import types as types_v03 + from sqlalchemy import insert + + task_id = 'legacy-detailed-1' + owner = 'legacy_user' + context_user = ServerCallContext(user=SampleUser(user_name=owner)) + + # 1. Create a detailed legacy Task using v0.3 models + legacy_task = types_v03.Task( + id=task_id, + context_id='legacy-ctx-1', + status=types_v03.TaskStatus( + state=types_v03.TaskState.working, + message=types_v03.Message( + message_id='msg-status', + role=types_v03.Role.agent, + parts=[ + types_v03.Part( + root=types_v03.TextPart(text='Legacy status message') + ) + ], + ), + timestamp='2023-10-27T10:00:00Z', + ), + history=[ + types_v03.Message( + message_id='msg-1', + role=types_v03.Role.user, + parts=[ + types_v03.Part(root=types_v03.TextPart(text='Hello legacy')) + ], + ), + types_v03.Message( + message_id='msg-2', + role=types_v03.Role.agent, + parts=[ + types_v03.Part( + root=types_v03.DataPart(data={'legacy_key': 'value'}) + ) + ], + ), + ], + artifacts=[ + types_v03.Artifact( + artifact_id='art-1', + name='Legacy Artifact', + parts=[ + types_v03.Part( + root=types_v03.FilePart( + file=types_v03.FileWithUri( + uri='https://example.com/legacy.txt', + mime_type='text/plain', + ) + ) + ) + ], + ) + ], + metadata={'meta_key': 'meta_val'}, + ) + + # 2. Manually insert the legacy data into the database + # We must bypass the store's save() method because it expects Protobuf objects. + async with db_store_parameterized.async_session_maker.begin() as session: + # Pydantic model_dump(mode='json') produces exactly what would be in the legacy DB + legacy_data = legacy_task.model_dump(mode='json') + + stmt = insert(db_store_parameterized.task_model).values( + id=task_id, + context_id=legacy_task.context_id, + owner=owner, + status=legacy_data['status'], + history=legacy_data['history'], + artifacts=legacy_data['artifacts'], + task_metadata=legacy_data['metadata'], + kind='task', + last_updated=None, + ) + await session.execute(stmt) + + # 3. Retrieve the task using the standard store.get() + # This will trigger the PydanticType/PydanticListType legacy fallback + retrieved_task = await db_store_parameterized.get(task_id, context_user) + + # 4. Verify the conversion to modern Protobuf + assert retrieved_task is not None + assert retrieved_task.id == task_id + assert retrieved_task.context_id == 'legacy-ctx-1' + + # Check Status & State (The most critical part: string 'working' -> enum TASK_STATE_WORKING) + assert retrieved_task.status.state == TaskState.TASK_STATE_WORKING + assert retrieved_task.status.message.message_id == 'msg-status' + assert retrieved_task.status.message.role == Role.ROLE_AGENT + assert ( + retrieved_task.status.message.parts[0].text == 'Legacy status message' + ) + + # Check History + assert len(retrieved_task.history) == 2 + assert retrieved_task.history[0].message_id == 'msg-1' + assert retrieved_task.history[0].role == Role.ROLE_USER + assert retrieved_task.history[0].parts[0].text == 'Hello legacy' + + assert retrieved_task.history[1].message_id == 'msg-2' + assert retrieved_task.history[1].role == Role.ROLE_AGENT + assert ( + MessageToDict(retrieved_task.history[1].parts[0].data)['legacy_key'] + == 'value' + ) + + # Check Artifacts + assert len(retrieved_task.artifacts) == 1 + assert retrieved_task.artifacts[0].artifact_id == 'art-1' + assert retrieved_task.artifacts[0].name == 'Legacy Artifact' + assert ( + retrieved_task.artifacts[0].parts[0].url + == 'https://example.com/legacy.txt' + ) + + # Check Metadata + assert dict(retrieved_task.metadata) == {'meta_key': 'meta_val'} + + await db_store_parameterized.delete(task_id, context_user) + + # Ensure aiosqlite, asyncpg, and aiomysql are installed in the test environment (added to pyproject.toml).