From 15127753edb9691bca4d498fde0bcf408752aed1 Mon Sep 17 00:00:00 2001 From: Dewan Shakil Date: Thu, 21 May 2026 17:31:14 +0530 Subject: [PATCH] Return flat mono audio arrays --- kittentts/onnx_model.py | 8 +++++- tests/test_text_normalization.py | 46 ++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/kittentts/onnx_model.py b/kittentts/onnx_model.py index b368e4f..71c3519 100644 --- a/kittentts/onnx_model.py +++ b/kittentts/onnx_model.py @@ -15,6 +15,12 @@ def basic_english_tokenize(text): tokens = re.findall(r"\w+|[^\w\s]", text) return tokens + +def mono_audio_array(audio): + """Return generated mono audio as a flat samples array.""" + return np.asarray(audio).squeeze() + + class TextCleaner: def __init__(self, dummy=None): _pad = "$" @@ -154,7 +160,7 @@ def generate_single_chunk(self, text: str, voice: str = "expr-voice-5-m", speed: # Trim audio audio = outputs[0][..., :-5000] - return audio + return mono_audio_array(audio) def generate_to_file(self, text: str, output_path: str, voice: str = "expr-voice-5-m", speed: float = 1.0, sample_rate: int = 24000, clean_text: bool=True) -> None: diff --git a/tests/test_text_normalization.py b/tests/test_text_normalization.py index 6a9eb92..dfa494d 100644 --- a/tests/test_text_normalization.py +++ b/tests/test_text_normalization.py @@ -1,4 +1,7 @@ import unittest +import sys +import types +from unittest.mock import patch from kittentts import NormalizedTextResult, normalize_text from kittentts.preprocess import chunk_text @@ -62,5 +65,48 @@ def test_unsupported_locale_fails_explicitly(self): normalize_text("Bonjour 2026", locale="fr-FR") +class AudioArrayTests(unittest.TestCase): + def test_mono_audio_array_flattens_single_channel_output(self): + espeakng_loader = types.SimpleNamespace( + get_library_path=lambda: "/tmp/libespeak-ng.dylib", + get_data_path=lambda: "/tmp/espeak-ng-data", + ) + espeak_wrapper = types.SimpleNamespace(set_library=lambda path: None) + phonemizer = types.SimpleNamespace( + backend=types.SimpleNamespace( + EspeakBackend=object, + espeak=types.SimpleNamespace(wrapper=types.SimpleNamespace(EspeakWrapper=espeak_wrapper)), + ) + ) + + with patch.dict( + sys.modules, + { + "espeakng_loader": espeakng_loader, + "phonemizer": phonemizer, + "phonemizer.backend": phonemizer.backend, + "phonemizer.backend.espeak": phonemizer.backend.espeak, + "phonemizer.backend.espeak.wrapper": phonemizer.backend.espeak.wrapper, + "numpy": types.SimpleNamespace(asarray=lambda audio: audio, ndarray=object), + "onnxruntime": types.SimpleNamespace(), + "soundfile": types.SimpleNamespace(), + }, + ): + from kittentts.onnx_model import mono_audio_array + + audio = mono_audio_array(_FakeAudioArray((1, 24000), (24000,))) + + self.assertEqual(audio.shape, (24000,)) + + +class _FakeAudioArray: + def __init__(self, shape, squeezed_shape): + self.shape = shape + self._squeezed_shape = squeezed_shape + + def squeeze(self): + return _FakeAudioArray(self._squeezed_shape, self._squeezed_shape) + + if __name__ == "__main__": unittest.main()