Skip to content

Commit d72c677

Browse files
committed
Linting and Smart Turn Refactor
1 parent c3d1dd6 commit d72c677

2 files changed

Lines changed: 76 additions & 42 deletions

File tree

sdk/voice/speechmatics/voice/_smart_turn.py

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ class SmartTurnDetector:
7878
Further information at https://github.com/pipecat-ai/smart-turn
7979
"""
8080

81+
WINDOW_SECONDS = 8
82+
DEFAULT_SAMPLE_RATE = 16000
83+
8184
def __init__(self, auto_init: bool = True, threshold: float = 0.8):
8285
"""Create the new SmartTurnDetector.
8386
@@ -125,7 +128,7 @@ def setup(self) -> None:
125128
self.session = self.build_session(SMART_TURN_MODEL_LOCAL_PATH)
126129

127130
# Load the feature extractor
128-
self.feature_extractor = WhisperFeatureExtractor(chunk_length=8)
131+
self.feature_extractor = WhisperFeatureExtractor(chunk_length=self.WINDOW_SECONDS)
129132

130133
# Set initialized
131134
self._is_initialized = True
@@ -156,83 +159,113 @@ def build_session(self, onnx_path: str) -> ort.InferenceSession:
156159
# Return the new session
157160
return ort.InferenceSession(onnx_path, sess_options=so)
158161

159-
async def predict(
160-
self, audio_array: bytes, language: str, sample_rate: int = 16000, sample_width: int = 2
161-
) -> SmartTurnPredictionResult:
162-
"""Predict whether an audio segment is complete (turn ended) or incomplete.
162+
def _prepare_audio(self, audio_array: bytes, sample_rate: int, sample_width: int) -> np.ndarray:
163+
"""Prepare the audio for inference.
163164
164165
Args:
165166
audio_array: Numpy array containing audio samples at 16kHz. The function
166167
will convert the audio into float32 and truncate to 8 seconds (keeping the end)
167168
or pad to 8 seconds.
168-
language: Language of the audio.
169169
sample_rate: Sample rate of the audio.
170170
sample_width: Sample width of the audio.
171171
172172
Returns:
173-
Prediction result containing completion status and probability.
173+
Numpy array containing audio samples at 16kHz.
174174
"""
175-
176-
# Check if initialized
177-
if not self._is_initialized:
178-
return SmartTurnPredictionResult(error="SmartTurnDetector is not initialized")
179-
180-
# Check a valid language
181-
if not self.valid_language(language):
182-
logger.warning(f"Invalid language: {language}. Results may be unreliable.")
183-
184-
# Record start time
185-
start_time = datetime.datetime.now()
186-
187175
# Convert into numpy array
188176
dtype = np.int16 if sample_width == 2 else np.int8
189177
int16_array: np.ndarray = np.frombuffer(audio_array, dtype=dtype).astype(np.int16)
190178

191-
# Truncate to last 8 seconds if needed (keep the tail/end of audio)
192-
max_samples = 8 * sample_rate
179+
# Truncate to last WINDOW_SECONDS seconds if needed (keep the tail/end of audio)
180+
max_samples = self.WINDOW_SECONDS * sample_rate
193181
if len(int16_array) > max_samples:
194182
int16_array = int16_array[-max_samples:]
195183

196184
# Convert int16 to float32 in range [-1, 1] (same as reference implementation)
197185
float32_array: np.ndarray = int16_array.astype(np.float32) / 32768.0
198186

199-
# Process audio using Whisper's feature extractor
187+
return float32_array
188+
189+
def _get_input_features(self, audio_data: np.ndarray, sample_rate: int) -> np.ndarray:
190+
"""
191+
Get the input features for the audio data using Whisper's feature extractor.
192+
193+
Args:
194+
audio_data: Numpy array containing audio samples.
195+
sample_rate: Sample rate of the audio.
196+
"""
197+
200198
inputs = self.feature_extractor(
201-
float32_array,
199+
audio_data,
202200
sampling_rate=sample_rate,
203201
return_tensors="np",
204202
padding="max_length",
205-
max_length=max_samples,
203+
max_length= self.WINDOW_SECONDS * sample_rate,
206204
truncation=True,
207205
do_normalize=True,
208206
)
209207

210-
# Extract features and ensure correct shape for ONNX
208+
# Ensure dimensions are correct shape for ONNX
211209
input_features = inputs.input_features.squeeze(0).astype(np.float32)
212210
input_features = np.expand_dims(input_features, axis=0)
213211

214-
# Run ONNX inference
215-
outputs = self.session.run(None, {"input_features": input_features})
212+
return input_features
213+
214+
async def predict(
215+
self, audio_array: bytes, language: str, sample_rate: int = DEFAULT_SAMPLE_RATE, sample_width: int = 2
216+
) -> SmartTurnPredictionResult:
217+
"""Predict whether an audio segment is complete (turn ended) or incomplete.
218+
219+
Args:
220+
audio_array: Numpy array containing audio samples at 16kHz. The function
221+
will convert the audio into float32 and truncate to 8 seconds (keeping the end)
222+
or pad to 8 seconds.
223+
language: Language of the audio.
224+
sample_rate: Sample rate of the audio.
225+
sample_width: Sample width of the audio.
226+
227+
Returns:
228+
Prediction result containing completion status and probability.
229+
"""
216230

217-
# Extract probability (ONNX model returns sigmoid probabilities)
231+
# Check if initialized
232+
if not self._is_initialized:
233+
return SmartTurnPredictionResult(error="SmartTurnDetector is not initialized")
234+
235+
# Check a valid language
236+
if not self.valid_language(language):
237+
logger.warning(f"Invalid language: {language}. Results may be unreliable.")
238+
239+
# Record start time
240+
start_time = datetime.datetime.now()
241+
242+
# Convert the audio into required format
243+
prepared_audio = self._prepare_audio(audio_array, sample_rate, sample_width)
244+
245+
# Feature extraction
246+
input_features = self._get_input_features(prepared_audio, sample_rate)
247+
248+
# Model inference
249+
outputs = self.session.run(None, {"input_features": input_features})
218250
probability = outputs[0][0].item()
219251

220252
# Make prediction (True for Complete, False for Incomplete)
221253
prediction = probability >= self._threshold
222254

223-
# Record end time
255+
# Result Formatting
224256
end_time = datetime.datetime.now()
257+
duration = float((end_time - start_time).total_seconds())
225258

226259
# Return the result
227260
return SmartTurnPredictionResult(
228261
prediction=prediction,
229262
probability=round(probability, 3),
230-
processing_time=round(float((end_time - start_time).total_seconds()), 3),
263+
processing_time=round(duration, 3),
231264
)
232265

233266
@staticmethod
234267
def truncate_audio_to_last_n_seconds(
235-
audio_array: np.ndarray, n_seconds: float = 8.0, sample_rate: int = 16000
268+
audio_array: np.ndarray, n_seconds: float = 8.0, sample_rate: int = DEFAULT_SAMPLE_RATE
236269
) -> np.ndarray:
237270
"""Truncate audio to last n seconds or pad with zeros to meet n seconds.
238271
@@ -300,7 +333,8 @@ def model_exists() -> bool:
300333

301334
@staticmethod
302335
def valid_language(language: str) -> bool:
303-
"""Check if the language is valid.
336+
"""Check if the language is valid against list of supported languages
337+
for the Pipecat model.
304338
305339
Args:
306340
language: Language code to validate.

sdk/voice/speechmatics/voice/_vad.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def process_chunk(self, chunk_f32: np.ndarray) -> float:
225225

226226
# Return probability (out shape is (1, 1))
227227
return float(out[0][0])
228-
228+
229229
def _validate_input(self, sample_rate: int) -> bool:
230230
"""
231231
Ensures the VAD is ready and the incoming audio format
@@ -244,7 +244,7 @@ def _validate_input(self, sample_rate: int) -> bool:
244244
if sample_rate != SILERO_SAMPLE_RATE:
245245
logger.error(f"Sample rate must be {SILERO_SAMPLE_RATE}Hz, got {sample_rate}Hz")
246246
return False
247-
247+
248248
return True
249249

250250
def _get_audio_chunks(self, sample_width: int):
@@ -282,17 +282,17 @@ def _prepare_chunk(self, chunk_bytes: bytes, sample_width: int) -> np.ndarray:
282282
"""
283283
if sample_width == 2:
284284
dtype = np.int16
285-
divisor = 32768.0
285+
divisor = 32768.0
286286
elif sample_width == 1:
287287
dtype = np.int8
288288
divisor = 128.0
289289
else:
290290
raise ValueError(f"Unsupported sample_width {sample_width}")
291291

292-
# Decode and normalize the chunk data
292+
# Decode and normalize the chunk data
293293
int_array = np.frombuffer(chunk_bytes, dtype=dtype)
294294
float32_array: np.ndarray = int_array.astype(np.float32) / divisor
295-
295+
296296
return float32_array
297297

298298
def _evaluate_activity_change(self) -> None:
@@ -353,7 +353,7 @@ def _trigger_on_state_change(self, is_speech: bool, probability: float) -> None:
353353

354354
# Trigger callback with result
355355
self._on_state_change(result)
356-
356+
357357
async def process_audio(self, audio_bytes: bytes, sample_rate: int = 16000, sample_width: int = 2) -> None:
358358
"""Process incoming audio bytes and invoke callback on state changes.
359359
@@ -367,21 +367,21 @@ async def process_audio(self, audio_bytes: bytes, sample_rate: int = 16000, samp
367367
"""
368368
if not self._validate_input(sample_rate):
369369
return
370-
370+
371371
# Add new bytes to the buffer
372372
self._audio_buffer += audio_bytes
373-
373+
374374
# Process all complete chunks in the buffer
375375
for chunk in self._get_audio_chunks(sample_width):
376376
audio_f32 = self._prepare_chunk(chunk, sample_width)
377-
377+
378378
try:
379379
probability = self.process_chunk(audio_f32)
380380
self._prediction_window.append(probability)
381381
except Exception as e:
382382
logger.error(f"Error processing VAD chunk: {e}")
383383
continue
384-
384+
385385
# Check if VAD state has changed
386386
self._evaluate_activity_change()
387387

0 commit comments

Comments
 (0)