Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ use_lang_id: false
num_lang: 1
use_spk_id: false
num_spk: 1
use_mix_ln: false
mix_ln_layer: [0, 2]
use_energy_embed: false
use_breathiness_embed: false
use_voicing_embed: false
Expand Down
3 changes: 3 additions & 0 deletions configs/templates/config_acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ num_lang: 1
use_spk_id: false
num_spk: 1

use_mix_ln: true
mix_ln_layer: [0, 2]

# NOTICE: before enabling variance embeddings, please read the docs at
# https://github.com/openvpi/DiffSinger/tree/main/docs/BestPractices.md#choosing-variance-parameters
use_energy_embed: false
Expand Down
21 changes: 20 additions & 1 deletion deployment/modules/fastspeech2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,21 @@
f0_mel_max = 1127 * np.log(1 + f0_max / 700)


def uniform_attention_pooling(spk_embed, durations):
B, T_mel, C = spk_embed.shape
T_ph = durations.shape[1]
ph_starts = torch.cumsum(torch.cat([torch.zeros_like(durations[:, :1]), durations[:, :-1]], dim=1), dim=1)
ph_ends = ph_starts + durations
mel_indices = torch.arange(T_mel, device=spk_embed.device).view(1, 1, T_mel)
phoneme_to_mel_mask = (mel_indices >= ph_starts.unsqueeze(-1)) & (mel_indices < ph_ends.unsqueeze(-1))
uniform_scores = phoneme_to_mel_mask.float()
sum_scores = uniform_scores.sum(dim=2, keepdim=True)
attn_weights = uniform_scores / (sum_scores + (sum_scores == 0).float()) # [B, T_ph, T_mel]
ph_spk_embed = torch.bmm(attn_weights, spk_embed)

return ph_spk_embed


def f0_to_coarse(f0):
f0_mel = 1127 * (1 + f0 / 700).log()
a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
Expand Down Expand Up @@ -89,7 +104,11 @@ def forward(
extra_embed = dur_embed + lang_embed
else:
extra_embed = dur_embed
encoded = self.encoder(txt_embed, extra_embed, tokens == PAD_INDEX)
if hparams.get('use_mix_ln', False):
ph_spk_embed = uniform_attention_pooling(spk_embed, durations)
else:
ph_spk_embed = None
encoded = self.encoder(txt_embed, extra_embed, tokens == PAD_INDEX, spk_embed=ph_spk_embed)
encoded = F.pad(encoded, (0, 0, 1, 0))
condition = torch.gather(encoded, 1, mel2ph)

Expand Down
81 changes: 74 additions & 7 deletions modules/commons/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,57 @@ def __init__(self, dims):

def forward(self, x):
return x.transpose(*self.dims)


class Mixed_LayerNorm(nn.Module):
def __init__(
self,
channels: int,
condition_channels: int,
beta_distribution_concentration: float = 0.2,
eps: float = 1e-5,
bias: bool = True
):
super().__init__()
self.channels = channels
self.eps = eps

self.beta_distribution = torch.distributions.Beta(
beta_distribution_concentration,
beta_distribution_concentration
)

self.affine = XavierUniformInitLinear(condition_channels, channels * 2, bias=bias)
if self.affine.bias is not None:
self.affine.bias.data[:channels] = 1
self.affine.bias.data[channels:] = 0

def forward(
self,
x: torch.FloatTensor,
condition: torch.FloatTensor # -> shape [Batch, Cond_d]
) -> torch.FloatTensor:
x = F.layer_norm(x, normalized_shape=(self.channels,), weight=None, bias=None, eps=self.eps)

affine_params = self.affine(condition)
if affine_params.ndim == 2:
affine_params = affine_params.unsqueeze(1)
betas, gammas = torch.split(affine_params, self.channels, dim=-1)

if not self.training or x.size(0) == 1:
return gammas * x + betas

shuffle_indices = torch.randperm(x.size(0), device=x.device)
shuffled_betas = betas[shuffle_indices]
shuffled_gammas = gammas[shuffle_indices]


beta_samples = self.beta_distribution.sample((x.size(0), 1, 1)).to(x.device)
mixed_betas = beta_samples * betas + (1 - beta_samples) * shuffled_betas
mixed_gammas = beta_samples * gammas + (1 - beta_samples) * shuffled_gammas

return mixed_gammas * x + mixed_betas


class TransformerFFNLayer(nn.Module):
def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0., act='gelu'):
super().__init__()
Expand Down Expand Up @@ -284,10 +333,19 @@ def forward(self, x, key_padding_mask=None):

class EncSALayer(nn.Module):
def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
relu_dropout=0.1, kernel_size=9, act='gelu', rotary_embed=None):
relu_dropout=0.1, kernel_size=9, act='gelu', rotary_embed=None,
layer_idx=None, mix_ln_layer=[]
):
super().__init__()
self.dropout = dropout
self.layer_norm1 = LayerNorm(c)
if layer_idx is not None:
self.use_mix_ln = (layer_idx in mix_ln_layer)
else:
self.use_mix_ln = False
if self.use_mix_ln:
self.layer_norm1 = Mixed_LayerNorm(c, c)
else:
self.layer_norm1 = LayerNorm(c)
if rotary_embed is None:
self.self_attn = MultiheadAttention(
c, num_heads, dropout=attention_dropout, bias=False, batch_first=False
Expand All @@ -298,18 +356,24 @@ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
c, num_heads, dropout=attention_dropout, bias=False, rotary_embed=rotary_embed
)
self.use_rope = True
self.layer_norm2 = LayerNorm(c)
if self.use_mix_ln:
self.layer_norm2 = Mixed_LayerNorm(c, c)
else:
self.layer_norm2 = LayerNorm(c)
self.ffn = TransformerFFNLayer(
c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, act=act
)

def forward(self, x, encoder_padding_mask=None, **kwargs):
def forward(self, x, encoder_padding_mask=None, cond=None, **kwargs):
layer_norm_training = kwargs.get('layer_norm_training', None)
if layer_norm_training is not None:
self.layer_norm1.training = layer_norm_training
self.layer_norm2.training = layer_norm_training
residual = x
x = self.layer_norm1(x)
if self.use_mix_ln:
x = self.layer_norm1(x, cond)
else:
x = self.layer_norm1(x)
if self.use_rope:
x = self.self_attn(x, key_padding_mask=encoder_padding_mask)
else:
Expand All @@ -326,7 +390,10 @@ def forward(self, x, encoder_padding_mask=None, **kwargs):
x = x * (1 - encoder_padding_mask.float())[..., None]

residual = x
x = self.layer_norm2(x)
if self.use_mix_ln:
x = self.layer_norm2(x, cond)
else:
x = self.layer_norm2(x)
x = self.ffn(x)
x = F.dropout(x, self.dropout, training=self.training)
x = residual + x
Expand Down
21 changes: 14 additions & 7 deletions modules/fastspeech/acoustic_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,18 @@ def __init__(self, vocab_size):
self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True)

self.dur_embed = Linear(1, hparams['hidden_size'])
self.use_mix_ln = hparams.get('use_mix_ln', False)
if self.use_mix_ln:
self.mix_ln_layer = hparams['mix_ln_layer']
else:
self.mix_ln_layer = []
self.encoder = FastSpeech2Encoder(
hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'],
ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'],
dropout=hparams['dropout'], num_heads=hparams['num_heads'],
use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False),
use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True)
use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True),
mix_ln_layer=self.mix_ln_layer
)

self.pitch_embed = Linear(1, hparams['hidden_size'])
Expand Down Expand Up @@ -119,6 +125,12 @@ def forward(
spk_embed_id=None, languages=None,
**kwargs
):
if self.use_spk_id:
spk_mix_embed = kwargs.get('spk_mix_embed')
if spk_mix_embed is not None:
spk_embed = spk_mix_embed
else:
spk_embed = self.spk_embed(spk_embed_id)[:, None, :]
txt_embed = self.txt_embed(txt_tokens)
dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1])
if self.use_variance_scaling:
Expand All @@ -130,7 +142,7 @@ def forward(
extra_embed = dur_embed + lang_embed
else:
extra_embed = dur_embed
encoder_out = self.encoder(txt_embed, extra_embed, txt_tokens == 0)
encoder_out = self.encoder(txt_embed, extra_embed, txt_tokens == 0, spk_embed)

encoder_out = F.pad(encoder_out, [0, 0, 1, 0])
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
Expand All @@ -150,11 +162,6 @@ def forward(
condition = condition + stretch_embed_rnn_out

if self.use_spk_id:
spk_mix_embed = kwargs.get('spk_mix_embed')
if spk_mix_embed is not None:
spk_embed = spk_mix_embed
else:
spk_embed = self.spk_embed(spk_embed_id)[:, None, :]
condition += spk_embed

f0_mel = (1 + f0 / 700).log()
Expand Down
17 changes: 10 additions & 7 deletions modules/fastspeech/tts_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@


class TransformerEncoderLayer(nn.Module):
def __init__(self, hidden_size, dropout, kernel_size=None, act='gelu', num_heads=2, rotary_embed=None):
def __init__(self, hidden_size, dropout, kernel_size=None, act='gelu', num_heads=2, rotary_embed=None, layer_idx=None, mix_ln_layer=None):
super().__init__()
self.op = EncSALayer(
hidden_size, num_heads, dropout=dropout,
attention_dropout=0.0, relu_dropout=dropout,
kernel_size=kernel_size,
act=act, rotary_embed=rotary_embed
act=act, rotary_embed=rotary_embed,
layer_idx=layer_idx, mix_ln_layer=mix_ln_layer
)

def forward(self, x, **kwargs):
Expand Down Expand Up @@ -369,7 +370,8 @@ def mel2ph_to_dur(mel2ph, T_txt, max_dur=None):
class FastSpeech2Encoder(nn.Module):
def __init__(self, hidden_size, num_layers,
ffn_kernel_size=9, ffn_act='gelu',
dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False, rope_interleaved=True):
dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True,
use_rope=False, rope_interleaved=True, mix_ln_layer=[]):
super().__init__()
self.num_layers = num_layers
embed_dim = self.hidden_size = hidden_size
Expand All @@ -383,9 +385,10 @@ def __init__(self, hidden_size, num_layers,
TransformerEncoderLayer(
self.hidden_size, self.dropout,
kernel_size=ffn_kernel_size, act=ffn_act,
num_heads=num_heads, rotary_embed=rotary_embed
num_heads=num_heads, rotary_embed=rotary_embed,
layer_idx=i, mix_ln_layer=mix_ln_layer
)
for _ in range(self.num_layers)
for i in range(self.num_layers)
])
self.layer_norm = nn.LayerNorm(embed_dim)

Expand Down Expand Up @@ -415,7 +418,7 @@ def forward_embedding(self, main_embed, extra_embed=None, padding_mask=None):
x = F.dropout(x, p=self.dropout, training=self.training)
return x

def forward(self, main_embed, extra_embed, padding_mask, attn_mask=None, return_hiddens=False):
def forward(self, main_embed, extra_embed, padding_mask, spk_embed=None, attn_mask=None, return_hiddens=False):
x = self.forward_embedding(main_embed, extra_embed, padding_mask=padding_mask) # [B, T, H]
nonpadding_mask_BT = 1 - padding_mask.float()[:, :, None] # [B, T, 1]

Expand All @@ -435,7 +438,7 @@ def forward(self, main_embed, extra_embed, padding_mask, attn_mask=None, return_
x = x * nonpadding_mask_BT
hiddens = []
for layer in self.layers:
x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_BT
x = layer(x, encoder_padding_mask=padding_mask, cond=spk_embed, attn_mask=attn_mask) * nonpadding_mask_BT
if return_hiddens:
hiddens.append(x)
x = self.layer_norm(x) * nonpadding_mask_BT
Expand Down