From 3814ab46f0e36b9a03122477979746f0c8845b11 Mon Sep 17 00:00:00 2001 From: Dewan Shakil Date: Thu, 21 May 2026 17:30:14 +0530 Subject: [PATCH] Support loading local model directories --- README.md | 11 +++++- kittentts/__init__.py | 7 ++-- kittentts/get_model.py | 68 ++++++++++++++++++++++++++++++-- tests/test_text_normalization.py | 62 ++++++++++++++++++++++++++++- 4 files changed, 139 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 6d8bb73..e2925d7 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,15 @@ print(model.available_voices) # ['Bella', 'Jasper', 'Luna', 'Bruno', 'Rosie', 'Hugo', 'Kiki', 'Leo'] ``` +### Loading local model files + +If you already downloaded a model repository, pass the local directory that +contains `config.json`, the ONNX model file, and the voices file: + +```python +model = KittenTTS("/path/to/kitten-tts-mini-0.8") +``` + ### Using with GPU ``` @@ -119,7 +128,7 @@ Load a model from Hugging Face Hub. | Parameter | Type | Default | Description | |---|---|---|---| -| `model_name` | `str` | `"KittenML/kitten-tts-nano-0.8"` | Hugging Face repository ID | +| `model_name` | `str` | `"KittenML/kitten-tts-nano-0.8"` | Hugging Face repository ID, model name, or local model directory | | `cache_dir` | `str` | `None` | Local directory for caching downloaded model files | ### `model.generate(text, voice, speed, clean_text)` diff --git a/kittentts/__init__.py b/kittentts/__init__.py index 6c39e74..71cf4fd 100644 --- a/kittentts/__init__.py +++ b/kittentts/__init__.py @@ -7,6 +7,7 @@ __all__ = [ "get_model", "KittenTTS", + "load_from_local", "normalize_text", "normalize_text_result", "NormalizedSpan", @@ -15,8 +16,8 @@ def __getattr__(name): - if name in {"get_model", "KittenTTS"}: - from kittentts.get_model import KittenTTS, get_model + if name in {"get_model", "KittenTTS", "load_from_local"}: + from kittentts.get_model import KittenTTS, get_model, load_from_local - return {"get_model": get_model, "KittenTTS": KittenTTS}[name] + return {"get_model": get_model, "KittenTTS": KittenTTS, "load_from_local": load_from_local}[name] raise AttributeError(f"module 'kittentts' has no attribute {name!r}") diff --git a/kittentts/get_model.py b/kittentts/get_model.py index d8d2225..b72eb06 100644 --- a/kittentts/get_model.py +++ b/kittentts/get_model.py @@ -1,7 +1,6 @@ import json import os -from huggingface_hub import hf_hub_download -from .onnx_model import KittenTTS_1_Onnx +from pathlib import Path from .preprocess import normalize_text @@ -12,9 +11,14 @@ def __init__(self, model_name="KittenML/kitten-tts-nano-0.8", cache_dir=None, ba """Initialize KittenTTS with a model from Hugging Face. Args: - model_name: Hugging Face repository ID or model name + model_name: Hugging Face repository ID, model name, or local model directory cache_dir: Directory to cache downloaded files """ + local_model_path = _local_model_path(model_name) + if local_model_path is not None: + self.model = load_from_local(local_model_path, backend=backend) + return + # Handle different model name formats if "/" not in model_name: # If just model name provided, assume it's from KittenML @@ -78,6 +82,8 @@ def download_from_huggingface(repo_id="KittenML/kitten-tts-nano-0.1", cache_dir= Returns: KittenTTS_1_Onnx: Instantiated model ready for use """ + from huggingface_hub import hf_hub_download + # Download config file first config_path = hf_hub_download( repo_id=repo_id, @@ -106,11 +112,65 @@ def download_from_huggingface(repo_id="KittenML/kitten-tts-nano-0.1", cache_dir= ) # Instantiate and return model - model = KittenTTS_1_Onnx(model_path=model_path, voices_path=voices_path, speed_priors=config.get("speed_priors", {}) , voice_aliases=config.get("voice_aliases", {}), backend=backend) + model = _create_onnx_model(model_path=model_path, voices_path=voices_path, speed_priors=config.get("speed_priors", {}) , voice_aliases=config.get("voice_aliases", {}), backend=backend) return model +def load_from_local(model_path, backend=None): + """Load model files directly from a local model directory. + + The directory must contain config.json plus the model and voice files named + by that config. + """ + model_dir = Path(model_path).expanduser() + if not model_dir.is_dir(): + raise FileNotFoundError(f"Local model directory not found: {model_dir}") + + config_path = model_dir / "config.json" + if not config_path.is_file(): + raise FileNotFoundError(f"Local model config not found: {config_path}") + + with config_path.open("r", encoding="utf-8") as f: + config = json.load(f) + + if config.get("type") not in ["ONNX1", "ONNX2"]: + raise ValueError("Unsupported model type in local config.") + + try: + model_file = config["model_file"] + voices_file = config["voices"] + except KeyError as exc: + raise ValueError(f"Missing required local model config key: {exc.args[0]}") from exc + + model_file_path = model_dir / model_file + voices_path = model_dir / voices_file + for path in [model_file_path, voices_path]: + if not path.is_file(): + raise FileNotFoundError(f"Local model file not found: {path}") + + return _create_onnx_model( + model_path=str(model_file_path), + voices_path=str(voices_path), + speed_priors=config.get("speed_priors", {}), + voice_aliases=config.get("voice_aliases", {}), + backend=backend, + ) + + def get_model(repo_id="KittenML/kitten-tts-nano-0.1", cache_dir=None, backend=None): """Get a KittenTTS model (legacy function for backward compatibility).""" return KittenTTS(repo_id, cache_dir, backend=backend) + + +def _local_model_path(model_name): + if not isinstance(model_name, (str, os.PathLike)): + return None + path = Path(model_name).expanduser() + return path if path.exists() else None + + +def _create_onnx_model(**kwargs): + from .onnx_model import KittenTTS_1_Onnx + + return KittenTTS_1_Onnx(**kwargs) diff --git a/tests/test_text_normalization.py b/tests/test_text_normalization.py index 6a9eb92..471e06d 100644 --- a/tests/test_text_normalization.py +++ b/tests/test_text_normalization.py @@ -1,6 +1,10 @@ +import json +import tempfile import unittest +from pathlib import Path +from unittest.mock import patch -from kittentts import NormalizedTextResult, normalize_text +from kittentts import KittenTTS, NormalizedTextResult, load_from_local, normalize_text from kittentts.preprocess import chunk_text @@ -62,5 +66,61 @@ def test_unsupported_locale_fails_explicitly(self): normalize_text("Bonjour 2026", locale="fr-FR") +class LocalModelLoadingTests(unittest.TestCase): + def _write_local_model(self, model_dir: Path): + (model_dir / "config.json").write_text( + json.dumps( + { + "type": "ONNX1", + "model_file": "model.onnx", + "voices": "voices.npz", + "speed_priors": {"Bella": 0.95}, + "voice_aliases": {"Bella": "expr-voice-2-f"}, + } + ), + encoding="utf-8", + ) + (model_dir / "model.onnx").write_bytes(b"onnx") + (model_dir / "voices.npz").write_bytes(b"voices") + + def test_load_from_local_uses_configured_files(self): + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + self._write_local_model(model_dir) + + with patch("kittentts.get_model._create_onnx_model") as model_cls: + model = load_from_local(model_dir, backend="cpu") + + self.assertIs(model, model_cls.return_value) + model_cls.assert_called_once_with( + model_path=str(model_dir / "model.onnx"), + voices_path=str(model_dir / "voices.npz"), + speed_priors={"Bella": 0.95}, + voice_aliases={"Bella": "expr-voice-2-f"}, + backend="cpu", + ) + + def test_kittentts_accepts_existing_local_model_directory(self): + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + self._write_local_model(model_dir) + + with patch("kittentts.get_model._create_onnx_model") as model_cls: + model = KittenTTS(str(model_dir), backend="cpu") + + self.assertIs(model.model, model_cls.return_value) + + def test_load_from_local_requires_model_assets(self): + with tempfile.TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + (model_dir / "config.json").write_text( + json.dumps({"type": "ONNX1", "model_file": "missing.onnx", "voices": "voices.npz"}), + encoding="utf-8", + ) + + with self.assertRaises(FileNotFoundError): + load_from_local(model_dir) + + if __name__ == "__main__": unittest.main()