Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions cosyvoice/cli/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import random
from typing import Callable, List, Tuple, Union, Optional

import contractions
import librosa
import inflect
import onnxruntime
import torch
Expand Down Expand Up @@ -580,6 +580,10 @@ class TTSFrontEnd:
"""
Unified Frontend for TTS, managing Text Frontend, Speech Tokenizer, and Speaker Embedding extraction.
"""
LOWER_SR = 16000
#HIGH_SR = 22050
HIGH_SR = 24000

def __init__(self,
tokenize_fn: Callable,
speech_tokenizer: SpeechTokenizer,
Expand Down Expand Up @@ -651,12 +655,72 @@ def _extract_speech_feat(self, speech, sample_rate=24000):
speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat

def postprocess(self,speech, top_db=60, hop_length=220, win_length=440):
max_val = 0.8

speech, _ = librosa.effects.trim(
speech, top_db=top_db,
frame_length=win_length,
hop_length=hop_length
)

if speech.abs().max() > max_val:
speech = speech / speech.abs().max() * max_val

zeros = torch.zeros(1, int(self.HIGH_SR * 0.2))

print(speech, zeros)

speech = torch.concat([speech, zeros], dim=1)

return speech

def load_spk_from_wav(self, wav_file):
target_wav, sample_rate = torchaudio.load(wav_file)
if target_wav.shape[0] == 2:
# 计算两个声道的平均值
target_wav = target_wav.mean(dim=0, keepdim=True)

target_wav_high = torchaudio.transforms.Resample(sample_rate, self.HIGH_SR)(target_wav)
target_wav_high = self.postprocess(target_wav_high)
target_wav_lower = torchaudio.transforms.Resample(self.HIGH_SR, self.LOWER_SR)(target_wav_high)

speech_feat = self._extract_speech_feat(target_wav_high)
speech_token = self._extract_speech_token(target_wav_lower)
embedding = self._extract_spk_embedding(target_wav_lower)

print(f"speech_feat {type(speech_feat)}")
print(f"speech_token {type(speech_token)}")
print(f"embedding {type(embedding)}")

return {
"speech_feat": speech_feat,
"speech_token": speech_token,
"embedding": embedding
}

if __name__ == "__main__":
def test_text_frontend():
# Test example
frontend = TextFrontEnd(use_phoneme=True)
text = frontend.text_normalize("You're absolutely killing it! Keep that amazing energy up—nothing can stop you, girl! You're gonna rock it!")
print(f"English Normalization: {text}")

text = frontend.text_normalize("噢,我知道了。")
print(f"Chinese Normalization: {text}")
print(f"Chinese Normalization: {text}")

text = "你好,世界!This is a test sentence with numbers 123 and symbols #@$%."
normalized_text = frontend.text_normalize(text)
print(f"Normalized Text: {normalized_text}")

g2p_text = frontend.g2p_infer(normalized_text)
print(f"G2P Text: {g2p_text}")

def test_speech_frontend():
pass


"""
python -m cosyvoice.cli.frontend
"""
if __name__ == "__main__":
test_text_frontend()
11 changes: 9 additions & 2 deletions llm/glmtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import json
import yaml
import queue
from typing import Union, Optional, List, Dict, Any
import torch
import torch.nn as nn
Expand Down Expand Up @@ -158,7 +159,8 @@ def inference(
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
sample_method: str = "ras",
spk: str = "tongtong"
spk: str = "tongtong",
queue: queue.Queue = None,
) -> torch.Tensor:
"""
Autoregressive inference loop to generate speech tokens from text.
Expand Down Expand Up @@ -270,11 +272,16 @@ def inference(
if top_ids == self.eoa:
break

if queue is not None:
queue.put_nowait(top_ids - self.ats)
out_tokens.append(top_ids)

# Prepare input for the next step (auto-regressive)
# Prepare input for the next step (auto-regressive) use cache prefix
inputs_embeds = self.llama_embedding(torch.LongTensor([top_ids]).to(device))[None]

if queue is not None:
queue.put_nowait(None) # Signal completion

# 5. Validation and Output Construction
# Ensure all tokens are within the valid audio token range
for token in out_tokens:
Expand Down
4 changes: 2 additions & 2 deletions utils/hift_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def extract_mel(wav: torch.Tensor) -> torch.Tensor:
center=False
)

def load_hift(device: str = "cuda", load_only_nsf: bool = False) -> HiFTInference:
def load_hift(device: str = "cuda", load_only_nsf: bool = False, ckpt_path: str = '') -> HiFTInference:
"""Factory function to load HiFT model."""
# Update this path to your actual relative path for the open source release
ckpt_path = 'ckpt/hift/hift.pt'
ckpt_path = ckpt_path or 'ckpt/hift/hift.pt'
print(f"Loading HiFT model from {ckpt_path} on {device}...")
return HiFTInference(ckpt_path, device=device, load_only_nsf=load_only_nsf)
31 changes: 24 additions & 7 deletions utils/tts_model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import queue
from typing import List, Tuple, Generator, Optional, Union

import torch
import numpy as np
from typing import List, Tuple, Generator, Optional, Union

from utils.vocos_util import load_vocos_jit
from utils.hift_util import load_hift

class Token2Wav:
def __init__(self, flow, sample_rate: int = 24000, device: str = "cuda"):
def __init__(self, flow, sample_rate: int = 24000, device: str = "cuda",ckpt_path: str = ""):
self.device = device
self.flow = flow
self.input_frame_rate = flow.input_frame_rate
Expand All @@ -28,11 +31,11 @@ def __init__(self, flow, sample_rate: int = 24000, device: str = "cuda"):
if sample_rate == 32000:
self.hop_size = 640
self.sample_rate = 32000
self.vocoder = load_vocos_jit(device)
self.vocoder = load_vocos_jit(device, ckpt_path)
elif sample_rate == 24000:
self.hop_size = 480
self.sample_rate = 24000
self.vocoder = load_hift(device)
self.vocoder = load_hift(device, ckpt_path=ckpt_path)
else:
raise ValueError(f"Unsupported sample_rate: {sample_rate}")

Expand All @@ -44,7 +47,9 @@ def token2wav_stream(self,
embedding: Optional[torch.Tensor] = None,
prompt_token_list: Optional[torch.Tensor] = None,
prompt_feat_td: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[float], List[float], List[np.ndarray]]:
n_timesteps: int = 10,
queue: queue.Queue = None,
) -> Tuple[torch.Tensor, List[float], List[float], List[np.ndarray], list[torch.Tensor]]:

if not isinstance(syn_token, list):
raise TypeError("syn_token must be a list.")
Expand Down Expand Up @@ -73,6 +78,7 @@ def token2wav_stream(self,
prompt_token=prompt_token_list.to(self.device),
prompt_feat=prompt_feat_td.to(self.device),
embedding=embedding.to(self.device),
n_timesteps=n_timesteps,
last_step_cache=diff_cache,
is_causal=True,
block_pattern=[len(prompt_token_list)] + block_sizes
Expand Down Expand Up @@ -112,12 +118,16 @@ def token2wav_stream(self,
if i == 0:
if len(chunked_list) == 1:
result_wav_list.append(wav_npy)
if queue is not None:
queue.put_nowait(wav_npy)
continue

# 1. Non-overlap area, safe to return/play
# Ensure we don't slice with negative index if wav is too short
valid_len = max(0, len(wav_npy) - overlap_len)
result_wav_list.append(wav_npy[:valid_len])
if queue is not None:
queue.put_nowait(wav_npy[:valid_len])
wav_len_pointer += len(result_wav_list[-1])

# 2. Fade area, stored for next iteration
Expand All @@ -143,16 +153,23 @@ def token2wav_stream(self,
# Case 3: Last chunk
if i == len(chunked_list) - 1:
result_wav_list.append(current_wav)
if queue is not None:
queue.put_nowait(current_wav)
break

# 3. Return content (minus the overlap for the next chunk)
valid_len = max(0, len(current_wav) - overlap_len)
result_wav_list.append(current_wav[:valid_len])
if queue is not None:
queue.put_nowait(current_wav[:valid_len])
wav_len_pointer += len(result_wav_list[-1])

# 4. Update fade area for the next iteration
last_fade_out_array = current_wav[-overlap_len:-look_back_len] if look_back_len > 0 else current_wav[-overlap_len:]

if queue is not None:
queue.put_nowait(None) # Signal completion

# Statistics: length of each segment
sec_list = [len(wav) / self.sample_rate for wav in result_wav_list]

Expand Down Expand Up @@ -190,10 +207,10 @@ def token2wav_stream(self,
mel_big = mel_big[:, :, -overlap_mel_len:]

diff = self.calc_ratio(mel_small, mel_big) * 100
# print(f"Chunk {i}: diff:{diff :.2f}%") # Optional logging
print(f"Chunk {i}: diff:{diff :.2f}%") # Optional logging
diff_list.append(diff)

return wav_bt, sec_list, diff_list, result_wav_list
return wav_bt, sec_list, diff_list, result_wav_list, mel_list

def token2wav_with_cache(self,
token_bt: Union[List[int], np.ndarray, torch.Tensor],
Expand Down
4 changes: 2 additions & 2 deletions utils/vocos_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def stft_mel(self, xs: torch.Tensor) -> torch.Tensor:
xs_mel_tr: torch.Tensor = xs_mel.transpose(-1, -2) - MEL_LOGDIFF
return xs_mel_tr

def load_vocos_jit(device: str = "cuda") -> Vocos2DInference:
def load_vocos_jit(device: str = "cuda", ckpt_path: str = "") -> Vocos2DInference:
"""Factory function to load Vocos model"""
ckpt_path = 'ckpt/vocos2d/generator_jit.ckpt'
ckpt_path = ckpt_path or 'ckpt/vocos2d/generator_jit.ckpt'
print(f"Loading Vocos JIT model from {ckpt_path} on {device}...")
return Vocos2DInference(ckpt_path, device=device)