Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 117 additions & 18 deletions riva/client/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,11 @@ async def connect(self):
await self._initialize_session()

except requests.exceptions.RequestException as e:
logger.error("HTTP request failed: %s", e)
raise
except WebSocketException as e:
logger.error("WebSocket connection failed: %s", e)
raise
except Exception as e:
logger.error("Unexpected error during connection: %s", e)
raise

async def _initialize_http_session(self) -> Dict[str, Any]:
Expand All @@ -70,23 +68,83 @@ async def _initialize_http_session(self) -> Dict[str, Any]:
if self.args.use_ssl:
uri = f"https://{self.args.server}/v1/realtime/transcription_sessions"
logger.debug("Initializing session via HTTP POST request to: %s", uri)
response = requests.post(
uri,
headers=headers,
json={},
cert=(self.args.ssl_client_cert, self.args.ssl_client_key) if self.args.ssl_client_cert and self.args.ssl_client_key else None,
verify=self.args.ssl_root_cert if self.args.ssl_root_cert else True
)

if response.status_code != 200:
raise Exception(
f"Failed to initialize session. Status: {response.status_code}, "
f"Error: {response.text}"

try:
response = requests.post(
uri,
headers=headers,
json={},
cert=(self.args.ssl_client_cert, self.args.ssl_client_key) if self.args.ssl_client_cert and self.args.ssl_client_key else None,
verify=self.args.ssl_root_cert if self.args.ssl_root_cert else True,
timeout=30 # Add timeout to prevent hanging
)

session_data = response.json()
logger.debug("Session initialized: %s", session_data)
return session_data
if response.status_code != 200:
raise Exception(
f"Failed to initialize session. Status: {response.status_code}, "
f"Error: {response.text}"
)

session_data = response.json()
logger.debug("Session initialized: %s", session_data)
return session_data

except requests.exceptions.ConnectionError as e:
# Handle connection errors more gracefully
if "Connection refused" in str(e):
error_msg = f"Cannot connect to server at {self.args.server}. The server may be down or not running."
elif "Name or service not known" in str(e):
error_msg = f"Cannot resolve server hostname '{self.args.server}'. Please check the server address."
elif "timeout" in str(e).lower():
error_msg = f"Connection to {self.args.server} timed out. The server may be overloaded or unreachable."
elif "Connection aborted." in str(e):
error_msg = f"Connection aborted. Failed to establish a new connection to {self.args.server}. The server may be down or not running."
else:
error_msg = f"Connection failed to {self.args.server}: {str(e)}"

logger.error("HTTP connection error: %s", error_msg)
raise Exception(error_msg) from e

except requests.exceptions.SSLError as e:
error_msg = f"SSL/TLS connection failed to {self.args.server}: {str(e)}"
logger.error("SSL error: %s", error_msg)
raise Exception(error_msg) from e

except requests.exceptions.Timeout as e:
error_msg = f"Request to {self.args.server} timed out after 30 seconds"
logger.error("Request timeout: %s", error_msg)
raise Exception(error_msg) from e

except requests.exceptions.RequestException as e:
# Handle other HTTP-related errors
if hasattr(e, 'response') and e.response is not None:
status_code = e.response.status_code
try:
error_text = e.response.text
except:
error_text = "Unable to read error response"

if status_code == 401:
error_msg = f"Authentication failed (401). Please check your SSL certificates and credentials."
elif status_code == 403:
error_msg = f"Access forbidden (403). You may not have permission to access this service."
elif status_code == 404:
error_msg = f"Service not found (404). The transcription service endpoint may not exist at {uri}"
elif status_code == 500:
error_msg = f"Server error (500). The transcription service encountered an internal error."
else:
error_msg = f"HTTP request failed with status {status_code}: {error_text}"
else:
error_msg = f"HTTP request failed: {str(e)}"

logger.error("HTTP request error: %s", error_msg)
raise Exception(error_msg) from e

except Exception as e:
# Handle any other unexpected errors
error_msg = f"Unexpected error during HTTP session initialization: {str(e)}"
logger.error("Unexpected error: %s", error_msg)
raise Exception(error_msg) from e

async def _connect_websocket(self):
"""Connect to WebSocket endpoint."""
Expand All @@ -107,7 +165,48 @@ async def _connect_websocket(self):
# ssl_context.verify_mode = ssl.CERT_REQUIRED

logger.debug("Connecting to WebSocket: %s", ws_url)
self.websocket = await websockets.connect(ws_url, ssl=ssl_context)

try:
self.websocket = await websockets.connect(
ws_url,
ssl=ssl_context,
ping_interval=20, # Send ping every 20 seconds
ping_timeout=10, # Wait 10 seconds for pong response
close_timeout=10 # Wait 10 seconds for close response
)
except websockets.exceptions.InvalidURI as e:
error_msg = f"Invalid WebSocket URI: {ws_url}. Please check the server address and endpoint."
logger.error("WebSocket URI error: %s", error_msg)
raise Exception(error_msg) from e
except websockets.exceptions.InvalidHandshake as e:
error_msg = f"WebSocket handshake failed. The server may not support WebSocket connections or the endpoint is incorrect."
logger.error("WebSocket handshake error: %s", error_msg)
raise Exception(error_msg) from e
except websockets.exceptions.ConnectionClosed as e:
error_msg = f"WebSocket connection was closed unexpectedly: {str(e)}"
logger.error("WebSocket connection closed: %s", error_msg)
raise Exception(error_msg) from e
except websockets.exceptions.WebSocketException as e:
error_msg = f"WebSocket connection failed: {str(e)}"
logger.error("WebSocket error: %s", error_msg)
raise Exception(error_msg) from e
except ConnectionRefusedError as e:
error_msg = f"Cannot connect to WebSocket server at {self.args.server}. The server may be down or not running."
logger.error("WebSocket connection refused: %s", error_msg)
raise Exception(error_msg) from e
except OSError as e:
if "Name or service not known" in str(e):
error_msg = f"Cannot resolve server hostname '{self.args.server}'. Please check the server address."
elif "timeout" in str(e).lower():
error_msg = f"WebSocket connection to {self.args.server} timed out."
else:
error_msg = f"WebSocket connection failed: {str(e)}"
logger.error("WebSocket OS error: %s", error_msg)
raise Exception(error_msg) from e
except Exception as e:
error_msg = f"Unexpected error during WebSocket connection: {str(e)}"
logger.error("Unexpected WebSocket error: %s", error_msg)
raise Exception(error_msg) from e

async def _initialize_session(self):
"""Initialize the WebSocket session."""
Expand Down
2 changes: 1 addition & 1 deletion scripts/asr/realtime_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ async def run_transcription(args):
print("Transcription stopped gracefully.")

except Exception as e:
print(f"Error during realtime transcription: {e}")
print(f"Error during realtime transcription. Aborting transcription.")
raise

finally:
Expand Down