From 4a60822ba878b1704bf22c4fd78166ba0648c98f Mon Sep 17 00:00:00 2001 From: Aditya Date: Tue, 3 Feb 2026 01:46:49 +0530 Subject: [PATCH] refactor: tts and stt up-to-date with latest spec and modularized client --- smallestai/waves/__init__.py | 7 +- smallestai/waves/async_waves_client.py | 282 ++++++++++++------------- smallestai/waves/models.py | 31 ++- smallestai/waves/stream_tts.py | 51 ++++- smallestai/waves/utils.py | 120 ++++++----- smallestai/waves/waves_client.py | 264 ++++++++++++----------- 6 files changed, 400 insertions(+), 355 deletions(-) diff --git a/smallestai/waves/__init__.py b/smallestai/waves/__init__.py index 95d83d63..8162cea3 100644 --- a/smallestai/waves/__init__.py +++ b/smallestai/waves/__init__.py @@ -2,4 +2,9 @@ from smallestai.waves.async_waves_client import AsyncWavesClient from smallestai.waves.stream_tts import WavesStreamingTTS, TTSConfig -__all__ = ["WavesClient", "AsyncWavesClient", "WavesStreamingTTS", "TTSConfig"] \ No newline at end of file +__all__ = [ + "WavesClient", + "AsyncWavesClient", + "WavesStreamingTTS", + "TTSConfig", +] \ No newline at end of file diff --git a/smallestai/waves/async_waves_client.py b/smallestai/waves/async_waves_client.py index b583d125..bd8c694f 100644 --- a/smallestai/waves/async_waves_client.py +++ b/smallestai/waves/async_waves_client.py @@ -1,99 +1,55 @@ import os -import copy import json import aiohttp import aiofiles import requests -from typing import Optional, Union, List +from typing import Optional, List from smallestai.waves.exceptions import InvalidError, APIError -from smallestai.waves.utils import (TTSOptions, validate_input, validate_asr_input, - get_smallest_languages, get_smallest_models, ALLOWED_AUDIO_EXTENSIONS, API_BASE_URL) +from smallestai.waves.utils import (validate_tts_input, validate_stt_input, + get_smallest_languages, get_tts_models, get_stt_models, ALLOWED_AUDIO_EXTENSIONS, API_BASE_URL, + DEFAULT_SAMPLE_RATES) class AsyncWavesClient: - def __init__( - self, - api_key: str = None, - model: Optional[str] = "lightning", - sample_rate: Optional[int] = 24000, - voice_id: Optional[str] = "emily", - speed: Optional[float] = 1.0, - consistency: Optional[float] = 0.5, - similarity: Optional[float] = 0.0, - enhancement: Optional[int] = 1, - language: Optional[str] = "en", - output_format: Optional[str] = "wav" - ) -> None: + def __init__(self, api_key: str = None) -> None: """ - AsyncSmallest Instance for asynchronous text-to-speech synthesis. - - This class provides an asynchronous implementation of the text-to-speech functionality. - It allows for non-blocking synthesis of speech from text, making it suitable for applications - that require async processing. + Asynchronous Waves Client for Text-to-Speech and Speech-to-Text. Args: - - api_key (str): The API key for authentication, export it as 'SMALLEST_API_KEY' in your environment variables. - - model (TTSModels): The model to be used for synthesis. - - sample_rate (int): The sample rate for the audio output. - - voice_id (TTSVoices): The voice to be used for synthesis. - - speed (float): The speed of the speech synthesis. - - consistency (float): This parameter controls word repetition and skipping. Decrease it to prevent skipped words, and increase it to prevent repetition. Only supported in `lightning-large` model. Range - [0, 1] - - similarity (float): This parameter controls the similarity between the synthesized audio and the reference audio. Increase it to make the speech more similar to the reference audio. Only supported in `lightning-large` model. Range - [0, 1] - - enhancement (int): Enhances speech quality at the cost of increased latency. Only supported in `lightning-large` model. Range - [0, 2]. - - language (str): The language for synthesis. Default is "en". - - output_format (str): The output audio format. Options: "pcm", "mp3", "wav", "mulaw". Default is "pcm". + - api_key (str): The API key for authentication. + Set via parameter or 'SMALLEST_API_KEY' environment variable. Methods: - - get_languages: Returns a list of available languages for synthesis. - - get_voices: Returns a list of available voices for synthesis. - - get_models: Returns a list of available models for synthesis. - - synthesize: Asynchronously converts the provided text into speech and returns the audio content. + - synthesize: Async text to speech. + - transcribe: Async speech to text. + - get_languages: Returns available languages for a model. + - get_voices: Returns available voices for a model. + - get_models: Returns available TTS models. """ self.api_key = api_key or os.environ.get("SMALLEST_API_KEY") if not self.api_key: raise InvalidError() - if model == "lightning-large" and voice_id is None: - voice_id = "lakshya" - - self.chunk_size = 250 - - self.opts = TTSOptions( - model=model, - sample_rate=sample_rate, - voice_id=voice_id, - api_key=self.api_key, - speed=speed, - consistency=consistency, - similarity=similarity, - enhancement=enhancement, - language=language, - output_format=output_format - ) self.session = None - async def __aenter__(self): if self.session is None: self.session = aiohttp.ClientSession() return self - async def __aexit__(self, exc_type, exc_val, exc_tb): if self.session: await self.session.close() - async def _ensure_session(self): """Ensure session exists for direct calls""" if not self.session: self.session = aiohttp.ClientSession() return True return False - - def get_languages(self, model="lightning") -> List[str]: - """Returns a list of available languages.""" + def get_languages(self, model: str = "lightning-v3.1") -> List[str]: + """Returns a list of available languages for a model (TTS or STT).""" return get_smallest_languages(model) def get_cloned_voices(self) -> str: @@ -107,13 +63,9 @@ def get_cloned_voices(self) -> str: raise APIError(f"Failed to get cloned voices: {res.text}. For more information, visit https://waves.smallest.ai/") return json.dumps(res.json(), indent=4, ensure_ascii=False) - - def get_voices( - self, - model: Optional[str] = "lightning" - ) -> str: - """Returns a list of available voices.""" + def get_voices(self, model: str = "lightning-v3.1") -> str: + """Returns a list of available voices for a TTS model.""" headers = { "Authorization": f"Bearer {self.api_key}", } @@ -124,81 +76,91 @@ def get_voices( return json.dumps(res.json(), indent=4, ensure_ascii=False) + def get_tts_models(self) -> List[str]: + """Returns a list of available TTS models.""" + return get_tts_models() - def get_models(self) -> List[str]: - """Returns a list of available models.""" - return get_smallest_models() - + def get_stt_models(self) -> List[str]: + """Returns a list of available STT models.""" + return get_stt_models() async def synthesize( self, text: str, - **kwargs - ) -> Union[bytes]: + model: str = "lightning-v3.1", + voice_id: Optional[str] = None, + sample_rate: Optional[int] = None, + speed: float = 1.0, + language: str = "en", + output_format: str = "wav", + consistency: Optional[float] = 0.5, + similarity: Optional[float] = 0.0, + enhancement: Optional[int] = 1, + pronunciation_dicts: Optional[List[str]] = None + ) -> bytes: """ - Asynchronously synthesize speech from the provided text. + Async synthesize speech from text. Args: - - text (str): The text to be converted to speech. - - stream (Optional[bool]): If True, returns an iterator yielding audio chunks instead of a full byte array. - - kwargs: Additional optional parameters to override `__init__` options for this call. + - text (str): The text to convert to speech. + - model (str): TTS model. Options: "lightning-v3.1", "lightning-v2". Default: "lightning-v3.1". + - voice_id (str): Voice ID. Default: "sophia" for v3.1, "alice" for v2. + - sample_rate (int): Sample rate in Hz. Default: 44100 for v3.1, 24000 for v2. + - speed (float): Speech speed (0.5-2.0). Default: 1.0. + - language (str): Language code. Default: "en". + - output_format (str): Output format ("pcm", "mp3", "wav", "mulaw"). Default: "wav". + - consistency (float): Word repetition control (0-1). Only for lightning-v2. Default: 0.5. + - similarity (float): Reference audio similarity (0-1). Only for lightning-v2. Default: 0.0. + - enhancement (int): Quality enhancement (0-2). Only for lightning-v2. Default: 1. + - pronunciation_dicts (List[str]): Pronunciation dictionary IDs. Default: None. Returns: - - Union[bytes, None, Iterator[bytes]]: - - If `stream=True`, returns an iterator yielding audio chunks. - - If `save_as` is provided, saves the file and returns None. - - Otherwise, returns the synthesized audio content as bytes. + - bytes: The synthesized audio content. Raises: - - InvalidError: If the provided file name does not have a .wav or .mp3 extension when `save_as` is specified. - - APIError: If the API request fails or returns an error. - - ValueError: If an unexpected parameter is passed in `kwargs`. + - ValidationError: If input parameters are invalid. + - APIError: If the API request fails. """ - should_cleanup = False + if sample_rate is None: + sample_rate = DEFAULT_SAMPLE_RATES.get(model, 24000) + + if voice_id is None: + voice_id = "sophia" if model == "lightning-v3.1" else "alice" + + validate_tts_input(text, model, sample_rate, speed, consistency, similarity, enhancement) + should_cleanup = False if self.session is None or self.session.closed: self.session = aiohttp.ClientSession() - should_cleanup = True # Cleanup only if we created a new session + should_cleanup = True try: - opts = copy.deepcopy(self.opts) - valid_keys = set(vars(opts).keys()) - - invalid_keys = [key for key in kwargs if key not in valid_keys] - if invalid_keys: - raise ValueError(f"Invalid parameter(s) in kwargs: {', '.join(invalid_keys)}. Allowed parameters are: {', '.join(valid_keys)}") - - for key, value in kwargs.items(): - setattr(opts, key, value) - - validate_input(text, opts.model, opts.sample_rate, opts.speed, opts.consistency, opts.similarity, opts.enhancement) - payload = { "text": text, - "voice_id": opts.voice_id, - "sample_rate": opts.sample_rate, - "speed": opts.speed, - "consistency": opts.consistency, - "similarity": opts.similarity, - "enhancement": opts.enhancement, - "language": opts.language, - "output_format": opts.output_format + "voice_id": voice_id, + "sample_rate": sample_rate, + "speed": speed, + "language": language, + "output_format": output_format } - if opts.model == "lightning-large" or opts.model == "lightning-v2": - if opts.consistency is not None: - payload["consistency"] = opts.consistency - if opts.similarity is not None: - payload["similarity"] = opts.similarity - if opts.enhancement is not None: - payload["enhancement"] = opts.enhancement + if model == "lightning-v2": + if consistency is not None: + payload["consistency"] = consistency + if similarity is not None: + payload["similarity"] = similarity + if enhancement is not None: + payload["enhancement"] = enhancement + + if pronunciation_dicts: + payload["pronunciation_dicts"] = pronunciation_dicts headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } - async with self.session.post(f"{API_BASE_URL}/{opts.model}/get_speech", json=payload, headers=headers) as res: + async with self.session.post(f"{API_BASE_URL}/{model}/get_speech", json=payload, headers=headers) as res: if res.status != 200: raise APIError(f"Failed to synthesize speech: {await res.text()}. For more information, visit https://waves.smallest.ai/") @@ -210,21 +172,20 @@ async def synthesize( await self.session.close() self.session = None - async def add_voice(self, display_name: str, file_path: str) -> str: """ - Instantly clone your voice asynchronously. + Clone a voice from an audio file. Args: - - display_name (str): The display name for the new voice. - - file_path (str): The path to the reference audio file to be cloned. + - display_name (str): Display name for the new voice. + - file_path (str): Path to the reference audio file. Returns: - - str: The response from the API as a formatted JSON string. + - str: API response as JSON. Raises: - - InvalidError: If the file does not exist or is not a valid audio file. - - APIError: If the API request fails or returns an error. + - InvalidError: If the file is invalid. + - APIError: If the API request fails. """ url = f"{API_BASE_URL}/lightning-large/add_voice" @@ -261,20 +222,19 @@ async def add_voice(self, display_name: str, file_path: str) -> str: if should_cleanup and self.session: await self.session.close() self.session = None - - + async def delete_voice(self, voice_id: str) -> str: """ - Delete a cloned voice asynchronously. + Delete a cloned voice. Args: - - voice_id (str): The ID of the voice to be deleted. + - voice_id (str): The voice ID to delete. Returns: - - str: The response from the API. + - str: API response. Raises: - - APIError: If the API request fails or returns an error. + - APIError: If the API request fails. """ url = f"{API_BASE_URL}/lightning-large" payload = {'voiceId': voice_id} @@ -299,20 +259,48 @@ async def delete_voice(self, voice_id: str) -> str: async def transcribe( self, file_path: str, - language: Optional[str] = "en", - word_timestamps: Optional[bool] = False, - age_detection: Optional[bool] = False, - gender_detection: Optional[bool] = False, - emotion_detection: Optional[bool] = False, - model: Optional[str] = "lightning" + language: str = "en", + word_timestamps: bool = False, + diarize: bool = False, + age_detection: bool = False, + gender_detection: bool = False, + emotion_detection: bool = False, + model: str = "pulse" ) -> dict: - validate_asr_input(file_path, model, language) + """ + Async transcribe audio from a file. - url = f"{API_BASE_URL}/speech-to-text" - headers = { - 'Authorization': f"Bearer {self.api_key}", + Args: + - file_path (str): Path to the audio file. + - language (str): Language code. Use "multi" for auto-detection. Default: "en". + - word_timestamps (bool): Include word-level timestamps. Default: False. + - diarize (bool): Enable speaker diarization. Default: False. + - age_detection (bool): Predict speaker age. Default: False. + - gender_detection (bool): Predict speaker gender. Default: False. + - emotion_detection (bool): Predict speaker emotion. Default: False. + - model (str): STT model. Default: "pulse". + + Returns: + - dict: Transcription result. + + Raises: + - ValidationError: If inputs are invalid. + - APIError: If the API request fails. + """ + validate_stt_input(file_path, model, language) + + params = { + 'model': model, + 'language': language, + 'word_timestamps': str(bool(word_timestamps)).lower(), + 'diarize': str(bool(diarize)).lower(), + 'age_detection': str(bool(age_detection)).lower(), + 'gender_detection': str(bool(gender_detection)).lower(), + 'emotion_detection': str(bool(emotion_detection)).lower() } + url = f"{API_BASE_URL}/pulse/get_text" + should_cleanup = await self._ensure_session() try: @@ -322,26 +310,16 @@ async def transcribe( async with aiofiles.open(file_path, 'rb') as f: file_data = await f.read() - form = aiohttp.FormData() - form.add_field( - 'file', - file_data, - filename=os.path.basename(file_path), - content_type=content_type - ) - # Send options as multipart form fields (not query params) - form.add_field('model', model) - form.add_field('language', language) - form.add_field('word_timestamps', str(bool(word_timestamps)).lower()) - form.add_field('age_detection', str(bool(age_detection)).lower()) - form.add_field('gender_detection', str(bool(gender_detection)).lower()) - form.add_field('emotion_detection', str(bool(emotion_detection)).lower()) - - async with self.session.post(url, headers=headers, data=form) as res: + headers = { + 'Authorization': f"Bearer {self.api_key}", + 'Content-Type': content_type + } + + async with self.session.post(url, headers=headers, params=params, data=file_data) as res: if res.status != 200: raise APIError( f"Failed to transcribe audio: {await res.text()}. " - "For more information, visit https://waves-docs.smallest.ai/v4.0.0/content/api-references/asr-post-api" + "For more information, visit https://waves-docs.smallest.ai/" ) return await res.json() finally: diff --git a/smallestai/waves/models.py b/smallestai/waves/models.py index 78a92c00..44048d6a 100644 --- a/smallestai/waves/models.py +++ b/smallestai/waves/models.py @@ -1,16 +1,29 @@ -TTSLanguages_lightning = ["en", "hi"] -TTSLanguages_lightning_large = ["en", "hi"] -TTSLanguages_lightning_v2 = ["en", "hi", "mr", "kn", "ta", "bn", "gu", "de", "fr", "es", "it", "pl", "nl", "ru", "ar", "he"] +DEFAULT_TTS_MODEL = "lightning-v3.1" + +DEFAULT_STT_MODEL = "pulse" + +# Lightning v2: supports 19 languages +TTSLanguages_lightning_v2 = [ + "en", "hi", "ta", "kn", "mr", "bn", "gu", "ar", "he", + "fr", "de", "pl", "ru", "it", "nl", "es", "sv", "ml", "te" +] +# Lightning v3.1: supports 4 languages +TTSLanguages_lightning_v3_1 = ["en", "hi", "ta", "es"] + +# Available TTS Models TTSModels = [ - "lightning", - "lightning-large", - "lightning-v2" + "lightning-v2", + "lightning-v3.1" ] -ASRLanguages_lightning = [ + +# STT Languages (Pulse model) +STTLanguages_pulse = [ "it", "es", "en", "pt", "hi", "de", "fr", "uk", "ru", "kn", "ml", "pl", "mr", "gu", "cs", "sk", "te", "or", "nl", "bn", "lv", "et", "ro", "pa", "fi", "sv", "bg", "ta", "hu", "da", "lt", "mt", "multi" ] -ASRModels = [ - "lightning" + +# Available STT Models +STTModels = [ + "pulse" ] diff --git a/smallestai/waves/stream_tts.py b/smallestai/waves/stream_tts.py index 20f39c60..5318176f 100644 --- a/smallestai/waves/stream_tts.py +++ b/smallestai/waves/stream_tts.py @@ -3,16 +3,31 @@ import time import threading import queue -from typing import Generator +from typing import Generator, Optional from dataclasses import dataclass from websocket import WebSocketApp @dataclass class TTSConfig: + """Configuration for TTS WebSocket streaming. + + Attributes: + voice_id: The voice identifier to use for synthesis. + api_key: API key for authentication. + model: TTS model to use. Options: "lightning-v3.1", "lightning-v2". Default: "lightning-v3.1". + language: Language code. Default: "en". + sample_rate: Audio sample rate in Hz. Default: 44100 for v3.1, 24000 for v2. + speed: Speech speed multiplier (0.5-2.0). Default: 1.0. + consistency: Controls word repetition/skipping (0-1). Only for lightning-v2. Default: 0.5. + enhancement: Speech quality enhancement (0-2). Only for lightning-v2. Default: 1. + similarity: Reference audio similarity (0-1). Only for lightning-v2. Default: 0. + max_buffer_flush_ms: Buffer flush interval in ms. Default: 0. + """ voice_id: str api_key: str + model: str = "lightning-v3.1" language: str = "en" - sample_rate: int = 24000 + sample_rate: int = 44100 speed: float = 1.0 consistency: float = 0.5 enhancement: int = 1 @@ -20,9 +35,27 @@ class TTSConfig: max_buffer_flush_ms: int = 0 class WavesStreamingTTS: + """ + Streaming Text-to-Speech client using WebSocket API. + + Supports both Lightning v2 and Lightning v3.1 models. + + Example: + config = TTSConfig( + voice_id="sophia", + api_key="your_api_key", + model="lightning-v3.1" + ) + tts = WavesStreamingTTS(config) + + for audio_chunk in tts.synthesize("Hello, world!"): + # Process audio chunk + pass + """ + def __init__(self, config: TTSConfig): self.config = config - self.ws_url = "wss://waves-api.smallest.ai/api/v1/lightning-v2/get_speech/stream" + self.ws_url = f"wss://waves-api.smallest.ai/api/v1/{config.model}/get_speech/stream" self.ws = None self.audio_queue = queue.Queue() self.error_queue = queue.Queue() @@ -34,19 +67,23 @@ def _get_headers(self): return [f"Authorization: Bearer {self.config.api_key}"] def _create_payload(self, text: str, continue_stream: bool = False, flush: bool = False): - return { + payload = { "voice_id": self.config.voice_id, "text": text, "language": self.config.language, "sample_rate": self.config.sample_rate, "speed": self.config.speed, - "consistency": self.config.consistency, - "similarity": self.config.similarity, - "enhancement": self.config.enhancement, "max_buffer_flush_ms": self.config.max_buffer_flush_ms, "continue": continue_stream, "flush": flush } + + if self.config.model == "lightning-v2": + payload["consistency"] = self.config.consistency + payload["similarity"] = self.config.similarity + payload["enhancement"] = self.config.enhancement + + return payload def _on_open(self, ws): self.is_connected = True diff --git a/smallestai/waves/utils.py b/smallestai/waves/utils.py index 16a26784..146f3f02 100644 --- a/smallestai/waves/utils.py +++ b/smallestai/waves/utils.py @@ -1,76 +1,92 @@ import os -from typing import List -from typing import Optional -from dataclasses import dataclass +from typing import List, Optional from smallestai.waves.exceptions import ValidationError -from smallestai.waves.models import TTSModels, TTSLanguages_lightning, TTSLanguages_lightning_large, TTSLanguages_lightning_v2, ASRModels, ASRLanguages_lightning +from smallestai.waves.models import ( + TTSModels, + TTSLanguages_lightning_v2, + TTSLanguages_lightning_v3_1, + STTModels, + STTLanguages_pulse, +) API_BASE_URL = "https://waves-api.smallest.ai/api/v1" -WEBSOCKET_URL = "wss://waves-api.smallest.ai/api/v1/lightning-v2/get_speech/stream" + SAMPLE_WIDTH = 2 CHANNELS = 1 ALLOWED_AUDIO_EXTENSIONS = ['.mp3', '.wav'] +VALID_SAMPLE_RATES = { + "lightning-v2": [8000, 16000, 24000], + "lightning-v3.1": [8000, 16000, 24000, 44100], +} -@dataclass -class TTSOptions: - model: str - sample_rate: int - voice_id: str - api_key: str - speed: float - consistency: float - similarity: float - enhancement: int - language: str - output_format: str - -@dataclass -class ASROptions: - model: str - api_key: str - language: str - word_timestamps: bool - age_detection: bool - gender_detection: bool - emotion_detection: bool - -def validate_asr_input(file_path: str, model: str, language: str): +DEFAULT_SAMPLE_RATES = { + "lightning-v2": 24000, + "lightning-v3.1": 44100, +} + + +def validate_stt_input(file_path: str, model: str, language: str): + """Validate STT input parameters.""" if not os.path.isfile(file_path): raise ValidationError("Invalid file path. File does not exist.") - if model not in ASRModels: - raise ValidationError(f"Invalid model: {model}. Must be one of {ASRModels}") - if language not in ASRLanguages_lightning: - raise ValidationError(f"Invalid language: {language}. Must be one of {ASRLanguages_lightning}") + if model not in STTModels: + raise ValidationError(f"Invalid model: {model}. Must be one of {STTModels}") + if language not in STTLanguages_pulse: + raise ValidationError(f"Invalid language: {language}. Must be one of {STTLanguages_pulse}") -def validate_input(text: str, model: str, sample_rate: int, speed: float, consistency: Optional[float] = None, similarity: Optional[float] = None, enhancement: Optional[int] = None): + +def validate_tts_input( + text: str, + model: str, + sample_rate: int, + speed: float, + consistency: Optional[float] = None, + similarity: Optional[float] = None, + enhancement: Optional[int] = None +): + """Validate TTS input parameters.""" if not text: raise ValidationError("Text cannot be empty.") if model not in TTSModels: raise ValidationError(f"Invalid model: {model}. Must be one of {TTSModels}") - if not 8000 <= sample_rate <= 24000: - raise ValidationError(f"Invalid sample rate: {sample_rate}. Must be between 8000 and 24000") + + valid_rates = VALID_SAMPLE_RATES.get(model, [8000, 16000, 24000]) + if sample_rate not in valid_rates: + raise ValidationError(f"Invalid sample rate: {sample_rate}. Must be one of {valid_rates} for model {model}") + if not 0.5 <= speed <= 2.0: raise ValidationError(f"Invalid speed: {speed}. Must be between 0.5 and 2.0") - if consistency is not None and not 0.0 <= consistency <= 1.0: - raise ValidationError(f"Invalid consistency: {consistency}. Must be between 0.0 and 1.0") - if similarity is not None and not 0.0 <= similarity <= 1.0: - raise ValidationError(f"Invalid similarity: {similarity}. Must be between 0.0 and 1.0") - if enhancement is not None and not 0 <= enhancement <= 2: - raise ValidationError(f"Invalid enhancement: {enhancement}. Must be between 0 and 2.") - - -def get_smallest_languages(model: str = 'lightning') -> List[str]: - if model == 'lightning': - return TTSLanguages_lightning - elif model == 'lightning-large': - return TTSLanguages_lightning_large - elif model == 'lightning-v2': + + if model == "lightning-v2": + if consistency is not None and not 0.0 <= consistency <= 1.0: + raise ValidationError(f"Invalid consistency: {consistency}. Must be between 0.0 and 1.0") + if similarity is not None and not 0.0 <= similarity <= 1.0: + raise ValidationError(f"Invalid similarity: {similarity}. Must be between 0.0 and 1.0") + if enhancement is not None and not 0 <= enhancement <= 2: + raise ValidationError(f"Invalid enhancement: {enhancement}. Must be between 0 and 2.") + + +def get_smallest_languages(model: str = 'lightning-v3.1') -> List[str]: + """Get available languages for a model (TTS or STT).""" + if model == 'lightning-v2': return TTSLanguages_lightning_v2 + elif model == 'lightning-v3.1': + return TTSLanguages_lightning_v3_1 + elif model == 'pulse': + return STTLanguages_pulse else: - raise ValidationError(f"Invalid model: {model}. Must be one of {TTSModels}") + all_models = TTSModels + STTModels + raise ValidationError(f"Invalid model: {model}. Must be one of {all_models}") -def get_smallest_models() -> List[str]: + +def get_tts_models() -> List[str]: + """Get available TTS models.""" return TTSModels + + +def get_stt_models() -> List[str]: + """Get available STT models.""" + return STTModels diff --git a/smallestai/waves/waves_client.py b/smallestai/waves/waves_client.py index 6dcd7f9b..ab7f5eb1 100644 --- a/smallestai/waves/waves_client.py +++ b/smallestai/waves/waves_client.py @@ -1,75 +1,37 @@ import os import json -import copy import requests -from typing import Optional, Union, List +from typing import Optional, List from smallestai.waves.exceptions import InvalidError, APIError -from smallestai.waves.utils import (TTSOptions, validate_input, validate_asr_input, - get_smallest_languages, get_smallest_models, ALLOWED_AUDIO_EXTENSIONS, API_BASE_URL) +from smallestai.waves.utils import (validate_tts_input, validate_stt_input, + get_smallest_languages, get_tts_models, get_stt_models, ALLOWED_AUDIO_EXTENSIONS, API_BASE_URL, + DEFAULT_SAMPLE_RATES) + class WavesClient: - def __init__( - self, - api_key: str = None, - model: Optional[str] = "lightning", - sample_rate: Optional[int] = 24000, - voice_id: Optional[str] = "emily", - speed: Optional[float] = 1.0, - consistency: Optional[float] = 0.5, - similarity: Optional[float] = 0.0, - enhancement: Optional[int] = 1, - language: Optional[str] = "en", - output_format: Optional[str] = "wav" - ) -> None: + def __init__(self, api_key: str = None) -> None: """ - Smallest Instance for text-to-speech synthesis. - - This is a synchronous implementation of the text-to-speech functionality. - For an asynchronous version, please refer to the AsyncSmallest Instance. + Waves Client for synchronous Text-to-Speech and Speech-to-Text. + For an asynchronous version, please refer to the AsyncWavesClient. Args: - - api_key (str): The API key for authentication, export it as 'SMALLEST_API_KEY' in your environment variables. - - model (TTSModels): The model to be used for synthesis. - - sample_rate (int): The sample rate for the audio output. - - voice_id (TTSVoices): The voice to be used for synthesis. - - speed (float): The speed of the speech synthesis. - - consistency (float): This parameter controls word repetition and skipping. Decrease it to prevent skipped words, and increase it to prevent repetition. Only supported in `lightning-large` model. Range - [0, 1] - - similarity (float): This parameter controls the similarity between the synthesized audio and the reference audio. Increase it to make the speech more similar to the reference audio. Only supported in `lightning-large` model. Range - [0, 1] - - enhancement (int): Enhances speech quality at the cost of increased latency. Only supported in `lightning-large` model. Range - [0, 2]. - - language (str): The language for synthesis. Default is "en". - - output_format (str): The output audio format. Options: "pcm", "mp3", "wav", "mulaw". Default is "pcm". + - api_key (str): The API key for authentication. + Set via parameter or 'SMALLEST_API_KEY' environment variable. Methods: - - get_languages: Returns a list of available languages for synthesis. - - get_voices: Returns a list of available voices for synthesis. - - get_models: Returns a list of available models for synthesis. - - synthesize: Converts the provided text into speech and returns the audio content. + - synthesize: Converts text to speech. + - transcribe: Converts speech to text. + - get_languages: Returns available languages for a model. + - get_voices: Returns available voices for a model. + - get_models: Returns available TTS models. """ self.api_key = api_key or os.environ.get("SMALLEST_API_KEY") if not self.api_key: raise InvalidError() - if model == "lightning-large" and voice_id is None: - voice_id = "lakshya" - self.chunk_size = 250 - - self.opts = TTSOptions( - model=model, - sample_rate=sample_rate, - voice_id=voice_id, - api_key=self.api_key, - speed=speed, - consistency=consistency, - similarity=similarity, - enhancement=enhancement, - language=language, - output_format=output_format - ) - - - def get_languages(self, model:str="lightning") -> List[str]: - """Returns a list of available languages.""" + def get_languages(self, model: str = "lightning-v3.1") -> List[str]: + """Returns a list of available languages for a model (TTS or STT).""" return get_smallest_languages(model) def get_cloned_voices(self) -> str: @@ -83,13 +45,9 @@ def get_cloned_voices(self) -> str: raise APIError(f"Failed to get cloned voices: {res.text}. For more information, visit https://waves.smallest.ai/") return json.dumps(res.json(), indent=4, ensure_ascii=False) - - def get_voices( - self, - model: Optional[str] = "lightning" - ) -> str: - """Returns a list of available voices.""" + def get_voices(self, model: str = "lightning-v3.1") -> str: + """Returns a list of available voices for a TTS model.""" headers = { "Authorization": f"Bearer {self.api_key}", } @@ -99,93 +57,105 @@ def get_voices( raise APIError(f"Failed to get voices: {res.text}. For more information, visit https://waves.smallest.ai/") return json.dumps(res.json(), indent=4, ensure_ascii=False) - - def get_models(self) -> List[str]: - """Returns a list of available models.""" - return get_smallest_models() - - + def get_tts_models(self) -> List[str]: + """Returns a list of available TTS models.""" + return get_tts_models() + + def get_stt_models(self) -> List[str]: + """Returns a list of available STT models.""" + return get_stt_models() + def synthesize( self, text: str, - **kwargs - ) -> Union[bytes]: + model: str = "lightning-v3.1", + voice_id: Optional[str] = None, + sample_rate: Optional[int] = None, + speed: float = 1.0, + language: str = "en", + output_format: str = "wav", + consistency: Optional[float] = 0.5, + similarity: Optional[float] = 0.0, + enhancement: Optional[int] = 1, + pronunciation_dicts: Optional[List[str]] = None + ) -> bytes: """ - Synthesize speech from the provided text. + Synthesize speech from text. - - text (str): The text to be converted to speech. - - stream (Optional[bool]): If True, returns an iterator yielding audio chunks instead of a full byte array. - - kwargs: Additional optional parameters to override `__init__` options for this call. + Args: + - text (str): The text to convert to speech. + - model (str): TTS model. Options: "lightning-v3.1", "lightning-v2". Default: "lightning-v3.1". + - voice_id (str): Voice ID. Default: "sophia" for v3.1, "alice" for v2. + - sample_rate (int): Sample rate in Hz. Default: 44100 for v3.1, 24000 for v2. + - speed (float): Speech speed (0.5-2.0). Default: 1.0. + - language (str): Language code. Default: "en". + - output_format (str): Output format ("pcm", "mp3", "wav", "mulaw"). Default: "wav". + - consistency (float): Word repetition control (0-1). Only for lightning-v2. Default: 0.5. + - similarity (float): Reference audio similarity (0-1). Only for lightning-v2. Default: 0.0. + - enhancement (int): Quality enhancement (0-2). Only for lightning-v2. Default: 1. + - pronunciation_dicts (List[str]): Pronunciation dictionary IDs. Default: None. Returns: - - Union[bytes, None, Iterator[bytes]]: - - If `stream=True`, returns an iterator yielding audio chunks. - - If `save_as` is provided, saves the file and returns None. - - Otherwise, returns the synthesized audio content as bytes. + - bytes: The synthesized audio content. Raises: - - InvalidError: If the provided file name does not have a .wav or .mp3 extension when `save_as` is specified. - - APIError: If the API request fails or returns an error. + - ValidationError: If input parameters are invalid. + - APIError: If the API request fails. """ - opts = copy.deepcopy(self.opts) - valid_keys = set(vars(opts).keys()) - - invalid_keys = [key for key in kwargs if key not in valid_keys] - if invalid_keys: - raise ValueError(f"Invalid parameter(s) in kwargs: {', '.join(invalid_keys)}. Allowed parameters are: {', '.join(valid_keys)}") - - for key, value in kwargs.items(): - setattr(opts, key, value) + if sample_rate is None: + sample_rate = DEFAULT_SAMPLE_RATES.get(model, 24000) + + if voice_id is None: + voice_id = "sophia" if model == "lightning-v3.1" else "alice" - validate_input(text, opts.model, opts.sample_rate, opts.speed, opts.consistency, opts.similarity, opts.enhancement) + validate_tts_input(text, model, sample_rate, speed, consistency, similarity, enhancement) payload = { "text": text, - "voice_id": opts.voice_id, - "sample_rate": opts.sample_rate, - "speed": opts.speed, - "consistency": opts.consistency, - "similarity": opts.similarity, - "enhancement": opts.enhancement, - "language": opts.language, - "output_format": opts.output_format + "voice_id": voice_id, + "sample_rate": sample_rate, + "speed": speed, + "language": language, + "output_format": output_format } - if opts.model == "lightning-large" or opts.model == "lightning-v2": - if opts.consistency is not None: - payload["consistency"] = opts.consistency - if opts.similarity is not None: - payload["similarity"] = opts.similarity - if opts.enhancement is not None: - payload["enhancement"] = opts.enhancement + if model == "lightning-v2": + if consistency is not None: + payload["consistency"] = consistency + if similarity is not None: + payload["similarity"] = similarity + if enhancement is not None: + payload["enhancement"] = enhancement + + if pronunciation_dicts: + payload["pronunciation_dicts"] = pronunciation_dicts headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } - res = requests.post(f"{API_BASE_URL}/{opts.model}/get_speech", json=payload, headers=headers) + res = requests.post(f"{API_BASE_URL}/{model}/get_speech", json=payload, headers=headers) if res.status_code != 200: raise APIError(f"Failed to synthesize speech: {res.text}. Please check if you have set the correct API key. For more information, visit https://waves.smallest.ai/") return res.content - - + def add_voice(self, display_name: str, file_path: str) -> str: """ - Instantly clone your voice synchronously. + Instantly clone your voice from an audio file. Args: - - display_name (str): The display name for the new voice. - - file_path (str): The path to the reference audio file to be cloned. + - display_name (str): Display name for the new voice. + - file_path (str): Path to the reference audio file. Returns: - - str: The response from the API as a formatted JSON string. + - str: API response as JSON. Raises: - - InvalidError: If the file does not exist or is not a valid audio file. - - APIError: If the API request fails or returns an error. + - InvalidError: If the file is invalid. + - APIError: If the API request fails. """ if not os.path.isfile(file_path): raise InvalidError("Invalid file path. File does not exist.") @@ -209,19 +179,18 @@ def add_voice(self, display_name: str, file_path: str) -> str: return response.json() - def delete_voice(self, voice_id: str) -> str: """ - Delete a cloned voice synchronously. + Delete a cloned voice. Args: - - voice_id (str): The ID of the voice to be deleted. + - voice_id (str): The voice ID to delete. Returns: - - str: The response from the API. + - str: API response. Raises: - - APIError: If the API request fails or returns an error. + - APIError: If the API request fails. """ url = f"{API_BASE_URL}/lightning-large" payload = {'voiceId': voice_id} @@ -239,36 +208,63 @@ def delete_voice(self, voice_id: str) -> str: def transcribe( self, file_path: str, - language: Optional[str] = "en", - word_timestamps: Optional[bool] = False, - age_detection: Optional[bool] = False, - gender_detection: Optional[bool] = False, - emotion_detection: Optional[bool] = False, - model: Optional[str] = "lightning" + language: str = "en", + word_timestamps: bool = False, + diarize: bool = False, + age_detection: bool = False, + gender_detection: bool = False, + emotion_detection: bool = False, + model: str = "pulse" ) -> dict: - validate_asr_input(file_path, model, language) + """ + Transcribe audio from a file. - url = f"{API_BASE_URL}/speech-to-text" - headers = { - 'Authorization': f"Bearer {self.api_key}", - } - payload = { + Args: + - file_path (str): Path to the audio file. + - language (str): Language code. Use "multi" for auto-detection. Default: "en". + - word_timestamps (bool): Include word-level timestamps. Default: False. + - diarize (bool): Enable speaker diarization. Default: False. + - age_detection (bool): Predict speaker age. Default: False. + - gender_detection (bool): Predict speaker gender. Default: False. + - emotion_detection (bool): Predict speaker emotion. Default: False. + - model (str): STT model. Default: "pulse". + + Returns: + - dict: Transcription result with transcript, words, utterances, and metadata. + + Raises: + - ValidationError: If inputs are invalid. + - APIError: If the API request fails. + """ + validate_stt_input(file_path, model, language) + + params = { 'model': model, 'language': language, 'word_timestamps': str(bool(word_timestamps)).lower(), + 'diarize': str(bool(diarize)).lower(), 'age_detection': str(bool(age_detection)).lower(), 'gender_detection': str(bool(gender_detection)).lower(), 'emotion_detection': str(bool(emotion_detection)).lower() } + + url = f"{API_BASE_URL}/pulse/get_text" + headers = { + 'Authorization': f"Bearer {self.api_key}", + } file_extension = os.path.splitext(file_path)[1].lower() content_type = f"audio/{file_extension[1:]}" if file_extension else "application/octet-stream" with open(file_path, 'rb') as f: - files = {'file': (os.path.basename(file_path), f, content_type)} - response = requests.post(url, headers=headers, files=files, data=payload) + response = requests.post( + url, + headers={**headers, 'Content-Type': content_type}, + params=params, + data=f.read() + ) if response.status_code != 200: - raise APIError(f"Failed to transcribe audio: {response.text}. For more information, visit https://waves-docs.smallest.ai/v4.0.0/content/api-references/asr-post-api") + raise APIError(f"Failed to transcribe audio: {response.text}. For more information, visit https://waves-docs.smallest.ai/") - return response.json() \ No newline at end of file + return response.json()