Skip to content

Commit 67aa20e

Browse files
feat(client): allow enqueuing to websockets even when not connected
1 parent 0280d05 commit 67aa20e

File tree

7 files changed

+645
-6
lines changed

7 files changed

+645
-6
lines changed

src/openai/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929
InternalServerError,
3030
PermissionDeniedError,
3131
LengthFinishReasonError,
32+
WebSocketQueueFullError,
3233
UnprocessableEntityError,
3334
APIResponseValidationError,
3435
InvalidWebhookSignatureError,
3536
ContentFilterFinishReasonError,
37+
WebSocketConnectionClosedError,
3638
)
3739
from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
3840
from ._utils._logs import setup_logging as _setup_logging
@@ -87,6 +89,8 @@
8789
"DefaultAioHttpClient",
8890
"ReconnectingEvent",
8991
"ReconnectingOverrides",
92+
"WebSocketQueueFullError",
93+
"WebSocketConnectionClosedError",
9094
]
9195

9296
if not _t.TYPE_CHECKING:

src/openai/_event_handler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,16 @@ def has_handlers(self, event_type: str) -> bool:
7070
return bool(handlers)
7171
finally:
7272
self._release()
73+
74+
def merge_into(self, target: EventHandlerRegistry) -> None:
75+
"""Move all handlers from this registry into *target*, then clear self."""
76+
self._acquire()
77+
try:
78+
for event_type, handlers in self._handlers.items():
79+
for handler in handlers:
80+
once = id(handler) in self._once_ids
81+
target.add(event_type, handler, once=once)
82+
self._handlers.clear()
83+
self._once_ids.clear()
84+
finally:
85+
self._release()

src/openai/_exceptions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
"ContentFilterFinishReasonError",
2929
"InvalidWebhookSignatureError",
3030
"SubjectTokenProviderError",
31+
"WebSocketConnectionClosedError",
32+
"WebSocketQueueFullError",
3133
]
3234

3335

@@ -187,3 +189,19 @@ def __init__(self) -> None:
187189

188190
class InvalidWebhookSignatureError(ValueError):
189191
"""Raised when a webhook signature is invalid, meaning the computed signature does not match the expected signature."""
192+
193+
194+
class WebSocketConnectionClosedError(OpenAIError):
195+
"""Raised when a WebSocket connection closes with unsent messages."""
196+
197+
unsent_messages: list[str]
198+
199+
def __init__(self, message: str, *, unsent_messages: list[str]) -> None:
200+
super().__init__(message)
201+
self.unsent_messages = unsent_messages
202+
203+
204+
class WebSocketQueueFullError(OpenAIError):
205+
"""Raised when the outgoing WebSocket message queue exceeds its byte-size limit."""
206+
207+
pass

src/openai/_send_queue.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2+
3+
from __future__ import annotations
4+
5+
import typing
6+
import threading
7+
8+
from ._exceptions import WebSocketQueueFullError
9+
10+
11+
class SendQueue:
12+
"""Bounded byte-size queue for outgoing WebSocket messages.
13+
14+
Messages are stored as pre-serialized strings. The queue enforces a
15+
maximum byte budget so that unbounded buffering cannot occur during
16+
reconnection windows.
17+
"""
18+
19+
def __init__(self, max_bytes: int = 1_048_576) -> None:
20+
self._queue: list[tuple[str, int]] = [] # (data, byte_length)
21+
self._bytes: int = 0
22+
self._max_bytes = max_bytes
23+
self._lock = threading.Lock()
24+
25+
def enqueue(self, data: str) -> None:
26+
"""Append *data* to the queue.
27+
28+
Raises :class:`WebSocketQueueFullError` if the message would
29+
exceed the byte-size limit.
30+
"""
31+
byte_length = len(data.encode("utf-8"))
32+
with self._lock:
33+
if self._bytes + byte_length > self._max_bytes:
34+
raise WebSocketQueueFullError("send queue is full, message discarded")
35+
self._queue.append((data, byte_length))
36+
self._bytes += byte_length
37+
38+
def flush_sync(self, send: typing.Callable[[str], object]) -> None:
39+
"""Send every queued message via *send*.
40+
41+
If *send* raises, the failing message and all subsequent messages
42+
are re-queued and the error is re-raised.
43+
"""
44+
with self._lock:
45+
pending = list(self._queue)
46+
self._queue.clear()
47+
self._bytes = 0
48+
49+
for i, (data, _byte_length) in enumerate(pending):
50+
try:
51+
send(data)
52+
except Exception:
53+
with self._lock:
54+
remaining = pending[i:]
55+
self._queue = remaining + self._queue
56+
self._bytes = sum(bl for _, bl in self._queue)
57+
raise
58+
59+
async def flush_async(self, send: typing.Callable[[str], typing.Awaitable[object]]) -> None:
60+
"""Async variant of :meth:`flush_sync`."""
61+
with self._lock:
62+
pending = list(self._queue)
63+
self._queue.clear()
64+
self._bytes = 0
65+
66+
for i, (data, _byte_length) in enumerate(pending):
67+
try:
68+
await send(data)
69+
except Exception:
70+
with self._lock:
71+
remaining = pending[i:]
72+
self._queue = remaining + self._queue
73+
self._bytes = sum(bl for _, bl in self._queue)
74+
raise
75+
76+
def drain(self) -> list[str]:
77+
"""Remove and return all queued messages."""
78+
with self._lock:
79+
items = [data for data, _ in self._queue]
80+
self._queue.clear()
81+
self._bytes = 0
82+
return items
83+
84+
def __len__(self) -> int:
85+
with self._lock:
86+
return len(self._queue)
87+
88+
def __bool__(self) -> bool:
89+
with self._lock:
90+
return len(self._queue) > 0

0 commit comments

Comments
 (0)