Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,7 @@ cython_debug/

# PyPI configuration file
.pypirc

# runtime
triton_python_backend_utils.py
*.wav
141 changes: 107 additions & 34 deletions sparktts/models/audio_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np

from pathlib import Path
from typing import Any, Dict, Tuple
from typing import Any, Dict, Literal, Optional, Tuple, List, Union
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model

from sparktts.utils.file import load_config
Expand All @@ -39,19 +39,25 @@ def __init__(self, model_dir: Path, device: torch.device = None, **kwargs):
self.device = device
self.model_dir = model_dir
self.config = load_config(f"{model_dir}/config.yaml")
self._initialize_model()
self._initialize_model(**kwargs)

def _initialize_model(self):
def _initialize_model(
self,
attn_implementation: Optional[Literal["sdpa", "flash_attention_2", "eager"]] = None,
):
"""Load and initialize the BiCodec model and Wav2Vec2 feature extractor."""
self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to(
self.device
)
self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to(self.device)
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
f"{self.model_dir}/wav2vec2-large-xlsr-53"
)
self.feature_extractor = Wav2Vec2Model.from_pretrained(
f"{self.model_dir}/wav2vec2-large-xlsr-53"
).to(self.device)
self.feature_extractor = (
Wav2Vec2Model.from_pretrained(
f"{self.model_dir}/wav2vec2-large-xlsr-53",
attn_implementation=attn_implementation,
)
.to(self.device)
.eval()
)
self.feature_extractor.config.output_hidden_states = True

def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
Expand All @@ -69,8 +75,11 @@ def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:

return wav[:ref_segment_length]

def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]:
"""load auido and get reference audio from wav path"""
def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, np.ndarray]:
"""
load auido and get reference audio from wav path
return (wav, wav_ref) # (shape:(seq_len), shape:(seq_len))
"""
wav = load_audio(
wav_path,
sampling_rate=self.config["sample_rate"],
Expand All @@ -79,24 +88,23 @@ def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]:

wav_ref = self.get_ref_clip(wav)

wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float()
return wav, wav_ref

def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
"""extract wav2vec2 features"""
def extract_wav2vec2_features(self, wavs: np.ndarray | List[np.ndarray]) -> torch.Tensor:
"""extract wav2vec2 features
return: torch.Tensor shape:(batch_size, features_seq_len, feature_dim)
"""
inputs = self.processor(
wavs,
sampling_rate=16000,
return_tensors="pt",
padding=True,
output_hidden_states=True,
).input_values
feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
feats_mix = (
feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
) / 3
# output_hidden_states=True,
).to(self.feature_extractor.device)
feat = self.feature_extractor(**inputs)
feats_mix = (feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]) / 3

return feats_mix
return feats_mix.detach()

def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor:
"""tokenize the batch of audio
Expand All @@ -116,22 +124,66 @@ def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor:

return global_tokens, semantic_tokens

def batch_tokenize(
self, audio_paths: Union[str | List[str]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
return (global_tokens, semantic_tokens):
- semantic_tokens: semantic tokens. shape: (batch_size, latent_dim)
- global_tokens: global tokens. shape: (batch_size, channel, global_dim)
"""
if isinstance(audio_paths, str):
audio_paths = [audio_paths]
wav_list = []
audio_clip = []
for audio_path in audio_paths:
wav, wav_ref = self.process_audio(audio_path)
wav_list.append(wav)
audio_clip.append(torch.from_numpy(wav_ref))

audio_clip = torch.stack(audio_clip).to(self.device)
audio_features = self.extract_wav2vec2_features(wav_list)

batch = {
"ref_wav": audio_clip.float().to(self.device), # [batch_size,seq_len]
"feat": audio_features.to(self.device), # [batch_size,features_seq_len,feature_dim]
}
semantic_tokens, global_tokens = self.model.tokenize(batch) # [batch_size,seq_len]

if self.device.type == "cuda":
torch.cuda.empty_cache()

return global_tokens, semantic_tokens

def batch_detokenize(
self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor
) -> np.array:
wav_rec = self.model.detokenize(semantic_tokens, global_tokens)
if self.device.type == "cuda":
torch.cuda.empty_cache()
return wav_rec.squeeze().cpu().numpy()

def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""tokenize the audio"""
"""tokenize the audio
return (global_tokens, semantic_tokens):
- semantic_tokens: semantic tokens. shape: (batch_size, latent_dim)
- global_tokens: global tokens. shape: (batch_size, channel, global_dim)
"""
wav, ref_wav = self.process_audio(audio_path)
feat = self.extract_wav2vec2_features(wav)
batch = {
"wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device),
"ref_wav": ref_wav.to(self.device),
# "wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device),
"ref_wav": torch.from_numpy(ref_wav).unsqueeze(0).float().to(self.device),
"feat": feat.to(self.device),
}
semantic_tokens, global_tokens = self.model.tokenize(batch)

if self.device.type == "cuda":
torch.cuda.empty_cache()

return global_tokens, semantic_tokens

def detokenize(
self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor
) -> np.array:
def detokenize(self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor) -> np.array:
"""detokenize the tokens to waveform

Args:
Expand All @@ -143,21 +195,42 @@ def detokenize(
"""
global_tokens = global_tokens.unsqueeze(1)
wav_rec = self.model.detokenize(semantic_tokens, global_tokens)
return wav_rec.detach().squeeze().cpu().numpy()
if self.device.type == "cuda":
torch.cuda.empty_cache()
return wav_rec.squeeze().cpu().numpy()


# test
if __name__ == "__main__":
import soundfile as sf
import os
from time import perf_counter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BiCodecTokenizer(
model_dir="pretrained_models/Spark-TTS-0.5B",
model_dir=os.getenv("MODEL_DIR", "pretrained_models/Spark-TTS-0.5B"),
device=device,
)
wav_path = "example/prompt_audio.wav"

global_tokens, semantic_tokens = tokenizer.tokenize(wav_path)

wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens)
sf.write("example/prompt_recon.wav", wav_rec, 16000)
wav_cases = {
"single": "example/prompt_audio.wav",
"multi": ["example/prompt_audio.wav", "example/prompt_audio.wav"],
}
for case, wav_path in wav_cases.items():
start_time = perf_counter()
if isinstance(wav_path, list):
global_tokens, semantic_tokens = tokenizer.batch_tokenize(wav_path)
else:
global_tokens, semantic_tokens = tokenizer.tokenize(wav_path)
print(f"""{case} encode elapsed time: {perf_counter()-start_time:.4f} seconds""")
print(semantic_tokens.shape, global_tokens.shape)

start_time = perf_counter()
wav_rec = tokenizer.detokenize(global_tokens.squeeze(1), semantic_tokens)
print(f"""{case} decode elapsed time: {perf_counter()-start_time:.4f} seconds""")
print(wav_rec.shape)
if len(wav_rec.shape) > 1:
for i, wav in enumerate(wav_rec):
sf.write(f"example/prompt_recon_{i}.wav", wav, 16000)
if len(wav_rec.shape) == 1:
sf.write("example/prompt_recon.wav", wav_rec, 16000)
29 changes: 17 additions & 12 deletions sparktts/models/bicodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import torch.nn as nn
from pathlib import Path
from typing import Dict, Any
from omegaconf import DictConfig
from safetensors.torch import load_file

from sparktts.utils.file import load_config
Expand All @@ -43,7 +42,7 @@ def __init__(
speaker_encoder: nn.Module,
prenet: nn.Module,
postnet: nn.Module,
**kwargs
**kwargs,
) -> None:
"""
Initializes the BiCodec model with the required components.
Expand Down Expand Up @@ -73,12 +72,12 @@ def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec":

Args:
model_dir (Path): Path to the model directory containing checkpoint and config.

Returns:
BiCodec: The initialized BiCodec model.
"""
ckpt_path = f'{model_dir}/model.safetensors'
config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer']
ckpt_path = f"{model_dir}/model.safetensors"
config = load_config(f"{model_dir}/config.yaml")["audio_tokenizer"]
mel_params = config["mel_params"]
encoder = Encoder(**config["encoder"])
quantizer = FactorizedVectorQuantize(**config["quantizer"])
Expand Down Expand Up @@ -116,7 +115,7 @@ def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:

Args:
batch (dict): A dictionary containing features, reference waveform, and target waveform.

Returns:
dict: A dictionary containing the reconstruction, features, and other metrics.
"""
Expand Down Expand Up @@ -166,7 +165,7 @@ def tokenize(self, batch: Dict[str, Any]):
semantic_tokens = self.quantizer.tokenize(z)
global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2))

return semantic_tokens, global_tokens
return semantic_tokens.detach(), global_tokens.detach()

@torch.no_grad()
def detokenize(self, semantic_tokens, global_tokens):
Expand All @@ -186,7 +185,7 @@ def detokenize(self, semantic_tokens, global_tokens):
x = x + d_vector.unsqueeze(-1)
wav_recon = self.decoder(x)

return wav_recon
return wav_recon.detach()

def init_mel_transformer(self, config: Dict[str, Any]):
"""
Expand All @@ -212,6 +211,7 @@ def init_mel_transformer(self, config: Dict[str, Any]):

def remove_weight_norm(self):
"""Removes weight normalization from all layers."""

def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
Expand All @@ -223,16 +223,20 @@ def _remove_weight_norm(m):

# Test the model
if __name__ == "__main__":

config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml")
model = BiCodec.load_from_checkpoint(
model_dir="pretrained_models/SparkTTS-0.5B/BiCodec",
)
device = "cpu" if not torch.cuda.is_available() else "cuda"
print(model)
model_million_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"{model_million_params}M parameters")
model.to(device)

# Generate random inputs for testing
duration = 0.96
x = torch.randn(20, 1, int(duration * 16000))
feat = torch.randn(20, int(duration * 50), 1024)
x = torch.randn(20, 1, int(duration * 16000)).to(device)
feat = torch.randn(20, int(duration * 50), 1024).to(device)
inputs = {"feat": feat, "wav": x, "ref_wav": x}

# Forward pass
Expand All @@ -241,7 +245,8 @@ def _remove_weight_norm(m):
wav_recon = model.detokenize(semantic_tokens, global_tokens)

# Verify if the reconstruction matches
if torch.allclose(outputs["recons"].detach(), wav_recon):
if torch.allclose(outputs["recons"].detach(), wav_recon, rtol=1e-3, atol=1e-5):
# if torch.allclose(outputs["recons"].detach(), wav_recon):
print("Test successful")
else:
print("Test failed")