@@ -78,6 +78,9 @@ class SmartTurnDetector:
7878 Further information at https://github.com/pipecat-ai/smart-turn
7979 """
8080
81+ WINDOW_SECONDS = 8
82+ DEFAULT_SAMPLE_RATE = 16000
83+
8184 def __init__ (self , auto_init : bool = True , threshold : float = 0.8 ):
8285 """Create the new SmartTurnDetector.
8386
@@ -125,7 +128,7 @@ def setup(self) -> None:
125128 self .session = self .build_session (SMART_TURN_MODEL_LOCAL_PATH )
126129
127130 # Load the feature extractor
128- self .feature_extractor = WhisperFeatureExtractor (chunk_length = 8 )
131+ self .feature_extractor = WhisperFeatureExtractor (chunk_length = self . WINDOW_SECONDS )
129132
130133 # Set initialized
131134 self ._is_initialized = True
@@ -156,83 +159,113 @@ def build_session(self, onnx_path: str) -> ort.InferenceSession:
156159 # Return the new session
157160 return ort .InferenceSession (onnx_path , sess_options = so )
158161
159- async def predict (
160- self , audio_array : bytes , language : str , sample_rate : int = 16000 , sample_width : int = 2
161- ) -> SmartTurnPredictionResult :
162- """Predict whether an audio segment is complete (turn ended) or incomplete.
162+ def _prepare_audio (self , audio_array : bytes , sample_rate : int , sample_width : int ) -> np .ndarray :
163+ """Prepare the audio for inference.
163164
164165 Args:
165166 audio_array: Numpy array containing audio samples at 16kHz. The function
166167 will convert the audio into float32 and truncate to 8 seconds (keeping the end)
167168 or pad to 8 seconds.
168- language: Language of the audio.
169169 sample_rate: Sample rate of the audio.
170170 sample_width: Sample width of the audio.
171171
172172 Returns:
173- Prediction result containing completion status and probability .
173+ Numpy array containing audio samples at 16kHz .
174174 """
175-
176- # Check if initialized
177- if not self ._is_initialized :
178- return SmartTurnPredictionResult (error = "SmartTurnDetector is not initialized" )
179-
180- # Check a valid language
181- if not self .valid_language (language ):
182- logger .warning (f"Invalid language: { language } . Results may be unreliable." )
183-
184- # Record start time
185- start_time = datetime .datetime .now ()
186-
187175 # Convert into numpy array
188176 dtype = np .int16 if sample_width == 2 else np .int8
189177 int16_array : np .ndarray = np .frombuffer (audio_array , dtype = dtype ).astype (np .int16 )
190178
191- # Truncate to last 8 seconds if needed (keep the tail/end of audio)
192- max_samples = 8 * sample_rate
179+ # Truncate to last WINDOW_SECONDS seconds if needed (keep the tail/end of audio)
180+ max_samples = self . WINDOW_SECONDS * sample_rate
193181 if len (int16_array ) > max_samples :
194182 int16_array = int16_array [- max_samples :]
195183
196184 # Convert int16 to float32 in range [-1, 1] (same as reference implementation)
197185 float32_array : np .ndarray = int16_array .astype (np .float32 ) / 32768.0
198186
199- # Process audio using Whisper's feature extractor
187+ return float32_array
188+
189+ def _get_input_features (self , audio_data : np .ndarray , sample_rate : int ) -> np .ndarray :
190+ """
191+ Get the input features for the audio data using Whisper's feature extractor.
192+
193+ Args:
194+ audio_data: Numpy array containing audio samples.
195+ sample_rate: Sample rate of the audio.
196+ """
197+
200198 inputs = self .feature_extractor (
201- float32_array ,
199+ audio_data ,
202200 sampling_rate = sample_rate ,
203201 return_tensors = "np" ,
204202 padding = "max_length" ,
205- max_length = max_samples ,
203+ max_length = self . WINDOW_SECONDS * sample_rate ,
206204 truncation = True ,
207205 do_normalize = True ,
208206 )
209207
210- # Extract features and ensure correct shape for ONNX
208+ # Ensure dimensions are correct shape for ONNX
211209 input_features = inputs .input_features .squeeze (0 ).astype (np .float32 )
212210 input_features = np .expand_dims (input_features , axis = 0 )
213211
214- # Run ONNX inference
215- outputs = self .session .run (None , {"input_features" : input_features })
212+ return input_features
213+
214+ async def predict (
215+ self , audio_array : bytes , language : str , sample_rate : int = DEFAULT_SAMPLE_RATE , sample_width : int = 2
216+ ) -> SmartTurnPredictionResult :
217+ """Predict whether an audio segment is complete (turn ended) or incomplete.
218+
219+ Args:
220+ audio_array: Numpy array containing audio samples at 16kHz. The function
221+ will convert the audio into float32 and truncate to 8 seconds (keeping the end)
222+ or pad to 8 seconds.
223+ language: Language of the audio.
224+ sample_rate: Sample rate of the audio.
225+ sample_width: Sample width of the audio.
226+
227+ Returns:
228+ Prediction result containing completion status and probability.
229+ """
216230
217- # Extract probability (ONNX model returns sigmoid probabilities)
231+ # Check if initialized
232+ if not self ._is_initialized :
233+ return SmartTurnPredictionResult (error = "SmartTurnDetector is not initialized" )
234+
235+ # Check a valid language
236+ if not self .valid_language (language ):
237+ logger .warning (f"Invalid language: { language } . Results may be unreliable." )
238+
239+ # Record start time
240+ start_time = datetime .datetime .now ()
241+
242+ # Convert the audio into required format
243+ prepared_audio = self ._prepare_audio (audio_array , sample_rate , sample_width )
244+
245+ # Feature extraction
246+ input_features = self ._get_input_features (prepared_audio , sample_rate )
247+
248+ # Model inference
249+ outputs = self .session .run (None , {"input_features" : input_features })
218250 probability = outputs [0 ][0 ].item ()
219251
220252 # Make prediction (True for Complete, False for Incomplete)
221253 prediction = probability >= self ._threshold
222254
223- # Record end time
255+ # Result Formatting
224256 end_time = datetime .datetime .now ()
257+ duration = float ((end_time - start_time ).total_seconds ())
225258
226259 # Return the result
227260 return SmartTurnPredictionResult (
228261 prediction = prediction ,
229262 probability = round (probability , 3 ),
230- processing_time = round (float (( end_time - start_time ). total_seconds ()) , 3 ),
263+ processing_time = round (duration , 3 ),
231264 )
232265
233266 @staticmethod
234267 def truncate_audio_to_last_n_seconds (
235- audio_array : np .ndarray , n_seconds : float = 8.0 , sample_rate : int = 16000
268+ audio_array : np .ndarray , n_seconds : float = 8.0 , sample_rate : int = DEFAULT_SAMPLE_RATE
236269 ) -> np .ndarray :
237270 """Truncate audio to last n seconds or pad with zeros to meet n seconds.
238271
@@ -300,7 +333,8 @@ def model_exists() -> bool:
300333
301334 @staticmethod
302335 def valid_language (language : str ) -> bool :
303- """Check if the language is valid.
336+ """Check if the language is valid against list of supported languages
337+ for the Pipecat model.
304338
305339 Args:
306340 language: Language code to validate.
0 commit comments