Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/celeste/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from celeste.mime_types import ApplicationMimeType
from celeste.models import Model
from celeste.parameters import ParameterMapper, Parameters
from celeste.streaming import Stream
from celeste.streaming import Stream, enrich_stream_errors
from celeste.types import RawUsage


Expand Down Expand Up @@ -250,6 +250,7 @@ def _stream(
extra_headers=extra_headers,
**parameters,
)
sse_iterator = enrich_stream_errors(sse_iterator, self._handle_error_response)
return stream_class(
sse_iterator,
transform_output=self._transform_output,
Expand Down
8 changes: 6 additions & 2 deletions src/celeste/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ async def stream_post(
headers=headers,
timeout=timeout,
) as event_source:
event_source.response.raise_for_status()
if not event_source.response.is_success:
await event_source.response.aread()
event_source.response.raise_for_status()
async for sse in event_source.aiter_sse():
try:
yield json.loads(sse.data)
Expand Down Expand Up @@ -221,7 +223,9 @@ async def stream_post_ndjson(
headers=headers,
timeout=timeout,
) as response:
response.raise_for_status()
if not response.is_success:
await response.aread()
response.raise_for_status()
async for line in response.aiter_lines():
if line:
yield json.loads(line)
Expand Down
16 changes: 15 additions & 1 deletion src/celeste/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from types import TracebackType
from typing import Any, ClassVar, Self, Unpack

import httpx
from anyio.from_thread import BlockingPortal, start_blocking_portal

from celeste.exceptions import StreamEventError, StreamNotExhaustedError
Expand All @@ -15,6 +16,19 @@
from celeste.types import RawUsage


async def enrich_stream_errors(
iterator: AsyncIterator[dict[str, Any]],
error_handler: Callable[[httpx.Response], None],
) -> AsyncIterator[dict[str, Any]]:
"""Wrap stream iterator to enrich HTTP errors with provider-specific messages."""
try:
async for event in iterator:
yield event
except httpx.HTTPStatusError as e:
error_handler(e.response)
raise # Unreachable — error_handler always raises for error responses


class Stream[Out: Output, Params: Parameters, Chunk: ChunkBase](ABC):
"""Async iterator wrapper providing final Output access after stream exhaustion.

Expand Down Expand Up @@ -332,4 +346,4 @@ async def aclose(self) -> None:
await self._sse_iterator.aclose()


__all__ = ["Stream"]
__all__ = ["Stream", "enrich_stream_errors"]
75 changes: 75 additions & 0 deletions tests/unit_tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,8 +823,83 @@ async def test_stream_post_passes_parameters_correctly(
timeout=timeout,
)

async def test_stream_post_raises_http_error_with_readable_body(
self, mock_httpx_client: AsyncMock
) -> None:
"""stream_post reads response body before raising on HTTP errors."""
# Arrange
http_client = HTTPClient()
mock_response = httpx.Response(
401,
content=b'{"error": {"message": "Invalid API Key"}}',
request=httpx.Request("POST", "https://api.example.com/stream"),
)

mock_source = MagicMock()
mock_source.response = mock_response
mock_source.__aenter__ = AsyncMock(return_value=mock_source)
mock_source.__aexit__ = AsyncMock(return_value=False)

# Act & Assert
with (
patch("celeste.http.httpx.AsyncClient", return_value=mock_httpx_client),
patch("celeste.http.aconnect_sse", return_value=mock_source),
pytest.raises(httpx.HTTPStatusError) as exc_info,
):
async for _ in http_client.stream_post(
url="https://api.example.com/stream",
headers={"Authorization": "Bearer bad-key"},
json_body={"prompt": "test"},
):
pass

# Body should be readable for downstream enrichment
assert exc_info.value.response.json()["error"]["message"] == "Invalid API Key"

async def test_stream_post_ndjson_raises_http_error_with_readable_body(
self, mock_httpx_client: AsyncMock
) -> None:
"""stream_post_ndjson reads response body before raising on HTTP errors."""
# Arrange
http_client = HTTPClient()
error_body = b'{"error": {"message": "Forbidden"}}'
mock_response = httpx.Response(
403,
content=error_body,
request=httpx.Request("POST", "https://api.example.com/stream"),
)
mock_httpx_client.stream = MagicMock(return_value=_async_context(mock_response))

# Act & Assert
with (
patch("celeste.http.httpx.AsyncClient", return_value=mock_httpx_client),
pytest.raises(httpx.HTTPStatusError) as exc_info,
):
async for _ in http_client.stream_post_ndjson(
url="https://api.example.com/stream",
headers={"Authorization": "Bearer bad-key"},
json_body={"prompt": "test"},
):
pass

# Body should be readable for downstream enrichment
assert exc_info.value.response.json()["error"]["message"] == "Forbidden"

@staticmethod
async def _async_iter(items: list) -> AsyncIterator:
"""Helper to create async iterator from list."""
for item in items:
yield item


class _async_context:
"""Async context manager wrapping a response for client.stream() mocking."""

def __init__(self, response: httpx.Response) -> None:
self._response = response

async def __aenter__(self) -> httpx.Response:
return self._response

async def __aexit__(self, *args: object) -> None:
pass
86 changes: 85 additions & 1 deletion tests/unit_tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from typing import Any, ClassVar, Unpack
from unittest.mock import AsyncMock

import httpx
import pytest
from pydantic import Field

from celeste.exceptions import StreamEventError, StreamNotExhaustedError
from celeste.io import Chunk, FinishReason, Output, Usage
from celeste.parameters import Parameters
from celeste.streaming import Stream
from celeste.streaming import Stream, enrich_stream_errors


class ConcreteOutput(Output[str]):
Expand Down Expand Up @@ -876,3 +877,86 @@ async def test_error_provides_full_event_data(self) -> None:
async for _ in stream:
pass
assert exc_info.value.event_data == event


class TestEnrichStreamErrors:
"""Test enrich_stream_errors wraps streaming HTTP errors with provider messages."""

async def test_enriches_http_error_with_provider_message(self) -> None:
"""HTTP errors from stream iterators are enriched via error_handler."""

async def _failing_stream() -> AsyncIterator[dict[str, Any]]:
response = httpx.Response(
401,
content=b'{"error": {"message": "Invalid API Key"}}',
request=httpx.Request("POST", "https://api.example.com/v1/chat"),
)
raise httpx.HTTPStatusError(
"Client error '401 Unauthorized'",
request=response.request,
response=response,
)
yield # type: ignore[misc] # Make this an async generator

def _handle_error(response: httpx.Response) -> None:
error_msg = response.json()["error"]["message"]
raise httpx.HTTPStatusError(
f"TestProvider API error: {error_msg}",
request=response.request,
response=response,
)

enriched = enrich_stream_errors(_failing_stream(), _handle_error)

with pytest.raises(
httpx.HTTPStatusError, match="TestProvider API error: Invalid API Key"
):
async for _ in enriched:
pass

async def test_passes_through_events_on_success(self) -> None:
"""Successful streams pass through events unmodified."""

async def _ok_stream() -> AsyncIterator[dict[str, Any]]:
yield {"delta": "Hello"}
yield {"delta": " world"}

enriched = enrich_stream_errors(_ok_stream(), lambda r: None)
events = [event async for event in enriched]

assert events == [{"delta": "Hello"}, {"delta": " world"}]

async def test_enriches_error_with_non_json_body(self) -> None:
"""Error handler receives response even when body isn't valid JSON."""

async def _failing_stream() -> AsyncIterator[dict[str, Any]]:
response = httpx.Response(
500,
content=b"Internal Server Error",
request=httpx.Request("POST", "https://api.example.com/v1/chat"),
)
raise httpx.HTTPStatusError(
"Server error '500 Internal Server Error'",
request=response.request,
response=response,
)
yield # type: ignore[misc] # Make this an async generator

def _handle_error(response: httpx.Response) -> None:
try:
error_msg = response.json()["error"]["message"]
except Exception:
error_msg = response.text or f"HTTP {response.status_code}"
raise httpx.HTTPStatusError(
f"TestProvider API error: {error_msg}",
request=response.request,
response=response,
)

enriched = enrich_stream_errors(_failing_stream(), _handle_error)

with pytest.raises(
httpx.HTTPStatusError, match="TestProvider API error: Internal Server Error"
):
async for _ in enriched:
pass
Loading