From e573ee33d5d3a3c292bbba4c81fd75fd81730946 Mon Sep 17 00:00:00 2001
From: rulerman <1330677461@qq.com>
Date: Fri, 31 Oct 2025 08:43:13 +0000
Subject: [PATCH 1/4] update v0.7
---
README.md | 26 +-
README_zh.md | 25 +-
XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml | 101 +++++
XY_Tokenizer/xy_tokenizer/model.py | 21 +-
generation_utils.py | 411 +++++++++++++------
inference.py | 6 +-
6 files changed, 427 insertions(+), 163 deletions(-)
create mode 100644 XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml
diff --git a/README.md b/README.md
index e026944..fa02cdc 100644
--- a/README.md
+++ b/README.md
@@ -28,20 +28,21 @@ MOSS-TTSD supports voice cloning and long single-session speech generation, maki
## Highlights
-- **Highly Expressive Dialogue Speech**: Built on unified semantic-acoustic neural audio codec, a pre-trained large language model, millions of hours of TTS data, and 400k hours synthetic and real conversational speech, MOSS-TTSD generates highly expressive, human-like dialogue speech with natural conversational prosody.
+- **Highly Expressive Dialogue Speech**: Built on unified semantic-acoustic neural audio codec, a pre-trained large language model, millions of hours of TTS data, and 600k hours synthetic and real conversational speech, MOSS-TTSD generates highly expressive, human-like dialogue speech with natural conversational prosody.
- **Two-Speaker Voice Cloning**: MOSS-TTSD supports zero-shot two speakers voice cloning and can generate conversational speech with accurate speaker swithcing based on dialogue scripts. Only 10 to 20 seconds of reference audio is needed.
- **Chinese-English Bilingual Support**: MOSS-TTSD enables highly expressive speech generation in both Chinese and English.
-- **Long-Form Speech Generation**: Thanks to low-bitrate codec and training framework optimization, MOSS-TTSD has been trained for long speech generation (Training maximum length is 960s).
+- **Long-Form Speech Generation**: Thanks to low-bitrate codec and training framework optimization, MOSS-TTSD has been trained for long speech generation (Training maximum length is 1700s).
- **Fully Open Source & Commercial-Ready**: MOSS-TTSD and its future updates will be fully open-source and support free commercial use.
## News 🚀
+ - **[2025-10-31]** MOSS-TTSD v0.7 is released! v0.7 has significantly improved audio quality, voice cloning capability, and stability, greatly extended single-pass generation length (960s→1700s), and more reliably generates speech events following speaker tags. We recommend using the v0.7 model by default.
- **[2025-09-09]** We supported SGLang inference engine to accelerate model inference by up to **16x**.
- **[2025-08-25]** We released the 32khz version of XY-Tokenizer.
- **[2025-08-12]** We add support for streaming inference in MOSS-TTSD v0.5.
- **[2025-07-29]** We provide the SiliconFlow API interface and usage examples for MOSS-TTSD v0.5.
- **[2025-07-16]** We open-source the fine-tuning code for MOSS-TTSD v0.5, supporting full-parameter fine-tuning, LoRA fine-tuning, and multi-node training.
-- **[2025-07-04]** MOSS-TTSD v0.5 is released! v0.5 has enhanced the accuracy of timbre switching, voice cloning capability, and model stability. We recommend using the v0.5 model by default.
+- **[2025-07-04]** MOSS-TTSD v0.5 is released! v0.5 has enhanced the accuracy of timbre switching, voice cloning capability, and model stability.
- **[2025-06-20]** MOSS-TTSD v0 is released! Moreover, we provide a podcast generation pipeline named Podever, which can automatically convert PDF, URL, or long text files into high-quality podcasts.
## Installation
@@ -62,7 +63,7 @@ You also need to download the XY Tokenizer model weights. You can find the weigh
```bash
mkdir -p XY_Tokenizer/weights
-huggingface-cli download fnlp/XY_Tokenizer_TTSD_V0_32k xy_tokenizer.ckpt --local-dir ./XY_Tokenizer/weights/
+huggingface-cli download fnlp/MOSS_TTSD_tokenizer MOSS_TTSD_tokenizer --local-dir ./XY_Tokenizer/weights/
```
## Usage
@@ -89,16 +90,9 @@ Parameters:
#### JSONL Input Format
-The input JSONL file should contain one JSON object per line. MOSS-TTSD supports multiple input formats:
+The input JSONL file should contain one JSON object per line. MOSS-TTSD supports two input formats:
-**Format 1: Text-only input (No voice cloning, using the model's random timbre)**
-```json
-{
- "text": "[S1]Speaker 1 dialogue content[S2]Speaker 2 dialogue content[S1]..."
-}
-```
-
-**Format 2: Separate speaker audio references**
+**Format 1: Separate speaker audio references**
```json
{
"base_path": "/path/to/audio/files",
@@ -110,7 +104,7 @@ The input JSONL file should contain one JSON object per line. MOSS-TTSD supports
}
```
-**Format 3: Shared audio reference**
+**Format 2: Shared audio reference**
```json
{
"base_path": "/path/to/audio/files",
@@ -126,11 +120,11 @@ The input JSONL file should contain one JSON object per line. MOSS-TTSD supports
- `text`: Dialogue script with speaker tags `[S1]` and `[S2]` indicating speaker turns (required)
- `base_path`: Base directory path for relative file paths (optional)
-**For voice cloning (Format 2):**
+**For voice cloning (Format 1):**
- `prompt_audio_speaker1/2`: Path to reference audio files for voice cloning (relative to `base_path`)
- `prompt_text_speaker1/2`: Reference text corresponding to the audio prompts for better voice matching
-**For shared reference (Format 3):**
+**For shared reference (Format 2):**
- `prompt_audio`: Path to shared reference audio file containing both speakers' voices (relative to `base_path`)
- `prompt_text`: Reference text corresponding to the audio, also using `[S1]` and `[S2]` tags to distinguish speakers
diff --git a/README_zh.md b/README_zh.md
index e626eab..5b19669 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -26,14 +26,15 @@ MOSS-TTSD(text to spoken dialogue)是一个开源的中英双语口语对话
## 亮点
-- **高表现力对话语音**:基于统一语义-声学神经音频Codec、预训练大语言模型、百万小时TTS数据与约40万小时的真实/合成对话语音数据,MOSS-TTSD能够生成高表现力,高自然度,具有自然对话韵律的拟人对话语音。
+- **高表现力对话语音**:基于统一语义-声学神经音频Codec、预训练大语言模型、百万小时TTS数据与约60万小时的真实/合成对话语音数据,MOSS-TTSD能够生成高表现力,高自然度,具有自然对话韵律的拟人对话语音。
- **双说话人零样本声音克隆**:MOSS-TTSD支持零样本双说话人克隆,按脚本精确进行角色/声线切换。只需要提供10到20秒的参考音频片段。
- **中英双语**:MOSS-TTSD支持中英两种语言的高表现力语音生成。
-- **长音频生成**:得益于低码率Codec与训练框架优化,MOSS-TTSD在长音频生成场景进行了大量训练(训练最大长度达到960s),能够单次生成超长音频。
+- **长音频生成**:得益于低码率Codec与训练框架优化,MOSS-TTSD在长音频生成场景进行了大量训练(训练最大长度达到1700s),能够单次生成超长音频。
- **开源可商用**:当前与后续版本将保持开源,支持免费商用。
## 最新动态 🚀
+- **[2025-10-31]** 我们发布了 MOSS-TTSD v0.7:显著提升了音质、声音克隆能力与稳定性,大幅拓展了单次生成长度(960s->1700s),更够比较稳定地根据说话人标签生成语音事件。
- **[2025-09-09]** 我们支持了 SGLang 推理引擎加速模型推理,最高可加速**16倍**。
- **[2025-08-25]** 我们发布了 32khz XY-Tokenizer。
- **[2025-08-12]** 我们支持了 MOSS-TTSD v0.5 的流式推理。
@@ -60,7 +61,7 @@ pip install flash-attn
```bash
mkdir -p XY_Tokenizer/weights
-huggingface-cli download fnlp/XY_Tokenizer_TTSD_V0_32k xy_tokenizer.ckpt --local-dir ./XY_Tokenizer/weights/
+huggingface-cli download fnlp/MOSS_TTSD_tokenizer MOSS_TTSD_tokenizer --local-dir ./XY_Tokenizer/weights/
```
## 使用方法
@@ -87,17 +88,9 @@ python inference.py --jsonl examples/examples.jsonl --output_dir outputs --seed
#### JSONL 输入格式
-MOSS-TTSD支持多种输入格式:
+MOSS-TTSD 支持两种输入格式:
-**格式1:仅文本(不进行声音克隆,使用模型随机音色)**
-
-```json
-{
- "text": "[S1]说话人1的内容[S2]说话人2的内容[S1]..."
-}
-```
-
-**格式2:分别提供两位说话人的参考音频**
+**格式1:分别提供两位说话人的参考音频**
```json
{
@@ -110,7 +103,7 @@ MOSS-TTSD支持多种输入格式:
}
```
-**格式3:共享参考音频(一个参考音频包含两个说话人的内容)**
+**格式2:共享参考音频(一个参考音频包含两个说话人的内容)**
```json
{
@@ -128,12 +121,12 @@ MOSS-TTSD支持多种输入格式:
- `text`:带 `[S1]`、`[S2]` 说话人标签的对话脚本(必填)
- `base_path`:相对路径的基准目录(可选)
-**用于声音克隆(格式2):**
+**用于声音克隆(格式1):**
- `prompt_audio_speaker1/2`:两位说话人的参考音频(可相对 `base_path`)
- `prompt_text_speaker1/2`:对应参考音频的文本,有助于更好匹配音色
-**用于共享参考(格式3):**
+**用于共享参考(格式2):**
- `prompt_audio`:包含两位说话人的共享参考音频(可相对 `base_path`)
- `prompt_text`:对应的参考文本,亦使用 `[S1]`、`[S2]` 区分
diff --git a/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml b/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml
new file mode 100644
index 0000000..ce539c3
--- /dev/null
+++ b/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml
@@ -0,0 +1,101 @@
+generator_params:
+ input_sample_rate: 16000
+ output_sample_rate: 32000
+
+ # Codec / model architecture (inference required)
+ semantic_encoder_kwargs: # 100hz -> 50hz
+ num_mel_bins: 80
+ sampling_rate: 16000
+ hop_length: 160
+ stride_size: 2
+ kernel_size: 3
+ d_model: 768
+ scale_embedding: false
+ max_audio_seconds: 30
+ encoder_layers: 12
+ encoder_attention_heads: 12
+ encoder_ffn_dim: 3072
+ activation_function: "gelu"
+
+ semantic_encoder_adapter_kwargs: # 50hz
+ input_dim: 768
+ output_dim: 768
+ d_model: 768
+ max_source_positions: 1500
+ encoder_layers: 4
+ encoder_attention_heads: 12
+ encoder_ffn_dim: 3072
+
+ acoustic_encoder_kwargs: # 100hz -> 50hz
+ num_mel_bins: 80
+ sampling_rate: 16000
+ hop_length: 160
+ stride_size: 2
+ kernel_size: 3
+ d_model: 768
+ scale_embedding: false
+ max_audio_seconds: 30
+ encoder_layers: 12
+ encoder_attention_heads: 12
+ encoder_ffn_dim: 3072
+ activation_function: "gelu"
+
+ pre_rvq_adapter_kwargs: # 50hz
+ input_dim: 1536
+ output_dim: 768
+ d_model: 768
+ max_source_positions: 1500
+ encoder_layers: 4
+ encoder_attention_heads: 12
+ encoder_ffn_dim: 3072
+
+ downsample_kwargs: # 50hz -> 12.5hz
+ d_model: 768
+ avg_pooler: 4
+
+ quantizer_kwargs: # 12.5hz
+ input_dim: 3072
+ rvq_dim: 512
+ output_dim: 3072
+ num_quantizers: 8
+ codebook_size: 1024
+ codebook_dim: 512
+ quantizer_dropout: 0.0
+ commitment: 1
+
+ post_rvq_adapter_kwargs: # 12.5hz
+ input_dim: 3072
+ output_dim: 3072
+ d_model: 768
+ max_source_positions: 375
+ encoder_layers: 4
+ encoder_attention_heads: 12
+ encoder_ffn_dim: 3072
+
+ upsample_kwargs: # 12.5hz -> 50hz
+ d_model: 768
+ stride: 4
+
+ acoustic_decoder_kwargs: # 50hz -> 100hz
+ num_mel_bins: 80
+ sampling_rate: 16000
+ hop_length: 160
+ stride_size: 2
+ kernel_size: 3
+ d_model: 768
+ scale_embedding: false
+ max_audio_seconds: 30
+ decoder_layers: 12
+ decoder_attention_heads: 12
+ decoder_ffn_dim: 3072
+ activation_function: "gelu"
+
+ vocos_kwargs: # 100hz -> 32khz
+ input_channels: 80
+ dim: 512
+ intermediate_dim: 4096
+ num_layers: 30
+ n_fft: 1280
+ hop_size: 320
+ padding: "same"
+
diff --git a/XY_Tokenizer/xy_tokenizer/model.py b/XY_Tokenizer/xy_tokenizer/model.py
index fa1a672..4c51d5d 100644
--- a/XY_Tokenizer/xy_tokenizer/model.py
+++ b/XY_Tokenizer/xy_tokenizer/model.py
@@ -17,8 +17,8 @@ def __init__(self, generator_params):
self.input_sample_rate = generator_params['input_sample_rate']
self.output_sample_rate = generator_params['output_sample_rate']
- self.encoder_downsample_rate = generator_params['encoder_downsample_rate']
- self.decoder_upsample_rate = generator_params['decoder_upsample_rate']
+ self.encoder_downsample_rate = 1280
+ self.decoder_upsample_rate = int(self.output_sample_rate / 12.5)
self.code_dim = generator_params['quantizer_kwargs']['input_dim']
## Codec part
@@ -49,7 +49,22 @@ def __init__(self, generator_params):
self.enhanced_vocos = Vocos(**generator_params['vocos_kwargs'])
## Feature extractor
- self.feature_extractor = MelFeatureExtractor(**generator_params['feature_extractor_kwargs'])
+ default_feature_extractor_kwargs = {
+ 'chunk_length': 30,
+ 'feature_size': 80,
+ 'hop_length': 160,
+ 'n_fft': 400,
+ 'n_samples': 480000,
+ 'nb_max_frames': 3000,
+ 'padding_side': 'right',
+ 'padding_value': 0.0,
+ 'return_attention_mask': False,
+ 'sampling_rate': self.input_sample_rate,
+ }
+ fe_kwargs = generator_params.get('feature_extractor_kwargs', {})
+ merged_fe_kwargs = {**default_feature_extractor_kwargs, **fe_kwargs}
+ merged_fe_kwargs['sampling_rate'] = self.input_sample_rate
+ self.feature_extractor = MelFeatureExtractor(**merged_fe_kwargs)
@torch.inference_mode()
def inference_tokenize(self, x, input_lengths):
diff --git a/generation_utils.py b/generation_utils.py
index 2d69f5a..b7c2d09 100644
--- a/generation_utils.py
+++ b/generation_utils.py
@@ -8,6 +8,66 @@
MAX_CHANNELS = 8
+def pad_or_truncate_to_seconds(wav: torch.Tensor, target_seconds: float, sr: int) -> torch.Tensor:
+ """Pad or truncate a mono waveform to target length in seconds.
+
+ Args:
+ wav: (1, T) or (T,) tensor
+ target_seconds: target duration in seconds
+ sr: sample rate
+ Returns:
+ (1, T_target) tensor
+ """
+ if wav.dim() == 2 and wav.shape[0] == 1:
+ wav_1d = wav.squeeze(0)
+ else:
+ wav_1d = wav.reshape(-1)
+ target_len = int(round(target_seconds * sr))
+ cur_len = wav_1d.shape[-1]
+ if cur_len == target_len:
+ out = wav_1d
+ elif cur_len > target_len:
+ out = wav_1d[:target_len]
+ else:
+ pad_len = target_len - cur_len
+ out = torch.cat(
+ [wav_1d, torch.zeros(pad_len, dtype=wav_1d.dtype, device=wav_1d.device)], dim=-1
+ )
+ return out.unsqueeze(0)
+
+
+def crossfade_concat(segments: list, sample_rate: int, crossfade_seconds: float = 0.1) -> torch.Tensor:
+ """Concatenate segments with linear crossfade.
+
+ Args:
+ segments: list of (1, T) tensors
+ sample_rate: sampling rate
+ crossfade_seconds: overlap time for crossfade
+ Returns:
+ (1, T_total) tensor
+ """
+ if len(segments) == 0:
+ return torch.zeros(1, 0)
+ if len(segments) == 1:
+ return segments[0]
+ out = segments[0]
+ cf_len_target = int(round(crossfade_seconds * sample_rate))
+ for k in range(1, len(segments)):
+ nxt = segments[k]
+ if cf_len_target <= 0:
+ out = torch.cat([out, nxt], dim=-1)
+ continue
+ cf_len = min(cf_len_target, out.shape[-1], nxt.shape[-1])
+ if cf_len <= 0:
+ out = torch.cat([out, nxt], dim=-1)
+ continue
+ fade_out = torch.linspace(1.0, 0.0, steps=cf_len, dtype=out.dtype, device=out.device)
+ fade_in = torch.linspace(0.0, 1.0, steps=cf_len, dtype=nxt.dtype, device=nxt.device)
+ overlap = out[0, -cf_len:] * fade_out + nxt[0, :cf_len] * fade_in
+ out = torch.cat([out[:, :-cf_len], overlap.unsqueeze(0), nxt[:, cf_len:]], dim=-1)
+ return out
+
+
def load_model(
model_path,
spt_config_path,
@@ -34,71 +94,72 @@ def load_model(
def process_jsonl_item(item):
- """Process JSONL data items and extract audio and text information according to the new format"""
- base_path = item.get("base_path", "")
+ """Parse a JSONL item enforcing prompt requirement.
+
+ Only supports Format 1 (separate speaker refs) and Format 2 (shared ref),
+ consistent with the updated README. If `base_path` is missing/empty, any
+ string paths must be absolute. Text-only input is not supported and will raise.
+ """
+ base_path = item.get("base_path", "") or ""
text = item.get("text", "")
+ def _resolve_path(p: str) -> str:
+ if not isinstance(p, str) or not p:
+ return p
+ if base_path:
+ return os.path.join(base_path, p)
+ # base_path missing: require absolute path
+ if not os.path.isabs(p):
+ raise ValueError(
+ "When base_path is omitted, audio paths must be absolute. Got: " + p
+ )
+ return p
+
+ # Try Format 2 first: shared audio reference
prompt_audio = None
prompt_text = ""
-
- # Process prompt audio and text
- if "prompt_audio" in item and "prompt_text" in item:
- print("Using prompt_audio and prompt_text directly from item.")
- # If prompt_audio and prompt_text exist, use them directly
- prompt_audio_val = item["prompt_audio"]
- if prompt_audio_val: # Only assign if not empty
+ if "prompt_audio" in item:
+ prompt_audio_val = item.get("prompt_audio")
+ if not prompt_audio_val:
+ raise ValueError("Format 2 requires non-empty 'prompt_audio'.")
+ if isinstance(prompt_audio_val, str):
+ prompt_audio = _resolve_path(prompt_audio_val)
+ else:
+ # allow tuple form for backward-compatibility
prompt_audio = prompt_audio_val
- prompt_text = item["prompt_text"]
-
- # Only perform path joining when prompt_audio is a string path
- if isinstance(prompt_audio, str) and base_path and prompt_audio:
- prompt_audio = os.path.join(base_path, prompt_audio)
- else:
- # Otherwise, merge speaker1 and speaker2 information
- prompt_audio_speaker1 = item.get("prompt_audio_speaker1", "")
- prompt_text_speaker1 = item.get("prompt_text_speaker1", "")
- prompt_audio_speaker2 = item.get("prompt_audio_speaker2", "")
- prompt_text_speaker2 = item.get("prompt_text_speaker2", "")
-
- has_speaker1_audio = (
- isinstance(prompt_audio_speaker1, str) and prompt_audio_speaker1
- ) or isinstance(prompt_audio_speaker1, tuple)
- has_speaker2_audio = (
- isinstance(prompt_audio_speaker2, str) and prompt_audio_speaker2
- ) or isinstance(prompt_audio_speaker2, tuple)
-
- if has_speaker1_audio or has_speaker2_audio:
- print("Using speaker1 and speaker2 information for prompt audio and text.")
- # Process audio: if it's a string path, perform path joining; if it's a tuple, use directly
- if isinstance(prompt_audio_speaker1, str):
- speaker1_audio = (
- os.path.join(base_path, prompt_audio_speaker1)
- if base_path and prompt_audio_speaker1
- else prompt_audio_speaker1
- )
- else:
- speaker1_audio = prompt_audio_speaker1 # Use tuple directly
-
- if isinstance(prompt_audio_speaker2, str):
- speaker2_audio = (
- os.path.join(base_path, prompt_audio_speaker2)
- if base_path and prompt_audio_speaker2
- else prompt_audio_speaker2
- )
- else:
- speaker2_audio = prompt_audio_speaker2 # Use tuple directly
-
- prompt_audio = {"speaker1": speaker1_audio, "speaker2": speaker2_audio}
-
- # Merge text
- temp_prompt_text = ""
- if prompt_text_speaker1:
- temp_prompt_text += f"[S1]{prompt_text_speaker1}"
- if prompt_text_speaker2:
- temp_prompt_text += f"[S2]{prompt_text_speaker2}"
- prompt_text = temp_prompt_text.strip()
-
- return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}
+ prompt_text = item.get("prompt_text", "")
+ return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}
+
+ # Try Format 1: separate speaker references
+ s1 = item.get("prompt_audio_speaker1", "")
+ s2 = item.get("prompt_audio_speaker2", "")
+ has_s1 = ((isinstance(s1, str) and s1) or isinstance(s1, tuple))
+ has_s2 = ((isinstance(s2, str) and s2) or isinstance(s2, tuple))
+
+ if has_s1 and has_s2:
+ if isinstance(s1, str) and s1:
+ s1_resolved = _resolve_path(s1)
+ else:
+ s1_resolved = s1
+ if isinstance(s2, str) and s2:
+ s2_resolved = _resolve_path(s2)
+ else:
+ s2_resolved = s2
+ # Build merged prompt audio dict
+ prompt_audio = {"speaker1": s1_resolved, "speaker2": s2_resolved}
+ # Merge texts
+ pt1 = item.get("prompt_text_speaker1", "")
+ pt2 = item.get("prompt_text_speaker2", "")
+ merged = ""
+ if pt1:
+ merged += f"[S1]{pt1}"
+ if pt2:
+ merged += f"[S2]{pt2}"
+ prompt_text = merged.strip()
+ return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}
+
+ # Otherwise, no supported prompt found → reject (text-only unsupported)
+ raise ValueError("Input must include prompt (Format 1 or 2). Text-only is not supported.")
def load_audio_data(prompt_audio, target_sample_rate=16000):
@@ -288,25 +349,22 @@ def normalize_text(text: str) -> str:
Normalize multi-speaker script.
1. Don't preserve line breaks.
- 2. Remove brackets for non-speaker tags (if [] doesn't contain S1/S2...Sx format, remove the brackets themselves).
- 3. Remove decorative symbols: 【】《》()『』「」"-“” .
- 4. Internal punctuation !;:、 → ,;only allow ? and ,。
+ 2. Preserve bracketed segments like [] () <> even when they are not speaker tags.
+ 3. Remove decorative symbols: 【】《》()『』「」~~-_.
+ 4. Internal punctuation ;:、 → ,;keep ?!?.
5. Multiple 。 keep only the last one, others → ,。
6. Replace consecutive "哈" (>=2) with "(笑)".
7. Auto-recognize [S1] / [S2] … tags; if missing, treat as whole segment.
8. Merge adjacent identical speaker tags.
"""
# Replace [1], [2] etc. format with [S1], [S2] etc. format
- text = re.sub(r"\[(\d+)\]", r"[S\1]", text)
+ text = re.sub(r'\[(\d+)\]', r'[S\1]', text)
# Remove decorative characters
- remove_chars = "【】《》()『』「」" '"-_“”~~'
-
- # Remove brackets for non-speaker tags (keep content, only remove brackets themselves)
- text = re.sub(r"\[(?!S\d+\])([^\]]*)\]", r"\1", text)
+ remove_chars = "【】《》()『』「」~~-_"
# Use positive lookahead to split text by speaker tags (tags themselves are still preserved)
- segments = re.split(r"(?=\[S\d+\])", text.replace("\n", " "))
+ segments = re.split(r'(?=\[S\d+\])', text.replace("\n", " "))
processed_parts = []
for seg in segments:
@@ -315,70 +373,58 @@ def normalize_text(text: str) -> str:
continue
# Extract tags
- m = re.match(r"^(\[S\d+\])\s*(.*)", seg)
- tag, content = m.groups() if m else ("", seg)
+ m = re.match(r'^(\[S\d+\])\s*(.*)', seg)
+ tag, content = m.groups() if m else ('', seg)
# Remove irrelevant symbols
content = re.sub(f"[{re.escape(remove_chars)}]", "", content)
# Handle consecutive "哈" characters: replace 2 or more with "(笑)"
- content = re.sub(r"哈{2,}", "(笑)", content)
+ content = re.sub(r'哈{2,}', '[笑]', content)
# Handle English laughter (e.g., "haha", "ha ha")
- content = re.sub(r"\b(ha(\s*ha)+)\b", "(laughs)", content, flags=re.IGNORECASE)
+ content = re.sub(r'\b(ha(\s*ha)+)\b', '[laugh]', content, flags=re.IGNORECASE)
# First handle multi-character punctuation marks
- content = content.replace("——", ",")
- content = content.replace("……", ",")
+ content = content.replace('——', ',')
+ content = content.replace('……', ',')
# Handle single-character internal punctuation marks
- internal_punct_map = str.maketrans(
- {
- "!": ",",
- "!": ",",
- ";": ",",
- ";": ",",
- ":": ",",
- ":": ",",
- "、": ",",
- "?": ",",
- "?": ",",
- }
- )
+ internal_punct_map = str.maketrans({
+ ';': ',', ';': ',',
+ ':': ',', ':': ',',
+ '、': ','
+ })
content = content.translate(internal_punct_map)
content = content.strip()
# Keep only the final period
if len(content) > 1:
- last_ch = (
- "。"
- if content[-1] == ","
- else ("." if content[-1] == "," else content[-1])
- )
- body = content[:-1].replace("。", ",")
+ last_ch = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1])
+ body = content[:-1].replace('。', ',')
content = body + last_ch
- processed_parts.append({"tag": tag, "content": content})
+ processed_parts.append({'tag': tag, 'content': content})
if not processed_parts:
return ""
# Merge consecutive same speakers
merged_lines = []
- current_tag = processed_parts[0]["tag"]
- current_content = [processed_parts[0]["content"]]
+ current_tag = processed_parts[0]['tag']
+ current_content = [processed_parts[0]['content']]
for part in processed_parts[1:]:
- if part["tag"] == current_tag and current_tag:
- current_content.append(part["content"])
+ if part['tag'] == current_tag and current_tag:
+ current_content.append(part['content'])
else:
merged_lines.append(f"{current_tag}{''.join(current_content)}".strip())
- current_tag = part["tag"]
- current_content = [part["content"]]
+ current_tag = part['tag']
+ current_content = [part['content']]
merged_lines.append(f"{current_tag}{''.join(current_content)}".strip())
-
- return "".join(merged_lines).replace("‘", "'").replace("’", "'")
+
+ return "".join(merged_lines).replace('‘', "'").replace('’', "'")
def process_batch(
@@ -507,28 +553,143 @@ def process_batch(
f"Speech token shape for sample {start_idx + i}: {this_speech_id.shape}"
)
- # Decode generated audio
- with torch.no_grad():
- codes_list = [
- this_speech_id.permute(1, 0)
- ] # Convert to SPT expected format
- decode_result = spt.decode(codes_list, overlap_seconds=10)
- audio_result = decode_result["syn_wav_list"][0].cpu().detach()
-
- if audio_result.ndim == 1: # If 1D [samples]
- audio_result = audio_result.unsqueeze(
- 0
- ) # Convert to 2D [1, samples]
-
- # Save audio data instead of file path
- audio_results.append(
- {
- "audio_data": audio_result,
- "sample_rate": spt.output_sample_rate,
- "index": start_idx + i,
- }
- )
- print(f"Audio generation completed: sample {start_idx + i}")
+ # Prompt-Augmented Decode (rvq8-style); fall back to original decode if no prompt
+ prompt_audio = prompt_audios[i]
+ if prompt_audio is None:
+ # Fallback to original decode
+ with torch.no_grad():
+ codes_list = [this_speech_id.permute(1, 0)]
+ decode_result = spt.decode(codes_list, overlap_seconds=10)
+ audio_out = decode_result["syn_wav_list"][0].cpu().detach()
+ if audio_out.ndim == 1:
+ audio_out = audio_out.unsqueeze(0)
+ audio_results.append(
+ {
+ "audio_data": audio_out,
+ "sample_rate": spt.output_sample_rate,
+ "index": start_idx + i,
+ }
+ )
+ print(f"Audio generation completed (orig): sample {start_idx + i}")
+ else:
+ # 1) Load prompt at SPT input sr and force to 20s
+ ref_sr_in = (
+ getattr(spt, "input_sample_rate", None)
+ or getattr(spt, "sampling_rate", None)
+ or 24000
+ )
+ ref_wav = load_audio_data(prompt_audio, target_sample_rate=ref_sr_in)
+ if ref_wav is None:
+ # If ref missing, use original decode
+ with torch.no_grad():
+ codes_list = [this_speech_id.permute(1, 0)]
+ decode_result = spt.decode(codes_list, overlap_seconds=10)
+ audio_out = decode_result["syn_wav_list"][0].cpu().detach()
+ if audio_out.ndim == 1:
+ audio_out = audio_out.unsqueeze(0)
+ audio_results.append(
+ {
+ "audio_data": audio_out,
+ "sample_rate": spt.output_sample_rate,
+ "index": start_idx + i,
+ }
+ )
+ print(f"Audio generation completed (orig no-ref): sample {start_idx + i}")
+ else:
+ # Encode 20s reference to tokens
+ ref_wav_20s = pad_or_truncate_to_seconds(ref_wav, 20.0, ref_sr_in).to(device)
+ with torch.no_grad():
+ enc = spt.encode([ref_wav_20s.squeeze(0)])
+ ref_codes = enc["codes_list"][0].to(device).long() # (nq, T_ref)
+
+ # Prepare token-to-sample mapping and windowing params
+ out_sr = (
+ getattr(spt, "output_sample_rate", None)
+ or getattr(spt, "sample_rate", None)
+ or 24000
+ )
+ tokens_per_second = float(ref_sr_in) / float(spt.encoder_downsample_rate)
+ tokens_per_chunk = int(round(10.0 * tokens_per_second))
+ stride_tokens = 85
+ keep_tokens = 85
+ left_ctx_tokens = 20
+ total_tokens = this_speech_id.shape[0]
+ samples_per_token = int(round(out_sr / tokens_per_second))
+ crossfade_seconds = 0.1
+ crossfade_samples = int(round(crossfade_seconds * out_sr))
+
+ kept_segments = []
+ chunk_idx = 0
+ while True:
+ st_tok = chunk_idx * stride_tokens
+ if st_tok >= total_tokens:
+ break
+ ed_tok = min(st_tok + tokens_per_chunk, total_tokens)
+ gen_chunk = this_speech_id[st_tok:ed_tok] # (len, C)
+ if gen_chunk.shape[0] == 0:
+ break
+
+ # Concatenate reference tokens with current window tokens
+ combined_codes = torch.cat(
+ [ref_codes, gen_chunk.permute(1, 0).long()], dim=1
+ ).to(device) # (nq, T_ref + T_chunk)
+ codes_lengths = torch.tensor(
+ [combined_codes.shape[-1]], dtype=torch.long, device=device
+ )
+ combined_codes_batched = combined_codes.unsqueeze(1) # (nq, 1, T)
+
+ with torch.no_grad():
+ detok = spt.inference_detokenize(combined_codes_batched, codes_lengths)
+ y = detok["y"][0, 0] # (T_samples)
+
+ # Remove 20s reference portion (in samples)
+ ref_samples = int(round(20.0 * out_sr))
+ if y.shape[-1] <= ref_samples:
+ chunk_idx += 1
+ continue
+ chunk_y = y[ref_samples:]
+
+ # Determine kept region within current window
+ window_len = gen_chunk.shape[0]
+ remains = total_tokens - st_tok
+ is_first = chunk_idx == 0
+ is_last = ed_tok >= total_tokens
+
+ if is_first:
+ keep_start_tok = 0
+ keep_end_tok = min(keep_tokens + left_ctx_tokens, window_len)
+ elif is_last and remains < 105:
+ keep_start_tok = 0 if is_first else min(left_ctx_tokens, window_len)
+ keep_end_tok = window_len
+ else:
+ keep_start_tok = min(left_ctx_tokens, window_len)
+ keep_end_tok = min(left_ctx_tokens + keep_tokens, window_len)
+
+ keep_start_smps = keep_start_tok * samples_per_token
+ keep_end_smps = keep_end_tok * samples_per_token
+ left_margin = 0
+ right_margin = crossfade_samples if not is_last else 0
+ seg_start = max(0, keep_start_smps - left_margin)
+ seg_end = min(chunk_y.shape[-1], keep_end_smps + right_margin)
+ if seg_end > seg_start:
+ kept_segments.append(chunk_y[seg_start:seg_end].detach().cpu().unsqueeze(0))
+
+ chunk_idx += 1
+
+ # Concatenate with crossfade; if empty, return tiny silence
+ if len(kept_segments) == 0:
+ audio_out = torch.zeros(1, int(0.01 * out_sr))
+ else:
+ audio_out = crossfade_concat(kept_segments, out_sr, crossfade_seconds=crossfade_seconds)
+
+ audio_results.append(
+ {
+ "audio_data": audio_out,
+ "sample_rate": out_sr,
+ "index": start_idx + i,
+ }
+ )
+ print(f"Audio generation completed (prompt-aug): sample {start_idx + i}")
except Exception as e:
print(f"Error processing sample {start_idx + i}: {str(e)}, skipping...")
diff --git a/inference.py b/inference.py
index 31f5a30..df4b8de 100644
--- a/inference.py
+++ b/inference.py
@@ -7,10 +7,10 @@
from generation_utils import load_model, process_batch
-MODEL_PATH = "fnlp/MOSS-TTSD-v0.5"
+MODEL_PATH = "fnlp/MOSS-TTSD-v0.7"
SYSTEM_PROMPT = "You are a speech synthesizer that generates natural, realistic, and human-like conversational audio from dialogue text."
-SPT_CONFIG_PATH = "XY_Tokenizer/config/xy_tokenizer_32k_config.yaml"
-SPT_CHECKPOINT_PATH = "XY_Tokenizer/weights/xy_tokenizer.ckpt"
+SPT_CONFIG_PATH = "XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml"
+SPT_CHECKPOINT_PATH ="XY_Tokenizer/weights/MOSS_TTSD_tokenizer"
MAX_CHANNELS = 8
def main():
From b5fc87d400eaddae99bdb4380e8ee394deb52072 Mon Sep 17 00:00:00 2001
From: rulerman <1330677461@qq.com>
Date: Fri, 31 Oct 2025 16:03:18 +0000
Subject: [PATCH 2/4] update v0.7
---
README.md | 4 ++--
README_zh.md | 4 ++--
XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml | 14 ++++++++++++++
XY_Tokenizer/xy_tokenizer/model.py | 20 +++-----------------
generation_utils.py | 2 +-
5 files changed, 22 insertions(+), 22 deletions(-)
diff --git a/README.md b/README.md
index fa02cdc..87edd1d 100644
--- a/README.md
+++ b/README.md
@@ -28,7 +28,7 @@ MOSS-TTSD supports voice cloning and long single-session speech generation, maki
## Highlights
-- **Highly Expressive Dialogue Speech**: Built on unified semantic-acoustic neural audio codec, a pre-trained large language model, millions of hours of TTS data, and 600k hours synthetic and real conversational speech, MOSS-TTSD generates highly expressive, human-like dialogue speech with natural conversational prosody.
+- **Highly Expressive Dialogue Speech**: Built on unified semantic-acoustic neural audio codec, a pre-trained large language model, millions of hours of TTS data and conversational speech, MOSS-TTSD generates highly expressive, human-like dialogue speech with natural conversational prosody.
- **Two-Speaker Voice Cloning**: MOSS-TTSD supports zero-shot two speakers voice cloning and can generate conversational speech with accurate speaker swithcing based on dialogue scripts. Only 10 to 20 seconds of reference audio is needed.
- **Chinese-English Bilingual Support**: MOSS-TTSD enables highly expressive speech generation in both Chinese and English.
- **Long-Form Speech Generation**: Thanks to low-bitrate codec and training framework optimization, MOSS-TTSD has been trained for long speech generation (Training maximum length is 1700s).
@@ -36,7 +36,7 @@ MOSS-TTSD supports voice cloning and long single-session speech generation, maki
## News 🚀
- - **[2025-10-31]** MOSS-TTSD v0.7 is released! v0.7 has significantly improved audio quality, voice cloning capability, and stability, greatly extended single-pass generation length (960s→1700s), and more reliably generates speech events following speaker tags. We recommend using the v0.7 model by default.
+ - **[2025-11-01]** MOSS-TTSD v0.7 is released! v0.7 significantly improves audio quality, voice cloning capability, and stability, adds support for 32 kHz high‑quality output, greatly extends single‑pass generation length (960s→1700s), and more reliably generates speech events following speaker tags. We recommend using the v0.7 model by default.
- **[2025-09-09]** We supported SGLang inference engine to accelerate model inference by up to **16x**.
- **[2025-08-25]** We released the 32khz version of XY-Tokenizer.
- **[2025-08-12]** We add support for streaming inference in MOSS-TTSD v0.5.
diff --git a/README_zh.md b/README_zh.md
index 5b19669..ecf5c79 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -26,7 +26,7 @@ MOSS-TTSD(text to spoken dialogue)是一个开源的中英双语口语对话
## 亮点
-- **高表现力对话语音**:基于统一语义-声学神经音频Codec、预训练大语言模型、百万小时TTS数据与约60万小时的真实/合成对话语音数据,MOSS-TTSD能够生成高表现力,高自然度,具有自然对话韵律的拟人对话语音。
+- **高表现力对话语音**:基于统一语义-声学神经音频Codec、预训练大语言模型、百万小时TTS数据与对话语音数据,MOSS-TTSD能够生成高表现力,高自然度,具有自然对话韵律的拟人对话语音。
- **双说话人零样本声音克隆**:MOSS-TTSD支持零样本双说话人克隆,按脚本精确进行角色/声线切换。只需要提供10到20秒的参考音频片段。
- **中英双语**:MOSS-TTSD支持中英两种语言的高表现力语音生成。
- **长音频生成**:得益于低码率Codec与训练框架优化,MOSS-TTSD在长音频生成场景进行了大量训练(训练最大长度达到1700s),能够单次生成超长音频。
@@ -34,7 +34,7 @@ MOSS-TTSD(text to spoken dialogue)是一个开源的中英双语口语对话
## 最新动态 🚀
-- **[2025-10-31]** 我们发布了 MOSS-TTSD v0.7:显著提升了音质、声音克隆能力与稳定性,大幅拓展了单次生成长度(960s->1700s),更够比较稳定地根据说话人标签生成语音事件。
+- **[2025-11-01]** 我们发布了 MOSS-TTSD v0.7:显著提升了音质、声音克隆能力与稳定性,支持32khz高音质输出,并大幅拓展了单次生成长度(960s->1700s),更够比较稳定地根据说话人标签生成语音事件。
- **[2025-09-09]** 我们支持了 SGLang 推理引擎加速模型推理,最高可加速**16倍**。
- **[2025-08-25]** 我们发布了 32khz XY-Tokenizer。
- **[2025-08-12]** 我们支持了 MOSS-TTSD v0.5 的流式推理。
diff --git a/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml b/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml
index ce539c3..76b7618 100644
--- a/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml
+++ b/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml
@@ -1,6 +1,20 @@
generator_params:
input_sample_rate: 16000
output_sample_rate: 32000
+ encoder_downsample_rate: 1280
+ decoder_upsample_rate: 2560
+
+ feature_extractor_kwargs:
+ chunk_length: 30
+ feature_size: 80
+ hop_length: 160
+ n_fft: 400
+ n_samples: 480000
+ nb_max_frames: 3000
+ padding_side: right
+ padding_value: 0.0
+ return_attention_mask: false
+ sampling_rate: 16000
# Codec / model architecture (inference required)
semantic_encoder_kwargs: # 100hz -> 50hz
diff --git a/XY_Tokenizer/xy_tokenizer/model.py b/XY_Tokenizer/xy_tokenizer/model.py
index 4c51d5d..971ec93 100644
--- a/XY_Tokenizer/xy_tokenizer/model.py
+++ b/XY_Tokenizer/xy_tokenizer/model.py
@@ -17,8 +17,8 @@ def __init__(self, generator_params):
self.input_sample_rate = generator_params['input_sample_rate']
self.output_sample_rate = generator_params['output_sample_rate']
- self.encoder_downsample_rate = 1280
- self.decoder_upsample_rate = int(self.output_sample_rate / 12.5)
+ self.encoder_downsample_rate = generator_params['encoder_downsample_rate']
+ self.decoder_upsample_rate = generator_params['decoder_upsample_rate']
self.code_dim = generator_params['quantizer_kwargs']['input_dim']
## Codec part
@@ -49,22 +49,8 @@ def __init__(self, generator_params):
self.enhanced_vocos = Vocos(**generator_params['vocos_kwargs'])
## Feature extractor
- default_feature_extractor_kwargs = {
- 'chunk_length': 30,
- 'feature_size': 80,
- 'hop_length': 160,
- 'n_fft': 400,
- 'n_samples': 480000,
- 'nb_max_frames': 3000,
- 'padding_side': 'right',
- 'padding_value': 0.0,
- 'return_attention_mask': False,
- 'sampling_rate': self.input_sample_rate,
- }
fe_kwargs = generator_params.get('feature_extractor_kwargs', {})
- merged_fe_kwargs = {**default_feature_extractor_kwargs, **fe_kwargs}
- merged_fe_kwargs['sampling_rate'] = self.input_sample_rate
- self.feature_extractor = MelFeatureExtractor(**merged_fe_kwargs)
+ self.feature_extractor = MelFeatureExtractor(**fe_kwargs)
@torch.inference_mode()
def inference_tokenize(self, x, input_lengths):
diff --git a/generation_utils.py b/generation_utils.py
index b7c2d09..2746a70 100644
--- a/generation_utils.py
+++ b/generation_utils.py
@@ -361,7 +361,7 @@ def normalize_text(text: str) -> str:
text = re.sub(r'\[(\d+)\]', r'[S\1]', text)
# Remove decorative characters
- remove_chars = "【】《》()『』「」~~-_"
+ remove_chars = "【】《》()『』「」" '"-_“”~~'
# Use positive lookahead to split text by speaker tags (tags themselves are still preserved)
segments = re.split(r'(?=\[S\d+\])', text.replace("\n", " "))
From 7d567c1b8b862fe50322e0d2311b0d4647e08ada Mon Sep 17 00:00:00 2001
From: xiami2019 <435350193@qq.com>
Date: Mon, 3 Nov 2025 11:44:18 +0800
Subject: [PATCH 3/4] format
---
.gitignore | 2 +-
LICENSE | 2 +-
XY_Tokenizer/.gitignore | 2 +-
XY_Tokenizer/README.md | 2 +-
XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml | 1 -
.../config/xy_tokenizer_32k_config.yaml | 12 +-
XY_Tokenizer/config/xy_tokenizer_config.yaml | 12 +-
XY_Tokenizer/inference.py | 85 ++-
XY_Tokenizer/requirements.txt | 2 +-
XY_Tokenizer/utils/helpers.py | 108 ++--
XY_Tokenizer/xy_tokenizer/model.py | 385 +++++++-----
.../xy_tokenizer/nn/feature_extractor.py | 82 ++-
XY_Tokenizer/xy_tokenizer/nn/modules.py | 577 +++++++++++-------
XY_Tokenizer/xy_tokenizer/nn/quantizer.py | 332 ++++++----
examples/example.txt | 2 +-
examples/examples.jsonl | 2 +-
examples/examples_only_text.jsonl | 2 +-
examples/examples_single_reference.jsonl | 2 +-
finetune/data_preprocess.py | 307 ++++++----
finetune/finetune.py | 269 +++++---
finetune/finetune_config.yaml | 4 +-
finetune/finetune_workflow.py | 104 ++--
finetune/lora_config.yaml | 2 +-
finetune/requirements_finetune.txt | 2 +-
finetune/training_config.yaml | 2 +-
generation_utils.py | 150 +++--
gradio_demo.py | 318 ++++++----
inference.py | 130 ++--
modeling_asteroid.py | 277 ++++++---
podcast_generate.py | 211 ++++---
requirements.txt | 2 +-
streamer.py | 248 +++++---
use_api.py | 222 ++++---
33 files changed, 2460 insertions(+), 1400 deletions(-)
diff --git a/.gitignore b/.gitignore
index b286de5..beed4f7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,2 @@
*.pyc
-outputs/*
\ No newline at end of file
+outputs/*
diff --git a/LICENSE b/LICENSE
index daab1b6..afe8540 100644
--- a/LICENSE
+++ b/LICENSE
@@ -198,4 +198,4 @@
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
- limitations under the License.
\ No newline at end of file
+ limitations under the License.
diff --git a/XY_Tokenizer/.gitignore b/XY_Tokenizer/.gitignore
index 7804221..d60efb5 100644
--- a/XY_Tokenizer/.gitignore
+++ b/XY_Tokenizer/.gitignore
@@ -190,4 +190,4 @@ backup*
output_wavs/
*.pt
*.pth
-output.log
\ No newline at end of file
+output.log
diff --git a/XY_Tokenizer/README.md b/XY_Tokenizer/README.md
index 1aecd1a..2387a53 100644
--- a/XY_Tokenizer/README.md
+++ b/XY_Tokenizer/README.md
@@ -69,4 +69,4 @@ The model processes audio through several stages:
## License
-[Specify your license here]
\ No newline at end of file
+[Specify your license here]
diff --git a/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml b/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml
index 76b7618..b2de9dd 100644
--- a/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml
+++ b/XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml
@@ -112,4 +112,3 @@ generator_params:
n_fft: 1280
hop_size: 320
padding: "same"
-
diff --git a/XY_Tokenizer/config/xy_tokenizer_32k_config.yaml b/XY_Tokenizer/config/xy_tokenizer_32k_config.yaml
index 141cdc6..8a33dcc 100644
--- a/XY_Tokenizer/config/xy_tokenizer_32k_config.yaml
+++ b/XY_Tokenizer/config/xy_tokenizer_32k_config.yaml
@@ -62,7 +62,7 @@ generator_params:
activation_function: "gelu"
- ## semantic & acoustic shared parameters
+ ## semantic & acoustic shared parameters
pre_rvq_adapter_kwargs: # 50hz
input_dim: 1536
output_dim: 768
@@ -71,7 +71,7 @@ generator_params:
encoder_layers: 4
encoder_attention_heads: 12
encoder_ffn_dim: 3072
-
+
downsample_kwargs: # 50hz -> 12.5hz
d_model: 768
avg_pooler: 4
@@ -85,7 +85,7 @@ generator_params:
codebook_dim: 512
quantizer_dropout: 0.0
commitment: 1
-
+
post_rvq_adapter_kwargs: # 12.5hz
input_dim: 3072
output_dim: 3072
@@ -98,7 +98,7 @@ generator_params:
upsample_kwargs: # 12.5hz -> 50hz
d_model: 768
stride: 4
-
+
## acoustic channel
acoustic_decoder_kwargs: # 50hz -> 100hz
num_mel_bins: 80
@@ -113,7 +113,7 @@ generator_params:
decoder_attention_heads: 12
decoder_ffn_dim: 3072
activation_function: "gelu"
-
+
vocos_kwargs: # 100hz -> 32khz
input_channels: 80
dim: 512
@@ -121,4 +121,4 @@ generator_params:
num_layers: 30
n_fft: 1280
hop_size: 320
- padding: "same"
\ No newline at end of file
+ padding: "same"
diff --git a/XY_Tokenizer/config/xy_tokenizer_config.yaml b/XY_Tokenizer/config/xy_tokenizer_config.yaml
index 331c5d7..7232d15 100644
--- a/XY_Tokenizer/config/xy_tokenizer_config.yaml
+++ b/XY_Tokenizer/config/xy_tokenizer_config.yaml
@@ -62,7 +62,7 @@ generator_params:
activation_function: "gelu"
- ## semantic & acoustic shared parameters
+ ## semantic & acoustic shared parameters
pre_rvq_adapter_kwargs: # 50hz
input_dim: 1536
output_dim: 768
@@ -71,7 +71,7 @@ generator_params:
encoder_layers: 4
encoder_attention_heads: 12
encoder_ffn_dim: 3072
-
+
downsample_kwargs: # 50hz -> 12.5hz
d_model: 768
avg_pooler: 4
@@ -85,7 +85,7 @@ generator_params:
codebook_dim: 512
quantizer_dropout: 0.0
commitment: 1
-
+
post_rvq_adapter_kwargs: # 12.5hz
input_dim: 3072
output_dim: 3072
@@ -98,7 +98,7 @@ generator_params:
upsample_kwargs: # 12.5hz -> 50hz
d_model: 768
stride: 4
-
+
## acoustic channel
acoustic_decoder_kwargs: # 50hz -> 100hz
num_mel_bins: 80
@@ -113,7 +113,7 @@ generator_params:
decoder_attention_heads: 12
decoder_ffn_dim: 3072
activation_function: "gelu"
-
+
vocos_kwargs: # 100hz -> 24khz
input_channels: 80
dim: 512
@@ -121,4 +121,4 @@ generator_params:
num_layers: 30
n_fft: 960
hop_size: 240
- padding: "same"
\ No newline at end of file
+ padding: "same"
diff --git a/XY_Tokenizer/inference.py b/XY_Tokenizer/inference.py
index 689893a..d01ea7a 100644
--- a/XY_Tokenizer/inference.py
+++ b/XY_Tokenizer/inference.py
@@ -1,27 +1,40 @@
-import os
import argparse
import logging
-import torch
+import os
-from utils.helpers import set_logging, waiting_for_debug, load_audio, save_audio, find_audio_files
+import torch
+from utils.helpers import (
+ find_audio_files,
+ load_audio,
+ save_audio,
+ set_logging,
+ waiting_for_debug,
+)
from xy_tokenizer.model import XY_Tokenizer
if __name__ == "__main__":
set_logging()
-
+
parser = argparse.ArgumentParser()
- parser.add_argument("--config_path", type=str, default="./config/xy_tokenizer_32k_config.yaml")
- parser.add_argument("--checkpoint_path", type=str, default="./weights/xy_tokenizer.ckpt")
+ parser.add_argument(
+ "--config_path", type=str, default="./config/xy_tokenizer_32k_config.yaml"
+ )
+ parser.add_argument(
+ "--checkpoint_path", type=str, default="./weights/xy_tokenizer.ckpt"
+ )
parser.add_argument("--device", type=str, default="cuda")
-
+
parser.add_argument("--input_dir", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
-
-
+
parser.add_argument("--debug_ip", type=str)
parser.add_argument("--debug_port", type=int)
- parser.add_argument("--debug", default=0, type=int, nargs="?",
- help='whether debug or not',
+ parser.add_argument(
+ "--debug",
+ default=0,
+ type=int,
+ nargs="?",
+ help="whether debug or not",
)
args = parser.parse_args()
if args.debug == 1:
@@ -30,42 +43,66 @@
device = torch.device(args.device)
## Load codec model
- generator = XY_Tokenizer.load_from_checkpoint(config_path=args.config_path, ckpt_path=args.checkpoint_path).to(device).eval()
-
+ generator = (
+ XY_Tokenizer.load_from_checkpoint(
+ config_path=args.config_path, ckpt_path=args.checkpoint_path
+ )
+ .to(device)
+ .eval()
+ )
+
## Find audios
audio_paths = find_audio_files(input_dir=args.input_dir)
-
+
## Create output directory if not exists
os.makedirs(args.output_dir, exist_ok=True)
- logging.info(f"Processing {len(audio_paths)} audio files, output will be saved to {args.output_dir}")
+ logging.info(
+ f"Processing {len(audio_paths)} audio files, output will be saved to {args.output_dir}"
+ )
with torch.no_grad():
## Process audios in batches
batch_size = 8
for i in range(0, len(audio_paths), batch_size):
- batch_paths = audio_paths[i:i + batch_size]
- logging.info(f"Processing batch {i // batch_size + 1}/{len(audio_paths) // batch_size + 1}, files: {batch_paths}")
+ batch_paths = audio_paths[i : i + batch_size]
+ logging.info(
+ f"Processing batch {i // batch_size + 1}/{len(audio_paths) // batch_size + 1}, files: {batch_paths}"
+ )
# Load audio files
- wav_list = [load_audio(path, target_sample_rate=generator.input_sample_rate).squeeze().to(device) for path in batch_paths]
- logging.info(f"Successfully loaded {len(wav_list)} audio files with lengths {[len(wav) for wav in wav_list]} samples")
+ wav_list = [
+ load_audio(path, target_sample_rate=generator.input_sample_rate)
+ .squeeze()
+ .to(device)
+ for path in batch_paths
+ ]
+ logging.info(
+ f"Successfully loaded {len(wav_list)} audio files with lengths {[len(wav) for wav in wav_list]} samples"
+ )
# Encode
encode_result = generator.encode(wav_list, overlap_seconds=10)
codes_list = encode_result["codes_list"] # B * (nq, T)
- logging.info(f"Encoding completed, code lengths: {[codes.shape[-1] for codes in codes_list]}")
+ logging.info(
+ f"Encoding completed, code lengths: {[codes.shape[-1] for codes in codes_list]}"
+ )
logging.info(f"{codes_list = }")
# Decode
decode_result = generator.decode(codes_list, overlap_seconds=10)
syn_wav_list = decode_result["syn_wav_list"] # B * (T,)
- logging.info(f"Decoding completed, generated waveform lengths: {[len(wav) for wav in syn_wav_list]} samples")
+ logging.info(
+ f"Decoding completed, generated waveform lengths: {[len(wav) for wav in syn_wav_list]} samples"
+ )
# Save generated audios
for path, syn_wav in zip(batch_paths, syn_wav_list):
output_path = os.path.join(args.output_dir, os.path.basename(path))
- save_audio(output_path, syn_wav.cpu().reshape(1, -1), sample_rate=generator.output_sample_rate)
+ save_audio(
+ output_path,
+ syn_wav.cpu().reshape(1, -1),
+ sample_rate=generator.output_sample_rate,
+ )
logging.info(f"Saved generated audio to {output_path}")
-
- logging.info("All audio processing completed")
\ No newline at end of file
+ logging.info("All audio processing completed")
diff --git a/XY_Tokenizer/requirements.txt b/XY_Tokenizer/requirements.txt
index fba07a1..c16cb4f 100644
--- a/XY_Tokenizer/requirements.txt
+++ b/XY_Tokenizer/requirements.txt
@@ -20,4 +20,4 @@ stopes
s3prl
onnxscript
jiwer
-orjson
\ No newline at end of file
+orjson
diff --git a/XY_Tokenizer/utils/helpers.py b/XY_Tokenizer/utils/helpers.py
index 9b144a4..e7ef1c8 100644
--- a/XY_Tokenizer/utils/helpers.py
+++ b/XY_Tokenizer/utils/helpers.py
@@ -1,61 +1,72 @@
+import glob
import logging
-import torchaudio
import os
+import re
import sys
-import glob
+
import debugpy
-import torch
import numpy as np
-import re
+import torch
+import torchaudio
+
def count_params_by_module(model_name, model):
logging.info(f"Counting num_parameters of {model_name}:")
-
+
param_stats = {}
total_params = 0 # Count total parameters
total_requires_grad_params = 0 # Count parameters with requires_grad=True
total_no_grad_params = 0 # Count parameters with requires_grad=False
-
+
for name, param in model.named_parameters():
- module_name = name.split('.')[0]
+ module_name = name.split(".")[0]
if module_name not in param_stats:
- param_stats[module_name] = {'total': 0, 'requires_grad': 0, 'no_grad': 0}
-
+ param_stats[module_name] = {"total": 0, "requires_grad": 0, "no_grad": 0}
+
param_num = param.numel()
- param_stats[module_name]['total'] += param_num
+ param_stats[module_name]["total"] += param_num
total_params += param_num
-
+
if param.requires_grad:
- param_stats[module_name]['requires_grad'] += param_num
+ param_stats[module_name]["requires_grad"] += param_num
total_requires_grad_params += param_num
else:
- param_stats[module_name]['no_grad'] += param_num
+ param_stats[module_name]["no_grad"] += param_num
total_no_grad_params += param_num
-
+
# Calculate maximum width for each column
max_module_name_length = max(len(module) for module in param_stats)
- max_param_length = max(len(f"{stats['total'] / 1e6:.2f}M") for stats in param_stats.values())
-
+ max_param_length = max(
+ len(f"{stats['total'] / 1e6:.2f}M") for stats in param_stats.values()
+ )
+
# Output parameter statistics for each module
for module, stats in param_stats.items():
- logging.info(f"\t{module:<{max_module_name_length}}: "
- f"Total: {stats['total'] / 1e6:<{max_param_length}.2f}M, "
- f"Requires Grad: {stats['requires_grad'] / 1e6:<{max_param_length}.2f}M, "
- f"No Grad: {stats['no_grad'] / 1e6:<{max_param_length}.2f}M")
-
+ logging.info(
+ f"\t{module:<{max_module_name_length}}: "
+ f"Total: {stats['total'] / 1e6:<{max_param_length}.2f}M, "
+ f"Requires Grad: {stats['requires_grad'] / 1e6:<{max_param_length}.2f}M, "
+ f"No Grad: {stats['no_grad'] / 1e6:<{max_param_length}.2f}M"
+ )
+
# Output total parameter statistics
logging.info(f"\tTotal parameters: {total_params / 1e6:.2f}M parameters")
- logging.info(f"\tRequires Grad parameters: {total_requires_grad_params / 1e6:.2f}M parameters")
+ logging.info(
+ f"\tRequires Grad parameters: {total_requires_grad_params / 1e6:.2f}M parameters"
+ )
logging.info(f"\tNo Grad parameters: {total_no_grad_params / 1e6:.2f}M parameters")
logging.info(f"################################################################")
def load_and_resample_audio(audio_path, target_sample_rate):
- wav, raw_sample_rate = torchaudio.load(audio_path) # (1, T) tensor
- if raw_sample_rate != target_sample_rate:
- wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate) # tensor
+ wav, raw_sample_rate = torchaudio.load(audio_path) # (1, T) tensor
+ if raw_sample_rate != target_sample_rate:
+ wav = torchaudio.functional.resample(
+ wav, raw_sample_rate, target_sample_rate
+ ) # tensor
return wav.squeeze()
+
def set_logging():
rank = os.environ.get("RANK", 0)
logging.basicConfig(
@@ -63,55 +74,64 @@ def set_logging():
stream=sys.stdout,
format=f"%(asctime)s [RANK {rank}] (%(module)s:%(lineno)d) %(levelname)s : %(message)s",
)
-
+
+
def waiting_for_debug(ip, port):
rank = os.environ.get("RANK", "0")
- debugpy.listen((ip, port)) # Replace localhost with cluster node IP
+ debugpy.listen((ip, port)) # Replace localhost with cluster node IP
logging.info(f"[rank = {rank}] Waiting for debugger attach...")
debugpy.wait_for_client()
logging.info(f"[rank = {rank}] Debugger attached")
-
+
+
def load_audio(audio_path, target_sample_rate):
# Load audio file, wav shape: (channels, time)
wav, raw_sample_rate = torchaudio.load(audio_path)
-
+
# If multi-channel, convert to mono by averaging across channels
if wav.shape[0] > 1:
- wav = torch.mean(wav, dim=0, keepdim=True) # Average across channels, keep channel dim
-
+ wav = torch.mean(
+ wav, dim=0, keepdim=True
+ ) # Average across channels, keep channel dim
+
# Resample if necessary
if raw_sample_rate != target_sample_rate:
wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate)
-
+
# Convert to numpy, add channel dimension, then back to tensor with desired shape
wav = np.expand_dims(wav.squeeze(0).numpy(), axis=1) # Shape: (time, 1)
wav = torch.tensor(wav).reshape(1, 1, -1) # Shape: (1, 1, time)
-
+
return wav
+
def save_audio(audio_outpath, audio_out, sample_rate):
torchaudio.save(
- audio_outpath,
- audio_out,
- sample_rate=sample_rate,
- encoding='PCM_S',
- bits_per_sample=16
+ audio_outpath,
+ audio_out,
+ sample_rate=sample_rate,
+ encoding="PCM_S",
+ bits_per_sample=16,
)
logging.info(f"Successfully saved audio at {audio_outpath}")
-
+
+
def find_audio_files(input_dir):
- audio_extensions = ['*.flac', '*.mp3', '*.wav']
+ audio_extensions = ["*.flac", "*.mp3", "*.wav"]
audios_input = []
for ext in audio_extensions:
- audios_input.extend(glob.glob(os.path.join(input_dir, '**', ext), recursive=True))
+ audios_input.extend(
+ glob.glob(os.path.join(input_dir, "**", ext), recursive=True)
+ )
logging.info(f"Found {len(audios_input)} audio files in {input_dir}")
return sorted(audios_input)
+
def normalize_text(text):
# Remove all punctuation (including English and Chinese punctuation)
- text = re.sub(r'[^\w\s\u4e00-\u9fff]', '', text)
+ text = re.sub(r"[^\w\s\u4e00-\u9fff]", "", text)
# Convert to lowercase (effective for English, no effect on Chinese)
text = text.lower()
# Remove extra spaces
- text = ' '.join(text.split())
- return text
\ No newline at end of file
+ text = " ".join(text.split())
+ return text
diff --git a/XY_Tokenizer/xy_tokenizer/model.py b/XY_Tokenizer/xy_tokenizer/model.py
index 971ec93..5b73d0a 100644
--- a/XY_Tokenizer/xy_tokenizer/model.py
+++ b/XY_Tokenizer/xy_tokenizer/model.py
@@ -1,148 +1,198 @@
# -*- coding: utf-8 -*-
-import yaml
import logging
+
import torch
import torch.nn as nn
import torch.nn.functional as F
-
+import yaml
from .nn.feature_extractor import MelFeatureExtractor
-from .nn.modules import OmniAudioEncoder, OmniAudioDecoder, ResidualDownConv, UpConv, Transformer, Vocos
+from .nn.modules import (
+ OmniAudioDecoder,
+ OmniAudioEncoder,
+ ResidualDownConv,
+ Transformer,
+ UpConv,
+ Vocos,
+)
from .nn.quantizer import ResidualVQ
+
class XY_Tokenizer(nn.Module):
def __init__(self, generator_params):
super().__init__()
# Basic parameters
- self.input_sample_rate = generator_params['input_sample_rate']
- self.output_sample_rate = generator_params['output_sample_rate']
-
- self.encoder_downsample_rate = generator_params['encoder_downsample_rate']
- self.decoder_upsample_rate = generator_params['decoder_upsample_rate']
- self.code_dim = generator_params['quantizer_kwargs']['input_dim']
-
+ self.input_sample_rate = generator_params["input_sample_rate"]
+ self.output_sample_rate = generator_params["output_sample_rate"]
+
+ self.encoder_downsample_rate = generator_params["encoder_downsample_rate"]
+ self.decoder_upsample_rate = generator_params["decoder_upsample_rate"]
+ self.code_dim = generator_params["quantizer_kwargs"]["input_dim"]
+
## Codec part
## Semantic channel
- self.semantic_encoder = OmniAudioEncoder(**generator_params['semantic_encoder_kwargs'])
-
- self.semantic_encoder_adapter = Transformer(**generator_params['semantic_encoder_adapter_kwargs'])
-
+ self.semantic_encoder = OmniAudioEncoder(
+ **generator_params["semantic_encoder_kwargs"]
+ )
+
+ self.semantic_encoder_adapter = Transformer(
+ **generator_params["semantic_encoder_adapter_kwargs"]
+ )
+
## Acoustic channel
- self.acoustic_encoder = OmniAudioEncoder(**generator_params['acoustic_encoder_kwargs'])
-
+ self.acoustic_encoder = OmniAudioEncoder(
+ **generator_params["acoustic_encoder_kwargs"]
+ )
+
## Semantic & acoustic shared parameters
- self.pre_rvq_adapter = Transformer(**generator_params['pre_rvq_adapter_kwargs'])
-
- self.downsample = ResidualDownConv(**generator_params['downsample_kwargs'])
-
- self.quantizer = ResidualVQ(**generator_params['quantizer_kwargs'])
- self.nq = generator_params['quantizer_kwargs']['num_quantizers']
-
- self.post_rvq_adapter = Transformer(**generator_params['post_rvq_adapter_kwargs'])
-
+ self.pre_rvq_adapter = Transformer(**generator_params["pre_rvq_adapter_kwargs"])
+
+ self.downsample = ResidualDownConv(**generator_params["downsample_kwargs"])
+
+ self.quantizer = ResidualVQ(**generator_params["quantizer_kwargs"])
+ self.nq = generator_params["quantizer_kwargs"]["num_quantizers"]
+
+ self.post_rvq_adapter = Transformer(
+ **generator_params["post_rvq_adapter_kwargs"]
+ )
+
## Acoustic channel
- self.upsample = UpConv(**generator_params['upsample_kwargs'])
+ self.upsample = UpConv(**generator_params["upsample_kwargs"])
- self.acoustic_decoder = OmniAudioDecoder(**generator_params['acoustic_decoder_kwargs'])
+ self.acoustic_decoder = OmniAudioDecoder(
+ **generator_params["acoustic_decoder_kwargs"]
+ )
- self.enhanced_vocos = Vocos(**generator_params['vocos_kwargs'])
+ self.enhanced_vocos = Vocos(**generator_params["vocos_kwargs"])
## Feature extractor
- fe_kwargs = generator_params.get('feature_extractor_kwargs', {})
+ fe_kwargs = generator_params.get("feature_extractor_kwargs", {})
self.feature_extractor = MelFeatureExtractor(**fe_kwargs)
@torch.inference_mode()
def inference_tokenize(self, x, input_lengths):
"""
- Input:
- x: Waveform tensor # (B, 1, T), T <= 30s * sample_rate
- input_lengths: Valid length for each sample # (B,)
- Output:
- dict: Contains the following key-value pairs
- "zq": Quantized embeddings # (B, D, T)
- "codes": Quantization codes # (nq, B, T)
- "codes_lengths": Quantization code lengths # (B,)
+ Input:
+ x: Waveform tensor # (B, 1, T), T <= 30s * sample_rate
+ input_lengths: Valid length for each sample # (B,)
+ Output:
+ dict: Contains the following key-value pairs
+ "zq": Quantized embeddings # (B, D, T)
+ "codes": Quantization codes # (nq, B, T)
+ "codes_lengths": Quantization code lengths # (B,)
"""
- list_x = [xi[:, :x_len].reshape(-1).cpu().numpy() for xi, x_len in zip(x, input_lengths)]
+ list_x = [
+ xi[:, :x_len].reshape(-1).cpu().numpy()
+ for xi, x_len in zip(x, input_lengths)
+ ]
features = self.feature_extractor(
list_x,
sampling_rate=self.input_sample_rate,
return_tensors="pt",
- return_attention_mask=True
+ return_attention_mask=True,
)
- input_mel = features['input_features'].to(x.device).to(x.dtype) # (B, D, 3000)
- audio_attention_mask = features['attention_mask'].to(x.device) # (B, 3000)
-
+ input_mel = features["input_features"].to(x.device).to(x.dtype) # (B, D, 3000)
+ audio_attention_mask = features["attention_mask"].to(x.device) # (B, 3000)
+
# Get batch size and sequence length of the input
- mel_output_length = torch.sum(audio_attention_mask, dim=-1).long() # (B,)
-
+ mel_output_length = torch.sum(audio_attention_mask, dim=-1).long() # (B,)
+
# Semantic channel
- semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz
-
- semantic_encoder_adapter_output, semantic_encoder_adapter_output_length = self.semantic_encoder_adapter(semantic_encoder_output, semantic_encoder_output_length) # (B, D, T), 50hz
-
+ semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(
+ input_mel, mel_output_length
+ ) # (B, D, T), 100hz -> 50hz
+
+ semantic_encoder_adapter_output, semantic_encoder_adapter_output_length = (
+ self.semantic_encoder_adapter(
+ semantic_encoder_output, semantic_encoder_output_length
+ )
+ ) # (B, D, T), 50hz
+
# Acoustic channel
- acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz
-
+ acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(
+ input_mel, mel_output_length
+ ) # (B, D, T), 100hz -> 50hz
+
# Semantic & acoustic mixing
- concated_semantic_acoustic_channel = torch.concat([semantic_encoder_adapter_output, acoustic_encoder_output], dim=1) # (B, D, T)
+ concated_semantic_acoustic_channel = torch.concat(
+ [semantic_encoder_adapter_output, acoustic_encoder_output], dim=1
+ ) # (B, D, T)
concated_semantic_acoustic_channel_length = acoustic_encoder_output_length
-
- pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(concated_semantic_acoustic_channel, concated_semantic_acoustic_channel_length) # (B, D, T), 50hz
-
- downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, pre_rvq_adapter_output_length) # (B, D, T), 50hz -> 12.5hz
- zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(downsample_output, downsample_output_length) # (B, D, T), (nq, B, T), (nq,), (nq, B, D, T), (B,)
+ pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(
+ concated_semantic_acoustic_channel,
+ concated_semantic_acoustic_channel_length,
+ ) # (B, D, T), 50hz
+
+ downsample_output, downsample_output_length = self.downsample(
+ pre_rvq_adapter_output, pre_rvq_adapter_output_length
+ ) # (B, D, T), 50hz -> 12.5hz
+
+ zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(
+ downsample_output, downsample_output_length
+ ) # (B, D, T), (nq, B, T), (nq,), (nq, B, D, T), (B,)
return {
- "zq": zq, # (B, D, T)
- "codes": codes, # (nq, B, T)
- "codes_lengths": quantizer_output_length # (B,)
+ "zq": zq, # (B, D, T)
+ "codes": codes, # (nq, B, T)
+ "codes_lengths": quantizer_output_length, # (B,)
}
-
- @torch.inference_mode()
+
+ @torch.inference_mode()
def inference_detokenize(self, codes, codes_lengths):
"""
- Input:
- codes: Quantization codes # (nq, B, T)
- codes_lengths: Quantization code lengths for each sample # (B,)
- Output:
- dict: Contains the following key-value pairs
- "y": Synthesized audio waveform # (B, 1, T)
- "output_length": Output lengths # (B,)
+ Input:
+ codes: Quantization codes # (nq, B, T)
+ codes_lengths: Quantization code lengths for each sample # (B,)
+ Output:
+ dict: Contains the following key-value pairs
+ "y": Synthesized audio waveform # (B, 1, T)
+ "output_length": Output lengths # (B,)
"""
- zq = self.quantizer.decode_codes(codes) # (B, D, T)
-
- post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(zq, codes_lengths) # (B, D, T), 12.5hz
-
- # Acoustic channel
- upsample_output, upsample_output_length = self.upsample(post_rvq_adapter_output, post_rvq_adapter_output_length) # (B, D, T), 12.5hz -> 50hz
+ zq = self.quantizer.decode_codes(codes) # (B, D, T)
+
+ post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(
+ zq, codes_lengths
+ ) # (B, D, T), 12.5hz
- acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(upsample_output, upsample_output_length) # (B, D, T), 50hz -> 100hz
+ # Acoustic channel
+ upsample_output, upsample_output_length = self.upsample(
+ post_rvq_adapter_output, post_rvq_adapter_output_length
+ ) # (B, D, T), 12.5hz -> 50hz
+
+ acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(
+ upsample_output, upsample_output_length
+ ) # (B, D, T), 50hz -> 100hz
+
+ y, vocos_output_length = self.enhanced_vocos(
+ acoustic_decoder_output, acoustic_decoder_output_length
+ ) # (B, 1, T), 100hz -> 16khz
- y, vocos_output_length = self.enhanced_vocos(acoustic_decoder_output, acoustic_decoder_output_length) # (B, 1, T), 100hz -> 16khz
-
return {
- "y": y, # (B, 1, T)
- "output_length": vocos_output_length, # (B,)
+ "y": y, # (B, 1, T)
+ "output_length": vocos_output_length, # (B,)
}
-
+
@torch.inference_mode()
def encode(self, wav_list, overlap_seconds=10):
"""
- Input:
- wav_list: List of audio waveforms, each with potentially different length, may exceed 30 seconds # B * (T,)
- overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
- Output:
- dict: Contains the following key-value pairs
- "codes_list": List of quantization codes # B * (nq, T)
+ Input:
+ wav_list: List of audio waveforms, each with potentially different length, may exceed 30 seconds # B * (T,)
+ overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
+ Output:
+ dict: Contains the following key-value pairs
+ "codes_list": List of quantization codes # B * (nq, T)
"""
device = wav_list[0].device
duration_seconds = 30 - overlap_seconds
- chunk_size = int(30 * self.input_sample_rate) # Maximum samples per chunk
- duration_size = int(duration_seconds * self.input_sample_rate) # Valid output samples per chunk
- code_duration_length = duration_size // self.encoder_downsample_rate # Valid code length per chunk
+ chunk_size = int(30 * self.input_sample_rate) # Maximum samples per chunk
+ duration_size = int(
+ duration_seconds * self.input_sample_rate
+ ) # Valid output samples per chunk
+ code_duration_length = (
+ duration_size // self.encoder_downsample_rate
+ ) # Valid code length per chunk
# Get maximum waveform length
max_length = max(len(wav) for wav in wav_list)
@@ -150,8 +200,8 @@ def encode(self, wav_list, overlap_seconds=10):
wav_tensor = torch.zeros(batch_size, 1, max_length, device=device)
input_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
for i, wav in enumerate(wav_list):
- wav_tensor[i, 0, :len(wav)] = wav
- input_lengths[i] = len(wav) # (B,)
+ wav_tensor[i, 0, : len(wav)] = wav
+ input_lengths[i] = len(wav) # (B,)
# Calculate number of chunks needed
max_chunks = (max_length + duration_size - 1) // duration_size
@@ -161,122 +211,161 @@ def encode(self, wav_list, overlap_seconds=10):
for chunk_idx in range(max_chunks):
start = chunk_idx * duration_size
end = min(start + chunk_size, max_length)
- chunk = wav_tensor[:, :, start:end] # (B, 1, T')
- chunk_lengths = torch.clamp(input_lengths - start, 0, end - start) # (B,)
+ chunk = wav_tensor[:, :, start:end] # (B, 1, T')
+ chunk_lengths = torch.clamp(input_lengths - start, 0, end - start) # (B,)
# Skip empty chunks
if chunk_lengths.max() == 0:
continue
# Encode
- result = self.inference_tokenize(chunk, chunk_lengths) # {"zq": (B, D, T'), "codes": (nq, B, T'), "codes_lengths": (B,)}
- chunk_codes = result["codes"] # (nq, B, T')
- chunk_code_lengths = result["codes_lengths"] # (B,)
+ result = self.inference_tokenize(
+ chunk, chunk_lengths
+ ) # {"zq": (B, D, T'), "codes": (nq, B, T'), "codes_lengths": (B,)}
+ chunk_codes = result["codes"] # (nq, B, T')
+ chunk_code_lengths = result["codes_lengths"] # (B,)
# Extract valid portion
- valid_code_lengths = torch.clamp(chunk_code_lengths, 0, code_duration_length) # (B,)
- valid_chunk_codes = torch.zeros(self.nq, batch_size, code_duration_length, device=device, dtype=chunk_codes.dtype)
+ valid_code_lengths = torch.clamp(
+ chunk_code_lengths, 0, code_duration_length
+ ) # (B,)
+ valid_chunk_codes = torch.zeros(
+ self.nq,
+ batch_size,
+ code_duration_length,
+ device=device,
+ dtype=chunk_codes.dtype,
+ )
for b in range(batch_size):
if valid_code_lengths[b] > 0:
- valid_chunk_codes[:, b, :valid_code_lengths[b]] = chunk_codes[:, b, :valid_code_lengths[b]] # (nq, B, valid_code_length)
+ valid_chunk_codes[:, b, : valid_code_lengths[b]] = chunk_codes[
+ :, b, : valid_code_lengths[b]
+ ] # (nq, B, valid_code_length)
- codes_list.append(valid_chunk_codes) # (nq, B, valid_code_length)
+ codes_list.append(valid_chunk_codes) # (nq, B, valid_code_length)
# Concatenate all chunks
if codes_list:
- codes_tensor = torch.cat(codes_list, dim=-1) # (nq, B, T_total)
- codes_list = [codes_tensor[:, i, :input_lengths[i] // self.encoder_downsample_rate] for i in range(batch_size)] # B * (nq, T)
+ codes_tensor = torch.cat(codes_list, dim=-1) # (nq, B, T_total)
+ codes_list = [
+ codes_tensor[:, i, : input_lengths[i] // self.encoder_downsample_rate]
+ for i in range(batch_size)
+ ] # B * (nq, T)
else:
- codes_list = [torch.zeros(self.nq, 0, device=device, dtype=torch.long) for _ in range(batch_size)] # B * (nq, 0)
+ codes_list = [
+ torch.zeros(self.nq, 0, device=device, dtype=torch.long)
+ for _ in range(batch_size)
+ ] # B * (nq, 0)
+
+ return {"codes_list": codes_list} # B * (nq, T)
- return {
- "codes_list": codes_list # B * (nq, T)
- }
-
@torch.inference_mode()
def decode(self, codes_list, overlap_seconds=10):
"""
- Input:
- codes_list: List of quantization codes # B * (nq, T)
- overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
- Output:
- dict: Contains the following key-value pairs
- "syn_wav_list": List of synthesized audio waveforms # B * (T,)
+ Input:
+ codes_list: List of quantization codes # B * (nq, T)
+ overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
+ Output:
+ dict: Contains the following key-value pairs
+ "syn_wav_list": List of synthesized audio waveforms # B * (T,)
"""
device = codes_list[0].device
duration_seconds = 30 - overlap_seconds
- chunk_code_length = int(30 * self.input_sample_rate // self.encoder_downsample_rate) # Maximum code length per chunk
- duration_code_length = int(duration_seconds * self.input_sample_rate // self.encoder_downsample_rate) # Valid code length per chunk
- duration_wav_length = duration_code_length * self.decoder_upsample_rate # Valid waveform length per chunk
+ chunk_code_length = int(
+ 30 * self.input_sample_rate // self.encoder_downsample_rate
+ ) # Maximum code length per chunk
+ duration_code_length = int(
+ duration_seconds * self.input_sample_rate // self.encoder_downsample_rate
+ ) # Valid code length per chunk
+ duration_wav_length = (
+ duration_code_length * self.decoder_upsample_rate
+ ) # Valid waveform length per chunk
# Get maximum code length
max_code_length = max(codes.shape[-1] for codes in codes_list)
batch_size = len(codes_list)
- codes_tensor = torch.zeros(self.nq, batch_size, max_code_length, device=device, dtype=torch.long)
+ codes_tensor = torch.zeros(
+ self.nq, batch_size, max_code_length, device=device, dtype=torch.long
+ )
code_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
for i, codes in enumerate(codes_list):
- codes_tensor[:, i, :codes.shape[-1]] = codes.to(device)
- code_lengths[i] = codes.shape[-1] # (B,)
+ codes_tensor[:, i, : codes.shape[-1]] = codes.to(device)
+ code_lengths[i] = codes.shape[-1] # (B,)
# Calculate number of chunks needed
- max_chunks = (max_code_length + duration_code_length - 1) // duration_code_length
+ max_chunks = (
+ max_code_length + duration_code_length - 1
+ ) // duration_code_length
wav_list = []
# Process the entire batch in chunks
for chunk_idx in range(max_chunks):
start = chunk_idx * duration_code_length
end = min(start + chunk_code_length, max_code_length)
- chunk_codes = codes_tensor[:, :, start:end] # (nq, B, T')
- chunk_code_lengths = torch.clamp(code_lengths - start, 0, end - start) # (B,)
+ chunk_codes = codes_tensor[:, :, start:end] # (nq, B, T')
+ chunk_code_lengths = torch.clamp(
+ code_lengths - start, 0, end - start
+ ) # (B,)
# Skip empty chunks
if chunk_code_lengths.max() == 0:
continue
# Decode
- result = self.inference_detokenize(chunk_codes, chunk_code_lengths) # {"y": (B, 1, T'), "output_length": (B,)}
- chunk_wav = result["y"] # (B, 1, T')
- chunk_wav_lengths = result["output_length"] # (B,)
+ result = self.inference_detokenize(
+ chunk_codes, chunk_code_lengths
+ ) # {"y": (B, 1, T'), "output_length": (B,)}
+ chunk_wav = result["y"] # (B, 1, T')
+ chunk_wav_lengths = result["output_length"] # (B,)
# Extract valid portion
- valid_wav_lengths = torch.clamp(chunk_wav_lengths, 0, duration_wav_length) # (B,)
- valid_chunk_wav = torch.zeros(batch_size, 1, duration_wav_length, device=device)
+ valid_wav_lengths = torch.clamp(
+ chunk_wav_lengths, 0, duration_wav_length
+ ) # (B,)
+ valid_chunk_wav = torch.zeros(
+ batch_size, 1, duration_wav_length, device=device
+ )
for b in range(batch_size):
if valid_wav_lengths[b] > 0:
- valid_chunk_wav[b, :, :valid_wav_lengths[b]] = chunk_wav[b, :, :valid_wav_lengths[b]] # (B, 1, valid_wav_length)
+ valid_chunk_wav[b, :, : valid_wav_lengths[b]] = chunk_wav[
+ b, :, : valid_wav_lengths[b]
+ ] # (B, 1, valid_wav_length)
- wav_list.append(valid_chunk_wav) # (B, 1, valid_wav_length)
+ wav_list.append(valid_chunk_wav) # (B, 1, valid_wav_length)
# Concatenate all chunks
if wav_list:
- wav_tensor = torch.cat(wav_list, dim=-1) # (B, 1, T_total)
- syn_wav_list = [wav_tensor[i, 0, :code_lengths[i] * self.decoder_upsample_rate] for i in range(batch_size)] # B * (T,)
+ wav_tensor = torch.cat(wav_list, dim=-1) # (B, 1, T_total)
+ syn_wav_list = [
+ wav_tensor[i, 0, : code_lengths[i] * self.decoder_upsample_rate]
+ for i in range(batch_size)
+ ] # B * (T,)
else:
- syn_wav_list = [torch.zeros(0, device=device) for _ in range(batch_size)] # B * (0,)
-
- return {
- "syn_wav_list": syn_wav_list # B * (T,)
- }
-
+ syn_wav_list = [
+ torch.zeros(0, device=device) for _ in range(batch_size)
+ ] # B * (0,)
+
+ return {"syn_wav_list": syn_wav_list} # B * (T,)
+
@classmethod
def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
# Load model from configuration file and checkpoint
logging.info(f"Loading model from {config_path} and {ckpt_path}")
-
+
# Load configuration
- with open(config_path, 'r') as f:
+ with open(config_path, "r") as f:
config = yaml.safe_load(f)
-
+
# Create model instance
- model = cls(config['generator_params'])
-
+ model = cls(config["generator_params"])
+
# Load checkpoint
- checkpoint = torch.load(ckpt_path, map_location='cpu')
-
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
+
# Check if checkpoint contains 'generator' key
- if 'generator' in checkpoint:
- model.load_state_dict(checkpoint['generator'])
+ if "generator" in checkpoint:
+ model.load_state_dict(checkpoint["generator"])
else:
model.load_state_dict(checkpoint)
-
- return model
\ No newline at end of file
+
+ return model
diff --git a/XY_Tokenizer/xy_tokenizer/nn/feature_extractor.py b/XY_Tokenizer/xy_tokenizer/nn/feature_extractor.py
index 4d397b0..ccb14a6 100644
--- a/XY_Tokenizer/xy_tokenizer/nn/feature_extractor.py
+++ b/XY_Tokenizer/xy_tokenizer/nn/feature_extractor.py
@@ -1,16 +1,17 @@
+from typing import List, Optional, Union
+
import numpy as np
import torch
-
-from typing import Union, List, Optional
+from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import TensorType, logging
from transformers.utils.import_utils import is_torch_available
-from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
+
class MelFeatureExtractor(SequenceFeatureExtractor):
model_input_names = ["input_features"]
-
+
def __init__(
self,
feature_size=80,
@@ -38,7 +39,9 @@ def __init__(
self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate
self.dither = dither
- self.max_frequency = max_frequency if max_frequency is not None else sampling_rate / 2
+ self.max_frequency = (
+ max_frequency if max_frequency is not None else sampling_rate / 2
+ )
self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + n_fft // 2,
num_mel_filters=feature_size,
@@ -49,7 +52,9 @@ def __init__(
mel_scale="slaney",
)
- def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> np.ndarray:
+ def _np_extract_fbank_features(
+ self, waveform_batch: np.array, device: str
+ ) -> np.ndarray:
if device != "cpu":
raise ValueError(
f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator "
@@ -75,7 +80,9 @@ def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> n
log_spec_batch = np.array(log_spec_batch)
return log_spec_batch
- def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") -> np.ndarray:
+ def _torch_extract_fbank_features(
+ self, waveform: np.array, device: str = "cpu"
+ ) -> np.ndarray:
"""
Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching,
yielding results similar to cpu computing with 1e-5 tolerance.
@@ -84,9 +91,13 @@ def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu")
window = torch.hann_window(self.n_fft, device=device)
if self.dither != 0.0:
- waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device)
+ waveform += self.dither * torch.randn(
+ waveform.shape, dtype=waveform.dtype, device=waveform.device
+ )
- stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
+ stft = torch.stft(
+ waveform, self.n_fft, self.hop_length, window=window, return_complex=True
+ )
magnitudes = stft[..., :-1].abs() ** 2
mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
@@ -105,7 +116,9 @@ def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu")
@staticmethod
def zero_mean_unit_var_norm(
- input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
+ input_values: List[np.ndarray],
+ attention_mask: List[np.ndarray],
+ padding_value: float = 0.0,
) -> List[np.ndarray]:
"""
Every array in the list is normalized to have zero mean and unit variance
@@ -115,13 +128,17 @@ def zero_mean_unit_var_norm(
normed_input_values = []
for vector, length in zip(input_values, attention_mask.sum(-1)):
- normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
+ normed_slice = (vector - vector[:length].mean()) / np.sqrt(
+ vector[:length].var() + 1e-7
+ )
if length < normed_slice.shape[0]:
normed_slice[length:] = padding_value
normed_input_values.append(normed_slice)
else:
- normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
+ normed_input_values = [
+ (x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values
+ ]
return normed_input_values
@@ -177,18 +194,27 @@ def __call__(
f"extractor's sampling rate to ensure correct feature extraction."
)
- is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
+ is_batched_numpy = (
+ isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
+ )
if is_batched_numpy and len(raw_speech.shape) > 2:
- raise ValueError(f"Only mono-channel audio is supported for input to {self}")
+ raise ValueError(
+ f"Only mono-channel audio is supported for input to {self}"
+ )
is_batched = is_batched_numpy or (
- isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
+ isinstance(raw_speech, (list, tuple))
+ and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
)
if is_batched:
- raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
+ raw_speech = [
+ np.asarray([speech], dtype=np.float32).T for speech in raw_speech
+ ]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
- elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(
+ np.float64
+ ):
raw_speech = raw_speech.astype(np.float32)
if not is_batched:
@@ -211,27 +237,37 @@ def __call__(
attention_mask=padded_inputs["attention_mask"],
padding_value=self.padding_value,
)
- padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], axis=0)
+ padded_inputs["input_features"] = np.stack(
+ padded_inputs["input_features"], axis=0
+ )
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
extract_fbank_features = (
- self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
+ self._torch_extract_fbank_features
+ if is_torch_available()
+ else self._np_extract_fbank_features
)
input_features = extract_fbank_features(input_features[0], device)
if isinstance(input_features[0], List):
- padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
+ padded_inputs["input_features"] = [
+ np.asarray(feature, dtype=np.float32) for feature in input_features
+ ]
else:
padded_inputs["input_features"] = input_features
if return_attention_mask:
- padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
+ padded_inputs["attention_mask"] = padded_inputs["attention_mask"][
+ :, :: self.hop_length
+ ]
if return_token_timestamps is not None:
- padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech]
+ padded_inputs["num_frames"] = [
+ len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech
+ ]
if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
- return padded_inputs
\ No newline at end of file
+ return padded_inputs
diff --git a/XY_Tokenizer/xy_tokenizer/nn/modules.py b/XY_Tokenizer/xy_tokenizer/nn/modules.py
index cc186d9..eabc01b 100644
--- a/XY_Tokenizer/xy_tokenizer/nn/modules.py
+++ b/XY_Tokenizer/xy_tokenizer/nn/modules.py
@@ -1,24 +1,21 @@
-import torch
-import torch.distributed
-import numpy as np
+import copy
import logging
import math
-import copy
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import librosa
import numpy as np
import scipy
import torch
-import librosa
-
-from typing import Optional, Tuple
-from torch import nn, view_as_real, view_as_complex
-from torch import nn
+import torch.distributed
+from torch import nn, view_as_complex, view_as_real
from torch.nn import functional as F
-from torch.nn.utils import weight_norm, remove_weight_norm
+from torch.nn.utils import remove_weight_norm, weight_norm
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
+from transformers import WhisperModel
from transformers.activations import ACT2FN
-from dataclasses import dataclass
from transformers.modeling_outputs import ModelOutput
-from transformers import WhisperModel
# Define function to generate positional embeddings using sine and cosine functions to represent sequence position information
@@ -30,6 +27,7 @@ def sinusoids(length, channels, max_timescale=10000):
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+
# Generate sequence mask to distinguish valid sequence and padding parts
def get_sequence_mask(inputs, inputs_length):
if inputs.dim() == 3:
@@ -37,9 +35,12 @@ def get_sequence_mask(inputs, inputs_length):
else:
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
- sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(
+ bsz, tgt_len, 1
+ )
return sequence_mask
+
# Define RMSNorm layer for normalizing hidden states and stabilizing training process
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
@@ -54,6 +55,7 @@ def forward(self, hidden_states):
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
+
# Modified variable-length attention mechanism, supporting FP32 with unified interface
class VarLenAttention(nn.Module):
def __init__(self, embed_dim, num_heads, causal=False, dropout=0.0):
@@ -73,7 +75,7 @@ def __init__(self, embed_dim, num_heads, causal=False, dropout=0.0):
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.causal = causal
self.dropout = nn.Dropout(dropout)
- self.scaling = self.head_dim ** -0.5 # Scaling factor
+ self.scaling = self.head_dim**-0.5 # Scaling factor
# Linear projection layers for Q, K, V and output
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
@@ -104,18 +106,29 @@ def _create_attention_mask(self, seq_len, max_len, device, dtype):
# Mark valid positions (less than seq_len)
valid_mask = seq_indices < seq_len_expanded.unsqueeze(-1) # [bsz, 1, max_len]
- mask = mask * (valid_mask.unsqueeze(2) & valid_mask.unsqueeze(3)).to(dtype) # [bsz, 1, max_len, max_len]
+ mask = mask * (valid_mask.unsqueeze(2) & valid_mask.unsqueeze(3)).to(
+ dtype
+ ) # [bsz, 1, max_len, max_len]
# If causal attention, add upper triangular mask
if self.causal:
- causal_mask = torch.triu(torch.ones(max_len, max_len, device=device, dtype=torch.bool), diagonal=1)
- mask = mask * (~causal_mask.unsqueeze(0).unsqueeze(1)).to(dtype) # Keep only lower triangular part
+ causal_mask = torch.triu(
+ torch.ones(max_len, max_len, device=device, dtype=torch.bool),
+ diagonal=1,
+ )
+ mask = mask * (~causal_mask.unsqueeze(0).unsqueeze(1)).to(
+ dtype
+ ) # Keep only lower triangular part
# Set invalid positions (0) to dtype's minimum value
- mask = mask + (1.0 - mask) * torch.finfo(dtype).min # Valid positions unchanged, invalid positions to minimum value
+ mask = (
+ mask + (1.0 - mask) * torch.finfo(dtype).min
+ ) # Valid positions unchanged, invalid positions to minimum value
return mask
- def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
+ def forward(
+ self, hidden_states: torch.Tensor, seq_len: torch.Tensor
+ ) -> torch.Tensor:
"""
Forward propagation, input and output are [bsz, max_len, embed_dim].
@@ -130,43 +143,73 @@ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.T
# Project to Q, K, V
query = self.q_proj(hidden_states) * self.scaling # [bsz, max_len, embed_dim]
- key = self.k_proj(hidden_states) # [bsz, max_len, embed_dim]
- value = self.v_proj(hidden_states) # [bsz, max_len, embed_dim]
+ key = self.k_proj(hidden_states) # [bsz, max_len, embed_dim]
+ value = self.v_proj(hidden_states) # [bsz, max_len, embed_dim]
# Reshape to multi-head form
- query = query.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2) # [bsz, num_heads, max_len, head_dim]
- key = key.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2) # [bsz, num_heads, max_len, head_dim]
- value = value.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2) # [bsz, num_heads, max_len, head_dim]
+ query = query.view(bsz, max_len, self.num_heads, self.head_dim).transpose(
+ 1, 2
+ ) # [bsz, num_heads, max_len, head_dim]
+ key = key.view(bsz, max_len, self.num_heads, self.head_dim).transpose(
+ 1, 2
+ ) # [bsz, num_heads, max_len, head_dim]
+ value = value.view(bsz, max_len, self.num_heads, self.head_dim).transpose(
+ 1, 2
+ ) # [bsz, num_heads, max_len, head_dim]
# Calculate attention scores
- attn_scores = torch.matmul(query, key.transpose(-1, -2)) # [bsz, num_heads, max_len, max_len]
+ attn_scores = torch.matmul(
+ query, key.transpose(-1, -2)
+ ) # [bsz, num_heads, max_len, max_len]
# Generate attention mask
- attn_mask = self._create_attention_mask(seq_len, max_len, hidden_states.device, attn_scores.dtype) # [bsz, 1, max_len, max_len]
+ attn_mask = self._create_attention_mask(
+ seq_len, max_len, hidden_states.device, attn_scores.dtype
+ ) # [bsz, 1, max_len, max_len]
# Apply mask (additive form, consistent with HubertEncoder)
- attn_scores = attn_scores + attn_mask # Invalid positions set to very small value
+ attn_scores = (
+ attn_scores + attn_mask
+ ) # Invalid positions set to very small value
# Softmax calculate attention weights
- attn_weights = F.softmax(attn_scores, dim=-1) # [bsz, num_heads, max_len, max_len]
+ attn_weights = F.softmax(
+ attn_scores, dim=-1
+ ) # [bsz, num_heads, max_len, max_len]
attn_weights = self.dropout(attn_weights)
# Calculate attention output
- attn_output = torch.matmul(attn_weights, value) # [bsz, num_heads, max_len, head_dim]
- attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, max_len, self.embed_dim) # [bsz, max_len, embed_dim]
+ attn_output = torch.matmul(
+ attn_weights, value
+ ) # [bsz, num_heads, max_len, head_dim]
+ attn_output = (
+ attn_output.transpose(1, 2).contiguous().view(bsz, max_len, self.embed_dim)
+ ) # [bsz, max_len, embed_dim]
# Output projection
attn_output = self.out_proj(attn_output) # [bsz, max_len, embed_dim]
return attn_output
+
# Define Transformer layer containing attention mechanism and feedforward network for feature extraction and transformation
class OmniWhisperTransformerLayer(nn.Module):
- def __init__(self, activation_function="gelu", d_model=1280, attention_heads=20, ffn_dim=5120, causal=False, ln_type="LayerNorm", attn_type="varlen"):
+ def __init__(
+ self,
+ activation_function="gelu",
+ d_model=1280,
+ attention_heads=20,
+ ffn_dim=5120,
+ causal=False,
+ ln_type="LayerNorm",
+ attn_type="varlen",
+ ):
super().__init__()
self.embed_dim = d_model
# Only keep varlen attention mechanism
if attn_type != "varlen":
- raise ValueError(f"Unknown attn_type: {attn_type}. Only 'varlen' is supported.")
+ raise ValueError(
+ f"Unknown attn_type: {attn_type}. Only 'varlen' is supported."
+ )
self.self_attn = VarLenAttention(self.embed_dim, attention_heads, causal)
if ln_type == "LayerNorm":
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
@@ -184,85 +227,105 @@ def __init__(self, activation_function="gelu", d_model=1280, attention_heads=20,
else:
raise ValueError(f"Unknown ln_type: {ln_type}")
- def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
+ def forward(
+ self, hidden_states: torch.Tensor, seq_len: torch.Tensor
+ ) -> torch.Tensor:
residual = hidden_states # [bsz, max_len, embed_dim]
hidden_states = self.self_attn_layer_norm(hidden_states)
# from torch.cuda.amp import autocast
# print(f"{residual.dtype = }")
# print(f"Autocast enabled: {torch.is_autocast_enabled():}")
# print(f"after layernorm {hidden_states.dtype = }")
- hidden_states = self.self_attn(hidden_states, seq_len) # [bsz, max_len, embed_dim]
+ hidden_states = self.self_attn(
+ hidden_states, seq_len
+ ) # [bsz, max_len, embed_dim]
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
- if (hidden_states.dtype == torch.float16 or hidden_states.dtype == torch.bfloat16) and \
- (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
+ if (
+ hidden_states.dtype == torch.float16
+ or hidden_states.dtype == torch.bfloat16
+ ) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+ hidden_states = torch.clamp(
+ hidden_states, min=-clamp_value, max=clamp_value
+ )
return hidden_states
+
# Define audio encoder to convert input audio features to hidden state representation
class OmniAudioEncoder(nn.Module):
def __init__(
- self,
- num_mel_bins=128, # Input feature Mel band number, usually the dimension of Mel spectrogram
- sampling_rate=16000, # Audio sampling rate, unit Hz
- hop_length=160, # Frame shift length (sample number) when calculating Mel spectrogram
- stride_size=2, # Convolution layer step, used for downsampling
- kernel_size=3, # Convolution kernel size, controlling receptive field
- d_model=1280, # Model's hidden state dimension (embedding dimension)
- scale_embedding=True, # Whether to scale embedding (usually used for stabilizing training)
- max_audio_seconds=30, # Maximum audio duration supported (seconds)
- encoder_layers=32, # Transformer encoder layer number
- encoder_attention_heads=20, # Attention head number for each Transformer layer
- encoder_ffn_dim=5120, # Intermediate dimension for feedforward network
- activation_function="gelu", # Activation function type, default GELU
- attn_type="varlen" # New parameter, select attention mechanism type
- ):
+ self,
+ num_mel_bins=128, # Input feature Mel band number, usually the dimension of Mel spectrogram
+ sampling_rate=16000, # Audio sampling rate, unit Hz
+ hop_length=160, # Frame shift length (sample number) when calculating Mel spectrogram
+ stride_size=2, # Convolution layer step, used for downsampling
+ kernel_size=3, # Convolution kernel size, controlling receptive field
+ d_model=1280, # Model's hidden state dimension (embedding dimension)
+ scale_embedding=True, # Whether to scale embedding (usually used for stabilizing training)
+ max_audio_seconds=30, # Maximum audio duration supported (seconds)
+ encoder_layers=32, # Transformer encoder layer number
+ encoder_attention_heads=20, # Attention head number for each Transformer layer
+ encoder_ffn_dim=5120, # Intermediate dimension for feedforward network
+ activation_function="gelu", # Activation function type, default GELU
+ attn_type="varlen", # New parameter, select attention mechanism type
+ ):
super().__init__()
# Calculate maximum sequence length: Convert sampling rate to frame number after considering downsampling step
- self.max_source_positions = (max_audio_seconds * sampling_rate // hop_length) // stride_size
+ self.max_source_positions = (
+ max_audio_seconds * sampling_rate // hop_length
+ ) // stride_size
# Embedding scaling factor, if enabled sqrt(d_model), otherwise 1.0
self.embed_scale = math.sqrt(d_model) if scale_embedding else 1.0
self.num_mel_bins = num_mel_bins # Save Mel band number
self.d_model = d_model # Save hidden state dimension
self.stride_size = stride_size
-
+
# First convolution layer: Convert Mel spectrogram features (num_mel_bins) to hidden dimension (d_model)
- self.conv1 = nn.Conv1d(num_mel_bins, d_model, kernel_size=kernel_size, padding=1)
+ self.conv1 = nn.Conv1d(
+ num_mel_bins, d_model, kernel_size=kernel_size, padding=1
+ )
# Second convolution layer: Apply downsampling with stride_size
- self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=kernel_size, stride=stride_size, padding=1)
-
+ self.conv2 = nn.Conv1d(
+ d_model, d_model, kernel_size=kernel_size, stride=stride_size, padding=1
+ )
+
# Register positional embedding buffer, using sine function to generate, shape (max_source_positions, d_model)
- self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model))
-
+ self.register_buffer(
+ "positional_embedding", sinusoids(self.max_source_positions, d_model)
+ )
+
# Create Transformer encoder layer list, each layer contains attention mechanism and feedforward network
- self.layers = nn.ModuleList([
- OmniWhisperTransformerLayer(
- activation_function=activation_function,
- d_model=d_model,
- attention_heads=encoder_attention_heads,
- ffn_dim=encoder_ffn_dim,
- causal=False, # Encoder does not need causal attention
- attn_type=attn_type # Pass attention type
- ) for _ in range(encoder_layers)
- ])
-
+ self.layers = nn.ModuleList(
+ [
+ OmniWhisperTransformerLayer(
+ activation_function=activation_function,
+ d_model=d_model,
+ attention_heads=encoder_attention_heads,
+ ffn_dim=encoder_ffn_dim,
+ causal=False, # Encoder does not need causal attention
+ attn_type=attn_type, # Pass attention type
+ )
+ for _ in range(encoder_layers)
+ ]
+ )
+
# Last layer normalization for stable output
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, input_features, input_length, output_hidden_states=False):
"""
Forward propagation function to convert input audio features to hidden state representation
-
+
Parameters:
input_features (torch.Tensor): Input Mel spectrogram features, shape [bsz, num_mel_bins, seq_len]
input_length (torch.Tensor): Input sequence length for each sample, shape [bsz]
output_hidden_states (bool, optional): Whether to return hidden states for each layer, default False
-
+
Returns:
if output_hidden_states is False:
hidden_states (torch.Tensor): Encoded hidden states, shape [bsz, d_model, tgt_len]
@@ -274,160 +337,187 @@ def forward(self, input_features, input_length, output_hidden_states=False):
"""
# Ensure input feature data type consistent with convolution layer weights
input_features = input_features.to(self.conv1.weight.dtype) # (B, D, T)
-
+
# First layer convolution + GELU activation, Convert Mel spectrogram to hidden states
inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (B, D, T)
-
+
# Second layer convolution + GELU activation, Apply downsampling with stride_size
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (B, D, T)
-
+
# Calculate output length: Result after downsampling with stride_size
output_length = (input_length // self.stride_size).long() # (B,)
-
+
# Adjust dimension order to [bsz, seq_len, d_model] for Transformer input
hidden_states = inputs_embeds.permute(0, 2, 1) # (B, T, D)
-
+
# Get batch size and target sequence length
bsz, tgt_len, _ = hidden_states.size()
-
+
# According to current sequence length, take or use complete positional embedding
if tgt_len < self.positional_embedding.shape[0]:
current_positional_embedding = self.positional_embedding[:tgt_len]
else:
current_positional_embedding = self.positional_embedding
-
+
# Add input embedding to positional embedding, convert to float to avoid precision issues
- hidden_states = (hidden_states.to(torch.float32) + current_positional_embedding).to(hidden_states.dtype)
-
+ hidden_states = (
+ hidden_states.to(torch.float32) + current_positional_embedding
+ ).to(hidden_states.dtype)
+
# Generate sequence mask for processing variable-length sequence
- attention_mask = get_sequence_mask(hidden_states, output_length) # [bsz, tgt_len, 1]
-
+ attention_mask = get_sequence_mask(
+ hidden_states, output_length
+ ) # [bsz, tgt_len, 1]
+
# Initialize hidden states list for storing output for each layer (if needed)
hidden_states_all_layers = () if output_hidden_states else None
-
+
# Process hidden states through Transformer encoder layer by layer
for encoder_layer in self.layers:
if output_hidden_states:
hidden_states_all_layers = hidden_states_all_layers + (hidden_states,)
- hidden_states = encoder_layer(hidden_states, output_length) # [bsz, tgt_len, d_model]
-
+ hidden_states = encoder_layer(
+ hidden_states, output_length
+ ) # [bsz, tgt_len, d_model]
+
# Normalize hidden states
hidden_states = self.layer_norm(hidden_states) # [bsz, tgt_len, d_model]
if output_hidden_states:
hidden_states_all_layers = hidden_states_all_layers + (hidden_states,)
-
+
# Use mask to zero out padding parts and ensure output only retains valid data
- hidden_states = torch.where(attention_mask, hidden_states, 0) # [bsz, tgt_len, d_model]
+ hidden_states = torch.where(
+ attention_mask, hidden_states, 0
+ ) # [bsz, tgt_len, d_model]
hidden_states = hidden_states.transpose(1, 2) # [bsz, d_model, tgt_len]
-
+
if not output_hidden_states:
- return hidden_states, output_length
+ return hidden_states, output_length
else:
return hidden_states, output_length, hidden_states_all_layers
-
+
+
# Define audio decoder to convert hidden states to Mel spectrogram
class OmniAudioDecoder(nn.Module):
def __init__(
- self,
- num_mel_bins=128,
- sampling_rate=16000,
- hop_length=160,
- stride_size=2,
- kernel_size=3,
- d_model=1280,
- scale_embedding=True,
- max_audio_seconds=30,
- decoder_layers=32,
- decoder_attention_heads=20,
- decoder_ffn_dim=5120,
- activation_function="gelu",
- attn_type="varlen" # New parameter, select attention mechanism type
- ):
+ self,
+ num_mel_bins=128,
+ sampling_rate=16000,
+ hop_length=160,
+ stride_size=2,
+ kernel_size=3,
+ d_model=1280,
+ scale_embedding=True,
+ max_audio_seconds=30,
+ decoder_layers=32,
+ decoder_attention_heads=20,
+ decoder_ffn_dim=5120,
+ activation_function="gelu",
+ attn_type="varlen", # New parameter, select attention mechanism type
+ ):
super().__init__()
- self.max_source_positions = (max_audio_seconds * sampling_rate // hop_length) // stride_size
+ self.max_source_positions = (
+ max_audio_seconds * sampling_rate // hop_length
+ ) // stride_size
self.embed_scale = math.sqrt(d_model) if scale_embedding else 1.0
self.num_mel_bins = num_mel_bins
self.d_model = d_model
self.stride_size = stride_size
-
+
# Correct transpose convolution layer to ensure output length close to stride_size times
self.deconv1 = nn.ConvTranspose1d(
- d_model,
- d_model,
- kernel_size=kernel_size,
- stride=stride_size,
+ d_model,
+ d_model,
+ kernel_size=kernel_size,
+ stride=stride_size,
padding=0, # Do not fill input side
- output_padding=0 # Can be adjusted to precisely control length
+ output_padding=0, # Can be adjusted to precisely control length
)
self.deconv2 = nn.ConvTranspose1d(
- d_model,
- num_mel_bins,
- kernel_size=kernel_size,
+ d_model,
+ num_mel_bins,
+ kernel_size=kernel_size,
stride=1, # Only convert channels, do not change length
- padding=0
+ padding=0,
)
-
+
# Positional embedding remains consistent
- self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model)) # (T, D)
-
+ self.register_buffer(
+ "positional_embedding", sinusoids(self.max_source_positions, d_model)
+ ) # (T, D)
+
# Transformer decoder layer
- self.layers = nn.ModuleList([
- OmniWhisperTransformerLayer(
- activation_function=activation_function,
- d_model=d_model,
- attention_heads=decoder_attention_heads,
- ffn_dim=decoder_ffn_dim,
- causal=False, # Decoder uses causal attention
- attn_type=attn_type # Pass attention type
- ) for _ in range(decoder_layers)
- ])
+ self.layers = nn.ModuleList(
+ [
+ OmniWhisperTransformerLayer(
+ activation_function=activation_function,
+ d_model=d_model,
+ attention_heads=decoder_attention_heads,
+ ffn_dim=decoder_ffn_dim,
+ causal=False, # Decoder uses causal attention
+ attn_type=attn_type, # Pass attention type
+ )
+ for _ in range(decoder_layers)
+ ]
+ )
self.layer_norm = nn.LayerNorm(d_model)
- def forward(self, hidden_states, input_length): # (B, D, T)
+ def forward(self, hidden_states, input_length): # (B, D, T)
# Input is hidden state output from encoder
- hidden_states = hidden_states.transpose(1, 2) # (B, T, D)
+ hidden_states = hidden_states.transpose(1, 2) # (B, T, D)
bsz, tgt_len, _ = hidden_states.size()
-
+
# Add positional embedding
if tgt_len < self.positional_embedding.shape[0]:
- current_positional_embedding = self.positional_embedding[:tgt_len] # (T, D)
+ current_positional_embedding = self.positional_embedding[:tgt_len] # (T, D)
else:
current_positional_embedding = self.positional_embedding
- hidden_states = (hidden_states.to(torch.float32) + current_positional_embedding).to(hidden_states.dtype) # (B, T, D)
-
+ hidden_states = (
+ hidden_states.to(torch.float32) + current_positional_embedding
+ ).to(
+ hidden_states.dtype
+ ) # (B, T, D)
+
# Generate sequence mask
- attention_mask = get_sequence_mask(hidden_states, input_length) # [bsz, tgt_len, 1]
-
+ attention_mask = get_sequence_mask(
+ hidden_states, input_length
+ ) # [bsz, tgt_len, 1]
+
# Process through decoder layer
for decoder_layer in self.layers:
- hidden_states = decoder_layer(hidden_states, input_length) # [bsz, tgt_len, d_model]
-
+ hidden_states = decoder_layer(
+ hidden_states, input_length
+ ) # [bsz, tgt_len, d_model]
+
# Final layer normalization
hidden_states = self.layer_norm(hidden_states) # [bsz, tgt_len, d_model]
-
+
# Use mask to zero out padding parts
- hidden_states = torch.where(attention_mask, hidden_states, 0) # [bsz, tgt_len, d_model]
-
+ hidden_states = torch.where(
+ attention_mask, hidden_states, 0
+ ) # [bsz, tgt_len, d_model]
+
# Process through transpose convolution layer to reconstruct audio features
hidden_states = hidden_states.permute(0, 2, 1) # (B, D, T)
- output_features = nn.functional.gelu(self.deconv1(hidden_states)) # (B, D, T)
- output_features = nn.functional.gelu(self.deconv2(output_features)) # (B, D, T)
-
+ output_features = nn.functional.gelu(self.deconv1(hidden_states)) # (B, D, T)
+ output_features = nn.functional.gelu(self.deconv2(output_features)) # (B, D, T)
+
# If strictly stride_size times length is needed, can trim extra parts
expected_length = tgt_len * self.stride_size
if output_features.size(2) > expected_length:
output_features = output_features[:, :, :expected_length]
-
+
output_length = input_length * self.stride_size
# Output shape: [bsz, num_mel_bins, seq_len]
return output_features, output_length
+
# The following part remains unchanged
class ResidualDownConv(nn.Module):
def __init__(self, d_model=1280, avg_pooler=4):
"""
Downsampling module containing residual connection and convolution operation
-
+
Parameters:
d_model (int): Input and output hidden dimension
avg_pooler (int): Downsampling factor (convolution step)
@@ -436,52 +526,58 @@ def __init__(self, d_model=1280, avg_pooler=4):
self.d_model = d_model
self.avg_pooler = avg_pooler
self.intermediate_dim = d_model * avg_pooler
-
+
# Convolution layer for downsampling
- self.gate_proj = nn.Conv1d(d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False)
- self.up_proj = nn.Conv1d(d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False)
-
+ self.gate_proj = nn.Conv1d(
+ d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False
+ )
+ self.up_proj = nn.Conv1d(
+ d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False
+ )
+
# Downsampled linear projection
- self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False)
-
+ self.down_proj = nn.Linear(
+ self.intermediate_dim, self.intermediate_dim, bias=False
+ )
+
# Activation function and layer normalization
- self.act_fn = ACT2FN['silu']
+ self.act_fn = ACT2FN["silu"]
self.layer_norm = nn.LayerNorm(self.intermediate_dim)
def forward(self, x, input_length):
"""
Forward propagation, execute downsampling and residual processing
-
+
Parameters:
x (torch.Tensor): Input tensor, shape [B, D, T]
-
+
Returns:
res (torch.Tensor): Downsampled feature, shape [B, intermediate_dim, seq_len // avg_pooler]
valid_mask (torch.Tensor): Valid sequence mask
"""
output_length = input_length // self.avg_pooler
- x = x.transpose(1, 2) # (B, T, D)
- batch_size, seq_len, _ = x.shape # (B, T, D)
+ x = x.transpose(1, 2) # (B, T, D)
+ batch_size, seq_len, _ = x.shape # (B, T, D)
if seq_len % self.avg_pooler != 0:
pad_size = self.avg_pooler - seq_len % self.avg_pooler
x = F.pad(x, (0, pad_size), "constant", 0)
-
- xt = x.permute(0, 2, 1) # (B, D, T)
+
+ xt = x.permute(0, 2, 1) # (B, D, T)
g = self.gate_proj(xt).permute(0, 2, 1) # (B, T, D)
- u = self.up_proj(xt).permute(0, 2, 1) # (B, T, D)
+ u = self.up_proj(xt).permute(0, 2, 1) # (B, T, D)
x = x.reshape(batch_size, -1, self.intermediate_dim) # (B, T, D)
- c = self.down_proj(self.act_fn(g) * u) # (B, T, D)
- res = self.layer_norm(c + x) # (B, T, D)
- res = res.transpose(1, 2) # (B, D, T)
- return res, output_length # (B, D, T)
-
-
+ c = self.down_proj(self.act_fn(g) * u) # (B, T, D)
+ res = self.layer_norm(c + x) # (B, T, D)
+ res = res.transpose(1, 2) # (B, D, T)
+ return res, output_length # (B, D, T)
+
+
class UpConv(nn.Module):
def __init__(self, d_model=1280, stride=4):
"""
Simple upsampling module using transpose convolution
-
+
Parameters:
d_model (int): Input and output hidden dimension
stride (int): Upsampling factor (transpose convolution step)
@@ -489,23 +585,23 @@ def __init__(self, d_model=1280, stride=4):
super().__init__()
self.d_model = d_model
self.stride = stride
-
+
# Simple transpose convolution layer to keep channel number consistent
self.up_conv = nn.ConvTranspose1d(
- self.stride * d_model,
- d_model,
- kernel_size=stride,
- stride=stride,
- bias=False
+ self.stride * d_model,
+ d_model,
+ kernel_size=stride,
+ stride=stride,
+ bias=False,
)
def forward(self, x, input_length):
"""
Forward propagation, execute upsampling
-
+
Parameters:
x (torch.Tensor): Input tensor, shape [B, D * stride, T]
-
+
Returns:
res (torch.Tensor): Upsampled feature, shape [B, D, T * stride]
"""
@@ -513,21 +609,21 @@ def forward(self, x, input_length):
res = self.up_conv(x)
output_length = input_length * self.stride
return res, output_length
-
+
# Define Transformer encoder containing multiple Transformer layers for feature extraction and transformation
class Transformer(nn.Module):
def __init__(
- self,
- input_dim=1280, # Input feature dimension
- d_model=1280, # Model's hidden state dimension (embedding dimension)
- output_dim=1280, # Output feature dimension
- max_source_positions=1500, # Maximum sequence length for positional embedding
- encoder_layers=32, # Transformer encoder layer number
- encoder_attention_heads=20, # Attention head number for each Transformer layer
- encoder_ffn_dim=5120, # Intermediate dimension for feedforward network
- activation_function="gelu", # Activation function type, default GELU
- attn_type="varlen" # Attention mechanism type
+ self,
+ input_dim=1280, # Input feature dimension
+ d_model=1280, # Model's hidden state dimension (embedding dimension)
+ output_dim=1280, # Output feature dimension
+ max_source_positions=1500, # Maximum sequence length for positional embedding
+ encoder_layers=32, # Transformer encoder layer number
+ encoder_attention_heads=20, # Attention head number for each Transformer layer
+ encoder_ffn_dim=5120, # Intermediate dimension for feedforward network
+ activation_function="gelu", # Activation function type, default GELU
+ attn_type="varlen", # Attention mechanism type
):
super().__init__()
self.input_dim = input_dim # Save input dimension
@@ -542,20 +638,25 @@ def __init__(
self.proj = None # No need for input projection layer
# Register positional embedding buffer, using sine function to generate, shape (max_source_positions, d_model)
- self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model))
-
+ self.register_buffer(
+ "positional_embedding", sinusoids(self.max_source_positions, d_model)
+ )
+
# Create Transformer encoder layer list, each layer contains attention mechanism and feedforward network
- self.layers = nn.ModuleList([
- OmniWhisperTransformerLayer(
- activation_function=activation_function,
- d_model=d_model,
- attention_heads=encoder_attention_heads,
- ffn_dim=encoder_ffn_dim,
- causal=False, # Encoder does not need causal attention
- attn_type=attn_type # Pass attention type
- ) for _ in range(encoder_layers)
- ])
-
+ self.layers = nn.ModuleList(
+ [
+ OmniWhisperTransformerLayer(
+ activation_function=activation_function,
+ d_model=d_model,
+ attention_heads=encoder_attention_heads,
+ ffn_dim=encoder_ffn_dim,
+ causal=False, # Encoder does not need causal attention
+ attn_type=attn_type, # Pass attention type
+ )
+ for _ in range(encoder_layers)
+ ]
+ )
+
# Last layer normalization for stable output
self.layer_norm = nn.LayerNorm(d_model)
@@ -565,15 +666,20 @@ def __init__(
else:
self.out_proj = None # No need for output projection layer
- def forward(self, input_features: torch.Tensor, input_length: torch.Tensor, output_hidden_states: bool = False):
+ def forward(
+ self,
+ input_features: torch.Tensor,
+ input_length: torch.Tensor,
+ output_hidden_states: bool = False,
+ ):
"""
Forward propagation function to convert input features through Transformer layer to hidden state representation
-
+
Parameters:
input_features (torch.Tensor): Input features, shape [bsz, input_dim, seq_len] (B, input_dim, T)
input_length (torch.Tensor): Input sequence length for each sample, shape [bsz]
output_hidden_states (bool, optional): Whether to return hidden states for each layer, default False
-
+
Returns:
if output_hidden_states is False:
hidden_states (torch.Tensor): Encoded hidden states, shape [bsz, output_dim, seq_len] (B, output_dim, T)
@@ -588,57 +694,79 @@ def forward(self, input_features: torch.Tensor, input_length: torch.Tensor, outp
# If there is input projection layer, map input features from input_dim to d_model
if self.proj is not None:
- hidden_states = self.proj(input_features.permute(0, 2, 1)).permute(0, 2, 1) # [bsz, d_model, seq_len] (B, D, T)
+ hidden_states = self.proj(input_features.permute(0, 2, 1)).permute(
+ 0, 2, 1
+ ) # [bsz, d_model, seq_len] (B, D, T)
else:
hidden_states = input_features # [bsz, d_model, seq_len] (B, D, T)
# Adjust input dimension order to [bsz, seq_len, d_model] for Transformer input
- hidden_states = hidden_states.permute(0, 2, 1) # [bsz, seq_len, d_model] (B, T, D)
-
+ hidden_states = hidden_states.permute(
+ 0, 2, 1
+ ) # [bsz, seq_len, d_model] (B, T, D)
+
# Get batch size and target sequence length
bsz, tgt_len, _ = hidden_states.size()
-
+
# According to current sequence length, take or use complete positional embedding
if tgt_len < self.positional_embedding.shape[0]:
- current_positional_embedding = self.positional_embedding[:tgt_len] # [tgt_len, d_model]
+ current_positional_embedding = self.positional_embedding[
+ :tgt_len
+ ] # [tgt_len, d_model]
else:
- current_positional_embedding = self.positional_embedding # [max_source_positions, d_model]
-
+ current_positional_embedding = (
+ self.positional_embedding
+ ) # [max_source_positions, d_model]
+
# Add input features to positional embedding, convert to float to avoid precision issues
- hidden_states = (hidden_states.to(torch.float32) + current_positional_embedding).to(hidden_states.dtype) # [bsz, seq_len, d_model]
-
+ hidden_states = (
+ hidden_states.to(torch.float32) + current_positional_embedding
+ ).to(
+ hidden_states.dtype
+ ) # [bsz, seq_len, d_model]
+
# Generate sequence mask for processing variable-length sequence
- attention_mask = get_sequence_mask(hidden_states, output_length) # [bsz, tgt_len, 1]
-
+ attention_mask = get_sequence_mask(
+ hidden_states, output_length
+ ) # [bsz, tgt_len, 1]
+
# Initialize hidden states list for storing output for each layer (if needed)
hidden_states_all_layers = () if output_hidden_states else None
-
+
# Process hidden states through Transformer encoder layer by layer
for encoder_layer in self.layers:
if output_hidden_states:
hidden_states_all_layers = hidden_states_all_layers + (hidden_states,)
- hidden_states = encoder_layer(hidden_states, output_length) # [bsz, seq_len, d_model]
-
+ hidden_states = encoder_layer(
+ hidden_states, output_length
+ ) # [bsz, seq_len, d_model]
+
# Normalize hidden states
hidden_states = self.layer_norm(hidden_states) # [bsz, seq_len, d_model]
if output_hidden_states:
hidden_states_all_layers = hidden_states_all_layers + (hidden_states,)
-
+
# Use mask to zero out padding parts and ensure output only retains valid data
- hidden_states = torch.where(attention_mask, hidden_states, 0) # [bsz, seq_len, d_model]
-
+ hidden_states = torch.where(
+ attention_mask, hidden_states, 0
+ ) # [bsz, seq_len, d_model]
+
# Adjust dimension order to [bsz, d_model, seq_len]
- hidden_states = hidden_states.transpose(1, 2) # [bsz, d_model, seq_len] (B, D, T)
+ hidden_states = hidden_states.transpose(
+ 1, 2
+ ) # [bsz, d_model, seq_len] (B, D, T)
# If there is output projection layer, map hidden states from d_model to output_dim
if self.out_proj is not None:
- hidden_states = self.out_proj(hidden_states.permute(0, 2, 1)).permute(0, 2, 1) # [bsz, output_dim, seq_len] (B, output_dim, T)
+ hidden_states = self.out_proj(hidden_states.permute(0, 2, 1)).permute(
+ 0, 2, 1
+ ) # [bsz, output_dim, seq_len] (B, output_dim, T)
if not output_hidden_states:
return hidden_states, output_length
else:
return hidden_states, output_length, hidden_states_all_layers
-
+
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
"""
@@ -1477,4 +1605,3 @@ def forward(self, x, input_length):
x = self.head(x)
output_length = input_length * self.hop_size
return x[:, None, :], output_length
-
diff --git a/XY_Tokenizer/xy_tokenizer/nn/quantizer.py b/XY_Tokenizer/xy_tokenizer/nn/quantizer.py
index a7d28b9..79e2044 100644
--- a/XY_Tokenizer/xy_tokenizer/nn/quantizer.py
+++ b/XY_Tokenizer/xy_tokenizer/nn/quantizer.py
@@ -1,18 +1,21 @@
import logging
+
import torch
+import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
-import torch.distributed as dist
-
from einops import rearrange
from torch.nn.utils import weight_norm
+
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
+
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
def sample_vectors(samples, num):
# samples: (N, D), num_samples: N, feature dim: D
num_samples, device = samples.shape[0], samples.device
@@ -22,35 +25,49 @@ def sample_vectors(samples, num):
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices].float() # (num, D), ensure fp32
+
def kmeans(samples, num_clusters, num_iters=10):
# samples: (N, D), N samples with D dimensions
dim, dtype = samples.shape[-1], torch.float32 # Force fp32
- means = sample_vectors(samples, num_clusters).float() # (num_clusters, D), ensure fp32
-
+ means = sample_vectors(
+ samples, num_clusters
+ ).float() # (num_clusters, D), ensure fp32
+
for _ in range(num_iters):
- dists = -(samples.float().pow(2).sum(1, keepdim=True) - # (N, 1), ensure fp32
- 2 * samples.float() @ means.t() + # (N, num_clusters), ensure fp32
- means.t().float().pow(2).sum(0, keepdim=True)) # (1, num_clusters), ensure fp32
+ dists = -(
+ samples.float().pow(2).sum(1, keepdim=True) # (N, 1), ensure fp32
+ - 2 * samples.float() @ means.t() # (N, num_clusters), ensure fp32
+ + means.t().float().pow(2).sum(0, keepdim=True)
+ ) # (1, num_clusters), ensure fp32
# dists: (N, num_clusters)
buckets = dists.max(dim=-1).indices # (N)
bins = torch.bincount(buckets, minlength=num_clusters) # (num_clusters)
zero_mask = bins == 0 # (num_clusters)
bins_min_clamped = bins.masked_fill(zero_mask, 1) # (num_clusters)
-
- new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32) # (num_clusters, D), ensure fp32
- new_means.scatter_add_(0, buckets.unsqueeze(1).expand(-1, dim), samples.float()) # (num_clusters, D), ensure fp32
+
+ new_means = buckets.new_zeros(
+ num_clusters, dim, dtype=torch.float32
+ ) # (num_clusters, D), ensure fp32
+ new_means.scatter_add_(
+ 0, buckets.unsqueeze(1).expand(-1, dim), samples.float()
+ ) # (num_clusters, D), ensure fp32
new_means = new_means / bins_min_clamped[..., None] # (num_clusters, D)
means = torch.where(zero_mask[..., None], means, new_means) # (num_clusters, D)
-
+
# Final cluster assignments for returning cluster sizes
- dists = -(samples.float().pow(2).sum(1, keepdim=True) -
- 2 * samples.float() @ means.t() +
- means.t().float().pow(2).sum(0, keepdim=True)) # (N, num_clusters), ensure fp32
+ dists = -(
+ samples.float().pow(2).sum(1, keepdim=True)
+ - 2 * samples.float() @ means.t()
+ + means.t().float().pow(2).sum(0, keepdim=True)
+ ) # (N, num_clusters), ensure fp32
buckets = dists.max(dim=-1).indices # (N)
- bins = torch.bincount(buckets, minlength=num_clusters).float() # (num_clusters), ensure fp32
-
+ bins = torch.bincount(
+ buckets, minlength=num_clusters
+ ).float() # (num_clusters), ensure fp32
+
return means, bins # (num_clusters, D), (num_clusters)
+
class VectorQuantize(nn.Module):
def __init__(
self,
@@ -58,11 +75,11 @@ def __init__(
codebook_size,
codebook_dim,
commitment=1.0,
- decay=0.99, # EMA decay
- epsilon=1e-5, # Laplace smoothing epsilon
- threshold_ema_dead=2, # Dead code threshold
- kmeans_init=True, # Use kmeans initialization
- kmeans_iters=10, # Kmeans iterations
+ decay=0.99, # EMA decay
+ epsilon=1e-5, # Laplace smoothing epsilon
+ threshold_ema_dead=2, # Dead code threshold
+ kmeans_init=True, # Use kmeans initialization
+ kmeans_iters=10, # Kmeans iterations
):
super().__init__()
self.input_dim = input_dim
@@ -74,20 +91,32 @@ def __init__(
self.threshold_ema_dead = threshold_ema_dead
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
-
+
if self.input_dim != self.codebook_dim:
- self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) # (B, D, T) -> (B, D', T)
- self.out_project = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1) # (B, D', T) -> (B, D, T)
+ self.in_project = WNConv1d(
+ self.input_dim, self.codebook_dim, kernel_size=1
+ ) # (B, D, T) -> (B, D', T)
+ self.out_project = WNConv1d(
+ self.codebook_dim, self.input_dim, kernel_size=1
+ ) # (B, D', T) -> (B, D, T)
else:
self.in_project = nn.Identity()
self.out_project = nn.Identity()
# Initialize codebook and EMA buffers
init_fn = torch.zeros if kmeans_init else lambda x, y: torch.randn(x, y)
- self.register_buffer("codebook", init_fn(codebook_size, codebook_dim).float()) # (codebook_size, D'), ensure fp32
- self.register_buffer("inited", torch.tensor([not kmeans_init], dtype=torch.bool)) # (1)
- self.register_buffer("cluster_size", torch.zeros(codebook_size).float()) # (codebook_size), ensure fp32
- self.register_buffer("embed_avg", self.codebook.clone().float()) # (codebook_size, D'), ensure fp32
+ self.register_buffer(
+ "codebook", init_fn(codebook_size, codebook_dim).float()
+ ) # (codebook_size, D'), ensure fp32
+ self.register_buffer(
+ "inited", torch.tensor([not kmeans_init], dtype=torch.bool)
+ ) # (1)
+ self.register_buffer(
+ "cluster_size", torch.zeros(codebook_size).float()
+ ) # (codebook_size), ensure fp32
+ self.register_buffer(
+ "embed_avg", self.codebook.clone().float()
+ ) # (codebook_size, D'), ensure fp32
def ema_update(self, encodings, embed_onehot):
# encodings: (B*T, D'), embed_onehot: (B*T, codebook_size)
@@ -96,56 +125,72 @@ def ema_update(self, encodings, embed_onehot):
embed_onehot = embed_onehot.float() # Ensure fp32
cluster_size_new = embed_onehot.sum(0) # (codebook_size)
embed_sum = encodings.t() @ embed_onehot # (D', codebook_size)
-
+
# Distributed reduction
if dist.is_initialized():
dist.all_reduce(cluster_size_new, op=dist.ReduceOp.SUM)
dist.all_reduce(embed_sum, op=dist.ReduceOp.SUM)
-
+
ema_inplace(self.cluster_size, cluster_size_new, self.decay) # (codebook_size)
ema_inplace(self.embed_avg, embed_sum.t(), self.decay) # (codebook_size, D')
-
+
# Laplace smoothing
- cluster_size = (self.cluster_size + self.epsilon) / (self.cluster_size.sum() + self.codebook_size * self.epsilon) # (codebook_size)
+ cluster_size = (self.cluster_size + self.epsilon) / (
+ self.cluster_size.sum() + self.codebook_size * self.epsilon
+ ) # (codebook_size)
cluster_size = cluster_size * self.cluster_size.sum() # (codebook_size)
- self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1)) # (codebook_size, D')
+ self.codebook.copy_(
+ self.embed_avg / cluster_size.unsqueeze(1)
+ ) # (codebook_size, D')
def replace_dead_codes(self, encodings):
# encodings: (B*T, D')
"""Replace dead codes with random samples from current batch"""
if self.threshold_ema_dead == 0:
return
-
+
dead_mask = self.cluster_size < self.threshold_ema_dead # (codebook_size)
if dead_mask.any():
if dist.is_initialized() and dist.get_rank() == 0:
- samples = sample_vectors(encodings.float(), self.codebook_size) # (codebook_size, D'), ensure fp32
+ samples = sample_vectors(
+ encodings.float(), self.codebook_size
+ ) # (codebook_size, D'), ensure fp32
else:
- samples = torch.zeros_like(self.codebook).float() # Placeholder, ensure fp32
-
+ samples = torch.zeros_like(
+ self.codebook
+ ).float() # Placeholder, ensure fp32
+
# Broadcast samples
if dist.is_initialized():
dist.broadcast(samples, src=0)
-
- self.codebook[dead_mask] = samples[:dead_mask.sum()].to(self.codebook.dtype) # Update dead codes
+
+ self.codebook[dead_mask] = samples[: dead_mask.sum()].to(
+ self.codebook.dtype
+ ) # Update dead codes
def init_codebook(self, encodings):
# encodings: (B*T, D')
"""Initialize codebook with k-means and update cluster_size"""
if self.inited.item():
return
-
+
if dist.is_initialized() and dist.get_rank() == 0:
- embed, cluster_sizes = kmeans(encodings.float(), self.codebook_size, self.kmeans_iters) # (codebook_size, D'), (codebook_size), ensure fp32
+ embed, cluster_sizes = kmeans(
+ encodings.float(), self.codebook_size, self.kmeans_iters
+ ) # (codebook_size, D'), (codebook_size), ensure fp32
else:
- embed = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device).float() # ensure fp32
- cluster_sizes = torch.zeros(self.codebook_size, device=encodings.device, dtype=torch.float32) # ensure fp32
-
+ embed = torch.zeros(
+ self.codebook_size, self.codebook_dim, device=encodings.device
+ ).float() # ensure fp32
+ cluster_sizes = torch.zeros(
+ self.codebook_size, device=encodings.device, dtype=torch.float32
+ ) # ensure fp32
+
# Broadcast results
if dist.is_initialized():
dist.broadcast(embed, src=0)
dist.broadcast(cluster_sizes, src=0)
-
+
self.codebook.copy_(embed) # (codebook_size, D')
self.embed_avg.copy_(embed.clone()) # (codebook_size, D')
self.cluster_size.copy_(cluster_sizes.float()) # (codebook_size)
@@ -155,32 +200,39 @@ def forward(self, z): # z: (B, D, T)
# logging.info(f"{self.cluster_size = }, {self.codebook = }, {self.embed_avg = }, {self.inited = }")
z = z.float() # Ensure fp32
z_e = self.in_project(z).float() # (B, D', T), ensure fp32
-
+
# Rearrange for quantization
encodings = rearrange(z_e, "b d t -> (b t) d").float() # (B*T, D'), ensure fp32
-
+
# Initialize codebook if needed
if self.kmeans_init and not self.inited.item():
self.init_codebook(encodings)
# Quantization
- dist = (encodings.pow(2).sum(1, keepdim=True) - # (B*T, 1)
- 2 * encodings @ self.codebook.float().t() + # (B*T, codebook_size)
- self.codebook.float().pow(2).sum(1, keepdim=True).t()) # (1, codebook_size)
+ dist = (
+ encodings.pow(2).sum(1, keepdim=True) # (B*T, 1)
+ - 2 * encodings @ self.codebook.float().t() # (B*T, codebook_size)
+ + self.codebook.float().pow(2).sum(1, keepdim=True).t()
+ ) # (1, codebook_size)
# dist: (B*T, codebook_size)
-
+
indices = (-dist).max(1)[1] # (B*T)
indices = rearrange(indices, "(b t) -> b t", b=z.size(0)) # (B, T)
-
+
# Get quantized vectors
z_q = self.decode_code(indices).float() # (B, D', T), ensure fp32
-
+
# Commitment loss
- commit_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) * self.commitment # (B)
-
+ commit_loss = (
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+ * self.commitment
+ ) # (B)
+
# EMA updates and dead code replacement during training
if self.training and torch.is_grad_enabled():
- embed_onehot = F.one_hot(indices.view(-1), self.codebook_size).float() # (B*T, codebook_size), ensure fp32
+ embed_onehot = F.one_hot(
+ indices.view(-1), self.codebook_size
+ ).float() # (B*T, codebook_size), ensure fp32
self.ema_update(encodings, embed_onehot)
self.replace_dead_codes(encodings)
@@ -188,20 +240,29 @@ def forward(self, z): # z: (B, D, T)
z_q = z_e + (z_q - z_e).detach() # (B, D', T)
z_q = self.out_project(z_q).float() # (B, D, T), ensure fp32
- return z_q, commit_loss, torch.tensor(0.0, device=z.device, dtype=torch.float32), indices, z # (B, D, T), (B), scalar, (B, T), (B, D', T)
+ return (
+ z_q,
+ commit_loss,
+ torch.tensor(0.0, device=z.device, dtype=torch.float32),
+ indices,
+ z,
+ ) # (B, D, T), (B), scalar, (B, T), (B, D', T)
def decode_code(self, embed_id): # embed_id: (B, T)
- return F.embedding(embed_id, self.codebook).transpose(1, 2).float() # (B, D', T), ensure fp32
+ return (
+ F.embedding(embed_id, self.codebook).transpose(1, 2).float()
+ ) # (B, D', T), ensure fp32
+
class ResidualVQ(nn.Module):
def __init__(
self,
- input_dim: int = 1280, # Input dimension, unrelated to RVQ
- rvq_dim = None, # RVQ dimension. If different from input_dim/output_dim, will add input_dim->rvq_dim/rvq_dim->output_dim projection
- output_dim: int = None, # Output dimension, unrelated to RVQ
+ input_dim: int = 1280, # Input dimension, unrelated to RVQ
+ rvq_dim=None, # RVQ dimension. If different from input_dim/output_dim, will add input_dim->rvq_dim/rvq_dim->output_dim projection
+ output_dim: int = None, # Output dimension, unrelated to RVQ
num_quantizers: int = 32,
codebook_size: int = 1024,
- codebook_dim: int = 8, # Dimension of each codebook. If different from rvq_dim, will add rvq_dim->codebook_dim and codebook_dim->rvq_dim projections
+ codebook_dim: int = 8, # Dimension of each codebook. If different from rvq_dim, will add rvq_dim->codebook_dim and codebook_dim->rvq_dim projections
quantizer_dropout: float = 0.5,
decay=0.99,
epsilon=1e-5,
@@ -213,16 +274,24 @@ def __init__(
):
super().__init__()
self.input_dim = input_dim
-
+
self.num_quantizers = num_quantizers
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer_dropout = quantizer_dropout
self.skip_rvq_ratio = skip_rvq_ratio # Store skip probability
self.rvq_dim = rvq_dim
-
- self.input_proj = WNConv1d(input_dim, rvq_dim, kernel_size=1) if input_dim != rvq_dim else nn.Identity()
- self.output_proj = WNConv1d(rvq_dim, output_dim, kernel_size=1) if rvq_dim != output_dim else nn.Identity()
+
+ self.input_proj = (
+ WNConv1d(input_dim, rvq_dim, kernel_size=1)
+ if input_dim != rvq_dim
+ else nn.Identity()
+ )
+ self.output_proj = (
+ WNConv1d(rvq_dim, output_dim, kernel_size=1)
+ if rvq_dim != output_dim
+ else nn.Identity()
+ )
self.quantizers = nn.ModuleList(
[
@@ -241,14 +310,22 @@ def __init__(
]
)
- def forward(self, z, input_length, n_quantizers: int = None): # z: (B, D, T), input_length: (B)
+ def forward(
+ self, z, input_length, n_quantizers: int = None
+ ): # z: (B, D, T), input_length: (B)
z = self.input_proj(z)
- with torch.autocast('cuda', enabled = False):
+ with torch.autocast("cuda", enabled=False):
batch_size, _, max_time = z.shape
- mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) # (B, T)
-
- quantized_out = torch.zeros_like(z, dtype=torch.float32) # (B, D, T), ensure fp32
+ mask = torch.arange(max_time, device=z.device).expand(
+ batch_size, max_time
+ ) < input_length.unsqueeze(
+ 1
+ ) # (B, T)
+
+ quantized_out = torch.zeros_like(
+ z, dtype=torch.float32
+ ) # (B, D, T), ensure fp32
residual = z.clone().float() # (B, D, T), ensure fp32
all_commit_losses = []
@@ -261,36 +338,66 @@ def forward(self, z, input_length, n_quantizers: int = None): # z: (B, D, T), i
skip_mask = None
if self.training and torch.is_grad_enabled() and self.skip_rvq_ratio > 0:
# Generate random mask with skip_rvq_ratio probability
- skip_mask = torch.rand(batch_size, device=z.device) < self.skip_rvq_ratio # (B,)
+ skip_mask = (
+ torch.rand(batch_size, device=z.device) < self.skip_rvq_ratio
+ ) # (B,)
# If all samples are skipped, force the first sample to be unskipped
if skip_mask.all():
- skip_mask[0] = False # Ensure at least one sample (index 0) is not skipped
+ skip_mask[0] = (
+ False # Ensure at least one sample (index 0) is not skipped
+ )
if self.training and torch.is_grad_enabled():
- n_quantizers_tensor = torch.ones((z.shape[0],), dtype=torch.float32, device=z.device) * self.num_quantizers + 1 # (B)
- dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],), dtype=torch.float32, device=z.device) # (B)
+ n_quantizers_tensor = (
+ torch.ones((z.shape[0],), dtype=torch.float32, device=z.device)
+ * self.num_quantizers
+ + 1
+ ) # (B)
+ dropout = torch.randint(
+ 1,
+ self.num_quantizers + 1,
+ (z.shape[0],),
+ dtype=torch.float32,
+ device=z.device,
+ ) # (B)
n_dropout = int(z.shape[0] * self.quantizer_dropout)
n_quantizers_tensor[:n_dropout] = dropout[:n_dropout] # (B)
else:
- n_quantizers_tensor = torch.full((z.shape[0],), n_quantizers, dtype=torch.float32, device=z.device) # (B)
+ n_quantizers_tensor = torch.full(
+ (z.shape[0],), n_quantizers, dtype=torch.float32, device=z.device
+ ) # (B)
for i, quantizer in enumerate(self.quantizers):
if not self.training and i >= n_quantizers:
break
masked_residual = residual * mask.unsqueeze(1) # (B, D, T)
-
+
# If skipping RVQ, directly use input value
if self.training and skip_mask is not None and skip_mask.any():
- z_q_i = torch.zeros_like(masked_residual, dtype=torch.float32) # (B, D, T), ensure fp32
- commit_loss_i = torch.zeros(batch_size, device=z.device, dtype=torch.float32) # (B), ensure fp32
- indices_i = torch.zeros(batch_size, max_time, device=z.device, dtype=torch.long) # (B, T)
- z_e_i = torch.zeros_like(masked_residual, dtype=torch.float32) # (B, D, T), ensure fp32
-
+ z_q_i = torch.zeros_like(
+ masked_residual, dtype=torch.float32
+ ) # (B, D, T), ensure fp32
+ commit_loss_i = torch.zeros(
+ batch_size, device=z.device, dtype=torch.float32
+ ) # (B), ensure fp32
+ indices_i = torch.zeros(
+ batch_size, max_time, device=z.device, dtype=torch.long
+ ) # (B, T)
+ z_e_i = torch.zeros_like(
+ masked_residual, dtype=torch.float32
+ ) # (B, D, T), ensure fp32
+
# Quantize non-skipped samples
non_skipped_mask = ~skip_mask # (B)
if non_skipped_mask.any():
- z_q_i_non_skipped, commit_loss_i_non_skipped, _, indices_i_non_skipped, z_e_i_non_skipped = quantizer(
+ (
+ z_q_i_non_skipped,
+ commit_loss_i_non_skipped,
+ _,
+ indices_i_non_skipped,
+ z_e_i_non_skipped,
+ ) = quantizer(
masked_residual[non_skipped_mask].float() # Ensure fp32
)
z_q_i[non_skipped_mask] = z_q_i_non_skipped
@@ -298,29 +405,48 @@ def forward(self, z, input_length, n_quantizers: int = None): # z: (B, D, T), i
indices_i[non_skipped_mask] = indices_i_non_skipped
z_e_i[non_skipped_mask] = z_e_i_non_skipped
else:
- z_q_i, commit_loss_i, _, indices_i, z_e_i = quantizer(masked_residual.float()) # (B, D, T), (B), scalar, (B, T), (B, D', T), ensure fp32
+ z_q_i, commit_loss_i, _, indices_i, z_e_i = quantizer(
+ masked_residual.float()
+ ) # (B, D, T), (B), scalar, (B, T), (B, D', T), ensure fp32
+
+ quantizer_mask = (
+ torch.full((z.shape[0],), i, device=z.device, dtype=torch.float32)
+ < n_quantizers_tensor
+ ) # (B)
+ update_mask = (mask & quantizer_mask.unsqueeze(-1)).unsqueeze(
+ 1
+ ) # (B, 1, T)
- quantizer_mask = (torch.full((z.shape[0],), i, device=z.device, dtype=torch.float32) < n_quantizers_tensor) # (B)
- update_mask = (mask & quantizer_mask.unsqueeze(-1)).unsqueeze(1) # (B, 1, T)
-
# If skipping, output is directly the input
if skip_mask is not None:
- skip_mask_expanded = skip_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1)
- z_q_i = torch.where(skip_mask_expanded, masked_residual, z_q_i) # (B, D, T)
- commit_loss_i = torch.where(skip_mask, torch.zeros_like(commit_loss_i), commit_loss_i) # (B)
+ skip_mask_expanded = skip_mask.unsqueeze(1).unsqueeze(
+ 2
+ ) # (B, 1, 1)
+ z_q_i = torch.where(
+ skip_mask_expanded, masked_residual, z_q_i
+ ) # (B, D, T)
+ commit_loss_i = torch.where(
+ skip_mask, torch.zeros_like(commit_loss_i), commit_loss_i
+ ) # (B)
quantized_out = quantized_out + z_q_i * update_mask # (B, D, T)
residual_fp32 = residual.to(dtype=torch.float32) # (B, D, T)
z_q_i_fp32 = z_q_i.to(dtype=torch.float32) # (B, D, T)
residual_fp32 = residual_fp32 - z_q_i_fp32 * update_mask # (B, D, T)
- residual = residual_fp32.to(dtype=torch.float32) # (B, D, T), ensure fp32
+ residual = residual_fp32.to(
+ dtype=torch.float32
+ ) # (B, D, T), ensure fp32
valid_mask = mask & quantizer_mask.unsqueeze(-1) # (B, T)
if valid_mask.any():
- commit_loss_i = (commit_loss_i * quantizer_mask).sum() / quantizer_mask.sum() # scalar
+ commit_loss_i = (
+ commit_loss_i * quantizer_mask
+ ).sum() / quantizer_mask.sum() # scalar
else:
- commit_loss_i = torch.tensor(0.0, device=z.device, dtype=torch.float32) # scalar, ensure fp32
+ commit_loss_i = torch.tensor(
+ 0.0, device=z.device, dtype=torch.float32
+ ) # scalar, ensure fp32
all_commit_losses.append(commit_loss_i) # scalar
all_indices.append(indices_i) # (B, T)
@@ -335,16 +461,16 @@ def forward(self, z, input_length, n_quantizers: int = None): # z: (B, D, T), i
quantized_out = self.output_proj(quantized_out)
return (
- quantized_out, # (B, D, T)
- all_indices, # (N, B, T)
- all_commit_losses,# (N)
- all_quantized, # (N, B, D, T)
- output_length, # (B)
+ quantized_out, # (B, D, T)
+ all_indices, # (N, B, T)
+ all_commit_losses, # (N)
+ all_quantized, # (N, B, D, T)
+ output_length, # (B)
)
def decode_codes(self, codes): # codes: (nq, B, T)
"""Decode codes from multiple quantizers to embeddings.
-
+
Args:
codes: Tensor of shape (nq, B, T) containing code indices for each quantizer.
@@ -353,7 +479,9 @@ def decode_codes(self, codes): # codes: (nq, B, T)
"""
nq, B, T = codes.shape
device = codes.device
- emb = torch.zeros(B, self.rvq_dim, T, device=device, dtype=torch.float32) # (B, D, T)
+ emb = torch.zeros(
+ B, self.rvq_dim, T, device=device, dtype=torch.float32
+ ) # (B, D, T)
for i, quantizer in enumerate(self.quantizers[:nq]):
code_i = codes[i] # (B, T)
@@ -361,10 +489,10 @@ def decode_codes(self, codes): # codes: (nq, B, T)
emb += quantized_i # Accumulate quantized embeddings
emb = self.output_proj(emb) # (B, D, T), apply output projection
- return emb # (B, D, T)
+ return emb # (B, D, T)
def ema_inplace(moving_avg, new, decay):
# moving_avg: (codebook_size) or (codebook_size, D'), new: same as moving_avg
"""Update exponential moving average in-place"""
- moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) # ensure fp32
\ No newline at end of file
+ moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) # ensure fp32
diff --git a/examples/example.txt b/examples/example.txt
index 0d4b3bd..db4779f 100644
--- a/examples/example.txt
+++ b/examples/example.txt
@@ -1 +1 @@
-MOSS-TTSD (text to spoken dialogue) is an open-source bilingual spoken dialogue synthesis model that supports both Chinese and English. It can transform dialogue scripts between two speakers into natural, expressive conversational speech. MOSS-TTSD supports voice cloning and single-session speech generation of up to 960 seconds, making it ideal for AI podcast production.
\ No newline at end of file
+MOSS-TTSD (text to spoken dialogue) is an open-source bilingual spoken dialogue synthesis model that supports both Chinese and English. It can transform dialogue scripts between two speakers into natural, expressive conversational speech. MOSS-TTSD supports voice cloning and single-session speech generation of up to 960 seconds, making it ideal for AI podcast production.
diff --git a/examples/examples.jsonl b/examples/examples.jsonl
index c22a263..4a7b346 100644
--- a/examples/examples.jsonl
+++ b/examples/examples.jsonl
@@ -1,2 +1,2 @@
{"base_path":"examples","text": "[S1]诶,我最近看了一篇讲人工智能的文章,还挺有意思的,想跟你聊聊。[S2]哦?是吗,关于啥的啊?又是哪个公司发了什么逆天的新模型吗?[S1]那倒不是,是一个咱们国内的教授,复旦大学的邱锡鹏教授,他提了一个新概念,叫什么,呃,叫情境扩展,Context Scaling。[S2]Context Scaling?情境扩展?听起来有点,呃,有点玄乎啊,这是个啥意思?[S1]对,我一开始也觉得有点抽象,但你看完就觉得,诶,特别有道理。他大概意思就是说啊,咱们现在对人工智能的追求,不能光是把它做得更大,你知道吧,就是不能光堆参数,喂数据。[S2]嗯,是,这个我懂。就好像之前大家都在比谁的模型参数多,几千亿,上万亿的。[S1]对对对,就是那个意思。他说那个时代,算是第一幕,就是模型规模化的胜利,靠堆料,堆出了像这个ChatGPT这样厉害的通用模型。[S2]嗯,是的。[S1]然后呢,现在差不多是第二幕,就是大家发现光堆料好像不行了,收益越来越小,就开始搞一些,呃,后训练的优化。[S2]哦,后训练优化?比如呢?[S1]比如让AI学会用工具,或者搞那个什么思维链,让它像人一样一步一步思考问题,还有就是强化学习,让它自己在游戏里或者写代码的时候自己跟自己玩,然后越变越强。[S2]哦,明白了。就是让它不光是知识多,还得会用,变得更聪明,是这个意思吧。[S1]没错,就是这个理儿。但是呢,这两步走完了,就遇到了新的瓶颈。邱教授就觉得,关键问题出在了这个情境,也就是Context上。[S2]嗯?情境?[S1]是。很多时候AI做不出正确的决定,不是因为它笨,而是因为它没搞明白现在到底是个什么情况,就是我们给它的任务或者情境描述得不够清楚。[S2]哦,原来是这样。[S1]对。所以他觉得下一幕,也就是第三幕,就应该是这个情境扩展,Context Scaling。核心就是要让AI能真正理解那种特别复杂、多变,甚至有点模糊的真实世界的情境。[S2]嗯,这个听起来就高级了。那他说的这个情境,到底指什么啊?不就是我们输进去的那段话吗?[S1]诶,不是那么简单的。他说的情境啊,是个特别立体的东西。它不光包括你说了什么,还包括,呃,时间,地点,谁在说话,目的是什么,甚至还包括咱们文化里那种只可意会不可言传的规则,还有人与人之间的那种默契。[S2]哇,那这个范围可太广了。[S1]是吧。他举了个例子,就比如说,当一个人说不要的时候。[S2]嗯。[S1]你想想,在不同的情况下,这个不要的意思完全不一样。有可能是真的拒绝,也可能是在开玩笑,甚至可能是一种反向的请求,就是我们说的口是心非,嘴上说不要,身体很诚实那种。[S2]哈哈,确实确实,这个AI可怎么判断啊。[S1]对啊,所以就需要理解整个情境。他说这个里面最关键的,就是要捕获一种叫暗知识的东西。[S2]暗知识?听着像武林秘籍一样。[S1]有点那个意思。就是指我们人类会,但是很难用语言清楚地讲出来的那些能力。[S2]哦?比如说?[S1]比如说社交的智慧啊,怎么通过一个眼神,一个停顿,或者语气变化来理解对方的意思。[S2]嗯,是的。[S1]还有就是文化适应能力,在不同的文化里,有些事能做,有些事不能做,这些都没人写在书上,但我们就是知道。[S2]没错。[S1]他说AI要是能学会这些,那才算是真正的智能突破了。这也能解决一些AI安全的问题,比如那个很有名的回形针悖论。[S2]哦,那个我知道,就是你让AI造回形针,它最后可能会为了这个目标把整个世界都给占了,用来造回形针。[S1]对。但如果它有情境智能,它就能理解我们人类社会的复杂性,知道有些事是不能做的,就算你没有明确下指令禁止它。[S2]有道理,有道理。那,那技术上要怎么实现呢?听起来好难啊。[S1]是挺难的。他提了三大技术支柱。第一个叫强交互性。就是AI要不停地跟环境,特别是跟人来互动学习,不光要知道怎么做,还要知道为什么这么做。[S2]嗯,在互动里学习。[S1]第二个叫具身性。就是AI得有个身体,不一定是真的人形机器人,在虚拟世界里也行,总之它得能感知和行动。[S2]哦,明白。[S1]第三个,我觉得是最特别的,叫拟人化。[S2]拟人化?[S1]对,就是说AI需要有类似人类的情感共鸣能力。不是假装有感情,而是真正理解人的情绪,人的偏好,懂得社交的距离感,什么时候该关心,什么时候该保持沉默。[S2]哇,这个要求可太高了,感觉比养个孩子还难。[S1]哈哈,可不是嘛。所以他说这事儿吧,不是要去替代其他的技术路线,而是把它们都整合起来。像什么推理增强啊,多模态啊,强化学习啊,最终都是为了服务于一个目标,就是让AI能深刻地理解情境。[S2]嗯。[S1]所以总的来说,就是别光在已有的路上卷了,要去定义一些大家都觉得重要,但没人说清楚的问题。[S2]确实。听你这么一说,感觉这个情境扩展,Context Scaling,确实给人工智能指了一个挺不一样的方向。不再是冷冰冰的计算,而是要变得更懂我们人类,更懂这个复杂的世界。[S1]是啊,他说这可能是通往通用人工智能,就是AGI,非常关键的一步。[S2]嗯,听起来未来可期啊。","prompt_audio_speaker1":"zh_spk1_moon.wav","prompt_text_speaker1": "周一到周五,每天早晨七点半到九点半的直播片段。言下之意呢,就是废话有点多,大家也别嫌弃,因为这都是直播间最真实的状态了。","prompt_audio_speaker2":"zh_spk2_moon.wav","prompt_text_speaker2": "如果大家想听到更丰富更及时的直播内容,记得在周一到周五准时进入直播间,和大家一起畅聊新消费新科技新趋势。"}
-{"base_path":"examples","text": "[S1]Hey, did you hear about that company called MoSi AI? [S2]MoSi AI? Yeah, I think I've heard of them. Aren't they the ones doing AI stuff? What new thing have they come up with now? [S1]Yeah, that's them! They recently launched this super hot new product called, um, Asteroid. [S2]Asteroid. That's a pretty cool name. Does it mean like the space rock? [S1]Yeah, I think that's what it means. Let me tell you, this thing is incredible. They say it's currently the most realistic, human-like conversational TTS model out there. [S2]Oh, TTS technology? You mean the text-to-speech thing? Aren't there already a lot of those on the market? What makes this one so special? [S1]Well, it's completely different. They say the voice produced by Asteroid sounds almost exactly like a real person talking. And it's super smooth and natural. Not at all like, you know, that stiff robotic tone. [S2]I see. Some voice assistants do still have that mechanical feel, especially during multi-turn conversations. So how amazing is this Asteroid exactly? [S1]I heard they internally call Asteroid China's own version of NotebookLM. [S2]NotebookLM? Oh, I know that one. Isn't that the personal AI that Google made? The one that helps organize notes and answers all kinds of questions? So Asteroid has similar functions? [S1]Right. That's probably what they mean. It's not just that the voice sounds incredibly human. The intelligence level is also really high. It can have these really logical, contextual, in-depth conversations with you. It's just like chatting with a real person. [S2]Wow, that sounds amazing. If they can really achieve that... [S1]Yeah, it's basically like having a personal assistant that's both articulate and really understands you. [S2]Hmm. That does sound appealing. [S1]And some people are saying it's like the, what's it called again in the voice technology circle? Oh right, DeepSeek. [S2]DeepSeek? Isn't that the company making large language models? Their models are pretty popular now. That's high praise. So they're saying Asteroid is top-tier technology? [S1]Yeah, I think that's what they mean. It's like they've reached a whole new level in voice synthesis. Similar to the impact DeepSeek has had in natural language processing. It could be that kind of groundbreaking technology. [S2]If Asteroid is really that impressive, where could it be used? I feel like there must be huge potential there. [S1]Absolutely. Just imagine future smart customer service, audiobook reading, and those virtual livestreamers that are so popular now. The quality would improve dramatically. We might even have personal assistants using Asteroid to talk to us directly. How natural would that be? [S2]Yeah. That does sound exciting. When can we actually try it out? Are there any demos available? [S1]I haven't looked into that carefully yet. But since they've already announced it, I'm guessing it won't be long. I'm really eager to try it and see just how human-like it is. [S2]Yeah, yeah. If it can really deliver what they're promising, getting information and interacting with machines will be so much more convenient. The experience will be much better too. [S1]Exactly, exactly. We're just waiting for MoSi AI to give us this big surprise.", "prompt_audio_speaker1":"m1.wav","prompt_text_speaker1": "How much do you know about her?","prompt_audio_speaker2":"m2.wav","prompt_text_speaker2": "Well, we know this much about her. You've been with her constantly since the first day you met her. And we followed you while you went dining, dancing, and sailing. And last night, I happened to be there when you were having dinner with her at Le Petit Tableau."}
\ No newline at end of file
+{"base_path":"examples","text": "[S1]Hey, did you hear about that company called MoSi AI? [S2]MoSi AI? Yeah, I think I've heard of them. Aren't they the ones doing AI stuff? What new thing have they come up with now? [S1]Yeah, that's them! They recently launched this super hot new product called, um, Asteroid. [S2]Asteroid. That's a pretty cool name. Does it mean like the space rock? [S1]Yeah, I think that's what it means. Let me tell you, this thing is incredible. They say it's currently the most realistic, human-like conversational TTS model out there. [S2]Oh, TTS technology? You mean the text-to-speech thing? Aren't there already a lot of those on the market? What makes this one so special? [S1]Well, it's completely different. They say the voice produced by Asteroid sounds almost exactly like a real person talking. And it's super smooth and natural. Not at all like, you know, that stiff robotic tone. [S2]I see. Some voice assistants do still have that mechanical feel, especially during multi-turn conversations. So how amazing is this Asteroid exactly? [S1]I heard they internally call Asteroid China's own version of NotebookLM. [S2]NotebookLM? Oh, I know that one. Isn't that the personal AI that Google made? The one that helps organize notes and answers all kinds of questions? So Asteroid has similar functions? [S1]Right. That's probably what they mean. It's not just that the voice sounds incredibly human. The intelligence level is also really high. It can have these really logical, contextual, in-depth conversations with you. It's just like chatting with a real person. [S2]Wow, that sounds amazing. If they can really achieve that... [S1]Yeah, it's basically like having a personal assistant that's both articulate and really understands you. [S2]Hmm. That does sound appealing. [S1]And some people are saying it's like the, what's it called again in the voice technology circle? Oh right, DeepSeek. [S2]DeepSeek? Isn't that the company making large language models? Their models are pretty popular now. That's high praise. So they're saying Asteroid is top-tier technology? [S1]Yeah, I think that's what they mean. It's like they've reached a whole new level in voice synthesis. Similar to the impact DeepSeek has had in natural language processing. It could be that kind of groundbreaking technology. [S2]If Asteroid is really that impressive, where could it be used? I feel like there must be huge potential there. [S1]Absolutely. Just imagine future smart customer service, audiobook reading, and those virtual livestreamers that are so popular now. The quality would improve dramatically. We might even have personal assistants using Asteroid to talk to us directly. How natural would that be? [S2]Yeah. That does sound exciting. When can we actually try it out? Are there any demos available? [S1]I haven't looked into that carefully yet. But since they've already announced it, I'm guessing it won't be long. I'm really eager to try it and see just how human-like it is. [S2]Yeah, yeah. If it can really deliver what they're promising, getting information and interacting with machines will be so much more convenient. The experience will be much better too. [S1]Exactly, exactly. We're just waiting for MoSi AI to give us this big surprise.", "prompt_audio_speaker1":"m1.wav","prompt_text_speaker1": "How much do you know about her?","prompt_audio_speaker2":"m2.wav","prompt_text_speaker2": "Well, we know this much about her. You've been with her constantly since the first day you met her. And we followed you while you went dining, dancing, and sailing. And last night, I happened to be there when you were having dinner with her at Le Petit Tableau."}
diff --git a/examples/examples_only_text.jsonl b/examples/examples_only_text.jsonl
index 9255952..1758785 100644
--- a/examples/examples_only_text.jsonl
+++ b/examples/examples_only_text.jsonl
@@ -1 +1 @@
-{"text": "[S1]Hey, did you hear about that company called MoSi AI? [S2]MoSi AI? Yeah, I think I've heard of them. Aren't they the ones doing AI stuff? What new thing have they come up with now? [S1]Yeah, that's them! They recently launched this super hot new product called, um, Asteroid. [S2]Asteroid. That's a pretty cool name. Does it mean like the space rock? [S1]Yeah, I think that's what it means. Let me tell you, this thing is incredible. They say it's currently the most realistic, human-like conversational TTS model out there. [S2]Oh, TTS technology? You mean the text-to-speech thing? Aren't there already a lot of those on the market? What makes this one so special? [S1]Well, it's completely different. They say the voice produced by Asteroid sounds almost exactly like a real person talking. And it's super smooth and natural. Not at all like, you know, that stiff robotic tone."}
\ No newline at end of file
+{"text": "[S1]Hey, did you hear about that company called MoSi AI? [S2]MoSi AI? Yeah, I think I've heard of them. Aren't they the ones doing AI stuff? What new thing have they come up with now? [S1]Yeah, that's them! They recently launched this super hot new product called, um, Asteroid. [S2]Asteroid. That's a pretty cool name. Does it mean like the space rock? [S1]Yeah, I think that's what it means. Let me tell you, this thing is incredible. They say it's currently the most realistic, human-like conversational TTS model out there. [S2]Oh, TTS technology? You mean the text-to-speech thing? Aren't there already a lot of those on the market? What makes this one so special? [S1]Well, it's completely different. They say the voice produced by Asteroid sounds almost exactly like a real person talking. And it's super smooth and natural. Not at all like, you know, that stiff robotic tone."}
diff --git a/examples/examples_single_reference.jsonl b/examples/examples_single_reference.jsonl
index 6f7fe10..64a3ff7 100644
--- a/examples/examples_single_reference.jsonl
+++ b/examples/examples_single_reference.jsonl
@@ -1 +1 @@
-{"base_path":"examples","text": "[S1]诶,跟你说个事儿啊,我最近听了不少那种AI生成的播客,不知道你有没有听过。[S2]哦,听过一些。怎么了,感觉怎么样?[S1]就是……怎么说呢,单听一句话,你觉得,哇,好像跟真人没啥区别。[S2]嗯。[S1]但是,你只要让它说上一段完整的对话,比如俩人聊天那种,那个感觉就立马不对了。[S2]对对对,我懂你的意思。就是那个所谓的“恐怖谷”效应,是吧?听着有点瘆人,感觉特别假,没有那个交流感。[S1]就是这个词儿,恐怖谷。结果你猜怎么着,这个事儿最近好像有救了。[S2]哦?什么情况?[S1]复旦大学那个邱锡鹏教授的团队,就是那个OpenMOSS团队,他们搞了个新东西,叫Moss TTSD。[S2]是吗?专门解决这个问题的?[S1]对,就是专门搞对话语音合成的。他们说,这个模型是基于一百万小时的音频训练出来的,号称要打破AI播客的恐怖谷魔咒。[S2]一百万小时,我的天,这数据量也太吓人了。[S1]是吧。而且最牛的是,这个模型,Moss TTSD,代码和模型权重,全都开源了,还能免费商用。[S2]哇,那可太棒了。意思就是,它不是生成一句话一句话的,而是把一整段多人对话的稿子给它,它直接生成一整段特逼真的对话录音?[S1]没错,就是这个意思。它能把对话里那种节奏啊、语气的变化啊,都给模拟出来。[S2]那效果到底怎么样?有对比吗?[S1]有啊。他们拿现在一个挺火的商业模型,叫豆包的播客生成功能,做了个对比。[S2]哦,豆包我知道,挺多人在用。[S1]结果呢,人家Moss TTSD作为一个开源模型,在情感丰富度啊、语气自然度这些方面,跟豆包比起来,基本上不相上下。[S2]真的假的?一个开源的能做到跟商业模型一个水平,那可太厉害了。[S1]是啊。我听了他们放出来的几段对比音频,确实很惊艳。[S2]那它技术上到底是怎么实现的?肯定有啥独门秘诀吧。[S1]嘿,你还真说对了。他们的核心创新,是一个叫XY Tokenizer的东西。[S2]XY Tokenizer?这是个啥玩意儿?[S1]它算是一个特别聪明的语音编码器。[S2]嗯。[S1]就是说,它在处理声音的时候,不光能明白你说的内容是啥意思,也就是语义信息,还能同时把你说话的那个味儿,那个声学信息,一块儿给编码进去。[S2]哦,我好像有点明白了。就是说,它不光知道“说的是什么”,还知道“是怎么说的”。[S1]对!而且它把这些信息压缩得特别特别小,比特率只有一kbps。这样一来呢,大语言模型学起来就轻松多了,能把声音里那些特别细微的特征都学到。[S2]原来是这样。那数据处理呢?你说那一百万小时的音频,肯定不能直接用吧,得要特别干净的数据才行。[S1]这个他们也下功夫了。他们自己搞了一套效率特别高的数据处理流水线。[S2]嗯。[S1]先用自己内部的一个模型,把一段录音里不同的人声给分开。他们说这个模型比现在市面上开源的、甚至一些商用的都要好。[S2]哇。[S1]分开之后呢,再用一个叫DNSMOS的工具给语音质量打分,只保留那些分数在二点八分以上的高质量片段。[S2]这筛选标准还挺严格的。[S1]可不是嘛。而且针对中文数据,他们还专门训练了一个模型,叫Whisper d,用来做文本转录。这个模型特别牛,连那种俩人说话声音叠在一起的地方,它都能准确地转写出来。[S2]这个确实是痛点,很多模型都搞不定这个。[S1]所以啊,料下足了,最后出来的东西效果就好。测试结果说,它跟顶尖的闭源模型豆包播客模型性能差不多。[S2]嗯。[S1]然后跟另一个开源模型MoonCast比呢,它的韵律和表现力都更自然。[S2]是的。[S1]最关键的一点来了, follow豆包的播客功能比,它除了效果差不多,还支持一个叫“零样本音色克隆”的功能。[S2]零样本音色克隆?这是说……我给它一段我的声音,它就能用我的声音去生成对话了?[S1]bingo。完全正确。就是这个意思。这样一来,可定制的程度就特别高了。[S2]我的天,那这应用场景可就广了去了。[S1]那可不。还有一个特别大的优点,它能一次性生成特别长的音频,最长能到九百六十秒。[S2]九百六十秒,那不就是十六分钟?[S1]对。一次生成十六分钟,就再也不用一小段一小段生成完了再拼起来了。那个拼接造成的停顿感和不自然,就都解决了。[S2]确实,这个对于做长音频,比如播客啊、影视配音啊、有声书什么的,简直是刚需。[S1]是啊。所以说,这个Moss TTSD一出来,感觉以后做播客、搞长篇访谈,甚至数字人直播带货这些事儿,门槛又得降一大截了。[S2]感觉AI播客真的要从“能听”变成“好听”了。对我们内容创作者来说,这绝对是个好工具啊。", "prompt_audio": "single_reference.wav", "prompt_text": "[S1]我意思是说通用计算机是指我们所有人都能用的这种普通计算机的意思啊?[S2]不是,不是。现在通用计算机就,比如说现在买了一台电脑嘛。[S1]嗯。[S2]电脑里面你可以装各种各样的软件,对吧?你可以用来聊微信,你可以用来-[S1]明白,可以做快机用的发票机啊-[S2]对对对对 [S1]...连接各种设备啊什么的。[S2]你只要放不同的软件,[S1]嗯。[S2]它都可以去完成那个软件所定义的相对应的工作。"}
\ No newline at end of file
+{"base_path":"examples","text": "[S1]诶,跟你说个事儿啊,我最近听了不少那种AI生成的播客,不知道你有没有听过。[S2]哦,听过一些。怎么了,感觉怎么样?[S1]就是……怎么说呢,单听一句话,你觉得,哇,好像跟真人没啥区别。[S2]嗯。[S1]但是,你只要让它说上一段完整的对话,比如俩人聊天那种,那个感觉就立马不对了。[S2]对对对,我懂你的意思。就是那个所谓的“恐怖谷”效应,是吧?听着有点瘆人,感觉特别假,没有那个交流感。[S1]就是这个词儿,恐怖谷。结果你猜怎么着,这个事儿最近好像有救了。[S2]哦?什么情况?[S1]复旦大学那个邱锡鹏教授的团队,就是那个OpenMOSS团队,他们搞了个新东西,叫Moss TTSD。[S2]是吗?专门解决这个问题的?[S1]对,就是专门搞对话语音合成的。他们说,这个模型是基于一百万小时的音频训练出来的,号称要打破AI播客的恐怖谷魔咒。[S2]一百万小时,我的天,这数据量也太吓人了。[S1]是吧。而且最牛的是,这个模型,Moss TTSD,代码和模型权重,全都开源了,还能免费商用。[S2]哇,那可太棒了。意思就是,它不是生成一句话一句话的,而是把一整段多人对话的稿子给它,它直接生成一整段特逼真的对话录音?[S1]没错,就是这个意思。它能把对话里那种节奏啊、语气的变化啊,都给模拟出来。[S2]那效果到底怎么样?有对比吗?[S1]有啊。他们拿现在一个挺火的商业模型,叫豆包的播客生成功能,做了个对比。[S2]哦,豆包我知道,挺多人在用。[S1]结果呢,人家Moss TTSD作为一个开源模型,在情感丰富度啊、语气自然度这些方面,跟豆包比起来,基本上不相上下。[S2]真的假的?一个开源的能做到跟商业模型一个水平,那可太厉害了。[S1]是啊。我听了他们放出来的几段对比音频,确实很惊艳。[S2]那它技术上到底是怎么实现的?肯定有啥独门秘诀吧。[S1]嘿,你还真说对了。他们的核心创新,是一个叫XY Tokenizer的东西。[S2]XY Tokenizer?这是个啥玩意儿?[S1]它算是一个特别聪明的语音编码器。[S2]嗯。[S1]就是说,它在处理声音的时候,不光能明白你说的内容是啥意思,也就是语义信息,还能同时把你说话的那个味儿,那个声学信息,一块儿给编码进去。[S2]哦,我好像有点明白了。就是说,它不光知道“说的是什么”,还知道“是怎么说的”。[S1]对!而且它把这些信息压缩得特别特别小,比特率只有一kbps。这样一来呢,大语言模型学起来就轻松多了,能把声音里那些特别细微的特征都学到。[S2]原来是这样。那数据处理呢?你说那一百万小时的音频,肯定不能直接用吧,得要特别干净的数据才行。[S1]这个他们也下功夫了。他们自己搞了一套效率特别高的数据处理流水线。[S2]嗯。[S1]先用自己内部的一个模型,把一段录音里不同的人声给分开。他们说这个模型比现在市面上开源的、甚至一些商用的都要好。[S2]哇。[S1]分开之后呢,再用一个叫DNSMOS的工具给语音质量打分,只保留那些分数在二点八分以上的高质量片段。[S2]这筛选标准还挺严格的。[S1]可不是嘛。而且针对中文数据,他们还专门训练了一个模型,叫Whisper d,用来做文本转录。这个模型特别牛,连那种俩人说话声音叠在一起的地方,它都能准确地转写出来。[S2]这个确实是痛点,很多模型都搞不定这个。[S1]所以啊,料下足了,最后出来的东西效果就好。测试结果说,它跟顶尖的闭源模型豆包播客模型性能差不多。[S2]嗯。[S1]然后跟另一个开源模型MoonCast比呢,它的韵律和表现力都更自然。[S2]是的。[S1]最关键的一点来了, follow豆包的播客功能比,它除了效果差不多,还支持一个叫“零样本音色克隆”的功能。[S2]零样本音色克隆?这是说……我给它一段我的声音,它就能用我的声音去生成对话了?[S1]bingo。完全正确。就是这个意思。这样一来,可定制的程度就特别高了。[S2]我的天,那这应用场景可就广了去了。[S1]那可不。还有一个特别大的优点,它能一次性生成特别长的音频,最长能到九百六十秒。[S2]九百六十秒,那不就是十六分钟?[S1]对。一次生成十六分钟,就再也不用一小段一小段生成完了再拼起来了。那个拼接造成的停顿感和不自然,就都解决了。[S2]确实,这个对于做长音频,比如播客啊、影视配音啊、有声书什么的,简直是刚需。[S1]是啊。所以说,这个Moss TTSD一出来,感觉以后做播客、搞长篇访谈,甚至数字人直播带货这些事儿,门槛又得降一大截了。[S2]感觉AI播客真的要从“能听”变成“好听”了。对我们内容创作者来说,这绝对是个好工具啊。", "prompt_audio": "single_reference.wav", "prompt_text": "[S1]我意思是说通用计算机是指我们所有人都能用的这种普通计算机的意思啊?[S2]不是,不是。现在通用计算机就,比如说现在买了一台电脑嘛。[S1]嗯。[S2]电脑里面你可以装各种各样的软件,对吧?你可以用来聊微信,你可以用来-[S1]明白,可以做快机用的发票机啊-[S2]对对对对 [S1]...连接各种设备啊什么的。[S2]你只要放不同的软件,[S1]嗯。[S2]它都可以去完成那个软件所定义的相对应的工作。"}
diff --git a/finetune/data_preprocess.py b/finetune/data_preprocess.py
index a584e2a..f1faba6 100644
--- a/finetune/data_preprocess.py
+++ b/finetune/data_preprocess.py
@@ -1,13 +1,16 @@
-import json
-import torch
-import numpy as np
import argparse
+import json
import os
import pickle
import sys
+
+import numpy as np
+import torch
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-from generation_utils import process_jsonl_item, normalize_text, load_audio_data
from transformers import AutoTokenizer
+
+from generation_utils import load_audio_data, normalize_text, process_jsonl_item
from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer
MODEL_PATH = "fnlp/MOSS-TTSD-v0.5"
@@ -17,34 +20,53 @@
MAX_CHANNELS = 8
SILENCE_DURATION = 0.0 # Fixed silence duration: 0 seconds
+
def load_tokenizer(model_path, spt_config_path, spt_checkpoint_path):
tokenizer = AutoTokenizer.from_pretrained(model_path)
- spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path)
+ spt = XY_Tokenizer.load_from_checkpoint(
+ config_path=spt_config_path, ckpt_path=spt_checkpoint_path
+ )
spt.eval()
return tokenizer, spt
-def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, reference_audio=None, main_audio=None, max_channels=8, pad_token=1024):
+
+def process_inputs(
+ tokenizer,
+ spt,
+ prompt,
+ text,
+ device,
+ audio_data=None,
+ reference_audio=None,
+ main_audio=None,
+ max_channels=8,
+ pad_token=1024,
+):
# Decompose template into multiple parts
# 1. Style prompt part
seg1 = f"<|begin_of_style|>{prompt}<|end_of_style|>\n<|begin_of_text|>"
inputs1 = np.array(tokenizer.encode(seg1))
inputs_expanded1 = np.full((len(inputs1), max_channels), pad_token)
inputs_expanded1[:, 0] = inputs1
- labels1 = np.full(inputs_expanded1.shape, -100) # Style prompt does not compute loss
-
+ labels1 = np.full(
+ inputs_expanded1.shape, -100
+ ) # Style prompt does not compute loss
+
# 2. Text part
text_tokens = tokenizer.encode(text, add_special_tokens=False)
inputs2 = np.array(text_tokens)
inputs_expanded2 = np.full((len(inputs2), max_channels), pad_token)
inputs_expanded2[:, 0] = inputs2
labels2 = np.full(inputs_expanded2.shape, -100) # Text does not compute loss
-
+
# 3. Text end/speech begin part
seg3 = f"<|end_of_text|>\n<|begin_of_speech|>"
inputs3 = np.array(tokenizer.encode(seg3))
inputs_expanded3 = np.full((len(inputs3), max_channels), pad_token)
inputs_expanded3[:, 0] = inputs3
- labels3 = np.full(inputs_expanded3.shape, -100) # Start marker does not compute loss
+ labels3 = np.full(
+ inputs_expanded3.shape, -100
+ ) # Start marker does not compute loss
# 4. Audio processing part
audio_token = None
@@ -54,32 +76,42 @@ def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, refere
# Add silence to the end of each audio
silence_samples = int(SILENCE_DURATION * 16000)
silence = torch.zeros(1, silence_samples)
-
- # Ensure audio has correct shape [1, samples]
+
+ # Ensure audio has correct shape [1, samples]
if len(reference_audio.shape) == 1:
reference_audio = reference_audio.unsqueeze(0)
if len(main_audio.shape) == 1:
main_audio = main_audio.unsqueeze(0)
-
+
# Add silence to each audio
ref_audio_with_silence = torch.cat([reference_audio, silence], dim=1)
main_audio_with_silence = torch.cat([main_audio, silence], dim=1)
with torch.no_grad():
# Encode two audio files separately
- ref_encode_result = spt.encode([ref_audio_with_silence.squeeze().to(device)])
- main_encode_result = spt.encode([main_audio_with_silence.squeeze().to(device)])
-
- ref_audio_token = ref_encode_result["codes_list"][0].permute(1, 0).cpu().numpy()
- main_audio_token = main_encode_result["codes_list"][0].permute(1, 0).cpu().numpy()
-
+ ref_encode_result = spt.encode(
+ [ref_audio_with_silence.squeeze().to(device)]
+ )
+ main_encode_result = spt.encode(
+ [main_audio_with_silence.squeeze().to(device)]
+ )
+
+ ref_audio_token = (
+ ref_encode_result["codes_list"][0].permute(1, 0).cpu().numpy()
+ )
+ main_audio_token = (
+ main_encode_result["codes_list"][0].permute(1, 0).cpu().numpy()
+ )
+
# Concatenate at token level
- audio_token = np.concatenate([ref_audio_token, main_audio_token], axis=0)
+ audio_token = np.concatenate(
+ [ref_audio_token, main_audio_token], axis=0
+ )
except Exception as e:
print(f"Error processing two audio files: {e}")
raise
-
+
elif audio_data is not None:
# Original format: single audio processing
try:
@@ -102,19 +134,19 @@ def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, refere
if audio_token is not None:
# Add offset (only for the first layer)
audio_token[:, 0] = audio_token[:, 0] + 151665
-
+
# Channel count alignment processing
if audio_token.shape[1] > max_channels:
audio_token = audio_token[:, :max_channels]
elif audio_token.shape[1] < max_channels:
padded = np.full((audio_token.shape[0], max_channels), pad_token)
- padded[:, :audio_token.shape[1]] = audio_token
+ padded[:, : audio_token.shape[1]] = audio_token
audio_token = padded
-
+
labels4 = audio_token.copy() # Audio tokens need to compute loss
else:
raise ValueError("No audio data provided")
-
+
# 5. Speech end part
seg5 = "<|end_of_speech|>"
inputs5 = np.array(tokenizer.encode(seg5))
@@ -124,21 +156,25 @@ def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, refere
labels5[:, 0] = inputs_expanded5[:, 0] # End marker needs to be learned
# Concatenate all parts
- input_ids = np.concatenate([
- inputs_expanded1, # Style prompt
- inputs_expanded2, # Text
- inputs_expanded3, # Speech start marker
- audio_token, # Speech tokens (first layer with offset added)
- inputs_expanded5 # End marker
- ])
-
- labels = np.concatenate([
- labels1, # Style prompt (no loss computation)
- labels2, # Text (no loss computation)
- labels3, # Start marker (no loss computation)
- labels4, # Speech tokens (compute loss)
- labels5 # End marker (compute loss)
- ])
+ input_ids = np.concatenate(
+ [
+ inputs_expanded1, # Style prompt
+ inputs_expanded2, # Text
+ inputs_expanded3, # Speech start marker
+ audio_token, # Speech tokens (first layer with offset added)
+ inputs_expanded5, # End marker
+ ]
+ )
+
+ labels = np.concatenate(
+ [
+ labels1, # Style prompt (no loss computation)
+ labels2, # Text (no loss computation)
+ labels3, # Start marker (no loss computation)
+ labels4, # Speech tokens (compute loss)
+ labels5, # End marker (compute loss)
+ ]
+ )
# Calculate length information
total_length = input_ids.shape[0] # Total token length
@@ -146,23 +182,25 @@ def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, refere
return input_ids, labels, total_length, audio_length
+
def process_data(
- jsonl: str ,
- model_path: str ,
- output_dir: str ,
- data_name: str = "processd_data",
- use_normalize: bool = True):
+ jsonl: str,
+ model_path: str,
+ output_dir: str,
+ data_name: str = "processd_data",
+ use_normalize: bool = True,
+):
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
-
+
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
-
+
# Load models
print("Loading models...")
tokenizer, spt = load_tokenizer(model_path, SPT_CONFIG_PATH, SPT_CHECKPOINT_PATH)
spt = spt.to(device)
-
+
# Load the items from the JSONL file
try:
with open(jsonl, "r") as f:
@@ -179,139 +217,192 @@ def process_data(
all_data = []
offsets = []
tokens_lengths = [] # Added: store total token length
- tims_lengths = [] # Added: store audio token length
-
+ tims_lengths = [] # Added: store audio token length
+
for idx, item in enumerate(items):
# Support two JSONL formats
# Format 1: {"file_path": "path/to/audio.wav", "full_transcript": "speech content..."}
# Format 2: {"reference_audio": "path1", "reference_text": "text1", "audio": "path2", "text": "text2"}
-
+
if "file_path" in item and "full_transcript" in item:
# Original format
file_path = item["file_path"]
full_text = item["full_transcript"]
-
+
# Check if audio file exists
if not file_path:
print(f"Warning: Item {idx} has empty file_path, skipping...")
continue
-
+
if not os.path.exists(file_path):
- print(f"Warning: Audio file not found: {file_path}, skipping item {idx}...")
+ print(
+ f"Warning: Audio file not found: {file_path}, skipping item {idx}..."
+ )
continue
-
+
try:
# load_audio_data already includes 16kHz mono conversion functionality
audio_data = load_audio_data(file_path)
except Exception as e:
- print(f"Warning: Failed to load audio from {file_path}: {e}, skipping item {idx}...")
+ print(
+ f"Warning: Failed to load audio from {file_path}: {e}, skipping item {idx}..."
+ )
continue
-
+
# Apply text normalization based on parameter
if use_normalize:
full_text = normalize_text(full_text)
-
+
# Replace speaker tags
- final_text = full_text.replace("[S1]", "