diff --git a/src/openai/_send_queue.py b/src/openai/_send_queue.py index b35d0fbcba..998b14850a 100644 --- a/src/openai/_send_queue.py +++ b/src/openai/_send_queue.py @@ -11,31 +11,32 @@ class SendQueue: """Bounded byte-size queue for outgoing WebSocket messages. - Messages are stored as pre-serialized strings. The queue enforces a - maximum byte budget so that unbounded buffering cannot occur during - reconnection windows. + Messages are stored as either ``str`` (text frames) or ``bytes`` (binary + frames), preserving the original frame type so that binary payloads are + not corrupted on replay. The queue enforces a maximum byte budget so that + unbounded buffering cannot occur during reconnection windows. """ def __init__(self, max_bytes: int = 1_048_576) -> None: - self._queue: list[tuple[str, int]] = [] # (data, byte_length) + self._queue: list[tuple[bytes | str, int]] = [] # (data, byte_length) self._bytes: int = 0 self._max_bytes = max_bytes self._lock = threading.Lock() - def enqueue(self, data: str) -> None: + def enqueue(self, data: bytes | str) -> None: """Append *data* to the queue. Raises :class:`WebSocketQueueFullError` if the message would exceed the byte-size limit. """ - byte_length = len(data.encode("utf-8")) + byte_length = len(data) if isinstance(data, bytes) else len(data.encode("utf-8")) with self._lock: if self._bytes + byte_length > self._max_bytes: raise WebSocketQueueFullError("send queue is full, message discarded") self._queue.append((data, byte_length)) self._bytes += byte_length - def flush_sync(self, send: typing.Callable[[str], object]) -> None: + def flush_sync(self, send: typing.Callable[[bytes | str], object]) -> None: """Send every queued message via *send*. If *send* raises, the failing message and all subsequent messages @@ -56,7 +57,7 @@ def flush_sync(self, send: typing.Callable[[str], object]) -> None: self._bytes = sum(bl for _, bl in self._queue) raise - async def flush_async(self, send: typing.Callable[[str], typing.Awaitable[object]]) -> None: + async def flush_async(self, send: typing.Callable[[bytes | str], typing.Awaitable[object]]) -> None: """Async variant of :meth:`flush_sync`.""" with self._lock: pending = list(self._queue) @@ -73,7 +74,7 @@ async def flush_async(self, send: typing.Callable[[str], typing.Awaitable[object self._bytes = sum(bl for _, bl in self._queue) raise - def drain(self) -> list[str]: + def drain(self) -> list[bytes | str]: """Remove and return all queued messages.""" with self._lock: items = [data for data, _ in self._queue] diff --git a/src/openai/lib/streaming/responses/_responses.py b/src/openai/lib/streaming/responses/_responses.py index 6975a9260d..41d855e90a 100644 --- a/src/openai/lib/streaming/responses/_responses.py +++ b/src/openai/lib/streaming/responses/_responses.py @@ -25,6 +25,7 @@ ParsedResponseOutputMessage, ParsedResponseFunctionToolCall, ) +from ....types.responses.response_reasoning_item import Content as ReasoningItemContent class ResponseStream(Generic[TextFormatT]): @@ -346,6 +347,10 @@ def accumulate_event(self, event: RawResponseStreamEvent) -> ParsedResponseSnaps output.content.append( construct_type_unchecked(type_=cast(Any, ParsedContent), value=event.part.to_dict()) ) + elif output.type == "reasoning" and event.part.type == "reasoning_text": + if output.content is None: + output.content = [] + output.content.append(construct_type_unchecked(type_=ReasoningItemContent, value=event.part.to_dict())) elif event.type == "response.output_text.delta": output = snapshot.output[event.output_index] if output.type == "message": @@ -356,6 +361,12 @@ def accumulate_event(self, event: RawResponseStreamEvent) -> ParsedResponseSnaps output = snapshot.output[event.output_index] if output.type == "function_call": output.arguments += event.delta + elif event.type == "response.reasoning_text.delta": + output = snapshot.output[event.output_index] + if output.type == "reasoning" and output.content is not None: + content = output.content[event.content_index] + assert content.type == "reasoning_text" + content.text += event.delta elif event.type == "response.completed": self._completed_response = parse_response( text_format=self._text_format, diff --git a/src/openai/resources/realtime/realtime.py b/src/openai/resources/realtime/realtime.py index e4c5bd8163..a29248e3c8 100644 --- a/src/openai/resources/realtime/realtime.py +++ b/src/openai/resources/realtime/realtime.py @@ -359,10 +359,13 @@ async def send(self, event: RealtimeClientEvent | RealtimeClientEventParam) -> N async def send_raw(self, data: bytes | str) -> None: if self._is_reconnecting: - raw = data if isinstance(data, str) else data.decode("utf-8") - self._send_queue.enqueue(raw) + self._send_queue.enqueue(data) return - await self._connection.send(data) + try: + await self._connection.send(data) + except Exception: + self._send_queue.enqueue(data) + raise async def close(self, *, code: int = 1000, reason: str = "") -> None: self._intentionally_closed = True @@ -839,10 +842,13 @@ def send(self, event: RealtimeClientEvent | RealtimeClientEventParam) -> None: def send_raw(self, data: bytes | str) -> None: if self._is_reconnecting: - raw = data if isinstance(data, str) else data.decode("utf-8") - self._send_queue.enqueue(raw) + self._send_queue.enqueue(data) return - self._connection.send(data) + try: + self._connection.send(data) + except Exception: + self._send_queue.enqueue(data) + raise def close(self, *, code: int = 1000, reason: str = "") -> None: self._intentionally_closed = True diff --git a/src/openai/resources/responses/responses.py b/src/openai/resources/responses/responses.py index 5019d7e831..06f3d8e28a 100644 --- a/src/openai/resources/responses/responses.py +++ b/src/openai/resources/responses/responses.py @@ -3852,10 +3852,13 @@ async def send(self, event: ResponsesClientEvent | ResponsesClientEventParam) -> async def send_raw(self, data: bytes | str) -> None: if self._is_reconnecting: - raw = data if isinstance(data, str) else data.decode("utf-8") - self._send_queue.enqueue(raw) + self._send_queue.enqueue(data) return - await self._connection.send(data) + try: + await self._connection.send(data) + except Exception: + self._send_queue.enqueue(data) + raise async def close(self, *, code: int = 1000, reason: str = "") -> None: self._intentionally_closed = True @@ -4309,10 +4312,13 @@ def send(self, event: ResponsesClientEvent | ResponsesClientEventParam) -> None: def send_raw(self, data: bytes | str) -> None: if self._is_reconnecting: - raw = data if isinstance(data, str) else data.decode("utf-8") - self._send_queue.enqueue(raw) + self._send_queue.enqueue(data) return - self._connection.send(data) + try: + self._connection.send(data) + except Exception: + self._send_queue.enqueue(data) + raise def close(self, *, code: int = 1000, reason: str = "") -> None: self._intentionally_closed = True diff --git a/tests/lib/responses/test_responses.py b/tests/lib/responses/test_responses.py index 8e5f16df95..b32f6f7a8b 100644 --- a/tests/lib/responses/test_responses.py +++ b/tests/lib/responses/test_responses.py @@ -7,7 +7,15 @@ from inline_snapshot import snapshot from openai import OpenAI, AsyncOpenAI +from openai._types import omit from openai._utils import assert_signatures_in_sync +from openai.types.responses import ( + ResponseCreatedEvent, + ResponseOutputItemAddedEvent, + ResponseContentPartAddedEvent, + ResponseReasoningTextDeltaEvent, +) +from openai.lib.streaming.responses._responses import ResponseStreamState from ...conftest import base_url from ..snapshots import make_snapshot_request @@ -61,3 +69,83 @@ def test_parse_method_definition_in_sync(sync: bool, client: OpenAI, async_clien checking_client.responses.parse, exclude_params={"tools"}, ) + + +_RESPONSE_CREATED_PAYLOAD = { + "type": "response.created", + "sequence_number": 0, + "response": { + "id": "resp_1", + "object": "response", + "created_at": 0, + "status": "in_progress", + "error": None, + "incomplete_details": None, + "instructions": None, + "max_output_tokens": None, + "model": "o3", + "output": [], + "parallel_tool_calls": True, + "temperature": 1.0, + "tool_choice": "auto", + "tools": [], + "top_p": 1.0, + "metadata": {}, + }, +} + + +def test_stream_state_accumulates_reasoning_text_delta() -> None: + state: ResponseStreamState[object] = ResponseStreamState(input_tools=omit, text_format=omit) + + state.handle_event(ResponseCreatedEvent.model_validate(_RESPONSE_CREATED_PAYLOAD)) + state.handle_event( + ResponseOutputItemAddedEvent.model_validate( + { + "type": "response.output_item.added", + "sequence_number": 1, + "output_index": 0, + "item": { + "id": "rs_1", + "type": "reasoning", + "status": "in_progress", + "summary": [], + "content": None, + }, + } + ) + ) + state.handle_event( + ResponseContentPartAddedEvent.model_validate( + { + "type": "response.content_part.added", + "sequence_number": 2, + "output_index": 0, + "content_index": 0, + "item_id": "rs_1", + "part": {"type": "reasoning_text", "text": ""}, + } + ) + ) + for index, delta in enumerate(["Let me ", "think ", "carefully."]): + state.handle_event( + ResponseReasoningTextDeltaEvent.model_validate( + { + "type": "response.reasoning_text.delta", + "sequence_number": 3 + index, + "output_index": 0, + "content_index": 0, + "item_id": "rs_1", + "delta": delta, + } + ) + ) + + current_snapshot = state._ResponseStreamState__current_snapshot # type: ignore[attr-defined] + assert current_snapshot is not None + reasoning = current_snapshot.output[0] + assert reasoning.type == "reasoning" + assert reasoning.content is not None + assert len(reasoning.content) == 1 + assert reasoning.content[0].type == "reasoning_text" + assert reasoning.content[0].text == "Let me think carefully." # type: ignore[union-attr] diff --git a/tests/test_realtime_reconnect.py b/tests/test_realtime_reconnect.py new file mode 100644 index 0000000000..27b14744bc --- /dev/null +++ b/tests/test_realtime_reconnect.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from openai.resources.realtime.realtime import AsyncRealtimeConnection + + +def _connection_closed_error(code: int = 1011) -> Exception: + from websockets.frames import Close + from websockets.exceptions import ConnectionClosedError + + return ConnectionClosedError(Close(code=code, reason=""), None) + + +class _DeadConnection: + """A connection whose send() always fails, simulating a dropped socket.""" + + async def send(self, _data: bytes | str) -> None: + raise _connection_closed_error() + + async def close(self, *, code: int = 1000, reason: str = "") -> None: + pass + + +class _RecordingConnection: + """The connection returned after a successful reconnect.""" + + def __init__(self) -> None: + self.sent: list[bytes | str] = [] + + async def send(self, data: bytes | str) -> None: + self.sent.append(data) + + +def _make_connection(new_conn: _RecordingConnection) -> AsyncRealtimeConnection: + async def make_ws(_extra_query: Any, _extra_headers: Any) -> Any: + return new_conn + + return AsyncRealtimeConnection( + _DeadConnection(), # type: ignore[arg-type] + make_ws=make_ws, + on_reconnecting=lambda _event: None, + max_retries=1, + initial_delay=0.0, + max_delay=0.0, + ) + + +@pytest.mark.asyncio +async def test_reconnect_resends_binary_payload_unchanged() -> None: + """End-to-end: a binary send_raw() that fails mid-send is queued and + replayed byte-for-byte after reconnect, without UTF-8 corruption.""" + from websockets.exceptions import ConnectionClosedError + + new_conn = _RecordingConnection() + conn = _make_connection(new_conn) + + binary = b"\xff\xfe\x00audio" # not valid UTF-8 (would crash on decode) + + # send fails on the dead socket -> the original connection error must + # surface (NOT a UnicodeDecodeError from decoding the binary payload), + # and the payload must be queued for replay. + with pytest.raises(ConnectionClosedError): + await conn.send_raw(binary) + + # Drive the real reconnect path, which flushes the queue to the new socket. + reconnected = await conn._reconnect(_connection_closed_error()) + assert reconnected is True + + assert new_conn.sent == [binary] + assert isinstance(new_conn.sent[0], bytes) + + +@pytest.mark.asyncio +async def test_reconnect_resends_text_payload() -> None: + """A str send_raw() is replayed as text after reconnect.""" + from websockets.exceptions import ConnectionClosedError + + new_conn = _RecordingConnection() + conn = _make_connection(new_conn) + + with pytest.raises(ConnectionClosedError): + await conn.send_raw('{"type": "input_audio_buffer.append"}') + + assert await conn._reconnect(_connection_closed_error()) is True + assert new_conn.sent == ['{"type": "input_audio_buffer.append"}'] + assert isinstance(new_conn.sent[0], str) diff --git a/tests/test_send_queue.py b/tests/test_send_queue.py index 61db916bc4..6d0676337f 100644 --- a/tests/test_send_queue.py +++ b/tests/test_send_queue.py @@ -19,6 +19,39 @@ def test_enqueue_and_drain(self) -> None: assert items == ['{"type": "session.update"}', '{"type": "response.create"}'] assert len(q) == 0 + def test_enqueue_preserves_binary_frames(self) -> None: + """Binary payloads must be stored as-is, not decoded to text. + + Decoding to UTF-8 would corrupt binary frames and raise + ``UnicodeDecodeError`` for arbitrary bytes (e.g. audio chunks). + """ + q = SendQueue() + binary = b"\xff\xfe\x00audio" # not valid UTF-8 + q.enqueue(binary) + q.enqueue("text") + + items = q.drain() + assert items == [binary, "text"] + assert isinstance(items[0], bytes) + assert isinstance(items[1], str) + + def test_enqueue_counts_binary_byte_length(self) -> None: + q = SendQueue(max_bytes=4) + q.enqueue(b"\xff\xfe\xfd\xfc") # 4 bytes, fits exactly + with pytest.raises(WebSocketQueueFullError): + q.enqueue(b"\x00") # would exceed + assert len(q) == 1 + + def test_flush_sync_preserves_binary(self) -> None: + q = SendQueue() + binary = b"\xff\xfe" + q.enqueue(binary) + q.enqueue("text") + + sent: list[bytes | str] = [] + q.flush_sync(sent.append) + assert sent == [binary, "text"] + def test_enqueue_respects_byte_limit(self) -> None: q = SendQueue(max_bytes=10) q.enqueue("12345") # 5 bytes, fits