Skip to content

Commit 98c9ea8

Browse files
bgotthold-aaiAssemblyAI
andauthored
chore: sync sdk code with DeepLearning repo (#196)
Co-authored-by: AssemblyAI <engineering.sdk@assemblyai.com>
1 parent efb4448 commit 98c9ea8

5 files changed

Lines changed: 938 additions & 58 deletions

File tree

assemblyai/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.63.1"
1+
__version__ = "0.64.0"

assemblyai/streaming/v3/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
StreamingError,
1212
StreamingEvents,
1313
StreamingParameters,
14+
StreamingPiiPolicy,
15+
StreamingPiiSubstitution,
1416
StreamingSessionParameters,
1517
TerminationEvent,
1618
TurnEvent,
@@ -31,6 +33,8 @@
3133
"StreamingError",
3234
"StreamingEvents",
3335
"StreamingParameters",
36+
"StreamingPiiPolicy",
37+
"StreamingPiiSubstitution",
3438
"StreamingSessionParameters",
3539
"TerminationEvent",
3640
"TurnEvent",

assemblyai/streaming/v3/client.py

Lines changed: 194 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ def _dump_model(model: BaseModel):
4343
return model.dict(exclude_none=True)
4444

4545

46+
def _parse_model(model_class, data):
47+
if hasattr(model_class, "model_validate"):
48+
return model_class.model_validate(data)
49+
return model_class.parse_obj(data)
50+
51+
4652
def _normalize_min_turn_silence(params_dict: dict) -> dict:
4753
"""Collapse `min_end_of_turn_silence_when_confident` into `min_turn_silence` so only
4854
one wire key is ever sent. Emits deprecation warnings."""
@@ -65,6 +71,31 @@ def _normalize_min_turn_silence(params_dict: dict) -> dict:
6571
return params_dict
6672

6773

74+
def _normalize_voice_focus(params_dict: dict) -> dict:
75+
"""Collapse `noise_suppression_model` / `noise_suppression_threshold` into
76+
`voice_focus` / `voice_focus_threshold` so only the new wire keys are sent.
77+
Emits deprecation warnings."""
78+
for old_key, new_key in (
79+
("noise_suppression_model", "voice_focus"),
80+
("noise_suppression_threshold", "voice_focus_threshold"),
81+
):
82+
old = params_dict.pop(old_key, None)
83+
if old is None:
84+
continue
85+
if new_key in params_dict:
86+
logger.warning(
87+
f"[Deprecation Warning] Both `{old_key}` and `{new_key}` are set. "
88+
f"Using `{new_key}`; `{old_key}` is deprecated."
89+
)
90+
else:
91+
logger.warning(
92+
f"[Deprecation Warning] `{old_key}` is deprecated and will be removed "
93+
f"in a future release. Please use `{new_key}` instead."
94+
)
95+
params_dict[new_key] = old
96+
return params_dict
97+
98+
6899
def _dump_model_json(model: BaseModel):
69100
if hasattr(model, "model_dump_json"):
70101
return model.model_dump_json(exclude_none=True)
@@ -94,6 +125,17 @@ def __init__(self, options: StreamingClientOptions):
94125
self._write_thread = threading.Thread(target=self._write_message)
95126
self._read_thread = threading.Thread(target=self._read_message)
96127
self._stop_event = threading.Event()
128+
# Both flags are read and set only on the read thread (or on the main
129+
# thread before workers start, for handshake errors). Plain bools are
130+
# sufficient — no cross-thread synchronization is needed.
131+
self._connection_closed_reported = False
132+
self._server_error_reported = False
133+
# Deliberate single-slot shared-memory handoff: the write thread parks
134+
# a ConnectionClosed here and the read thread drains it. Synchronization
135+
# is provided by `_stop_event.set()` (write side) + `recv(timeout=1)`
136+
# (read side), which together give a happens-before within ~1s.
137+
self._pending_close_error: Optional[Exception] = None
138+
self._websocket = None
97139

98140
def connect(self, params: StreamingParameters) -> None:
99141
if params.speech_model == "u3-pro":
@@ -102,7 +144,15 @@ def connect(self, params: StreamingParameters) -> None:
102144
"Please use `u3-rt-pro` instead."
103145
)
104146

105-
params_dict = _normalize_min_turn_silence(_dump_model(params))
147+
if params.customer_support_audio_capture:
148+
logger.warning(
149+
"`customer_support_audio_capture=True` will record session audio. "
150+
"Only enable this when explicitly coordinating with AssemblyAI support."
151+
)
152+
153+
params_dict = _normalize_voice_focus(
154+
_normalize_min_turn_silence(_dump_model(params))
155+
)
106156

107157
# JSON-encode list and dict parameters for proper API compatibility (e.g., keyterms_prompt, llm_gateway)
108158
for key, value in params_dict.items():
@@ -132,8 +182,22 @@ def connect(self, params: StreamingParameters) -> None:
132182
additional_headers=headers,
133183
open_timeout=15,
134184
)
135-
except websockets.exceptions.ConnectionClosed as exc:
136-
self._handle_error(exc)
185+
except websockets.exceptions.InvalidStatus as exc:
186+
status_code = getattr(getattr(exc, "response", None), "status_code", None)
187+
self._report_connection_closed(
188+
StreamingError(
189+
message=f"WebSocket handshake rejected (HTTP {status_code})",
190+
code=status_code,
191+
)
192+
)
193+
return
194+
except (
195+
websockets.exceptions.InvalidHandshake,
196+
websockets.exceptions.ConnectionClosed,
197+
OSError,
198+
TimeoutError,
199+
) as exc:
200+
self._report_connection_closed(exc)
137201
return
138202

139203
self._write_thread.start()
@@ -145,23 +209,40 @@ def disconnect(self, terminate: bool = False) -> None:
145209
if terminate and not self._stop_event.is_set():
146210
self._write_queue.put(TerminateSession())
147211

148-
try:
149-
self._read_thread.join()
150-
self._write_thread.join()
212+
self._stop_event.set()
213+
214+
current = threading.current_thread()
215+
for thread in (self._read_thread, self._write_thread):
216+
if thread is current or not thread.is_alive():
217+
continue
218+
try:
219+
thread.join()
220+
except RuntimeError as exc:
221+
logger.debug("Thread join skipped: %s", exc)
151222

152-
if self._websocket:
153-
self._websocket.close()
154-
except Exception:
155-
pass
223+
self._close_websocket()
224+
225+
def _close_websocket(self) -> None:
226+
if not self._websocket:
227+
return
228+
try:
229+
self._websocket.close()
230+
except (OSError, websockets.exceptions.WebSocketException) as exc:
231+
logger.debug("Error closing websocket: %s", exc)
156232

157233
def stream(
158234
self, data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]]
159235
) -> None:
236+
if self._stop_event.is_set():
237+
return
238+
160239
if isinstance(data, bytes):
161240
self._write_queue.put(data)
162241
return
163242

164243
for chunk in data:
244+
if self._stop_event.is_set():
245+
return
165246
self._write_queue.put(chunk)
166247

167248
def set_params(self, params: StreamingSessionParameters):
@@ -178,15 +259,24 @@ def on(self, event: StreamingEvents, handler: Callable) -> None:
178259
self._handlers[event].append(handler)
179260

180261
def _write_message(self) -> None:
181-
while not self._stop_event.is_set():
262+
while True:
182263
if not self._websocket:
183264
raise ValueError("Not connected to the WebSocket server")
184265

185266
try:
186267
data = self._write_queue.get(timeout=1)
187268
except queue.Empty:
269+
if self._stop_event.is_set():
270+
return
188271
continue
189272

273+
# TerminateSession bypasses the stop gate so disconnect(terminate=True)
274+
# can always send it, even when stop is set between put() and the
275+
# write loop's next iteration.
276+
is_terminate = isinstance(data, TerminateSession)
277+
if not is_terminate and self._stop_event.is_set():
278+
return
279+
190280
try:
191281
if isinstance(data, bytes):
192282
self._websocket.send(data)
@@ -195,20 +285,36 @@ def _write_message(self) -> None:
195285
else:
196286
raise ValueError(f"Attempted to send invalid message: {type(data)}")
197287
except websockets.exceptions.ConnectionClosed as exc:
198-
self._handle_error(exc)
288+
# Defer reporting to the read thread so all on_error dispatch
289+
# happens on a single thread (no cross-thread dedup race).
290+
self._pending_close_error = exc
291+
self._stop_event.set()
292+
return
293+
294+
if is_terminate:
199295
return
200296

201297
def _read_message(self) -> None:
202-
while not self._stop_event.is_set():
298+
while True:
203299
if not self._websocket:
204300
raise ValueError("Not connected to the WebSocket server")
205301

302+
# Drain a write-thread close before honoring stop, so a stop set by
303+
# the write thread doesn't cause us to exit silently with an
304+
# unreported close.
305+
if self._pending_close_error is not None:
306+
pending, self._pending_close_error = self._pending_close_error, None
307+
self._report_connection_closed(pending)
308+
return
309+
if self._stop_event.is_set():
310+
return
311+
206312
try:
207313
message_data = self._websocket.recv(timeout=1)
208314
except TimeoutError:
209315
continue
210316
except websockets.exceptions.ConnectionClosed as exc:
211-
self._handle_error(exc)
317+
self._report_connection_closed(exc)
212318
return
213319

214320
try:
@@ -220,7 +326,7 @@ def _read_message(self) -> None:
220326
message = self._parse_message(message_json)
221327

222328
if isinstance(message, ErrorEvent):
223-
self._handle_error(message)
329+
self._report_server_error(message)
224330
elif isinstance(message, WarningEvent):
225331
self._handle_warning(message)
226332
elif message:
@@ -244,23 +350,23 @@ def _parse_message(self, data: Dict[str, Any]) -> Optional[EventMessage]:
244350
event_type = self._parse_event_type(message_type)
245351

246352
if event_type == StreamingEvents.Begin:
247-
return BeginEvent.model_validate(data)
353+
return _parse_model(BeginEvent, data)
248354
elif event_type == StreamingEvents.Termination:
249-
return TerminationEvent.model_validate(data)
355+
return _parse_model(TerminationEvent, data)
250356
elif event_type == StreamingEvents.Turn:
251-
return TurnEvent.model_validate(data)
357+
return _parse_model(TurnEvent, data)
252358
elif event_type == StreamingEvents.SpeechStarted:
253-
return SpeechStartedEvent.model_validate(data)
359+
return _parse_model(SpeechStartedEvent, data)
254360
elif event_type == StreamingEvents.LLMGatewayResponse:
255-
return LLMGatewayResponseEvent.model_validate(data)
361+
return _parse_model(LLMGatewayResponseEvent, data)
256362
elif event_type == StreamingEvents.Error:
257-
return ErrorEvent.model_validate(data)
363+
return _parse_model(ErrorEvent, data)
258364
elif event_type == StreamingEvents.Warning:
259-
return WarningEvent.model_validate(data)
365+
return _parse_model(WarningEvent, data)
260366
else:
261367
return None
262368
elif "error" in data:
263-
return ErrorEvent.model_validate(data)
369+
return _parse_model(ErrorEvent, data)
264370

265371
return None
266372

@@ -281,44 +387,85 @@ def _handle_warning(self, warning: WarningEvent):
281387
for handler in self._handlers[StreamingEvents.Warning]:
282388
handler(self, warning)
283389

284-
def _handle_error(
390+
def _report_server_error(self, error: ErrorEvent) -> None:
391+
self._server_error_reported = True
392+
streaming_error = StreamingError(
393+
message=error.error,
394+
code=error.error_code,
395+
)
396+
logger.error("Streaming error: %s (code=%s)", error.error, error.error_code)
397+
self._dispatch_error(streaming_error)
398+
399+
def _report_connection_closed(
285400
self,
286401
error: Union[
402+
StreamingError,
287403
ErrorEvent,
288404
websockets.exceptions.ConnectionClosed,
405+
OSError,
289406
],
290-
):
291-
parsed_error = self._parse_error(error)
407+
) -> None:
408+
# Idempotent: defensive guard in case future callers (e.g. another
409+
# connect-time error path) reach this method twice.
410+
if self._connection_closed_reported:
411+
return
412+
self._connection_closed_reported = True
413+
self._stop_event.set()
292414

293-
for handler in self._handlers[StreamingEvents.Error]:
294-
handler(self, parsed_error)
415+
streaming_error = self._build_connection_closed_error(error)
295416

296-
self.disconnect()
417+
# Clean close (code 1000) → no streaming_error, nothing to report.
418+
if streaming_error is None:
419+
self._close_websocket()
420+
return
297421

298-
def _parse_error(
299-
self,
422+
if isinstance(error, websockets.exceptions.ConnectionClosed):
423+
reason = error.reason or "no reason given"
424+
logger.error("Connection closed: %s (code=%s)", reason, error.code)
425+
else:
426+
logger.error(
427+
"Connection failed: %s (code=%s)",
428+
streaming_error,
429+
streaming_error.code,
430+
)
431+
432+
# If a server Error frame already fired on_error, the close is the
433+
# effect, not a new cause — log it (above) but skip the duplicate
434+
# user-visible error.
435+
if not self._server_error_reported:
436+
self._dispatch_error(streaming_error)
437+
438+
self._close_websocket()
439+
440+
def _dispatch_error(self, error: StreamingError) -> None:
441+
for handler in self._handlers[StreamingEvents.Error]:
442+
try:
443+
handler(self, error)
444+
except Exception:
445+
logger.exception("on_error handler raised")
446+
447+
@staticmethod
448+
def _build_connection_closed_error(
300449
error: Union[
450+
StreamingError,
301451
ErrorEvent,
302452
websockets.exceptions.ConnectionClosed,
453+
OSError,
303454
],
304-
) -> StreamingError:
455+
) -> Optional[StreamingError]:
456+
if isinstance(error, StreamingError):
457+
return error
305458
if isinstance(error, ErrorEvent):
306-
return StreamingError(
307-
message=error.error,
308-
code=error.error_code,
309-
)
310-
elif isinstance(error, websockets.exceptions.ConnectionClosed):
311-
if error.code in StreamingErrorCodes:
312-
error_message = StreamingErrorCodes[error.code]
459+
return StreamingError(message=error.error, code=error.error_code)
460+
if isinstance(error, websockets.exceptions.ConnectionClosed):
461+
if error.code == 1000:
462+
return None
463+
if error.code is not None and error.code in StreamingErrorCodes:
464+
message = StreamingErrorCodes[error.code]
313465
else:
314-
error_message = error.reason
315-
316-
if error.code != 1000:
317-
return StreamingError(message=error_message, code=error.code)
318-
319-
return StreamingError(
320-
message=f"Unknown error: {error}",
321-
)
466+
message = error.reason or f"Connection closed (code={error.code})"
467+
return StreamingError(message=message, code=error.code)
468+
return StreamingError(message=f"Connection failed: {error}")
322469

323470
def create_temporary_token(
324471
self,

0 commit comments

Comments
 (0)