Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/anthropic/lib/bedrock/_stream_decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Iterator, AsyncIterator

from ..._utils import lru_cache
Expand Down Expand Up @@ -37,7 +38,11 @@ def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
yield ServerSentEvent(data=message, event="completion")
try:
event_type = json.loads(message).get("type", "completion")
except Exception:
event_type = "completion"
yield ServerSentEvent(data=message, event=event_type)

async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
Expand All @@ -49,7 +54,11 @@ async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[Ser
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
yield ServerSentEvent(data=message, event="completion")
try:
event_type = json.loads(message).get("type", "completion")
except Exception:
event_type = "completion"
yield ServerSentEvent(data=message, event=event_type)

def _parse_message_from_event(self, event: EventStreamMessage) -> str | None:
response_dict = event.to_response_dict()
Expand Down
87 changes: 86 additions & 1 deletion tests/lib/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import typing as t
import tempfile
from typing import TypedDict, cast
from typing import List, TypedDict, cast
from typing_extensions import Protocol

import httpx
Expand Down Expand Up @@ -275,3 +275,88 @@ def test_region_infer_from_specified_profile(
client = AnthropicBedrock()

assert client.aws_region == next(profile for profile in profiles if profile["name"] == aws_profile)["region"]


def test_stream_decoder_preserves_sse_event_type() -> None:
import json
from unittest.mock import MagicMock, patch

from anthropic.lib.bedrock._stream_decoder import AWSEventStreamDecoder

message_json = json.dumps(
{
"type": "message_start",
"message": {
"id": "msg_1",
"type": "message",
"role": "assistant",
"content": [],
"model": "claude-3-5-sonnet-20241022",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 10, "output_tokens": 0},
},
}
)

mock_event = MagicMock()
mock_buffer_instance = MagicMock()
mock_buffer_instance.__iter__.return_value = iter([mock_event])

decoder = AWSEventStreamDecoder()

with patch("botocore.eventstream.EventStreamBuffer", return_value=mock_buffer_instance), patch.object(
decoder, "_parse_message_from_event", return_value=message_json
):
events = list(decoder.iter_bytes(iter([b"fake_chunk"])))

assert len(events) == 1
assert events[0].event == "message_start"
assert events[0].event != "completion"


@pytest.mark.asyncio()
async def test_stream_decoder_preserves_sse_event_type_async() -> None:
import json
from unittest.mock import MagicMock, patch

from anthropic.lib.bedrock._stream_decoder import AWSEventStreamDecoder

message_json = json.dumps(
{
"type": "message_start",
"message": {
"id": "msg_1",
"type": "message",
"role": "assistant",
"content": [],
"model": "claude-3-5-sonnet-20241022",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 10, "output_tokens": 0},
},
}
)

mock_event = MagicMock()
mock_buffer_instance = MagicMock()
mock_buffer_instance.__iter__.return_value = iter([mock_event])

decoder = AWSEventStreamDecoder()

async def fake_aiter(chunks: t.List[bytes]) -> t.AsyncIterator[bytes]:
for chunk in chunks:
yield chunk

from anthropic._streaming import ServerSentEvent

with patch("botocore.eventstream.EventStreamBuffer", return_value=mock_buffer_instance), patch.object(
decoder, "_parse_message_from_event", return_value=message_json
):
events: List[ServerSentEvent] = []
async for event in decoder.aiter_bytes(fake_aiter([b"fake_chunk"])):
events.append(event)

assert len(events) == 1
assert events[0].event == "message_start"
assert events[0].event != "completion"