From 3554e3c97aa76726a799c3a74e1d52f82328c72e Mon Sep 17 00:00:00 2001 From: rmittal-github <61574997+rmittal-github@users.noreply.github.com> Date: Fri, 8 Aug 2025 14:17:42 +0530 Subject: [PATCH 1/4] Add SSL mutual authentication support (#147) * Fix import for I/O device only in case of mic (#146) * conditional import for mic to suppress warnings * Remove unused default_device_index variable * Allow grpc options (#140) * allow passing grpc channel create options * add support for using client cert for MTLS * rename ssl-cert to ssl-root-cert * fix: renaming ssl_cert -> ssl_root_cert * fix auth init and docstring * fix missing typos --------- Co-authored-by: yhayarannvidia Co-authored-by: Viraj Karandikar <16838694+virajkarandikar@users.noreply.github.com> --- riva/client/argparse_utils.py | 21 +++++- riva/client/auth.py | 82 ++++++++++++++++++------ scripts/asr/realtime_asr_client.py | 30 +++++---- scripts/asr/riva_streaming_asr_client.py | 10 ++- scripts/asr/transcribe_file.py | 10 ++- scripts/asr/transcribe_file_offline.py | 10 ++- scripts/asr/transcribe_mic.py | 10 ++- scripts/nlp/punctuation_client.py | 20 +++++- scripts/nmt/nmt.py | 10 ++- scripts/nmt/nmt_speech_to_speech.py | 10 ++- scripts/nmt/nmt_speech_to_text.py | 10 ++- scripts/tts/talk.py | 13 +++- 12 files changed, 195 insertions(+), 41 deletions(-) diff --git a/riva/client/argparse_utils.py b/riva/client/argparse_utils.py index d662697f..a8cc1a72 100644 --- a/riva/client/argparse_utils.py +++ b/riva/client/argparse_utils.py @@ -3,6 +3,20 @@ import argparse +def validate_grpc_message_size(value): + """Validate that the GRPC message size is within acceptable limits.""" + min_size = 4 * 1024 * 1024 # 4MB + max_size = 1024 * 1024 * 1024 # 1GB + + try: + size = int(value) + if size < min_size: + raise argparse.ArgumentTypeError(f"GRPC message size must be at least {min_size} bytes (4MB)") + if size > max_size: + raise argparse.ArgumentTypeError(f"GRPC message size must be at most {max_size} bytes (1GB)") + return size + except ValueError: + raise argparse.ArgumentTypeError(f"'{value}' is not a valid integer") def add_asr_config_argparse_parameters( parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, word_time_offsets: bool = False @@ -102,13 +116,16 @@ def add_asr_config_argparse_parameters( def add_connection_argparse_parameters(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument("--server", default="localhost:50051", help="URI to GRPC server endpoint.") - parser.add_argument("--ssl-cert", help="Path to SSL client certificates file.") + parser.add_argument("--ssl-root-cert", help="Path to SSL root certificates file.") + parser.add_argument("--ssl-client-cert", help="Path to SSL client certificates file.") + parser.add_argument("--ssl-client-key", help="Path to SSL client key file.") parser.add_argument( "--use-ssl", action='store_true', help="Boolean to control if SSL/TLS encryption should be used." ) parser.add_argument("--metadata", action='append', nargs='+', help="Send HTTP Header(s) to server") + parser.add_argument("--options", action='append', nargs='+', help="Send GRPC options to server") parser.add_argument( - "--max-message-length", type=int, default=64 * 1024 * 1024, help="Maximum message length for GRPC server." + "--max-message-length", type=validate_grpc_message_size, default=64 * 1024 * 1024, help="Maximum message length for GRPC server." ) return parser diff --git a/riva/client/auth.py b/riva/client/auth.py index 046f409d..4ee70854 100644 --- a/riva/client/auth.py +++ b/riva/client/auth.py @@ -8,23 +8,34 @@ def create_channel( - ssl_cert: Optional[Union[str, os.PathLike]] = None, + ssl_root_cert: Optional[Union[str, os.PathLike]] = None, + ssl_client_cert: Optional[Union[str, os.PathLike]] = None, + ssl_client_key: Optional[Union[str, os.PathLike]] = None, use_ssl: bool = False, uri: str = "localhost:50051", metadata: Optional[List[Tuple[str, str]]] = None, - max_message_length: int = 64 * 1024 * 1024, + options: Optional[List[Tuple[str, str]]] = [], ) -> grpc.Channel: def metadata_callback(context, callback): callback(metadata, None) - options = [('grpc.max_receive_message_length', max_message_length), ('grpc.max_send_message_length', max_message_length)] - if ssl_cert is not None or use_ssl: + if ssl_root_cert is not None or ssl_client_cert is not None or ssl_client_key is not None or use_ssl: root_certificates = None - if ssl_cert is not None: - ssl_cert = Path(ssl_cert).expanduser() - with open(ssl_cert, 'rb') as f: + client_certificates = None + client_key = None + if ssl_root_cert is not None: + ssl_root_cert = Path(ssl_root_cert).expanduser() + with open(ssl_root_cert, 'rb') as f: root_certificates = f.read() - creds = grpc.ssl_channel_credentials(root_certificates) + if ssl_client_cert is not None: + ssl_client_cert = Path(ssl_client_cert).expanduser() + with open(ssl_client_cert, 'rb') as f: + client_certificates = f.read() + if ssl_client_key is not None: + ssl_client_key = Path(ssl_client_key).expanduser() + with open(ssl_client_key, 'rb') as f: + client_key = f.read() + creds = grpc.ssl_channel_credentials(root_certificates=root_certificates, private_key=client_key, certificate_chain=client_certificates) if metadata: auth_creds = grpc.metadata_call_credentials(metadata_callback) creds = grpc.composite_channel_credentials(creds, auth_creds) @@ -37,23 +48,58 @@ def metadata_callback(context, callback): class Auth: def __init__( self, - ssl_cert: Optional[Union[str, os.PathLike]] = None, + ssl_root_cert: Optional[Union[str, os.PathLike]] = None, use_ssl: bool = False, uri: str = "localhost:50051", metadata_args: List[List[str]] = None, - max_message_length: int = 64 * 1024 * 1024, + ssl_client_cert: Optional[Union[str, os.PathLike]] = None, + ssl_client_key: Optional[Union[str, os.PathLike]] = None, + options: Optional[List[Tuple[str, str]]] = [], ) -> None: """ - A class responsible for establishing connection with a server and providing security metadata. + Initialize the Auth class for establishing secure connections with a server. + + This class handles SSL/TLS configuration, authentication metadata, and gRPC channel creation + for secure communication with Riva services. Args: - ssl_cert (:obj:`Union[str, os.PathLike]`, `optional`): a path to SSL certificate file. If :param:`use_ssl` - is :obj:`False` and :param:`ssl_cert` is not :obj:`None`, then SSL is used. - use_ssl (:obj:`bool`, defaults to :obj:`False`): whether to use SSL. If :param:`ssl_cert` is :obj:`None`, - then SSL is still used but with default credentials. - uri (:obj:`str`, defaults to :obj:`"localhost:50051"`): a Riva URI. + ssl_root_cert (Optional[Union[str, os.PathLike]], optional): Path to the SSL root certificate file. + If provided and use_ssl is False, SSL will still be enabled. Defaults to None. + use_ssl (bool, optional): Whether to use SSL/TLS encryption. If True and ssl_root_cert is None, + SSL will be used with default credentials. Defaults to False. + uri (str, optional): The Riva server URI in format "host:port". Defaults to "localhost:50051". + metadata_args (List[List[str]], optional): List of metadata key-value pairs for authentication. + Each inner list should contain exactly 2 elements: [key, value]. Defaults to None. + ssl_client_cert (Optional[Union[str, os.PathLike]], optional): Path to the SSL client certificate file. + Used for mutual TLS authentication. Defaults to None. + ssl_client_key (Optional[Union[str, os.PathLike]], optional): Path to the SSL client private key file. + Used for mutual TLS authentication. Defaults to None. + options (Optional[List[Tuple[str, str]]], optional): Additional gRPC channel options. + Each tuple should contain (option_name, option_value). Defaults to []. + + Raises: + ValueError: If any metadata argument doesn't contain exactly 2 elements (key-value pair). + + Example: + >>> # Basic connection without SSL + >>> auth = Auth(uri="localhost:50051") + + >>> # SSL connection with custom certificate + >>> auth = Auth( + ... use_ssl=True, + ... ssl_root_cert="/path/to/cert.pem", + ... uri="secure-server:50051" + ... ) + + >>> # Connection with authentication metadata + >>> auth = Auth( + ... metadata_args=[["api-key", "your-api-key"], ["user-id", "12345"]], + ... uri="auth-server:50051" + ... ) """ - self.ssl_cert: Optional[Path] = None if ssl_cert is None else Path(ssl_cert).expanduser() + self.ssl_root_cert: Optional[Path] = None if ssl_root_cert is None else Path(ssl_root_cert).expanduser() + self.ssl_client_cert: Optional[Path] = None if ssl_client_cert is None else Path(ssl_client_cert).expanduser() + self.ssl_client_key: Optional[Path] = None if ssl_client_key is None else Path(ssl_client_key).expanduser() self.uri: str = uri self.use_ssl: bool = use_ssl self.metadata = [] @@ -65,7 +111,7 @@ def __init__( ) self.metadata.append(tuple(meta)) self.channel: grpc.Channel = create_channel( - self.ssl_cert, self.use_ssl, self.uri, self.metadata, max_message_length=max_message_length + self.ssl_root_cert, self.ssl_client_cert, self.ssl_client_key, self.use_ssl, self.uri, self.metadata, options=options ) def get_auth_metadata(self) -> List[Tuple[str, str]]: diff --git a/scripts/asr/realtime_asr_client.py b/scripts/asr/realtime_asr_client.py index 6c256018..ea1e6f8c 100644 --- a/scripts/asr/realtime_asr_client.py +++ b/scripts/asr/realtime_asr_client.py @@ -46,20 +46,12 @@ def parse_args() -> argparse.Namespace: help="Duration in seconds to record from microphone (only used with --mic)", default=None ) - - # Audio device configuration - try: - import riva.client.audio_io - default_device_info = riva.client.audio_io.get_default_input_device_info() - default_device_index = None if default_device_info is None else default_device_info['index'] - except ModuleNotFoundError: - default_device_index = None parser.add_argument( "--input-device", type=int, - default=default_device_index, - help="Input audio device index to use (only used with --mic)" + default=None, + help="Input audio device index to use (only used with --mic). If not specified, will use default device." ) parser.add_argument( "--list-devices", @@ -126,6 +118,15 @@ def parse_args() -> argparse.Namespace: return args +def get_default_device_index(): + """Get default audio device index only when needed.""" + try: + import riva.client.audio_io + default_device_info = riva.client.audio_io.get_default_input_device_info() + return None if default_device_info is None else default_device_info['index'] + except ModuleNotFoundError: + return None + def setup_signal_handler(): """Set up signal handler for graceful shutdown.""" def signal_handler(sig, frame): @@ -145,11 +146,18 @@ async def create_audio_iterator(args): Audio iterator for streaming audio data """ if args.mic: + # Only import when using microphone from riva.client.audio_io import MicrophoneStream + + # Get default device index if not specified + device_index = args.input_device + if device_index is None: + device_index = get_default_device_index() + audio_chunk_iterator = MicrophoneStream( args.sample_rate_hz, args.file_streaming_chunk, - device=args.input_device + device=device_index ) args.num_channels = 1 else: diff --git a/scripts/asr/riva_streaming_asr_client.py b/scripts/asr/riva_streaming_asr_client.py index 43873942..f600af66 100644 --- a/scripts/asr/riva_streaming_asr_client.py +++ b/scripts/asr/riva_streaming_asr_client.py @@ -50,7 +50,15 @@ def streaming_transcription_worker( ) -> None: output_file = Path(output_file).expanduser() try: - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) asr_service = riva.client.ASRService(auth) config = riva.client.StreamingRecognitionConfig( config=riva.client.RecognitionConfig( diff --git a/scripts/asr/transcribe_file.py b/scripts/asr/transcribe_file.py index f9cbd7c9..1849a675 100644 --- a/scripts/asr/transcribe_file.py +++ b/scripts/asr/transcribe_file.py @@ -66,7 +66,15 @@ def main() -> None: if args.list_devices: riva.client.audio_io.list_output_devices() return - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) asr_service = riva.client.ASRService(auth) if args.list_models: diff --git a/scripts/asr/transcribe_file_offline.py b/scripts/asr/transcribe_file_offline.py index 8d1fd482..afedb957 100644 --- a/scripts/asr/transcribe_file_offline.py +++ b/scripts/asr/transcribe_file_offline.py @@ -33,7 +33,15 @@ def parse_args() -> argparse.Namespace: def main() -> None: args = parse_args() - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + options = [('grpc.max_receive_message_length', args.max_message_length), ('grpc.max_send_message_length', args.max_message_length)] + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=options) asr_service = riva.client.ASRService(auth) if args.list_models: diff --git a/scripts/asr/transcribe_mic.py b/scripts/asr/transcribe_mic.py index 77d38fb3..3fd2b5a2 100644 --- a/scripts/asr/transcribe_mic.py +++ b/scripts/asr/transcribe_mic.py @@ -45,7 +45,15 @@ def main() -> None: if args.list_devices: riva.client.audio_io.list_input_devices() return - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) asr_service = riva.client.ASRService(auth) config = riva.client.StreamingRecognitionConfig( config=riva.client.RecognitionConfig( diff --git a/scripts/nlp/punctuation_client.py b/scripts/nlp/punctuation_client.py index ab843a77..437cce21 100644 --- a/scripts/nlp/punctuation_client.py +++ b/scripts/nlp/punctuation_client.py @@ -39,7 +39,15 @@ def parse_args() -> argparse.Namespace: def run_punct_capit(args: argparse.Namespace) -> None: - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nlp_service = riva.client.NLPService(auth) if args.interactive: while True: @@ -134,7 +142,15 @@ def run_tests(args: argparse.Namespace) -> int: ], } - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nlp_service = riva.client.NLPService(auth) fail_count = 0 diff --git a/scripts/nmt/nmt.py b/scripts/nmt/nmt.py index a978cc9b..f3d13dbe 100644 --- a/scripts/nmt/nmt.py +++ b/scripts/nmt/nmt.py @@ -123,7 +123,15 @@ def request(inputs,args): args = parse_args() - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nmt_client = riva.client.NeuralMachineTranslationClient(auth) if args.list_models: diff --git a/scripts/nmt/nmt_speech_to_speech.py b/scripts/nmt/nmt_speech_to_speech.py index 49426bfe..7348525b 100644 --- a/scripts/nmt/nmt_speech_to_speech.py +++ b/scripts/nmt/nmt_speech_to_speech.py @@ -32,7 +32,15 @@ def main(): if not os.path.exists(args.audio_file): raise FileNotFoundError(f"Input audio file not found: {args.audio_file}") - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nmt_client = riva.client.NeuralMachineTranslationClient(auth) if args.list_models: diff --git a/scripts/nmt/nmt_speech_to_text.py b/scripts/nmt/nmt_speech_to_text.py index 7d75a6c7..bc8c0f2a 100644 --- a/scripts/nmt/nmt_speech_to_text.py +++ b/scripts/nmt/nmt_speech_to_text.py @@ -43,7 +43,15 @@ def main(): if not os.path.exists(args.audio_file): raise FileNotFoundError(f"Input audio file not found: {args.audio_file}") - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nmt_client = riva.client.NeuralMachineTranslationClient(auth) if args.list_models: diff --git a/scripts/tts/talk.py b/scripts/tts/talk.py index ac3ed6d0..2df233b0 100644 --- a/scripts/tts/talk.py +++ b/scripts/tts/talk.py @@ -99,8 +99,19 @@ def main() -> None: riva.client.audio_io.list_output_devices() return + if args.options is None: + args.options = [] + args.options.append(('grpc.max_receive_message_length', args.max_message_length)) + args.options.append(('grpc.max_send_message_length', args.max_message_length)) + auth = riva.client.Auth( - args.ssl_cert, args.use_ssl, args.server, args.metadata, max_message_length=args.max_message_length + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options ) service = riva.client.SpeechSynthesisService(auth) nchannels = 1 From 72b2657a667f54f08b6bf5d7280a5730f7c62935 Mon Sep 17 00:00:00 2001 From: Viraj Karandikar <16838694+virajkarandikar@users.noreply.github.com> Date: Tue, 12 Aug 2025 12:44:39 +0530 Subject: [PATCH 2/4] Add SSL support in realtime client (#150) * add ssl support for realtime api * add option to open aio channel * add websockets to requirements --- requirements.txt | 1 + riva/client/auth.py | 22 ++++- riva/client/realtime.py | 149 +++++++++++++++++------------ scripts/asr/realtime_asr_client.py | 122 ++++++++++++----------- 4 files changed, 166 insertions(+), 128 deletions(-) diff --git a/requirements.txt b/requirements.txt index 0db03ce0..2e4b2d22 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ setuptools==78.1.1 grpcio==1.67.1 grpcio-tools==1.67.1 +websockets==15.0.1 diff --git a/riva/client/auth.py b/riva/client/auth.py index 4ee70854..8a4688d4 100644 --- a/riva/client/auth.py +++ b/riva/client/auth.py @@ -15,6 +15,7 @@ def create_channel( uri: str = "localhost:50051", metadata: Optional[List[Tuple[str, str]]] = None, options: Optional[List[Tuple[str, str]]] = [], + use_aio: Optional[bool] = False, ) -> grpc.Channel: def metadata_callback(context, callback): callback(metadata, None) @@ -39,9 +40,15 @@ def metadata_callback(context, callback): if metadata: auth_creds = grpc.metadata_call_credentials(metadata_callback) creds = grpc.composite_channel_credentials(creds, auth_creds) - channel = grpc.secure_channel(uri, creds, options=options) + if use_aio: + channel = grpc.aio.secure_channel(uri, creds, options=options) + else: + channel = grpc.secure_channel(uri, creds, options=options) else: - channel = grpc.insecure_channel(uri, options=options) + if use_aio: + channel = grpc.aio.insecure_channel(uri, options=options) + else: + channel = grpc.insecure_channel(uri, options=options) return channel @@ -55,6 +62,7 @@ def __init__( ssl_client_cert: Optional[Union[str, os.PathLike]] = None, ssl_client_key: Optional[Union[str, os.PathLike]] = None, options: Optional[List[Tuple[str, str]]] = [], + use_aio: bool = False, ) -> None: """ Initialize the Auth class for establishing secure connections with a server. @@ -76,6 +84,7 @@ def __init__( Used for mutual TLS authentication. Defaults to None. options (Optional[List[Tuple[str, str]]], optional): Additional gRPC channel options. Each tuple should contain (option_name, option_value). Defaults to []. + use_aio (bool, optional): Whether to use asyncio for the channel. Defaults to False. Raises: ValueError: If any metadata argument doesn't contain exactly 2 elements (key-value pair). @@ -111,7 +120,14 @@ def __init__( ) self.metadata.append(tuple(meta)) self.channel: grpc.Channel = create_channel( - self.ssl_root_cert, self.ssl_client_cert, self.ssl_client_key, self.use_ssl, self.uri, self.metadata, options=options + self.ssl_root_cert, + self.ssl_client_cert, + self.ssl_client_key, + self.use_ssl, + self.uri, + self.metadata, + options=options, + use_aio=use_aio, ) def get_auth_metadata(self) -> List[Tuple[str, str]]: diff --git a/riva/client/realtime.py b/riva/client/realtime.py index 5de310bf..7ff2ad66 100644 --- a/riva/client/realtime.py +++ b/riva/client/realtime.py @@ -10,10 +10,11 @@ import requests import websockets +import ssl from websockets.exceptions import WebSocketException logging.basicConfig( - level=logging.INFO, + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) @@ -21,40 +22,41 @@ class RealtimeClient: """Client for real-time transcription via WebSocket connection.""" - + def __init__(self, args: argparse.Namespace): """Initialize the RealtimeClient. - + Args: args: Command line arguments containing configuration """ self.args = args self.websocket = None self.session_config = None - + # Input audio playback self.input_audio_queue = queue.Queue() self.input_playback_thread = None self.is_input_playing = False self.input_buffer_size = 1024 # Buffer size for input audio playback - + # Transcription results self.delta_transcripts: List[str] = [] self.interim_final_transcripts: List[str] = [] self.final_transcript: str = "" self.is_config_updated = False + async def connect(self): """Establish connection to the ASR server.""" try: # Initialize session via HTTP POST session_data = await self._initialize_http_session() self.session_config = session_data - + # Connect to WebSocket await self._connect_websocket() await self._initialize_session() - + except requests.exceptions.RequestException as e: logger.error(f"HTTP request failed: {e}") raise @@ -68,27 +70,48 @@ async def connect(self): async def _initialize_http_session(self) -> Dict[str, Any]: """Initialize session via HTTP POST request.""" headers = {"Content-Type": "application/json"} + uri = f"http://{self.args.server}/v1/realtime/transcription_sessions" + if self.args.use_ssl: + uri = f"https://{self.args.server}/v1/realtime/transcription_sessions" + logger.info(f"Initializing session via HTTP POST request to: {uri}") response = requests.post( - f"http://{self.args.server}/v1/realtime/transcription_sessions", + uri, headers=headers, - json={} + json={}, + cert=(self.args.ssl_client_cert, self.args.ssl_client_key) if self.args.ssl_client_cert and self.args.ssl_client_key else None, + verify=self.args.ssl_root_cert if self.args.ssl_root_cert else True ) - + if response.status_code != 200: raise Exception( f"Failed to initialize session. Status: {response.status_code}, " f"Error: {response.text}" ) - + session_data = response.json() logger.info(f"Session initialized: {session_data}") return session_data async def _connect_websocket(self): """Connect to WebSocket endpoint.""" + ssl_context = None ws_url = f"ws://{self.args.server}{self.args.endpoint}?{self.args.query_params}" + if self.args.use_ssl: + ws_url = f"wss://{self.args.server}{self.args.endpoint}?{self.args.query_params}" + + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + # Load a custom CA certificate bundle + if self.args.ssl_root_cert: + ssl_context.load_verify_locations(self.args.ssl_root_cert) + # Load a client certificate and key + if self.args.ssl_client_cert and self.args.ssl_client_key: + ssl_context.load_cert_chain(self.args.ssl_client_cert, self.args.ssl_client_key) + # Disable hostname verification + ssl_context.check_hostname = False + # ssl_context.verify_mode = ssl.CERT_REQUIRED + logger.info(f"Connecting to WebSocket: {ws_url}") - self.websocket = await websockets.connect(ws_url) + self.websocket = await websockets.connect(ws_url, ssl=ssl_context) async def _initialize_session(self): """Initialize the WebSocket session.""" @@ -97,7 +120,7 @@ async def _initialize_session(self): response = await self.websocket.recv() response_data = json.loads(response) logger.info("Session created: %s", response_data) - + event_type = response_data.get("type", "") if event_type == "conversation.created": logger.info("Conversation created successfully") @@ -111,9 +134,9 @@ async def _initialize_session(self): if not self.is_config_updated: logger.error("Failed to update session") raise Exception("Failed to update session") - + logger.info("Session initialization complete") - + except json.JSONDecodeError as e: logger.error(f"Failed to parse JSON response: {e}") raise @@ -126,7 +149,7 @@ async def _initialize_session(self): def _safe_update_config(self, config: Dict[str, Any], key: str, value: Any, section: str = None): """Safely update a configuration value, creating the section if it doesn't exist. - + Args: config: The configuration dictionary to update key: The key to update @@ -144,62 +167,62 @@ def _safe_update_config(self, config: Dict[str, Any], key: str, value: Any, sect async def _update_session(self) -> bool: """Update session configuration by selectively overriding server defaults. - + Returns: True if session was updated successfully, False otherwise """ logger.info("Updating session configuration...") logger.info(f"Server default config: {self.session_config}") - + # Create a copy of the session config from server defaults session_config = self.session_config.copy() - + # Track what we're overriding overrides = [] - + # Update input audio transcription - only override if args are provided if hasattr(self.args, 'language_code') and self.args.language_code: self._safe_update_config(session_config, "language", self.args.language_code, "input_audio_transcription") overrides.append("language") - + if hasattr(self.args, 'model_name') and self.args.model_name: self._safe_update_config(session_config, "model", self.args.model_name, "input_audio_transcription") overrides.append("model") - + if hasattr(self.args, 'prompt') and self.args.prompt: self._safe_update_config(session_config, "prompt", self.args.prompt, "input_audio_transcription") overrides.append("prompt") - + # Update input audio parameters - only override if args are provided if hasattr(self.args, 'sample_rate_hz') and self.args.sample_rate_hz: self._safe_update_config(session_config, "sample_rate_hz", self.args.sample_rate_hz, "input_audio_params") overrides.append("sample_rate_hz") - + if hasattr(self.args, 'num_channels') and self.args.num_channels: self._safe_update_config(session_config, "num_channels", self.args.num_channels, "input_audio_params") overrides.append("num_channels") - + # Update recognition settings - only override if args are provided if hasattr(self.args, 'max_alternatives') and self.args.max_alternatives is not None: self._safe_update_config(session_config, "max_alternatives", self.args.max_alternatives, "recognition_config") overrides.append("max_alternatives") - + if hasattr(self.args, 'automatic_punctuation') and self.args.automatic_punctuation is not None: self._safe_update_config(session_config, "enable_automatic_punctuation", self.args.automatic_punctuation, "recognition_config") overrides.append("automatic_punctuation") - + if hasattr(self.args, 'word_time_offsets') and self.args.word_time_offsets is not None: self._safe_update_config(session_config, "enable_word_time_offsets", self.args.word_time_offsets, "recognition_config") overrides.append("word_time_offsets") - + if hasattr(self.args, 'profanity_filter') and self.args.profanity_filter is not None: self._safe_update_config(session_config, "enable_profanity_filter", self.args.profanity_filter, "recognition_config") overrides.append("profanity_filter") - + if hasattr(self.args, 'no_verbatim_transcripts') and self.args.no_verbatim_transcripts is not None: self._safe_update_config(session_config, "enable_verbatim_transcripts", self.args.no_verbatim_transcripts, "recognition_config") overrides.append("verbatim_transcripts") - + # Configure speaker diarization if enabled if hasattr(self.args, 'speaker_diarization') and self.args.speaker_diarization: session_config["speaker_diarization"] = { @@ -207,10 +230,10 @@ async def _update_session(self) -> bool: "max_speaker_count": getattr(self.args, 'diarization_max_speakers', 2) } overrides.append("speaker_diarization") - + # Configure word boosting if enabled - if (hasattr(self.args, 'boosted_lm_words') and - self.args.boosted_lm_words and + if (hasattr(self.args, 'boosted_lm_words') and + self.args.boosted_lm_words and len(self.args.boosted_lm_words)): word_boosting_list = [ { @@ -223,44 +246,44 @@ async def _update_session(self) -> bool: "word_boosting_list": word_boosting_list } overrides.append("word_boosting") - + # Configure endpointing if any parameters are set if self._has_endpointing_config(): session_config["endpointing_config"] = self._build_endpointing_config() overrides.append("endpointing_config") - + # Configure custom configuration if provided if hasattr(self.args, 'custom_configuration') and self.args.custom_configuration: custom_config = self._parse_custom_configuration(self.args.custom_configuration) if custom_config: session_config["custom_configuration"] = custom_config overrides.append("custom_configuration") - + if overrides: logger.info(f"Overriding server defaults for: {', '.join(overrides)}") else: logger.info("Using server default configuration (no overrides)") - + logger.info(f"Final session config: {session_config}") - + # Send update request update_session_request = { "type": "transcription_session.update", "session": session_config } await self._send_message(update_session_request) - + # Handle response return await self._handle_session_update_response() def _has_endpointing_config(self) -> bool: """Check if any endpointing configuration parameters are set.""" return ( - self.args.start_history > 0 or - self.args.start_threshold > 0 or - self.args.stop_history > 0 or - self.args.stop_history_eou > 0 or - self.args.stop_threshold > 0 or + self.args.start_history > 0 or + self.args.start_threshold > 0 or + self.args.stop_history > 0 or + self.args.stop_history_eou > 0 or + self.args.stop_threshold > 0 or self.args.stop_threshold_eou > 0 ) @@ -277,41 +300,41 @@ def _build_endpointing_config(self) -> Dict[str, Any]: def _parse_custom_configuration(self, custom_configuration: str) -> Dict[str, str]: """Parse custom configuration string into a dictionary. - + Args: custom_configuration: String in format "key1:value1,key2:value2" - + Returns: Dictionary of custom configuration key-value pairs - + Raises: ValueError: If the custom configuration format is invalid """ custom_config = {} custom_configuration = custom_configuration.strip().replace(" ", "") - + if not custom_configuration: return custom_config - + for pair in custom_configuration.split(","): key_value = pair.split(":") if len(key_value) == 2: custom_config[key_value[0]] = key_value[1] else: raise ValueError(f"Invalid key:value pair {key_value}") - + return custom_config async def _handle_session_update_response(self) -> bool: """Handle session update response. - + Returns: True if session was updated successfully, False otherwise """ response = await self.websocket.recv() response_data = json.loads(response) logger.info("Session updated: %s", response_data) - + event_type = response_data.get("type", "") if event_type == "transcription_session.updated": logger.info("Transcription session updated successfully") @@ -330,23 +353,23 @@ async def _send_message(self, message: Dict[str, Any]): async def send_audio_chunks(self, audio_chunks): """Send audio chunks to the server for transcription.""" logger.info("Sending audio chunks...") - + for chunk in audio_chunks: chunk_base64 = base64.b64encode(chunk).decode("utf-8") - + # Send chunk to the server await self._send_message({ "type": "input_audio_buffer.append", "audio": chunk_base64, }) - + # Commit the chunk await self._send_message({ "type": "input_audio_buffer.commit", }) - + logger.info("All chunks sent") - + # Tell the server that we are done sending chunks await self._send_message({ "type": "input_audio_buffer.done", @@ -356,7 +379,7 @@ async def receive_responses(self): """Receive and process transcription responses from the server.""" logger.info("Listening for responses...") received_final_response = False - + while not received_final_response: try: response = await asyncio.wait_for(self.websocket.recv(), 10.0) @@ -367,13 +390,13 @@ async def receive_responses(self): delta = event.get("delta", "") logger.info("Transcript: %s", delta) self.delta_transcripts.append(delta) - + elif event_type == "conversation.item.input_audio_transcription.completed": is_last_result = event.get("is_last_result", False) interim_final_transcript = event.get("transcript", "") self.interim_final_transcripts.append(interim_final_transcript) self.final_transcript = interim_final_transcript - + if is_last_result: logger.info("Final Transcript: %s", self.final_transcript) logger.info("Transcription completed") @@ -381,9 +404,9 @@ async def receive_responses(self): break else: logger.info("Interim Transcript: %s", interim_final_transcript) - + logger.info("Words Info: %s", event.get("words_info", "")) - + elif "error" in event_type.lower(): logger.error( f"Error: {event.get('error', {}).get('message', 'Unknown error')}" @@ -399,7 +422,7 @@ async def receive_responses(self): def save_responses(self, output_text_file: str): """Save collected transcription text to a file. - + Args: output_text_file: Path to the output text file """ diff --git a/scripts/asr/realtime_asr_client.py b/scripts/asr/realtime_asr_client.py index ea1e6f8c..eb13c7bc 100644 --- a/scripts/asr/realtime_asr_client.py +++ b/scripts/asr/realtime_asr_client.py @@ -11,6 +11,7 @@ from riva.client.argparse_utils import ( add_asr_config_argparse_parameters, add_realtime_config_argparse_parameters, + add_connection_argparse_parameters, ) @@ -27,94 +28,91 @@ def parse_args() -> argparse.Namespace: ), formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - + # Input configuration parser.add_argument( - "--input-file", - required=False, + "--input-file", + required=False, help="Input audio file (required when not using --mic)" ) parser.add_argument( - "--mic", - action="store_true", - help="Use microphone input instead of file input", + "--mic", + action="store_true", + help="Use microphone input instead of file input", default=False ) parser.add_argument( - "--duration", - type=int, - help="Duration in seconds to record from microphone (only used with --mic)", + "--duration", + type=int, + help="Duration in seconds to record from microphone (only used with --mic)", default=None ) - + parser.add_argument( - "--input-device", - type=int, - default=None, + "--input-device", + type=int, + default=None, help="Input audio device index to use (only used with --mic). If not specified, will use default device." ) parser.add_argument( - "--list-devices", - action="store_true", + "--list-devices", + action="store_true", help="List available input audio device indices" ) - + # Audio parameters parser.add_argument( - "--sample-rate-hz", - type=int, - help="Number of frames per second in audio streamed from a microphone.", + "--sample-rate-hz", + type=int, + help="Number of frames per second in audio streamed from a microphone.", default=16000 ) parser.add_argument( - "--num-channels", - type=int, - help="Number of audio channels.", + "--num-channels", + type=int, + help="Number of audio channels.", default=1 ) parser.add_argument( - "--file-streaming-chunk", - type=int, - default=1600, + "--file-streaming-chunk", + type=int, + default=1600, help="Maximum number of frames in one chunk sent to server." ) - + # Output configuration parser.add_argument( - "--output-text", - type=str, + "--output-text", + type=str, help="Output text file" ) parser.add_argument( - "--prompt", - default="", + "--prompt", + default="", help="Prompt to be used for transcription." ) - - parser.add_argument( - "--server", - default="localhost:9090", - help="URI to WebSocket server endpoint." - ) - + + # Add connection parameters + parser = add_connection_argparse_parameters(parser) + # Add ASR and realtime configuration parameters parser = add_asr_config_argparse_parameters( - parser, - max_alternatives=True, - profanity_filter=True, + parser, + max_alternatives=True, + profanity_filter=True, word_time_offsets=True ) parser = add_realtime_config_argparse_parameters(parser) - + args = parser.parse_args() - + # Validate input configuration if not args.mic and not args.input_file: parser.error("Either --input-file or --mic must be specified") - + if args.mic and args.input_file: parser.error("Cannot specify both --input-file and --mic") - + return args @@ -138,25 +136,25 @@ def signal_handler(sig, frame): async def create_audio_iterator(args): """Create appropriate audio iterator based on input type. - + Args: args: Command line arguments containing input configuration - + Returns: Audio iterator for streaming audio data """ if args.mic: # Only import when using microphone from riva.client.audio_io import MicrophoneStream - + # Get default device index if not specified device_index = args.input_device if device_index is None: device_index = get_default_device_index() - + audio_chunk_iterator = MicrophoneStream( - args.sample_rate_hz, - args.file_streaming_chunk, + args.sample_rate_hz, + args.file_streaming_chunk, device=device_index ) args.num_channels = 1 @@ -166,29 +164,29 @@ async def create_audio_iterator(args): args.sample_rate_hz = wav_parameters['framerate'] args.num_channels = wav_parameters['nchannels'] audio_chunk_iterator = AudioChunkFileIterator( - args.input_file, - args.file_streaming_chunk, + args.input_file, + args.file_streaming_chunk, delay_callback=None ) - + return audio_chunk_iterator async def run_transcription(args): """Run the transcription process. - + Args: args: Command line arguments containing all configuration """ client = RealtimeClient(args=args) - + try: # Create audio iterator audio_chunk_iterator = await create_audio_iterator(args) - + # Connect and start transcription await client.connect() - + # Run send and receive tasks concurrently send_task = asyncio.create_task( client.send_audio_chunks(audio_chunk_iterator) @@ -196,13 +194,13 @@ async def run_transcription(args): receive_task = asyncio.create_task( client.receive_responses() ) - + await asyncio.gather(send_task, receive_task) - + # Save results if output file specified if args.output_text: client.save_responses(args.output_text) - + except Exception as e: print(f"Error: {e}") raise @@ -213,7 +211,7 @@ async def run_transcription(args): async def main() -> None: """Main entry point for the realtime ASR client.""" args = parse_args() - + # Handle list devices option if args.list_devices: try: @@ -222,7 +220,7 @@ async def main() -> None: except ModuleNotFoundError: print("PyAudio not available. Please install PyAudio to list audio devices.") return - + setup_signal_handler() try: From 61cf9541fadd7e513e683b8304b582a47b72e470 Mon Sep 17 00:00:00 2001 From: yhayarannvidia Date: Fri, 22 Aug 2025 19:44:28 +0530 Subject: [PATCH 3/4] FIx microphone case for realtime ASR client (#151) * Realtime ASR micropphone fix * Set default server to localhost for realtime ASR client * refactor argument parsing in realtime ASR client to use mutually exclusive input options and update default server port to 9000 * refactor logging in RealtimeClient to use debug level for detailed internal state and error messages * Enhance RealtimeClient to support microphone input with PCM16 encoding and improve audio chunk handling with async iteration. Update logging for word information formatting and handle timeouts during audio processing. --- riva/client/realtime.py | 149 +++++++++++++++++++---------- scripts/asr/realtime_asr_client.py | 132 ++++++++++++++++++++----- 2 files changed, 206 insertions(+), 75 deletions(-) diff --git a/riva/client/realtime.py b/riva/client/realtime.py index 7ff2ad66..1bd382ab 100644 --- a/riva/client/realtime.py +++ b/riva/client/realtime.py @@ -38,10 +38,6 @@ def __init__(self, args: argparse.Namespace): self.input_playback_thread = None self.is_input_playing = False self.input_buffer_size = 1024 # Buffer size for input audio playback - - # Transcription results - self.delta_transcripts: List[str] = [] - self.interim_final_transcripts: List[str] = [] self.final_transcript: str = "" self.is_config_updated = False @@ -58,13 +54,13 @@ async def connect(self): await self._initialize_session() except requests.exceptions.RequestException as e: - logger.error(f"HTTP request failed: {e}") + logger.error("HTTP request failed: %s", e) raise except WebSocketException as e: - logger.error(f"WebSocket connection failed: {e}") + logger.error("WebSocket connection failed: %s", e) raise except Exception as e: - logger.error(f"Unexpected error during connection: {e}") + logger.error("Unexpected error during connection: %s", e) raise async def _initialize_http_session(self) -> Dict[str, Any]: @@ -73,7 +69,7 @@ async def _initialize_http_session(self) -> Dict[str, Any]: uri = f"http://{self.args.server}/v1/realtime/transcription_sessions" if self.args.use_ssl: uri = f"https://{self.args.server}/v1/realtime/transcription_sessions" - logger.info(f"Initializing session via HTTP POST request to: {uri}") + logger.debug("Initializing session via HTTP POST request to: %s", uri) response = requests.post( uri, headers=headers, @@ -89,7 +85,7 @@ async def _initialize_http_session(self) -> Dict[str, Any]: ) session_data = response.json() - logger.info(f"Session initialized: {session_data}") + logger.debug("Session initialized: %s", session_data) return session_data async def _connect_websocket(self): @@ -110,7 +106,7 @@ async def _connect_websocket(self): ssl_context.check_hostname = False # ssl_context.verify_mode = ssl.CERT_REQUIRED - logger.info(f"Connecting to WebSocket: {ws_url}") + logger.debug("Connecting to WebSocket: %s", ws_url) self.websocket = await websockets.connect(ws_url, ssl=ssl_context) async def _initialize_session(self): @@ -119,14 +115,14 @@ async def _initialize_session(self): # Handle first response: "conversation.created" response = await self.websocket.recv() response_data = json.loads(response) - logger.info("Session created: %s", response_data) + logger.debug("Session created: %s", response_data) event_type = response_data.get("type", "") if event_type == "conversation.created": - logger.info("Conversation created successfully") + logger.debug("Conversation created successfully") logger.debug("Response structure: %s", list(response_data.keys())) else: - logger.warning(f"Unexpected first response type: {event_type}") + logger.warning("Unexpected first response type: %s", event_type) logger.debug("Full response: %s", response_data) # Update session configuration @@ -135,16 +131,16 @@ async def _initialize_session(self): logger.error("Failed to update session") raise Exception("Failed to update session") - logger.info("Session initialization complete") + logger.debug("Session initialization complete") except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON response: {e}") + logger.error("Failed to parse JSON response: %s", e) raise except KeyError as e: - logger.error(f"Missing expected key in response: {e}") + logger.error("Missing expected key in response: %s", e) raise except Exception as e: - logger.error(f"Unexpected error during session initialization: {e}") + logger.error("Unexpected error during session initialization: %s", e) raise def _safe_update_config(self, config: Dict[str, Any], key: str, value: Any, section: str = None): @@ -160,10 +156,10 @@ def _safe_update_config(self, config: Dict[str, Any], key: str, value: Any, sect if section not in config: config[section] = {} config[section][key] = value - logger.debug(f"Updated {section}.{key} = {value}") + logger.debug("Updated %s.%s = %s", section, key, value) else: config[key] = value - logger.debug(f"Updated {key} = {value}") + logger.debug("Updated %s = %s", key, value) async def _update_session(self) -> bool: """Update session configuration by selectively overriding server defaults. @@ -171,14 +167,22 @@ async def _update_session(self) -> bool: Returns: True if session was updated successfully, False otherwise """ - logger.info("Updating session configuration...") - logger.info(f"Server default config: {self.session_config}") + logger.debug("Updating session configuration...") + logger.debug("Server default config: %s", self.session_config) # Create a copy of the session config from server defaults session_config = self.session_config.copy() # Track what we're overriding overrides = [] + + # Check if the input is microphone, then set the encoding to pcm16 + if hasattr(self.args, 'mic') and self.args.mic: + self._safe_update_config(session_config, "input_audio_format", "pcm16") + overrides.append("input_audio_format") + else: + self._safe_update_config(session_config, "input_audio_format", "none") + overrides.append("input_audio_format") # Update input audio transcription - only override if args are provided if hasattr(self.args, 'language_code') and self.args.language_code: @@ -260,11 +264,11 @@ async def _update_session(self) -> bool: overrides.append("custom_configuration") if overrides: - logger.info(f"Overriding server defaults for: {', '.join(overrides)}") + logger.debug("Overriding server defaults for: %s", ', '.join(overrides)) else: - logger.info("Using server default configuration (no overrides)") + logger.debug("Using server default configuration (no overrides)") - logger.info(f"Final session config: {session_config}") + logger.debug("Final session config: %s", session_config) # Send update request update_session_request = { @@ -333,16 +337,16 @@ async def _handle_session_update_response(self) -> bool: """ response = await self.websocket.recv() response_data = json.loads(response) - logger.info("Session updated: %s", response_data) + logger.info("Current Session Config: %s", response_data) event_type = response_data.get("type", "") if event_type == "transcription_session.updated": - logger.info("Transcription session updated successfully") + logger.debug("Transcription session updated successfully") logger.debug("Response structure: %s", list(response_data.keys())) self.session_config = response_data["session"] return True else: - logger.warning(f"Unexpected response type: {event_type}") + logger.warning("Unexpected response type: %s", event_type) logger.debug("Full response: %s", response_data) return False @@ -352,23 +356,49 @@ async def _send_message(self, message: Dict[str, Any]): async def send_audio_chunks(self, audio_chunks): """Send audio chunks to the server for transcription.""" - logger.info("Sending audio chunks...") - - for chunk in audio_chunks: - chunk_base64 = base64.b64encode(chunk).decode("utf-8") - - # Send chunk to the server - await self._send_message({ - "type": "input_audio_buffer.append", - "audio": chunk_base64, - }) - - # Commit the chunk - await self._send_message({ - "type": "input_audio_buffer.commit", - }) - - logger.info("All chunks sent") + logger.debug("Sending audio chunks...") + + # Check if the audio_chunks supports async iteration + if hasattr(audio_chunks, '__aiter__'): + # Use async for for async iterators - this allows proper task switching + async for chunk in audio_chunks: + try: + chunk_base64 = base64.b64encode(chunk).decode("utf-8") + + # Send chunk to the server + await self._send_message({ + "type": "input_audio_buffer.append", + "audio": chunk_base64, + }) + + # Commit the chunk + await self._send_message({ + "type": "input_audio_buffer.commit", + }) + except TimeoutError: + # Handle timeout from AsyncAudioIterator - no audio available, continue + logger.debug("No audio chunk available within timeout, continuing...") + continue + except Exception as e: + logger.error(f"Error processing audio chunk: {e}") + continue + else: + # Fallback for regular iterators + for chunk in audio_chunks: + chunk_base64 = base64.b64encode(chunk).decode("utf-8") + + # Send chunk to the server + await self._send_message({ + "type": "input_audio_buffer.append", + "audio": chunk_base64, + }) + + # Commit the chunk + await self._send_message({ + "type": "input_audio_buffer.commit", + }) + + logger.debug("All chunks sent") # Tell the server that we are done sending chunks await self._send_message({ @@ -377,7 +407,7 @@ async def send_audio_chunks(self, audio_chunks): async def receive_responses(self): """Receive and process transcription responses from the server.""" - logger.info("Listening for responses...") + logger.debug("Listening for responses...") received_final_response = False while not received_final_response: @@ -389,12 +419,10 @@ async def receive_responses(self): if event_type == "conversation.item.input_audio_transcription.delta": delta = event.get("delta", "") logger.info("Transcript: %s", delta) - self.delta_transcripts.append(delta) elif event_type == "conversation.item.input_audio_transcription.completed": is_last_result = event.get("is_last_result", False) interim_final_transcript = event.get("transcript", "") - self.interim_final_transcripts.append(interim_final_transcript) self.final_transcript = interim_final_transcript if is_last_result: @@ -405,7 +433,28 @@ async def receive_responses(self): else: logger.info("Interim Transcript: %s", interim_final_transcript) - logger.info("Words Info: %s", event.get("words_info", "")) + # Format Words Info similar to print_streaming function + words_info = event.get("words_info", {}) + if words_info and "words" in words_info: + print("Words Info:") + + # Create header format similar to print_streaming + header_format = '{: <40s}{: <16s}{: <16s}{: <16s}{: <16s}' + header_values = ['Word', 'Start (ms)', 'End (ms)', 'Confidence', 'Speaker'] + print(header_format.format(*header_values)) + + # Print each word with formatted information + for word_data in words_info["words"]: + word = word_data.get("word", "") + start_time = word_data.get("start_time", 0) + end_time = word_data.get("end_time", 0) + confidence = word_data.get("confidence", 0.0) + speaker_tag = word_data.get("speaker_tag", 0) + + # Format the word info line similar to print_streaming + word_format = '{: <40s}{: <16.0f}{: <16.0f}{: <16.4f}{: <16d}' + word_values = [word, start_time, end_time, confidence, speaker_tag] + print(word_format.format(*word_values)) elif "error" in event_type.lower(): logger.error( @@ -417,7 +466,7 @@ async def receive_responses(self): except asyncio.TimeoutError: continue except Exception as e: - logger.error(f"Error: {e}") + logger.error("Error: %s", e) break def save_responses(self, output_text_file: str): @@ -431,7 +480,7 @@ def save_responses(self, output_text_file: str): with open(output_text_file, "w") as f: f.write(self.final_transcript) except Exception as e: - logger.error(f"Error saving text: {e}") + logger.error("Error saving text: %s", e) async def disconnect(self): """Close the WebSocket connection.""" diff --git a/scripts/asr/realtime_asr_client.py b/scripts/asr/realtime_asr_client.py index eb13c7bc..045595da 100644 --- a/scripts/asr/realtime_asr_client.py +++ b/scripts/asr/realtime_asr_client.py @@ -30,17 +30,22 @@ def parse_args() -> argparse.Namespace: ) # Input configuration - parser.add_argument( + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument( "--input-file", - required=False, - help="Input audio file (required when not using --mic)" + help="Input audio file" ) - parser.add_argument( + input_group.add_argument( "--mic", action="store_true", - help="Use microphone input instead of file input", - default=False + help="Use microphone input instead of file input" + ) + input_group.add_argument( + "--list-devices", + action="store_true", + help="List available input audio device indices" ) + parser.add_argument( "--duration", type=int, @@ -54,11 +59,6 @@ def parse_args() -> argparse.Namespace: default=None, help="Input audio device index to use (only used with --mic). If not specified, will use default device." ) - parser.add_argument( - "--list-devices", - action="store_true", - help="List available input audio device indices" - ) # Audio parameters parser.add_argument( @@ -95,6 +95,9 @@ def parse_args() -> argparse.Namespace: # Add connection parameters parser = add_connection_argparse_parameters(parser) + # Override default server for realtime ASR (WebSocket endpoint, not gRPC) + parser.set_defaults(server="localhost:9000") + # Add ASR and realtime configuration parameters parser = add_asr_config_argparse_parameters( parser, @@ -106,13 +109,6 @@ def parse_args() -> argparse.Namespace: args = parser.parse_args() - # Validate input configuration - if not args.mic and not args.input_file: - parser.error("Either --input-file or --mic must be specified") - - if args.mic and args.input_file: - parser.error("Cannot specify both --input-file and --mic") - return args @@ -143,8 +139,7 @@ async def create_audio_iterator(args): Returns: Audio iterator for streaming audio data """ - if args.mic: - # Only import when using microphone + if args.mic: from riva.client.audio_io import MicrophoneStream # Get default device index if not specified @@ -152,11 +147,55 @@ async def create_audio_iterator(args): if device_index is None: device_index = get_default_device_index() - audio_chunk_iterator = MicrophoneStream( - args.sample_rate_hz, - args.file_streaming_chunk, + mic_stream = MicrophoneStream( + args.sample_rate_hz, + args.file_streaming_chunk, device=device_index ) + + # Initialize the stream (this starts the microphone) + audio_chunk_iterator = mic_stream.__enter__() + # Store the stream object for cleanup later + args._mic_stream = mic_stream + print("Recording indefinitely (press Ctrl+C to stop gracefully)...") + + class AsyncAudioIterator: + """Async wrapper for blocking audio iterators to prevent event loop starvation.""" + def __init__(self, audio_iterator): + self.audio_iterator = audio_iterator + self._stop_requested = False + self.chunk_count = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self._stop_requested: + raise StopAsyncIteration + + try: + # Add timeout to prevent hanging when no audio is available + chunk = await asyncio.wait_for( + asyncio.to_thread(lambda: next(self.audio_iterator)), + timeout=1.0 # 1 second timeout + ) + self.chunk_count += 1 + return chunk + except asyncio.TimeoutError: + # Return empty chunk or raise custom exception + raise TimeoutError("No audio chunk available within timeout") + except StopIteration: + print(f"Audio iterator exhausted after {self.chunk_count} chunks") + raise StopAsyncIteration + except Exception as e: + print(f"Error getting audio chunk #{self.chunk_count + 1}: {e}") + raise + + def stop(self): + self._stop_requested = True + + # Use async iterator to prevent event loop starvation + audio_chunk_iterator = AsyncAudioIterator(audio_chunk_iterator) args.num_channels = 1 else: wav_parameters = get_wav_file_parameters(args.input_file) @@ -179,6 +218,8 @@ async def run_transcription(args): args: Command line arguments containing all configuration """ client = RealtimeClient(args=args) + send_task = None + receive_task = None try: # Create audio iterator @@ -200,11 +241,52 @@ async def run_transcription(args): # Save results if output file specified if args.output_text: client.save_responses(args.output_text) - + + except KeyboardInterrupt: + if hasattr(args, '_interruptible_iterator'): + args._interruptible_iterator.stop() + print("Audio input stopped") + + # Cancel the send task and wait for it to finish + if send_task and not send_task.done(): + print("Cancelling send task...") + send_task.cancel() + try: + await send_task + except asyncio.CancelledError: + pass + print("Send task cancelled") + + # Wait a bit for the receive task to process any remaining audio + if receive_task and not receive_task.done(): + print("Processing remaining audio...") + try: + await asyncio.wait_for(receive_task, timeout=5.0) + print("Receive task completed") + except asyncio.TimeoutError: + print("Receive task timeout, cancelling...") + receive_task.cancel() + try: + await receive_task + except asyncio.CancelledError: + pass + print("Receive task cancelled") + + print("Transcription stopped gracefully.") + except Exception as e: - print(f"Error: {e}") + print(f"Error during realtime transcription: {e}") raise + finally: + # Clean up microphone stream if it was created + if hasattr(args, '_mic_stream') and args._mic_stream is not None: + try: + args._mic_stream.close() + print("Microphone stream closed") + except Exception as e: + print(f"Warning: Error closing microphone stream: {e}") + await client.disconnect() From fe3c138bce1eb30307cba3888a83fd9d70857e3e Mon Sep 17 00:00:00 2001 From: Yash Hayaran Date: Mon, 25 Aug 2025 12:31:21 +0530 Subject: [PATCH 4/4] chore: adding http error handling --- riva/client/realtime.py | 135 +++++++++++++++++++++++++---- scripts/asr/realtime_asr_client.py | 2 +- 2 files changed, 118 insertions(+), 19 deletions(-) diff --git a/riva/client/realtime.py b/riva/client/realtime.py index 1bd382ab..01310f0e 100644 --- a/riva/client/realtime.py +++ b/riva/client/realtime.py @@ -54,13 +54,11 @@ async def connect(self): await self._initialize_session() except requests.exceptions.RequestException as e: - logger.error("HTTP request failed: %s", e) raise except WebSocketException as e: logger.error("WebSocket connection failed: %s", e) raise except Exception as e: - logger.error("Unexpected error during connection: %s", e) raise async def _initialize_http_session(self) -> Dict[str, Any]: @@ -70,23 +68,83 @@ async def _initialize_http_session(self) -> Dict[str, Any]: if self.args.use_ssl: uri = f"https://{self.args.server}/v1/realtime/transcription_sessions" logger.debug("Initializing session via HTTP POST request to: %s", uri) - response = requests.post( - uri, - headers=headers, - json={}, - cert=(self.args.ssl_client_cert, self.args.ssl_client_key) if self.args.ssl_client_cert and self.args.ssl_client_key else None, - verify=self.args.ssl_root_cert if self.args.ssl_root_cert else True - ) - - if response.status_code != 200: - raise Exception( - f"Failed to initialize session. Status: {response.status_code}, " - f"Error: {response.text}" + + try: + response = requests.post( + uri, + headers=headers, + json={}, + cert=(self.args.ssl_client_cert, self.args.ssl_client_key) if self.args.ssl_client_cert and self.args.ssl_client_key else None, + verify=self.args.ssl_root_cert if self.args.ssl_root_cert else True, + timeout=30 # Add timeout to prevent hanging ) - session_data = response.json() - logger.debug("Session initialized: %s", session_data) - return session_data + if response.status_code != 200: + raise Exception( + f"Failed to initialize session. Status: {response.status_code}, " + f"Error: {response.text}" + ) + + session_data = response.json() + logger.debug("Session initialized: %s", session_data) + return session_data + + except requests.exceptions.ConnectionError as e: + # Handle connection errors more gracefully + if "Connection refused" in str(e): + error_msg = f"Cannot connect to server at {self.args.server}. The server may be down or not running." + elif "Name or service not known" in str(e): + error_msg = f"Cannot resolve server hostname '{self.args.server}'. Please check the server address." + elif "timeout" in str(e).lower(): + error_msg = f"Connection to {self.args.server} timed out. The server may be overloaded or unreachable." + elif "Connection aborted." in str(e): + error_msg = f"Connection aborted. Failed to establish a new connection to {self.args.server}. The server may be down or not running." + else: + error_msg = f"Connection failed to {self.args.server}: {str(e)}" + + logger.error("HTTP connection error: %s", error_msg) + raise Exception(error_msg) from e + + except requests.exceptions.SSLError as e: + error_msg = f"SSL/TLS connection failed to {self.args.server}: {str(e)}" + logger.error("SSL error: %s", error_msg) + raise Exception(error_msg) from e + + except requests.exceptions.Timeout as e: + error_msg = f"Request to {self.args.server} timed out after 30 seconds" + logger.error("Request timeout: %s", error_msg) + raise Exception(error_msg) from e + + except requests.exceptions.RequestException as e: + # Handle other HTTP-related errors + if hasattr(e, 'response') and e.response is not None: + status_code = e.response.status_code + try: + error_text = e.response.text + except: + error_text = "Unable to read error response" + + if status_code == 401: + error_msg = f"Authentication failed (401). Please check your SSL certificates and credentials." + elif status_code == 403: + error_msg = f"Access forbidden (403). You may not have permission to access this service." + elif status_code == 404: + error_msg = f"Service not found (404). The transcription service endpoint may not exist at {uri}" + elif status_code == 500: + error_msg = f"Server error (500). The transcription service encountered an internal error." + else: + error_msg = f"HTTP request failed with status {status_code}: {error_text}" + else: + error_msg = f"HTTP request failed: {str(e)}" + + logger.error("HTTP request error: %s", error_msg) + raise Exception(error_msg) from e + + except Exception as e: + # Handle any other unexpected errors + error_msg = f"Unexpected error during HTTP session initialization: {str(e)}" + logger.error("Unexpected error: %s", error_msg) + raise Exception(error_msg) from e async def _connect_websocket(self): """Connect to WebSocket endpoint.""" @@ -107,7 +165,48 @@ async def _connect_websocket(self): # ssl_context.verify_mode = ssl.CERT_REQUIRED logger.debug("Connecting to WebSocket: %s", ws_url) - self.websocket = await websockets.connect(ws_url, ssl=ssl_context) + + try: + self.websocket = await websockets.connect( + ws_url, + ssl=ssl_context, + ping_interval=20, # Send ping every 20 seconds + ping_timeout=10, # Wait 10 seconds for pong response + close_timeout=10 # Wait 10 seconds for close response + ) + except websockets.exceptions.InvalidURI as e: + error_msg = f"Invalid WebSocket URI: {ws_url}. Please check the server address and endpoint." + logger.error("WebSocket URI error: %s", error_msg) + raise Exception(error_msg) from e + except websockets.exceptions.InvalidHandshake as e: + error_msg = f"WebSocket handshake failed. The server may not support WebSocket connections or the endpoint is incorrect." + logger.error("WebSocket handshake error: %s", error_msg) + raise Exception(error_msg) from e + except websockets.exceptions.ConnectionClosed as e: + error_msg = f"WebSocket connection was closed unexpectedly: {str(e)}" + logger.error("WebSocket connection closed: %s", error_msg) + raise Exception(error_msg) from e + except websockets.exceptions.WebSocketException as e: + error_msg = f"WebSocket connection failed: {str(e)}" + logger.error("WebSocket error: %s", error_msg) + raise Exception(error_msg) from e + except ConnectionRefusedError as e: + error_msg = f"Cannot connect to WebSocket server at {self.args.server}. The server may be down or not running." + logger.error("WebSocket connection refused: %s", error_msg) + raise Exception(error_msg) from e + except OSError as e: + if "Name or service not known" in str(e): + error_msg = f"Cannot resolve server hostname '{self.args.server}'. Please check the server address." + elif "timeout" in str(e).lower(): + error_msg = f"WebSocket connection to {self.args.server} timed out." + else: + error_msg = f"WebSocket connection failed: {str(e)}" + logger.error("WebSocket OS error: %s", error_msg) + raise Exception(error_msg) from e + except Exception as e: + error_msg = f"Unexpected error during WebSocket connection: {str(e)}" + logger.error("Unexpected WebSocket error: %s", error_msg) + raise Exception(error_msg) from e async def _initialize_session(self): """Initialize the WebSocket session.""" diff --git a/scripts/asr/realtime_asr_client.py b/scripts/asr/realtime_asr_client.py index 045595da..3a6b9131 100644 --- a/scripts/asr/realtime_asr_client.py +++ b/scripts/asr/realtime_asr_client.py @@ -275,7 +275,7 @@ async def run_transcription(args): print("Transcription stopped gracefully.") except Exception as e: - print(f"Error during realtime transcription: {e}") + print(f"Error during realtime transcription. Aborting transcription.") raise finally: