From d2750ce2e2b87589b7a82bbba6e718b48a83d206 Mon Sep 17 00:00:00 2001 From: Maitreyee Deshmukh Date: Thu, 4 Jun 2026 16:28:11 -0700 Subject: [PATCH] fix(bedrock): preserve SSE event type in stream decoder instead of hardcoding completion --- src/anthropic/lib/bedrock/_stream_decoder.py | 13 ++- tests/lib/test_bedrock.py | 87 +++++++++++++++++++- 2 files changed, 97 insertions(+), 3 deletions(-) diff --git a/src/anthropic/lib/bedrock/_stream_decoder.py b/src/anthropic/lib/bedrock/_stream_decoder.py index 02e81a3ca..ecb3222c9 100644 --- a/src/anthropic/lib/bedrock/_stream_decoder.py +++ b/src/anthropic/lib/bedrock/_stream_decoder.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import TYPE_CHECKING, Iterator, AsyncIterator from ..._utils import lru_cache @@ -37,7 +38,11 @@ def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: for event in event_stream_buffer: message = self._parse_message_from_event(event) if message: - yield ServerSentEvent(data=message, event="completion") + try: + event_type = json.loads(message).get("type", "completion") + except Exception: + event_type = "completion" + yield ServerSentEvent(data=message, event=event_type) async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: """Given an async iterator that yields lines, iterate over it & yield every event encountered""" @@ -49,7 +54,11 @@ async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[Ser for event in event_stream_buffer: message = self._parse_message_from_event(event) if message: - yield ServerSentEvent(data=message, event="completion") + try: + event_type = json.loads(message).get("type", "completion") + except Exception: + event_type = "completion" + yield ServerSentEvent(data=message, event=event_type) def _parse_message_from_event(self, event: EventStreamMessage) -> str | None: response_dict = event.to_response_dict() diff --git a/tests/lib/test_bedrock.py b/tests/lib/test_bedrock.py index 6e45c27f7..b5ab5f0ef 100644 --- a/tests/lib/test_bedrock.py +++ b/tests/lib/test_bedrock.py @@ -1,7 +1,7 @@ import re import typing as t import tempfile -from typing import TypedDict, cast +from typing import List, TypedDict, cast from typing_extensions import Protocol import httpx @@ -275,3 +275,88 @@ def test_region_infer_from_specified_profile( client = AnthropicBedrock() assert client.aws_region == next(profile for profile in profiles if profile["name"] == aws_profile)["region"] + + +def test_stream_decoder_preserves_sse_event_type() -> None: + import json + from unittest.mock import MagicMock, patch + + from anthropic.lib.bedrock._stream_decoder import AWSEventStreamDecoder + + message_json = json.dumps( + { + "type": "message_start", + "message": { + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-5-sonnet-20241022", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 10, "output_tokens": 0}, + }, + } + ) + + mock_event = MagicMock() + mock_buffer_instance = MagicMock() + mock_buffer_instance.__iter__.return_value = iter([mock_event]) + + decoder = AWSEventStreamDecoder() + + with patch("botocore.eventstream.EventStreamBuffer", return_value=mock_buffer_instance), patch.object( + decoder, "_parse_message_from_event", return_value=message_json + ): + events = list(decoder.iter_bytes(iter([b"fake_chunk"]))) + + assert len(events) == 1 + assert events[0].event == "message_start" + assert events[0].event != "completion" + + +@pytest.mark.asyncio() +async def test_stream_decoder_preserves_sse_event_type_async() -> None: + import json + from unittest.mock import MagicMock, patch + + from anthropic.lib.bedrock._stream_decoder import AWSEventStreamDecoder + + message_json = json.dumps( + { + "type": "message_start", + "message": { + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-5-sonnet-20241022", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 10, "output_tokens": 0}, + }, + } + ) + + mock_event = MagicMock() + mock_buffer_instance = MagicMock() + mock_buffer_instance.__iter__.return_value = iter([mock_event]) + + decoder = AWSEventStreamDecoder() + + async def fake_aiter(chunks: t.List[bytes]) -> t.AsyncIterator[bytes]: + for chunk in chunks: + yield chunk + + from anthropic._streaming import ServerSentEvent + + with patch("botocore.eventstream.EventStreamBuffer", return_value=mock_buffer_instance), patch.object( + decoder, "_parse_message_from_event", return_value=message_json + ): + events: List[ServerSentEvent] = [] + async for event in decoder.aiter_bytes(fake_aiter([b"fake_chunk"])): + events.append(event) + + assert len(events) == 1 + assert events[0].event == "message_start" + assert events[0].event != "completion"