11#!/usr/bin/env python3
22"""
3- gRPC server of LocalAI for NVIDIA NEMO Toolkit ASR.
3+ GRPC server of LocalAI for NVIDIA NEMO Toolkit ASR.
44"""
55from concurrent import futures
66import time
1212import backend_pb2_grpc
1313import torch
1414import nemo .collections .asr as nemo_asr
15+ import numpy as np
16+
17+ try :
18+ import torchaudio
19+ TORCHAUDIO_AVAILABLE = True
20+ except ImportError :
21+ TORCHAUDIO_AVAILABLE = False
22+ print ("[WARNING] torchaudio not available, will use fallback audio loading" , file = sys .stderr )
1523
1624import grpc
1725
@@ -36,6 +44,50 @@ def is_int(s):
3644MAX_WORKERS = int (os .environ .get ('PYTHON_GRPC_MAX_WORKERS' , '1' ))
3745
3846
47+ def load_audio_np (audio_path , target_sample_rate = 16000 ):
48+ """Load audio file as numpy array using available methods."""
49+ if TORCHAUDIO_AVAILABLE :
50+ try :
51+ waveform , sample_rate = torchaudio .load (audio_path )
52+ # Convert to mono if stereo
53+ if waveform .shape [0 ] > 1 :
54+ waveform = waveform .mean (dim = 0 , keepdim = True )
55+ # Resample if needed
56+ if sample_rate != target_sample_rate :
57+ resampler = torchaudio .transforms .Resample (sample_rate , target_sample_rate )
58+ waveform = resampler (waveform )
59+ # Convert to numpy
60+ audio_np = waveform .squeeze ().numpy ()
61+ return audio_np , target_sample_rate
62+ except Exception as e :
63+ print (f"[WARNING] torchaudio loading failed: { e } , trying fallback" , file = sys .stderr )
64+
65+ # Fallback: try using scipy or soundfile
66+ try :
67+ import soundfile as sf
68+ audio_np , sample_rate = sf .read (audio_path )
69+ if audio_np .ndim > 1 :
70+ audio_np = audio_np .mean (axis = 1 )
71+ if sample_rate != target_sample_rate :
72+ from scipy .signal import resample
73+ num_samples = int (len (audio_np ) * target_sample_rate / sample_rate )
74+ audio_np = resample (audio_np , num_samples )
75+ return audio_np , target_sample_rate
76+ except ImportError :
77+ pass
78+
79+ try :
80+ from scipy .io import wavfile
81+ sample_rate , audio_np = wavfile .read (audio_path )
82+ if audio_np .ndim > 1 :
83+ audio_np = audio_np .mean (axis = 1 )
84+ return audio_np , sample_rate
85+ except ImportError :
86+ pass
87+
88+ raise RuntimeError ("No audio loading library available (torchaudio, soundfile, scipy)" )
89+
90+
3991class BackendServicer (backend_pb2_grpc .BackendServicer ):
4092 def Health (self , request , context ):
4193 return backend_pb2 .Reply (message = bytes ("OK" , 'utf-8' ))
@@ -89,14 +141,37 @@ def AudioTranscription(self, request, context):
89141 print (f"Error: Audio file not found: { audio_path } " , file = sys .stderr )
90142 return backend_pb2 .TranscriptResult (segments = [], text = "" )
91143
92- # NEMO's transcribe method accepts a list of audio paths and returns a list of transcripts
93- results = self .model .transcribe ([audio_path ])
94-
144+ # Load audio as numpy array to avoid lhotse dataloader issues
145+ audio_np , sample_rate = load_audio_np (audio_path , target_sample_rate = 16000 )
146+
147+ # Convert to torch tensor
148+ audio_tensor = torch .from_numpy (audio_np ).float ()
149+ audio_tensor = audio_tensor .unsqueeze (0 ) # Add batch dimension
150+
151+ # Use the model's transcribe method with the tensor directly
152+ # Some NEMO models accept audio tensors directly
153+ try :
154+ # Try passing the waveform tensor directly
155+ results = self .model .transcribe (audio_tensor , return_char_alignments = False )
156+ except TypeError :
157+ # Fallback: try with dict format
158+ results = self .model .transcribe (
159+ [{"audio_file" : audio_path }],
160+ return_char_alignments = False
161+ )
162+
95163 if not results or len (results ) == 0 :
164+ print ("[WARNING] No transcription results returned" , file = sys .stderr )
96165 return backend_pb2 .TranscriptResult (segments = [], text = "" )
97166
98167 # Get the transcript text from the first result
99- text = results [0 ]
168+ if isinstance (results , list ) and len (results ) > 0 :
169+ text = results [0 ]
170+ elif isinstance (results , dict ) and "text" in results :
171+ text = results ["text" ]
172+ else :
173+ text = str (results ) if results else ""
174+
100175 if text :
101176 # Create a single segment with the full transcription
102177 result_segments .append (backend_pb2 .TranscriptSegment (
0 commit comments