From 3aec25774f2a137bc7ee9d685907407625a01c4a Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Wed, 24 Dec 2025 22:23:44 +0800 Subject: [PATCH 1/6] Update MIX-LN Update MIX-LN --- configs/acoustic.yaml | 2 + configs/templates/config_acoustic.yaml | 3 + deployment/modules/fastspeech2.py | 21 ++++++- modules/commons/common_layers.py | 82 +++++++++++++++++++++++--- modules/fastspeech/acoustic_encoder.py | 21 ++++--- modules/fastspeech/tts_modules.py | 17 +++--- 6 files changed, 123 insertions(+), 23 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 935d6e16..b8aa4b07 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -51,6 +51,8 @@ use_lang_id: false num_lang: 1 use_spk_id: false num_spk: 1 +use_mix_ln: false +mix_ln_layer: [0, 2] use_energy_embed: false use_breathiness_embed: false use_voicing_embed: false diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 9d63028f..4dad9da4 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -43,6 +43,9 @@ num_lang: 1 use_spk_id: false num_spk: 1 +use_mix_ln: false +mix_ln_layer: [0, 2] + # NOTICE: before enabling variance embeddings, please read the docs at # https://github.com/openvpi/DiffSinger/tree/main/docs/BestPractices.md#choosing-variance-parameters use_energy_embed: false diff --git a/deployment/modules/fastspeech2.py b/deployment/modules/fastspeech2.py index b22590c2..83786281 100644 --- a/deployment/modules/fastspeech2.py +++ b/deployment/modules/fastspeech2.py @@ -18,6 +18,21 @@ f0_mel_max = 1127 * np.log(1 + f0_max / 700) +def uniform_attention_pooling(spk_embed, durations): + B, T_mel, C = spk_embed.shape + T_ph = durations.shape[1] + ph_starts = torch.cumsum(torch.cat([torch.zeros_like(durations[:, :1]), durations[:, :-1]], dim=1), dim=1) + ph_ends = ph_starts + durations + mel_indices = torch.arange(T_mel, device=spk_embed.device).view(1, 1, T_mel) + phoneme_to_mel_mask = (mel_indices >= ph_starts.unsqueeze(-1)) & (mel_indices < ph_ends.unsqueeze(-1)) + uniform_scores = phoneme_to_mel_mask.float() + sum_scores = uniform_scores.sum(dim=2, keepdim=True) + attn_weights = uniform_scores / (sum_scores + (sum_scores == 0).float()) # [B, T_ph, T_mel] + ph_spk_embed = torch.bmm(attn_weights, spk_embed) + + return ph_spk_embed + + def f0_to_coarse(f0): f0_mel = 1127 * (1 + f0 / 700).log() a = (f0_bin - 2) / (f0_mel_max - f0_mel_min) @@ -89,7 +104,11 @@ def forward( extra_embed = dur_embed + lang_embed else: extra_embed = dur_embed - encoded = self.encoder(txt_embed, extra_embed, tokens == PAD_INDEX) + if hparams.get('use_mix_ln', False): + ph_spk_embed = uniform_attention_pooling(spk_embed, durations) + else: + ph_spk_embed = None + encoded = self.encoder(txt_embed, extra_embed, tokens == PAD_INDEX, spk_embed=ph_spk_embed) encoded = F.pad(encoded, (0, 0, 1, 0)) condition = torch.gather(encoded, 1, mel2ph) diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index 0012b99c..690c9b4d 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -174,8 +174,57 @@ def __init__(self, dims): def forward(self, x): return x.transpose(*self.dims) + + +class Mixed_LayerNorm(nn.Module): + def __init__( + self, + channels: int, + condition_channels: int, + beta_distribution_concentration: float = 0.2, + eps: float = 1e-5, + bias: bool = True + ): + super().__init__() + self.channels = channels + self.eps = eps + + self.beta_distribution = torch.distributions.Beta( + beta_distribution_concentration, + beta_distribution_concentration + ) + + self.affine = XavierUniformInitLinear(condition_channels, channels * 2, bias=bias) + if self.affine.bias is not None: + self.affine.bias.data[:channels] = 1 + self.affine.bias.data[channels:] = 0 + + def forward( + self, + x: torch.FloatTensor, + condition: torch.FloatTensor # -> shape [Batch, Cond_d] + ) -> torch.FloatTensor: + x = F.layer_norm(x, normalized_shape=(self.channels,), weight=None, bias=None, eps=self.eps) + + affine_params = self.affine(condition) + if affine_params.ndim == 2: + affine_params = affine_params.unsqueeze(1) + betas, gammas = torch.split(affine_params, self.channels, dim=-1) + + if not self.training or x.size(0) == 1: + return gammas * x + betas + + shuffle_indices = torch.randperm(x.size(0), device=x.device) + shuffled_betas = betas[shuffle_indices] + shuffled_gammas = gammas[shuffle_indices] - + beta_samples = self.beta_distribution.sample((x.size(0), 1, 1)).to(x.device) + mixed_betas = beta_samples * betas + (1 - beta_samples) * shuffled_betas + mixed_gammas = beta_samples * gammas + (1 - beta_samples) * shuffled_gammas + + return mixed_gammas * x + mixed_betas + + class TransformerFFNLayer(nn.Module): def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0., act='gelu'): super().__init__() @@ -284,10 +333,19 @@ def forward(self, x, key_padding_mask=None): class EncSALayer(nn.Module): def __init__(self, c, num_heads, dropout, attention_dropout=0.1, - relu_dropout=0.1, kernel_size=9, act='gelu', rotary_embed=None): + relu_dropout=0.1, kernel_size=9, act='gelu', rotary_embed=None, + layer_idx=None, mix_ln_layer=[] + ): super().__init__() self.dropout = dropout - self.layer_norm1 = LayerNorm(c) + if layer_idx is not None: + self.use_mix_ln = (layer_idx in mix_ln_layer) + else: + self.use_mix_ln = False + if self.use_mix_ln: + self.layer_norm1 = Mixed_LayerNorm(c, c) + else: + self.layer_norm1 = LayerNorm(c) if rotary_embed is None: self.self_attn = MultiheadAttention( c, num_heads, dropout=attention_dropout, bias=False, batch_first=False @@ -298,18 +356,23 @@ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, c, num_heads, dropout=attention_dropout, bias=False, rotary_embed=rotary_embed ) self.use_rope = True - self.layer_norm2 = LayerNorm(c) + if self.use_mix_ln: + self.layer_norm1 = Mixed_LayerNorm(c, c) + else: + self.layer_norm1 = LayerNorm(c) self.ffn = TransformerFFNLayer( c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, act=act ) - def forward(self, x, encoder_padding_mask=None, **kwargs): + def forward(self, x, encoder_padding_mask=None, cond=None, **kwargs): layer_norm_training = kwargs.get('layer_norm_training', None) if layer_norm_training is not None: self.layer_norm1.training = layer_norm_training self.layer_norm2.training = layer_norm_training - residual = x - x = self.layer_norm1(x) + if self.use_mix_ln: + x = self.layer_norm1(x, cond) + else: + x = self.layer_norm1(x) if self.use_rope: x = self.self_attn(x, key_padding_mask=encoder_padding_mask) else: @@ -326,7 +389,10 @@ def forward(self, x, encoder_padding_mask=None, **kwargs): x = x * (1 - encoder_padding_mask.float())[..., None] residual = x - x = self.layer_norm2(x) + if self.use_mix_ln: + x = self.layer_norm2(x, cond) + else: + x = self.layer_norm2(x) x = self.ffn(x) x = F.dropout(x, self.dropout, training=self.training) x = residual + x diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index 868d383f..4435d3a2 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -33,12 +33,18 @@ def __init__(self, vocab_size): self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True) self.dur_embed = Linear(1, hparams['hidden_size']) + self.use_mix_ln = hparams.get('use_mix_ln', False) + if self.use_mix_ln: + self.mix_ln_layer = hparams['mix_ln_layer'] + else: + self.mix_ln_layer = [] self.encoder = FastSpeech2Encoder( hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'], ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'], dropout=hparams['dropout'], num_heads=hparams['num_heads'], use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False), - use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True) + use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True), + mix_ln_layer=self.mix_ln_layer ) self.pitch_embed = Linear(1, hparams['hidden_size']) @@ -119,6 +125,12 @@ def forward( spk_embed_id=None, languages=None, **kwargs ): + if self.use_spk_id: + spk_mix_embed = kwargs.get('spk_mix_embed') + if spk_mix_embed is not None: + spk_embed = spk_mix_embed + else: + spk_embed = self.spk_embed(spk_embed_id)[:, None, :] txt_embed = self.txt_embed(txt_tokens) dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]) if self.use_variance_scaling: @@ -130,7 +142,7 @@ def forward( extra_embed = dur_embed + lang_embed else: extra_embed = dur_embed - encoder_out = self.encoder(txt_embed, extra_embed, txt_tokens == 0) + encoder_out = self.encoder(txt_embed, extra_embed, txt_tokens == 0, spk_embed) encoder_out = F.pad(encoder_out, [0, 0, 1, 0]) mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]]) @@ -150,11 +162,6 @@ def forward( condition = condition + stretch_embed_rnn_out if self.use_spk_id: - spk_mix_embed = kwargs.get('spk_mix_embed') - if spk_mix_embed is not None: - spk_embed = spk_mix_embed - else: - spk_embed = self.spk_embed(spk_embed_id)[:, None, :] condition += spk_embed f0_mel = (1 + f0 / 700).log() diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index cc840aed..ee146b17 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -12,13 +12,14 @@ class TransformerEncoderLayer(nn.Module): - def __init__(self, hidden_size, dropout, kernel_size=None, act='gelu', num_heads=2, rotary_embed=None): + def __init__(self, hidden_size, dropout, kernel_size=None, act='gelu', num_heads=2, rotary_embed=None, layer_idx=None, mix_ln_layer=None): super().__init__() self.op = EncSALayer( hidden_size, num_heads, dropout=dropout, attention_dropout=0.0, relu_dropout=dropout, kernel_size=kernel_size, - act=act, rotary_embed=rotary_embed + act=act, rotary_embed=rotary_embed, + layer_idx=layer_idx, mix_ln_layer=mix_ln_layer ) def forward(self, x, **kwargs): @@ -369,7 +370,8 @@ def mel2ph_to_dur(mel2ph, T_txt, max_dur=None): class FastSpeech2Encoder(nn.Module): def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, ffn_act='gelu', - dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False, rope_interleaved=True): + dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, + use_rope=False, rope_interleaved=True, mix_ln_layer=[]): super().__init__() self.num_layers = num_layers embed_dim = self.hidden_size = hidden_size @@ -383,9 +385,10 @@ def __init__(self, hidden_size, num_layers, TransformerEncoderLayer( self.hidden_size, self.dropout, kernel_size=ffn_kernel_size, act=ffn_act, - num_heads=num_heads, rotary_embed=rotary_embed + num_heads=num_heads, rotary_embed=rotary_embed, + layer_idx=i, mix_ln_layer=mix_ln_layer ) - for _ in range(self.num_layers) + for i in range(self.num_layers) ]) self.layer_norm = nn.LayerNorm(embed_dim) @@ -415,7 +418,7 @@ def forward_embedding(self, main_embed, extra_embed=None, padding_mask=None): x = F.dropout(x, p=self.dropout, training=self.training) return x - def forward(self, main_embed, extra_embed, padding_mask, attn_mask=None, return_hiddens=False): + def forward(self, main_embed, extra_embed, padding_mask, spk_embed=None, attn_mask=None, return_hiddens=False): x = self.forward_embedding(main_embed, extra_embed, padding_mask=padding_mask) # [B, T, H] nonpadding_mask_BT = 1 - padding_mask.float()[:, :, None] # [B, T, 1] @@ -435,7 +438,7 @@ def forward(self, main_embed, extra_embed, padding_mask, attn_mask=None, return_ x = x * nonpadding_mask_BT hiddens = [] for layer in self.layers: - x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_BT + x = layer(x, encoder_padding_mask=padding_mask, cond=spk_embed, attn_mask=attn_mask) * nonpadding_mask_BT if return_hiddens: hiddens.append(x) x = self.layer_norm(x) * nonpadding_mask_BT From cac165d7d702b6a9819eff20a2c8e6dbfa9fe4a2 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Wed, 24 Dec 2025 22:54:49 +0800 Subject: [PATCH 2/6] Update MIX-LN --- modules/commons/common_layers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index 690c9b4d..35f666f8 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -357,9 +357,9 @@ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, ) self.use_rope = True if self.use_mix_ln: - self.layer_norm1 = Mixed_LayerNorm(c, c) + self.layer_norm2 = Mixed_LayerNorm(c, c) else: - self.layer_norm1 = LayerNorm(c) + self.layer_norm2 = LayerNorm(c) self.ffn = TransformerFFNLayer( c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, act=act ) @@ -369,6 +369,7 @@ def forward(self, x, encoder_padding_mask=None, cond=None, **kwargs): if layer_norm_training is not None: self.layer_norm1.training = layer_norm_training self.layer_norm2.training = layer_norm_training + residual = x if self.use_mix_ln: x = self.layer_norm1(x, cond) else: From eef2ac96896975b0bf06ba10335a198fb15dc30c Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Wed, 24 Dec 2025 22:23:44 +0800 Subject: [PATCH 3/6] Update MIX-LN Update MIX-LN Update MIX-LN --- configs/acoustic.yaml | 2 + configs/templates/config_acoustic.yaml | 3 + deployment/modules/fastspeech2.py | 21 ++++++- modules/commons/common_layers.py | 81 +++++++++++++++++++++++--- modules/fastspeech/acoustic_encoder.py | 21 ++++--- modules/fastspeech/tts_modules.py | 17 +++--- 6 files changed, 123 insertions(+), 22 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 935d6e16..b8aa4b07 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -51,6 +51,8 @@ use_lang_id: false num_lang: 1 use_spk_id: false num_spk: 1 +use_mix_ln: false +mix_ln_layer: [0, 2] use_energy_embed: false use_breathiness_embed: false use_voicing_embed: false diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 9d63028f..4dad9da4 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -43,6 +43,9 @@ num_lang: 1 use_spk_id: false num_spk: 1 +use_mix_ln: false +mix_ln_layer: [0, 2] + # NOTICE: before enabling variance embeddings, please read the docs at # https://github.com/openvpi/DiffSinger/tree/main/docs/BestPractices.md#choosing-variance-parameters use_energy_embed: false diff --git a/deployment/modules/fastspeech2.py b/deployment/modules/fastspeech2.py index b22590c2..83786281 100644 --- a/deployment/modules/fastspeech2.py +++ b/deployment/modules/fastspeech2.py @@ -18,6 +18,21 @@ f0_mel_max = 1127 * np.log(1 + f0_max / 700) +def uniform_attention_pooling(spk_embed, durations): + B, T_mel, C = spk_embed.shape + T_ph = durations.shape[1] + ph_starts = torch.cumsum(torch.cat([torch.zeros_like(durations[:, :1]), durations[:, :-1]], dim=1), dim=1) + ph_ends = ph_starts + durations + mel_indices = torch.arange(T_mel, device=spk_embed.device).view(1, 1, T_mel) + phoneme_to_mel_mask = (mel_indices >= ph_starts.unsqueeze(-1)) & (mel_indices < ph_ends.unsqueeze(-1)) + uniform_scores = phoneme_to_mel_mask.float() + sum_scores = uniform_scores.sum(dim=2, keepdim=True) + attn_weights = uniform_scores / (sum_scores + (sum_scores == 0).float()) # [B, T_ph, T_mel] + ph_spk_embed = torch.bmm(attn_weights, spk_embed) + + return ph_spk_embed + + def f0_to_coarse(f0): f0_mel = 1127 * (1 + f0 / 700).log() a = (f0_bin - 2) / (f0_mel_max - f0_mel_min) @@ -89,7 +104,11 @@ def forward( extra_embed = dur_embed + lang_embed else: extra_embed = dur_embed - encoded = self.encoder(txt_embed, extra_embed, tokens == PAD_INDEX) + if hparams.get('use_mix_ln', False): + ph_spk_embed = uniform_attention_pooling(spk_embed, durations) + else: + ph_spk_embed = None + encoded = self.encoder(txt_embed, extra_embed, tokens == PAD_INDEX, spk_embed=ph_spk_embed) encoded = F.pad(encoded, (0, 0, 1, 0)) condition = torch.gather(encoded, 1, mel2ph) diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index 0012b99c..35f666f8 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -174,8 +174,57 @@ def __init__(self, dims): def forward(self, x): return x.transpose(*self.dims) + + +class Mixed_LayerNorm(nn.Module): + def __init__( + self, + channels: int, + condition_channels: int, + beta_distribution_concentration: float = 0.2, + eps: float = 1e-5, + bias: bool = True + ): + super().__init__() + self.channels = channels + self.eps = eps + + self.beta_distribution = torch.distributions.Beta( + beta_distribution_concentration, + beta_distribution_concentration + ) + + self.affine = XavierUniformInitLinear(condition_channels, channels * 2, bias=bias) + if self.affine.bias is not None: + self.affine.bias.data[:channels] = 1 + self.affine.bias.data[channels:] = 0 + + def forward( + self, + x: torch.FloatTensor, + condition: torch.FloatTensor # -> shape [Batch, Cond_d] + ) -> torch.FloatTensor: + x = F.layer_norm(x, normalized_shape=(self.channels,), weight=None, bias=None, eps=self.eps) + + affine_params = self.affine(condition) + if affine_params.ndim == 2: + affine_params = affine_params.unsqueeze(1) + betas, gammas = torch.split(affine_params, self.channels, dim=-1) + + if not self.training or x.size(0) == 1: + return gammas * x + betas + + shuffle_indices = torch.randperm(x.size(0), device=x.device) + shuffled_betas = betas[shuffle_indices] + shuffled_gammas = gammas[shuffle_indices] - + beta_samples = self.beta_distribution.sample((x.size(0), 1, 1)).to(x.device) + mixed_betas = beta_samples * betas + (1 - beta_samples) * shuffled_betas + mixed_gammas = beta_samples * gammas + (1 - beta_samples) * shuffled_gammas + + return mixed_gammas * x + mixed_betas + + class TransformerFFNLayer(nn.Module): def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0., act='gelu'): super().__init__() @@ -284,10 +333,19 @@ def forward(self, x, key_padding_mask=None): class EncSALayer(nn.Module): def __init__(self, c, num_heads, dropout, attention_dropout=0.1, - relu_dropout=0.1, kernel_size=9, act='gelu', rotary_embed=None): + relu_dropout=0.1, kernel_size=9, act='gelu', rotary_embed=None, + layer_idx=None, mix_ln_layer=[] + ): super().__init__() self.dropout = dropout - self.layer_norm1 = LayerNorm(c) + if layer_idx is not None: + self.use_mix_ln = (layer_idx in mix_ln_layer) + else: + self.use_mix_ln = False + if self.use_mix_ln: + self.layer_norm1 = Mixed_LayerNorm(c, c) + else: + self.layer_norm1 = LayerNorm(c) if rotary_embed is None: self.self_attn = MultiheadAttention( c, num_heads, dropout=attention_dropout, bias=False, batch_first=False @@ -298,18 +356,24 @@ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, c, num_heads, dropout=attention_dropout, bias=False, rotary_embed=rotary_embed ) self.use_rope = True - self.layer_norm2 = LayerNorm(c) + if self.use_mix_ln: + self.layer_norm2 = Mixed_LayerNorm(c, c) + else: + self.layer_norm2 = LayerNorm(c) self.ffn = TransformerFFNLayer( c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, act=act ) - def forward(self, x, encoder_padding_mask=None, **kwargs): + def forward(self, x, encoder_padding_mask=None, cond=None, **kwargs): layer_norm_training = kwargs.get('layer_norm_training', None) if layer_norm_training is not None: self.layer_norm1.training = layer_norm_training self.layer_norm2.training = layer_norm_training residual = x - x = self.layer_norm1(x) + if self.use_mix_ln: + x = self.layer_norm1(x, cond) + else: + x = self.layer_norm1(x) if self.use_rope: x = self.self_attn(x, key_padding_mask=encoder_padding_mask) else: @@ -326,7 +390,10 @@ def forward(self, x, encoder_padding_mask=None, **kwargs): x = x * (1 - encoder_padding_mask.float())[..., None] residual = x - x = self.layer_norm2(x) + if self.use_mix_ln: + x = self.layer_norm2(x, cond) + else: + x = self.layer_norm2(x) x = self.ffn(x) x = F.dropout(x, self.dropout, training=self.training) x = residual + x diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index 868d383f..4435d3a2 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -33,12 +33,18 @@ def __init__(self, vocab_size): self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True) self.dur_embed = Linear(1, hparams['hidden_size']) + self.use_mix_ln = hparams.get('use_mix_ln', False) + if self.use_mix_ln: + self.mix_ln_layer = hparams['mix_ln_layer'] + else: + self.mix_ln_layer = [] self.encoder = FastSpeech2Encoder( hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'], ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'], dropout=hparams['dropout'], num_heads=hparams['num_heads'], use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False), - use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True) + use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True), + mix_ln_layer=self.mix_ln_layer ) self.pitch_embed = Linear(1, hparams['hidden_size']) @@ -119,6 +125,12 @@ def forward( spk_embed_id=None, languages=None, **kwargs ): + if self.use_spk_id: + spk_mix_embed = kwargs.get('spk_mix_embed') + if spk_mix_embed is not None: + spk_embed = spk_mix_embed + else: + spk_embed = self.spk_embed(spk_embed_id)[:, None, :] txt_embed = self.txt_embed(txt_tokens) dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]) if self.use_variance_scaling: @@ -130,7 +142,7 @@ def forward( extra_embed = dur_embed + lang_embed else: extra_embed = dur_embed - encoder_out = self.encoder(txt_embed, extra_embed, txt_tokens == 0) + encoder_out = self.encoder(txt_embed, extra_embed, txt_tokens == 0, spk_embed) encoder_out = F.pad(encoder_out, [0, 0, 1, 0]) mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]]) @@ -150,11 +162,6 @@ def forward( condition = condition + stretch_embed_rnn_out if self.use_spk_id: - spk_mix_embed = kwargs.get('spk_mix_embed') - if spk_mix_embed is not None: - spk_embed = spk_mix_embed - else: - spk_embed = self.spk_embed(spk_embed_id)[:, None, :] condition += spk_embed f0_mel = (1 + f0 / 700).log() diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index cc840aed..ee146b17 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -12,13 +12,14 @@ class TransformerEncoderLayer(nn.Module): - def __init__(self, hidden_size, dropout, kernel_size=None, act='gelu', num_heads=2, rotary_embed=None): + def __init__(self, hidden_size, dropout, kernel_size=None, act='gelu', num_heads=2, rotary_embed=None, layer_idx=None, mix_ln_layer=None): super().__init__() self.op = EncSALayer( hidden_size, num_heads, dropout=dropout, attention_dropout=0.0, relu_dropout=dropout, kernel_size=kernel_size, - act=act, rotary_embed=rotary_embed + act=act, rotary_embed=rotary_embed, + layer_idx=layer_idx, mix_ln_layer=mix_ln_layer ) def forward(self, x, **kwargs): @@ -369,7 +370,8 @@ def mel2ph_to_dur(mel2ph, T_txt, max_dur=None): class FastSpeech2Encoder(nn.Module): def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, ffn_act='gelu', - dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False, rope_interleaved=True): + dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, + use_rope=False, rope_interleaved=True, mix_ln_layer=[]): super().__init__() self.num_layers = num_layers embed_dim = self.hidden_size = hidden_size @@ -383,9 +385,10 @@ def __init__(self, hidden_size, num_layers, TransformerEncoderLayer( self.hidden_size, self.dropout, kernel_size=ffn_kernel_size, act=ffn_act, - num_heads=num_heads, rotary_embed=rotary_embed + num_heads=num_heads, rotary_embed=rotary_embed, + layer_idx=i, mix_ln_layer=mix_ln_layer ) - for _ in range(self.num_layers) + for i in range(self.num_layers) ]) self.layer_norm = nn.LayerNorm(embed_dim) @@ -415,7 +418,7 @@ def forward_embedding(self, main_embed, extra_embed=None, padding_mask=None): x = F.dropout(x, p=self.dropout, training=self.training) return x - def forward(self, main_embed, extra_embed, padding_mask, attn_mask=None, return_hiddens=False): + def forward(self, main_embed, extra_embed, padding_mask, spk_embed=None, attn_mask=None, return_hiddens=False): x = self.forward_embedding(main_embed, extra_embed, padding_mask=padding_mask) # [B, T, H] nonpadding_mask_BT = 1 - padding_mask.float()[:, :, None] # [B, T, 1] @@ -435,7 +438,7 @@ def forward(self, main_embed, extra_embed, padding_mask, attn_mask=None, return_ x = x * nonpadding_mask_BT hiddens = [] for layer in self.layers: - x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_BT + x = layer(x, encoder_padding_mask=padding_mask, cond=spk_embed, attn_mask=attn_mask) * nonpadding_mask_BT if return_hiddens: hiddens.append(x) x = self.layer_norm(x) * nonpadding_mask_BT From ab01df47329d549fa57feb9289a5a9985a4159e2 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Wed, 24 Dec 2025 23:07:20 +0800 Subject: [PATCH 4/6] fix 'https://github.com/KakaruHayate/DiffSinger/issues/40' --- basics/base_task.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/basics/base_task.py b/basics/base_task.py index 656893d9..1c940dbb 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -102,6 +102,8 @@ def freeze_params(self) -> None: freeze_key = self.get_need_freeze_state_dict_key(model_state_dict=model_state_dict) for i in freeze_key: + if 'cached_freqs' or 'inv_freq' in name: + continue params=self.get_parameter(i) params.requires_grad = False From 4052654b3c2e387bcc9a41f8e8a3f481efba5c3a Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Wed, 24 Dec 2025 23:19:01 +0800 Subject: [PATCH 5/6] Revert "fix 'https://github.com/KakaruHayate/DiffSinger/issues/40'" This reverts commit ab01df47329d549fa57feb9289a5a9985a4159e2. --- basics/base_task.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index 1c940dbb..656893d9 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -102,8 +102,6 @@ def freeze_params(self) -> None: freeze_key = self.get_need_freeze_state_dict_key(model_state_dict=model_state_dict) for i in freeze_key: - if 'cached_freqs' or 'inv_freq' in name: - continue params=self.get_parameter(i) params.requires_grad = False From d93de12f13c88c312ce8c269f3a305e83da7fcc8 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Wed, 24 Dec 2025 23:25:13 +0800 Subject: [PATCH 6/6] Use MIX-LN default --- configs/templates/config_acoustic.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 4dad9da4..1a2a8a57 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -43,7 +43,7 @@ num_lang: 1 use_spk_id: false num_spk: 1 -use_mix_ln: false +use_mix_ln: true mix_ln_layer: [0, 2] # NOTICE: before enabling variance embeddings, please read the docs at