diff --git a/src/celeste/client.py b/src/celeste/client.py index e79ee9b..236e17d 100644 --- a/src/celeste/client.py +++ b/src/celeste/client.py @@ -16,7 +16,7 @@ from celeste.mime_types import ApplicationMimeType from celeste.models import Model from celeste.parameters import ParameterMapper, Parameters -from celeste.streaming import Stream +from celeste.streaming import Stream, enrich_stream_errors from celeste.types import RawUsage @@ -250,6 +250,7 @@ def _stream( extra_headers=extra_headers, **parameters, ) + sse_iterator = enrich_stream_errors(sse_iterator, self._handle_error_response) return stream_class( sse_iterator, transform_output=self._transform_output, diff --git a/src/celeste/http.py b/src/celeste/http.py index 651130d..6054e80 100644 --- a/src/celeste/http.py +++ b/src/celeste/http.py @@ -185,7 +185,9 @@ async def stream_post( headers=headers, timeout=timeout, ) as event_source: - event_source.response.raise_for_status() + if not event_source.response.is_success: + await event_source.response.aread() + event_source.response.raise_for_status() async for sse in event_source.aiter_sse(): try: yield json.loads(sse.data) @@ -221,7 +223,9 @@ async def stream_post_ndjson( headers=headers, timeout=timeout, ) as response: - response.raise_for_status() + if not response.is_success: + await response.aread() + response.raise_for_status() async for line in response.aiter_lines(): if line: yield json.loads(line) diff --git a/src/celeste/streaming.py b/src/celeste/streaming.py index 95068a5..0ab6b43 100644 --- a/src/celeste/streaming.py +++ b/src/celeste/streaming.py @@ -6,6 +6,7 @@ from types import TracebackType from typing import Any, ClassVar, Self, Unpack +import httpx from anyio.from_thread import BlockingPortal, start_blocking_portal from celeste.exceptions import StreamEventError, StreamNotExhaustedError @@ -15,6 +16,19 @@ from celeste.types import RawUsage +async def enrich_stream_errors( + iterator: AsyncIterator[dict[str, Any]], + error_handler: Callable[[httpx.Response], None], +) -> AsyncIterator[dict[str, Any]]: + """Wrap stream iterator to enrich HTTP errors with provider-specific messages.""" + try: + async for event in iterator: + yield event + except httpx.HTTPStatusError as e: + error_handler(e.response) + raise # Unreachable — error_handler always raises for error responses + + class Stream[Out: Output, Params: Parameters, Chunk: ChunkBase](ABC): """Async iterator wrapper providing final Output access after stream exhaustion. @@ -332,4 +346,4 @@ async def aclose(self) -> None: await self._sse_iterator.aclose() -__all__ = ["Stream"] +__all__ = ["Stream", "enrich_stream_errors"] diff --git a/tests/unit_tests/test_http.py b/tests/unit_tests/test_http.py index 11f22f3..264ea95 100644 --- a/tests/unit_tests/test_http.py +++ b/tests/unit_tests/test_http.py @@ -823,8 +823,83 @@ async def test_stream_post_passes_parameters_correctly( timeout=timeout, ) + async def test_stream_post_raises_http_error_with_readable_body( + self, mock_httpx_client: AsyncMock + ) -> None: + """stream_post reads response body before raising on HTTP errors.""" + # Arrange + http_client = HTTPClient() + mock_response = httpx.Response( + 401, + content=b'{"error": {"message": "Invalid API Key"}}', + request=httpx.Request("POST", "https://api.example.com/stream"), + ) + + mock_source = MagicMock() + mock_source.response = mock_response + mock_source.__aenter__ = AsyncMock(return_value=mock_source) + mock_source.__aexit__ = AsyncMock(return_value=False) + + # Act & Assert + with ( + patch("celeste.http.httpx.AsyncClient", return_value=mock_httpx_client), + patch("celeste.http.aconnect_sse", return_value=mock_source), + pytest.raises(httpx.HTTPStatusError) as exc_info, + ): + async for _ in http_client.stream_post( + url="https://api.example.com/stream", + headers={"Authorization": "Bearer bad-key"}, + json_body={"prompt": "test"}, + ): + pass + + # Body should be readable for downstream enrichment + assert exc_info.value.response.json()["error"]["message"] == "Invalid API Key" + + async def test_stream_post_ndjson_raises_http_error_with_readable_body( + self, mock_httpx_client: AsyncMock + ) -> None: + """stream_post_ndjson reads response body before raising on HTTP errors.""" + # Arrange + http_client = HTTPClient() + error_body = b'{"error": {"message": "Forbidden"}}' + mock_response = httpx.Response( + 403, + content=error_body, + request=httpx.Request("POST", "https://api.example.com/stream"), + ) + mock_httpx_client.stream = MagicMock(return_value=_async_context(mock_response)) + + # Act & Assert + with ( + patch("celeste.http.httpx.AsyncClient", return_value=mock_httpx_client), + pytest.raises(httpx.HTTPStatusError) as exc_info, + ): + async for _ in http_client.stream_post_ndjson( + url="https://api.example.com/stream", + headers={"Authorization": "Bearer bad-key"}, + json_body={"prompt": "test"}, + ): + pass + + # Body should be readable for downstream enrichment + assert exc_info.value.response.json()["error"]["message"] == "Forbidden" + @staticmethod async def _async_iter(items: list) -> AsyncIterator: """Helper to create async iterator from list.""" for item in items: yield item + + +class _async_context: + """Async context manager wrapping a response for client.stream() mocking.""" + + def __init__(self, response: httpx.Response) -> None: + self._response = response + + async def __aenter__(self) -> httpx.Response: + return self._response + + async def __aexit__(self, *args: object) -> None: + pass diff --git a/tests/unit_tests/test_streaming.py b/tests/unit_tests/test_streaming.py index 59a9b3f..f94295b 100644 --- a/tests/unit_tests/test_streaming.py +++ b/tests/unit_tests/test_streaming.py @@ -4,13 +4,14 @@ from typing import Any, ClassVar, Unpack from unittest.mock import AsyncMock +import httpx import pytest from pydantic import Field from celeste.exceptions import StreamEventError, StreamNotExhaustedError from celeste.io import Chunk, FinishReason, Output, Usage from celeste.parameters import Parameters -from celeste.streaming import Stream +from celeste.streaming import Stream, enrich_stream_errors class ConcreteOutput(Output[str]): @@ -876,3 +877,86 @@ async def test_error_provides_full_event_data(self) -> None: async for _ in stream: pass assert exc_info.value.event_data == event + + +class TestEnrichStreamErrors: + """Test enrich_stream_errors wraps streaming HTTP errors with provider messages.""" + + async def test_enriches_http_error_with_provider_message(self) -> None: + """HTTP errors from stream iterators are enriched via error_handler.""" + + async def _failing_stream() -> AsyncIterator[dict[str, Any]]: + response = httpx.Response( + 401, + content=b'{"error": {"message": "Invalid API Key"}}', + request=httpx.Request("POST", "https://api.example.com/v1/chat"), + ) + raise httpx.HTTPStatusError( + "Client error '401 Unauthorized'", + request=response.request, + response=response, + ) + yield # type: ignore[misc] # Make this an async generator + + def _handle_error(response: httpx.Response) -> None: + error_msg = response.json()["error"]["message"] + raise httpx.HTTPStatusError( + f"TestProvider API error: {error_msg}", + request=response.request, + response=response, + ) + + enriched = enrich_stream_errors(_failing_stream(), _handle_error) + + with pytest.raises( + httpx.HTTPStatusError, match="TestProvider API error: Invalid API Key" + ): + async for _ in enriched: + pass + + async def test_passes_through_events_on_success(self) -> None: + """Successful streams pass through events unmodified.""" + + async def _ok_stream() -> AsyncIterator[dict[str, Any]]: + yield {"delta": "Hello"} + yield {"delta": " world"} + + enriched = enrich_stream_errors(_ok_stream(), lambda r: None) + events = [event async for event in enriched] + + assert events == [{"delta": "Hello"}, {"delta": " world"}] + + async def test_enriches_error_with_non_json_body(self) -> None: + """Error handler receives response even when body isn't valid JSON.""" + + async def _failing_stream() -> AsyncIterator[dict[str, Any]]: + response = httpx.Response( + 500, + content=b"Internal Server Error", + request=httpx.Request("POST", "https://api.example.com/v1/chat"), + ) + raise httpx.HTTPStatusError( + "Server error '500 Internal Server Error'", + request=response.request, + response=response, + ) + yield # type: ignore[misc] # Make this an async generator + + def _handle_error(response: httpx.Response) -> None: + try: + error_msg = response.json()["error"]["message"] + except Exception: + error_msg = response.text or f"HTTP {response.status_code}" + raise httpx.HTTPStatusError( + f"TestProvider API error: {error_msg}", + request=response.request, + response=response, + ) + + enriched = enrich_stream_errors(_failing_stream(), _handle_error) + + with pytest.raises( + httpx.HTTPStatusError, match="TestProvider API error: Internal Server Error" + ): + async for _ in enriched: + pass