Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ print(model.available_voices)
# ['Bella', 'Jasper', 'Luna', 'Bruno', 'Rosie', 'Hugo', 'Kiki', 'Leo']
```

### Loading local model files

If you already downloaded a model repository, pass the local directory that
contains `config.json`, the ONNX model file, and the voices file:

```python
model = KittenTTS("/path/to/kitten-tts-mini-0.8")
```

### Using with GPU

```
Expand All @@ -119,7 +128,7 @@ Load a model from Hugging Face Hub.

| Parameter | Type | Default | Description |
|---|---|---|---|
| `model_name` | `str` | `"KittenML/kitten-tts-nano-0.8"` | Hugging Face repository ID |
| `model_name` | `str` | `"KittenML/kitten-tts-nano-0.8"` | Hugging Face repository ID, model name, or local model directory |
| `cache_dir` | `str` | `None` | Local directory for caching downloaded model files |

### `model.generate(text, voice, speed, clean_text)`
Expand Down
7 changes: 4 additions & 3 deletions kittentts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
__all__ = [
"get_model",
"KittenTTS",
"load_from_local",
"normalize_text",
"normalize_text_result",
"NormalizedSpan",
Expand All @@ -15,8 +16,8 @@


def __getattr__(name):
if name in {"get_model", "KittenTTS"}:
from kittentts.get_model import KittenTTS, get_model
if name in {"get_model", "KittenTTS", "load_from_local"}:
from kittentts.get_model import KittenTTS, get_model, load_from_local

return {"get_model": get_model, "KittenTTS": KittenTTS}[name]
return {"get_model": get_model, "KittenTTS": KittenTTS, "load_from_local": load_from_local}[name]
raise AttributeError(f"module 'kittentts' has no attribute {name!r}")
68 changes: 64 additions & 4 deletions kittentts/get_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import os
from huggingface_hub import hf_hub_download
from .onnx_model import KittenTTS_1_Onnx
from pathlib import Path
from .preprocess import normalize_text


Expand All @@ -12,9 +11,14 @@ def __init__(self, model_name="KittenML/kitten-tts-nano-0.8", cache_dir=None, ba
"""Initialize KittenTTS with a model from Hugging Face.

Args:
model_name: Hugging Face repository ID or model name
model_name: Hugging Face repository ID, model name, or local model directory
cache_dir: Directory to cache downloaded files
"""
local_model_path = _local_model_path(model_name)
if local_model_path is not None:
self.model = load_from_local(local_model_path, backend=backend)
return

# Handle different model name formats
if "/" not in model_name:
# If just model name provided, assume it's from KittenML
Expand Down Expand Up @@ -78,6 +82,8 @@ def download_from_huggingface(repo_id="KittenML/kitten-tts-nano-0.1", cache_dir=
Returns:
KittenTTS_1_Onnx: Instantiated model ready for use
"""
from huggingface_hub import hf_hub_download

# Download config file first
config_path = hf_hub_download(
repo_id=repo_id,
Expand Down Expand Up @@ -106,11 +112,65 @@ def download_from_huggingface(repo_id="KittenML/kitten-tts-nano-0.1", cache_dir=
)

# Instantiate and return model
model = KittenTTS_1_Onnx(model_path=model_path, voices_path=voices_path, speed_priors=config.get("speed_priors", {}) , voice_aliases=config.get("voice_aliases", {}), backend=backend)
model = _create_onnx_model(model_path=model_path, voices_path=voices_path, speed_priors=config.get("speed_priors", {}) , voice_aliases=config.get("voice_aliases", {}), backend=backend)

return model


def load_from_local(model_path, backend=None):
"""Load model files directly from a local model directory.

The directory must contain config.json plus the model and voice files named
by that config.
"""
model_dir = Path(model_path).expanduser()
if not model_dir.is_dir():
raise FileNotFoundError(f"Local model directory not found: {model_dir}")

config_path = model_dir / "config.json"
if not config_path.is_file():
raise FileNotFoundError(f"Local model config not found: {config_path}")

with config_path.open("r", encoding="utf-8") as f:
config = json.load(f)

if config.get("type") not in ["ONNX1", "ONNX2"]:
raise ValueError("Unsupported model type in local config.")

try:
model_file = config["model_file"]
voices_file = config["voices"]
except KeyError as exc:
raise ValueError(f"Missing required local model config key: {exc.args[0]}") from exc

model_file_path = model_dir / model_file
voices_path = model_dir / voices_file
for path in [model_file_path, voices_path]:
if not path.is_file():
raise FileNotFoundError(f"Local model file not found: {path}")

return _create_onnx_model(
model_path=str(model_file_path),
voices_path=str(voices_path),
speed_priors=config.get("speed_priors", {}),
voice_aliases=config.get("voice_aliases", {}),
backend=backend,
)


def get_model(repo_id="KittenML/kitten-tts-nano-0.1", cache_dir=None, backend=None):
"""Get a KittenTTS model (legacy function for backward compatibility)."""
return KittenTTS(repo_id, cache_dir, backend=backend)


def _local_model_path(model_name):
if not isinstance(model_name, (str, os.PathLike)):
return None
path = Path(model_name).expanduser()
return path if path.exists() else None


def _create_onnx_model(**kwargs):
from .onnx_model import KittenTTS_1_Onnx

return KittenTTS_1_Onnx(**kwargs)
62 changes: 61 additions & 1 deletion tests/test_text_normalization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import json
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch

from kittentts import NormalizedTextResult, normalize_text
from kittentts import KittenTTS, NormalizedTextResult, load_from_local, normalize_text
from kittentts.preprocess import chunk_text


Expand Down Expand Up @@ -62,5 +66,61 @@ def test_unsupported_locale_fails_explicitly(self):
normalize_text("Bonjour 2026", locale="fr-FR")


class LocalModelLoadingTests(unittest.TestCase):
def _write_local_model(self, model_dir: Path):
(model_dir / "config.json").write_text(
json.dumps(
{
"type": "ONNX1",
"model_file": "model.onnx",
"voices": "voices.npz",
"speed_priors": {"Bella": 0.95},
"voice_aliases": {"Bella": "expr-voice-2-f"},
}
),
encoding="utf-8",
)
(model_dir / "model.onnx").write_bytes(b"onnx")
(model_dir / "voices.npz").write_bytes(b"voices")

def test_load_from_local_uses_configured_files(self):
with tempfile.TemporaryDirectory() as tmpdir:
model_dir = Path(tmpdir)
self._write_local_model(model_dir)

with patch("kittentts.get_model._create_onnx_model") as model_cls:
model = load_from_local(model_dir, backend="cpu")

self.assertIs(model, model_cls.return_value)
model_cls.assert_called_once_with(
model_path=str(model_dir / "model.onnx"),
voices_path=str(model_dir / "voices.npz"),
speed_priors={"Bella": 0.95},
voice_aliases={"Bella": "expr-voice-2-f"},
backend="cpu",
)

def test_kittentts_accepts_existing_local_model_directory(self):
with tempfile.TemporaryDirectory() as tmpdir:
model_dir = Path(tmpdir)
self._write_local_model(model_dir)

with patch("kittentts.get_model._create_onnx_model") as model_cls:
model = KittenTTS(str(model_dir), backend="cpu")

self.assertIs(model.model, model_cls.return_value)

def test_load_from_local_requires_model_assets(self):
with tempfile.TemporaryDirectory() as tmpdir:
model_dir = Path(tmpdir)
(model_dir / "config.json").write_text(
json.dumps({"type": "ONNX1", "model_file": "missing.onnx", "voices": "voices.npz"}),
encoding="utf-8",
)

with self.assertRaises(FileNotFoundError):
load_from_local(model_dir)


if __name__ == "__main__":
unittest.main()