diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py index fb971b9a..fdc748ea 100644 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -24,12 +24,15 @@ _package_fastapi_installed = False +from starlette.exceptions import HTTPException as StarletteHTTPException + from a2a.server.apps.jsonrpc.jsonrpc_app import CallContextBuilder from a2a.server.apps.rest.rest_adapter import RESTAdapter from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import AgentCard from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH +from a2a.utils.error_handlers import rest_error_handler logger = logging.getLogger(__name__) @@ -112,10 +115,28 @@ def build( f'{rpc_url}{route[0]}', callback, methods=[route[1]] ) + # Catch exceptions thrown by card modifiers. @router.get(f'{rpc_url}{agent_card_url}') + @rest_error_handler async def get_agent_card(request: Request) -> Response: card = await self._adapter.handle_get_agent_card(request) return JSONResponse(card) app.include_router(router) + + @app.exception_handler(StarletteHTTPException) + async def http_exception_handler( + request: Request, exc: StarletteHTTPException + ) -> JSONResponse: + return JSONResponse( + status_code=exc.status_code, + content={ + 'type': 'about:blank', + 'title': 'HTTP Error', + 'status': exc.status_code, + 'detail': exc.detail, + }, + media_type='application/problem+json', + ) + return app diff --git a/src/a2a/utils/error_handlers.py b/src/a2a/utils/error_handlers.py index bd30595a..8b401460 100644 --- a/src/a2a/utils/error_handlers.py +++ b/src/a2a/utils/error_handlers.py @@ -62,19 +62,86 @@ JSONParseError: 400, InvalidRequestError: 400, MethodNotFoundError: 404, - InvalidParamsError: 422, + InvalidParamsError: 400, InternalError: 500, JSONRPCInternalError: 500, TaskNotFoundError: 404, TaskNotCancelableError: 409, - PushNotificationNotSupportedError: 501, - UnsupportedOperationError: 501, + PushNotificationNotSupportedError: 400, + UnsupportedOperationError: 400, ContentTypeNotSupportedError: 415, InvalidAgentResponseError: 502, - AuthenticatedExtendedCardNotConfiguredError: 404, + AuthenticatedExtendedCardNotConfiguredError: 400, +} + +A2AErrorToTypeURI: dict[_A2AErrorType, str] = { + TaskNotFoundError: 'https://a2a-protocol.org/errors/task-not-found', + TaskNotCancelableError: 'https://a2a-protocol.org/errors/task-not-cancelable', + PushNotificationNotSupportedError: 'https://a2a-protocol.org/errors/push-notification-not-supported', + UnsupportedOperationError: 'https://a2a-protocol.org/errors/unsupported-operation', + ContentTypeNotSupportedError: 'https://a2a-protocol.org/errors/content-type-not-supported', + InvalidAgentResponseError: 'https://a2a-protocol.org/errors/invalid-agent-response', + AuthenticatedExtendedCardNotConfiguredError: 'https://a2a-protocol.org/errors/extended-agent-card-not-configured', +} + +A2AErrorToTitle: dict[_A2AErrorType, str] = { + JSONRPCError: 'JSON RPC Error', + JSONParseError: 'JSON Parse Error', + InvalidRequestError: 'Invalid Request Error', + MethodNotFoundError: 'Method Not Found Error', + InvalidParamsError: 'Invalid Params Error', + InternalError: 'Internal Error', + JSONRPCInternalError: 'Internal Error', + TaskNotFoundError: 'Task Not Found', + TaskNotCancelableError: 'Task Not Cancelable', + PushNotificationNotSupportedError: 'Push Notification Not Supported', + UnsupportedOperationError: 'Unsupported Operation', + ContentTypeNotSupportedError: 'Content Type Not Supported', + InvalidAgentResponseError: 'Invalid Agent Response', + AuthenticatedExtendedCardNotConfiguredError: 'Extended Agent Card Not Configured', } +def _build_problem_details_response(error: A2AError) -> JSONResponse: + """Helper to convert exceptions to RFC 9457 Problem Details responses.""" + error_type = cast('_A2AErrorType', type(error)) + http_code = A2AErrorToHttpStatus.get(error_type, 500) + type_uri = A2AErrorToTypeURI.get(error_type, 'about:blank') + title = A2AErrorToTitle.get(error_type, error.__class__.__name__) + + log_level = ( + logging.ERROR if isinstance(error, InternalError) else logging.WARNING + ) + logger.log( + log_level, + "Request error: Code=%s, Message='%s'%s", + getattr(error, 'code', 'N/A'), + getattr(error, 'message', str(error)), + ', Data=' + str(getattr(error, 'data', '')) + if getattr(error, 'data', None) + else '', + ) + + payload = { + 'type': type_uri, + 'title': title, + 'status': http_code, + 'detail': getattr(error, 'message', str(error)), + } + + data = getattr(error, 'data', None) + if isinstance(data, dict): + for key, value in data.items(): + if key not in payload: + payload[key] = value + + return JSONResponse( + content=payload, + status_code=http_code, + media_type='application/problem+json', + ) + + def rest_error_handler( func: Callable[..., Awaitable[Response]], ) -> Callable[..., Awaitable[Response]]: @@ -85,37 +152,18 @@ async def wrapper(*args: Any, **kwargs: Any) -> Response: try: return await func(*args, **kwargs) except A2AError as error: - http_code = A2AErrorToHttpStatus.get( - cast('_A2AErrorType', type(error)), 500 - ) - - log_level = ( - logging.ERROR - if isinstance(error, InternalError) - else logging.WARNING - ) - logger.log( - log_level, - "Request error: Code=%s, Message='%s'%s", - getattr(error, 'code', 'N/A'), - getattr(error, 'message', str(error)), - ', Data=' + str(getattr(error, 'data', '')) - if getattr(error, 'data', None) - else '', - ) - # TODO(#722): Standardize error response format. - return JSONResponse( - content={ - 'message': getattr(error, 'message', str(error)), - 'type': type(error).__name__, - }, - status_code=http_code, - ) + return _build_problem_details_response(error) except Exception: logger.exception('Unknown error occurred') return JSONResponse( - content={'message': 'unknown exception', 'type': 'Exception'}, + content={ + 'type': 'about:blank', + 'title': 'Internal Error', + 'status': 500, + 'detail': 'Unknown exception', + }, status_code=500, + media_type='application/problem+json', ) return wrapper diff --git a/tests/server/apps/rest/test_rest_fastapi_app.py b/tests/server/apps/rest/test_rest_fastapi_app.py index a58936b3..f0fc6238 100644 --- a/tests/server/apps/rest/test_rest_fastapi_app.py +++ b/tests/server/apps/rest/test_rest_fastapi_app.py @@ -1,7 +1,7 @@ import logging from typing import Any -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest @@ -23,6 +23,7 @@ TaskState, TaskStatus, ) +from a2a.utils.errors import InternalError logger = logging.getLogger(__name__) @@ -396,5 +397,54 @@ async def test_send_message_rejected_task( assert expected_response == actual_response +@pytest.mark.anyio +async def test_global_http_exception_handler_returns_problem_details( + client: AsyncClient, +) -> None: + """Test that a standard FastAPI 404 is transformed into RFC 9457 format.""" + + # Send a request to an endpoint that does not exist + response = await client.get('/non-existent-route') + + # Verify it returns a 404, but in the new RFC 9457 format + assert response.status_code == 404 + assert response.headers.get('content-type') == 'application/problem+json' + + data = response.json() + assert data['type'] == 'about:blank' + assert data['title'] == 'HTTP Error' + assert data['status'] == 404 + assert 'Not Found' in data['detail'] + + +@pytest.mark.anyio +async def test_get_agent_card_error_handling( + client: AsyncClient, +) -> None: + """Test that the agent card endpoint properly catches and formats A2A errors.""" + + # Mock the REST adapter to simulate an internal failure when fetching the card + with patch( + 'a2a.server.apps.rest.rest_adapter.RESTAdapter.handle_get_agent_card', + side_effect=InternalError( + message='Failed to load customized agent card' + ), + ): + # In the fixtures, the agent card URL is set to /well-known/agent.json + response = await client.get('/well-known/agent.json') + + # Verify the error was caught and serialized cleanly + assert response.status_code == 500 + assert ( + response.headers.get('content-type') == 'application/problem+json' + ) + + data = response.json() + assert data['type'] == 'about:blank' + assert data['title'] == 'Internal Error' + assert data['status'] == 500 + assert data['detail'] == 'Failed to load customized agent card' + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/utils/test_error_handlers.py b/tests/utils/test_error_handlers.py index e20c402a..73d6d7f2 100644 --- a/tests/utils/test_error_handlers.py +++ b/tests/utils/test_error_handlers.py @@ -1,7 +1,6 @@ """Tests for a2a.utils.error_handlers module.""" from unittest.mock import patch - import pytest from a2a.types import ( @@ -14,15 +13,18 @@ ) from a2a.utils.error_handlers import ( A2AErrorToHttpStatus, + A2AErrorToTitle, + A2AErrorToTypeURI, rest_error_handler, rest_stream_error_handler, ) class MockJSONResponse: - def __init__(self, content, status_code): + def __init__(self, content, status_code, media_type=None): self.content = content self.status_code = status_code + self.media_type = media_type @pytest.mark.asyncio @@ -39,9 +41,39 @@ async def failing_func(): assert isinstance(result, MockJSONResponse) assert result.status_code == 400 + assert result.media_type == 'application/problem+json' assert result.content == { - 'message': 'Bad request', - 'type': 'InvalidRequestError', + 'type': 'about:blank', + 'title': 'Invalid Request Error', + 'status': 400, + 'detail': 'Bad request', + } + + +@pytest.mark.asyncio +async def test_rest_error_handler_with_data_extensions(): + """Test rest_error_handler maps A2AError.data to extension fields.""" + error = TaskNotFoundError(message='Task not found') + # Dynamically attach data since __init__ no longer accepts it + error.data = {'taskId': '123', 'retry': False} + + @rest_error_handler + async def failing_func(): + raise error + + with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse): + result = await failing_func() + + assert isinstance(result, MockJSONResponse) + assert result.status_code == 404 + assert result.media_type == 'application/problem+json' + assert result.content == { + 'type': 'https://a2a-protocol.org/errors/task-not-found', + 'title': 'Task Not Found', + 'status': 404, + 'detail': 'Task not found', + 'taskId': '123', + 'retry': False, } @@ -58,9 +90,12 @@ async def failing_func(): assert isinstance(result, MockJSONResponse) assert result.status_code == 500 + assert result.media_type == 'application/problem+json' assert result.content == { - 'message': 'unknown exception', - 'type': 'Exception', + 'type': 'about:blank', + 'title': 'Internal Error', + 'status': 500, + 'detail': 'Unknown exception', } @@ -91,9 +126,20 @@ async def failing_stream(): await failing_stream() -def test_a2a_error_to_http_status_mapping(): - """Test A2AErrorToHttpStatus mapping.""" +def test_a2a_error_mappings(): + """Test A2A error mappings.""" + # HTTP Status assert A2AErrorToHttpStatus[InvalidRequestError] == 400 assert A2AErrorToHttpStatus[MethodNotFoundError] == 404 assert A2AErrorToHttpStatus[TaskNotFoundError] == 404 assert A2AErrorToHttpStatus[InternalError] == 500 + + # Type URI + assert ( + A2AErrorToTypeURI[TaskNotFoundError] + == 'https://a2a-protocol.org/errors/task-not-found' + ) + + # Title + assert A2AErrorToTitle[TaskNotFoundError] == 'Task Not Found' + assert A2AErrorToTitle[InvalidRequestError] == 'Invalid Request Error'