@@ -43,6 +43,12 @@ def _dump_model(model: BaseModel):
4343 return model .dict (exclude_none = True )
4444
4545
46+ def _parse_model (model_class , data ):
47+ if hasattr (model_class , "model_validate" ):
48+ return model_class .model_validate (data )
49+ return model_class .parse_obj (data )
50+
51+
4652def _normalize_min_turn_silence (params_dict : dict ) -> dict :
4753 """Collapse `min_end_of_turn_silence_when_confident` into `min_turn_silence` so only
4854 one wire key is ever sent. Emits deprecation warnings."""
@@ -65,6 +71,31 @@ def _normalize_min_turn_silence(params_dict: dict) -> dict:
6571 return params_dict
6672
6773
74+ def _normalize_voice_focus (params_dict : dict ) -> dict :
75+ """Collapse `noise_suppression_model` / `noise_suppression_threshold` into
76+ `voice_focus` / `voice_focus_threshold` so only the new wire keys are sent.
77+ Emits deprecation warnings."""
78+ for old_key , new_key in (
79+ ("noise_suppression_model" , "voice_focus" ),
80+ ("noise_suppression_threshold" , "voice_focus_threshold" ),
81+ ):
82+ old = params_dict .pop (old_key , None )
83+ if old is None :
84+ continue
85+ if new_key in params_dict :
86+ logger .warning (
87+ f"[Deprecation Warning] Both `{ old_key } ` and `{ new_key } ` are set. "
88+ f"Using `{ new_key } `; `{ old_key } ` is deprecated."
89+ )
90+ else :
91+ logger .warning (
92+ f"[Deprecation Warning] `{ old_key } ` is deprecated and will be removed "
93+ f"in a future release. Please use `{ new_key } ` instead."
94+ )
95+ params_dict [new_key ] = old
96+ return params_dict
97+
98+
6899def _dump_model_json (model : BaseModel ):
69100 if hasattr (model , "model_dump_json" ):
70101 return model .model_dump_json (exclude_none = True )
@@ -94,6 +125,17 @@ def __init__(self, options: StreamingClientOptions):
94125 self ._write_thread = threading .Thread (target = self ._write_message )
95126 self ._read_thread = threading .Thread (target = self ._read_message )
96127 self ._stop_event = threading .Event ()
128+ # Both flags are read and set only on the read thread (or on the main
129+ # thread before workers start, for handshake errors). Plain bools are
130+ # sufficient — no cross-thread synchronization is needed.
131+ self ._connection_closed_reported = False
132+ self ._server_error_reported = False
133+ # Deliberate single-slot shared-memory handoff: the write thread parks
134+ # a ConnectionClosed here and the read thread drains it. Synchronization
135+ # is provided by `_stop_event.set()` (write side) + `recv(timeout=1)`
136+ # (read side), which together give a happens-before within ~1s.
137+ self ._pending_close_error : Optional [Exception ] = None
138+ self ._websocket = None
97139
98140 def connect (self , params : StreamingParameters ) -> None :
99141 if params .speech_model == "u3-pro" :
@@ -102,7 +144,15 @@ def connect(self, params: StreamingParameters) -> None:
102144 "Please use `u3-rt-pro` instead."
103145 )
104146
105- params_dict = _normalize_min_turn_silence (_dump_model (params ))
147+ if params .customer_support_audio_capture :
148+ logger .warning (
149+ "`customer_support_audio_capture=True` will record session audio. "
150+ "Only enable this when explicitly coordinating with AssemblyAI support."
151+ )
152+
153+ params_dict = _normalize_voice_focus (
154+ _normalize_min_turn_silence (_dump_model (params ))
155+ )
106156
107157 # JSON-encode list and dict parameters for proper API compatibility (e.g., keyterms_prompt, llm_gateway)
108158 for key , value in params_dict .items ():
@@ -132,8 +182,22 @@ def connect(self, params: StreamingParameters) -> None:
132182 additional_headers = headers ,
133183 open_timeout = 15 ,
134184 )
135- except websockets .exceptions .ConnectionClosed as exc :
136- self ._handle_error (exc )
185+ except websockets .exceptions .InvalidStatus as exc :
186+ status_code = getattr (getattr (exc , "response" , None ), "status_code" , None )
187+ self ._report_connection_closed (
188+ StreamingError (
189+ message = f"WebSocket handshake rejected (HTTP { status_code } )" ,
190+ code = status_code ,
191+ )
192+ )
193+ return
194+ except (
195+ websockets .exceptions .InvalidHandshake ,
196+ websockets .exceptions .ConnectionClosed ,
197+ OSError ,
198+ TimeoutError ,
199+ ) as exc :
200+ self ._report_connection_closed (exc )
137201 return
138202
139203 self ._write_thread .start ()
@@ -145,23 +209,40 @@ def disconnect(self, terminate: bool = False) -> None:
145209 if terminate and not self ._stop_event .is_set ():
146210 self ._write_queue .put (TerminateSession ())
147211
148- try :
149- self ._read_thread .join ()
150- self ._write_thread .join ()
212+ self ._stop_event .set ()
213+
214+ current = threading .current_thread ()
215+ for thread in (self ._read_thread , self ._write_thread ):
216+ if thread is current or not thread .is_alive ():
217+ continue
218+ try :
219+ thread .join ()
220+ except RuntimeError as exc :
221+ logger .debug ("Thread join skipped: %s" , exc )
151222
152- if self ._websocket :
153- self ._websocket .close ()
154- except Exception :
155- pass
223+ self ._close_websocket ()
224+
225+ def _close_websocket (self ) -> None :
226+ if not self ._websocket :
227+ return
228+ try :
229+ self ._websocket .close ()
230+ except (OSError , websockets .exceptions .WebSocketException ) as exc :
231+ logger .debug ("Error closing websocket: %s" , exc )
156232
157233 def stream (
158234 self , data : Union [bytes , Generator [bytes , None , None ], Iterable [bytes ]]
159235 ) -> None :
236+ if self ._stop_event .is_set ():
237+ return
238+
160239 if isinstance (data , bytes ):
161240 self ._write_queue .put (data )
162241 return
163242
164243 for chunk in data :
244+ if self ._stop_event .is_set ():
245+ return
165246 self ._write_queue .put (chunk )
166247
167248 def set_params (self , params : StreamingSessionParameters ):
@@ -178,15 +259,24 @@ def on(self, event: StreamingEvents, handler: Callable) -> None:
178259 self ._handlers [event ].append (handler )
179260
180261 def _write_message (self ) -> None :
181- while not self . _stop_event . is_set () :
262+ while True :
182263 if not self ._websocket :
183264 raise ValueError ("Not connected to the WebSocket server" )
184265
185266 try :
186267 data = self ._write_queue .get (timeout = 1 )
187268 except queue .Empty :
269+ if self ._stop_event .is_set ():
270+ return
188271 continue
189272
273+ # TerminateSession bypasses the stop gate so disconnect(terminate=True)
274+ # can always send it, even when stop is set between put() and the
275+ # write loop's next iteration.
276+ is_terminate = isinstance (data , TerminateSession )
277+ if not is_terminate and self ._stop_event .is_set ():
278+ return
279+
190280 try :
191281 if isinstance (data , bytes ):
192282 self ._websocket .send (data )
@@ -195,20 +285,36 @@ def _write_message(self) -> None:
195285 else :
196286 raise ValueError (f"Attempted to send invalid message: { type (data )} " )
197287 except websockets .exceptions .ConnectionClosed as exc :
198- self ._handle_error (exc )
288+ # Defer reporting to the read thread so all on_error dispatch
289+ # happens on a single thread (no cross-thread dedup race).
290+ self ._pending_close_error = exc
291+ self ._stop_event .set ()
292+ return
293+
294+ if is_terminate :
199295 return
200296
201297 def _read_message (self ) -> None :
202- while not self . _stop_event . is_set () :
298+ while True :
203299 if not self ._websocket :
204300 raise ValueError ("Not connected to the WebSocket server" )
205301
302+ # Drain a write-thread close before honoring stop, so a stop set by
303+ # the write thread doesn't cause us to exit silently with an
304+ # unreported close.
305+ if self ._pending_close_error is not None :
306+ pending , self ._pending_close_error = self ._pending_close_error , None
307+ self ._report_connection_closed (pending )
308+ return
309+ if self ._stop_event .is_set ():
310+ return
311+
206312 try :
207313 message_data = self ._websocket .recv (timeout = 1 )
208314 except TimeoutError :
209315 continue
210316 except websockets .exceptions .ConnectionClosed as exc :
211- self ._handle_error (exc )
317+ self ._report_connection_closed (exc )
212318 return
213319
214320 try :
@@ -220,7 +326,7 @@ def _read_message(self) -> None:
220326 message = self ._parse_message (message_json )
221327
222328 if isinstance (message , ErrorEvent ):
223- self ._handle_error (message )
329+ self ._report_server_error (message )
224330 elif isinstance (message , WarningEvent ):
225331 self ._handle_warning (message )
226332 elif message :
@@ -244,23 +350,23 @@ def _parse_message(self, data: Dict[str, Any]) -> Optional[EventMessage]:
244350 event_type = self ._parse_event_type (message_type )
245351
246352 if event_type == StreamingEvents .Begin :
247- return BeginEvent . model_validate ( data )
353+ return _parse_model ( BeginEvent , data )
248354 elif event_type == StreamingEvents .Termination :
249- return TerminationEvent . model_validate ( data )
355+ return _parse_model ( TerminationEvent , data )
250356 elif event_type == StreamingEvents .Turn :
251- return TurnEvent . model_validate ( data )
357+ return _parse_model ( TurnEvent , data )
252358 elif event_type == StreamingEvents .SpeechStarted :
253- return SpeechStartedEvent . model_validate ( data )
359+ return _parse_model ( SpeechStartedEvent , data )
254360 elif event_type == StreamingEvents .LLMGatewayResponse :
255- return LLMGatewayResponseEvent . model_validate ( data )
361+ return _parse_model ( LLMGatewayResponseEvent , data )
256362 elif event_type == StreamingEvents .Error :
257- return ErrorEvent . model_validate ( data )
363+ return _parse_model ( ErrorEvent , data )
258364 elif event_type == StreamingEvents .Warning :
259- return WarningEvent . model_validate ( data )
365+ return _parse_model ( WarningEvent , data )
260366 else :
261367 return None
262368 elif "error" in data :
263- return ErrorEvent . model_validate ( data )
369+ return _parse_model ( ErrorEvent , data )
264370
265371 return None
266372
@@ -281,44 +387,85 @@ def _handle_warning(self, warning: WarningEvent):
281387 for handler in self ._handlers [StreamingEvents .Warning ]:
282388 handler (self , warning )
283389
284- def _handle_error (
390+ def _report_server_error (self , error : ErrorEvent ) -> None :
391+ self ._server_error_reported = True
392+ streaming_error = StreamingError (
393+ message = error .error ,
394+ code = error .error_code ,
395+ )
396+ logger .error ("Streaming error: %s (code=%s)" , error .error , error .error_code )
397+ self ._dispatch_error (streaming_error )
398+
399+ def _report_connection_closed (
285400 self ,
286401 error : Union [
402+ StreamingError ,
287403 ErrorEvent ,
288404 websockets .exceptions .ConnectionClosed ,
405+ OSError ,
289406 ],
290- ):
291- parsed_error = self ._parse_error (error )
407+ ) -> None :
408+ # Idempotent: defensive guard in case future callers (e.g. another
409+ # connect-time error path) reach this method twice.
410+ if self ._connection_closed_reported :
411+ return
412+ self ._connection_closed_reported = True
413+ self ._stop_event .set ()
292414
293- for handler in self ._handlers [StreamingEvents .Error ]:
294- handler (self , parsed_error )
415+ streaming_error = self ._build_connection_closed_error (error )
295416
296- self .disconnect ()
417+ # Clean close (code 1000) → no streaming_error, nothing to report.
418+ if streaming_error is None :
419+ self ._close_websocket ()
420+ return
297421
298- def _parse_error (
299- self ,
422+ if isinstance (error , websockets .exceptions .ConnectionClosed ):
423+ reason = error .reason or "no reason given"
424+ logger .error ("Connection closed: %s (code=%s)" , reason , error .code )
425+ else :
426+ logger .error (
427+ "Connection failed: %s (code=%s)" ,
428+ streaming_error ,
429+ streaming_error .code ,
430+ )
431+
432+ # If a server Error frame already fired on_error, the close is the
433+ # effect, not a new cause — log it (above) but skip the duplicate
434+ # user-visible error.
435+ if not self ._server_error_reported :
436+ self ._dispatch_error (streaming_error )
437+
438+ self ._close_websocket ()
439+
440+ def _dispatch_error (self , error : StreamingError ) -> None :
441+ for handler in self ._handlers [StreamingEvents .Error ]:
442+ try :
443+ handler (self , error )
444+ except Exception :
445+ logger .exception ("on_error handler raised" )
446+
447+ @staticmethod
448+ def _build_connection_closed_error (
300449 error : Union [
450+ StreamingError ,
301451 ErrorEvent ,
302452 websockets .exceptions .ConnectionClosed ,
453+ OSError ,
303454 ],
304- ) -> StreamingError :
455+ ) -> Optional [StreamingError ]:
456+ if isinstance (error , StreamingError ):
457+ return error
305458 if isinstance (error , ErrorEvent ):
306- return StreamingError (
307- message = error .error ,
308- code = error .error_code ,
309- )
310- elif isinstance (error , websockets .exceptions .ConnectionClosed ):
311- if error .code in StreamingErrorCodes :
312- error_message = StreamingErrorCodes [error .code ]
459+ return StreamingError (message = error .error , code = error .error_code )
460+ if isinstance (error , websockets .exceptions .ConnectionClosed ):
461+ if error .code == 1000 :
462+ return None
463+ if error .code is not None and error .code in StreamingErrorCodes :
464+ message = StreamingErrorCodes [error .code ]
313465 else :
314- error_message = error .reason
315-
316- if error .code != 1000 :
317- return StreamingError (message = error_message , code = error .code )
318-
319- return StreamingError (
320- message = f"Unknown error: { error } " ,
321- )
466+ message = error .reason or f"Connection closed (code={ error .code } )"
467+ return StreamingError (message = message , code = error .code )
468+ return StreamingError (message = f"Connection failed: { error } " )
322469
323470 def create_temporary_token (
324471 self ,
0 commit comments