Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion kittentts/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ def basic_english_tokenize(text):
tokens = re.findall(r"\w+|[^\w\s]", text)
return tokens


def mono_audio_array(audio):
"""Return generated mono audio as a flat samples array."""
return np.asarray(audio).squeeze()


class TextCleaner:
def __init__(self, dummy=None):
_pad = "$"
Expand Down Expand Up @@ -154,7 +160,7 @@ def generate_single_chunk(self, text: str, voice: str = "expr-voice-5-m", speed:
# Trim audio
audio = outputs[0][..., :-5000]

return audio
return mono_audio_array(audio)

def generate_to_file(self, text: str, output_path: str, voice: str = "expr-voice-5-m",
speed: float = 1.0, sample_rate: int = 24000, clean_text: bool=True) -> None:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_text_normalization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import unittest
import sys
import types
from unittest.mock import patch

from kittentts import NormalizedTextResult, normalize_text
from kittentts.preprocess import chunk_text
Expand Down Expand Up @@ -62,5 +65,48 @@ def test_unsupported_locale_fails_explicitly(self):
normalize_text("Bonjour 2026", locale="fr-FR")


class AudioArrayTests(unittest.TestCase):
def test_mono_audio_array_flattens_single_channel_output(self):
espeakng_loader = types.SimpleNamespace(
get_library_path=lambda: "/tmp/libespeak-ng.dylib",
get_data_path=lambda: "/tmp/espeak-ng-data",
)
espeak_wrapper = types.SimpleNamespace(set_library=lambda path: None)
phonemizer = types.SimpleNamespace(
backend=types.SimpleNamespace(
EspeakBackend=object,
espeak=types.SimpleNamespace(wrapper=types.SimpleNamespace(EspeakWrapper=espeak_wrapper)),
)
)

with patch.dict(
sys.modules,
{
"espeakng_loader": espeakng_loader,
"phonemizer": phonemizer,
"phonemizer.backend": phonemizer.backend,
"phonemizer.backend.espeak": phonemizer.backend.espeak,
"phonemizer.backend.espeak.wrapper": phonemizer.backend.espeak.wrapper,
"numpy": types.SimpleNamespace(asarray=lambda audio: audio, ndarray=object),
"onnxruntime": types.SimpleNamespace(),
"soundfile": types.SimpleNamespace(),
},
):
from kittentts.onnx_model import mono_audio_array

audio = mono_audio_array(_FakeAudioArray((1, 24000), (24000,)))

self.assertEqual(audio.shape, (24000,))


class _FakeAudioArray:
def __init__(self, shape, squeezed_shape):
self.shape = shape
self._squeezed_shape = squeezed_shape

def squeeze(self):
return _FakeAudioArray(self._squeezed_shape, self._squeezed_shape)


if __name__ == "__main__":
unittest.main()