From e573ee33d5d3a3c292bbba4c81fd75fd81730946 Mon Sep 17 00:00:00 2001 From: rulerman <1330677461@qq.com> Date: Fri, 31 Oct 2025 08:43:13 +0000 Subject: [PATCH 1/4] update v0.7 --- README.md | 26 +- README_zh.md | 25 +- XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml | 101 +++++ XY_Tokenizer/xy_tokenizer/model.py | 21 +- generation_utils.py | 411 +++++++++++++------ inference.py | 6 +- 6 files changed, 427 insertions(+), 163 deletions(-) create mode 100644 XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml diff --git a/README.md b/README.md index e026944..fa02cdc 100644 --- a/README.md +++ b/README.md @@ -28,20 +28,21 @@ MOSS-TTSD supports voice cloning and long single-session speech generation, maki ## Highlights -- **Highly Expressive Dialogue Speech**: Built on unified semantic-acoustic neural audio codec, a pre-trained large language model, millions of hours of TTS data, and 400k hours synthetic and real conversational speech, MOSS-TTSD generates highly expressive, human-like dialogue speech with natural conversational prosody. +- **Highly Expressive Dialogue Speech**: Built on unified semantic-acoustic neural audio codec, a pre-trained large language model, millions of hours of TTS data, and 600k hours synthetic and real conversational speech, MOSS-TTSD generates highly expressive, human-like dialogue speech with natural conversational prosody. - **Two-Speaker Voice Cloning**: MOSS-TTSD supports zero-shot two speakers voice cloning and can generate conversational speech with accurate speaker swithcing based on dialogue scripts. Only 10 to 20 seconds of reference audio is needed. - **Chinese-English Bilingual Support**: MOSS-TTSD enables highly expressive speech generation in both Chinese and English. -- **Long-Form Speech Generation**: Thanks to low-bitrate codec and training framework optimization, MOSS-TTSD has been trained for long speech generation (Training maximum length is 960s). +- **Long-Form Speech Generation**: Thanks to low-bitrate codec and training framework optimization, MOSS-TTSD has been trained for long speech generation (Training maximum length is 1700s). - **Fully Open Source & Commercial-Ready**: MOSS-TTSD and its future updates will be fully open-source and support free commercial use. ## News 🚀 + - **[2025-10-31]** MOSS-TTSD v0.7 is released! v0.7 has significantly improved audio quality, voice cloning capability, and stability, greatly extended single-pass generation length (960s→1700s), and more reliably generates speech events following speaker tags. We recommend using the v0.7 model by default. - **[2025-09-09]** We supported SGLang inference engine to accelerate model inference by up to **16x**. - **[2025-08-25]** We released the 32khz version of XY-Tokenizer. - **[2025-08-12]** We add support for streaming inference in MOSS-TTSD v0.5. - **[2025-07-29]** We provide the SiliconFlow API interface and usage examples for MOSS-TTSD v0.5. - **[2025-07-16]** We open-source the fine-tuning code for MOSS-TTSD v0.5, supporting full-parameter fine-tuning, LoRA fine-tuning, and multi-node training. -- **[2025-07-04]** MOSS-TTSD v0.5 is released! v0.5 has enhanced the accuracy of timbre switching, voice cloning capability, and model stability. We recommend using the v0.5 model by default. +- **[2025-07-04]** MOSS-TTSD v0.5 is released! v0.5 has enhanced the accuracy of timbre switching, voice cloning capability, and model stability. - **[2025-06-20]** MOSS-TTSD v0 is released! Moreover, we provide a podcast generation pipeline named Podever, which can automatically convert PDF, URL, or long text files into high-quality podcasts. ## Installation @@ -62,7 +63,7 @@ You also need to download the XY Tokenizer model weights. You can find the weigh ```bash mkdir -p XY_Tokenizer/weights -huggingface-cli download fnlp/XY_Tokenizer_TTSD_V0_32k xy_tokenizer.ckpt --local-dir ./XY_Tokenizer/weights/ +huggingface-cli download fnlp/MOSS_TTSD_tokenizer MOSS_TTSD_tokenizer --local-dir ./XY_Tokenizer/weights/ ``` ## Usage @@ -89,16 +90,9 @@ Parameters: #### JSONL Input Format -The input JSONL file should contain one JSON object per line. MOSS-TTSD supports multiple input formats: +The input JSONL file should contain one JSON object per line. MOSS-TTSD supports two input formats: -**Format 1: Text-only input (No voice cloning, using the model's random timbre)** -```json -{ - "text": "[S1]Speaker 1 dialogue content[S2]Speaker 2 dialogue content[S1]..." -} -``` - -**Format 2: Separate speaker audio references** +**Format 1: Separate speaker audio references** ```json { "base_path": "/path/to/audio/files", @@ -110,7 +104,7 @@ The input JSONL file should contain one JSON object per line. MOSS-TTSD supports } ``` -**Format 3: Shared audio reference** +**Format 2: Shared audio reference** ```json { "base_path": "/path/to/audio/files", @@ -126,11 +120,11 @@ The input JSONL file should contain one JSON object per line. MOSS-TTSD supports - `text`: Dialogue script with speaker tags `[S1]` and `[S2]` indicating speaker turns (required) - `base_path`: Base directory path for relative file paths (optional) -**For voice cloning (Format 2):** +**For voice cloning (Format 1):** - `prompt_audio_speaker1/2`: Path to reference audio files for voice cloning (relative to `base_path`) - `prompt_text_speaker1/2`: Reference text corresponding to the audio prompts for better voice matching -**For shared reference (Format 3):** +**For shared reference (Format 2):** - `prompt_audio`: Path to shared reference audio file containing both speakers' voices (relative to `base_path`) - `prompt_text`: Reference text corresponding to the audio, also using `[S1]` and `[S2]` tags to distinguish speakers diff --git a/README_zh.md b/README_zh.md index e626eab..5b19669 100644 --- a/README_zh.md +++ b/README_zh.md @@ -26,14 +26,15 @@ MOSS-TTSD(text to spoken dialogue)是一个开源的中英双语口语对话 ## 亮点 -- **高表现力对话语音**:基于统一语义-声学神经音频Codec、预训练大语言模型、百万小时TTS数据与约40万小时的真实/合成对话语音数据,MOSS-TTSD能够生成高表现力,高自然度,具有自然对话韵律的拟人对话语音。 +- **高表现力对话语音**:基于统一语义-声学神经音频Codec、预训练大语言模型、百万小时TTS数据与约60万小时的真实/合成对话语音数据,MOSS-TTSD能够生成高表现力,高自然度,具有自然对话韵律的拟人对话语音。 - **双说话人零样本声音克隆**:MOSS-TTSD支持零样本双说话人克隆,按脚本精确进行角色/声线切换。只需要提供10到20秒的参考音频片段。 - **中英双语**:MOSS-TTSD支持中英两种语言的高表现力语音生成。 -- **长音频生成**:得益于低码率Codec与训练框架优化,MOSS-TTSD在长音频生成场景进行了大量训练(训练最大长度达到960s),能够单次生成超长音频。 +- **长音频生成**:得益于低码率Codec与训练框架优化,MOSS-TTSD在长音频生成场景进行了大量训练(训练最大长度达到1700s),能够单次生成超长音频。 - **开源可商用**:当前与后续版本将保持开源,支持免费商用。 ## 最新动态 🚀 +- **[2025-10-31]** 我们发布了 MOSS-TTSD v0.7:显著提升了音质、声音克隆能力与稳定性,大幅拓展了单次生成长度(960s->1700s),更够比较稳定地根据说话人标签生成语音事件。 - **[2025-09-09]** 我们支持了 SGLang 推理引擎加速模型推理,最高可加速**16倍**。 - **[2025-08-25]** 我们发布了 32khz XY-Tokenizer。 - **[2025-08-12]** 我们支持了 MOSS-TTSD v0.5 的流式推理。 @@ -60,7 +61,7 @@ pip install flash-attn ```bash mkdir -p XY_Tokenizer/weights -huggingface-cli download fnlp/XY_Tokenizer_TTSD_V0_32k xy_tokenizer.ckpt --local-dir ./XY_Tokenizer/weights/ +huggingface-cli download fnlp/MOSS_TTSD_tokenizer MOSS_TTSD_tokenizer --local-dir ./XY_Tokenizer/weights/ ``` ## 使用方法 @@ -87,17 +88,9 @@ python inference.py --jsonl examples/examples.jsonl --output_dir outputs --seed #### JSONL 输入格式 -MOSS-TTSD支持多种输入格式: +MOSS-TTSD 支持两种输入格式: -**格式1:仅文本(不进行声音克隆,使用模型随机音色)** - -```json -{ - "text": "[S1]说话人1的内容[S2]说话人2的内容[S1]..." -} -``` - -**格式2:分别提供两位说话人的参考音频** +**格式1:分别提供两位说话人的参考音频** ```json { @@ -110,7 +103,7 @@ MOSS-TTSD支持多种输入格式: } ``` -**格式3:共享参考音频(一个参考音频包含两个说话人的内容)** +**格式2:共享参考音频(一个参考音频包含两个说话人的内容)** ```json { @@ -128,12 +121,12 @@ MOSS-TTSD支持多种输入格式: - `text`:带 `[S1]`、`[S2]` 说话人标签的对话脚本(必填) - `base_path`:相对路径的基准目录(可选) -**用于声音克隆(格式2):** +**用于声音克隆(格式1):** - `prompt_audio_speaker1/2`:两位说话人的参考音频(可相对 `base_path`) - `prompt_text_speaker1/2`:对应参考音频的文本,有助于更好匹配音色 -**用于共享参考(格式3):** +**用于共享参考(格式2):** - `prompt_audio`:包含两位说话人的共享参考音频(可相对 `base_path`) - `prompt_text`:对应的参考文本,亦使用 `[S1]`、`[S2]` 区分 diff --git a/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml b/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml new file mode 100644 index 0000000..ce539c3 --- /dev/null +++ b/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml @@ -0,0 +1,101 @@ +generator_params: + input_sample_rate: 16000 + output_sample_rate: 32000 + + # Codec / model architecture (inference required) + semantic_encoder_kwargs: # 100hz -> 50hz + num_mel_bins: 80 + sampling_rate: 16000 + hop_length: 160 + stride_size: 2 + kernel_size: 3 + d_model: 768 + scale_embedding: false + max_audio_seconds: 30 + encoder_layers: 12 + encoder_attention_heads: 12 + encoder_ffn_dim: 3072 + activation_function: "gelu" + + semantic_encoder_adapter_kwargs: # 50hz + input_dim: 768 + output_dim: 768 + d_model: 768 + max_source_positions: 1500 + encoder_layers: 4 + encoder_attention_heads: 12 + encoder_ffn_dim: 3072 + + acoustic_encoder_kwargs: # 100hz -> 50hz + num_mel_bins: 80 + sampling_rate: 16000 + hop_length: 160 + stride_size: 2 + kernel_size: 3 + d_model: 768 + scale_embedding: false + max_audio_seconds: 30 + encoder_layers: 12 + encoder_attention_heads: 12 + encoder_ffn_dim: 3072 + activation_function: "gelu" + + pre_rvq_adapter_kwargs: # 50hz + input_dim: 1536 + output_dim: 768 + d_model: 768 + max_source_positions: 1500 + encoder_layers: 4 + encoder_attention_heads: 12 + encoder_ffn_dim: 3072 + + downsample_kwargs: # 50hz -> 12.5hz + d_model: 768 + avg_pooler: 4 + + quantizer_kwargs: # 12.5hz + input_dim: 3072 + rvq_dim: 512 + output_dim: 3072 + num_quantizers: 8 + codebook_size: 1024 + codebook_dim: 512 + quantizer_dropout: 0.0 + commitment: 1 + + post_rvq_adapter_kwargs: # 12.5hz + input_dim: 3072 + output_dim: 3072 + d_model: 768 + max_source_positions: 375 + encoder_layers: 4 + encoder_attention_heads: 12 + encoder_ffn_dim: 3072 + + upsample_kwargs: # 12.5hz -> 50hz + d_model: 768 + stride: 4 + + acoustic_decoder_kwargs: # 50hz -> 100hz + num_mel_bins: 80 + sampling_rate: 16000 + hop_length: 160 + stride_size: 2 + kernel_size: 3 + d_model: 768 + scale_embedding: false + max_audio_seconds: 30 + decoder_layers: 12 + decoder_attention_heads: 12 + decoder_ffn_dim: 3072 + activation_function: "gelu" + + vocos_kwargs: # 100hz -> 32khz + input_channels: 80 + dim: 512 + intermediate_dim: 4096 + num_layers: 30 + n_fft: 1280 + hop_size: 320 + padding: "same" + diff --git a/XY_Tokenizer/xy_tokenizer/model.py b/XY_Tokenizer/xy_tokenizer/model.py index fa1a672..4c51d5d 100644 --- a/XY_Tokenizer/xy_tokenizer/model.py +++ b/XY_Tokenizer/xy_tokenizer/model.py @@ -17,8 +17,8 @@ def __init__(self, generator_params): self.input_sample_rate = generator_params['input_sample_rate'] self.output_sample_rate = generator_params['output_sample_rate'] - self.encoder_downsample_rate = generator_params['encoder_downsample_rate'] - self.decoder_upsample_rate = generator_params['decoder_upsample_rate'] + self.encoder_downsample_rate = 1280 + self.decoder_upsample_rate = int(self.output_sample_rate / 12.5) self.code_dim = generator_params['quantizer_kwargs']['input_dim'] ## Codec part @@ -49,7 +49,22 @@ def __init__(self, generator_params): self.enhanced_vocos = Vocos(**generator_params['vocos_kwargs']) ## Feature extractor - self.feature_extractor = MelFeatureExtractor(**generator_params['feature_extractor_kwargs']) + default_feature_extractor_kwargs = { + 'chunk_length': 30, + 'feature_size': 80, + 'hop_length': 160, + 'n_fft': 400, + 'n_samples': 480000, + 'nb_max_frames': 3000, + 'padding_side': 'right', + 'padding_value': 0.0, + 'return_attention_mask': False, + 'sampling_rate': self.input_sample_rate, + } + fe_kwargs = generator_params.get('feature_extractor_kwargs', {}) + merged_fe_kwargs = {**default_feature_extractor_kwargs, **fe_kwargs} + merged_fe_kwargs['sampling_rate'] = self.input_sample_rate + self.feature_extractor = MelFeatureExtractor(**merged_fe_kwargs) @torch.inference_mode() def inference_tokenize(self, x, input_lengths): diff --git a/generation_utils.py b/generation_utils.py index 2d69f5a..b7c2d09 100644 --- a/generation_utils.py +++ b/generation_utils.py @@ -8,6 +8,66 @@ MAX_CHANNELS = 8 +def pad_or_truncate_to_seconds(wav: torch.Tensor, target_seconds: float, sr: int) -> torch.Tensor: + """Pad or truncate a mono waveform to target length in seconds. + + Args: + wav: (1, T) or (T,) tensor + target_seconds: target duration in seconds + sr: sample rate + Returns: + (1, T_target) tensor + """ + if wav.dim() == 2 and wav.shape[0] == 1: + wav_1d = wav.squeeze(0) + else: + wav_1d = wav.reshape(-1) + target_len = int(round(target_seconds * sr)) + cur_len = wav_1d.shape[-1] + if cur_len == target_len: + out = wav_1d + elif cur_len > target_len: + out = wav_1d[:target_len] + else: + pad_len = target_len - cur_len + out = torch.cat( + [wav_1d, torch.zeros(pad_len, dtype=wav_1d.dtype, device=wav_1d.device)], dim=-1 + ) + return out.unsqueeze(0) + + +def crossfade_concat(segments: list, sample_rate: int, crossfade_seconds: float = 0.1) -> torch.Tensor: + """Concatenate segments with linear crossfade. + + Args: + segments: list of (1, T) tensors + sample_rate: sampling rate + crossfade_seconds: overlap time for crossfade + Returns: + (1, T_total) tensor + """ + if len(segments) == 0: + return torch.zeros(1, 0) + if len(segments) == 1: + return segments[0] + out = segments[0] + cf_len_target = int(round(crossfade_seconds * sample_rate)) + for k in range(1, len(segments)): + nxt = segments[k] + if cf_len_target <= 0: + out = torch.cat([out, nxt], dim=-1) + continue + cf_len = min(cf_len_target, out.shape[-1], nxt.shape[-1]) + if cf_len <= 0: + out = torch.cat([out, nxt], dim=-1) + continue + fade_out = torch.linspace(1.0, 0.0, steps=cf_len, dtype=out.dtype, device=out.device) + fade_in = torch.linspace(0.0, 1.0, steps=cf_len, dtype=nxt.dtype, device=nxt.device) + overlap = out[0, -cf_len:] * fade_out + nxt[0, :cf_len] * fade_in + out = torch.cat([out[:, :-cf_len], overlap.unsqueeze(0), nxt[:, cf_len:]], dim=-1) + return out + + def load_model( model_path, spt_config_path, @@ -34,71 +94,72 @@ def load_model( def process_jsonl_item(item): - """Process JSONL data items and extract audio and text information according to the new format""" - base_path = item.get("base_path", "") + """Parse a JSONL item enforcing prompt requirement. + + Only supports Format 1 (separate speaker refs) and Format 2 (shared ref), + consistent with the updated README. If `base_path` is missing/empty, any + string paths must be absolute. Text-only input is not supported and will raise. + """ + base_path = item.get("base_path", "") or "" text = item.get("text", "") + def _resolve_path(p: str) -> str: + if not isinstance(p, str) or not p: + return p + if base_path: + return os.path.join(base_path, p) + # base_path missing: require absolute path + if not os.path.isabs(p): + raise ValueError( + "When base_path is omitted, audio paths must be absolute. Got: " + p + ) + return p + + # Try Format 2 first: shared audio reference prompt_audio = None prompt_text = "" - - # Process prompt audio and text - if "prompt_audio" in item and "prompt_text" in item: - print("Using prompt_audio and prompt_text directly from item.") - # If prompt_audio and prompt_text exist, use them directly - prompt_audio_val = item["prompt_audio"] - if prompt_audio_val: # Only assign if not empty + if "prompt_audio" in item: + prompt_audio_val = item.get("prompt_audio") + if not prompt_audio_val: + raise ValueError("Format 2 requires non-empty 'prompt_audio'.") + if isinstance(prompt_audio_val, str): + prompt_audio = _resolve_path(prompt_audio_val) + else: + # allow tuple form for backward-compatibility prompt_audio = prompt_audio_val - prompt_text = item["prompt_text"] - - # Only perform path joining when prompt_audio is a string path - if isinstance(prompt_audio, str) and base_path and prompt_audio: - prompt_audio = os.path.join(base_path, prompt_audio) - else: - # Otherwise, merge speaker1 and speaker2 information - prompt_audio_speaker1 = item.get("prompt_audio_speaker1", "") - prompt_text_speaker1 = item.get("prompt_text_speaker1", "") - prompt_audio_speaker2 = item.get("prompt_audio_speaker2", "") - prompt_text_speaker2 = item.get("prompt_text_speaker2", "") - - has_speaker1_audio = ( - isinstance(prompt_audio_speaker1, str) and prompt_audio_speaker1 - ) or isinstance(prompt_audio_speaker1, tuple) - has_speaker2_audio = ( - isinstance(prompt_audio_speaker2, str) and prompt_audio_speaker2 - ) or isinstance(prompt_audio_speaker2, tuple) - - if has_speaker1_audio or has_speaker2_audio: - print("Using speaker1 and speaker2 information for prompt audio and text.") - # Process audio: if it's a string path, perform path joining; if it's a tuple, use directly - if isinstance(prompt_audio_speaker1, str): - speaker1_audio = ( - os.path.join(base_path, prompt_audio_speaker1) - if base_path and prompt_audio_speaker1 - else prompt_audio_speaker1 - ) - else: - speaker1_audio = prompt_audio_speaker1 # Use tuple directly - - if isinstance(prompt_audio_speaker2, str): - speaker2_audio = ( - os.path.join(base_path, prompt_audio_speaker2) - if base_path and prompt_audio_speaker2 - else prompt_audio_speaker2 - ) - else: - speaker2_audio = prompt_audio_speaker2 # Use tuple directly - - prompt_audio = {"speaker1": speaker1_audio, "speaker2": speaker2_audio} - - # Merge text - temp_prompt_text = "" - if prompt_text_speaker1: - temp_prompt_text += f"[S1]{prompt_text_speaker1}" - if prompt_text_speaker2: - temp_prompt_text += f"[S2]{prompt_text_speaker2}" - prompt_text = temp_prompt_text.strip() - - return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} + prompt_text = item.get("prompt_text", "") + return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} + + # Try Format 1: separate speaker references + s1 = item.get("prompt_audio_speaker1", "") + s2 = item.get("prompt_audio_speaker2", "") + has_s1 = ((isinstance(s1, str) and s1) or isinstance(s1, tuple)) + has_s2 = ((isinstance(s2, str) and s2) or isinstance(s2, tuple)) + + if has_s1 and has_s2: + if isinstance(s1, str) and s1: + s1_resolved = _resolve_path(s1) + else: + s1_resolved = s1 + if isinstance(s2, str) and s2: + s2_resolved = _resolve_path(s2) + else: + s2_resolved = s2 + # Build merged prompt audio dict + prompt_audio = {"speaker1": s1_resolved, "speaker2": s2_resolved} + # Merge texts + pt1 = item.get("prompt_text_speaker1", "") + pt2 = item.get("prompt_text_speaker2", "") + merged = "" + if pt1: + merged += f"[S1]{pt1}" + if pt2: + merged += f"[S2]{pt2}" + prompt_text = merged.strip() + return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} + + # Otherwise, no supported prompt found → reject (text-only unsupported) + raise ValueError("Input must include prompt (Format 1 or 2). Text-only is not supported.") def load_audio_data(prompt_audio, target_sample_rate=16000): @@ -288,25 +349,22 @@ def normalize_text(text: str) -> str: Normalize multi-speaker script. 1. Don't preserve line breaks. - 2. Remove brackets for non-speaker tags (if [] doesn't contain S1/S2...Sx format, remove the brackets themselves). - 3. Remove decorative symbols: 【】《》()『』「」"-“” . - 4. Internal punctuation !;:、 → ,;only allow ? and ,。 + 2. Preserve bracketed segments like [] () <> even when they are not speaker tags. + 3. Remove decorative symbols: 【】《》()『』「」~~-_. + 4. Internal punctuation ;:、 → ,;keep ?!?. 5. Multiple 。 keep only the last one, others → ,。 6. Replace consecutive "哈" (>=2) with "(笑)". 7. Auto-recognize [S1] / [S2] … tags; if missing, treat as whole segment. 8. Merge adjacent identical speaker tags. """ # Replace [1], [2] etc. format with [S1], [S2] etc. format - text = re.sub(r"\[(\d+)\]", r"[S\1]", text) + text = re.sub(r'\[(\d+)\]', r'[S\1]', text) # Remove decorative characters - remove_chars = "【】《》()『』「」" '"-_“”~~' - - # Remove brackets for non-speaker tags (keep content, only remove brackets themselves) - text = re.sub(r"\[(?!S\d+\])([^\]]*)\]", r"\1", text) + remove_chars = "【】《》()『』「」~~-_" # Use positive lookahead to split text by speaker tags (tags themselves are still preserved) - segments = re.split(r"(?=\[S\d+\])", text.replace("\n", " ")) + segments = re.split(r'(?=\[S\d+\])', text.replace("\n", " ")) processed_parts = [] for seg in segments: @@ -315,70 +373,58 @@ def normalize_text(text: str) -> str: continue # Extract tags - m = re.match(r"^(\[S\d+\])\s*(.*)", seg) - tag, content = m.groups() if m else ("", seg) + m = re.match(r'^(\[S\d+\])\s*(.*)', seg) + tag, content = m.groups() if m else ('', seg) # Remove irrelevant symbols content = re.sub(f"[{re.escape(remove_chars)}]", "", content) # Handle consecutive "哈" characters: replace 2 or more with "(笑)" - content = re.sub(r"哈{2,}", "(笑)", content) + content = re.sub(r'哈{2,}', '[笑]', content) # Handle English laughter (e.g., "haha", "ha ha") - content = re.sub(r"\b(ha(\s*ha)+)\b", "(laughs)", content, flags=re.IGNORECASE) + content = re.sub(r'\b(ha(\s*ha)+)\b', '[laugh]', content, flags=re.IGNORECASE) # First handle multi-character punctuation marks - content = content.replace("——", ",") - content = content.replace("……", ",") + content = content.replace('——', ',') + content = content.replace('……', ',') # Handle single-character internal punctuation marks - internal_punct_map = str.maketrans( - { - "!": ",", - "!": ",", - ";": ",", - ";": ",", - ":": ",", - ":": ",", - "、": ",", - "?": ",", - "?": ",", - } - ) + internal_punct_map = str.maketrans({ + ';': ',', ';': ',', + ':': ',', ':': ',', + '、': ',' + }) content = content.translate(internal_punct_map) content = content.strip() # Keep only the final period if len(content) > 1: - last_ch = ( - "。" - if content[-1] == "," - else ("." if content[-1] == "," else content[-1]) - ) - body = content[:-1].replace("。", ",") + last_ch = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1]) + body = content[:-1].replace('。', ',') content = body + last_ch - processed_parts.append({"tag": tag, "content": content}) + processed_parts.append({'tag': tag, 'content': content}) if not processed_parts: return "" # Merge consecutive same speakers merged_lines = [] - current_tag = processed_parts[0]["tag"] - current_content = [processed_parts[0]["content"]] + current_tag = processed_parts[0]['tag'] + current_content = [processed_parts[0]['content']] for part in processed_parts[1:]: - if part["tag"] == current_tag and current_tag: - current_content.append(part["content"]) + if part['tag'] == current_tag and current_tag: + current_content.append(part['content']) else: merged_lines.append(f"{current_tag}{''.join(current_content)}".strip()) - current_tag = part["tag"] - current_content = [part["content"]] + current_tag = part['tag'] + current_content = [part['content']] merged_lines.append(f"{current_tag}{''.join(current_content)}".strip()) - - return "".join(merged_lines).replace("‘", "'").replace("’", "'") + + return "".join(merged_lines).replace('‘', "'").replace('’', "'") def process_batch( @@ -507,28 +553,143 @@ def process_batch( f"Speech token shape for sample {start_idx + i}: {this_speech_id.shape}" ) - # Decode generated audio - with torch.no_grad(): - codes_list = [ - this_speech_id.permute(1, 0) - ] # Convert to SPT expected format - decode_result = spt.decode(codes_list, overlap_seconds=10) - audio_result = decode_result["syn_wav_list"][0].cpu().detach() - - if audio_result.ndim == 1: # If 1D [samples] - audio_result = audio_result.unsqueeze( - 0 - ) # Convert to 2D [1, samples] - - # Save audio data instead of file path - audio_results.append( - { - "audio_data": audio_result, - "sample_rate": spt.output_sample_rate, - "index": start_idx + i, - } - ) - print(f"Audio generation completed: sample {start_idx + i}") + # Prompt-Augmented Decode (rvq8-style); fall back to original decode if no prompt + prompt_audio = prompt_audios[i] + if prompt_audio is None: + # Fallback to original decode + with torch.no_grad(): + codes_list = [this_speech_id.permute(1, 0)] + decode_result = spt.decode(codes_list, overlap_seconds=10) + audio_out = decode_result["syn_wav_list"][0].cpu().detach() + if audio_out.ndim == 1: + audio_out = audio_out.unsqueeze(0) + audio_results.append( + { + "audio_data": audio_out, + "sample_rate": spt.output_sample_rate, + "index": start_idx + i, + } + ) + print(f"Audio generation completed (orig): sample {start_idx + i}") + else: + # 1) Load prompt at SPT input sr and force to 20s + ref_sr_in = ( + getattr(spt, "input_sample_rate", None) + or getattr(spt, "sampling_rate", None) + or 24000 + ) + ref_wav = load_audio_data(prompt_audio, target_sample_rate=ref_sr_in) + if ref_wav is None: + # If ref missing, use original decode + with torch.no_grad(): + codes_list = [this_speech_id.permute(1, 0)] + decode_result = spt.decode(codes_list, overlap_seconds=10) + audio_out = decode_result["syn_wav_list"][0].cpu().detach() + if audio_out.ndim == 1: + audio_out = audio_out.unsqueeze(0) + audio_results.append( + { + "audio_data": audio_out, + "sample_rate": spt.output_sample_rate, + "index": start_idx + i, + } + ) + print(f"Audio generation completed (orig no-ref): sample {start_idx + i}") + else: + # Encode 20s reference to tokens + ref_wav_20s = pad_or_truncate_to_seconds(ref_wav, 20.0, ref_sr_in).to(device) + with torch.no_grad(): + enc = spt.encode([ref_wav_20s.squeeze(0)]) + ref_codes = enc["codes_list"][0].to(device).long() # (nq, T_ref) + + # Prepare token-to-sample mapping and windowing params + out_sr = ( + getattr(spt, "output_sample_rate", None) + or getattr(spt, "sample_rate", None) + or 24000 + ) + tokens_per_second = float(ref_sr_in) / float(spt.encoder_downsample_rate) + tokens_per_chunk = int(round(10.0 * tokens_per_second)) + stride_tokens = 85 + keep_tokens = 85 + left_ctx_tokens = 20 + total_tokens = this_speech_id.shape[0] + samples_per_token = int(round(out_sr / tokens_per_second)) + crossfade_seconds = 0.1 + crossfade_samples = int(round(crossfade_seconds * out_sr)) + + kept_segments = [] + chunk_idx = 0 + while True: + st_tok = chunk_idx * stride_tokens + if st_tok >= total_tokens: + break + ed_tok = min(st_tok + tokens_per_chunk, total_tokens) + gen_chunk = this_speech_id[st_tok:ed_tok] # (len, C) + if gen_chunk.shape[0] == 0: + break + + # Concatenate reference tokens with current window tokens + combined_codes = torch.cat( + [ref_codes, gen_chunk.permute(1, 0).long()], dim=1 + ).to(device) # (nq, T_ref + T_chunk) + codes_lengths = torch.tensor( + [combined_codes.shape[-1]], dtype=torch.long, device=device + ) + combined_codes_batched = combined_codes.unsqueeze(1) # (nq, 1, T) + + with torch.no_grad(): + detok = spt.inference_detokenize(combined_codes_batched, codes_lengths) + y = detok["y"][0, 0] # (T_samples) + + # Remove 20s reference portion (in samples) + ref_samples = int(round(20.0 * out_sr)) + if y.shape[-1] <= ref_samples: + chunk_idx += 1 + continue + chunk_y = y[ref_samples:] + + # Determine kept region within current window + window_len = gen_chunk.shape[0] + remains = total_tokens - st_tok + is_first = chunk_idx == 0 + is_last = ed_tok >= total_tokens + + if is_first: + keep_start_tok = 0 + keep_end_tok = min(keep_tokens + left_ctx_tokens, window_len) + elif is_last and remains < 105: + keep_start_tok = 0 if is_first else min(left_ctx_tokens, window_len) + keep_end_tok = window_len + else: + keep_start_tok = min(left_ctx_tokens, window_len) + keep_end_tok = min(left_ctx_tokens + keep_tokens, window_len) + + keep_start_smps = keep_start_tok * samples_per_token + keep_end_smps = keep_end_tok * samples_per_token + left_margin = 0 + right_margin = crossfade_samples if not is_last else 0 + seg_start = max(0, keep_start_smps - left_margin) + seg_end = min(chunk_y.shape[-1], keep_end_smps + right_margin) + if seg_end > seg_start: + kept_segments.append(chunk_y[seg_start:seg_end].detach().cpu().unsqueeze(0)) + + chunk_idx += 1 + + # Concatenate with crossfade; if empty, return tiny silence + if len(kept_segments) == 0: + audio_out = torch.zeros(1, int(0.01 * out_sr)) + else: + audio_out = crossfade_concat(kept_segments, out_sr, crossfade_seconds=crossfade_seconds) + + audio_results.append( + { + "audio_data": audio_out, + "sample_rate": out_sr, + "index": start_idx + i, + } + ) + print(f"Audio generation completed (prompt-aug): sample {start_idx + i}") except Exception as e: print(f"Error processing sample {start_idx + i}: {str(e)}, skipping...") diff --git a/inference.py b/inference.py index 31f5a30..df4b8de 100644 --- a/inference.py +++ b/inference.py @@ -7,10 +7,10 @@ from generation_utils import load_model, process_batch -MODEL_PATH = "fnlp/MOSS-TTSD-v0.5" +MODEL_PATH = "fnlp/MOSS-TTSD-v0.7" SYSTEM_PROMPT = "You are a speech synthesizer that generates natural, realistic, and human-like conversational audio from dialogue text." -SPT_CONFIG_PATH = "XY_Tokenizer/config/xy_tokenizer_32k_config.yaml" -SPT_CHECKPOINT_PATH = "XY_Tokenizer/weights/xy_tokenizer.ckpt" +SPT_CONFIG_PATH = "XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml" +SPT_CHECKPOINT_PATH ="XY_Tokenizer/weights/MOSS_TTSD_tokenizer" MAX_CHANNELS = 8 def main(): From b5fc87d400eaddae99bdb4380e8ee394deb52072 Mon Sep 17 00:00:00 2001 From: rulerman <1330677461@qq.com> Date: Fri, 31 Oct 2025 16:03:18 +0000 Subject: [PATCH 2/4] update v0.7 --- README.md | 4 ++-- README_zh.md | 4 ++-- XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml | 14 ++++++++++++++ XY_Tokenizer/xy_tokenizer/model.py | 20 +++----------------- generation_utils.py | 2 +- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index fa02cdc..87edd1d 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ MOSS-TTSD supports voice cloning and long single-session speech generation, maki ## Highlights -- **Highly Expressive Dialogue Speech**: Built on unified semantic-acoustic neural audio codec, a pre-trained large language model, millions of hours of TTS data, and 600k hours synthetic and real conversational speech, MOSS-TTSD generates highly expressive, human-like dialogue speech with natural conversational prosody. +- **Highly Expressive Dialogue Speech**: Built on unified semantic-acoustic neural audio codec, a pre-trained large language model, millions of hours of TTS data and conversational speech, MOSS-TTSD generates highly expressive, human-like dialogue speech with natural conversational prosody. - **Two-Speaker Voice Cloning**: MOSS-TTSD supports zero-shot two speakers voice cloning and can generate conversational speech with accurate speaker swithcing based on dialogue scripts. Only 10 to 20 seconds of reference audio is needed. - **Chinese-English Bilingual Support**: MOSS-TTSD enables highly expressive speech generation in both Chinese and English. - **Long-Form Speech Generation**: Thanks to low-bitrate codec and training framework optimization, MOSS-TTSD has been trained for long speech generation (Training maximum length is 1700s). @@ -36,7 +36,7 @@ MOSS-TTSD supports voice cloning and long single-session speech generation, maki ## News 🚀 - - **[2025-10-31]** MOSS-TTSD v0.7 is released! v0.7 has significantly improved audio quality, voice cloning capability, and stability, greatly extended single-pass generation length (960s→1700s), and more reliably generates speech events following speaker tags. We recommend using the v0.7 model by default. + - **[2025-11-01]** MOSS-TTSD v0.7 is released! v0.7 significantly improves audio quality, voice cloning capability, and stability, adds support for 32 kHz high‑quality output, greatly extends single‑pass generation length (960s→1700s), and more reliably generates speech events following speaker tags. We recommend using the v0.7 model by default. - **[2025-09-09]** We supported SGLang inference engine to accelerate model inference by up to **16x**. - **[2025-08-25]** We released the 32khz version of XY-Tokenizer. - **[2025-08-12]** We add support for streaming inference in MOSS-TTSD v0.5. diff --git a/README_zh.md b/README_zh.md index 5b19669..ecf5c79 100644 --- a/README_zh.md +++ b/README_zh.md @@ -26,7 +26,7 @@ MOSS-TTSD(text to spoken dialogue)是一个开源的中英双语口语对话 ## 亮点 -- **高表现力对话语音**:基于统一语义-声学神经音频Codec、预训练大语言模型、百万小时TTS数据与约60万小时的真实/合成对话语音数据,MOSS-TTSD能够生成高表现力,高自然度,具有自然对话韵律的拟人对话语音。 +- **高表现力对话语音**:基于统一语义-声学神经音频Codec、预训练大语言模型、百万小时TTS数据与对话语音数据,MOSS-TTSD能够生成高表现力,高自然度,具有自然对话韵律的拟人对话语音。 - **双说话人零样本声音克隆**:MOSS-TTSD支持零样本双说话人克隆,按脚本精确进行角色/声线切换。只需要提供10到20秒的参考音频片段。 - **中英双语**:MOSS-TTSD支持中英两种语言的高表现力语音生成。 - **长音频生成**:得益于低码率Codec与训练框架优化,MOSS-TTSD在长音频生成场景进行了大量训练(训练最大长度达到1700s),能够单次生成超长音频。 @@ -34,7 +34,7 @@ MOSS-TTSD(text to spoken dialogue)是一个开源的中英双语口语对话 ## 最新动态 🚀 -- **[2025-10-31]** 我们发布了 MOSS-TTSD v0.7:显著提升了音质、声音克隆能力与稳定性,大幅拓展了单次生成长度(960s->1700s),更够比较稳定地根据说话人标签生成语音事件。 +- **[2025-11-01]** 我们发布了 MOSS-TTSD v0.7:显著提升了音质、声音克隆能力与稳定性,支持32khz高音质输出,并大幅拓展了单次生成长度(960s->1700s),更够比较稳定地根据说话人标签生成语音事件。 - **[2025-09-09]** 我们支持了 SGLang 推理引擎加速模型推理,最高可加速**16倍**。 - **[2025-08-25]** 我们发布了 32khz XY-Tokenizer。 - **[2025-08-12]** 我们支持了 MOSS-TTSD v0.5 的流式推理。 diff --git a/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml b/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml index ce539c3..76b7618 100644 --- a/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml +++ b/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml @@ -1,6 +1,20 @@ generator_params: input_sample_rate: 16000 output_sample_rate: 32000 + encoder_downsample_rate: 1280 + decoder_upsample_rate: 2560 + + feature_extractor_kwargs: + chunk_length: 30 + feature_size: 80 + hop_length: 160 + n_fft: 400 + n_samples: 480000 + nb_max_frames: 3000 + padding_side: right + padding_value: 0.0 + return_attention_mask: false + sampling_rate: 16000 # Codec / model architecture (inference required) semantic_encoder_kwargs: # 100hz -> 50hz diff --git a/XY_Tokenizer/xy_tokenizer/model.py b/XY_Tokenizer/xy_tokenizer/model.py index 4c51d5d..971ec93 100644 --- a/XY_Tokenizer/xy_tokenizer/model.py +++ b/XY_Tokenizer/xy_tokenizer/model.py @@ -17,8 +17,8 @@ def __init__(self, generator_params): self.input_sample_rate = generator_params['input_sample_rate'] self.output_sample_rate = generator_params['output_sample_rate'] - self.encoder_downsample_rate = 1280 - self.decoder_upsample_rate = int(self.output_sample_rate / 12.5) + self.encoder_downsample_rate = generator_params['encoder_downsample_rate'] + self.decoder_upsample_rate = generator_params['decoder_upsample_rate'] self.code_dim = generator_params['quantizer_kwargs']['input_dim'] ## Codec part @@ -49,22 +49,8 @@ def __init__(self, generator_params): self.enhanced_vocos = Vocos(**generator_params['vocos_kwargs']) ## Feature extractor - default_feature_extractor_kwargs = { - 'chunk_length': 30, - 'feature_size': 80, - 'hop_length': 160, - 'n_fft': 400, - 'n_samples': 480000, - 'nb_max_frames': 3000, - 'padding_side': 'right', - 'padding_value': 0.0, - 'return_attention_mask': False, - 'sampling_rate': self.input_sample_rate, - } fe_kwargs = generator_params.get('feature_extractor_kwargs', {}) - merged_fe_kwargs = {**default_feature_extractor_kwargs, **fe_kwargs} - merged_fe_kwargs['sampling_rate'] = self.input_sample_rate - self.feature_extractor = MelFeatureExtractor(**merged_fe_kwargs) + self.feature_extractor = MelFeatureExtractor(**fe_kwargs) @torch.inference_mode() def inference_tokenize(self, x, input_lengths): diff --git a/generation_utils.py b/generation_utils.py index b7c2d09..2746a70 100644 --- a/generation_utils.py +++ b/generation_utils.py @@ -361,7 +361,7 @@ def normalize_text(text: str) -> str: text = re.sub(r'\[(\d+)\]', r'[S\1]', text) # Remove decorative characters - remove_chars = "【】《》()『』「」~~-_" + remove_chars = "【】《》()『』「」" '"-_“”~~' # Use positive lookahead to split text by speaker tags (tags themselves are still preserved) segments = re.split(r'(?=\[S\d+\])', text.replace("\n", " ")) From 7d567c1b8b862fe50322e0d2311b0d4647e08ada Mon Sep 17 00:00:00 2001 From: xiami2019 <435350193@qq.com> Date: Mon, 3 Nov 2025 11:44:18 +0800 Subject: [PATCH 3/4] format --- .gitignore | 2 +- LICENSE | 2 +- XY_Tokenizer/.gitignore | 2 +- XY_Tokenizer/README.md | 2 +- XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml | 1 - .../config/xy_tokenizer_32k_config.yaml | 12 +- XY_Tokenizer/config/xy_tokenizer_config.yaml | 12 +- XY_Tokenizer/inference.py | 85 ++- XY_Tokenizer/requirements.txt | 2 +- XY_Tokenizer/utils/helpers.py | 108 ++-- XY_Tokenizer/xy_tokenizer/model.py | 385 +++++++----- .../xy_tokenizer/nn/feature_extractor.py | 82 ++- XY_Tokenizer/xy_tokenizer/nn/modules.py | 577 +++++++++++------- XY_Tokenizer/xy_tokenizer/nn/quantizer.py | 332 ++++++---- examples/example.txt | 2 +- examples/examples.jsonl | 2 +- examples/examples_only_text.jsonl | 2 +- examples/examples_single_reference.jsonl | 2 +- finetune/data_preprocess.py | 307 ++++++---- finetune/finetune.py | 269 +++++--- finetune/finetune_config.yaml | 4 +- finetune/finetune_workflow.py | 104 ++-- finetune/lora_config.yaml | 2 +- finetune/requirements_finetune.txt | 2 +- finetune/training_config.yaml | 2 +- generation_utils.py | 150 +++-- gradio_demo.py | 318 ++++++---- inference.py | 130 ++-- modeling_asteroid.py | 277 ++++++--- podcast_generate.py | 211 ++++--- requirements.txt | 2 +- streamer.py | 248 +++++--- use_api.py | 222 ++++--- 33 files changed, 2460 insertions(+), 1400 deletions(-) diff --git a/.gitignore b/.gitignore index b286de5..beed4f7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ *.pyc -outputs/* \ No newline at end of file +outputs/* diff --git a/LICENSE b/LICENSE index daab1b6..afe8540 100644 --- a/LICENSE +++ b/LICENSE @@ -198,4 +198,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file + limitations under the License. diff --git a/XY_Tokenizer/.gitignore b/XY_Tokenizer/.gitignore index 7804221..d60efb5 100644 --- a/XY_Tokenizer/.gitignore +++ b/XY_Tokenizer/.gitignore @@ -190,4 +190,4 @@ backup* output_wavs/ *.pt *.pth -output.log \ No newline at end of file +output.log diff --git a/XY_Tokenizer/README.md b/XY_Tokenizer/README.md index 1aecd1a..2387a53 100644 --- a/XY_Tokenizer/README.md +++ b/XY_Tokenizer/README.md @@ -69,4 +69,4 @@ The model processes audio through several stages: ## License -[Specify your license here] \ No newline at end of file +[Specify your license here] diff --git a/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml b/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml index 76b7618..b2de9dd 100644 --- a/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml +++ b/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml @@ -112,4 +112,3 @@ generator_params: n_fft: 1280 hop_size: 320 padding: "same" - diff --git a/XY_Tokenizer/config/xy_tokenizer_32k_config.yaml b/XY_Tokenizer/config/xy_tokenizer_32k_config.yaml index 141cdc6..8a33dcc 100644 --- a/XY_Tokenizer/config/xy_tokenizer_32k_config.yaml +++ b/XY_Tokenizer/config/xy_tokenizer_32k_config.yaml @@ -62,7 +62,7 @@ generator_params: activation_function: "gelu" - ## semantic & acoustic shared parameters + ## semantic & acoustic shared parameters pre_rvq_adapter_kwargs: # 50hz input_dim: 1536 output_dim: 768 @@ -71,7 +71,7 @@ generator_params: encoder_layers: 4 encoder_attention_heads: 12 encoder_ffn_dim: 3072 - + downsample_kwargs: # 50hz -> 12.5hz d_model: 768 avg_pooler: 4 @@ -85,7 +85,7 @@ generator_params: codebook_dim: 512 quantizer_dropout: 0.0 commitment: 1 - + post_rvq_adapter_kwargs: # 12.5hz input_dim: 3072 output_dim: 3072 @@ -98,7 +98,7 @@ generator_params: upsample_kwargs: # 12.5hz -> 50hz d_model: 768 stride: 4 - + ## acoustic channel acoustic_decoder_kwargs: # 50hz -> 100hz num_mel_bins: 80 @@ -113,7 +113,7 @@ generator_params: decoder_attention_heads: 12 decoder_ffn_dim: 3072 activation_function: "gelu" - + vocos_kwargs: # 100hz -> 32khz input_channels: 80 dim: 512 @@ -121,4 +121,4 @@ generator_params: num_layers: 30 n_fft: 1280 hop_size: 320 - padding: "same" \ No newline at end of file + padding: "same" diff --git a/XY_Tokenizer/config/xy_tokenizer_config.yaml b/XY_Tokenizer/config/xy_tokenizer_config.yaml index 331c5d7..7232d15 100644 --- a/XY_Tokenizer/config/xy_tokenizer_config.yaml +++ b/XY_Tokenizer/config/xy_tokenizer_config.yaml @@ -62,7 +62,7 @@ generator_params: activation_function: "gelu" - ## semantic & acoustic shared parameters + ## semantic & acoustic shared parameters pre_rvq_adapter_kwargs: # 50hz input_dim: 1536 output_dim: 768 @@ -71,7 +71,7 @@ generator_params: encoder_layers: 4 encoder_attention_heads: 12 encoder_ffn_dim: 3072 - + downsample_kwargs: # 50hz -> 12.5hz d_model: 768 avg_pooler: 4 @@ -85,7 +85,7 @@ generator_params: codebook_dim: 512 quantizer_dropout: 0.0 commitment: 1 - + post_rvq_adapter_kwargs: # 12.5hz input_dim: 3072 output_dim: 3072 @@ -98,7 +98,7 @@ generator_params: upsample_kwargs: # 12.5hz -> 50hz d_model: 768 stride: 4 - + ## acoustic channel acoustic_decoder_kwargs: # 50hz -> 100hz num_mel_bins: 80 @@ -113,7 +113,7 @@ generator_params: decoder_attention_heads: 12 decoder_ffn_dim: 3072 activation_function: "gelu" - + vocos_kwargs: # 100hz -> 24khz input_channels: 80 dim: 512 @@ -121,4 +121,4 @@ generator_params: num_layers: 30 n_fft: 960 hop_size: 240 - padding: "same" \ No newline at end of file + padding: "same" diff --git a/XY_Tokenizer/inference.py b/XY_Tokenizer/inference.py index 689893a..d01ea7a 100644 --- a/XY_Tokenizer/inference.py +++ b/XY_Tokenizer/inference.py @@ -1,27 +1,40 @@ -import os import argparse import logging -import torch +import os -from utils.helpers import set_logging, waiting_for_debug, load_audio, save_audio, find_audio_files +import torch +from utils.helpers import ( + find_audio_files, + load_audio, + save_audio, + set_logging, + waiting_for_debug, +) from xy_tokenizer.model import XY_Tokenizer if __name__ == "__main__": set_logging() - + parser = argparse.ArgumentParser() - parser.add_argument("--config_path", type=str, default="./config/xy_tokenizer_32k_config.yaml") - parser.add_argument("--checkpoint_path", type=str, default="./weights/xy_tokenizer.ckpt") + parser.add_argument( + "--config_path", type=str, default="./config/xy_tokenizer_32k_config.yaml" + ) + parser.add_argument( + "--checkpoint_path", type=str, default="./weights/xy_tokenizer.ckpt" + ) parser.add_argument("--device", type=str, default="cuda") - + parser.add_argument("--input_dir", type=str, required=True) parser.add_argument("--output_dir", type=str, required=True) - - + parser.add_argument("--debug_ip", type=str) parser.add_argument("--debug_port", type=int) - parser.add_argument("--debug", default=0, type=int, nargs="?", - help='whether debug or not', + parser.add_argument( + "--debug", + default=0, + type=int, + nargs="?", + help="whether debug or not", ) args = parser.parse_args() if args.debug == 1: @@ -30,42 +43,66 @@ device = torch.device(args.device) ## Load codec model - generator = XY_Tokenizer.load_from_checkpoint(config_path=args.config_path, ckpt_path=args.checkpoint_path).to(device).eval() - + generator = ( + XY_Tokenizer.load_from_checkpoint( + config_path=args.config_path, ckpt_path=args.checkpoint_path + ) + .to(device) + .eval() + ) + ## Find audios audio_paths = find_audio_files(input_dir=args.input_dir) - + ## Create output directory if not exists os.makedirs(args.output_dir, exist_ok=True) - logging.info(f"Processing {len(audio_paths)} audio files, output will be saved to {args.output_dir}") + logging.info( + f"Processing {len(audio_paths)} audio files, output will be saved to {args.output_dir}" + ) with torch.no_grad(): ## Process audios in batches batch_size = 8 for i in range(0, len(audio_paths), batch_size): - batch_paths = audio_paths[i:i + batch_size] - logging.info(f"Processing batch {i // batch_size + 1}/{len(audio_paths) // batch_size + 1}, files: {batch_paths}") + batch_paths = audio_paths[i : i + batch_size] + logging.info( + f"Processing batch {i // batch_size + 1}/{len(audio_paths) // batch_size + 1}, files: {batch_paths}" + ) # Load audio files - wav_list = [load_audio(path, target_sample_rate=generator.input_sample_rate).squeeze().to(device) for path in batch_paths] - logging.info(f"Successfully loaded {len(wav_list)} audio files with lengths {[len(wav) for wav in wav_list]} samples") + wav_list = [ + load_audio(path, target_sample_rate=generator.input_sample_rate) + .squeeze() + .to(device) + for path in batch_paths + ] + logging.info( + f"Successfully loaded {len(wav_list)} audio files with lengths {[len(wav) for wav in wav_list]} samples" + ) # Encode encode_result = generator.encode(wav_list, overlap_seconds=10) codes_list = encode_result["codes_list"] # B * (nq, T) - logging.info(f"Encoding completed, code lengths: {[codes.shape[-1] for codes in codes_list]}") + logging.info( + f"Encoding completed, code lengths: {[codes.shape[-1] for codes in codes_list]}" + ) logging.info(f"{codes_list = }") # Decode decode_result = generator.decode(codes_list, overlap_seconds=10) syn_wav_list = decode_result["syn_wav_list"] # B * (T,) - logging.info(f"Decoding completed, generated waveform lengths: {[len(wav) for wav in syn_wav_list]} samples") + logging.info( + f"Decoding completed, generated waveform lengths: {[len(wav) for wav in syn_wav_list]} samples" + ) # Save generated audios for path, syn_wav in zip(batch_paths, syn_wav_list): output_path = os.path.join(args.output_dir, os.path.basename(path)) - save_audio(output_path, syn_wav.cpu().reshape(1, -1), sample_rate=generator.output_sample_rate) + save_audio( + output_path, + syn_wav.cpu().reshape(1, -1), + sample_rate=generator.output_sample_rate, + ) logging.info(f"Saved generated audio to {output_path}") - - logging.info("All audio processing completed") \ No newline at end of file + logging.info("All audio processing completed") diff --git a/XY_Tokenizer/requirements.txt b/XY_Tokenizer/requirements.txt index fba07a1..c16cb4f 100644 --- a/XY_Tokenizer/requirements.txt +++ b/XY_Tokenizer/requirements.txt @@ -20,4 +20,4 @@ stopes s3prl onnxscript jiwer -orjson \ No newline at end of file +orjson diff --git a/XY_Tokenizer/utils/helpers.py b/XY_Tokenizer/utils/helpers.py index 9b144a4..e7ef1c8 100644 --- a/XY_Tokenizer/utils/helpers.py +++ b/XY_Tokenizer/utils/helpers.py @@ -1,61 +1,72 @@ +import glob import logging -import torchaudio import os +import re import sys -import glob + import debugpy -import torch import numpy as np -import re +import torch +import torchaudio + def count_params_by_module(model_name, model): logging.info(f"Counting num_parameters of {model_name}:") - + param_stats = {} total_params = 0 # Count total parameters total_requires_grad_params = 0 # Count parameters with requires_grad=True total_no_grad_params = 0 # Count parameters with requires_grad=False - + for name, param in model.named_parameters(): - module_name = name.split('.')[0] + module_name = name.split(".")[0] if module_name not in param_stats: - param_stats[module_name] = {'total': 0, 'requires_grad': 0, 'no_grad': 0} - + param_stats[module_name] = {"total": 0, "requires_grad": 0, "no_grad": 0} + param_num = param.numel() - param_stats[module_name]['total'] += param_num + param_stats[module_name]["total"] += param_num total_params += param_num - + if param.requires_grad: - param_stats[module_name]['requires_grad'] += param_num + param_stats[module_name]["requires_grad"] += param_num total_requires_grad_params += param_num else: - param_stats[module_name]['no_grad'] += param_num + param_stats[module_name]["no_grad"] += param_num total_no_grad_params += param_num - + # Calculate maximum width for each column max_module_name_length = max(len(module) for module in param_stats) - max_param_length = max(len(f"{stats['total'] / 1e6:.2f}M") for stats in param_stats.values()) - + max_param_length = max( + len(f"{stats['total'] / 1e6:.2f}M") for stats in param_stats.values() + ) + # Output parameter statistics for each module for module, stats in param_stats.items(): - logging.info(f"\t{module:<{max_module_name_length}}: " - f"Total: {stats['total'] / 1e6:<{max_param_length}.2f}M, " - f"Requires Grad: {stats['requires_grad'] / 1e6:<{max_param_length}.2f}M, " - f"No Grad: {stats['no_grad'] / 1e6:<{max_param_length}.2f}M") - + logging.info( + f"\t{module:<{max_module_name_length}}: " + f"Total: {stats['total'] / 1e6:<{max_param_length}.2f}M, " + f"Requires Grad: {stats['requires_grad'] / 1e6:<{max_param_length}.2f}M, " + f"No Grad: {stats['no_grad'] / 1e6:<{max_param_length}.2f}M" + ) + # Output total parameter statistics logging.info(f"\tTotal parameters: {total_params / 1e6:.2f}M parameters") - logging.info(f"\tRequires Grad parameters: {total_requires_grad_params / 1e6:.2f}M parameters") + logging.info( + f"\tRequires Grad parameters: {total_requires_grad_params / 1e6:.2f}M parameters" + ) logging.info(f"\tNo Grad parameters: {total_no_grad_params / 1e6:.2f}M parameters") logging.info(f"################################################################") def load_and_resample_audio(audio_path, target_sample_rate): - wav, raw_sample_rate = torchaudio.load(audio_path) # (1, T) tensor - if raw_sample_rate != target_sample_rate: - wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate) # tensor + wav, raw_sample_rate = torchaudio.load(audio_path) # (1, T) tensor + if raw_sample_rate != target_sample_rate: + wav = torchaudio.functional.resample( + wav, raw_sample_rate, target_sample_rate + ) # tensor return wav.squeeze() + def set_logging(): rank = os.environ.get("RANK", 0) logging.basicConfig( @@ -63,55 +74,64 @@ def set_logging(): stream=sys.stdout, format=f"%(asctime)s [RANK {rank}] (%(module)s:%(lineno)d) %(levelname)s : %(message)s", ) - + + def waiting_for_debug(ip, port): rank = os.environ.get("RANK", "0") - debugpy.listen((ip, port)) # Replace localhost with cluster node IP + debugpy.listen((ip, port)) # Replace localhost with cluster node IP logging.info(f"[rank = {rank}] Waiting for debugger attach...") debugpy.wait_for_client() logging.info(f"[rank = {rank}] Debugger attached") - + + def load_audio(audio_path, target_sample_rate): # Load audio file, wav shape: (channels, time) wav, raw_sample_rate = torchaudio.load(audio_path) - + # If multi-channel, convert to mono by averaging across channels if wav.shape[0] > 1: - wav = torch.mean(wav, dim=0, keepdim=True) # Average across channels, keep channel dim - + wav = torch.mean( + wav, dim=0, keepdim=True + ) # Average across channels, keep channel dim + # Resample if necessary if raw_sample_rate != target_sample_rate: wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate) - + # Convert to numpy, add channel dimension, then back to tensor with desired shape wav = np.expand_dims(wav.squeeze(0).numpy(), axis=1) # Shape: (time, 1) wav = torch.tensor(wav).reshape(1, 1, -1) # Shape: (1, 1, time) - + return wav + def save_audio(audio_outpath, audio_out, sample_rate): torchaudio.save( - audio_outpath, - audio_out, - sample_rate=sample_rate, - encoding='PCM_S', - bits_per_sample=16 + audio_outpath, + audio_out, + sample_rate=sample_rate, + encoding="PCM_S", + bits_per_sample=16, ) logging.info(f"Successfully saved audio at {audio_outpath}") - + + def find_audio_files(input_dir): - audio_extensions = ['*.flac', '*.mp3', '*.wav'] + audio_extensions = ["*.flac", "*.mp3", "*.wav"] audios_input = [] for ext in audio_extensions: - audios_input.extend(glob.glob(os.path.join(input_dir, '**', ext), recursive=True)) + audios_input.extend( + glob.glob(os.path.join(input_dir, "**", ext), recursive=True) + ) logging.info(f"Found {len(audios_input)} audio files in {input_dir}") return sorted(audios_input) + def normalize_text(text): # Remove all punctuation (including English and Chinese punctuation) - text = re.sub(r'[^\w\s\u4e00-\u9fff]', '', text) + text = re.sub(r"[^\w\s\u4e00-\u9fff]", "", text) # Convert to lowercase (effective for English, no effect on Chinese) text = text.lower() # Remove extra spaces - text = ' '.join(text.split()) - return text \ No newline at end of file + text = " ".join(text.split()) + return text diff --git a/XY_Tokenizer/xy_tokenizer/model.py b/XY_Tokenizer/xy_tokenizer/model.py index 971ec93..5b73d0a 100644 --- a/XY_Tokenizer/xy_tokenizer/model.py +++ b/XY_Tokenizer/xy_tokenizer/model.py @@ -1,148 +1,198 @@ # -*- coding: utf-8 -*- -import yaml import logging + import torch import torch.nn as nn import torch.nn.functional as F - +import yaml from .nn.feature_extractor import MelFeatureExtractor -from .nn.modules import OmniAudioEncoder, OmniAudioDecoder, ResidualDownConv, UpConv, Transformer, Vocos +from .nn.modules import ( + OmniAudioDecoder, + OmniAudioEncoder, + ResidualDownConv, + Transformer, + UpConv, + Vocos, +) from .nn.quantizer import ResidualVQ + class XY_Tokenizer(nn.Module): def __init__(self, generator_params): super().__init__() # Basic parameters - self.input_sample_rate = generator_params['input_sample_rate'] - self.output_sample_rate = generator_params['output_sample_rate'] - - self.encoder_downsample_rate = generator_params['encoder_downsample_rate'] - self.decoder_upsample_rate = generator_params['decoder_upsample_rate'] - self.code_dim = generator_params['quantizer_kwargs']['input_dim'] - + self.input_sample_rate = generator_params["input_sample_rate"] + self.output_sample_rate = generator_params["output_sample_rate"] + + self.encoder_downsample_rate = generator_params["encoder_downsample_rate"] + self.decoder_upsample_rate = generator_params["decoder_upsample_rate"] + self.code_dim = generator_params["quantizer_kwargs"]["input_dim"] + ## Codec part ## Semantic channel - self.semantic_encoder = OmniAudioEncoder(**generator_params['semantic_encoder_kwargs']) - - self.semantic_encoder_adapter = Transformer(**generator_params['semantic_encoder_adapter_kwargs']) - + self.semantic_encoder = OmniAudioEncoder( + **generator_params["semantic_encoder_kwargs"] + ) + + self.semantic_encoder_adapter = Transformer( + **generator_params["semantic_encoder_adapter_kwargs"] + ) + ## Acoustic channel - self.acoustic_encoder = OmniAudioEncoder(**generator_params['acoustic_encoder_kwargs']) - + self.acoustic_encoder = OmniAudioEncoder( + **generator_params["acoustic_encoder_kwargs"] + ) + ## Semantic & acoustic shared parameters - self.pre_rvq_adapter = Transformer(**generator_params['pre_rvq_adapter_kwargs']) - - self.downsample = ResidualDownConv(**generator_params['downsample_kwargs']) - - self.quantizer = ResidualVQ(**generator_params['quantizer_kwargs']) - self.nq = generator_params['quantizer_kwargs']['num_quantizers'] - - self.post_rvq_adapter = Transformer(**generator_params['post_rvq_adapter_kwargs']) - + self.pre_rvq_adapter = Transformer(**generator_params["pre_rvq_adapter_kwargs"]) + + self.downsample = ResidualDownConv(**generator_params["downsample_kwargs"]) + + self.quantizer = ResidualVQ(**generator_params["quantizer_kwargs"]) + self.nq = generator_params["quantizer_kwargs"]["num_quantizers"] + + self.post_rvq_adapter = Transformer( + **generator_params["post_rvq_adapter_kwargs"] + ) + ## Acoustic channel - self.upsample = UpConv(**generator_params['upsample_kwargs']) + self.upsample = UpConv(**generator_params["upsample_kwargs"]) - self.acoustic_decoder = OmniAudioDecoder(**generator_params['acoustic_decoder_kwargs']) + self.acoustic_decoder = OmniAudioDecoder( + **generator_params["acoustic_decoder_kwargs"] + ) - self.enhanced_vocos = Vocos(**generator_params['vocos_kwargs']) + self.enhanced_vocos = Vocos(**generator_params["vocos_kwargs"]) ## Feature extractor - fe_kwargs = generator_params.get('feature_extractor_kwargs', {}) + fe_kwargs = generator_params.get("feature_extractor_kwargs", {}) self.feature_extractor = MelFeatureExtractor(**fe_kwargs) @torch.inference_mode() def inference_tokenize(self, x, input_lengths): """ - Input: - x: Waveform tensor # (B, 1, T), T <= 30s * sample_rate - input_lengths: Valid length for each sample # (B,) - Output: - dict: Contains the following key-value pairs - "zq": Quantized embeddings # (B, D, T) - "codes": Quantization codes # (nq, B, T) - "codes_lengths": Quantization code lengths # (B,) + Input: + x: Waveform tensor # (B, 1, T), T <= 30s * sample_rate + input_lengths: Valid length for each sample # (B,) + Output: + dict: Contains the following key-value pairs + "zq": Quantized embeddings # (B, D, T) + "codes": Quantization codes # (nq, B, T) + "codes_lengths": Quantization code lengths # (B,) """ - list_x = [xi[:, :x_len].reshape(-1).cpu().numpy() for xi, x_len in zip(x, input_lengths)] + list_x = [ + xi[:, :x_len].reshape(-1).cpu().numpy() + for xi, x_len in zip(x, input_lengths) + ] features = self.feature_extractor( list_x, sampling_rate=self.input_sample_rate, return_tensors="pt", - return_attention_mask=True + return_attention_mask=True, ) - input_mel = features['input_features'].to(x.device).to(x.dtype) # (B, D, 3000) - audio_attention_mask = features['attention_mask'].to(x.device) # (B, 3000) - + input_mel = features["input_features"].to(x.device).to(x.dtype) # (B, D, 3000) + audio_attention_mask = features["attention_mask"].to(x.device) # (B, 3000) + # Get batch size and sequence length of the input - mel_output_length = torch.sum(audio_attention_mask, dim=-1).long() # (B,) - + mel_output_length = torch.sum(audio_attention_mask, dim=-1).long() # (B,) + # Semantic channel - semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz - - semantic_encoder_adapter_output, semantic_encoder_adapter_output_length = self.semantic_encoder_adapter(semantic_encoder_output, semantic_encoder_output_length) # (B, D, T), 50hz - + semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder( + input_mel, mel_output_length + ) # (B, D, T), 100hz -> 50hz + + semantic_encoder_adapter_output, semantic_encoder_adapter_output_length = ( + self.semantic_encoder_adapter( + semantic_encoder_output, semantic_encoder_output_length + ) + ) # (B, D, T), 50hz + # Acoustic channel - acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz - + acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder( + input_mel, mel_output_length + ) # (B, D, T), 100hz -> 50hz + # Semantic & acoustic mixing - concated_semantic_acoustic_channel = torch.concat([semantic_encoder_adapter_output, acoustic_encoder_output], dim=1) # (B, D, T) + concated_semantic_acoustic_channel = torch.concat( + [semantic_encoder_adapter_output, acoustic_encoder_output], dim=1 + ) # (B, D, T) concated_semantic_acoustic_channel_length = acoustic_encoder_output_length - - pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(concated_semantic_acoustic_channel, concated_semantic_acoustic_channel_length) # (B, D, T), 50hz - - downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, pre_rvq_adapter_output_length) # (B, D, T), 50hz -> 12.5hz - zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(downsample_output, downsample_output_length) # (B, D, T), (nq, B, T), (nq,), (nq, B, D, T), (B,) + pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter( + concated_semantic_acoustic_channel, + concated_semantic_acoustic_channel_length, + ) # (B, D, T), 50hz + + downsample_output, downsample_output_length = self.downsample( + pre_rvq_adapter_output, pre_rvq_adapter_output_length + ) # (B, D, T), 50hz -> 12.5hz + + zq, codes, vq_loss, _, quantizer_output_length = self.quantizer( + downsample_output, downsample_output_length + ) # (B, D, T), (nq, B, T), (nq,), (nq, B, D, T), (B,) return { - "zq": zq, # (B, D, T) - "codes": codes, # (nq, B, T) - "codes_lengths": quantizer_output_length # (B,) + "zq": zq, # (B, D, T) + "codes": codes, # (nq, B, T) + "codes_lengths": quantizer_output_length, # (B,) } - - @torch.inference_mode() + + @torch.inference_mode() def inference_detokenize(self, codes, codes_lengths): """ - Input: - codes: Quantization codes # (nq, B, T) - codes_lengths: Quantization code lengths for each sample # (B,) - Output: - dict: Contains the following key-value pairs - "y": Synthesized audio waveform # (B, 1, T) - "output_length": Output lengths # (B,) + Input: + codes: Quantization codes # (nq, B, T) + codes_lengths: Quantization code lengths for each sample # (B,) + Output: + dict: Contains the following key-value pairs + "y": Synthesized audio waveform # (B, 1, T) + "output_length": Output lengths # (B,) """ - zq = self.quantizer.decode_codes(codes) # (B, D, T) - - post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(zq, codes_lengths) # (B, D, T), 12.5hz - - # Acoustic channel - upsample_output, upsample_output_length = self.upsample(post_rvq_adapter_output, post_rvq_adapter_output_length) # (B, D, T), 12.5hz -> 50hz + zq = self.quantizer.decode_codes(codes) # (B, D, T) + + post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter( + zq, codes_lengths + ) # (B, D, T), 12.5hz - acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(upsample_output, upsample_output_length) # (B, D, T), 50hz -> 100hz + # Acoustic channel + upsample_output, upsample_output_length = self.upsample( + post_rvq_adapter_output, post_rvq_adapter_output_length + ) # (B, D, T), 12.5hz -> 50hz + + acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder( + upsample_output, upsample_output_length + ) # (B, D, T), 50hz -> 100hz + + y, vocos_output_length = self.enhanced_vocos( + acoustic_decoder_output, acoustic_decoder_output_length + ) # (B, 1, T), 100hz -> 16khz - y, vocos_output_length = self.enhanced_vocos(acoustic_decoder_output, acoustic_decoder_output_length) # (B, 1, T), 100hz -> 16khz - return { - "y": y, # (B, 1, T) - "output_length": vocos_output_length, # (B,) + "y": y, # (B, 1, T) + "output_length": vocos_output_length, # (B,) } - + @torch.inference_mode() def encode(self, wav_list, overlap_seconds=10): """ - Input: - wav_list: List of audio waveforms, each with potentially different length, may exceed 30 seconds # B * (T,) - overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output - Output: - dict: Contains the following key-value pairs - "codes_list": List of quantization codes # B * (nq, T) + Input: + wav_list: List of audio waveforms, each with potentially different length, may exceed 30 seconds # B * (T,) + overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output + Output: + dict: Contains the following key-value pairs + "codes_list": List of quantization codes # B * (nq, T) """ device = wav_list[0].device duration_seconds = 30 - overlap_seconds - chunk_size = int(30 * self.input_sample_rate) # Maximum samples per chunk - duration_size = int(duration_seconds * self.input_sample_rate) # Valid output samples per chunk - code_duration_length = duration_size // self.encoder_downsample_rate # Valid code length per chunk + chunk_size = int(30 * self.input_sample_rate) # Maximum samples per chunk + duration_size = int( + duration_seconds * self.input_sample_rate + ) # Valid output samples per chunk + code_duration_length = ( + duration_size // self.encoder_downsample_rate + ) # Valid code length per chunk # Get maximum waveform length max_length = max(len(wav) for wav in wav_list) @@ -150,8 +200,8 @@ def encode(self, wav_list, overlap_seconds=10): wav_tensor = torch.zeros(batch_size, 1, max_length, device=device) input_lengths = torch.zeros(batch_size, dtype=torch.long, device=device) for i, wav in enumerate(wav_list): - wav_tensor[i, 0, :len(wav)] = wav - input_lengths[i] = len(wav) # (B,) + wav_tensor[i, 0, : len(wav)] = wav + input_lengths[i] = len(wav) # (B,) # Calculate number of chunks needed max_chunks = (max_length + duration_size - 1) // duration_size @@ -161,122 +211,161 @@ def encode(self, wav_list, overlap_seconds=10): for chunk_idx in range(max_chunks): start = chunk_idx * duration_size end = min(start + chunk_size, max_length) - chunk = wav_tensor[:, :, start:end] # (B, 1, T') - chunk_lengths = torch.clamp(input_lengths - start, 0, end - start) # (B,) + chunk = wav_tensor[:, :, start:end] # (B, 1, T') + chunk_lengths = torch.clamp(input_lengths - start, 0, end - start) # (B,) # Skip empty chunks if chunk_lengths.max() == 0: continue # Encode - result = self.inference_tokenize(chunk, chunk_lengths) # {"zq": (B, D, T'), "codes": (nq, B, T'), "codes_lengths": (B,)} - chunk_codes = result["codes"] # (nq, B, T') - chunk_code_lengths = result["codes_lengths"] # (B,) + result = self.inference_tokenize( + chunk, chunk_lengths + ) # {"zq": (B, D, T'), "codes": (nq, B, T'), "codes_lengths": (B,)} + chunk_codes = result["codes"] # (nq, B, T') + chunk_code_lengths = result["codes_lengths"] # (B,) # Extract valid portion - valid_code_lengths = torch.clamp(chunk_code_lengths, 0, code_duration_length) # (B,) - valid_chunk_codes = torch.zeros(self.nq, batch_size, code_duration_length, device=device, dtype=chunk_codes.dtype) + valid_code_lengths = torch.clamp( + chunk_code_lengths, 0, code_duration_length + ) # (B,) + valid_chunk_codes = torch.zeros( + self.nq, + batch_size, + code_duration_length, + device=device, + dtype=chunk_codes.dtype, + ) for b in range(batch_size): if valid_code_lengths[b] > 0: - valid_chunk_codes[:, b, :valid_code_lengths[b]] = chunk_codes[:, b, :valid_code_lengths[b]] # (nq, B, valid_code_length) + valid_chunk_codes[:, b, : valid_code_lengths[b]] = chunk_codes[ + :, b, : valid_code_lengths[b] + ] # (nq, B, valid_code_length) - codes_list.append(valid_chunk_codes) # (nq, B, valid_code_length) + codes_list.append(valid_chunk_codes) # (nq, B, valid_code_length) # Concatenate all chunks if codes_list: - codes_tensor = torch.cat(codes_list, dim=-1) # (nq, B, T_total) - codes_list = [codes_tensor[:, i, :input_lengths[i] // self.encoder_downsample_rate] for i in range(batch_size)] # B * (nq, T) + codes_tensor = torch.cat(codes_list, dim=-1) # (nq, B, T_total) + codes_list = [ + codes_tensor[:, i, : input_lengths[i] // self.encoder_downsample_rate] + for i in range(batch_size) + ] # B * (nq, T) else: - codes_list = [torch.zeros(self.nq, 0, device=device, dtype=torch.long) for _ in range(batch_size)] # B * (nq, 0) + codes_list = [ + torch.zeros(self.nq, 0, device=device, dtype=torch.long) + for _ in range(batch_size) + ] # B * (nq, 0) + + return {"codes_list": codes_list} # B * (nq, T) - return { - "codes_list": codes_list # B * (nq, T) - } - @torch.inference_mode() def decode(self, codes_list, overlap_seconds=10): """ - Input: - codes_list: List of quantization codes # B * (nq, T) - overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output - Output: - dict: Contains the following key-value pairs - "syn_wav_list": List of synthesized audio waveforms # B * (T,) + Input: + codes_list: List of quantization codes # B * (nq, T) + overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output + Output: + dict: Contains the following key-value pairs + "syn_wav_list": List of synthesized audio waveforms # B * (T,) """ device = codes_list[0].device duration_seconds = 30 - overlap_seconds - chunk_code_length = int(30 * self.input_sample_rate // self.encoder_downsample_rate) # Maximum code length per chunk - duration_code_length = int(duration_seconds * self.input_sample_rate // self.encoder_downsample_rate) # Valid code length per chunk - duration_wav_length = duration_code_length * self.decoder_upsample_rate # Valid waveform length per chunk + chunk_code_length = int( + 30 * self.input_sample_rate // self.encoder_downsample_rate + ) # Maximum code length per chunk + duration_code_length = int( + duration_seconds * self.input_sample_rate // self.encoder_downsample_rate + ) # Valid code length per chunk + duration_wav_length = ( + duration_code_length * self.decoder_upsample_rate + ) # Valid waveform length per chunk # Get maximum code length max_code_length = max(codes.shape[-1] for codes in codes_list) batch_size = len(codes_list) - codes_tensor = torch.zeros(self.nq, batch_size, max_code_length, device=device, dtype=torch.long) + codes_tensor = torch.zeros( + self.nq, batch_size, max_code_length, device=device, dtype=torch.long + ) code_lengths = torch.zeros(batch_size, dtype=torch.long, device=device) for i, codes in enumerate(codes_list): - codes_tensor[:, i, :codes.shape[-1]] = codes.to(device) - code_lengths[i] = codes.shape[-1] # (B,) + codes_tensor[:, i, : codes.shape[-1]] = codes.to(device) + code_lengths[i] = codes.shape[-1] # (B,) # Calculate number of chunks needed - max_chunks = (max_code_length + duration_code_length - 1) // duration_code_length + max_chunks = ( + max_code_length + duration_code_length - 1 + ) // duration_code_length wav_list = [] # Process the entire batch in chunks for chunk_idx in range(max_chunks): start = chunk_idx * duration_code_length end = min(start + chunk_code_length, max_code_length) - chunk_codes = codes_tensor[:, :, start:end] # (nq, B, T') - chunk_code_lengths = torch.clamp(code_lengths - start, 0, end - start) # (B,) + chunk_codes = codes_tensor[:, :, start:end] # (nq, B, T') + chunk_code_lengths = torch.clamp( + code_lengths - start, 0, end - start + ) # (B,) # Skip empty chunks if chunk_code_lengths.max() == 0: continue # Decode - result = self.inference_detokenize(chunk_codes, chunk_code_lengths) # {"y": (B, 1, T'), "output_length": (B,)} - chunk_wav = result["y"] # (B, 1, T') - chunk_wav_lengths = result["output_length"] # (B,) + result = self.inference_detokenize( + chunk_codes, chunk_code_lengths + ) # {"y": (B, 1, T'), "output_length": (B,)} + chunk_wav = result["y"] # (B, 1, T') + chunk_wav_lengths = result["output_length"] # (B,) # Extract valid portion - valid_wav_lengths = torch.clamp(chunk_wav_lengths, 0, duration_wav_length) # (B,) - valid_chunk_wav = torch.zeros(batch_size, 1, duration_wav_length, device=device) + valid_wav_lengths = torch.clamp( + chunk_wav_lengths, 0, duration_wav_length + ) # (B,) + valid_chunk_wav = torch.zeros( + batch_size, 1, duration_wav_length, device=device + ) for b in range(batch_size): if valid_wav_lengths[b] > 0: - valid_chunk_wav[b, :, :valid_wav_lengths[b]] = chunk_wav[b, :, :valid_wav_lengths[b]] # (B, 1, valid_wav_length) + valid_chunk_wav[b, :, : valid_wav_lengths[b]] = chunk_wav[ + b, :, : valid_wav_lengths[b] + ] # (B, 1, valid_wav_length) - wav_list.append(valid_chunk_wav) # (B, 1, valid_wav_length) + wav_list.append(valid_chunk_wav) # (B, 1, valid_wav_length) # Concatenate all chunks if wav_list: - wav_tensor = torch.cat(wav_list, dim=-1) # (B, 1, T_total) - syn_wav_list = [wav_tensor[i, 0, :code_lengths[i] * self.decoder_upsample_rate] for i in range(batch_size)] # B * (T,) + wav_tensor = torch.cat(wav_list, dim=-1) # (B, 1, T_total) + syn_wav_list = [ + wav_tensor[i, 0, : code_lengths[i] * self.decoder_upsample_rate] + for i in range(batch_size) + ] # B * (T,) else: - syn_wav_list = [torch.zeros(0, device=device) for _ in range(batch_size)] # B * (0,) - - return { - "syn_wav_list": syn_wav_list # B * (T,) - } - + syn_wav_list = [ + torch.zeros(0, device=device) for _ in range(batch_size) + ] # B * (0,) + + return {"syn_wav_list": syn_wav_list} # B * (T,) + @classmethod def load_from_checkpoint(cls, config_path: str, ckpt_path: str): # Load model from configuration file and checkpoint logging.info(f"Loading model from {config_path} and {ckpt_path}") - + # Load configuration - with open(config_path, 'r') as f: + with open(config_path, "r") as f: config = yaml.safe_load(f) - + # Create model instance - model = cls(config['generator_params']) - + model = cls(config["generator_params"]) + # Load checkpoint - checkpoint = torch.load(ckpt_path, map_location='cpu') - + checkpoint = torch.load(ckpt_path, map_location="cpu") + # Check if checkpoint contains 'generator' key - if 'generator' in checkpoint: - model.load_state_dict(checkpoint['generator']) + if "generator" in checkpoint: + model.load_state_dict(checkpoint["generator"]) else: model.load_state_dict(checkpoint) - - return model \ No newline at end of file + + return model diff --git a/XY_Tokenizer/xy_tokenizer/nn/feature_extractor.py b/XY_Tokenizer/xy_tokenizer/nn/feature_extractor.py index 4d397b0..ccb14a6 100644 --- a/XY_Tokenizer/xy_tokenizer/nn/feature_extractor.py +++ b/XY_Tokenizer/xy_tokenizer/nn/feature_extractor.py @@ -1,16 +1,17 @@ +from typing import List, Optional, Union + import numpy as np import torch - -from typing import Union, List, Optional +from transformers.audio_utils import mel_filter_bank, spectrogram, window_function from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor from transformers.feature_extraction_utils import BatchFeature from transformers.utils import TensorType, logging from transformers.utils.import_utils import is_torch_available -from transformers.audio_utils import mel_filter_bank, spectrogram, window_function + class MelFeatureExtractor(SequenceFeatureExtractor): model_input_names = ["input_features"] - + def __init__( self, feature_size=80, @@ -38,7 +39,9 @@ def __init__( self.nb_max_frames = self.n_samples // hop_length self.sampling_rate = sampling_rate self.dither = dither - self.max_frequency = max_frequency if max_frequency is not None else sampling_rate / 2 + self.max_frequency = ( + max_frequency if max_frequency is not None else sampling_rate / 2 + ) self.mel_filters = mel_filter_bank( num_frequency_bins=1 + n_fft // 2, num_mel_filters=feature_size, @@ -49,7 +52,9 @@ def __init__( mel_scale="slaney", ) - def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> np.ndarray: + def _np_extract_fbank_features( + self, waveform_batch: np.array, device: str + ) -> np.ndarray: if device != "cpu": raise ValueError( f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator " @@ -75,7 +80,9 @@ def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> n log_spec_batch = np.array(log_spec_batch) return log_spec_batch - def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") -> np.ndarray: + def _torch_extract_fbank_features( + self, waveform: np.array, device: str = "cpu" + ) -> np.ndarray: """ Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching, yielding results similar to cpu computing with 1e-5 tolerance. @@ -84,9 +91,13 @@ def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") window = torch.hann_window(self.n_fft, device=device) if self.dither != 0.0: - waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device) + waveform += self.dither * torch.randn( + waveform.shape, dtype=waveform.dtype, device=waveform.device + ) - stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True) + stft = torch.stft( + waveform, self.n_fft, self.hop_length, window=window, return_complex=True + ) magnitudes = stft[..., :-1].abs() ** 2 mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) @@ -105,7 +116,9 @@ def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") @staticmethod def zero_mean_unit_var_norm( - input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0 + input_values: List[np.ndarray], + attention_mask: List[np.ndarray], + padding_value: float = 0.0, ) -> List[np.ndarray]: """ Every array in the list is normalized to have zero mean and unit variance @@ -115,13 +128,17 @@ def zero_mean_unit_var_norm( normed_input_values = [] for vector, length in zip(input_values, attention_mask.sum(-1)): - normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) + normed_slice = (vector - vector[:length].mean()) / np.sqrt( + vector[:length].var() + 1e-7 + ) if length < normed_slice.shape[0]: normed_slice[length:] = padding_value normed_input_values.append(normed_slice) else: - normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] + normed_input_values = [ + (x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values + ] return normed_input_values @@ -177,18 +194,27 @@ def __call__( f"extractor's sampling rate to ensure correct feature extraction." ) - is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + is_batched_numpy = ( + isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + ) if is_batched_numpy and len(raw_speech.shape) > 2: - raise ValueError(f"Only mono-channel audio is supported for input to {self}") + raise ValueError( + f"Only mono-channel audio is supported for input to {self}" + ) is_batched = is_batched_numpy or ( - isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + isinstance(raw_speech, (list, tuple)) + and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) ) if is_batched: - raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech] + raw_speech = [ + np.asarray([speech], dtype=np.float32).T for speech in raw_speech + ] elif not is_batched and not isinstance(raw_speech, np.ndarray): raw_speech = np.asarray(raw_speech, dtype=np.float32) - elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype( + np.float64 + ): raw_speech = raw_speech.astype(np.float32) if not is_batched: @@ -211,27 +237,37 @@ def __call__( attention_mask=padded_inputs["attention_mask"], padding_value=self.padding_value, ) - padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], axis=0) + padded_inputs["input_features"] = np.stack( + padded_inputs["input_features"], axis=0 + ) input_features = padded_inputs.get("input_features").transpose(2, 0, 1) extract_fbank_features = ( - self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features + self._torch_extract_fbank_features + if is_torch_available() + else self._np_extract_fbank_features ) input_features = extract_fbank_features(input_features[0], device) if isinstance(input_features[0], List): - padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features] + padded_inputs["input_features"] = [ + np.asarray(feature, dtype=np.float32) for feature in input_features + ] else: padded_inputs["input_features"] = input_features if return_attention_mask: - padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length] + padded_inputs["attention_mask"] = padded_inputs["attention_mask"][ + :, :: self.hop_length + ] if return_token_timestamps is not None: - padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech] + padded_inputs["num_frames"] = [ + len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech + ] if return_tensors is not None: padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - return padded_inputs \ No newline at end of file + return padded_inputs diff --git a/XY_Tokenizer/xy_tokenizer/nn/modules.py b/XY_Tokenizer/xy_tokenizer/nn/modules.py index cc186d9..eabc01b 100644 --- a/XY_Tokenizer/xy_tokenizer/nn/modules.py +++ b/XY_Tokenizer/xy_tokenizer/nn/modules.py @@ -1,24 +1,21 @@ -import torch -import torch.distributed -import numpy as np +import copy import logging import math -import copy +from dataclasses import dataclass +from typing import Optional, Tuple + +import librosa import numpy as np import scipy import torch -import librosa - -from typing import Optional, Tuple -from torch import nn, view_as_real, view_as_complex -from torch import nn +import torch.distributed +from torch import nn, view_as_complex, view_as_real from torch.nn import functional as F -from torch.nn.utils import weight_norm, remove_weight_norm +from torch.nn.utils import remove_weight_norm, weight_norm from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz +from transformers import WhisperModel from transformers.activations import ACT2FN -from dataclasses import dataclass from transformers.modeling_outputs import ModelOutput -from transformers import WhisperModel # Define function to generate positional embeddings using sine and cosine functions to represent sequence position information @@ -30,6 +27,7 @@ def sinusoids(length, channels, max_timescale=10000): scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + # Generate sequence mask to distinguish valid sequence and padding parts def get_sequence_mask(inputs, inputs_length): if inputs.dim() == 3: @@ -37,9 +35,12 @@ def get_sequence_mask(inputs, inputs_length): else: bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length) sequence_mask = torch.arange(0, tgt_len).to(inputs.device) - sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1) + sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view( + bsz, tgt_len, 1 + ) return sequence_mask + # Define RMSNorm layer for normalizing hidden states and stabilizing training process class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -54,6 +55,7 @@ def forward(self, hidden_states): hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states + # Modified variable-length attention mechanism, supporting FP32 with unified interface class VarLenAttention(nn.Module): def __init__(self, embed_dim, num_heads, causal=False, dropout=0.0): @@ -73,7 +75,7 @@ def __init__(self, embed_dim, num_heads, causal=False, dropout=0.0): assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.causal = causal self.dropout = nn.Dropout(dropout) - self.scaling = self.head_dim ** -0.5 # Scaling factor + self.scaling = self.head_dim**-0.5 # Scaling factor # Linear projection layers for Q, K, V and output self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) @@ -104,18 +106,29 @@ def _create_attention_mask(self, seq_len, max_len, device, dtype): # Mark valid positions (less than seq_len) valid_mask = seq_indices < seq_len_expanded.unsqueeze(-1) # [bsz, 1, max_len] - mask = mask * (valid_mask.unsqueeze(2) & valid_mask.unsqueeze(3)).to(dtype) # [bsz, 1, max_len, max_len] + mask = mask * (valid_mask.unsqueeze(2) & valid_mask.unsqueeze(3)).to( + dtype + ) # [bsz, 1, max_len, max_len] # If causal attention, add upper triangular mask if self.causal: - causal_mask = torch.triu(torch.ones(max_len, max_len, device=device, dtype=torch.bool), diagonal=1) - mask = mask * (~causal_mask.unsqueeze(0).unsqueeze(1)).to(dtype) # Keep only lower triangular part + causal_mask = torch.triu( + torch.ones(max_len, max_len, device=device, dtype=torch.bool), + diagonal=1, + ) + mask = mask * (~causal_mask.unsqueeze(0).unsqueeze(1)).to( + dtype + ) # Keep only lower triangular part # Set invalid positions (0) to dtype's minimum value - mask = mask + (1.0 - mask) * torch.finfo(dtype).min # Valid positions unchanged, invalid positions to minimum value + mask = ( + mask + (1.0 - mask) * torch.finfo(dtype).min + ) # Valid positions unchanged, invalid positions to minimum value return mask - def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, seq_len: torch.Tensor + ) -> torch.Tensor: """ Forward propagation, input and output are [bsz, max_len, embed_dim]. @@ -130,43 +143,73 @@ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.T # Project to Q, K, V query = self.q_proj(hidden_states) * self.scaling # [bsz, max_len, embed_dim] - key = self.k_proj(hidden_states) # [bsz, max_len, embed_dim] - value = self.v_proj(hidden_states) # [bsz, max_len, embed_dim] + key = self.k_proj(hidden_states) # [bsz, max_len, embed_dim] + value = self.v_proj(hidden_states) # [bsz, max_len, embed_dim] # Reshape to multi-head form - query = query.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2) # [bsz, num_heads, max_len, head_dim] - key = key.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2) # [bsz, num_heads, max_len, head_dim] - value = value.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2) # [bsz, num_heads, max_len, head_dim] + query = query.view(bsz, max_len, self.num_heads, self.head_dim).transpose( + 1, 2 + ) # [bsz, num_heads, max_len, head_dim] + key = key.view(bsz, max_len, self.num_heads, self.head_dim).transpose( + 1, 2 + ) # [bsz, num_heads, max_len, head_dim] + value = value.view(bsz, max_len, self.num_heads, self.head_dim).transpose( + 1, 2 + ) # [bsz, num_heads, max_len, head_dim] # Calculate attention scores - attn_scores = torch.matmul(query, key.transpose(-1, -2)) # [bsz, num_heads, max_len, max_len] + attn_scores = torch.matmul( + query, key.transpose(-1, -2) + ) # [bsz, num_heads, max_len, max_len] # Generate attention mask - attn_mask = self._create_attention_mask(seq_len, max_len, hidden_states.device, attn_scores.dtype) # [bsz, 1, max_len, max_len] + attn_mask = self._create_attention_mask( + seq_len, max_len, hidden_states.device, attn_scores.dtype + ) # [bsz, 1, max_len, max_len] # Apply mask (additive form, consistent with HubertEncoder) - attn_scores = attn_scores + attn_mask # Invalid positions set to very small value + attn_scores = ( + attn_scores + attn_mask + ) # Invalid positions set to very small value # Softmax calculate attention weights - attn_weights = F.softmax(attn_scores, dim=-1) # [bsz, num_heads, max_len, max_len] + attn_weights = F.softmax( + attn_scores, dim=-1 + ) # [bsz, num_heads, max_len, max_len] attn_weights = self.dropout(attn_weights) # Calculate attention output - attn_output = torch.matmul(attn_weights, value) # [bsz, num_heads, max_len, head_dim] - attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, max_len, self.embed_dim) # [bsz, max_len, embed_dim] + attn_output = torch.matmul( + attn_weights, value + ) # [bsz, num_heads, max_len, head_dim] + attn_output = ( + attn_output.transpose(1, 2).contiguous().view(bsz, max_len, self.embed_dim) + ) # [bsz, max_len, embed_dim] # Output projection attn_output = self.out_proj(attn_output) # [bsz, max_len, embed_dim] return attn_output + # Define Transformer layer containing attention mechanism and feedforward network for feature extraction and transformation class OmniWhisperTransformerLayer(nn.Module): - def __init__(self, activation_function="gelu", d_model=1280, attention_heads=20, ffn_dim=5120, causal=False, ln_type="LayerNorm", attn_type="varlen"): + def __init__( + self, + activation_function="gelu", + d_model=1280, + attention_heads=20, + ffn_dim=5120, + causal=False, + ln_type="LayerNorm", + attn_type="varlen", + ): super().__init__() self.embed_dim = d_model # Only keep varlen attention mechanism if attn_type != "varlen": - raise ValueError(f"Unknown attn_type: {attn_type}. Only 'varlen' is supported.") + raise ValueError( + f"Unknown attn_type: {attn_type}. Only 'varlen' is supported." + ) self.self_attn = VarLenAttention(self.embed_dim, attention_heads, causal) if ln_type == "LayerNorm": self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -184,85 +227,105 @@ def __init__(self, activation_function="gelu", d_model=1280, attention_heads=20, else: raise ValueError(f"Unknown ln_type: {ln_type}") - def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, seq_len: torch.Tensor + ) -> torch.Tensor: residual = hidden_states # [bsz, max_len, embed_dim] hidden_states = self.self_attn_layer_norm(hidden_states) # from torch.cuda.amp import autocast # print(f"{residual.dtype = }") # print(f"Autocast enabled: {torch.is_autocast_enabled():}") # print(f"after layernorm {hidden_states.dtype = }") - hidden_states = self.self_attn(hidden_states, seq_len) # [bsz, max_len, embed_dim] + hidden_states = self.self_attn( + hidden_states, seq_len + ) # [bsz, max_len, embed_dim] hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.fc2(hidden_states) hidden_states = residual + hidden_states - if (hidden_states.dtype == torch.float16 or hidden_states.dtype == torch.bfloat16) and \ - (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): + if ( + hidden_states.dtype == torch.float16 + or hidden_states.dtype == torch.bfloat16 + ) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) return hidden_states + # Define audio encoder to convert input audio features to hidden state representation class OmniAudioEncoder(nn.Module): def __init__( - self, - num_mel_bins=128, # Input feature Mel band number, usually the dimension of Mel spectrogram - sampling_rate=16000, # Audio sampling rate, unit Hz - hop_length=160, # Frame shift length (sample number) when calculating Mel spectrogram - stride_size=2, # Convolution layer step, used for downsampling - kernel_size=3, # Convolution kernel size, controlling receptive field - d_model=1280, # Model's hidden state dimension (embedding dimension) - scale_embedding=True, # Whether to scale embedding (usually used for stabilizing training) - max_audio_seconds=30, # Maximum audio duration supported (seconds) - encoder_layers=32, # Transformer encoder layer number - encoder_attention_heads=20, # Attention head number for each Transformer layer - encoder_ffn_dim=5120, # Intermediate dimension for feedforward network - activation_function="gelu", # Activation function type, default GELU - attn_type="varlen" # New parameter, select attention mechanism type - ): + self, + num_mel_bins=128, # Input feature Mel band number, usually the dimension of Mel spectrogram + sampling_rate=16000, # Audio sampling rate, unit Hz + hop_length=160, # Frame shift length (sample number) when calculating Mel spectrogram + stride_size=2, # Convolution layer step, used for downsampling + kernel_size=3, # Convolution kernel size, controlling receptive field + d_model=1280, # Model's hidden state dimension (embedding dimension) + scale_embedding=True, # Whether to scale embedding (usually used for stabilizing training) + max_audio_seconds=30, # Maximum audio duration supported (seconds) + encoder_layers=32, # Transformer encoder layer number + encoder_attention_heads=20, # Attention head number for each Transformer layer + encoder_ffn_dim=5120, # Intermediate dimension for feedforward network + activation_function="gelu", # Activation function type, default GELU + attn_type="varlen", # New parameter, select attention mechanism type + ): super().__init__() # Calculate maximum sequence length: Convert sampling rate to frame number after considering downsampling step - self.max_source_positions = (max_audio_seconds * sampling_rate // hop_length) // stride_size + self.max_source_positions = ( + max_audio_seconds * sampling_rate // hop_length + ) // stride_size # Embedding scaling factor, if enabled sqrt(d_model), otherwise 1.0 self.embed_scale = math.sqrt(d_model) if scale_embedding else 1.0 self.num_mel_bins = num_mel_bins # Save Mel band number self.d_model = d_model # Save hidden state dimension self.stride_size = stride_size - + # First convolution layer: Convert Mel spectrogram features (num_mel_bins) to hidden dimension (d_model) - self.conv1 = nn.Conv1d(num_mel_bins, d_model, kernel_size=kernel_size, padding=1) + self.conv1 = nn.Conv1d( + num_mel_bins, d_model, kernel_size=kernel_size, padding=1 + ) # Second convolution layer: Apply downsampling with stride_size - self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=kernel_size, stride=stride_size, padding=1) - + self.conv2 = nn.Conv1d( + d_model, d_model, kernel_size=kernel_size, stride=stride_size, padding=1 + ) + # Register positional embedding buffer, using sine function to generate, shape (max_source_positions, d_model) - self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model)) - + self.register_buffer( + "positional_embedding", sinusoids(self.max_source_positions, d_model) + ) + # Create Transformer encoder layer list, each layer contains attention mechanism and feedforward network - self.layers = nn.ModuleList([ - OmniWhisperTransformerLayer( - activation_function=activation_function, - d_model=d_model, - attention_heads=encoder_attention_heads, - ffn_dim=encoder_ffn_dim, - causal=False, # Encoder does not need causal attention - attn_type=attn_type # Pass attention type - ) for _ in range(encoder_layers) - ]) - + self.layers = nn.ModuleList( + [ + OmniWhisperTransformerLayer( + activation_function=activation_function, + d_model=d_model, + attention_heads=encoder_attention_heads, + ffn_dim=encoder_ffn_dim, + causal=False, # Encoder does not need causal attention + attn_type=attn_type, # Pass attention type + ) + for _ in range(encoder_layers) + ] + ) + # Last layer normalization for stable output self.layer_norm = nn.LayerNorm(d_model) def forward(self, input_features, input_length, output_hidden_states=False): """ Forward propagation function to convert input audio features to hidden state representation - + Parameters: input_features (torch.Tensor): Input Mel spectrogram features, shape [bsz, num_mel_bins, seq_len] input_length (torch.Tensor): Input sequence length for each sample, shape [bsz] output_hidden_states (bool, optional): Whether to return hidden states for each layer, default False - + Returns: if output_hidden_states is False: hidden_states (torch.Tensor): Encoded hidden states, shape [bsz, d_model, tgt_len] @@ -274,160 +337,187 @@ def forward(self, input_features, input_length, output_hidden_states=False): """ # Ensure input feature data type consistent with convolution layer weights input_features = input_features.to(self.conv1.weight.dtype) # (B, D, T) - + # First layer convolution + GELU activation, Convert Mel spectrogram to hidden states inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (B, D, T) - + # Second layer convolution + GELU activation, Apply downsampling with stride_size inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (B, D, T) - + # Calculate output length: Result after downsampling with stride_size output_length = (input_length // self.stride_size).long() # (B,) - + # Adjust dimension order to [bsz, seq_len, d_model] for Transformer input hidden_states = inputs_embeds.permute(0, 2, 1) # (B, T, D) - + # Get batch size and target sequence length bsz, tgt_len, _ = hidden_states.size() - + # According to current sequence length, take or use complete positional embedding if tgt_len < self.positional_embedding.shape[0]: current_positional_embedding = self.positional_embedding[:tgt_len] else: current_positional_embedding = self.positional_embedding - + # Add input embedding to positional embedding, convert to float to avoid precision issues - hidden_states = (hidden_states.to(torch.float32) + current_positional_embedding).to(hidden_states.dtype) - + hidden_states = ( + hidden_states.to(torch.float32) + current_positional_embedding + ).to(hidden_states.dtype) + # Generate sequence mask for processing variable-length sequence - attention_mask = get_sequence_mask(hidden_states, output_length) # [bsz, tgt_len, 1] - + attention_mask = get_sequence_mask( + hidden_states, output_length + ) # [bsz, tgt_len, 1] + # Initialize hidden states list for storing output for each layer (if needed) hidden_states_all_layers = () if output_hidden_states else None - + # Process hidden states through Transformer encoder layer by layer for encoder_layer in self.layers: if output_hidden_states: hidden_states_all_layers = hidden_states_all_layers + (hidden_states,) - hidden_states = encoder_layer(hidden_states, output_length) # [bsz, tgt_len, d_model] - + hidden_states = encoder_layer( + hidden_states, output_length + ) # [bsz, tgt_len, d_model] + # Normalize hidden states hidden_states = self.layer_norm(hidden_states) # [bsz, tgt_len, d_model] if output_hidden_states: hidden_states_all_layers = hidden_states_all_layers + (hidden_states,) - + # Use mask to zero out padding parts and ensure output only retains valid data - hidden_states = torch.where(attention_mask, hidden_states, 0) # [bsz, tgt_len, d_model] + hidden_states = torch.where( + attention_mask, hidden_states, 0 + ) # [bsz, tgt_len, d_model] hidden_states = hidden_states.transpose(1, 2) # [bsz, d_model, tgt_len] - + if not output_hidden_states: - return hidden_states, output_length + return hidden_states, output_length else: return hidden_states, output_length, hidden_states_all_layers - + + # Define audio decoder to convert hidden states to Mel spectrogram class OmniAudioDecoder(nn.Module): def __init__( - self, - num_mel_bins=128, - sampling_rate=16000, - hop_length=160, - stride_size=2, - kernel_size=3, - d_model=1280, - scale_embedding=True, - max_audio_seconds=30, - decoder_layers=32, - decoder_attention_heads=20, - decoder_ffn_dim=5120, - activation_function="gelu", - attn_type="varlen" # New parameter, select attention mechanism type - ): + self, + num_mel_bins=128, + sampling_rate=16000, + hop_length=160, + stride_size=2, + kernel_size=3, + d_model=1280, + scale_embedding=True, + max_audio_seconds=30, + decoder_layers=32, + decoder_attention_heads=20, + decoder_ffn_dim=5120, + activation_function="gelu", + attn_type="varlen", # New parameter, select attention mechanism type + ): super().__init__() - self.max_source_positions = (max_audio_seconds * sampling_rate // hop_length) // stride_size + self.max_source_positions = ( + max_audio_seconds * sampling_rate // hop_length + ) // stride_size self.embed_scale = math.sqrt(d_model) if scale_embedding else 1.0 self.num_mel_bins = num_mel_bins self.d_model = d_model self.stride_size = stride_size - + # Correct transpose convolution layer to ensure output length close to stride_size times self.deconv1 = nn.ConvTranspose1d( - d_model, - d_model, - kernel_size=kernel_size, - stride=stride_size, + d_model, + d_model, + kernel_size=kernel_size, + stride=stride_size, padding=0, # Do not fill input side - output_padding=0 # Can be adjusted to precisely control length + output_padding=0, # Can be adjusted to precisely control length ) self.deconv2 = nn.ConvTranspose1d( - d_model, - num_mel_bins, - kernel_size=kernel_size, + d_model, + num_mel_bins, + kernel_size=kernel_size, stride=1, # Only convert channels, do not change length - padding=0 + padding=0, ) - + # Positional embedding remains consistent - self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model)) # (T, D) - + self.register_buffer( + "positional_embedding", sinusoids(self.max_source_positions, d_model) + ) # (T, D) + # Transformer decoder layer - self.layers = nn.ModuleList([ - OmniWhisperTransformerLayer( - activation_function=activation_function, - d_model=d_model, - attention_heads=decoder_attention_heads, - ffn_dim=decoder_ffn_dim, - causal=False, # Decoder uses causal attention - attn_type=attn_type # Pass attention type - ) for _ in range(decoder_layers) - ]) + self.layers = nn.ModuleList( + [ + OmniWhisperTransformerLayer( + activation_function=activation_function, + d_model=d_model, + attention_heads=decoder_attention_heads, + ffn_dim=decoder_ffn_dim, + causal=False, # Decoder uses causal attention + attn_type=attn_type, # Pass attention type + ) + for _ in range(decoder_layers) + ] + ) self.layer_norm = nn.LayerNorm(d_model) - def forward(self, hidden_states, input_length): # (B, D, T) + def forward(self, hidden_states, input_length): # (B, D, T) # Input is hidden state output from encoder - hidden_states = hidden_states.transpose(1, 2) # (B, T, D) + hidden_states = hidden_states.transpose(1, 2) # (B, T, D) bsz, tgt_len, _ = hidden_states.size() - + # Add positional embedding if tgt_len < self.positional_embedding.shape[0]: - current_positional_embedding = self.positional_embedding[:tgt_len] # (T, D) + current_positional_embedding = self.positional_embedding[:tgt_len] # (T, D) else: current_positional_embedding = self.positional_embedding - hidden_states = (hidden_states.to(torch.float32) + current_positional_embedding).to(hidden_states.dtype) # (B, T, D) - + hidden_states = ( + hidden_states.to(torch.float32) + current_positional_embedding + ).to( + hidden_states.dtype + ) # (B, T, D) + # Generate sequence mask - attention_mask = get_sequence_mask(hidden_states, input_length) # [bsz, tgt_len, 1] - + attention_mask = get_sequence_mask( + hidden_states, input_length + ) # [bsz, tgt_len, 1] + # Process through decoder layer for decoder_layer in self.layers: - hidden_states = decoder_layer(hidden_states, input_length) # [bsz, tgt_len, d_model] - + hidden_states = decoder_layer( + hidden_states, input_length + ) # [bsz, tgt_len, d_model] + # Final layer normalization hidden_states = self.layer_norm(hidden_states) # [bsz, tgt_len, d_model] - + # Use mask to zero out padding parts - hidden_states = torch.where(attention_mask, hidden_states, 0) # [bsz, tgt_len, d_model] - + hidden_states = torch.where( + attention_mask, hidden_states, 0 + ) # [bsz, tgt_len, d_model] + # Process through transpose convolution layer to reconstruct audio features hidden_states = hidden_states.permute(0, 2, 1) # (B, D, T) - output_features = nn.functional.gelu(self.deconv1(hidden_states)) # (B, D, T) - output_features = nn.functional.gelu(self.deconv2(output_features)) # (B, D, T) - + output_features = nn.functional.gelu(self.deconv1(hidden_states)) # (B, D, T) + output_features = nn.functional.gelu(self.deconv2(output_features)) # (B, D, T) + # If strictly stride_size times length is needed, can trim extra parts expected_length = tgt_len * self.stride_size if output_features.size(2) > expected_length: output_features = output_features[:, :, :expected_length] - + output_length = input_length * self.stride_size # Output shape: [bsz, num_mel_bins, seq_len] return output_features, output_length + # The following part remains unchanged class ResidualDownConv(nn.Module): def __init__(self, d_model=1280, avg_pooler=4): """ Downsampling module containing residual connection and convolution operation - + Parameters: d_model (int): Input and output hidden dimension avg_pooler (int): Downsampling factor (convolution step) @@ -436,52 +526,58 @@ def __init__(self, d_model=1280, avg_pooler=4): self.d_model = d_model self.avg_pooler = avg_pooler self.intermediate_dim = d_model * avg_pooler - + # Convolution layer for downsampling - self.gate_proj = nn.Conv1d(d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False) - self.up_proj = nn.Conv1d(d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False) - + self.gate_proj = nn.Conv1d( + d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False + ) + self.up_proj = nn.Conv1d( + d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False + ) + # Downsampled linear projection - self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False) - + self.down_proj = nn.Linear( + self.intermediate_dim, self.intermediate_dim, bias=False + ) + # Activation function and layer normalization - self.act_fn = ACT2FN['silu'] + self.act_fn = ACT2FN["silu"] self.layer_norm = nn.LayerNorm(self.intermediate_dim) def forward(self, x, input_length): """ Forward propagation, execute downsampling and residual processing - + Parameters: x (torch.Tensor): Input tensor, shape [B, D, T] - + Returns: res (torch.Tensor): Downsampled feature, shape [B, intermediate_dim, seq_len // avg_pooler] valid_mask (torch.Tensor): Valid sequence mask """ output_length = input_length // self.avg_pooler - x = x.transpose(1, 2) # (B, T, D) - batch_size, seq_len, _ = x.shape # (B, T, D) + x = x.transpose(1, 2) # (B, T, D) + batch_size, seq_len, _ = x.shape # (B, T, D) if seq_len % self.avg_pooler != 0: pad_size = self.avg_pooler - seq_len % self.avg_pooler x = F.pad(x, (0, pad_size), "constant", 0) - - xt = x.permute(0, 2, 1) # (B, D, T) + + xt = x.permute(0, 2, 1) # (B, D, T) g = self.gate_proj(xt).permute(0, 2, 1) # (B, T, D) - u = self.up_proj(xt).permute(0, 2, 1) # (B, T, D) + u = self.up_proj(xt).permute(0, 2, 1) # (B, T, D) x = x.reshape(batch_size, -1, self.intermediate_dim) # (B, T, D) - c = self.down_proj(self.act_fn(g) * u) # (B, T, D) - res = self.layer_norm(c + x) # (B, T, D) - res = res.transpose(1, 2) # (B, D, T) - return res, output_length # (B, D, T) - - + c = self.down_proj(self.act_fn(g) * u) # (B, T, D) + res = self.layer_norm(c + x) # (B, T, D) + res = res.transpose(1, 2) # (B, D, T) + return res, output_length # (B, D, T) + + class UpConv(nn.Module): def __init__(self, d_model=1280, stride=4): """ Simple upsampling module using transpose convolution - + Parameters: d_model (int): Input and output hidden dimension stride (int): Upsampling factor (transpose convolution step) @@ -489,23 +585,23 @@ def __init__(self, d_model=1280, stride=4): super().__init__() self.d_model = d_model self.stride = stride - + # Simple transpose convolution layer to keep channel number consistent self.up_conv = nn.ConvTranspose1d( - self.stride * d_model, - d_model, - kernel_size=stride, - stride=stride, - bias=False + self.stride * d_model, + d_model, + kernel_size=stride, + stride=stride, + bias=False, ) def forward(self, x, input_length): """ Forward propagation, execute upsampling - + Parameters: x (torch.Tensor): Input tensor, shape [B, D * stride, T] - + Returns: res (torch.Tensor): Upsampled feature, shape [B, D, T * stride] """ @@ -513,21 +609,21 @@ def forward(self, x, input_length): res = self.up_conv(x) output_length = input_length * self.stride return res, output_length - + # Define Transformer encoder containing multiple Transformer layers for feature extraction and transformation class Transformer(nn.Module): def __init__( - self, - input_dim=1280, # Input feature dimension - d_model=1280, # Model's hidden state dimension (embedding dimension) - output_dim=1280, # Output feature dimension - max_source_positions=1500, # Maximum sequence length for positional embedding - encoder_layers=32, # Transformer encoder layer number - encoder_attention_heads=20, # Attention head number for each Transformer layer - encoder_ffn_dim=5120, # Intermediate dimension for feedforward network - activation_function="gelu", # Activation function type, default GELU - attn_type="varlen" # Attention mechanism type + self, + input_dim=1280, # Input feature dimension + d_model=1280, # Model's hidden state dimension (embedding dimension) + output_dim=1280, # Output feature dimension + max_source_positions=1500, # Maximum sequence length for positional embedding + encoder_layers=32, # Transformer encoder layer number + encoder_attention_heads=20, # Attention head number for each Transformer layer + encoder_ffn_dim=5120, # Intermediate dimension for feedforward network + activation_function="gelu", # Activation function type, default GELU + attn_type="varlen", # Attention mechanism type ): super().__init__() self.input_dim = input_dim # Save input dimension @@ -542,20 +638,25 @@ def __init__( self.proj = None # No need for input projection layer # Register positional embedding buffer, using sine function to generate, shape (max_source_positions, d_model) - self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model)) - + self.register_buffer( + "positional_embedding", sinusoids(self.max_source_positions, d_model) + ) + # Create Transformer encoder layer list, each layer contains attention mechanism and feedforward network - self.layers = nn.ModuleList([ - OmniWhisperTransformerLayer( - activation_function=activation_function, - d_model=d_model, - attention_heads=encoder_attention_heads, - ffn_dim=encoder_ffn_dim, - causal=False, # Encoder does not need causal attention - attn_type=attn_type # Pass attention type - ) for _ in range(encoder_layers) - ]) - + self.layers = nn.ModuleList( + [ + OmniWhisperTransformerLayer( + activation_function=activation_function, + d_model=d_model, + attention_heads=encoder_attention_heads, + ffn_dim=encoder_ffn_dim, + causal=False, # Encoder does not need causal attention + attn_type=attn_type, # Pass attention type + ) + for _ in range(encoder_layers) + ] + ) + # Last layer normalization for stable output self.layer_norm = nn.LayerNorm(d_model) @@ -565,15 +666,20 @@ def __init__( else: self.out_proj = None # No need for output projection layer - def forward(self, input_features: torch.Tensor, input_length: torch.Tensor, output_hidden_states: bool = False): + def forward( + self, + input_features: torch.Tensor, + input_length: torch.Tensor, + output_hidden_states: bool = False, + ): """ Forward propagation function to convert input features through Transformer layer to hidden state representation - + Parameters: input_features (torch.Tensor): Input features, shape [bsz, input_dim, seq_len] (B, input_dim, T) input_length (torch.Tensor): Input sequence length for each sample, shape [bsz] output_hidden_states (bool, optional): Whether to return hidden states for each layer, default False - + Returns: if output_hidden_states is False: hidden_states (torch.Tensor): Encoded hidden states, shape [bsz, output_dim, seq_len] (B, output_dim, T) @@ -588,57 +694,79 @@ def forward(self, input_features: torch.Tensor, input_length: torch.Tensor, outp # If there is input projection layer, map input features from input_dim to d_model if self.proj is not None: - hidden_states = self.proj(input_features.permute(0, 2, 1)).permute(0, 2, 1) # [bsz, d_model, seq_len] (B, D, T) + hidden_states = self.proj(input_features.permute(0, 2, 1)).permute( + 0, 2, 1 + ) # [bsz, d_model, seq_len] (B, D, T) else: hidden_states = input_features # [bsz, d_model, seq_len] (B, D, T) # Adjust input dimension order to [bsz, seq_len, d_model] for Transformer input - hidden_states = hidden_states.permute(0, 2, 1) # [bsz, seq_len, d_model] (B, T, D) - + hidden_states = hidden_states.permute( + 0, 2, 1 + ) # [bsz, seq_len, d_model] (B, T, D) + # Get batch size and target sequence length bsz, tgt_len, _ = hidden_states.size() - + # According to current sequence length, take or use complete positional embedding if tgt_len < self.positional_embedding.shape[0]: - current_positional_embedding = self.positional_embedding[:tgt_len] # [tgt_len, d_model] + current_positional_embedding = self.positional_embedding[ + :tgt_len + ] # [tgt_len, d_model] else: - current_positional_embedding = self.positional_embedding # [max_source_positions, d_model] - + current_positional_embedding = ( + self.positional_embedding + ) # [max_source_positions, d_model] + # Add input features to positional embedding, convert to float to avoid precision issues - hidden_states = (hidden_states.to(torch.float32) + current_positional_embedding).to(hidden_states.dtype) # [bsz, seq_len, d_model] - + hidden_states = ( + hidden_states.to(torch.float32) + current_positional_embedding + ).to( + hidden_states.dtype + ) # [bsz, seq_len, d_model] + # Generate sequence mask for processing variable-length sequence - attention_mask = get_sequence_mask(hidden_states, output_length) # [bsz, tgt_len, 1] - + attention_mask = get_sequence_mask( + hidden_states, output_length + ) # [bsz, tgt_len, 1] + # Initialize hidden states list for storing output for each layer (if needed) hidden_states_all_layers = () if output_hidden_states else None - + # Process hidden states through Transformer encoder layer by layer for encoder_layer in self.layers: if output_hidden_states: hidden_states_all_layers = hidden_states_all_layers + (hidden_states,) - hidden_states = encoder_layer(hidden_states, output_length) # [bsz, seq_len, d_model] - + hidden_states = encoder_layer( + hidden_states, output_length + ) # [bsz, seq_len, d_model] + # Normalize hidden states hidden_states = self.layer_norm(hidden_states) # [bsz, seq_len, d_model] if output_hidden_states: hidden_states_all_layers = hidden_states_all_layers + (hidden_states,) - + # Use mask to zero out padding parts and ensure output only retains valid data - hidden_states = torch.where(attention_mask, hidden_states, 0) # [bsz, seq_len, d_model] - + hidden_states = torch.where( + attention_mask, hidden_states, 0 + ) # [bsz, seq_len, d_model] + # Adjust dimension order to [bsz, d_model, seq_len] - hidden_states = hidden_states.transpose(1, 2) # [bsz, d_model, seq_len] (B, D, T) + hidden_states = hidden_states.transpose( + 1, 2 + ) # [bsz, d_model, seq_len] (B, D, T) # If there is output projection layer, map hidden states from d_model to output_dim if self.out_proj is not None: - hidden_states = self.out_proj(hidden_states.permute(0, 2, 1)).permute(0, 2, 1) # [bsz, output_dim, seq_len] (B, output_dim, T) + hidden_states = self.out_proj(hidden_states.permute(0, 2, 1)).permute( + 0, 2, 1 + ) # [bsz, output_dim, seq_len] (B, output_dim, T) if not output_hidden_states: return hidden_states, output_length else: return hidden_states, output_length, hidden_states_all_layers - + def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: """ @@ -1477,4 +1605,3 @@ def forward(self, x, input_length): x = self.head(x) output_length = input_length * self.hop_size return x[:, None, :], output_length - diff --git a/XY_Tokenizer/xy_tokenizer/nn/quantizer.py b/XY_Tokenizer/xy_tokenizer/nn/quantizer.py index a7d28b9..79e2044 100644 --- a/XY_Tokenizer/xy_tokenizer/nn/quantizer.py +++ b/XY_Tokenizer/xy_tokenizer/nn/quantizer.py @@ -1,18 +1,21 @@ import logging + import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -import torch.distributed as dist - from einops import rearrange from torch.nn.utils import weight_norm + def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) + def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + def sample_vectors(samples, num): # samples: (N, D), num_samples: N, feature dim: D num_samples, device = samples.shape[0], samples.device @@ -22,35 +25,49 @@ def sample_vectors(samples, num): indices = torch.randint(0, num_samples, (num,), device=device) return samples[indices].float() # (num, D), ensure fp32 + def kmeans(samples, num_clusters, num_iters=10): # samples: (N, D), N samples with D dimensions dim, dtype = samples.shape[-1], torch.float32 # Force fp32 - means = sample_vectors(samples, num_clusters).float() # (num_clusters, D), ensure fp32 - + means = sample_vectors( + samples, num_clusters + ).float() # (num_clusters, D), ensure fp32 + for _ in range(num_iters): - dists = -(samples.float().pow(2).sum(1, keepdim=True) - # (N, 1), ensure fp32 - 2 * samples.float() @ means.t() + # (N, num_clusters), ensure fp32 - means.t().float().pow(2).sum(0, keepdim=True)) # (1, num_clusters), ensure fp32 + dists = -( + samples.float().pow(2).sum(1, keepdim=True) # (N, 1), ensure fp32 + - 2 * samples.float() @ means.t() # (N, num_clusters), ensure fp32 + + means.t().float().pow(2).sum(0, keepdim=True) + ) # (1, num_clusters), ensure fp32 # dists: (N, num_clusters) buckets = dists.max(dim=-1).indices # (N) bins = torch.bincount(buckets, minlength=num_clusters) # (num_clusters) zero_mask = bins == 0 # (num_clusters) bins_min_clamped = bins.masked_fill(zero_mask, 1) # (num_clusters) - - new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32) # (num_clusters, D), ensure fp32 - new_means.scatter_add_(0, buckets.unsqueeze(1).expand(-1, dim), samples.float()) # (num_clusters, D), ensure fp32 + + new_means = buckets.new_zeros( + num_clusters, dim, dtype=torch.float32 + ) # (num_clusters, D), ensure fp32 + new_means.scatter_add_( + 0, buckets.unsqueeze(1).expand(-1, dim), samples.float() + ) # (num_clusters, D), ensure fp32 new_means = new_means / bins_min_clamped[..., None] # (num_clusters, D) means = torch.where(zero_mask[..., None], means, new_means) # (num_clusters, D) - + # Final cluster assignments for returning cluster sizes - dists = -(samples.float().pow(2).sum(1, keepdim=True) - - 2 * samples.float() @ means.t() + - means.t().float().pow(2).sum(0, keepdim=True)) # (N, num_clusters), ensure fp32 + dists = -( + samples.float().pow(2).sum(1, keepdim=True) + - 2 * samples.float() @ means.t() + + means.t().float().pow(2).sum(0, keepdim=True) + ) # (N, num_clusters), ensure fp32 buckets = dists.max(dim=-1).indices # (N) - bins = torch.bincount(buckets, minlength=num_clusters).float() # (num_clusters), ensure fp32 - + bins = torch.bincount( + buckets, minlength=num_clusters + ).float() # (num_clusters), ensure fp32 + return means, bins # (num_clusters, D), (num_clusters) + class VectorQuantize(nn.Module): def __init__( self, @@ -58,11 +75,11 @@ def __init__( codebook_size, codebook_dim, commitment=1.0, - decay=0.99, # EMA decay - epsilon=1e-5, # Laplace smoothing epsilon - threshold_ema_dead=2, # Dead code threshold - kmeans_init=True, # Use kmeans initialization - kmeans_iters=10, # Kmeans iterations + decay=0.99, # EMA decay + epsilon=1e-5, # Laplace smoothing epsilon + threshold_ema_dead=2, # Dead code threshold + kmeans_init=True, # Use kmeans initialization + kmeans_iters=10, # Kmeans iterations ): super().__init__() self.input_dim = input_dim @@ -74,20 +91,32 @@ def __init__( self.threshold_ema_dead = threshold_ema_dead self.kmeans_init = kmeans_init self.kmeans_iters = kmeans_iters - + if self.input_dim != self.codebook_dim: - self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) # (B, D, T) -> (B, D', T) - self.out_project = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1) # (B, D', T) -> (B, D, T) + self.in_project = WNConv1d( + self.input_dim, self.codebook_dim, kernel_size=1 + ) # (B, D, T) -> (B, D', T) + self.out_project = WNConv1d( + self.codebook_dim, self.input_dim, kernel_size=1 + ) # (B, D', T) -> (B, D, T) else: self.in_project = nn.Identity() self.out_project = nn.Identity() # Initialize codebook and EMA buffers init_fn = torch.zeros if kmeans_init else lambda x, y: torch.randn(x, y) - self.register_buffer("codebook", init_fn(codebook_size, codebook_dim).float()) # (codebook_size, D'), ensure fp32 - self.register_buffer("inited", torch.tensor([not kmeans_init], dtype=torch.bool)) # (1) - self.register_buffer("cluster_size", torch.zeros(codebook_size).float()) # (codebook_size), ensure fp32 - self.register_buffer("embed_avg", self.codebook.clone().float()) # (codebook_size, D'), ensure fp32 + self.register_buffer( + "codebook", init_fn(codebook_size, codebook_dim).float() + ) # (codebook_size, D'), ensure fp32 + self.register_buffer( + "inited", torch.tensor([not kmeans_init], dtype=torch.bool) + ) # (1) + self.register_buffer( + "cluster_size", torch.zeros(codebook_size).float() + ) # (codebook_size), ensure fp32 + self.register_buffer( + "embed_avg", self.codebook.clone().float() + ) # (codebook_size, D'), ensure fp32 def ema_update(self, encodings, embed_onehot): # encodings: (B*T, D'), embed_onehot: (B*T, codebook_size) @@ -96,56 +125,72 @@ def ema_update(self, encodings, embed_onehot): embed_onehot = embed_onehot.float() # Ensure fp32 cluster_size_new = embed_onehot.sum(0) # (codebook_size) embed_sum = encodings.t() @ embed_onehot # (D', codebook_size) - + # Distributed reduction if dist.is_initialized(): dist.all_reduce(cluster_size_new, op=dist.ReduceOp.SUM) dist.all_reduce(embed_sum, op=dist.ReduceOp.SUM) - + ema_inplace(self.cluster_size, cluster_size_new, self.decay) # (codebook_size) ema_inplace(self.embed_avg, embed_sum.t(), self.decay) # (codebook_size, D') - + # Laplace smoothing - cluster_size = (self.cluster_size + self.epsilon) / (self.cluster_size.sum() + self.codebook_size * self.epsilon) # (codebook_size) + cluster_size = (self.cluster_size + self.epsilon) / ( + self.cluster_size.sum() + self.codebook_size * self.epsilon + ) # (codebook_size) cluster_size = cluster_size * self.cluster_size.sum() # (codebook_size) - self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1)) # (codebook_size, D') + self.codebook.copy_( + self.embed_avg / cluster_size.unsqueeze(1) + ) # (codebook_size, D') def replace_dead_codes(self, encodings): # encodings: (B*T, D') """Replace dead codes with random samples from current batch""" if self.threshold_ema_dead == 0: return - + dead_mask = self.cluster_size < self.threshold_ema_dead # (codebook_size) if dead_mask.any(): if dist.is_initialized() and dist.get_rank() == 0: - samples = sample_vectors(encodings.float(), self.codebook_size) # (codebook_size, D'), ensure fp32 + samples = sample_vectors( + encodings.float(), self.codebook_size + ) # (codebook_size, D'), ensure fp32 else: - samples = torch.zeros_like(self.codebook).float() # Placeholder, ensure fp32 - + samples = torch.zeros_like( + self.codebook + ).float() # Placeholder, ensure fp32 + # Broadcast samples if dist.is_initialized(): dist.broadcast(samples, src=0) - - self.codebook[dead_mask] = samples[:dead_mask.sum()].to(self.codebook.dtype) # Update dead codes + + self.codebook[dead_mask] = samples[: dead_mask.sum()].to( + self.codebook.dtype + ) # Update dead codes def init_codebook(self, encodings): # encodings: (B*T, D') """Initialize codebook with k-means and update cluster_size""" if self.inited.item(): return - + if dist.is_initialized() and dist.get_rank() == 0: - embed, cluster_sizes = kmeans(encodings.float(), self.codebook_size, self.kmeans_iters) # (codebook_size, D'), (codebook_size), ensure fp32 + embed, cluster_sizes = kmeans( + encodings.float(), self.codebook_size, self.kmeans_iters + ) # (codebook_size, D'), (codebook_size), ensure fp32 else: - embed = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device).float() # ensure fp32 - cluster_sizes = torch.zeros(self.codebook_size, device=encodings.device, dtype=torch.float32) # ensure fp32 - + embed = torch.zeros( + self.codebook_size, self.codebook_dim, device=encodings.device + ).float() # ensure fp32 + cluster_sizes = torch.zeros( + self.codebook_size, device=encodings.device, dtype=torch.float32 + ) # ensure fp32 + # Broadcast results if dist.is_initialized(): dist.broadcast(embed, src=0) dist.broadcast(cluster_sizes, src=0) - + self.codebook.copy_(embed) # (codebook_size, D') self.embed_avg.copy_(embed.clone()) # (codebook_size, D') self.cluster_size.copy_(cluster_sizes.float()) # (codebook_size) @@ -155,32 +200,39 @@ def forward(self, z): # z: (B, D, T) # logging.info(f"{self.cluster_size = }, {self.codebook = }, {self.embed_avg = }, {self.inited = }") z = z.float() # Ensure fp32 z_e = self.in_project(z).float() # (B, D', T), ensure fp32 - + # Rearrange for quantization encodings = rearrange(z_e, "b d t -> (b t) d").float() # (B*T, D'), ensure fp32 - + # Initialize codebook if needed if self.kmeans_init and not self.inited.item(): self.init_codebook(encodings) # Quantization - dist = (encodings.pow(2).sum(1, keepdim=True) - # (B*T, 1) - 2 * encodings @ self.codebook.float().t() + # (B*T, codebook_size) - self.codebook.float().pow(2).sum(1, keepdim=True).t()) # (1, codebook_size) + dist = ( + encodings.pow(2).sum(1, keepdim=True) # (B*T, 1) + - 2 * encodings @ self.codebook.float().t() # (B*T, codebook_size) + + self.codebook.float().pow(2).sum(1, keepdim=True).t() + ) # (1, codebook_size) # dist: (B*T, codebook_size) - + indices = (-dist).max(1)[1] # (B*T) indices = rearrange(indices, "(b t) -> b t", b=z.size(0)) # (B, T) - + # Get quantized vectors z_q = self.decode_code(indices).float() # (B, D', T), ensure fp32 - + # Commitment loss - commit_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) * self.commitment # (B) - + commit_loss = ( + F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + * self.commitment + ) # (B) + # EMA updates and dead code replacement during training if self.training and torch.is_grad_enabled(): - embed_onehot = F.one_hot(indices.view(-1), self.codebook_size).float() # (B*T, codebook_size), ensure fp32 + embed_onehot = F.one_hot( + indices.view(-1), self.codebook_size + ).float() # (B*T, codebook_size), ensure fp32 self.ema_update(encodings, embed_onehot) self.replace_dead_codes(encodings) @@ -188,20 +240,29 @@ def forward(self, z): # z: (B, D, T) z_q = z_e + (z_q - z_e).detach() # (B, D', T) z_q = self.out_project(z_q).float() # (B, D, T), ensure fp32 - return z_q, commit_loss, torch.tensor(0.0, device=z.device, dtype=torch.float32), indices, z # (B, D, T), (B), scalar, (B, T), (B, D', T) + return ( + z_q, + commit_loss, + torch.tensor(0.0, device=z.device, dtype=torch.float32), + indices, + z, + ) # (B, D, T), (B), scalar, (B, T), (B, D', T) def decode_code(self, embed_id): # embed_id: (B, T) - return F.embedding(embed_id, self.codebook).transpose(1, 2).float() # (B, D', T), ensure fp32 + return ( + F.embedding(embed_id, self.codebook).transpose(1, 2).float() + ) # (B, D', T), ensure fp32 + class ResidualVQ(nn.Module): def __init__( self, - input_dim: int = 1280, # Input dimension, unrelated to RVQ - rvq_dim = None, # RVQ dimension. If different from input_dim/output_dim, will add input_dim->rvq_dim/rvq_dim->output_dim projection - output_dim: int = None, # Output dimension, unrelated to RVQ + input_dim: int = 1280, # Input dimension, unrelated to RVQ + rvq_dim=None, # RVQ dimension. If different from input_dim/output_dim, will add input_dim->rvq_dim/rvq_dim->output_dim projection + output_dim: int = None, # Output dimension, unrelated to RVQ num_quantizers: int = 32, codebook_size: int = 1024, - codebook_dim: int = 8, # Dimension of each codebook. If different from rvq_dim, will add rvq_dim->codebook_dim and codebook_dim->rvq_dim projections + codebook_dim: int = 8, # Dimension of each codebook. If different from rvq_dim, will add rvq_dim->codebook_dim and codebook_dim->rvq_dim projections quantizer_dropout: float = 0.5, decay=0.99, epsilon=1e-5, @@ -213,16 +274,24 @@ def __init__( ): super().__init__() self.input_dim = input_dim - + self.num_quantizers = num_quantizers self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.quantizer_dropout = quantizer_dropout self.skip_rvq_ratio = skip_rvq_ratio # Store skip probability self.rvq_dim = rvq_dim - - self.input_proj = WNConv1d(input_dim, rvq_dim, kernel_size=1) if input_dim != rvq_dim else nn.Identity() - self.output_proj = WNConv1d(rvq_dim, output_dim, kernel_size=1) if rvq_dim != output_dim else nn.Identity() + + self.input_proj = ( + WNConv1d(input_dim, rvq_dim, kernel_size=1) + if input_dim != rvq_dim + else nn.Identity() + ) + self.output_proj = ( + WNConv1d(rvq_dim, output_dim, kernel_size=1) + if rvq_dim != output_dim + else nn.Identity() + ) self.quantizers = nn.ModuleList( [ @@ -241,14 +310,22 @@ def __init__( ] ) - def forward(self, z, input_length, n_quantizers: int = None): # z: (B, D, T), input_length: (B) + def forward( + self, z, input_length, n_quantizers: int = None + ): # z: (B, D, T), input_length: (B) z = self.input_proj(z) - with torch.autocast('cuda', enabled = False): + with torch.autocast("cuda", enabled=False): batch_size, _, max_time = z.shape - mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) # (B, T) - - quantized_out = torch.zeros_like(z, dtype=torch.float32) # (B, D, T), ensure fp32 + mask = torch.arange(max_time, device=z.device).expand( + batch_size, max_time + ) < input_length.unsqueeze( + 1 + ) # (B, T) + + quantized_out = torch.zeros_like( + z, dtype=torch.float32 + ) # (B, D, T), ensure fp32 residual = z.clone().float() # (B, D, T), ensure fp32 all_commit_losses = [] @@ -261,36 +338,66 @@ def forward(self, z, input_length, n_quantizers: int = None): # z: (B, D, T), i skip_mask = None if self.training and torch.is_grad_enabled() and self.skip_rvq_ratio > 0: # Generate random mask with skip_rvq_ratio probability - skip_mask = torch.rand(batch_size, device=z.device) < self.skip_rvq_ratio # (B,) + skip_mask = ( + torch.rand(batch_size, device=z.device) < self.skip_rvq_ratio + ) # (B,) # If all samples are skipped, force the first sample to be unskipped if skip_mask.all(): - skip_mask[0] = False # Ensure at least one sample (index 0) is not skipped + skip_mask[0] = ( + False # Ensure at least one sample (index 0) is not skipped + ) if self.training and torch.is_grad_enabled(): - n_quantizers_tensor = torch.ones((z.shape[0],), dtype=torch.float32, device=z.device) * self.num_quantizers + 1 # (B) - dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],), dtype=torch.float32, device=z.device) # (B) + n_quantizers_tensor = ( + torch.ones((z.shape[0],), dtype=torch.float32, device=z.device) + * self.num_quantizers + + 1 + ) # (B) + dropout = torch.randint( + 1, + self.num_quantizers + 1, + (z.shape[0],), + dtype=torch.float32, + device=z.device, + ) # (B) n_dropout = int(z.shape[0] * self.quantizer_dropout) n_quantizers_tensor[:n_dropout] = dropout[:n_dropout] # (B) else: - n_quantizers_tensor = torch.full((z.shape[0],), n_quantizers, dtype=torch.float32, device=z.device) # (B) + n_quantizers_tensor = torch.full( + (z.shape[0],), n_quantizers, dtype=torch.float32, device=z.device + ) # (B) for i, quantizer in enumerate(self.quantizers): if not self.training and i >= n_quantizers: break masked_residual = residual * mask.unsqueeze(1) # (B, D, T) - + # If skipping RVQ, directly use input value if self.training and skip_mask is not None and skip_mask.any(): - z_q_i = torch.zeros_like(masked_residual, dtype=torch.float32) # (B, D, T), ensure fp32 - commit_loss_i = torch.zeros(batch_size, device=z.device, dtype=torch.float32) # (B), ensure fp32 - indices_i = torch.zeros(batch_size, max_time, device=z.device, dtype=torch.long) # (B, T) - z_e_i = torch.zeros_like(masked_residual, dtype=torch.float32) # (B, D, T), ensure fp32 - + z_q_i = torch.zeros_like( + masked_residual, dtype=torch.float32 + ) # (B, D, T), ensure fp32 + commit_loss_i = torch.zeros( + batch_size, device=z.device, dtype=torch.float32 + ) # (B), ensure fp32 + indices_i = torch.zeros( + batch_size, max_time, device=z.device, dtype=torch.long + ) # (B, T) + z_e_i = torch.zeros_like( + masked_residual, dtype=torch.float32 + ) # (B, D, T), ensure fp32 + # Quantize non-skipped samples non_skipped_mask = ~skip_mask # (B) if non_skipped_mask.any(): - z_q_i_non_skipped, commit_loss_i_non_skipped, _, indices_i_non_skipped, z_e_i_non_skipped = quantizer( + ( + z_q_i_non_skipped, + commit_loss_i_non_skipped, + _, + indices_i_non_skipped, + z_e_i_non_skipped, + ) = quantizer( masked_residual[non_skipped_mask].float() # Ensure fp32 ) z_q_i[non_skipped_mask] = z_q_i_non_skipped @@ -298,29 +405,48 @@ def forward(self, z, input_length, n_quantizers: int = None): # z: (B, D, T), i indices_i[non_skipped_mask] = indices_i_non_skipped z_e_i[non_skipped_mask] = z_e_i_non_skipped else: - z_q_i, commit_loss_i, _, indices_i, z_e_i = quantizer(masked_residual.float()) # (B, D, T), (B), scalar, (B, T), (B, D', T), ensure fp32 + z_q_i, commit_loss_i, _, indices_i, z_e_i = quantizer( + masked_residual.float() + ) # (B, D, T), (B), scalar, (B, T), (B, D', T), ensure fp32 + + quantizer_mask = ( + torch.full((z.shape[0],), i, device=z.device, dtype=torch.float32) + < n_quantizers_tensor + ) # (B) + update_mask = (mask & quantizer_mask.unsqueeze(-1)).unsqueeze( + 1 + ) # (B, 1, T) - quantizer_mask = (torch.full((z.shape[0],), i, device=z.device, dtype=torch.float32) < n_quantizers_tensor) # (B) - update_mask = (mask & quantizer_mask.unsqueeze(-1)).unsqueeze(1) # (B, 1, T) - # If skipping, output is directly the input if skip_mask is not None: - skip_mask_expanded = skip_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1) - z_q_i = torch.where(skip_mask_expanded, masked_residual, z_q_i) # (B, D, T) - commit_loss_i = torch.where(skip_mask, torch.zeros_like(commit_loss_i), commit_loss_i) # (B) + skip_mask_expanded = skip_mask.unsqueeze(1).unsqueeze( + 2 + ) # (B, 1, 1) + z_q_i = torch.where( + skip_mask_expanded, masked_residual, z_q_i + ) # (B, D, T) + commit_loss_i = torch.where( + skip_mask, torch.zeros_like(commit_loss_i), commit_loss_i + ) # (B) quantized_out = quantized_out + z_q_i * update_mask # (B, D, T) residual_fp32 = residual.to(dtype=torch.float32) # (B, D, T) z_q_i_fp32 = z_q_i.to(dtype=torch.float32) # (B, D, T) residual_fp32 = residual_fp32 - z_q_i_fp32 * update_mask # (B, D, T) - residual = residual_fp32.to(dtype=torch.float32) # (B, D, T), ensure fp32 + residual = residual_fp32.to( + dtype=torch.float32 + ) # (B, D, T), ensure fp32 valid_mask = mask & quantizer_mask.unsqueeze(-1) # (B, T) if valid_mask.any(): - commit_loss_i = (commit_loss_i * quantizer_mask).sum() / quantizer_mask.sum() # scalar + commit_loss_i = ( + commit_loss_i * quantizer_mask + ).sum() / quantizer_mask.sum() # scalar else: - commit_loss_i = torch.tensor(0.0, device=z.device, dtype=torch.float32) # scalar, ensure fp32 + commit_loss_i = torch.tensor( + 0.0, device=z.device, dtype=torch.float32 + ) # scalar, ensure fp32 all_commit_losses.append(commit_loss_i) # scalar all_indices.append(indices_i) # (B, T) @@ -335,16 +461,16 @@ def forward(self, z, input_length, n_quantizers: int = None): # z: (B, D, T), i quantized_out = self.output_proj(quantized_out) return ( - quantized_out, # (B, D, T) - all_indices, # (N, B, T) - all_commit_losses,# (N) - all_quantized, # (N, B, D, T) - output_length, # (B) + quantized_out, # (B, D, T) + all_indices, # (N, B, T) + all_commit_losses, # (N) + all_quantized, # (N, B, D, T) + output_length, # (B) ) def decode_codes(self, codes): # codes: (nq, B, T) """Decode codes from multiple quantizers to embeddings. - + Args: codes: Tensor of shape (nq, B, T) containing code indices for each quantizer. @@ -353,7 +479,9 @@ def decode_codes(self, codes): # codes: (nq, B, T) """ nq, B, T = codes.shape device = codes.device - emb = torch.zeros(B, self.rvq_dim, T, device=device, dtype=torch.float32) # (B, D, T) + emb = torch.zeros( + B, self.rvq_dim, T, device=device, dtype=torch.float32 + ) # (B, D, T) for i, quantizer in enumerate(self.quantizers[:nq]): code_i = codes[i] # (B, T) @@ -361,10 +489,10 @@ def decode_codes(self, codes): # codes: (nq, B, T) emb += quantized_i # Accumulate quantized embeddings emb = self.output_proj(emb) # (B, D, T), apply output projection - return emb # (B, D, T) + return emb # (B, D, T) def ema_inplace(moving_avg, new, decay): # moving_avg: (codebook_size) or (codebook_size, D'), new: same as moving_avg """Update exponential moving average in-place""" - moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) # ensure fp32 \ No newline at end of file + moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) # ensure fp32 diff --git a/examples/example.txt b/examples/example.txt index 0d4b3bd..db4779f 100644 --- a/examples/example.txt +++ b/examples/example.txt @@ -1 +1 @@ -MOSS-TTSD (text to spoken dialogue) is an open-source bilingual spoken dialogue synthesis model that supports both Chinese and English. It can transform dialogue scripts between two speakers into natural, expressive conversational speech. MOSS-TTSD supports voice cloning and single-session speech generation of up to 960 seconds, making it ideal for AI podcast production. \ No newline at end of file +MOSS-TTSD (text to spoken dialogue) is an open-source bilingual spoken dialogue synthesis model that supports both Chinese and English. It can transform dialogue scripts between two speakers into natural, expressive conversational speech. MOSS-TTSD supports voice cloning and single-session speech generation of up to 960 seconds, making it ideal for AI podcast production. diff --git a/examples/examples.jsonl b/examples/examples.jsonl index c22a263..4a7b346 100644 --- a/examples/examples.jsonl +++ b/examples/examples.jsonl @@ -1,2 +1,2 @@ {"base_path":"examples","text": "[S1]诶,我最近看了一篇讲人工智能的文章,还挺有意思的,想跟你聊聊。[S2]哦?是吗,关于啥的啊?又是哪个公司发了什么逆天的新模型吗?[S1]那倒不是,是一个咱们国内的教授,复旦大学的邱锡鹏教授,他提了一个新概念,叫什么,呃,叫情境扩展,Context Scaling。[S2]Context Scaling?情境扩展?听起来有点,呃,有点玄乎啊,这是个啥意思?[S1]对,我一开始也觉得有点抽象,但你看完就觉得,诶,特别有道理。他大概意思就是说啊,咱们现在对人工智能的追求,不能光是把它做得更大,你知道吧,就是不能光堆参数,喂数据。[S2]嗯,是,这个我懂。就好像之前大家都在比谁的模型参数多,几千亿,上万亿的。[S1]对对对,就是那个意思。他说那个时代,算是第一幕,就是模型规模化的胜利,靠堆料,堆出了像这个ChatGPT这样厉害的通用模型。[S2]嗯,是的。[S1]然后呢,现在差不多是第二幕,就是大家发现光堆料好像不行了,收益越来越小,就开始搞一些,呃,后训练的优化。[S2]哦,后训练优化?比如呢?[S1]比如让AI学会用工具,或者搞那个什么思维链,让它像人一样一步一步思考问题,还有就是强化学习,让它自己在游戏里或者写代码的时候自己跟自己玩,然后越变越强。[S2]哦,明白了。就是让它不光是知识多,还得会用,变得更聪明,是这个意思吧。[S1]没错,就是这个理儿。但是呢,这两步走完了,就遇到了新的瓶颈。邱教授就觉得,关键问题出在了这个情境,也就是Context上。[S2]嗯?情境?[S1]是。很多时候AI做不出正确的决定,不是因为它笨,而是因为它没搞明白现在到底是个什么情况,就是我们给它的任务或者情境描述得不够清楚。[S2]哦,原来是这样。[S1]对。所以他觉得下一幕,也就是第三幕,就应该是这个情境扩展,Context Scaling。核心就是要让AI能真正理解那种特别复杂、多变,甚至有点模糊的真实世界的情境。[S2]嗯,这个听起来就高级了。那他说的这个情境,到底指什么啊?不就是我们输进去的那段话吗?[S1]诶,不是那么简单的。他说的情境啊,是个特别立体的东西。它不光包括你说了什么,还包括,呃,时间,地点,谁在说话,目的是什么,甚至还包括咱们文化里那种只可意会不可言传的规则,还有人与人之间的那种默契。[S2]哇,那这个范围可太广了。[S1]是吧。他举了个例子,就比如说,当一个人说不要的时候。[S2]嗯。[S1]你想想,在不同的情况下,这个不要的意思完全不一样。有可能是真的拒绝,也可能是在开玩笑,甚至可能是一种反向的请求,就是我们说的口是心非,嘴上说不要,身体很诚实那种。[S2]哈哈,确实确实,这个AI可怎么判断啊。[S1]对啊,所以就需要理解整个情境。他说这个里面最关键的,就是要捕获一种叫暗知识的东西。[S2]暗知识?听着像武林秘籍一样。[S1]有点那个意思。就是指我们人类会,但是很难用语言清楚地讲出来的那些能力。[S2]哦?比如说?[S1]比如说社交的智慧啊,怎么通过一个眼神,一个停顿,或者语气变化来理解对方的意思。[S2]嗯,是的。[S1]还有就是文化适应能力,在不同的文化里,有些事能做,有些事不能做,这些都没人写在书上,但我们就是知道。[S2]没错。[S1]他说AI要是能学会这些,那才算是真正的智能突破了。这也能解决一些AI安全的问题,比如那个很有名的回形针悖论。[S2]哦,那个我知道,就是你让AI造回形针,它最后可能会为了这个目标把整个世界都给占了,用来造回形针。[S1]对。但如果它有情境智能,它就能理解我们人类社会的复杂性,知道有些事是不能做的,就算你没有明确下指令禁止它。[S2]有道理,有道理。那,那技术上要怎么实现呢?听起来好难啊。[S1]是挺难的。他提了三大技术支柱。第一个叫强交互性。就是AI要不停地跟环境,特别是跟人来互动学习,不光要知道怎么做,还要知道为什么这么做。[S2]嗯,在互动里学习。[S1]第二个叫具身性。就是AI得有个身体,不一定是真的人形机器人,在虚拟世界里也行,总之它得能感知和行动。[S2]哦,明白。[S1]第三个,我觉得是最特别的,叫拟人化。[S2]拟人化?[S1]对,就是说AI需要有类似人类的情感共鸣能力。不是假装有感情,而是真正理解人的情绪,人的偏好,懂得社交的距离感,什么时候该关心,什么时候该保持沉默。[S2]哇,这个要求可太高了,感觉比养个孩子还难。[S1]哈哈,可不是嘛。所以他说这事儿吧,不是要去替代其他的技术路线,而是把它们都整合起来。像什么推理增强啊,多模态啊,强化学习啊,最终都是为了服务于一个目标,就是让AI能深刻地理解情境。[S2]嗯。[S1]所以总的来说,就是别光在已有的路上卷了,要去定义一些大家都觉得重要,但没人说清楚的问题。[S2]确实。听你这么一说,感觉这个情境扩展,Context Scaling,确实给人工智能指了一个挺不一样的方向。不再是冷冰冰的计算,而是要变得更懂我们人类,更懂这个复杂的世界。[S1]是啊,他说这可能是通往通用人工智能,就是AGI,非常关键的一步。[S2]嗯,听起来未来可期啊。","prompt_audio_speaker1":"zh_spk1_moon.wav","prompt_text_speaker1": "周一到周五,每天早晨七点半到九点半的直播片段。言下之意呢,就是废话有点多,大家也别嫌弃,因为这都是直播间最真实的状态了。","prompt_audio_speaker2":"zh_spk2_moon.wav","prompt_text_speaker2": "如果大家想听到更丰富更及时的直播内容,记得在周一到周五准时进入直播间,和大家一起畅聊新消费新科技新趋势。"} -{"base_path":"examples","text": "[S1]Hey, did you hear about that company called MoSi AI? [S2]MoSi AI? Yeah, I think I've heard of them. Aren't they the ones doing AI stuff? What new thing have they come up with now? [S1]Yeah, that's them! They recently launched this super hot new product called, um, Asteroid. [S2]Asteroid. That's a pretty cool name. Does it mean like the space rock? [S1]Yeah, I think that's what it means. Let me tell you, this thing is incredible. They say it's currently the most realistic, human-like conversational TTS model out there. [S2]Oh, TTS technology? You mean the text-to-speech thing? Aren't there already a lot of those on the market? What makes this one so special? [S1]Well, it's completely different. They say the voice produced by Asteroid sounds almost exactly like a real person talking. And it's super smooth and natural. Not at all like, you know, that stiff robotic tone. [S2]I see. Some voice assistants do still have that mechanical feel, especially during multi-turn conversations. So how amazing is this Asteroid exactly? [S1]I heard they internally call Asteroid China's own version of NotebookLM. [S2]NotebookLM? Oh, I know that one. Isn't that the personal AI that Google made? The one that helps organize notes and answers all kinds of questions? So Asteroid has similar functions? [S1]Right. That's probably what they mean. It's not just that the voice sounds incredibly human. The intelligence level is also really high. It can have these really logical, contextual, in-depth conversations with you. It's just like chatting with a real person. [S2]Wow, that sounds amazing. If they can really achieve that... [S1]Yeah, it's basically like having a personal assistant that's both articulate and really understands you. [S2]Hmm. That does sound appealing. [S1]And some people are saying it's like the, what's it called again in the voice technology circle? Oh right, DeepSeek. [S2]DeepSeek? Isn't that the company making large language models? Their models are pretty popular now. That's high praise. So they're saying Asteroid is top-tier technology? [S1]Yeah, I think that's what they mean. It's like they've reached a whole new level in voice synthesis. Similar to the impact DeepSeek has had in natural language processing. It could be that kind of groundbreaking technology. [S2]If Asteroid is really that impressive, where could it be used? I feel like there must be huge potential there. [S1]Absolutely. Just imagine future smart customer service, audiobook reading, and those virtual livestreamers that are so popular now. The quality would improve dramatically. We might even have personal assistants using Asteroid to talk to us directly. How natural would that be? [S2]Yeah. That does sound exciting. When can we actually try it out? Are there any demos available? [S1]I haven't looked into that carefully yet. But since they've already announced it, I'm guessing it won't be long. I'm really eager to try it and see just how human-like it is. [S2]Yeah, yeah. If it can really deliver what they're promising, getting information and interacting with machines will be so much more convenient. The experience will be much better too. [S1]Exactly, exactly. We're just waiting for MoSi AI to give us this big surprise.", "prompt_audio_speaker1":"m1.wav","prompt_text_speaker1": "How much do you know about her?","prompt_audio_speaker2":"m2.wav","prompt_text_speaker2": "Well, we know this much about her. You've been with her constantly since the first day you met her. And we followed you while you went dining, dancing, and sailing. And last night, I happened to be there when you were having dinner with her at Le Petit Tableau."} \ No newline at end of file +{"base_path":"examples","text": "[S1]Hey, did you hear about that company called MoSi AI? [S2]MoSi AI? Yeah, I think I've heard of them. Aren't they the ones doing AI stuff? What new thing have they come up with now? [S1]Yeah, that's them! They recently launched this super hot new product called, um, Asteroid. [S2]Asteroid. That's a pretty cool name. Does it mean like the space rock? [S1]Yeah, I think that's what it means. Let me tell you, this thing is incredible. They say it's currently the most realistic, human-like conversational TTS model out there. [S2]Oh, TTS technology? You mean the text-to-speech thing? Aren't there already a lot of those on the market? What makes this one so special? [S1]Well, it's completely different. They say the voice produced by Asteroid sounds almost exactly like a real person talking. And it's super smooth and natural. Not at all like, you know, that stiff robotic tone. [S2]I see. Some voice assistants do still have that mechanical feel, especially during multi-turn conversations. So how amazing is this Asteroid exactly? [S1]I heard they internally call Asteroid China's own version of NotebookLM. [S2]NotebookLM? Oh, I know that one. Isn't that the personal AI that Google made? The one that helps organize notes and answers all kinds of questions? So Asteroid has similar functions? [S1]Right. That's probably what they mean. It's not just that the voice sounds incredibly human. The intelligence level is also really high. It can have these really logical, contextual, in-depth conversations with you. It's just like chatting with a real person. [S2]Wow, that sounds amazing. If they can really achieve that... [S1]Yeah, it's basically like having a personal assistant that's both articulate and really understands you. [S2]Hmm. That does sound appealing. [S1]And some people are saying it's like the, what's it called again in the voice technology circle? Oh right, DeepSeek. [S2]DeepSeek? Isn't that the company making large language models? Their models are pretty popular now. That's high praise. So they're saying Asteroid is top-tier technology? [S1]Yeah, I think that's what they mean. It's like they've reached a whole new level in voice synthesis. Similar to the impact DeepSeek has had in natural language processing. It could be that kind of groundbreaking technology. [S2]If Asteroid is really that impressive, where could it be used? I feel like there must be huge potential there. [S1]Absolutely. Just imagine future smart customer service, audiobook reading, and those virtual livestreamers that are so popular now. The quality would improve dramatically. We might even have personal assistants using Asteroid to talk to us directly. How natural would that be? [S2]Yeah. That does sound exciting. When can we actually try it out? Are there any demos available? [S1]I haven't looked into that carefully yet. But since they've already announced it, I'm guessing it won't be long. I'm really eager to try it and see just how human-like it is. [S2]Yeah, yeah. If it can really deliver what they're promising, getting information and interacting with machines will be so much more convenient. The experience will be much better too. [S1]Exactly, exactly. We're just waiting for MoSi AI to give us this big surprise.", "prompt_audio_speaker1":"m1.wav","prompt_text_speaker1": "How much do you know about her?","prompt_audio_speaker2":"m2.wav","prompt_text_speaker2": "Well, we know this much about her. You've been with her constantly since the first day you met her. And we followed you while you went dining, dancing, and sailing. And last night, I happened to be there when you were having dinner with her at Le Petit Tableau."} diff --git a/examples/examples_only_text.jsonl b/examples/examples_only_text.jsonl index 9255952..1758785 100644 --- a/examples/examples_only_text.jsonl +++ b/examples/examples_only_text.jsonl @@ -1 +1 @@ -{"text": "[S1]Hey, did you hear about that company called MoSi AI? [S2]MoSi AI? Yeah, I think I've heard of them. Aren't they the ones doing AI stuff? What new thing have they come up with now? [S1]Yeah, that's them! They recently launched this super hot new product called, um, Asteroid. [S2]Asteroid. That's a pretty cool name. Does it mean like the space rock? [S1]Yeah, I think that's what it means. Let me tell you, this thing is incredible. They say it's currently the most realistic, human-like conversational TTS model out there. [S2]Oh, TTS technology? You mean the text-to-speech thing? Aren't there already a lot of those on the market? What makes this one so special? [S1]Well, it's completely different. They say the voice produced by Asteroid sounds almost exactly like a real person talking. And it's super smooth and natural. Not at all like, you know, that stiff robotic tone."} \ No newline at end of file +{"text": "[S1]Hey, did you hear about that company called MoSi AI? [S2]MoSi AI? Yeah, I think I've heard of them. Aren't they the ones doing AI stuff? What new thing have they come up with now? [S1]Yeah, that's them! They recently launched this super hot new product called, um, Asteroid. [S2]Asteroid. That's a pretty cool name. Does it mean like the space rock? [S1]Yeah, I think that's what it means. Let me tell you, this thing is incredible. They say it's currently the most realistic, human-like conversational TTS model out there. [S2]Oh, TTS technology? You mean the text-to-speech thing? Aren't there already a lot of those on the market? What makes this one so special? [S1]Well, it's completely different. They say the voice produced by Asteroid sounds almost exactly like a real person talking. And it's super smooth and natural. Not at all like, you know, that stiff robotic tone."} diff --git a/examples/examples_single_reference.jsonl b/examples/examples_single_reference.jsonl index 6f7fe10..64a3ff7 100644 --- a/examples/examples_single_reference.jsonl +++ b/examples/examples_single_reference.jsonl @@ -1 +1 @@ -{"base_path":"examples","text": "[S1]诶,跟你说个事儿啊,我最近听了不少那种AI生成的播客,不知道你有没有听过。[S2]哦,听过一些。怎么了,感觉怎么样?[S1]就是……怎么说呢,单听一句话,你觉得,哇,好像跟真人没啥区别。[S2]嗯。[S1]但是,你只要让它说上一段完整的对话,比如俩人聊天那种,那个感觉就立马不对了。[S2]对对对,我懂你的意思。就是那个所谓的“恐怖谷”效应,是吧?听着有点瘆人,感觉特别假,没有那个交流感。[S1]就是这个词儿,恐怖谷。结果你猜怎么着,这个事儿最近好像有救了。[S2]哦?什么情况?[S1]复旦大学那个邱锡鹏教授的团队,就是那个OpenMOSS团队,他们搞了个新东西,叫Moss TTSD。[S2]是吗?专门解决这个问题的?[S1]对,就是专门搞对话语音合成的。他们说,这个模型是基于一百万小时的音频训练出来的,号称要打破AI播客的恐怖谷魔咒。[S2]一百万小时,我的天,这数据量也太吓人了。[S1]是吧。而且最牛的是,这个模型,Moss TTSD,代码和模型权重,全都开源了,还能免费商用。[S2]哇,那可太棒了。意思就是,它不是生成一句话一句话的,而是把一整段多人对话的稿子给它,它直接生成一整段特逼真的对话录音?[S1]没错,就是这个意思。它能把对话里那种节奏啊、语气的变化啊,都给模拟出来。[S2]那效果到底怎么样?有对比吗?[S1]有啊。他们拿现在一个挺火的商业模型,叫豆包的播客生成功能,做了个对比。[S2]哦,豆包我知道,挺多人在用。[S1]结果呢,人家Moss TTSD作为一个开源模型,在情感丰富度啊、语气自然度这些方面,跟豆包比起来,基本上不相上下。[S2]真的假的?一个开源的能做到跟商业模型一个水平,那可太厉害了。[S1]是啊。我听了他们放出来的几段对比音频,确实很惊艳。[S2]那它技术上到底是怎么实现的?肯定有啥独门秘诀吧。[S1]嘿,你还真说对了。他们的核心创新,是一个叫XY Tokenizer的东西。[S2]XY Tokenizer?这是个啥玩意儿?[S1]它算是一个特别聪明的语音编码器。[S2]嗯。[S1]就是说,它在处理声音的时候,不光能明白你说的内容是啥意思,也就是语义信息,还能同时把你说话的那个味儿,那个声学信息,一块儿给编码进去。[S2]哦,我好像有点明白了。就是说,它不光知道“说的是什么”,还知道“是怎么说的”。[S1]对!而且它把这些信息压缩得特别特别小,比特率只有一kbps。这样一来呢,大语言模型学起来就轻松多了,能把声音里那些特别细微的特征都学到。[S2]原来是这样。那数据处理呢?你说那一百万小时的音频,肯定不能直接用吧,得要特别干净的数据才行。[S1]这个他们也下功夫了。他们自己搞了一套效率特别高的数据处理流水线。[S2]嗯。[S1]先用自己内部的一个模型,把一段录音里不同的人声给分开。他们说这个模型比现在市面上开源的、甚至一些商用的都要好。[S2]哇。[S1]分开之后呢,再用一个叫DNSMOS的工具给语音质量打分,只保留那些分数在二点八分以上的高质量片段。[S2]这筛选标准还挺严格的。[S1]可不是嘛。而且针对中文数据,他们还专门训练了一个模型,叫Whisper d,用来做文本转录。这个模型特别牛,连那种俩人说话声音叠在一起的地方,它都能准确地转写出来。[S2]这个确实是痛点,很多模型都搞不定这个。[S1]所以啊,料下足了,最后出来的东西效果就好。测试结果说,它跟顶尖的闭源模型豆包播客模型性能差不多。[S2]嗯。[S1]然后跟另一个开源模型MoonCast比呢,它的韵律和表现力都更自然。[S2]是的。[S1]最关键的一点来了, follow豆包的播客功能比,它除了效果差不多,还支持一个叫“零样本音色克隆”的功能。[S2]零样本音色克隆?这是说……我给它一段我的声音,它就能用我的声音去生成对话了?[S1]bingo。完全正确。就是这个意思。这样一来,可定制的程度就特别高了。[S2]我的天,那这应用场景可就广了去了。[S1]那可不。还有一个特别大的优点,它能一次性生成特别长的音频,最长能到九百六十秒。[S2]九百六十秒,那不就是十六分钟?[S1]对。一次生成十六分钟,就再也不用一小段一小段生成完了再拼起来了。那个拼接造成的停顿感和不自然,就都解决了。[S2]确实,这个对于做长音频,比如播客啊、影视配音啊、有声书什么的,简直是刚需。[S1]是啊。所以说,这个Moss TTSD一出来,感觉以后做播客、搞长篇访谈,甚至数字人直播带货这些事儿,门槛又得降一大截了。[S2]感觉AI播客真的要从“能听”变成“好听”了。对我们内容创作者来说,这绝对是个好工具啊。", "prompt_audio": "single_reference.wav", "prompt_text": "[S1]我意思是说通用计算机是指我们所有人都能用的这种普通计算机的意思啊?[S2]不是,不是。现在通用计算机就,比如说现在买了一台电脑嘛。[S1]嗯。[S2]电脑里面你可以装各种各样的软件,对吧?你可以用来聊微信,你可以用来-[S1]明白,可以做快机用的发票机啊-[S2]对对对对 [S1]...连接各种设备啊什么的。[S2]你只要放不同的软件,[S1]嗯。[S2]它都可以去完成那个软件所定义的相对应的工作。"} \ No newline at end of file +{"base_path":"examples","text": "[S1]诶,跟你说个事儿啊,我最近听了不少那种AI生成的播客,不知道你有没有听过。[S2]哦,听过一些。怎么了,感觉怎么样?[S1]就是……怎么说呢,单听一句话,你觉得,哇,好像跟真人没啥区别。[S2]嗯。[S1]但是,你只要让它说上一段完整的对话,比如俩人聊天那种,那个感觉就立马不对了。[S2]对对对,我懂你的意思。就是那个所谓的“恐怖谷”效应,是吧?听着有点瘆人,感觉特别假,没有那个交流感。[S1]就是这个词儿,恐怖谷。结果你猜怎么着,这个事儿最近好像有救了。[S2]哦?什么情况?[S1]复旦大学那个邱锡鹏教授的团队,就是那个OpenMOSS团队,他们搞了个新东西,叫Moss TTSD。[S2]是吗?专门解决这个问题的?[S1]对,就是专门搞对话语音合成的。他们说,这个模型是基于一百万小时的音频训练出来的,号称要打破AI播客的恐怖谷魔咒。[S2]一百万小时,我的天,这数据量也太吓人了。[S1]是吧。而且最牛的是,这个模型,Moss TTSD,代码和模型权重,全都开源了,还能免费商用。[S2]哇,那可太棒了。意思就是,它不是生成一句话一句话的,而是把一整段多人对话的稿子给它,它直接生成一整段特逼真的对话录音?[S1]没错,就是这个意思。它能把对话里那种节奏啊、语气的变化啊,都给模拟出来。[S2]那效果到底怎么样?有对比吗?[S1]有啊。他们拿现在一个挺火的商业模型,叫豆包的播客生成功能,做了个对比。[S2]哦,豆包我知道,挺多人在用。[S1]结果呢,人家Moss TTSD作为一个开源模型,在情感丰富度啊、语气自然度这些方面,跟豆包比起来,基本上不相上下。[S2]真的假的?一个开源的能做到跟商业模型一个水平,那可太厉害了。[S1]是啊。我听了他们放出来的几段对比音频,确实很惊艳。[S2]那它技术上到底是怎么实现的?肯定有啥独门秘诀吧。[S1]嘿,你还真说对了。他们的核心创新,是一个叫XY Tokenizer的东西。[S2]XY Tokenizer?这是个啥玩意儿?[S1]它算是一个特别聪明的语音编码器。[S2]嗯。[S1]就是说,它在处理声音的时候,不光能明白你说的内容是啥意思,也就是语义信息,还能同时把你说话的那个味儿,那个声学信息,一块儿给编码进去。[S2]哦,我好像有点明白了。就是说,它不光知道“说的是什么”,还知道“是怎么说的”。[S1]对!而且它把这些信息压缩得特别特别小,比特率只有一kbps。这样一来呢,大语言模型学起来就轻松多了,能把声音里那些特别细微的特征都学到。[S2]原来是这样。那数据处理呢?你说那一百万小时的音频,肯定不能直接用吧,得要特别干净的数据才行。[S1]这个他们也下功夫了。他们自己搞了一套效率特别高的数据处理流水线。[S2]嗯。[S1]先用自己内部的一个模型,把一段录音里不同的人声给分开。他们说这个模型比现在市面上开源的、甚至一些商用的都要好。[S2]哇。[S1]分开之后呢,再用一个叫DNSMOS的工具给语音质量打分,只保留那些分数在二点八分以上的高质量片段。[S2]这筛选标准还挺严格的。[S1]可不是嘛。而且针对中文数据,他们还专门训练了一个模型,叫Whisper d,用来做文本转录。这个模型特别牛,连那种俩人说话声音叠在一起的地方,它都能准确地转写出来。[S2]这个确实是痛点,很多模型都搞不定这个。[S1]所以啊,料下足了,最后出来的东西效果就好。测试结果说,它跟顶尖的闭源模型豆包播客模型性能差不多。[S2]嗯。[S1]然后跟另一个开源模型MoonCast比呢,它的韵律和表现力都更自然。[S2]是的。[S1]最关键的一点来了, follow豆包的播客功能比,它除了效果差不多,还支持一个叫“零样本音色克隆”的功能。[S2]零样本音色克隆?这是说……我给它一段我的声音,它就能用我的声音去生成对话了?[S1]bingo。完全正确。就是这个意思。这样一来,可定制的程度就特别高了。[S2]我的天,那这应用场景可就广了去了。[S1]那可不。还有一个特别大的优点,它能一次性生成特别长的音频,最长能到九百六十秒。[S2]九百六十秒,那不就是十六分钟?[S1]对。一次生成十六分钟,就再也不用一小段一小段生成完了再拼起来了。那个拼接造成的停顿感和不自然,就都解决了。[S2]确实,这个对于做长音频,比如播客啊、影视配音啊、有声书什么的,简直是刚需。[S1]是啊。所以说,这个Moss TTSD一出来,感觉以后做播客、搞长篇访谈,甚至数字人直播带货这些事儿,门槛又得降一大截了。[S2]感觉AI播客真的要从“能听”变成“好听”了。对我们内容创作者来说,这绝对是个好工具啊。", "prompt_audio": "single_reference.wav", "prompt_text": "[S1]我意思是说通用计算机是指我们所有人都能用的这种普通计算机的意思啊?[S2]不是,不是。现在通用计算机就,比如说现在买了一台电脑嘛。[S1]嗯。[S2]电脑里面你可以装各种各样的软件,对吧?你可以用来聊微信,你可以用来-[S1]明白,可以做快机用的发票机啊-[S2]对对对对 [S1]...连接各种设备啊什么的。[S2]你只要放不同的软件,[S1]嗯。[S2]它都可以去完成那个软件所定义的相对应的工作。"} diff --git a/finetune/data_preprocess.py b/finetune/data_preprocess.py index a584e2a..f1faba6 100644 --- a/finetune/data_preprocess.py +++ b/finetune/data_preprocess.py @@ -1,13 +1,16 @@ -import json -import torch -import numpy as np import argparse +import json import os import pickle import sys + +import numpy as np +import torch + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from generation_utils import process_jsonl_item, normalize_text, load_audio_data from transformers import AutoTokenizer + +from generation_utils import load_audio_data, normalize_text, process_jsonl_item from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer MODEL_PATH = "fnlp/MOSS-TTSD-v0.5" @@ -17,34 +20,53 @@ MAX_CHANNELS = 8 SILENCE_DURATION = 0.0 # Fixed silence duration: 0 seconds + def load_tokenizer(model_path, spt_config_path, spt_checkpoint_path): tokenizer = AutoTokenizer.from_pretrained(model_path) - spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path) + spt = XY_Tokenizer.load_from_checkpoint( + config_path=spt_config_path, ckpt_path=spt_checkpoint_path + ) spt.eval() return tokenizer, spt -def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, reference_audio=None, main_audio=None, max_channels=8, pad_token=1024): + +def process_inputs( + tokenizer, + spt, + prompt, + text, + device, + audio_data=None, + reference_audio=None, + main_audio=None, + max_channels=8, + pad_token=1024, +): # Decompose template into multiple parts # 1. Style prompt part seg1 = f"<|begin_of_style|>{prompt}<|end_of_style|>\n<|begin_of_text|>" inputs1 = np.array(tokenizer.encode(seg1)) inputs_expanded1 = np.full((len(inputs1), max_channels), pad_token) inputs_expanded1[:, 0] = inputs1 - labels1 = np.full(inputs_expanded1.shape, -100) # Style prompt does not compute loss - + labels1 = np.full( + inputs_expanded1.shape, -100 + ) # Style prompt does not compute loss + # 2. Text part text_tokens = tokenizer.encode(text, add_special_tokens=False) inputs2 = np.array(text_tokens) inputs_expanded2 = np.full((len(inputs2), max_channels), pad_token) inputs_expanded2[:, 0] = inputs2 labels2 = np.full(inputs_expanded2.shape, -100) # Text does not compute loss - + # 3. Text end/speech begin part seg3 = f"<|end_of_text|>\n<|begin_of_speech|>" inputs3 = np.array(tokenizer.encode(seg3)) inputs_expanded3 = np.full((len(inputs3), max_channels), pad_token) inputs_expanded3[:, 0] = inputs3 - labels3 = np.full(inputs_expanded3.shape, -100) # Start marker does not compute loss + labels3 = np.full( + inputs_expanded3.shape, -100 + ) # Start marker does not compute loss # 4. Audio processing part audio_token = None @@ -54,32 +76,42 @@ def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, refere # Add silence to the end of each audio silence_samples = int(SILENCE_DURATION * 16000) silence = torch.zeros(1, silence_samples) - - # Ensure audio has correct shape [1, samples] + + # Ensure audio has correct shape [1, samples] if len(reference_audio.shape) == 1: reference_audio = reference_audio.unsqueeze(0) if len(main_audio.shape) == 1: main_audio = main_audio.unsqueeze(0) - + # Add silence to each audio ref_audio_with_silence = torch.cat([reference_audio, silence], dim=1) main_audio_with_silence = torch.cat([main_audio, silence], dim=1) with torch.no_grad(): # Encode two audio files separately - ref_encode_result = spt.encode([ref_audio_with_silence.squeeze().to(device)]) - main_encode_result = spt.encode([main_audio_with_silence.squeeze().to(device)]) - - ref_audio_token = ref_encode_result["codes_list"][0].permute(1, 0).cpu().numpy() - main_audio_token = main_encode_result["codes_list"][0].permute(1, 0).cpu().numpy() - + ref_encode_result = spt.encode( + [ref_audio_with_silence.squeeze().to(device)] + ) + main_encode_result = spt.encode( + [main_audio_with_silence.squeeze().to(device)] + ) + + ref_audio_token = ( + ref_encode_result["codes_list"][0].permute(1, 0).cpu().numpy() + ) + main_audio_token = ( + main_encode_result["codes_list"][0].permute(1, 0).cpu().numpy() + ) + # Concatenate at token level - audio_token = np.concatenate([ref_audio_token, main_audio_token], axis=0) + audio_token = np.concatenate( + [ref_audio_token, main_audio_token], axis=0 + ) except Exception as e: print(f"Error processing two audio files: {e}") raise - + elif audio_data is not None: # Original format: single audio processing try: @@ -102,19 +134,19 @@ def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, refere if audio_token is not None: # Add offset (only for the first layer) audio_token[:, 0] = audio_token[:, 0] + 151665 - + # Channel count alignment processing if audio_token.shape[1] > max_channels: audio_token = audio_token[:, :max_channels] elif audio_token.shape[1] < max_channels: padded = np.full((audio_token.shape[0], max_channels), pad_token) - padded[:, :audio_token.shape[1]] = audio_token + padded[:, : audio_token.shape[1]] = audio_token audio_token = padded - + labels4 = audio_token.copy() # Audio tokens need to compute loss else: raise ValueError("No audio data provided") - + # 5. Speech end part seg5 = "<|end_of_speech|>" inputs5 = np.array(tokenizer.encode(seg5)) @@ -124,21 +156,25 @@ def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, refere labels5[:, 0] = inputs_expanded5[:, 0] # End marker needs to be learned # Concatenate all parts - input_ids = np.concatenate([ - inputs_expanded1, # Style prompt - inputs_expanded2, # Text - inputs_expanded3, # Speech start marker - audio_token, # Speech tokens (first layer with offset added) - inputs_expanded5 # End marker - ]) - - labels = np.concatenate([ - labels1, # Style prompt (no loss computation) - labels2, # Text (no loss computation) - labels3, # Start marker (no loss computation) - labels4, # Speech tokens (compute loss) - labels5 # End marker (compute loss) - ]) + input_ids = np.concatenate( + [ + inputs_expanded1, # Style prompt + inputs_expanded2, # Text + inputs_expanded3, # Speech start marker + audio_token, # Speech tokens (first layer with offset added) + inputs_expanded5, # End marker + ] + ) + + labels = np.concatenate( + [ + labels1, # Style prompt (no loss computation) + labels2, # Text (no loss computation) + labels3, # Start marker (no loss computation) + labels4, # Speech tokens (compute loss) + labels5, # End marker (compute loss) + ] + ) # Calculate length information total_length = input_ids.shape[0] # Total token length @@ -146,23 +182,25 @@ def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, refere return input_ids, labels, total_length, audio_length + def process_data( - jsonl: str , - model_path: str , - output_dir: str , - data_name: str = "processd_data", - use_normalize: bool = True): + jsonl: str, + model_path: str, + output_dir: str, + data_name: str = "processd_data", + use_normalize: bool = True, +): # Create output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) - + device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") - + # Load models print("Loading models...") tokenizer, spt = load_tokenizer(model_path, SPT_CONFIG_PATH, SPT_CHECKPOINT_PATH) spt = spt.to(device) - + # Load the items from the JSONL file try: with open(jsonl, "r") as f: @@ -179,139 +217,192 @@ def process_data( all_data = [] offsets = [] tokens_lengths = [] # Added: store total token length - tims_lengths = [] # Added: store audio token length - + tims_lengths = [] # Added: store audio token length + for idx, item in enumerate(items): # Support two JSONL formats # Format 1: {"file_path": "path/to/audio.wav", "full_transcript": "speech content..."} # Format 2: {"reference_audio": "path1", "reference_text": "text1", "audio": "path2", "text": "text2"} - + if "file_path" in item and "full_transcript" in item: # Original format file_path = item["file_path"] full_text = item["full_transcript"] - + # Check if audio file exists if not file_path: print(f"Warning: Item {idx} has empty file_path, skipping...") continue - + if not os.path.exists(file_path): - print(f"Warning: Audio file not found: {file_path}, skipping item {idx}...") + print( + f"Warning: Audio file not found: {file_path}, skipping item {idx}..." + ) continue - + try: # load_audio_data already includes 16kHz mono conversion functionality audio_data = load_audio_data(file_path) except Exception as e: - print(f"Warning: Failed to load audio from {file_path}: {e}, skipping item {idx}...") + print( + f"Warning: Failed to load audio from {file_path}: {e}, skipping item {idx}..." + ) continue - + # Apply text normalization based on parameter if use_normalize: full_text = normalize_text(full_text) - + # Replace speaker tags - final_text = full_text.replace("[S1]", "").replace("[S2]", "") - + final_text = full_text.replace("[S1]", "").replace( + "[S2]", "" + ) + # Process single audio format - input_id, labels, total_length, audio_length = process_inputs(tokenizer, spt, SYSTEM_PROMPT, final_text, device, audio_data, max_channels=MAX_CHANNELS) - - elif "reference_audio" in item and "reference_text" in item and "audio" in item and "text" in item: + input_id, labels, total_length, audio_length = process_inputs( + tokenizer, + spt, + SYSTEM_PROMPT, + final_text, + device, + audio_data, + max_channels=MAX_CHANNELS, + ) + + elif ( + "reference_audio" in item + and "reference_text" in item + and "audio" in item + and "text" in item + ): # New format: requires concatenation reference_audio_path = item["reference_audio"] reference_text = item["reference_text"] audio_path = item["audio"] text = item["text"] - + # Concatenate text full_text = reference_text + text - + # Check if both audio files exist if not reference_audio_path or not audio_path: print(f"Warning: Item {idx} has empty audio paths, skipping...") continue - + if not os.path.exists(reference_audio_path): - print(f"Warning: Reference audio file not found: {reference_audio_path}, skipping item {idx}...") + print( + f"Warning: Reference audio file not found: {reference_audio_path}, skipping item {idx}..." + ) continue - + if not os.path.exists(audio_path): - print(f"Warning: Audio file not found: {audio_path}, skipping item {idx}...") + print( + f"Warning: Audio file not found: {audio_path}, skipping item {idx}..." + ) continue - + try: # load_audio_data already includes 16kHz mono conversion functionality reference_audio = load_audio_data(reference_audio_path) main_audio = load_audio_data(audio_path) - + # Apply text normalization based on parameter if use_normalize: full_text = normalize_text(full_text) - + # Replace speaker tags - final_text = full_text.replace("[S1]", "").replace("[S2]", "") - + final_text = full_text.replace("[S1]", "").replace( + "[S2]", "" + ) + # Pass two separate audio files to process_inputs - input_id, labels, total_length, audio_length = process_inputs(tokenizer, spt, SYSTEM_PROMPT, final_text, device, - reference_audio=reference_audio, main_audio=main_audio, - max_channels=MAX_CHANNELS) - + input_id, labels, total_length, audio_length = process_inputs( + tokenizer, + spt, + SYSTEM_PROMPT, + final_text, + device, + reference_audio=reference_audio, + main_audio=main_audio, + max_channels=MAX_CHANNELS, + ) + except Exception as e: - print(f"Warning: Failed to load audio files: {e}, skipping item {idx}...") + print( + f"Warning: Failed to load audio files: {e}, skipping item {idx}..." + ) continue - + else: - print(f"Warning: Item {idx} missing required fields for both supported formats, skipping...") + print( + f"Warning: Item {idx} missing required fields for both supported formats, skipping..." + ) continue - + # Create data entry containing input_ids and labels data_entry = { - 'input_ids': input_id.tolist(), # shape: [seq_len, 8] - 'labels': labels.tolist() # shape: [seq_len, 8] + "input_ids": input_id.tolist(), # shape: [seq_len, 8] + "labels": labels.tolist(), # shape: [seq_len, 8] } - + all_data.append(data_entry) - tokens_lengths.append(total_length) # Record total length - tims_lengths.append(audio_length) # Record audio length + tokens_lengths.append(total_length) # Record total length + tims_lengths.append(audio_length) # Record audio length + + print( + f"Processed item {idx + 1}/{len(items)}: input_ids shape {input_id.shape}, labels shape {labels.shape}, total_len={total_length}, audio_len={audio_length}" + ) - print(f"Processed item {idx + 1}/{len(items)}: input_ids shape {input_id.shape}, labels shape {labels.shape}, total_len={total_length}, audio_len={audio_length}") - # Save pkl file - serialize one by one output_pkl_path = os.path.join(output_dir, f"{data_name}.pkl") - with open(output_pkl_path, 'wb') as f: + with open(output_pkl_path, "wb") as f: for data_entry in all_data: offsets.append(f.tell()) # Record current position as offset pickle.dump(data_entry, f) # Serialize each data entry separately - + # Save metadata file containing three arrays output_meta_path = os.path.join(output_dir, f"{data_name}_metas.npy") pointers = np.array(offsets) tokens = np.array(tokens_lengths) tims = np.array(tims_lengths) - + # Follow reference code format: stack([pointers, tokens, tims]) np.save(output_meta_path, np.stack([pointers, tokens, tims])) - + print(f"Saved {len(all_data)} processed items to {output_pkl_path}") - print(f"Saved metadata (pointers, tokens_lengths, tims_lengths) to {output_meta_path}") + print( + f"Saved metadata (pointers, tokens_lengths, tims_lengths) to {output_meta_path}" + ) print(f"Total sequences processed: {len(all_data)}") - print(f"Average total length: {np.mean(tokens_lengths):.1f}, Average audio length: {np.mean(tims_lengths):.1f}") + print( + f"Average total length: {np.mean(tokens_lengths):.1f}, Average audio length: {np.mean(tims_lengths):.1f}" + ) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="TTS inference with Asteroid model") - parser.add_argument("--jsonl", type=str, required=True, - help="Path to JSONL file") + parser.add_argument("--jsonl", type=str, required=True, help="Path to JSONL file") parser.add_argument("--model_path", type=str, help="Path to the pre-trained model") - parser.add_argument("--output_dir", type=str, required=True, - help="Output directory for generated audio files") - parser.add_argument("--data_name", default="processed_data", - help="Name of the processed data file (default: processed_data)") - parser.add_argument("--use_normalize", action="store_true", default=False, - help="Whether to use text normalization (default: False)") - + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Output directory for generated audio files", + ) + parser.add_argument( + "--data_name", + default="processed_data", + help="Name of the processed data file (default: processed_data)", + ) + parser.add_argument( + "--use_normalize", + action="store_true", + default=False, + help="Whether to use text normalization (default: False)", + ) + args = parser.parse_args() - + if not args.jsonl: raise ValueError("JSONL file path is required.") elif not os.path.exists(args.jsonl): @@ -324,5 +415,7 @@ def process_data( raise ValueError("Output directory is required.") elif not os.path.exists(args.output_dir): os.makedirs(args.output_dir) - - process_data(args.jsonl, args.model_path, args.output_dir, args.data_name, args.use_normalize) + + process_data( + args.jsonl, args.model_path, args.output_dir, args.data_name, args.use_normalize + ) diff --git a/finetune/finetune.py b/finetune/finetune.py index 215609e..d68af08 100644 --- a/finetune/finetune.py +++ b/finetune/finetune.py @@ -1,37 +1,45 @@ import os -import torch -import random import pickle -import numpy as np -from typing import Dict, List +import random +import sys from dataclasses import dataclass +from typing import Dict, List + +import numpy as np +import torch from torch.utils.data import Dataset -import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from modeling_asteroid import AsteroidTTSInstruct +import argparse + +# Import peft related modules +from peft import LoraConfig, PeftModel, TaskType, get_peft_model from transformers import AutoTokenizer +from transformers.tokenization_utils import PreTrainedTokenizer from transformers.trainer import Trainer from transformers.training_args import TrainingArguments -from transformers.tokenization_utils import PreTrainedTokenizer -import argparse -# Import peft related modules -from peft import LoraConfig, get_peft_model, PeftModel, TaskType +from modeling_asteroid import AsteroidTTSInstruct MODEL_PATH = "fnlp/MOSS-TTSD-v0.5" MAX_CHANNELS = 8 + class LazySupervisedDataset(Dataset): def __init__(self, data_dir, channels: int, tokenizer: PreTrainedTokenizer): super(LazySupervisedDataset, self).__init__() self.tokenizer, self.channels = tokenizer, channels - pkls = [os.path.join(data_dir, each) for each in os.listdir(data_dir) if each.endswith('.pkl')] + pkls = [ + os.path.join(data_dir, each) + for each in os.listdir(data_dir) + if each.endswith(".pkl") + ] self.data = [] for pkl_file in pkls: # Load metas file containing three arrays: [pointers, tokens_lengths, tims_lengths] metas = np.load(pkl_file.replace(".pkl", "_metas.npy")) pointers = metas[0] # Extract byte offset position array - + f = open(pkl_file, "rb") for start_pointer in pointers: f.seek(int(start_pointer)) # Ensure integer type @@ -42,12 +50,12 @@ def __init__(self, data_dir, channels: int, tokenizer: PreTrainedTokenizer): def __len__(self): return len(self.data) - + def truncate_and_shift(self, example: Dict[str, List]) -> Dict[str, np.ndarray]: # Read input_ids and labels from data instead of copying input_ids - input_ids = np.array(example["input_ids"])[:, :self.channels] - labels = np.array(example["labels"])[:, :self.channels] # Use labels from data - + input_ids = np.array(example["input_ids"])[:, : self.channels] + labels = np.array(example["labels"])[:, : self.channels] # Use labels from data + seq_len = input_ids.shape[0] new_seq_len = seq_len + self.channels - 1 @@ -59,21 +67,26 @@ def truncate_and_shift(self, example: Dict[str, List]) -> Dict[str, np.ndarray]: for i in range(self.channels): shifted_input_ids[i : (seq_len + i), i] = input_ids[:, i] shifted_labels[i : (seq_len + i), i] = labels[:, i] - + return { "input_ids": shifted_input_ids, "labels": shifted_labels, - "attention_mask": np.ones(new_seq_len) + "attention_mask": np.ones(new_seq_len), } def __getitem__(self, i) -> Dict[str, np.ndarray]: line = self.data[i] - + # Data validation if "input_ids" not in line or "labels" not in line: - raise ValueError(f"Data format error: sample {i} missing 'input_ids' or 'labels' field") - - return self.truncate_and_shift(line) # Return numpy arrays for consistency with original code + raise ValueError( + f"Data format error: sample {i} missing 'input_ids' or 'labels' field" + ) + + return self.truncate_and_shift( + line + ) # Return numpy arrays for consistency with original code + @dataclass class DataCollatorForSupervisedDataset: @@ -81,14 +94,16 @@ class DataCollatorForSupervisedDataset: max_length: int filler_token_id: int = 1024 - def __call__(self, instances: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Tensor]: + def __call__( + self, instances: List[Dict[str, np.ndarray]] + ) -> Dict[str, torch.Tensor]: input_ids = [instance["input_ids"] for instance in instances] labels = [instance["labels"] for instance in instances] attention_masks = [instance["attention_mask"] for instance in instances] channels = input_ids[0].shape[1] max_length = min(max(ids.shape[0] for ids in input_ids), self.max_length) padded_input_ids, padded_labels, padded_attns = [], [], [] - + for ids, lbls, attn in zip(input_ids, labels, attention_masks): seq_len = ids.shape[0] if seq_len < max_length: @@ -112,103 +127,128 @@ def __call__(self, instances: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Te return { "input_ids": input_ids, "labels": labels, - "attention_mask": attention_mask + "attention_mask": attention_mask, } -def train(model_path : str, data_dir : str, output_dir : str, training_config : Dict, device: str = "cuda", use_lora: bool = False, lora_cfg: Dict = None): + +def train( + model_path: str, + data_dir: str, + output_dir: str, + training_config: Dict, + device: str = "cuda", + use_lora: bool = False, + lora_cfg: Dict = None, +): print("Initializing tokenizer and model") tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer.padding_side = "left" - + # Load model with CPU offload support model = AsteroidTTSInstruct.from_pretrained( - model_path, - torch_dtype=torch.bfloat16, + model_path, + torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", offload_folder="offload", - offload_state_dict=True + offload_state_dict=True, ) - - model.set_weights([8,2,1,1,1,1,1,1]) + + model.set_weights([8, 2, 1, 1, 1, 1, 1, 1]) model.config.use_cache = False - + # Move model to device model.to(torch.device(device)) - + # Enable gradient checkpointing first (on base model) - if training_config.get('gradient_checkpointing', True): - model.gradient_checkpointing_enable( + if training_config.get("gradient_checkpointing", True): + model.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) print("Gradient checkpointing enabled with use_reentrant=False") - + # Configure LoRA parameters if using LoRA if use_lora: print("Configuring LoRA parameters...") - + # Default LoRA configuration default_lora_config = { - 'r': 16, - 'lora_alpha': 32, - 'target_modules': ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], - 'lora_dropout': 0.05, - 'bias': "none", - 'use_rslora': True + "r": 16, + "lora_alpha": 32, + "target_modules": [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + "lora_dropout": 0.05, + "bias": "none", + "use_rslora": True, } - + # Merge with user provided configuration if lora_cfg: default_lora_config.update(lora_cfg) - + print(f"Using LoRA configuration: {default_lora_config}") - + lora_config = LoraConfig( - r=int(default_lora_config['r']), - lora_alpha=int(default_lora_config['lora_alpha']), - target_modules=default_lora_config['target_modules'], - lora_dropout=float(default_lora_config['lora_dropout']), - bias=default_lora_config['bias'], + r=int(default_lora_config["r"]), + lora_alpha=int(default_lora_config["lora_alpha"]), + target_modules=default_lora_config["target_modules"], + lora_dropout=float(default_lora_config["lora_dropout"]), + bias=default_lora_config["bias"], task_type=TaskType.CAUSAL_LM, - use_rslora=bool(default_lora_config['use_rslora']), + use_rslora=bool(default_lora_config["use_rslora"]), ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() print("LoRA configuration completed") - + # Re-enable gradient checkpointing on PEFT model (to ensure compatibility) - if training_config.get('gradient_checkpointing', True): + if training_config.get("gradient_checkpointing", True): # Call base model's method with gradient_checkpointing_kwargs model.base_model.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) - print("Re-enabled gradient checkpointing on LoRA base model with use_reentrant=False") - + print( + "Re-enabled gradient checkpointing on LoRA base model with use_reentrant=False" + ) + # Ensure model is in training mode and verify trainable parameters model.train() trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) if trainable_params == 0: - raise ValueError("No trainable parameters! LoRA configuration might be problematic.") + raise ValueError( + "No trainable parameters! LoRA configuration might be problematic." + ) print(f"Number of trainable parameters: {trainable_params:,}") else: model.train() - + print("Initializing dataloader") train_dataset = LazySupervisedDataset(data_dir, MAX_CHANNELS, tokenizer) data_collator = DataCollatorForSupervisedDataset(tokenizer.pad_token_id, 16000) - + training_args = TrainingArguments( output_dir=output_dir, - per_device_train_batch_size=int(training_config.get('per_device_train_batch_size', 1)), - gradient_accumulation_steps=int(training_config.get('gradient_accumulation_steps', 1)), - num_train_epochs=int(training_config.get('num_train_epochs', 50)), - learning_rate=float(training_config.get('learning_rate', 1e-4)), - bf16=bool(training_config.get('bf16', True)), - logging_steps=int(training_config.get('logging_steps', 10)), - save_steps=int(training_config.get('save_steps', 10)), - save_total_limit=int(training_config.get('save_total_limit', 100)), - dataloader_num_workers=int(training_config.get('dataloader_num_workers', 1)), - warmup_ratio=float(training_config.get('warmup_ratio', 0.1)), - lr_scheduler_type=str(training_config.get('lr_scheduler_type', "cosine")), + per_device_train_batch_size=int( + training_config.get("per_device_train_batch_size", 1) + ), + gradient_accumulation_steps=int( + training_config.get("gradient_accumulation_steps", 1) + ), + num_train_epochs=int(training_config.get("num_train_epochs", 50)), + learning_rate=float(training_config.get("learning_rate", 1e-4)), + bf16=bool(training_config.get("bf16", True)), + logging_steps=int(training_config.get("logging_steps", 10)), + save_steps=int(training_config.get("save_steps", 10)), + save_total_limit=int(training_config.get("save_total_limit", 100)), + dataloader_num_workers=int(training_config.get("dataloader_num_workers", 1)), + warmup_ratio=float(training_config.get("warmup_ratio", 0.1)), + lr_scheduler_type=str(training_config.get("lr_scheduler_type", "cosine")), report_to="tensorboard", logging_dir=os.path.join(output_dir, "logs"), gradient_checkpointing=False, # Already enabled manually on model, don't duplicate @@ -216,9 +256,9 @@ def train(model_path : str, data_dir : str, output_dir : str, training_config : remove_unused_columns=False, # Keep all columns dataloader_pin_memory=False, # May help avoid certain CUDA issues save_safetensors=False, - ddp_find_unused_parameters=False + ddp_find_unused_parameters=False, ) - + trainer = Trainer( model=model, args=training_args, @@ -229,13 +269,13 @@ def train(model_path : str, data_dir : str, output_dir : str, training_config : trainer.train() torch.cuda.synchronize() - + # Save model if use_lora: # If using LoRA, merge LoRA weights to base model first, then save complete model print("Merging LoRA weights to base model...") merged_model = model.merge_and_unload() - + # Save the merged complete model with updated method merged_model.save_pretrained(output_dir, safe_serialization=False) print(f"LoRA weights merged and complete model saved to {output_dir}") @@ -243,20 +283,43 @@ def train(model_path : str, data_dir : str, output_dir : str, training_config : # If not using LoRA, save complete model trainer.save_model(output_dir) print(f"Complete model saved to {output_dir}") - + tokenizer.save_pretrained(output_dir) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Finetune Asteroid TTS Instruct Model") parser.add_argument("--model_path", type=str, help="Path to the pre-trained model") - parser.add_argument("--data_dir", type=str, required=True, help="Directory containing the training data") - parser.add_argument("--output_dir", type=str, required=True, help="Output directory for generated audio files") - parser.add_argument("--training_config", type=str, default="finetune/training_config.yaml", - help="Path to the training configuration file") - parser.add_argument("--lora_config", type=str, default="finetune/lora_config.yaml", - help="Path to the LoRA configuration file") - parser.add_argument("--lora", action="store_true", help="Use LoRA (Low-Rank Adaptation) for fine-tuning") - + parser.add_argument( + "--data_dir", + type=str, + required=True, + help="Directory containing the training data", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Output directory for generated audio files", + ) + parser.add_argument( + "--training_config", + type=str, + default="finetune/training_config.yaml", + help="Path to the training configuration file", + ) + parser.add_argument( + "--lora_config", + type=str, + default="finetune/lora_config.yaml", + help="Path to the LoRA configuration file", + ) + parser.add_argument( + "--lora", + action="store_true", + help="Use LoRA (Low-Rank Adaptation) for fine-tuning", + ) + args = parser.parse_args() if not args.model_path: args.model_path = MODEL_PATH @@ -274,26 +337,44 @@ def train(model_path : str, data_dir : str, output_dir : str, training_config : training_config = {} if args.training_config: import yaml + if os.path.exists(args.training_config): - with open(args.training_config, 'r') as f: + with open(args.training_config, "r") as f: training_config = yaml.safe_load(f) - print(f"Successfully loaded training configuration from {args.training_config}: {training_config}") + print( + f"Successfully loaded training configuration from {args.training_config}: {training_config}" + ) else: - print(f"Warning: Configuration file {args.training_config} does not exist, using default parameters.") - + print( + f"Warning: Configuration file {args.training_config} does not exist, using default parameters." + ) + lora_cfg = {} if args.lora and args.lora_config: import yaml + if os.path.exists(args.lora_config): - with open(args.lora_config, 'r') as f: + with open(args.lora_config, "r") as f: lora_cfg = yaml.safe_load(f) - print(f"Successfully loaded LoRA configuration from {args.lora_config}: {lora_cfg}") + print( + f"Successfully loaded LoRA configuration from {args.lora_config}: {lora_cfg}" + ) else: - print(f"Warning: LoRA configuration file {args.lora_config} does not exist, using default LoRA parameters.") - + print( + f"Warning: LoRA configuration file {args.lora_config} does not exist, using default LoRA parameters." + ) + if args.lora: print("Using LoRA fine-tuning mode") else: print("Using full model fine-tuning mode") - - train(args.model_path, args.data_dir, args.output_dir, training_config, device="cuda" if torch.cuda.is_available() else "cpu", use_lora=args.lora, lora_cfg=lora_cfg) \ No newline at end of file + + train( + args.model_path, + args.data_dir, + args.output_dir, + training_config, + device="cuda" if torch.cuda.is_available() else "cpu", + use_lora=args.lora, + lora_cfg=lora_cfg, + ) diff --git a/finetune/finetune_config.yaml b/finetune/finetune_config.yaml index 9cb7a1f..85e24e8 100644 --- a/finetune/finetune_config.yaml +++ b/finetune/finetune_config.yaml @@ -1,9 +1,9 @@ path_to_jsonl : # Path to the training data in JSONL format data_output_directory : # Directory where the processed data will be saved -data_name : # Name of the dataset +data_name : # Name of the dataset use_normalize : # Whether to normalize the data (true/false) path_to_model : # Path to the pre-trained model (leave empty to use default HuggingFace model) finetuned_model_output : # Directory where the finetuned model will be saved training_config_file : finetune/training_config.yaml # Path to the training configuration file use_lora : # Whether to use LoRA fine-tuning (true/false) -lora_config_file : finetune/lora_config.yaml # Path to the LoRA configuration file \ No newline at end of file +lora_config_file : finetune/lora_config.yaml # Path to the LoRA configuration file diff --git a/finetune/finetune_workflow.py b/finetune/finetune_workflow.py index 4adc4b7..e739e7b 100644 --- a/finetune/finetune_workflow.py +++ b/finetune/finetune_workflow.py @@ -1,9 +1,11 @@ -import yaml import argparse +import os + +import torch +import yaml from data_preprocess import process_data + import finetune -import torch -import os # Set environment variable to disable tokenizers parallelism warning os.environ["TOKENIZERS_PARALLELISM"] = "true" @@ -12,78 +14,100 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Finetune Asteroid TTS Instruct Model") - parser.add_argument("-c","--config", type=str, default="./finetune/finetune_config.yaml", help="Path to the finetune workflow configuration file") - parser.add_argument("-pd","--pass_data_preprocess", action="store_true", default=False, help="Skip data preprocess step and proceed directly to fine-tuning") + parser.add_argument( + "-c", + "--config", + type=str, + default="./finetune/finetune_config.yaml", + help="Path to the finetune workflow configuration file", + ) + parser.add_argument( + "-pd", + "--pass_data_preprocess", + action="store_true", + default=False, + help="Skip data preprocess step and proceed directly to fine-tuning", + ) args = parser.parse_args() if not os.path.exists(args.config): raise ValueError(f"Configuration file '{args.config}' does not exist.") else: - with open(args.config, 'r') as f: + with open(args.config, "r") as f: config = yaml.safe_load(f) if not args.pass_data_preprocess: - if not config.get('path_to_jsonl'): + if not config.get("path_to_jsonl"): raise ValueError("JSONL file path is required in the configuration.") - elif not os.path.exists(config['path_to_jsonl']): + elif not os.path.exists(config["path_to_jsonl"]): raise ValueError(f"JSONL file '{config['path_to_jsonl']}' does not exist.") - if not config.get('path_to_model'): - config['path_to_model'] = DEFAULT_MODEL_PATH - elif config['path_to_model'] != DEFAULT_MODEL_PATH and not os.path.exists(config['path_to_model']): + if not config.get("path_to_model"): + config["path_to_model"] = DEFAULT_MODEL_PATH + elif config["path_to_model"] != DEFAULT_MODEL_PATH and not os.path.exists( + config["path_to_model"] + ): raise ValueError(f"Model path '{config['path_to_model']}' does not exist.") - if not config.get('data_output_directory'): + if not config.get("data_output_directory"): raise ValueError("Data output directory is required in the configuration.") - elif not os.path.exists(config['data_output_directory']): - os.makedirs(config['data_output_directory']) + elif not os.path.exists(config["data_output_directory"]): + os.makedirs(config["data_output_directory"]) print("Beginning data processing...") process_data( - jsonl=str(config['path_to_jsonl']), - model_path=str(config['path_to_model']), - output_dir=str(config['data_output_directory']), - data_name=str(config.get('data_name', 'processed_data')), - use_normalize=bool(config.get('use_normalize', True)) + jsonl=str(config["path_to_jsonl"]), + model_path=str(config["path_to_model"]), + output_dir=str(config["data_output_directory"]), + data_name=str(config.get("data_name", "processed_data")), + use_normalize=bool(config.get("use_normalize", True)), ) print("Data processing completed.") else: print("Skipping data preprocess step.") # Validate model path for fine-tuning when skipping data preprocess - if not config.get('path_to_model'): - config['path_to_model'] = DEFAULT_MODEL_PATH - elif config['path_to_model'] != DEFAULT_MODEL_PATH and not os.path.exists(config['path_to_model']): + if not config.get("path_to_model"): + config["path_to_model"] = DEFAULT_MODEL_PATH + elif config["path_to_model"] != DEFAULT_MODEL_PATH and not os.path.exists( + config["path_to_model"] + ): raise ValueError(f"Model path '{config['path_to_model']}' does not exist.") - if not config.get('finetuned_model_output'): + if not config.get("finetuned_model_output"): raise ValueError("Finetune output directory is required in the configuration.") - elif not os.path.exists(config['finetuned_model_output']): - os.makedirs(config['finetuned_model_output']) + elif not os.path.exists(config["finetuned_model_output"]): + os.makedirs(config["finetuned_model_output"]) training_config = {} - training_config_file = config.get('training_config_file') + training_config_file = config.get("training_config_file") if training_config_file and os.path.exists(training_config_file): - with open(training_config_file, 'r') as f: + with open(training_config_file, "r") as f: training_config = yaml.safe_load(f) else: - print("Training config file not found or not specified, using default training configuration.") - + print( + "Training config file not found or not specified, using default training configuration." + ) + # Load LoRA configuration if using LoRA lora_cfg = {} - use_lora = bool(config.get('use_lora', False)) + use_lora = bool(config.get("use_lora", False)) if use_lora: - lora_config_file = config.get('lora_config_file', 'finetune/lora_config.yaml') + lora_config_file = config.get("lora_config_file", "finetune/lora_config.yaml") if lora_config_file and os.path.exists(lora_config_file): - with open(lora_config_file, 'r') as f: + with open(lora_config_file, "r") as f: lora_cfg = yaml.safe_load(f) print(f"Loaded LoRA configuration from {lora_config_file}") else: - print("LoRA config file not found or not specified, using default LoRA configuration.") - + print( + "LoRA config file not found or not specified, using default LoRA configuration." + ) + print("Beginning finetuning...") finetune.train( - model_path=str(config['path_to_model']), - data_dir=str(config['data_output_directory']), - output_dir=str(config['finetuned_model_output']), + model_path=str(config["path_to_model"]), + data_dir=str(config["data_output_directory"]), + output_dir=str(config["finetuned_model_output"]), training_config=training_config, - device=str(config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')), + device=str( + config.get("device", "cuda" if torch.cuda.is_available() else "cpu") + ), use_lora=use_lora, - lora_cfg=lora_cfg - ) \ No newline at end of file + lora_cfg=lora_cfg, + ) diff --git a/finetune/lora_config.yaml b/finetune/lora_config.yaml index 3da4b60..e71ecb7 100644 --- a/finetune/lora_config.yaml +++ b/finetune/lora_config.yaml @@ -10,4 +10,4 @@ target_modules: - "down_proj" lora_dropout: 0.05 bias: "none" -use_rslora: true \ No newline at end of file +use_rslora: true diff --git a/finetune/requirements_finetune.txt b/finetune/requirements_finetune.txt index 55f0450..82d56d0 100644 --- a/finetune/requirements_finetune.txt +++ b/finetune/requirements_finetune.txt @@ -18,4 +18,4 @@ tensorboard torch-tb-profiler peft liger_kernel -# flash-attn # Need to install separately \ No newline at end of file +# flash-attn # Need to install separately diff --git a/finetune/training_config.yaml b/finetune/training_config.yaml index 8549681..f33656d 100644 --- a/finetune/training_config.yaml +++ b/finetune/training_config.yaml @@ -9,4 +9,4 @@ save_total_limit: 100 dataloader_num_workers: 1 warmup_ratio: 0.1 lr_scheduler_type: "cosine" -gradient_checkpointing: true \ No newline at end of file +gradient_checkpointing: true diff --git a/generation_utils.py b/generation_utils.py index 2746a70..2638211 100644 --- a/generation_utils.py +++ b/generation_utils.py @@ -8,7 +8,9 @@ MAX_CHANNELS = 8 -def pad_or_truncate_to_seconds(wav: torch.Tensor, target_seconds: float, sr: int) -> torch.Tensor: +def pad_or_truncate_to_seconds( + wav: torch.Tensor, target_seconds: float, sr: int +) -> torch.Tensor: """Pad or truncate a mono waveform to target length in seconds. Args: @@ -31,12 +33,15 @@ def pad_or_truncate_to_seconds(wav: torch.Tensor, target_seconds: float, sr: int else: pad_len = target_len - cur_len out = torch.cat( - [wav_1d, torch.zeros(pad_len, dtype=wav_1d.dtype, device=wav_1d.device)], dim=-1 + [wav_1d, torch.zeros(pad_len, dtype=wav_1d.dtype, device=wav_1d.device)], + dim=-1, ) return out.unsqueeze(0) -def crossfade_concat(segments: list, sample_rate: int, crossfade_seconds: float = 0.1) -> torch.Tensor: +def crossfade_concat( + segments: list, sample_rate: int, crossfade_seconds: float = 0.1 +) -> torch.Tensor: """Concatenate segments with linear crossfade. Args: @@ -61,10 +66,16 @@ def crossfade_concat(segments: list, sample_rate: int, crossfade_seconds: float if cf_len <= 0: out = torch.cat([out, nxt], dim=-1) continue - fade_out = torch.linspace(1.0, 0.0, steps=cf_len, dtype=out.dtype, device=out.device) - fade_in = torch.linspace(0.0, 1.0, steps=cf_len, dtype=nxt.dtype, device=nxt.device) + fade_out = torch.linspace( + 1.0, 0.0, steps=cf_len, dtype=out.dtype, device=out.device + ) + fade_in = torch.linspace( + 0.0, 1.0, steps=cf_len, dtype=nxt.dtype, device=nxt.device + ) overlap = out[0, -cf_len:] * fade_out + nxt[0, :cf_len] * fade_in - out = torch.cat([out[:, :-cf_len], overlap.unsqueeze(0), nxt[:, cf_len:]], dim=-1) + out = torch.cat( + [out[:, :-cf_len], overlap.unsqueeze(0), nxt[:, cf_len:]], dim=-1 + ) return out @@ -133,8 +144,8 @@ def _resolve_path(p: str) -> str: # Try Format 1: separate speaker references s1 = item.get("prompt_audio_speaker1", "") s2 = item.get("prompt_audio_speaker2", "") - has_s1 = ((isinstance(s1, str) and s1) or isinstance(s1, tuple)) - has_s2 = ((isinstance(s2, str) and s2) or isinstance(s2, tuple)) + has_s1 = (isinstance(s1, str) and s1) or isinstance(s1, tuple) + has_s2 = (isinstance(s2, str) and s2) or isinstance(s2, tuple) if has_s1 and has_s2: if isinstance(s1, str) and s1: @@ -159,7 +170,9 @@ def _resolve_path(p: str) -> str: return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} # Otherwise, no supported prompt found → reject (text-only unsupported) - raise ValueError("Input must include prompt (Format 1 or 2). Text-only is not supported.") + raise ValueError( + "Input must include prompt (Format 1 or 2). Text-only is not supported." + ) def load_audio_data(prompt_audio, target_sample_rate=16000): @@ -358,13 +371,13 @@ def normalize_text(text: str) -> str: 8. Merge adjacent identical speaker tags. """ # Replace [1], [2] etc. format with [S1], [S2] etc. format - text = re.sub(r'\[(\d+)\]', r'[S\1]', text) + text = re.sub(r"\[(\d+)\]", r"[S\1]", text) # Remove decorative characters remove_chars = "【】《》()『』「」" '"-_“”~~' # Use positive lookahead to split text by speaker tags (tags themselves are still preserved) - segments = re.split(r'(?=\[S\d+\])', text.replace("\n", " ")) + segments = re.split(r"(?=\[S\d+\])", text.replace("\n", " ")) processed_parts = [] for seg in segments: @@ -373,58 +386,60 @@ def normalize_text(text: str) -> str: continue # Extract tags - m = re.match(r'^(\[S\d+\])\s*(.*)', seg) - tag, content = m.groups() if m else ('', seg) + m = re.match(r"^(\[S\d+\])\s*(.*)", seg) + tag, content = m.groups() if m else ("", seg) # Remove irrelevant symbols content = re.sub(f"[{re.escape(remove_chars)}]", "", content) # Handle consecutive "哈" characters: replace 2 or more with "(笑)" - content = re.sub(r'哈{2,}', '[笑]', content) + content = re.sub(r"哈{2,}", "[笑]", content) # Handle English laughter (e.g., "haha", "ha ha") - content = re.sub(r'\b(ha(\s*ha)+)\b', '[laugh]', content, flags=re.IGNORECASE) + content = re.sub(r"\b(ha(\s*ha)+)\b", "[laugh]", content, flags=re.IGNORECASE) # First handle multi-character punctuation marks - content = content.replace('——', ',') - content = content.replace('……', ',') + content = content.replace("——", ",") + content = content.replace("……", ",") # Handle single-character internal punctuation marks - internal_punct_map = str.maketrans({ - ';': ',', ';': ',', - ':': ',', ':': ',', - '、': ',' - }) + internal_punct_map = str.maketrans( + {";": ",", ";": ",", ":": ",", ":": ",", "、": ","} + ) content = content.translate(internal_punct_map) content = content.strip() # Keep only the final period if len(content) > 1: - last_ch = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1]) - body = content[:-1].replace('。', ',') + last_ch = ( + "。" + if content[-1] == "," + else ("." if content[-1] == "," else content[-1]) + ) + body = content[:-1].replace("。", ",") content = body + last_ch - processed_parts.append({'tag': tag, 'content': content}) + processed_parts.append({"tag": tag, "content": content}) if not processed_parts: return "" # Merge consecutive same speakers merged_lines = [] - current_tag = processed_parts[0]['tag'] - current_content = [processed_parts[0]['content']] + current_tag = processed_parts[0]["tag"] + current_content = [processed_parts[0]["content"]] for part in processed_parts[1:]: - if part['tag'] == current_tag and current_tag: - current_content.append(part['content']) + if part["tag"] == current_tag and current_tag: + current_content.append(part["content"]) else: merged_lines.append(f"{current_tag}{''.join(current_content)}".strip()) - current_tag = part['tag'] - current_content = [part['content']] + current_tag = part["tag"] + current_content = [part["content"]] merged_lines.append(f"{current_tag}{''.join(current_content)}".strip()) - - return "".join(merged_lines).replace('‘', "'").replace('’', "'") + + return "".join(merged_lines).replace("‘", "'").replace("’", "'") def process_batch( @@ -578,7 +593,9 @@ def process_batch( or getattr(spt, "sampling_rate", None) or 24000 ) - ref_wav = load_audio_data(prompt_audio, target_sample_rate=ref_sr_in) + ref_wav = load_audio_data( + prompt_audio, target_sample_rate=ref_sr_in + ) if ref_wav is None: # If ref missing, use original decode with torch.no_grad(): @@ -594,13 +611,19 @@ def process_batch( "index": start_idx + i, } ) - print(f"Audio generation completed (orig no-ref): sample {start_idx + i}") + print( + f"Audio generation completed (orig no-ref): sample {start_idx + i}" + ) else: # Encode 20s reference to tokens - ref_wav_20s = pad_or_truncate_to_seconds(ref_wav, 20.0, ref_sr_in).to(device) + ref_wav_20s = pad_or_truncate_to_seconds( + ref_wav, 20.0, ref_sr_in + ).to(device) with torch.no_grad(): enc = spt.encode([ref_wav_20s.squeeze(0)]) - ref_codes = enc["codes_list"][0].to(device).long() # (nq, T_ref) + ref_codes = ( + enc["codes_list"][0].to(device).long() + ) # (nq, T_ref) # Prepare token-to-sample mapping and windowing params out_sr = ( @@ -608,7 +631,9 @@ def process_batch( or getattr(spt, "sample_rate", None) or 24000 ) - tokens_per_second = float(ref_sr_in) / float(spt.encoder_downsample_rate) + tokens_per_second = float(ref_sr_in) / float( + spt.encoder_downsample_rate + ) tokens_per_chunk = int(round(10.0 * tokens_per_second)) stride_tokens = 85 keep_tokens = 85 @@ -632,14 +657,22 @@ def process_batch( # Concatenate reference tokens with current window tokens combined_codes = torch.cat( [ref_codes, gen_chunk.permute(1, 0).long()], dim=1 - ).to(device) # (nq, T_ref + T_chunk) + ).to( + device + ) # (nq, T_ref + T_chunk) codes_lengths = torch.tensor( - [combined_codes.shape[-1]], dtype=torch.long, device=device + [combined_codes.shape[-1]], + dtype=torch.long, + device=device, ) - combined_codes_batched = combined_codes.unsqueeze(1) # (nq, 1, T) + combined_codes_batched = combined_codes.unsqueeze( + 1 + ) # (nq, 1, T) with torch.no_grad(): - detok = spt.inference_detokenize(combined_codes_batched, codes_lengths) + detok = spt.inference_detokenize( + combined_codes_batched, codes_lengths + ) y = detok["y"][0, 0] # (T_samples) # Remove 20s reference portion (in samples) @@ -657,22 +690,35 @@ def process_batch( if is_first: keep_start_tok = 0 - keep_end_tok = min(keep_tokens + left_ctx_tokens, window_len) + keep_end_tok = min( + keep_tokens + left_ctx_tokens, window_len + ) elif is_last and remains < 105: - keep_start_tok = 0 if is_first else min(left_ctx_tokens, window_len) + keep_start_tok = ( + 0 if is_first else min(left_ctx_tokens, window_len) + ) keep_end_tok = window_len else: keep_start_tok = min(left_ctx_tokens, window_len) - keep_end_tok = min(left_ctx_tokens + keep_tokens, window_len) + keep_end_tok = min( + left_ctx_tokens + keep_tokens, window_len + ) keep_start_smps = keep_start_tok * samples_per_token keep_end_smps = keep_end_tok * samples_per_token left_margin = 0 right_margin = crossfade_samples if not is_last else 0 seg_start = max(0, keep_start_smps - left_margin) - seg_end = min(chunk_y.shape[-1], keep_end_smps + right_margin) + seg_end = min( + chunk_y.shape[-1], keep_end_smps + right_margin + ) if seg_end > seg_start: - kept_segments.append(chunk_y[seg_start:seg_end].detach().cpu().unsqueeze(0)) + kept_segments.append( + chunk_y[seg_start:seg_end] + .detach() + .cpu() + .unsqueeze(0) + ) chunk_idx += 1 @@ -680,7 +726,11 @@ def process_batch( if len(kept_segments) == 0: audio_out = torch.zeros(1, int(0.01 * out_sr)) else: - audio_out = crossfade_concat(kept_segments, out_sr, crossfade_seconds=crossfade_seconds) + audio_out = crossfade_concat( + kept_segments, + out_sr, + crossfade_seconds=crossfade_seconds, + ) audio_results.append( { @@ -689,7 +739,9 @@ def process_batch( "index": start_idx + i, } ) - print(f"Audio generation completed (prompt-aug): sample {start_idx + i}") + print( + f"Audio generation completed (prompt-aug): sample {start_idx + i}" + ) except Exception as e: print(f"Error processing sample {start_idx + i}: {str(e)}, skipping...") diff --git a/gradio_demo.py b/gradio_demo.py index 8f924b2..5614e63 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -1,18 +1,23 @@ -import gradio as gr -import torch -import torchaudio -import tempfile import json import os +import tempfile from typing import Optional, Tuple +import gradio as gr +import torch +import torchaudio + from generation_utils import load_model, process_batch + def load_examples_from_jsonl(): """ Load examples from examples/examples.jsonl and convert to format for both ROLE and SINGLE modes """ - jsonl_paths = ["examples/examples.jsonl", "examples/examples_single_reference.jsonl"] + jsonl_paths = [ + "examples/examples.jsonl", + "examples/examples_single_reference.jsonl", + ] role_examples = [] single_examples = [] @@ -23,7 +28,7 @@ def load_examples_from_jsonl(): if not os.path.exists(jsonl_path): print(f"Warning: {jsonl_path} not found") continue - with open(jsonl_path, 'r', encoding='utf-8') as f: + with open(jsonl_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: @@ -33,40 +38,51 @@ def load_examples_from_jsonl(): for line in lines: line = line.strip() if not line: - continue + continue data = json.loads(line) - + # Extract required fields - text = data.get('text', '') - base_path = data.get('base_path', 'examples') - use_normalize = data.get('use_normalize', True) - + text = data.get("text", "") + base_path = data.get("base_path", "examples") + use_normalize = data.get("use_normalize", True) + # Check if this is a role-based example (has speaker1 and speaker2 audio) - if 'prompt_audio_speaker1' in data and 'prompt_audio_speaker2' in data: + if "prompt_audio_speaker1" in data and "prompt_audio_speaker2" in data: # Role mode example audio_mode = "Role" - prompt_audio_1 = os.path.join(base_path, data['prompt_audio_speaker1']) - prompt_text_1 = data.get('prompt_text_speaker1', '') - prompt_audio_2 = os.path.join(base_path, data['prompt_audio_speaker2']) - prompt_text_2 = data.get('prompt_text_speaker2', '') - - example = [text, audio_mode, prompt_audio_1, prompt_text_1, prompt_audio_2, prompt_text_2, use_normalize] + prompt_audio_1 = os.path.join(base_path, data["prompt_audio_speaker1"]) + prompt_text_1 = data.get("prompt_text_speaker1", "") + prompt_audio_2 = os.path.join(base_path, data["prompt_audio_speaker2"]) + prompt_text_2 = data.get("prompt_text_speaker2", "") + + example = [ + text, + audio_mode, + prompt_audio_1, + prompt_text_1, + prompt_audio_2, + prompt_text_2, + use_normalize, + ] role_examples.append(example) - + # Check if this is a single audio example (has prompt_audio and prompt_text) - elif 'prompt_audio' in data and 'prompt_text' in data: + elif "prompt_audio" in data and "prompt_text" in data: # Single mode example audio_mode = "Single" - prompt_audio = os.path.join(base_path, data['prompt_audio']) - prompt_text = data.get('prompt_text', '') - + prompt_audio = os.path.join(base_path, data["prompt_audio"]) + prompt_text = data.get("prompt_text", "") + example = [text, audio_mode, prompt_audio, prompt_text, use_normalize] single_examples.append(example) - - print(f"Loaded {len(role_examples)} role examples and {len(single_examples)} single examples from {jsonl_paths}") + + print( + f"Loaded {len(role_examples)} role examples and {len(single_examples)} single examples from {jsonl_paths}" + ) return role_examples, single_examples + # Load examples from JSONL file ROLE_EXAMPLES, SINGLE_EXAMPLES = load_examples_from_jsonl() @@ -101,8 +117,22 @@ def load_examples_from_jsonl(): "examples_desc": "Click on examples below to auto-fill the form", "role_examples": "Role Mode Examples", "single_examples": "Single Audio Mode Examples", - "role_headers": ["Text to Synthesize", "Input Mode", "Role1 Audio File", "Role1 Text", "Role2 Audio File", "Role2 Text", "Use Normalize"], - "single_headers": ["Text to Synthesize", "Input Mode", "Audio File", "Prompt Text", "Use Normalize"] + "role_headers": [ + "Text to Synthesize", + "Input Mode", + "Role1 Audio File", + "Role1 Text", + "Role2 Audio File", + "Role2 Text", + "Use Normalize", + ], + "single_headers": [ + "Text to Synthesize", + "Input Mode", + "Audio File", + "Prompt Text", + "Use Normalize", + ], }, "中文": { "title": "MOSS-TTSD🪐 对话语音生成", @@ -133,9 +163,23 @@ def load_examples_from_jsonl(): "examples_desc": "点击下方示例自动填充表单", "role_examples": "角色模式示例", "single_examples": "单音频模式示例", - "role_headers": ["要合成的文本", "输入模式", "角色1音频文件", "角色1文本", "角色2音频文件", "角色2文本", "使用规范化"], - "single_headers": ["要合成的文本", "输入模式", "音频文件", "提示文本", "使用规范化"] - } + "role_headers": [ + "要合成的文本", + "输入模式", + "角色1音频文件", + "角色1文本", + "角色2音频文件", + "角色2文本", + "使用规范化", + ], + "single_headers": [ + "要合成的文本", + "输入模式", + "音频文件", + "提示文本", + "使用规范化", + ], + }, } # Model configuration @@ -151,20 +195,24 @@ def load_examples_from_jsonl(): spt = None device = None + def initialize_model(): """Initialize model (load only on first call)""" global tokenizer, model, spt, device - + if tokenizer is None: print("Initializing model...") device = "cuda" if torch.cuda.is_available() else "cpu" - tokenizer, model, spt = load_model(MODEL_PATH, SPT_CONFIG_PATH, SPT_CHECKPOINT_PATH) + tokenizer, model, spt = load_model( + MODEL_PATH, SPT_CONFIG_PATH, SPT_CHECKPOINT_PATH + ) spt = spt.to(device) model = model.to(device) print("Model initialization completed!") - + return tokenizer, model, spt, device + def process_single_audio_generation( text_input: str, audio_mode: str, @@ -174,11 +222,11 @@ def process_single_audio_generation( prompt_audio_1: Optional[str] = None, prompt_text_2: str = "", prompt_audio_2: Optional[str] = None, - use_normalize: bool = True + use_normalize: bool = True, ) -> Tuple[Optional[str], str]: """ Process single audio generation request - + Args: text_input: Text to synthesize prompt_text_single: Prompt text for single audio @@ -188,19 +236,19 @@ def process_single_audio_generation( prompt_text_2: Role2 text prompt_audio_2: Role2 audio file path use_normalize: Whether to use text normalization - + Returns: Generated audio file path and status information """ try: # Initialize model tokenizer, model, spt, device = initialize_model() - + # Build input item item = { "text": text_input, } - + # Handle different audio input modes (mutually exclusive) if audio_mode == "Single": # Use single audio mode @@ -223,12 +271,15 @@ def process_single_audio_generation( item["prompt_audio"] = prompt_audio_2 item["prompt_text"] = prompt_text_2 if prompt_text_2 else "" else: - return None, "Error: Please select a mode and provide corresponding audio files\n- Single Audio Mode: Provide one audio file and corresponding text\n- Role Mode: Provide audio files for Role1 and Role2" - + return ( + None, + "Error: Please select a mode and provide corresponding audio files\n- Single Audio Mode: Provide one audio file and corresponding text\n- Role Mode: Provide audio files for Role1 and Role2", + ) + # Set random seed to ensure reproducible results # import accelerate # accelerate.utils.set_seed(42) - + # Process batch (single item) actual_texts_data, audio_results = process_batch( batch_items=[item], @@ -238,21 +289,23 @@ def process_single_audio_generation( device=device, system_prompt=SYSTEM_PROMPT, start_idx=0, - use_normalize=use_normalize + use_normalize=use_normalize, ) - + # Check results if not audio_results or audio_results[0] is None: return None, "Error: Audio generation failed" - + audio_result = audio_results[0] - + # Create temporary output file output_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name - + # Save audio - torchaudio.save(output_path, audio_result["audio_data"], audio_result["sample_rate"]) - + torchaudio.save( + output_path, audio_result["audio_data"], audio_result["sample_rate"] + ) + # Build status information (using English since this is server-side output) status_info = f""" ✅ Generation successful! @@ -266,60 +319,64 @@ def process_single_audio_generation( - Final Text: {actual_texts_data[0]['final_text'][:100]}... - Use Normalize: {actual_texts_data[0]['use_normalize']} """ - + return output_path, status_info - + except Exception as e: import traceback + error_msg = f"Error: Audio generation failed: {str(e)}\n\nDetails:\n{traceback.format_exc()}" return None, error_msg + # Create Gradio interface def create_gradio_interface() -> gr.Blocks: - with gr.Blocks(title="MOSS-TTSD🪐 Dialogue Generation", theme=gr.themes.Soft()) as demo: - + with gr.Blocks( + title="MOSS-TTSD🪐 Dialogue Generation", theme=gr.themes.Soft() + ) as demo: + # Language selection at the top with gr.Row(): language_selector = gr.Radio( choices=["English", "中文"], value="English", label="Language / 语言", - info="Select interface language / 选择界面语言" + info="Select interface language / 选择界面语言", ) - + # Title and header (will be updated based on language) title_md = gr.Markdown("# MOSS-TTSD🪐 Dialogue Generation") github_md = gr.Markdown("### [Github](https://github.com/OpenMOSS/MOSS-TTSD)") - + with gr.Row(): # Left input area with gr.Column(scale=1): script_input_md = gr.Markdown("### Script Input") - + text_input = gr.Textbox( label="Text to Synthesize", placeholder="Text to be synthesized, format: [S1]Role1 text[S2]Role2 text", lines=6, ) - + use_normalize_single = gr.Checkbox( label="Use text normalization", value=True, - info="Recommended to enable, improves handling of numbers, punctuation, etc." + info="Recommended to enable, improves handling of numbers, punctuation, etc.", ) - + # Right audio input area with gr.Column(scale=1): audio_input_mode_md = gr.Markdown("### Audio Input Mode") - + # Audio input mode selection audio_mode = gr.Radio( choices=["Single", "Role"], value="Single", label="Select input mode", - info="Single Audio: Upload one audio with [S1][S2] text; Role Audio: Upload separate audio for Role1 and Role2" + info="Single Audio: Upload one audio with [S1][S2] text; Role Audio: Upload separate audio for Role1 and Role2", ) - + # Single audio mode with gr.Group(visible=True) as single_mode_group: single_warning_md = gr.Markdown( @@ -328,14 +385,14 @@ def create_gradio_interface() -> gr.Blocks: prompt_audio_single = gr.File( label="Drag and drop audio here - or - click to upload", file_types=["audio"], - type="filepath" + type="filepath", ) prompt_text_single = gr.Textbox( label="Prompt Text", placeholder="Format: [S1]Role1 text[S2]Role2 text", lines=3, ) - + # Role audio mode with gr.Group(visible=False) as role_mode_group: with gr.Row(): @@ -344,73 +401,87 @@ def create_gradio_interface() -> gr.Blocks: prompt_audio_1 = gr.File( label="Role1 Audio File", file_types=["audio"], - type="filepath" + type="filepath", ) prompt_text_1 = gr.Textbox( label="Role1 Text", placeholder="Role1 text content", - lines=2 + lines=2, ) - + with gr.Column(): role2_audio_md = gr.Markdown("**Role2 Audio**") prompt_audio_2 = gr.File( label="Role2 Audio File", file_types=["audio"], - type="filepath" + type="filepath", ) prompt_text_2 = gr.Textbox( label="Role2 Text", placeholder="Role2 text content", - lines=2 + lines=2, ) - + # Generate button with gr.Row(): generate_btn = gr.Button("Generate Audio", variant="primary", size="lg") - + # Output area with gr.Row(): with gr.Column(): output_audio = gr.Audio(label="Generated Audio", type="filepath") status_info = gr.Textbox( - label="Status Information", - lines=10, - interactive=False + label="Status Information", lines=10, interactive=False ) - + # Examples area with gr.Row(): with gr.Column(): examples_md = gr.Markdown("### Examples") - examples_desc_md = gr.Markdown("Click on examples below to auto-fill the form") + examples_desc_md = gr.Markdown( + "Click on examples below to auto-fill the form" + ) # Role mode examples with gr.Group(): role_examples_md = gr.Markdown("**Role Mode Examples**") role_examples = gr.Examples( examples=ROLE_EXAMPLES, - inputs=[text_input, audio_mode, prompt_audio_1, prompt_text_1, prompt_audio_2, prompt_text_2, use_normalize_single], + inputs=[ + text_input, + audio_mode, + prompt_audio_1, + prompt_text_1, + prompt_audio_2, + prompt_text_2, + use_normalize_single, + ], ) - + # Single audio mode examples with gr.Group(): single_examples_md = gr.Markdown("**Single Audio Mode Examples**") single_examples = gr.Examples( examples=SINGLE_EXAMPLES, - inputs=[text_input, audio_mode, prompt_audio_single, prompt_text_single, use_normalize_single], + inputs=[ + text_input, + audio_mode, + prompt_audio_single, + prompt_text_single, + use_normalize_single, + ], ) - + # Event handlers - + # Language change event def update_language(lang): """Update interface language""" texts = LANGUAGES[lang] - + # Update demo title demo.title = texts["title"] - + return ( gr.Markdown(f"# {texts['title']}"), # title_md texts["script_input"], # script_input_md @@ -422,20 +493,20 @@ def update_language(lang): gr.Checkbox( label=texts["use_normalize"], value=True, - info=texts["normalize_info"] + info=texts["normalize_info"], ), # use_normalize_single texts["audio_input_mode"], # audio_input_mode_md gr.Radio( choices=["Single", "Role"], value="Single", label=texts["select_input_mode"], - info=texts["mode_info"] + info=texts["mode_info"], ), # audio_mode gr.Markdown(texts["single_warning"]), # single_warning_md gr.File( label=texts["drag_drop_audio"], file_types=["audio"], - type="filepath" + type="filepath", ), # prompt_audio_single gr.Textbox( label=texts["prompt_text"], @@ -446,30 +517,32 @@ def update_language(lang): gr.File( label=texts["role1_audio_file"], file_types=["audio"], - type="filepath" + type="filepath", ), # prompt_audio_1 gr.Textbox( label=texts["role1_text"], placeholder=texts["role1_placeholder"], - lines=2 + lines=2, ), # prompt_text_1 texts["role2_audio"], # role2_audio_md gr.File( label=texts["role2_audio_file"], file_types=["audio"], - type="filepath" + type="filepath", ), # prompt_audio_2 gr.Textbox( label=texts["role2_text"], placeholder=texts["role2_placeholder"], - lines=2 + lines=2, ), # prompt_text_2 - gr.Button(texts["generate_audio"], variant="primary", size="lg"), # generate_btn - gr.Audio(label=texts["generated_audio"], type="filepath"), # output_audio + gr.Button( + texts["generate_audio"], variant="primary", size="lg" + ), # generate_btn + gr.Audio( + label=texts["generated_audio"], type="filepath" + ), # output_audio gr.Textbox( - label=texts["status_info"], - lines=10, - interactive=False + label=texts["status_info"], lines=10, interactive=False ), # status_info texts["examples"], # examples_md texts["examples_desc"], # examples_desc_md @@ -478,35 +551,51 @@ def update_language(lang): gr.Dataset(headers=texts["role_headers"]), gr.Dataset(headers=texts["single_headers"]), ) - + language_selector.change( fn=update_language, inputs=[language_selector], outputs=[ - title_md, script_input_md, text_input, use_normalize_single, - audio_input_mode_md, audio_mode, single_warning_md, - prompt_audio_single, prompt_text_single, - role1_audio_md, prompt_audio_1, prompt_text_1, - role2_audio_md, prompt_audio_2, prompt_text_2, - generate_btn, output_audio, status_info, - examples_md, examples_desc_md, role_examples_md, single_examples_md, - role_examples.dataset, single_examples.dataset - ] + title_md, + script_input_md, + text_input, + use_normalize_single, + audio_input_mode_md, + audio_mode, + single_warning_md, + prompt_audio_single, + prompt_text_single, + role1_audio_md, + prompt_audio_1, + prompt_text_1, + role2_audio_md, + prompt_audio_2, + prompt_text_2, + generate_btn, + output_audio, + status_info, + examples_md, + examples_desc_md, + role_examples_md, + single_examples_md, + role_examples.dataset, + single_examples.dataset, + ], ) - + # Audio mode toggle event def toggle_audio_mode(mode): if mode == "Single": return gr.Group(visible=True), gr.Group(visible=False) else: return gr.Group(visible=False), gr.Group(visible=True) - + audio_mode.change( fn=toggle_audio_mode, inputs=[audio_mode], - outputs=[single_mode_group, role_mode_group] + outputs=[single_mode_group, role_mode_group], ) - + # Audio generation event generate_btn.click( fn=process_single_audio_generation, @@ -519,17 +608,18 @@ def toggle_audio_mode(mode): prompt_audio_1, prompt_text_2, prompt_audio_2, - use_normalize_single + use_normalize_single, ], outputs=[output_audio, status_info], - show_progress=True + show_progress=True, ) - + return demo + # Main function if __name__ == "__main__": demo = create_gradio_interface() - + # Launch interface demo.launch() diff --git a/inference.py b/inference.py index df4b8de..94f4016 100644 --- a/inference.py +++ b/inference.py @@ -1,61 +1,98 @@ +import argparse import json +import os + +import accelerate import torch import torchaudio -import accelerate -import argparse -import os from generation_utils import load_model, process_batch MODEL_PATH = "fnlp/MOSS-TTSD-v0.7" SYSTEM_PROMPT = "You are a speech synthesizer that generates natural, realistic, and human-like conversational audio from dialogue text." SPT_CONFIG_PATH = "XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml" -SPT_CHECKPOINT_PATH ="XY_Tokenizer/weights/MOSS_TTSD_tokenizer" +SPT_CHECKPOINT_PATH = "XY_Tokenizer/weights/MOSS_TTSD_tokenizer" MAX_CHANNELS = 8 + def main(): parser = argparse.ArgumentParser(description="TTS inference with Asteroid model") - parser.add_argument("--jsonl", default="examples/examples.jsonl",help="Path to JSONL file (default: examples/examples.jsonl)") - parser.add_argument("--seed", type=int, default=None, - help="Random seed for reproducibility (default: None)") - parser.add_argument("--output_dir", default="outputs", - help="Output directory for generated audio files (default: outputs)") - parser.add_argument("--summary_file", default=None, - help="Path to save summary jsonl file (default: None)") - parser.add_argument("--use_normalize", action="store_true", default=False, - help="Whether to use text normalization (default: False)") - parser.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16", - help="Model data type (default: bf16)") - parser.add_argument("--attn_implementation", choices=["flash_attention_2", "sdpa", "eager"], default="flash_attention_2", - help="Attention implementation (default: flash_attention_2)") - parser.add_argument("--silence_duration", type=float, default=0, - help="Silence duration between speech prompt and generated speech, which can be used to avoid noise problem at the beginning of generated audio") - + parser.add_argument( + "--jsonl", + default="examples/examples.jsonl", + help="Path to JSONL file (default: examples/examples.jsonl)", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed for reproducibility (default: None)", + ) + parser.add_argument( + "--output_dir", + default="outputs", + help="Output directory for generated audio files (default: outputs)", + ) + parser.add_argument( + "--summary_file", + default=None, + help="Path to save summary jsonl file (default: None)", + ) + parser.add_argument( + "--use_normalize", + action="store_true", + default=False, + help="Whether to use text normalization (default: False)", + ) + parser.add_argument( + "--dtype", + choices=["bf16", "fp16", "fp32"], + default="bf16", + help="Model data type (default: bf16)", + ) + parser.add_argument( + "--attn_implementation", + choices=["flash_attention_2", "sdpa", "eager"], + default="flash_attention_2", + help="Attention implementation (default: flash_attention_2)", + ) + parser.add_argument( + "--silence_duration", + type=float, + default=0, + help="Silence duration between speech prompt and generated speech, which can be used to avoid noise problem at the beginning of generated audio", + ) + args = parser.parse_args() - + # Convert dtype string to torch dtype dtype_mapping = { "bf16": torch.bfloat16, "fp16": torch.float16, - "fp32": torch.float32 + "fp32": torch.float32, } torch_dtype = dtype_mapping[args.dtype] - + # Create output directory if it doesn't exist os.makedirs(args.output_dir, exist_ok=True) - + device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") print(f"Using dtype: {args.dtype} ({torch_dtype})") print(f"Using attention implementation: {args.attn_implementation}") - + # Load models print("Loading models...") - tokenizer, model, spt = load_model(MODEL_PATH, SPT_CONFIG_PATH, SPT_CHECKPOINT_PATH, - torch_dtype=torch_dtype, attn_implementation=args.attn_implementation) + tokenizer, model, spt = load_model( + MODEL_PATH, + SPT_CONFIG_PATH, + SPT_CHECKPOINT_PATH, + torch_dtype=torch_dtype, + attn_implementation=args.attn_implementation, + ) spt = spt.to(device) model = model.to(device) - + # Load the items from the JSONL file try: with open(args.jsonl, "r") as f: @@ -67,12 +104,12 @@ def main(): except json.JSONDecodeError as e: print(f"Error parsing JSONL file: {e}") return - + # Fix the seed for reproducibility if args.seed is not None: accelerate.utils.set_seed(args.seed) print(f"Set random seed to {args.seed}") - + # Process the batch of items print("Starting inference...") actual_texts_data, audio_results = process_batch( @@ -84,40 +121,43 @@ def main(): system_prompt=SYSTEM_PROMPT, start_idx=0, use_normalize=args.use_normalize, - silence_duration=args.silence_duration + silence_duration=args.silence_duration, ) - + # Save summary if requested if args.summary_file: summary_data = [] for item in actual_texts_data: - summary_data.append({ - "text": item["original_text"], - "normalized_text": item["normalized_text"], - "final_text": item["final_text"] - }) - + summary_data.append( + { + "text": item["original_text"], + "normalized_text": item["normalized_text"], + "final_text": item["final_text"], + } + ) + with open(args.summary_file, "w", encoding="utf-8") as f: for item in summary_data: f.write(json.dumps(item, ensure_ascii=False) + "\n") print(f"Saved summary to {args.summary_file}") - + # Save the audio results to files saved_count = 0 for idx, audio_result in enumerate(audio_results): if audio_result is not None: output_path = os.path.join(args.output_dir, f"output_{idx}.wav") torchaudio.save( - output_path, - audio_result["audio_data"], - audio_result["sample_rate"] + output_path, audio_result["audio_data"], audio_result["sample_rate"] ) print(f"Saved audio to {output_path}") saved_count += 1 else: print(f"Skipping sample {idx} due to generation error") - - print(f"Inference completed. Saved {saved_count}/{len(items)} audio files to {args.output_dir}") + + print( + f"Inference completed. Saved {saved_count}/{len(items)} audio files to {args.output_dir}" + ) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/modeling_asteroid.py b/modeling_asteroid.py index 6946ba9..e7325f6 100644 --- a/modeling_asteroid.py +++ b/modeling_asteroid.py @@ -1,31 +1,40 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + import torch import torch.nn as nn -from dataclasses import dataclass -from transformers.utils import ModelOutput +from transformers import GenerationMixin, PreTrainedModel, Qwen3Config, Qwen3Model from transformers.cache_utils import Cache -from typing import Optional, List, Tuple, Union -from transformers.loss.loss_utils import ForCausalLMLoss -from transformers.generation.streamers import BaseStreamer -from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.generation.configuration_utils import GenerationConfig +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) from transformers.generation.stopping_criteria import StoppingCriteriaList -from transformers import PreTrainedModel, GenerationMixin, Qwen3Config, Qwen3Model -from transformers.generation.logits_process import LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper +from transformers.generation.streamers import BaseStreamer +from transformers.loss.loss_utils import ForCausalLMLoss +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.utils import ModelOutput class AsteroidTTSConfig(Qwen3Config): - def __init__(self, - channels = 8, - speech_pad_token = 1024, - speech_vocab_size = 1025, - speech_token_range = [], - **kwargs): + def __init__( + self, + channels=8, + speech_pad_token=1024, + speech_vocab_size=1025, + speech_token_range=[], + **kwargs, + ): super().__init__(**kwargs) self.channels = channels self.speech_pad_token = speech_pad_token self.speech_vocab_size = speech_vocab_size self.speech_token_range = speech_token_range - + @dataclass class AsteroidTTSOutputWithPast(ModelOutput): @@ -36,7 +45,7 @@ class AsteroidTTSOutputWithPast(ModelOutput): past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - + @dataclass class GenerateDecoderOnlyOutput(ModelOutput): @@ -61,7 +70,7 @@ def _sample( ) -> Union[GenerateDecoderOnlyOutput, torch.LongTensor]: # Extract configuration parameters speech_pad_idx = self.config.speech_pad_token - + eos_token_id = generation_config.eos_token_id output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states @@ -69,26 +78,40 @@ def _sample( output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate max_length = generation_config.max_length - has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + has_eos_stopping_criteria = any( + hasattr(criteria, "eos_token_id") for criteria in stopping_criteria + ) do_sample = generation_config.do_sample # Initialize output tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) # Initialize tracking variables batch_size, cur_len, channels = input_ids.shape # channels = 8 this_peer_finished = False - unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - needs_additional_steps = -1 * torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + unfinished_sequences = torch.ones( + batch_size, dtype=torch.long, device=input_ids.device + ) + needs_additional_steps = -1 * torch.ones( + batch_size, dtype=torch.long, device=input_ids.device + ) tf_inputs = input_ids[:] - input_ids = input_ids[:, :-(channels - 1)] + input_ids = input_ids[:, : -(channels - 1)] cur_len = input_ids.shape[1] - model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :-(channels - 1)] + model_kwargs["attention_mask"] = model_kwargs["attention_mask"][ + :, : -(channels - 1) + ] base_length = input_ids.shape[1] - model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) + model_kwargs = self._get_initial_cache_position( + cur_len, input_ids.device, model_kwargs + ) # Define logits processor if generation_config.do_samples is not None: @@ -96,53 +119,87 @@ def _sample( realprocessor = [LogitsProcessorList() for _ in range(channels)] for i, layer_config in enumerate(generation_config.layers): if layer_config.get("repetition_penalty") is not None: - realprocessor[i].append(RepetitionPenaltyLogitsProcessor(penalty=layer_config.get("repetition_penalty"))) - if layer_config.get("temperature") is not None: - realprocessor[i].append(TemperatureLogitsWarper(temperature=layer_config.get("temperature"))) + realprocessor[i].append( + RepetitionPenaltyLogitsProcessor( + penalty=layer_config.get("repetition_penalty") + ) + ) + if layer_config.get("temperature") is not None: + realprocessor[i].append( + TemperatureLogitsWarper( + temperature=layer_config.get("temperature") + ) + ) if layer_config.get("top_k") is not None: - realprocessor[i].append(TopKLogitsWarper(top_k=layer_config.get("top_k"))) + realprocessor[i].append( + TopKLogitsWarper(top_k=layer_config.get("top_k")) + ) if layer_config.get("top_p") is not None: - realprocessor[i].append(TopPLogitsWarper(top_p=layer_config.get("top_p"))) + realprocessor[i].append( + TopPLogitsWarper(top_p=layer_config.get("top_p")) + ) else: do_samples = [do_sample for _ in range(channels)] realprocessor = [logits_processor for _ in range(channels)] - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device + ): # Prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) - model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + model_inputs.update( + {"output_attentions": output_attentions} if output_attentions else {} + ) + model_inputs.update( + {"output_hidden_states": output_hidden_states} + if output_hidden_states + else {} + ) # Forward pass outputs = self(**model_inputs, return_dict=True) - model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs + ) if synced_gpus and this_peer_finished: continue # Get next token logits - next_token_logits = [logits[:, -1, :].clone().float().to(input_ids.device) for logits in outputs.logits_all] + next_token_logits = [ + logits[:, -1, :].clone().float().to(input_ids.device) + for logits in outputs.logits_all + ] for i, channel_logits in enumerate(next_token_logits): - if i != 0 and input_ids.shape[1] + 1 > tf_inputs.shape[1] - 7 + i: - channel_logits[:, 1024] = - torch.inf - if i == 0 and input_ids.shape[1] + 1 <= tf_inputs.shape[1]: - channel_logits[:, 152694] = - torch.inf - next_token_scores = [realprocessor[i](input_ids[..., i], logits) for i, logits in enumerate(next_token_logits)] + if i != 0 and input_ids.shape[1] + 1 > tf_inputs.shape[1] - 7 + i: + channel_logits[:, 1024] = -torch.inf + if i == 0 and input_ids.shape[1] + 1 <= tf_inputs.shape[1]: + channel_logits[:, 152694] = -torch.inf + next_token_scores = [ + realprocessor[i](input_ids[..., i], logits) + for i, logits in enumerate(next_token_logits) + ] # Generate next tokens next_tokens = [] for i, channel_score in enumerate(next_token_scores): if do_samples[i]: - channel_ntk = torch.multinomial(nn.functional.softmax(channel_score, dim=-1), num_samples=1).squeeze(1) + channel_ntk = torch.multinomial( + nn.functional.softmax(channel_score, dim=-1), num_samples=1 + ).squeeze(1) elif not do_samples[i]: channel_ntk = torch.argmax(channel_score, dim=-1) next_tokens.append(channel_ntk) next_tokens = torch.stack(next_tokens, dim=-1) # [batch_size, channels] # Additional steps logic - indices = (~self.is_speech_token(next_tokens[:, 0])) & (needs_additional_steps < 0) - needs_additional_steps[indices] = channels - 1 # For 8 channels, need 7 steps - + indices = (~self.is_speech_token(next_tokens[:, 0])) & ( + needs_additional_steps < 0 + ) + needs_additional_steps[indices] = ( + channels - 1 + ) # For 8 channels, need 7 steps + if input_ids.shape[1] + 1 <= tf_inputs.shape[1]: i = input_ids.shape[1] + 1 - base_length next_tokens[:, i:] = tf_inputs[:, input_ids.shape[1], i:] - + # Replace tokens in additional steps mask = (needs_additional_steps > 0) & (needs_additional_steps < 7) if mask.any().item(): @@ -150,19 +207,27 @@ def _sample( for i in range(1, channels): mask_i = mask & (needs_additional_steps < channels - i) next_tokens[mask_i, i] = speech_pad_idx - + if has_eos_stopping_criteria: for i in range(channels): pddp = self.config.eos_token_id if i == 0 else speech_pad_idx - next_tokens[:, i] = next_tokens[:, i] * unfinished_sequences + pddp * (1 - unfinished_sequences) - + next_tokens[:, i] = next_tokens[ + :, i + ] * unfinished_sequences + pddp * (1 - unfinished_sequences) + input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1) if streamer is not None: streamer.put(next_tokens.cpu()) - + # Update unfinished_sequences - needs_additional_steps = torch.where(needs_additional_steps > 0, needs_additional_steps - 1, needs_additional_steps) - stopping = stopping_criteria(input_ids[..., 0], scores) | (needs_additional_steps == 0) + needs_additional_steps = torch.where( + needs_additional_steps > 0, + needs_additional_steps - 1, + needs_additional_steps, + ) + stopping = stopping_criteria(input_ids[..., 0], scores) | ( + needs_additional_steps == 0 + ) unfinished_sequences = unfinished_sequences & ~stopping unfinished_sequences = unfinished_sequences | (needs_additional_steps > 0) this_peer_finished = unfinished_sequences.max() == 0 @@ -179,7 +244,7 @@ def _sample( cur_len += 1 del outputs - + if streamer is not None: streamer.end() @@ -194,8 +259,8 @@ def _sample( ) else: return input_ids - - + + class AsteroidTTSPretrainedModel(PreTrainedModel): config_class = AsteroidTTSConfig base_model_prefix = "model" @@ -217,10 +282,16 @@ def __init__(self, config: AsteroidTTSConfig): self.text_pad_idx = config.pad_token_id self.speech_pad_idx = config.speech_pad_token self.embedding_list = nn.ModuleList([]) - self.embedding_list.append(nn.Embedding(config.vocab_size, config.hidden_size, self.text_pad_idx)) + self.embedding_list.append( + nn.Embedding(config.vocab_size, config.hidden_size, self.text_pad_idx) + ) # Channels 1 to channels-1: Speech tokens only for _ in range(1, config.channels): - self.embedding_list.append(nn.Embedding(config.speech_vocab_size, config.hidden_size, self.speech_pad_idx)) + self.embedding_list.append( + nn.Embedding( + config.speech_vocab_size, config.hidden_size, self.speech_pad_idx + ) + ) self.language_model = Qwen3Model(config) self.post_init() @@ -231,19 +302,29 @@ def get_input_embeddings(self): def set_input_embeddings(self, value: nn.Embedding): self.embedding_list[0] = value - def _prepare_multi_modal_inputs(self, input_ids: torch.LongTensor) -> torch.FloatTensor: + def _prepare_multi_modal_inputs( + self, input_ids: torch.LongTensor + ) -> torch.FloatTensor: """ Prepares multi-modal embeddings from input_ids of shape (batch_size, channels, sequence_length). For channel 0: text + speech tokens, for channels 1 to channels-1: speech tokens padded with speech_pad_token. """ batch_size, seq_length, channels = input_ids.shape if channels != self.config.channels: - raise ValueError(f"Expected {self.config.channels} channels, got {channels}") - - inputs_embeds = torch.zeros(batch_size, seq_length, self.config.hidden_size, device=input_ids.device, dtype=self.embedding_list[0].weight.dtype) + raise ValueError( + f"Expected {self.config.channels} channels, got {channels}" + ) + + inputs_embeds = torch.zeros( + batch_size, + seq_length, + self.config.hidden_size, + device=input_ids.device, + dtype=self.embedding_list[0].weight.dtype, + ) for i in range(channels): embed_layer = self.embedding_list[i] - channel_input = input_ids[...,i] + channel_input = input_ids[..., i] inputs_embeds += embed_layer(channel_input) return inputs_embeds @@ -264,7 +345,9 @@ def forward( ) -> Union[Tuple, BaseModelOutputWithPast]: if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) if input_ids is not None: inputs_embeds = self._prepare_multi_modal_inputs(input_ids) @@ -282,8 +365,8 @@ def forward( cache_position=cache_position, ) return outputs - - + + class AsteroidTTSInstruct(AsteroidTTSPretrainedModel, CustomMixin): _tied_weights_keys = [] _tp_plan = {"lm_head": "colwise_rep"} @@ -297,20 +380,26 @@ def __init__(self, config: AsteroidTTSConfig): self._tied_weights_keys = [f"lm_heads.{i}.weight" for i in range(self.channels)] self.vocab_size = config.vocab_size self.lm_heads = nn.ModuleList([]) - self.lm_heads.append(nn.Linear(config.hidden_size, config.vocab_size, bias=False)) + self.lm_heads.append( + nn.Linear(config.hidden_size, config.vocab_size, bias=False) + ) for _ in range(1, config.channels): - self.lm_heads.append(nn.Linear(config.hidden_size, config.speech_vocab_size, bias=False)) + self.lm_heads.append( + nn.Linear(config.hidden_size, config.speech_vocab_size, bias=False) + ) self.post_init() def get_input_embeddings(self): return self.model.embedding_list[0] - + def can_generate(self): return True - + def is_speech_token(self, tokens): - return (tokens >= self.config.speech_token_range[0]) & (tokens < self.config.speech_token_range[1]) - + return (tokens >= self.config.speech_token_range[0]) & ( + tokens < self.config.speech_token_range[1] + ) + def tie_weights(self): for i in range(self.config.channels): self._tie_or_clone_weights(self.lm_heads[i], self.model.embedding_list[i]) @@ -329,7 +418,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - + def set_weights(self, weights): self.weights = weights @@ -349,11 +438,25 @@ def forward( skip_logits: Optional[bool] = None, **kwargs, ) -> Union[Tuple, AsteroidTTSOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) - skip_logits = skip_logits if skip_logits is not None else (self.training and labels is not None) + skip_logits = ( + skip_logits + if skip_logits is not None + else (self.training and labels is not None) + ) if skip_logits and labels is None: skip_logits = False @@ -377,23 +480,25 @@ def forward( logits_all = None loss_all = None total_loss = None - + if labels is not None: from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss device = input_ids.device if input_ids is not None else inputs_embeds.device loss_all = torch.empty(self.channels, device=device) logits_list = [] - + for i in range(self.config.channels): - vocab_size = self.config.vocab_size if i == 0 else self.config.speech_vocab_size + vocab_size = ( + self.config.vocab_size if i == 0 else self.config.speech_vocab_size + ) if skip_logits: loss_all[i] = LigerForCausalLMLoss( hidden_states=hidden_states, lm_head_weight=self.lm_heads[i].weight, labels=labels[..., i], hidden_size=self.config.hidden_size, - **kwargs + **kwargs, ) else: logits = self.lm_heads[i](hidden_states) @@ -405,7 +510,7 @@ def forward( total_weight = sum(self.weights) normalized_weights = [w / total_weight for w in self.weights] - + total_loss = 0 for w, loss in zip(normalized_weights, loss_all): total_loss += w * loss @@ -414,7 +519,15 @@ def forward( if not return_dict: output = (logits_all,) + outputs[1:] - return (total_loss, loss_all, ) + output if total_loss is not None else output + return ( + ( + total_loss, + loss_all, + ) + + output + if total_loss is not None + else output + ) return AsteroidTTSOutputWithPast( loss=total_loss, @@ -424,4 +537,4 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - ) \ No newline at end of file + ) diff --git a/podcast_generate.py b/podcast_generate.py index 13d9cef..cdeba71 100644 --- a/podcast_generate.py +++ b/podcast_generate.py @@ -1,13 +1,15 @@ +import argparse import os +import re + +import openai +import requests import torch import torchaudio -import requests from bs4 import BeautifulSoup -import re from PyPDF2 import PdfReader -import openai + from generation_utils import load_model, process_batch -import argparse # =============== Configuration Section =============== SYSTEM_PROMPT = "You are a speech synthesizer that generates natural, realistic, and human-like conversational audio from dialogue text." @@ -19,7 +21,7 @@ # English audio examples EN_PROMPT_AUDIO_SPEAKER1 = "examples/m1.wav" EN_PROMPT_TEXT_SPEAKER1 = "How much do you know about her?" -EN_PROMPT_AUDIO_SPEAKER2 = "examples/m2.wav" +EN_PROMPT_AUDIO_SPEAKER2 = "examples/m2.wav" EN_PROMPT_TEXT_SPEAKER2 = "Well, we know this much about her. You've been with her constantly since the first day you met her. And we followed you while you went dining, dancing, and sailing. And last night, I happened to be there when you were having dinner with her at Le Petit Tableau." # Chinese audio examples @@ -32,12 +34,13 @@ # =============== Text Parsing Functions =============== + def extract_text_from_pdf(file_path): """Extract text content from PDF file - + Args: file_path (str): PDF file path - + Returns: str: Extracted text content, returns None if failed """ @@ -54,81 +57,83 @@ def extract_text_from_pdf(file_path): def extract_web_content(url): """Extract title and main content from web URL - + Args: url (str): Web URL address - + Returns: tuple: (title, content) - title and main content, returns None if failed """ try: # Send request and get web content headers = { - 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36' + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36" } response = requests.get(url, headers=headers, timeout=10) - response.encoding = 'utf-8' # Ensure correct encoding + response.encoding = "utf-8" # Ensure correct encoding response.raise_for_status() - + print(f"HTTP status: {response.status_code}") print(f"Response content length: {len(response.text)}") # Parse HTML content - soup = BeautifulSoup(response.text, 'html.parser') + soup = BeautifulSoup(response.text, "html.parser") # Extract title - try multiple approaches title = "" - + # Try h1 tags first - h1_tags = soup.find_all('h1') + h1_tags = soup.find_all("h1") if h1_tags: for h1 in h1_tags: if h1.text.strip(): title = h1.text.strip() break - + # Try title tag if h1 not found if not title and soup.title: title = soup.title.string.strip() if soup.title.string else "" - + # Try meta property="og:title" if not title: - og_title = soup.find('meta', property='og:title') - if og_title and og_title.get('content'): - title = og_title['content'].strip() + og_title = soup.find("meta", property="og:title") + if og_title and og_title.get("content"): + title = og_title["content"].strip() print(f"Extracted title: {title}") # Simply extract all text from the page # Remove script, style, and other non-content elements - for unwanted in soup.find_all(['script', 'style', 'noscript']): + for unwanted in soup.find_all(["script", "style", "noscript"]): unwanted.decompose() - + # Get all text content - text_content = soup.get_text(separator='\n', strip=True) - + text_content = soup.get_text(separator="\n", strip=True) + # Clean the text if text_content: # Remove extra blank lines - cleaned_text = re.sub(r'\n{3,}', '\n\n', text_content) + cleaned_text = re.sub(r"\n{3,}", "\n\n", text_content) # Remove extra spaces - cleaned_text = re.sub(r' {2,}', ' ', cleaned_text) + cleaned_text = re.sub(r" {2,}", " ", cleaned_text) # Remove very short lines and common noise - lines = cleaned_text.split('\n') + lines = cleaned_text.split("\n") filtered_lines = [] for line in lines: line = line.strip() # Filter out very short lines and common non-content patterns - if (len(line) > 3 and - 'browser does not support' not in line.lower() and - not re.match(r'^[0-9\s\-\/\.]+$', line)): # Filter date-only lines + if ( + len(line) > 3 + and "browser does not support" not in line.lower() + and not re.match(r"^[0-9\s\-\/\.]+$", line) + ): # Filter date-only lines filtered_lines.append(line) - - cleaned_text = '\n'.join(filtered_lines) - + + cleaned_text = "\n".join(filtered_lines) + print(f"Final content length: {len(cleaned_text)} characters") print(f"Content preview: {cleaned_text[:300]}...") - + return title, cleaned_text else: print("No content extracted") @@ -140,25 +145,26 @@ def extract_web_content(url): except Exception as e: print(f"Parse error: {e}") import traceback + traceback.print_exc() return None def extract_text_from_txt(file_path): """Extract text content from TXT file - + Args: file_path (str): TXT file path - + Returns: str: Extracted text content, returns None if failed """ try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: return f.read() except UnicodeDecodeError: try: - with open(file_path, 'r', encoding='gbk') as f: + with open(file_path, "r", encoding="gbk") as f: return f.read() except Exception as e: print(f"TXT file reading failed (tried GBK encoding): {str(e)}") @@ -170,17 +176,17 @@ def extract_text_from_txt(file_path): def parse_input_content(input_path): """Parse input content, supports URL, PDF or TXT files - + Args: input_path (str): URL address, PDF file path or TXT file path - + Returns: str: Parsed text content """ print(f"Parsing input: {input_path}") - + # Check if it's a URL - if input_path.startswith(('http://', 'https://')): + if input_path.startswith(("http://", "https://")): print("URL detected, extracting web content...") result = extract_web_content(input_path) if result: @@ -191,9 +197,9 @@ def parse_input_content(input_path): else: print("Web content extraction failed") return None - + # Check if it's a PDF file - elif input_path.lower().endswith('.pdf'): + elif input_path.lower().endswith(".pdf"): print("PDF file detected, extracting content...") content = extract_text_from_pdf(input_path) if content: @@ -202,9 +208,9 @@ def parse_input_content(input_path): else: print("PDF content extraction failed") return None - + # Check if it's a TXT file - elif input_path.lower().endswith('.txt'): + elif input_path.lower().endswith(".txt"): print("TXT file detected, reading content...") content = extract_text_from_txt(input_path) if content: @@ -213,7 +219,7 @@ def parse_input_content(input_path): else: print("TXT file reading failed") return None - + else: print(f"Unsupported input format: {input_path}") return None @@ -221,7 +227,8 @@ def parse_input_content(input_path): # =============== Dialogue Script Generation Function =============== -def generate_podcast_script(content, language='zh'): + +def generate_podcast_script(content, language="zh"): """Call large model to generate podcast dialogue script""" from openai import OpenAI @@ -230,7 +237,7 @@ def generate_podcast_script(content, language='zh'): base_url=os.getenv("OPENAI_API_BASE", "YOUR_API_BASE_URL"), ) - if language == 'zh': + if language == "zh": role_play = "两位中文播客主持人" instruction = f"""你是一位专业的中文播客文字脚本撰稿人。现在请你根据提供的有关最新AI及大模型相关进展的原始资料,生成一段模拟{role_play}之间的自然对话脚本。该脚本应符合以下具体要求: 一、语言风格 @@ -270,7 +277,7 @@ def generate_podcast_script(content, language='zh'): 请根据以上要求和提供的原始资料,将其转化为符合以上所有要求的播客对话脚本。一定要用[S1]和[S2]标记两位说话人,绝对不能使用任何其它符号标记说话人。 注意:直接输出结果,不要包含任何额外信息。 """ - else: # Default to English + else: # Default to English role_play = "two English podcast hosts" instruction = f"""You are a professional English podcast scriptwriter. Based on the provided source material about the latest developments in AI and large models, generate a natural conversational script simulating a dialogue between {role_play}. The script should meet the following specific requirements: I. Language Style @@ -314,23 +321,18 @@ def generate_podcast_script(content, language='zh'): try: print("Calling large model to generate dialogue script...") print(f"Input content length: {len(content)} characters") - + completion = client.chat.completions.create( model="gemini-2.5-pro", # model="gemini-2.5-flash-preview-04-17", - messages=[ - { - "role": "user", - "content": instruction - } - ] + messages=[{"role": "user", "content": instruction}], ) - + raw_result = completion.choices[0].message.content - + # Remove all newlines - processed_result = raw_result.replace('\n', '').replace('\r', '') - + processed_result = raw_result.replace("\n", "").replace("\r", "") + print("=" * 50) print("Large model generated dialogue script (original version):") print("=" * 50) @@ -342,9 +344,9 @@ def generate_podcast_script(content, language='zh'): print("=" * 50) print(f"Original script length: {len(raw_result)} characters") print(f"Processed script length: {len(processed_result)} characters") - + return processed_result - + except Exception as e: print(f"Large model call failed: {str(e)}") # If large model call fails, return a sample script for testing @@ -355,17 +357,20 @@ def generate_podcast_script(content, language='zh'): # =============== Main Function =============== -def process_input_to_audio(input_path: str, output_dir: str = "examples", language: str = 'zh'): + +def process_input_to_audio( + input_path: str, output_dir: str = "examples", language: str = "zh" +): """Complete processing pipeline: from input to audio output - + Args: input_path (str): Input path (URL, PDF or TXT file) output_dir (str): Output directory language (str): Language for the podcast script ('en' or 'zh') """ - + # Select prompts based on language - if language == 'zh': + if language == "zh": prompt_audio_speaker1 = ZH_PROMPT_AUDIO_SPEAKER1 prompt_text_speaker1 = ZH_PROMPT_TEXT_SPEAKER1 prompt_audio_speaker2 = ZH_PROMPT_AUDIO_SPEAKER2 @@ -375,20 +380,20 @@ def process_input_to_audio(input_path: str, output_dir: str = "examples", langua prompt_text_speaker1 = EN_PROMPT_TEXT_SPEAKER1 prompt_audio_speaker2 = EN_PROMPT_AUDIO_SPEAKER2 prompt_text_speaker2 = EN_PROMPT_TEXT_SPEAKER2 - + print(f"Using {language} prompts:") print(f"Speaker 1: {prompt_audio_speaker1}") print(f"Speaker 2: {prompt_audio_speaker2}") - + # 1. Parse input content print("Step 1: Parse input content") content = parse_input_content(input_path) if not content: print("Unable to parse input content, program terminated") return - + print(f"Content parsed successfully, content preview: {content[:200]}...") - + # 2. Use large model to generate dialogue script print("\nStep 2: Generate dialogue script") script = generate_podcast_script(content, language=language) @@ -402,22 +407,24 @@ def process_input_to_audio(input_path: str, output_dir: str = "examples", langua spt = spt.to(device) model = model.to(device) print("TTS model loading completed") - + # 4. Prepare TTS input data with language-specific prompts print("\nStep 4: Prepare TTS input data") - items = [{ - "text": script, - "base_path": "", - "prompt_audio_speaker1": prompt_audio_speaker1, - "prompt_text_speaker1": prompt_text_speaker1, - "prompt_audio_speaker2": prompt_audio_speaker2, - "prompt_text_speaker2": prompt_text_speaker2 - }] - + items = [ + { + "text": script, + "base_path": "", + "prompt_audio_speaker1": prompt_audio_speaker1, + "prompt_text_speaker1": prompt_text_speaker1, + "prompt_audio_speaker2": prompt_audio_speaker2, + "prompt_text_speaker2": prompt_text_speaker2, + } + ] + # 5. Set random seed # import accelerate # accelerate.utils.set_seed(42) - + # 6. Generate audio print("\nStep 5: Generate audio") actual_texts_data, audio_results = process_batch( @@ -428,21 +435,23 @@ def process_input_to_audio(input_path: str, output_dir: str = "examples", langua device=device, system_prompt=SYSTEM_PROMPT, start_idx=0, - use_normalize=True + use_normalize=True, ) - + # 7. Save audio files print("\nStep 6: Save audio files") os.makedirs(output_dir, exist_ok=True) - + for idx, audio_result in enumerate(audio_results): if audio_result is not None: output_path = os.path.join(output_dir, f"generated_podcast_{idx}.wav") - torchaudio.save(output_path, audio_result["audio_data"], audio_result["sample_rate"]) + torchaudio.save( + output_path, audio_result["audio_data"], audio_result["sample_rate"] + ) print(f"Audio saved to: {output_path}") else: print(f"Audio generation failed: sample {idx}") - + print("\nProcessing completed!") @@ -450,16 +459,28 @@ def process_input_to_audio(input_path: str, output_dir: str = "examples", langua if __name__ == "__main__": # Add command line argument parsing - parser = argparse.ArgumentParser(description="Generate podcast audio: supports URL, PDF or TXT file input") - parser.add_argument("input_path", help="Input path: URL address, PDF file path or TXT file path") - parser.add_argument("-o", "--output", default="outputs", help="Output directory (default: outputs)") - parser.add_argument("-l", "--language", default="zh", choices=['en', 'zh'], help="Language of the podcast script (en or zh, default: zh)") - + parser = argparse.ArgumentParser( + description="Generate podcast audio: supports URL, PDF or TXT file input" + ) + parser.add_argument( + "input_path", help="Input path: URL address, PDF file path or TXT file path" + ) + parser.add_argument( + "-o", "--output", default="outputs", help="Output directory (default: outputs)" + ) + parser.add_argument( + "-l", + "--language", + default="zh", + choices=["en", "zh"], + help="Language of the podcast script (en or zh, default: zh)", + ) + args = parser.parse_args() - + # Use command line arguments print(f"Input path: {args.input_path}") print(f"Output directory: {args.output}") print(f"Script language: {args.language}") - - process_input_to_audio(args.input_path, args.output, args.language) \ No newline at end of file + + process_input_to_audio(args.input_path, args.output, args.language) diff --git a/requirements.txt b/requirements.txt index 1c3d81e..20df480 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,4 @@ einops huggingface_hub liger_kernel pydub -# flash-attn # Need to install separately \ No newline at end of file +# flash-attn # Need to install separately diff --git a/streamer.py b/streamer.py index 49fc6de..f213631 100644 --- a/streamer.py +++ b/streamer.py @@ -1,25 +1,25 @@ +import argparse +import io import json +import os from queue import Queue + +import accelerate import torch import torchaudio -import accelerate -import argparse -import os from tqdm import tqdm -import io - from transformers.generation.streamers import BaseStreamer from generation_utils import ( - load_model, - process_jsonl_item, - normalize_text, - load_audio_data, - process_inputs, - shifting_inputs, - rpadding, find_max_valid_positions, - process_batch + load_audio_data, + load_model, + normalize_text, + process_batch, + process_inputs, + process_jsonl_item, + rpadding, + shifting_inputs, ) from modeling_asteroid import AsteroidTTSInstruct from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer @@ -30,6 +30,7 @@ SPT_CHECKPOINT_PATH = "XY_Tokenizer/weights/xy_tokenizer.ckpt" MAX_CHANNELS = 8 + class AudioIteratorStreamer(BaseStreamer): MAX_TOKEN_LENGTH = 16384 @@ -37,7 +38,7 @@ class AudioIteratorStreamer(BaseStreamer): OVERLAP_SECONDS = 10 def __init__( - self, + self, tokenizer, model: AsteroidTTSInstruct, spt: XY_Tokenizer, @@ -51,8 +52,12 @@ def __init__( self.use_tqdm = use_tqdm self.speech_offset = model.config.speech_token_range[0] self.channels = model.config.channels - self.duration_code_length = int(self.CHUNK_SIZE * spt.input_sample_rate // spt.encoder_downsample_rate) - self.overlap_code_length = int(self.OVERLAP_SECONDS * spt.input_sample_rate // spt.encoder_downsample_rate) + self.duration_code_length = int( + self.CHUNK_SIZE * spt.input_sample_rate // spt.encoder_downsample_rate + ) + self.overlap_code_length = int( + self.OVERLAP_SECONDS * spt.input_sample_rate // spt.encoder_downsample_rate + ) self.valid_code_length = self.duration_code_length - self.overlap_code_length self.valid_wav_length = int(self.valid_code_length * spt.decoder_upsample_rate) print(f"Speech offset: {self.speech_offset}") @@ -62,33 +67,44 @@ def __init__( print(f"Valid wav length: {self.valid_wav_length}") self.next_tokens_are_prompt = True - self.token_cache = torch.zeros(self.MAX_TOKEN_LENGTH, self.channels, dtype=torch.long, device=device) + self.token_cache = torch.zeros( + self.MAX_TOKEN_LENGTH, self.channels, dtype=torch.long, device=device + ) self.token_cache_length = 0 self.decoded_idx = 0 self.audio_queue = Queue() self.stop_signal = None - + # tqdm related self.pbar = None if self.use_tqdm: self.pbar = tqdm(desc="Processing tokens", unit="token", total=None) def decode(self, last_chunk=False): - duration_to_decode = min(self.duration_code_length, self.token_cache_length - self.decoded_idx - self.channels + 1) + duration_to_decode = min( + self.duration_code_length, + self.token_cache_length - self.decoded_idx - self.channels + 1, + ) speech_ids = torch.full((duration_to_decode, self.channels), 0).to(self.device) for j in range(self.channels): - speech_ids[..., j] = self.token_cache[self.decoded_idx + j : self.decoded_idx + j + duration_to_decode, j] + speech_ids[..., j] = self.token_cache[ + self.decoded_idx + j : self.decoded_idx + j + duration_to_decode, j + ] if j == 0: speech_ids[..., j] = speech_ids[..., j] - self.speech_offset # Decode generated audio with torch.no_grad(): - chunk_codes = speech_ids.permute(1, 0).unsqueeze(1) # Convert to SPT expected format - decode_result = self.spt.inference_detokenize(chunk_codes, torch.tensor([duration_to_decode], device=self.device)) + chunk_codes = speech_ids.permute(1, 0).unsqueeze( + 1 + ) # Convert to SPT expected format + decode_result = self.spt.inference_detokenize( + chunk_codes, torch.tensor([duration_to_decode], device=self.device) + ) audio_result = decode_result["y"][0].cpu().detach() if not last_chunk: - audio_result = audio_result[:, :self.valid_wav_length] + audio_result = audio_result[:, : self.valid_wav_length] self.audio_queue.put(audio_result) # will be out of bound of last chunk is not full @@ -101,33 +117,38 @@ def put(self, value: torch.LongTensor): self.token_cache[self.token_cache_length, :] = value self.token_cache_length += 1 - + # update tqdm pbar if self.use_tqdm and self.pbar is not None: - self.pbar.set_postfix({ - "cache_len": self.token_cache_length, - "decoded": self.decoded_idx, - "cache_usage": f"{self.token_cache_length/self.MAX_TOKEN_LENGTH*100:.1f}%" - }) + self.pbar.set_postfix( + { + "cache_len": self.token_cache_length, + "decoded": self.decoded_idx, + "cache_usage": f"{self.token_cache_length/self.MAX_TOKEN_LENGTH*100:.1f}%", + } + ) self.pbar.update(1) - + if self.token_cache_length >= self.MAX_TOKEN_LENGTH: raise ValueError("Token cache is full") - if self.token_cache_length >= self.decoded_idx + self.duration_code_length + self.channels - 1: + if ( + self.token_cache_length + >= self.decoded_idx + self.duration_code_length + self.channels - 1 + ): self.decode(last_chunk=False) - + def end(self): self.decode(last_chunk=True) self.audio_queue.put(self.stop_signal) - + # close tqdm if self.use_tqdm and self.pbar is not None: self.pbar.close() def __iter__(self): return self - + def __next__(self): value = self.audio_queue.get() if value == self.stop_signal: @@ -135,16 +156,17 @@ def __next__(self): else: return value + def streamer( - batch_items, + batch_items, tokenizer, model: AsteroidTTSInstruct, spt: XY_Tokenizer, device, - system_prompt, + system_prompt, use_normalize=False, use_tqdm=False, - ): +): """Process a batch of data items and generate audio, return audio data and metadata""" # Prepare batch data batch_size = len(batch_items) @@ -153,39 +175,45 @@ def streamer( prompts = [system_prompt] * batch_size prompt_audios = [] actual_texts_data = [] # Store actual text data used - + # Extract text and audio from each sample for i, item in enumerate(batch_items): # Use new processing function processed_item = process_jsonl_item(item) - + text = processed_item["text"] prompt_text = processed_item["prompt_text"] - + # Merge text, if prompt_text is empty, full_text is just text full_text = prompt_text + text if prompt_text else text original_full_text = full_text # Save original text - + # Apply text normalization based on parameter if use_normalize: full_text = normalize_text(full_text) - + # Replace speaker tags - final_text = full_text.replace("[S1]", "").replace("[S2]", "") + final_text = full_text.replace("[S1]", "").replace( + "[S2]", "" + ) texts.append(final_text) - + # Save actual text information used - actual_texts_data.append({ - "index": i, - "original_text": original_full_text, - "normalized_text": normalize_text(original_full_text) if use_normalize else None, - "final_text": final_text, - "use_normalize": use_normalize - }) - + actual_texts_data.append( + { + "index": i, + "original_text": original_full_text, + "normalized_text": ( + normalize_text(original_full_text) if use_normalize else None + ), + "final_text": final_text, + "use_normalize": use_normalize, + } + ) + # Get reference audio prompt_audios.append(processed_item["prompt_audio"]) - + # Process inputs input_ids_list = [] for i, (text, prompt, audio_path) in enumerate(zip(texts, prompts, prompt_audios)): @@ -194,13 +222,13 @@ def streamer( inputs = process_inputs(tokenizer, spt, prompt, text, device, audio_data) inputs = shifting_inputs(inputs, tokenizer) input_ids_list.append(inputs) - + # Pad batch inputs input_ids, attention_mask = rpadding(input_ids_list, MAX_CHANNELS, tokenizer) - + # Batch generation print(f"Starting batch audio generation...") - + # Move inputs to GPU input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) @@ -208,11 +236,15 @@ def streamer( streamer = AudioIteratorStreamer(tokenizer, model, spt, device, use_tqdm=use_tqdm) from threading import Thread - thread = Thread(target=model.generate, kwargs={ - "input_ids": input_ids, - "attention_mask": attention_mask, - "streamer": streamer, - }) + + thread = Thread( + target=model.generate, + kwargs={ + "input_ids": input_ids, + "attention_mask": attention_mask, + "streamer": streamer, + }, + ) thread.start() yield from streamer @@ -222,46 +254,77 @@ def streamer( def main(): parser = argparse.ArgumentParser(description="TTS inference with Asteroid model") - parser.add_argument("--jsonl", default="examples/examples.jsonl", - help="Path to JSONL file (default: examples/examples.jsonl)") - parser.add_argument("--seed", type=int, default=None, - help="Random seed for reproducibility (default: None)") - parser.add_argument("--output_dir", default="outputs/streamer", - help="Output directory for generated audio files (default: outputs)") - parser.add_argument("--use_normalize", action="store_true", default=True, - help="Whether to use text normalization (default: True)") - parser.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16", - help="Model data type (default: bf16)") - parser.add_argument("--attn_implementation", choices=["flash_attention_2", "sdpa", "eager"], default="flash_attention_2", - help="Attention implementation (default: flash_attention_2)") - parser.add_argument("--use_tqdm", action="store_true", default=False, - help="Whether to show progress bar using tqdm (default: False)") - + parser.add_argument( + "--jsonl", + default="examples/examples.jsonl", + help="Path to JSONL file (default: examples/examples.jsonl)", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed for reproducibility (default: None)", + ) + parser.add_argument( + "--output_dir", + default="outputs/streamer", + help="Output directory for generated audio files (default: outputs)", + ) + parser.add_argument( + "--use_normalize", + action="store_true", + default=True, + help="Whether to use text normalization (default: True)", + ) + parser.add_argument( + "--dtype", + choices=["bf16", "fp16", "fp32"], + default="bf16", + help="Model data type (default: bf16)", + ) + parser.add_argument( + "--attn_implementation", + choices=["flash_attention_2", "sdpa", "eager"], + default="flash_attention_2", + help="Attention implementation (default: flash_attention_2)", + ) + parser.add_argument( + "--use_tqdm", + action="store_true", + default=False, + help="Whether to show progress bar using tqdm (default: False)", + ) + args = parser.parse_args() - + # Convert dtype string to torch dtype dtype_mapping = { "bf16": torch.bfloat16, "fp16": torch.float16, - "fp32": torch.float32 + "fp32": torch.float32, } torch_dtype = dtype_mapping[args.dtype] - + # Create output directory if it doesn't exist os.makedirs(args.output_dir, exist_ok=True) - + device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") print(f"Using dtype: {args.dtype} ({torch_dtype})") print(f"Using attention implementation: {args.attn_implementation}") - + # Load models print("Loading models...") - tokenizer, model, spt = load_model(MODEL_PATH, SPT_CONFIG_PATH, SPT_CHECKPOINT_PATH, - torch_dtype=torch_dtype, attn_implementation=args.attn_implementation) + tokenizer, model, spt = load_model( + MODEL_PATH, + SPT_CONFIG_PATH, + SPT_CHECKPOINT_PATH, + torch_dtype=torch_dtype, + attn_implementation=args.attn_implementation, + ) spt = spt.to(device) model = model.to(device) - + # Load the items from the JSONL file try: with open(args.jsonl, "r") as f: @@ -273,18 +336,18 @@ def main(): except json.JSONDecodeError as e: print(f"Error parsing JSONL file: {e}") return - + # Fix the seed for reproducibility if args.seed is not None: accelerate.utils.set_seed(args.seed) print(f"Set random seed to {args.seed}") - + # Process the batch of items print("Starting streaming inference...") # Create output directory for FLAC files flac_output_dir = args.output_dir os.makedirs(flac_output_dir, exist_ok=True) - + audio_generator = streamer( batch_items=[items[0]], tokenizer=tokenizer, @@ -293,20 +356,21 @@ def main(): device=device, system_prompt=SYSTEM_PROMPT, use_normalize=args.use_normalize, - use_tqdm=args.use_tqdm + use_tqdm=args.use_tqdm, ) - + # Save each FLAC chunk to disk full_audio = [] for i, audio in enumerate(audio_generator): - chunk_path = os.path.join(flac_output_dir, f'chunk_{i}.flac') + chunk_path = os.path.join(flac_output_dir, f"chunk_{i}.flac") torchaudio.save(chunk_path, audio, spt.output_sample_rate, format="flac") full_audio.append(audio) full_audio = torch.cat(full_audio, dim=1) - full_audio_path = os.path.join(flac_output_dir, 'full_audio.flac') + full_audio_path = os.path.join(flac_output_dir, "full_audio.flac") torchaudio.save(full_audio_path, full_audio, spt.output_sample_rate, format="flac") print(f"Saved full audio to: {full_audio_path}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/use_api.py b/use_api.py index 622e9ad..eefdd62 100644 --- a/use_api.py +++ b/use_api.py @@ -1,15 +1,16 @@ # -*- coding: utf-8 -*- -import json import base64 -import librosa -import soundfile as sf -from pathlib import Path -from openai import OpenAI -import tempfile -import os import concurrent.futures +import json +import os +import tempfile import threading from functools import partial +from pathlib import Path + +import librosa +import soundfile as sf +from openai import OpenAI from pydub import AudioSegment # Get API credentials from environment variables @@ -24,6 +25,7 @@ # Thread-safe file writing lock write_lock = threading.Lock() + def audio_to_base64(audio_path, target_sr=16000, target_channels=1): """ Convert audio file to 16k mono mp3 format and encode to base64 @@ -31,39 +33,40 @@ def audio_to_base64(audio_path, target_sr=16000, target_channels=1): # Check if file exists if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") - + try: # Load audio file audio, sr = librosa.load(audio_path, sr=target_sr, mono=True) - + # Create temporary wav file to save processed audio - with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_wav_file: + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav_file: temp_wav_path = temp_wav_file.name sf.write(temp_wav_path, audio, target_sr) - + # Convert wav to mp3 - with tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) as temp_mp3_file: + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_mp3_file: temp_mp3_path = temp_mp3_file.name - + # Use pydub for format conversion audio_segment = AudioSegment.from_wav(temp_wav_path) audio_segment.export(temp_mp3_path, format="mp3", bitrate="128k") - + # Read mp3 file and convert to base64 - with open(temp_mp3_path, 'rb') as f: + with open(temp_mp3_path, "rb") as f: audio_data = f.read() - base64_encoded = base64.b64encode(audio_data).decode('utf-8') - + base64_encoded = base64.b64encode(audio_data).decode("utf-8") + # Delete temporary files os.unlink(temp_wav_path) os.unlink(temp_mp3_path) - + return base64_encoded - + except Exception as e: print(f"[ERROR] Audio processing failed: {e}") raise + def process_single_item(line_data, output_dir, line_num, output_jsonl_path): """ Process single data item (for concurrent execution) @@ -74,35 +77,46 @@ def process_single_item(line_data, output_dir, line_num, output_jsonl_path): try: # Extract text input_text = line_data["text"] - + # Check which format - if "prompt_audio_speaker1" in line_data and "prompt_audio_speaker2" in line_data: + if ( + "prompt_audio_speaker1" in line_data + and "prompt_audio_speaker2" in line_data + ): # Separate format (original format) - prompt_audio_speaker1_path = os.path.join(line_data["base_path"], line_data["prompt_audio_speaker1"]) - prompt_audio_speaker2_path = os.path.join(line_data["base_path"], line_data["prompt_audio_speaker2"]) - + prompt_audio_speaker1_path = os.path.join( + line_data["base_path"], line_data["prompt_audio_speaker1"] + ) + prompt_audio_speaker2_path = os.path.join( + line_data["base_path"], line_data["prompt_audio_speaker2"] + ) + # Check if audio files exist if not os.path.exists(prompt_audio_speaker1_path): - raise FileNotFoundError(f"Speaker1 audio file not found: {prompt_audio_speaker1_path}") + raise FileNotFoundError( + f"Speaker1 audio file not found: {prompt_audio_speaker1_path}" + ) if not os.path.exists(prompt_audio_speaker2_path): - raise FileNotFoundError(f"Speaker2 audio file not found: {prompt_audio_speaker2_path}") - + raise FileNotFoundError( + f"Speaker2 audio file not found: {prompt_audio_speaker2_path}" + ) + # Convert audio to base64 audio1_base64 = audio_to_base64(prompt_audio_speaker1_path) audio2_base64 = audio_to_base64(prompt_audio_speaker2_path) - + # Build reference data references = [ { "audio": f"data:audio/mp3;base64,{audio1_base64}", - "text": f"[S1]{line_data['prompt_text_speaker1']}" + "text": f"[S1]{line_data['prompt_text_speaker1']}", }, { "audio": f"data:audio/mp3;base64,{audio2_base64}", - "text": f"[S2]{line_data['prompt_text_speaker2']}" - } + "text": f"[S2]{line_data['prompt_text_speaker2']}", + }, ] - + # Build output record output_record = { "text": line_data["text"], @@ -110,101 +124,114 @@ def process_single_item(line_data, output_dir, line_num, output_jsonl_path): "prompt_text_speaker1": line_data["prompt_text_speaker1"], "prompt_audio_speaker2": line_data["prompt_audio_speaker2"], "prompt_text_speaker2": line_data["prompt_text_speaker2"], - "output_audio": None # Will be set later + "output_audio": None, # Will be set later } - + elif "prompt_audio" in line_data and "prompt_text" in line_data: # Merged format (new format) - prompt_audio_path = os.path.join(line_data["base_path"], line_data["prompt_audio"]) - + prompt_audio_path = os.path.join( + line_data["base_path"], line_data["prompt_audio"] + ) + # Check if audio file exists if not os.path.exists(prompt_audio_path): - raise FileNotFoundError(f"Reference audio file not found: {prompt_audio_path}") - + raise FileNotFoundError( + f"Reference audio file not found: {prompt_audio_path}" + ) + # Convert audio to base64 audio_base64 = audio_to_base64(prompt_audio_path) - + # Build reference data (using single audio and text) references = [ { "audio": f"data:audio/mp3;base64,{audio_base64}", - "text": line_data['prompt_text'] + "text": line_data["prompt_text"], } ] - + # Build output record output_record = { "text": line_data["text"], "prompt_audio": line_data["prompt_audio"], "prompt_text": line_data["prompt_text"], - "output_audio": None # Will be set later + "output_audio": None, # Will be set later } - + else: - raise ValueError("Unsupported data format. Must contain one of the following field sets:\n" - "1. prompt_audio_speaker1, prompt_text_speaker1, prompt_audio_speaker2, prompt_text_speaker2\n" - "2. prompt_audio, prompt_text") - + raise ValueError( + "Unsupported data format. Must contain one of the following field sets:\n" + "1. prompt_audio_speaker1, prompt_text_speaker1, prompt_audio_speaker2, prompt_text_speaker2\n" + "2. prompt_audio, prompt_text" + ) + # Generate output path (using line number as filename) output_filename = f"output_{line_num:04d}.wav" output_path = os.path.join(output_dir, output_filename) output_path = os.path.abspath(output_path) - + # Generate speech generate_speech(input_text, references, output_path, line_num) - + # Set output audio path output_record["output_audio"] = output_path - + # Thread-safe writing to output JSONL file with write_lock: - with open(output_jsonl_path, 'a', encoding='utf-8') as output_f: - output_f.write(json.dumps(output_record, ensure_ascii=False) + '\n') - + with open(output_jsonl_path, "a", encoding="utf-8") as output_f: + output_f.write(json.dumps(output_record, ensure_ascii=False) + "\n") + return f"Line {line_num} processed successfully" - + except Exception as e: error_msg = f"Error processing line {line_num}: {e}" print(f"[ERROR] {error_msg}") return error_msg + def generate_speech(input_text, references, output_path, line_num): """ Generate speech """ # Create request parameters params = dict( - model="fnlp/MOSS-TTSD-v0.5", + model="fnlp/MOSS-TTSD-v0.5", input=input_text, response_format="wav", voice="", extra_body={ "references": references, "max_tokens": 16384, - } + }, ) - + import time + start_time = time.time() - + try: with client.audio.speech.with_streaming_response.create(**params) as response: data = response.read() - + with open(output_path, "wb") as f: f.write(data) - + # Verify if file was written successfully if not os.path.exists(output_path): - print(f"[ERROR] Line {line_num} - File write failed, file does not exist: {output_path}") - + print( + f"[ERROR] Line {line_num} - File write failed, file does not exist: {output_path}" + ) + end_time = time.time() - print(f"Line {line_num} - Audio generation time: {end_time - start_time:.2f} seconds") - + print( + f"Line {line_num} - Audio generation time: {end_time - start_time:.2f} seconds" + ) + except Exception as e: print(f"[ERROR] Line {line_num} - API call failed: {e}") raise + def main(jsonl_file_path, output_dir, max_workers=4): """ Main function: process JSONL file concurrently @@ -212,26 +239,26 @@ def main(jsonl_file_path, output_dir, max_workers=4): print(f"Starting to process JSONL file: {jsonl_file_path}") print(f"Output directory: {output_dir}") print(f"Max workers: {max_workers}") - + # Check if JSONL file exists if not os.path.exists(jsonl_file_path): print(f"[ERROR] JSONL file not found: {jsonl_file_path}") return - + # Ensure output directory exists os.makedirs(output_dir, exist_ok=True) - + # Create output JSONL file path output_jsonl_path = os.path.join(output_dir, "output_results.jsonl") - + # Clear output file (if exists) - with open(output_jsonl_path, 'w', encoding='utf-8') as f: + with open(output_jsonl_path, "w", encoding="utf-8") as f: pass - - with open(jsonl_file_path, 'r', encoding='utf-8') as f: + + with open(jsonl_file_path, "r", encoding="utf-8") as f: lines = f.readlines() print(f"Total {len(lines)} lines of data") - + # Prepare all tasks tasks = [] for line_num, line in enumerate(lines, 1): @@ -241,17 +268,23 @@ def main(jsonl_file_path, output_dir, max_workers=4): except json.JSONDecodeError as e: print(f"[ERROR] Line {line_num} JSON parsing failed: {e}") continue - + print(f"Prepared {len(tasks)} valid tasks") - + # Use thread pool for concurrent processing with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all tasks future_to_line = { - executor.submit(process_single_item, line_data, output_dir, line_num, output_jsonl_path): line_num + executor.submit( + process_single_item, + line_data, + output_dir, + line_num, + output_jsonl_path, + ): line_num for line_data, line_num in tasks } - + # Process completed tasks completed = 0 for future in concurrent.futures.as_completed(future_to_line): @@ -262,22 +295,35 @@ def main(jsonl_file_path, output_dir, max_workers=4): print(f"({completed}/{len(tasks)}) {result}") except Exception as exc: print(f"[ERROR] Line {line_num} generated exception: {exc}") - + print(f"\nAll processing completed!") print(f"Output audio files saved in: {output_dir}") print(f"Output JSONL file saved in: {output_jsonl_path}") + if __name__ == "__main__": import argparse - - parser = argparse.ArgumentParser(description='MOSS-TTSD API batch processing tool') - parser.add_argument('--jsonl_file', type=str, default="examples/examples.jsonl", - help='Input JSONL file path (default: examples/examples.jsonl)') - parser.add_argument('--output_dir', type=str, default="api_outputs", - help='Output directory path (default: api_outputs)') - parser.add_argument('--max_workers', type=int, default=8, - help='Maximum number of concurrent workers (default: 8)') - + + parser = argparse.ArgumentParser(description="MOSS-TTSD API batch processing tool") + parser.add_argument( + "--jsonl_file", + type=str, + default="examples/examples.jsonl", + help="Input JSONL file path (default: examples/examples.jsonl)", + ) + parser.add_argument( + "--output_dir", + type=str, + default="api_outputs", + help="Output directory path (default: api_outputs)", + ) + parser.add_argument( + "--max_workers", + type=int, + default=8, + help="Maximum number of concurrent workers (default: 8)", + ) + args = parser.parse_args() - - main(args.jsonl_file, args.output_dir, args.max_workers) \ No newline at end of file + + main(args.jsonl_file, args.output_dir, args.max_workers) From f396a52fa2c8f687cf1aeac3d25767be9f7a9592 Mon Sep 17 00:00:00 2001 From: xiami2019 <435350193@qq.com> Date: Tue, 4 Nov 2025 20:11:15 +0800 Subject: [PATCH 4/4] update links --- README.md | 9 +++++---- README_zh.md | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 87edd1d..a1773bd 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,12 @@

blog paper - Hugging Face Hugging Face Spaces version python mit +
+ MOSS-TTSD-v0.5 # MOSS-TTSD 🪐 @@ -24,7 +25,7 @@ MOSS-TTSD (text to spoken dialogue) is an open-source bilingual spoken dialogue synthesis model that supports both Chinese and English. It can transform dialogue scripts between two speakers into natural, expressive conversational speech. MOSS-TTSD supports voice cloning and long single-session speech generation, making it ideal for AI podcast production, interviews, and chats. - For detailed information about the model and demos, please refer to our [Blog-en](https://www.open-moss.com/en/moss-ttsd/) and [中文博客](https://www.open-moss.com/cn/moss-ttsd/). You can also find the model on [Hugging Face](https://huggingface.co/fnlp/MOSS-TTSD-v0.5) and try it out in the [Spaces demo](https://huggingface.co/spaces/fnlp/MOSS-TTSD). + For detailed information about the model and demos, please refer to our [Blog-en](https://www.open-moss.com/en/moss-ttsd/) and [中文博客](https://www.open-moss.com/cn/moss-ttsd/). You can also find the model on [Hugging Face](https://huggingface.co/fnlp/MOSS-TTSD-v0.7) and try it out in the [Spaces demo](https://huggingface.co/spaces/fnlp/MOSS-TTSD). ## Highlights @@ -36,7 +37,7 @@ MOSS-TTSD supports voice cloning and long single-session speech generation, maki ## News 🚀 - - **[2025-11-01]** MOSS-TTSD v0.7 is released! v0.7 significantly improves audio quality, voice cloning capability, and stability, adds support for 32 kHz high‑quality output, greatly extends single‑pass generation length (960s→1700s), and more reliably generates speech events following speaker tags. We recommend using the v0.7 model by default. + - **[2025-11-01]** MOSS-TTSD v0.7 is released! v0.7 significantly improves audio quality, voice cloning capability, and stability, adds support for 32 kHz high‑quality output, greatly extends single‑pass generation length (960s→1700s). We recommend using the v0.7 model by default. [MOSS-TTSD v0.7 Model Address](https://huggingface.co/fnlp/MOSS-TTSD-v0.7) - **[2025-09-09]** We supported SGLang inference engine to accelerate model inference by up to **16x**. - **[2025-08-25]** We released the 32khz version of XY-Tokenizer. - **[2025-08-12]** We add support for streaming inference in MOSS-TTSD v0.5. @@ -59,7 +60,7 @@ pip install flash-attn ### Download XY-Tokenizer -You also need to download the XY Tokenizer model weights. You can find the weights in the [XY_Tokenizer repository](https://huggingface.co/fnlp/XY_Tokenizer_TTSD_V0_32k). +You also need to download the XY Tokenizer model weights. You can find the weights in the [XY-Tokenizer-TTSD version repository](https://huggingface.co/fnlp/MOSS_TTSD_tokenizer). ```bash mkdir -p XY_Tokenizer/weights diff --git a/README_zh.md b/README_zh.md index ecf5c79..d01b568 100644 --- a/README_zh.md +++ b/README_zh.md @@ -8,11 +8,12 @@

blog paper - Hugging Face Hugging Face Spaces version python mit +
+ MOSS-TTSD-v0.5 # MOSS-TTSD 🪐 @@ -22,7 +23,7 @@ ## 概述 MOSS-TTSD(text to spoken dialogue)是一个开源的中英双语口语对话合成模型,可以将包含两位说话人的对话脚本转换为自然、富有表现力的对话语音。MOSS-TTSD支持双说话人零样本音色克隆与长时间单段语音生成,非常适合播客,访谈,聊天等对话场景。 -详细模型介绍与演示请见我们的[中文博客](https://www.open-moss.com/cn/moss-ttsd/)和[Blog-en](https://www.open-moss.com/en/moss-ttsd/)。模型权重在 [Hugging Face](https://huggingface.co/fnlp/MOSS-TTSD-v0.5) 提供,并可在 [Spaces 演示](https://huggingface.co/spaces/fnlp/MOSS-TTSD) 在线体验。 +详细模型介绍与演示请见我们的[中文博客](https://www.open-moss.com/cn/moss-ttsd/)和[Blog-en](https://www.open-moss.com/en/moss-ttsd/)。模型权重在 [Hugging Face](https://huggingface.co/fnlp/MOSS-TTSD-v0.7) 提供,并可在 [Spaces 演示](https://huggingface.co/spaces/fnlp/MOSS-TTSD) 在线体验。 ## 亮点 @@ -34,7 +35,7 @@ MOSS-TTSD(text to spoken dialogue)是一个开源的中英双语口语对话 ## 最新动态 🚀 -- **[2025-11-01]** 我们发布了 MOSS-TTSD v0.7:显著提升了音质、声音克隆能力与稳定性,支持32khz高音质输出,并大幅拓展了单次生成长度(960s->1700s),更够比较稳定地根据说话人标签生成语音事件。 +- **[2025-11-01]** 我们发布了 MOSS-TTSD v0.7:显著提升了音质、声音克隆能力与稳定性,支持32khz高音质输出,并大幅拓展了单次生成长度(960s->1700s)。我们推荐默认使用MOSS-TTSD v0.7版本。[MOSS-TTSD v0.7 模型地址](https://huggingface.co/fnlp/MOSS-TTSD-v0.7) - **[2025-09-09]** 我们支持了 SGLang 推理引擎加速模型推理,最高可加速**16倍**。 - **[2025-08-25]** 我们发布了 32khz XY-Tokenizer。 - **[2025-08-12]** 我们支持了 MOSS-TTSD v0.5 的流式推理。 @@ -57,7 +58,7 @@ pip install flash-attn ### 下载 XY-Tokenizer 权重 -首先需要下载 XY-Tokenizer 的Codec模型权重,见 [XY_Tokenizer仓库](https://huggingface.co/fnlp/XY_Tokenizer_TTSD_V0_32k)。 +首先需要下载 XY-Tokenizer 的Codec模型权重,见[XY-Tokenizer-TTSD版本仓库](https://huggingface.co/fnlp/MOSS_TTSD_tokenizer)。 ```bash mkdir -p XY_Tokenizer/weights