diff --git a/LICENSE.txt b/LICENSE.txt index 3334c19..ecfe679 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -2,6 +2,7 @@ MIT License Copyright (c) 2019 fatchord (https://github.com/fatchord) Copyright (c) 2019 mkotha (https://github.com/mkotha) +Copyright (c) 2019 shaojinding (https://github.com/shaojinding/GroupLatentEmbedding) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index fbf48c9..53b5904 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,23 @@ -# WaveRNN + VQ-VAE +# Group Latent Embedding for Vector Quantized Variational Autoencoder in Non-Parallel Voice Conversion -This is a Pytorch implementation of [WaveRNN]( -https://arxiv.org/abs/1802.08435v1). Currently 3 top-level networks are -provided: +Code for this paper [Group Latent Embedding for Vector Quantized Variational Autoencoder in Non-Parallel Voice Conversion](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/1198.pdf) -* A [VQ-VAE](https://avdnoord.github.io/homepage/vqvae/) implementation with a - WaveRNN decoder. Trained on a multispeaker dataset of speech, it can - demonstrate speech reconstruction and speaker conversion. -* A vocoder implementation. Trained on a single-speaker dataset, it can turn a - mel spectrogram into raw waveform. -* An unconditioned WaveRNN. Trained on a single-speaker dataset, it can generate - random speech. +Shaojin Ding, Ricardo Gutierrez-Osuna -[Audio samples](https://mkotha.github.io/WaveRNN/). +In INTERSPEECH 2019 -It has been tested with the following datasets. +This is a Pytorch implementation. This implementation is based on the VQ-VAE-WaveRNN implementation at [https://github.com/mkotha/WaveRNN](https://github.com/mkotha/WaveRNN). -Multispeaker datasets: +## Dataset: * [VCTK](https://datashare.is.ed.ac.uk/handle/10283/2651) + * [Audio samples](https://shaojinding.github.io/samples/gle/gle_demo). + * [Trained model](https://drive.google.com/file/d/1W4lA37_susadCY5UQUPPbaqKbRNfNUW7/view?usp=sharing). -Single-speaker datasets: +## Preparation -* [LJ Speech](https://keithito.com/LJ-Speech-Dataset/) +The preparation is similar to that at [https://github.com/mkotha/WaveRNN](https://github.com/mkotha/WaveRNN). We repeat it here for convenience. -## Preparation ### Requirements @@ -52,27 +45,16 @@ You can skip this section if you don't need a multi-speaker dataset. 3. In `config.py`, set `multi_speaker_data_path` to point to the output directory. -### Preparing LJ-Speech - -You can skip this section if you don't need a single-speaker dataset. - -1. Download and uncompress [the LJ speech dataset]( - https://keithito.com/LJ-Speech-Dataset/). -2. `python preprocess16.py /path/to/dataset/LJSpeech-1.1/wavs - /path/to/output/directory` -3. In `config.py`, set `single_speaker_data_path` to point to the output - directory. ## Usage -`wavernn.py` is the entry point: +To run Group Latent Embedding: ``` -$ python wavernn.py +$ python wavernn.py -m vqvae_group --num-group 41 --num-sample 10 ``` -By default, it trains a VQ-VAE model. The `-m` option can be used to tell the -the script to train a different model. +The `-m` option can be used to tell the the script what model to train. By default, it trains a vanilla VQ-VAE model. Trained models are saved under the `model_checkpoints` directory. @@ -85,47 +67,21 @@ goes under the `model_outputs` directory. When the `-g` option is given, the script produces the output using the saved model, rather than training it. -# Deviations from the papers - -I deviated from the papers in some details, sometimes because I was lazy, and -sometimes because I was unable to get good results without it. Below is a -(probably incomplete) list of deviations. - -All models: - -* The sampling rate is 22.05kHz. - -VQ-VAE: - -* I normalize each latent embedding vector, so that it's on the unit 128- - dimensional sphere. Without this change, I was unable to get good utilization - of the embedding vectors. -* In the early stage of training, I scale with a small number the penalty term - that apply to the input to the VQ layer. Without this, the input very often - collapses into a degenerate distribution which always selects the same - embedding vector. -* During training, the target audio signal (which is also the input signal) is - translated along the time axis by a random amount, uniformly chosen from - [-128, 127] samples. Less importantly, some additive and multiplicative - Gaussian noise is also applied to each audio sample. Without these types of - noise, the feature captured by the model tended to be very sensitive to small - purterbations to the input, and the subjective quality of the model output - kept descreasing after a certain point in training. -* The decoder is based on WaveRNN instead of WaveNet. See the next section for - details about this network. - -# Context stacks - -The VQ-VAE implementation uses a WaveRNN-based decoder instead of a WaveNet- -based decoder found in the paper. This is a WaveRNN network augmented -with a context stack to extend the receptive field. This network is -defined in `layers/overtone.py`. - -The network has 6 convolutions with stride 2 to generate 64x downsampled -'summary' of the waveform, and then 4 layers of upsampling RNNs, the last of -which is the WaveRNN layer. It also has U-net-like skip connections that -connect layers with the same operating frequency. +`--num-group` specifies the number of groups. `--num-sample` specifies the number of atoms in each group. Note that num-group times num-sample should be equal to the total number of atoms in the embedding dictionary (`n_classes` in class `VectorQuantGroup` in `vector_quant.py`) # Acknowledgement -The code is based on [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN). +The code is based on [mkotha/WaveRNN](https://github.com/mkotha/WaveRNN). + +# Cite the work +``` +@inproceedings{Ding2019, + author={Shaojin Ding and Ricardo Gutierrez-Osuna}, + title={{Group Latent Embedding for Vector Quantized Variational Autoencoder in Non-Parallel Voice Conversion}}, + year=2019, + booktitle={Proc. Interspeech 2019}, + pages={724--728}, + doi={10.21437/Interspeech.2019-1198}, + url={http://dx.doi.org/10.21437/Interspeech.2019-1198} +} +``` diff --git a/layers/vector_quant.py b/layers/vector_quant.py index 75b454f..5ecf563 100644 --- a/layers/vector_quant.py +++ b/layers/vector_quant.py @@ -3,13 +3,15 @@ import torch.nn.functional as F import math import utils.logger as logger +import numpy as np +from layers.attention import EmbeddingAttention class VectorQuant(nn.Module): """ Input: (N, samples, n_channels, vec_len) numeric tensor Output: (N, samples, n_channels, vec_len) numeric tensor """ - def __init__(self, n_channels, n_classes, vec_len, normalize=False): + def __init__(self, n_channels, n_classes, vec_len, num_group, num_sample, normalize=False): super().__init__() if normalize: target_scale = 0.06 @@ -66,3 +68,111 @@ def after_update(self): with torch.no_grad(): target_norm = self.embedding_scale * math.sqrt(self.embedding0.size(2)) self.embedding0.mul_(target_norm / self.embedding0.norm(dim=2, keepdim=True)) + + +class VectorQuantGroup(nn.Module): + """ + Input: (N, samples, n_channels, vec_len) numeric tensor + Output: (N, samples, n_channels, vec_len) numeric tensor + """ + def __init__(self, n_channels, n_classes, vec_len, num_group, num_sample, normalize=False): + super().__init__() + if normalize: + target_scale = 0.06 + self.embedding_scale = target_scale + self.normalize_scale = target_scale + else: + self.embedding_scale = 1e-3 + self.normalize_scale = None + + self.n_classes = n_classes + self._num_group = num_group + self._num_sample = num_sample + if not self.n_classes % self._num_group == 0: + raise ValueError('num of embeddings in each group should be an integer') + self._num_classes_per_group = int(self.n_classes / self._num_group) + + self.embedding0 = nn.Parameter(torch.randn(n_channels, n_classes, vec_len, requires_grad=True) * self.embedding_scale) + self.offset = torch.arange(n_channels).cuda() * n_classes + # self.offset: (n_channels) long tensor + self.after_update() + + def forward(self, x0): + if self.normalize_scale: + target_norm = self.normalize_scale * math.sqrt(x0.size(3)) + x = target_norm * x0 / x0.norm(dim=3, keepdim=True) + embedding = target_norm * self.embedding0 / self.embedding0.norm(dim=2, keepdim=True) + else: + x = x0 + embedding = self.embedding0 + #logger.log(f'std[x] = {x.std()}') + x1 = x.reshape(x.size(0) * x.size(1), x.size(2), 1, x.size(3)) + # x1: (N*samples, n_channels, 1, vec_len) numeric tensor + + # Perform chunking to avoid overflowing GPU RAM. + index_chunks = [] + prob_chunks = [] + for x1_chunk in x1.split(512, dim=0): + d = (x1_chunk - embedding).norm(dim=3) + + # Compute the group-wise distance + d_group = torch.zeros(x1_chunk.shape[0], 1, self._num_group).to(torch.device('cuda')) + for i in range(self._num_group): + d_group[:, :, i] = torch.mean( + d[:, :, i * self._num_classes_per_group: (i + 1) * self._num_classes_per_group], 2) + + # Find the nearest group + index_chunk_group = d_group.argmin(dim=2) + + # Generate mask for the nearest group + index_chunk_group = index_chunk_group.repeat(1, self._num_classes_per_group) + index_chunk_group = torch.mul(self._num_classes_per_group, index_chunk_group) + idx_mtx = torch.LongTensor([x for x in range(self._num_classes_per_group)]).unsqueeze(0).cuda() + index_chunk_group += idx_mtx + encoding_mask = torch.zeros(x1_chunk.shape[0], self.n_classes).cuda() + encoding_mask.scatter_(1, index_chunk_group, 1) + + # Compute the weight atoms in the group + encoding_prob = torch.div(1, d.squeeze()) + + # Apply the mask + masked_encoding_prob = torch.mul(encoding_mask, encoding_prob) + p, idx = masked_encoding_prob.sort(dim=1, descending=True) + prob_chunks.append(p[:, :self._num_sample]) + index_chunks.append(idx[:, :self._num_sample]) + + + + index = torch.cat(index_chunks, dim=0) + prob_dist = torch.cat(prob_chunks, dim=0) + prob_dist = F.normalize(prob_dist, p=1, dim=1) + # index: (N*samples, n_channels) long tensor + if True: # compute the entropy + hist = index[:, 0].float().cpu().histc(bins=self.n_classes, min=-0.5, max=self.n_classes - 0.5) + prob = hist.masked_select(hist > 0) / len(index) + entropy = - (prob * prob.log()).sum().item() + #logger.log(f'entrypy: {entropy:#.4}/{math.log(self.n_classes):#.4}') + else: + entropy = 0 + index1 = (index + self.offset) + # index1: (N*samples*n_channels) long tensor + output_list = [] + for i in range(self._num_sample): + output_list.append(torch.mul(embedding.view(-1, embedding.size(2)).index_select(dim=0, index=index1[:, i]), prob_dist[:, i].unsqueeze(1).detach())) + + output_cat = torch.stack(output_list, dim=2) + output_flat = torch.sum(output_cat, dim=2) + # output_flat: (N*samples*n_channels, vec_len) numeric tensor + output = output_flat.view(x.size()) + + out0 = (output - x).detach() + x + out1 = (x.detach() - output).float().norm(dim=3).pow(2) + out2 = (x - output.detach()).float().norm(dim=3).pow(2) + (x - x0).float().norm(dim=3).pow(2) + #logger.log(f'std[embedding0] = {self.embedding0.view(-1, embedding.size(2)).index_select(dim=0, index=index1).std()}') + return (out0, out1, out2, entropy) + + def after_update(self): + if self.normalize_scale: + with torch.no_grad(): + target_norm = self.embedding_scale * math.sqrt(self.embedding0.size(2)) + self.embedding0.mul_(target_norm / self.embedding0.norm(dim=2, keepdim=True)) \ No newline at end of file diff --git a/models/vqvae.py b/models/vqvae.py index a3d006a..71df9b2 100644 --- a/models/vqvae.py +++ b/models/vqvae.py @@ -10,19 +10,31 @@ import sys import time from layers.overtone import Overtone -from layers.vector_quant import VectorQuant +from layers.vector_quant import * from layers.downsampling_encoder import DownsamplingEncoder import utils.env as env import utils.logger as logger import random +from layers.singular_loss import SingularLoss + +__model_factory = { + 'vqvae': VectorQuant, + 'vqvae_group': VectorQuantGroup, +} + +def init_vq(name, *args, **kwargs): + if name not in list(__model_factory.keys()): + raise KeyError("Unknown models: {}".format(name)) + return __model_factory[name](*args, **kwargs) class Model(nn.Module) : - def __init__(self, rnn_dims, fc_dims, global_decoder_cond_dims, upsample_factors, normalize_vq=False, - noise_x=False, noise_y=False): + def __init__(self, model_type, rnn_dims, fc_dims, global_decoder_cond_dims, upsample_factors, num_group, num_sample, + normalize_vq=False, noise_x=False, noise_y=False): super().__init__() - self.n_classes = 256 + # self.n_classes = 256 self.overtone = Overtone(rnn_dims, fc_dims, 128, global_decoder_cond_dims) - self.vq = VectorQuant(1, 512, 128, normalize=normalize_vq) + # self.vq = VectorQuant(1, 410, 128, normalize=normalize_vq) + self.vq = init_vq(model_type, 1, 410, 128, num_group, num_sample, normalize=normalize_vq) self.noise_x = noise_x self.noise_y = noise_y encoder_layers = [ @@ -125,15 +137,15 @@ def pad_right(self): def total_scale(self): return self.encoder.total_scale - def do_train(self, paths, dataset, optimiser, epochs, batch_size, step, lr=1e-4, valid_index=[], use_half=False, do_clip=False): + def do_train(self, paths, dataset, optimiser, writer, epochs, test_epochs, batch_size, step, epoch, valid_index=[], use_half=False, do_clip=False, beta=0.): if use_half: import apex optimiser = apex.fp16_utils.FP16_Optimizer(optimiser, dynamic_loss_scale=True) - for p in optimiser.param_groups : p['lr'] = lr + # for p in optimiser.param_groups : p['lr'] = lr criterion = nn.NLLLoss().cuda() - k = 0 - saved_k = 0 + # k = 0 + # saved_k = 0 pad_left = self.pad_left() pad_left_encoder = self.pad_left_encoder() pad_left_decoder = self.pad_left_decoder() @@ -145,7 +157,7 @@ def do_train(self, paths, dataset, optimiser, epochs, batch_size, step, lr=1e-4, window = 16 * self.total_scale() logger.log(f'pad_left={pad_left_encoder}|{pad_left_decoder}, pad_right={pad_right}, total_scale={self.total_scale()}') - for e in range(epochs) : + for e in range(epoch, epochs) : trn_loader = DataLoader(dataset, collate_fn=lambda batch: env.collate_multispeaker_samples(pad_left, window, pad_right, batch), batch_size=batch_size, num_workers=2, shuffle=True, pin_memory=True) @@ -161,7 +173,7 @@ def do_train(self, paths, dataset, optimiser, epochs, batch_size, step, lr=1e-4, iters = len(trn_loader) - for i, (speaker, wave16) in enumerate(trn_loader) : + for i, (speaker, wave16) in enumerate(trn_loader): speaker = speaker.cuda() wave16 = wave16.cuda() @@ -259,20 +271,124 @@ def do_train(self, paths, dataset, optimiser, epochs, batch_size, step, lr=1e-4, step += 1 k = step // 1000 - logger.status(f'Epoch: {e+1}/{epochs} -- Batch: {i+1}/{iters} -- Loss: c={avg_loss_c:#.4} f={avg_loss_f:#.4} vq={avg_loss_vq:#.4} vqc={avg_loss_vqc:#.4} -- Entropy: {avg_entropy:#.4} -- Grad: {running_max_grad:#.1} {running_max_grad_name} Speed: {speed:#.4} steps/sec -- Step: {k}k ') + logger.status(f'[Training] Epoch: {e+1}/{epochs} -- Batch: {i+1}/{iters} -- Loss: c={avg_loss_c:#.4} f={avg_loss_f:#.4} vq={avg_loss_vq:#.4} vqc={avg_loss_vqc:#.4} -- Entropy: {avg_entropy:#.4} -- Grad: {running_max_grad:#.1} {running_max_grad_name} Speed: {speed:#.4} steps/sec -- Step: {k}k ') + + # tensorboard writer + writer.add_scalars('Train/loss_group', {'loss_c': loss_c.item(), + 'loss_f': loss_f.item(), + 'vq': vq_pen.item(), + 'vqc': encoder_pen.item(), + 'entropy': entropy,}, step - 1) os.makedirs(paths.checkpoint_dir, exist_ok=True) - torch.save(self.state_dict(), paths.model_path()) - np.save(paths.step_path(), step) + torch.save({'epoch': e, + 'state_dict': self.state_dict(), + 'optimiser': optimiser.state_dict(), + 'step': step}, + paths.model_path()) + # torch.save(self.state_dict(), paths.model_path()) + # np.save(paths.step_path(), step) logger.log_current_status() - logger.log(f' ; w[0][0] = {self.overtone.wavernn.gru.weight_ih_l0[0][0]}') - if k > saved_k + 50: - torch.save(self.state_dict(), paths.model_hist_path(step)) - saved_k = k - self.do_generate(paths, step, dataset.path, valid_index) + # logger.log(f' ; w[0][0] = {self.overtone.wavernn.gru.weight_ih_l0[0][0]}') + + if e % test_epochs == 0: + torch.save({'epoch': e, + 'state_dict': self.state_dict(), + 'optimiser': optimiser.state_dict(), + 'step': step}, + paths.model_hist_path(step)) + self.do_test(writer, e, step, dataset.path, valid_index) + self.do_test_generate(paths, step, dataset.path, valid_index) + # if k > saved_k + 50: + # torch.save({'epoch': e, + # 'state_dict': self.state_dict(), + # 'optimiser': optimiser.state_dict(), + # 'step': step}, + # paths.model_hist_path(step)) + # # torch.save(self.state_dict(), paths.model_hist_path(step)) + # saved_k = k + # self.do_generate(paths, step, dataset.path, valid_index) + + def do_test(self, writer, epoch, step, data_path, test_index): + dataset = env.MultispeakerDataset(test_index, data_path) + criterion = nn.NLLLoss().cuda() + # k = 0 + # saved_k = 0 + pad_left = self.pad_left() + pad_left_encoder = self.pad_left_encoder() + pad_left_decoder = self.pad_left_decoder() + extra_pad_right = 0 + pad_right = self.pad_right() + extra_pad_right + window = 16 * self.total_scale() + + test_loader = DataLoader(dataset, collate_fn=lambda batch: env.collate_multispeaker_samples(pad_left, window, pad_right, batch), + batch_size=16, num_workers=2, shuffle=False, pin_memory=True) + + running_loss_c = 0. + running_loss_f = 0. + running_loss_vq = 0. + running_loss_vqc = 0. + running_entropy = 0. + running_max_grad = 0. + running_max_grad_name = "" + + for i, (speaker, wave16) in enumerate(test_loader): + speaker = speaker.cuda() + wave16 = wave16.cuda() + + coarse = (wave16 + 2 ** 15) // 256 + fine = (wave16 + 2 ** 15) % 256 + + coarse_f = coarse.float() / 127.5 - 1. + fine_f = fine.float() / 127.5 - 1. + total_f = (wave16.float() + 0.5) / 32767.5 + + noisy_f = total_f + + x = torch.cat([ + coarse_f[:, pad_left - pad_left_decoder:-pad_right].unsqueeze(-1), + fine_f[:, pad_left - pad_left_decoder:-pad_right].unsqueeze(-1), + coarse_f[:, pad_left - pad_left_decoder + 1:1 - pad_right].unsqueeze(-1), + ], dim=2) + y_coarse = coarse[:, pad_left + 1:1 - pad_right] + y_fine = fine[:, pad_left + 1:1 - pad_right] + + translated = noisy_f[:, pad_left - pad_left_encoder:] + + p_cf, vq_pen, encoder_pen, entropy = self(speaker, x, translated) + p_c, p_f = p_cf + loss_c = criterion(p_c.transpose(1, 2).float(), y_coarse) + loss_f = criterion(p_f.transpose(1, 2).float(), y_fine) + # encoder_weight = 0.01 * min(1, max(0.1, step / 1000 - 1)) + # loss = loss_c + loss_f + vq_pen + encoder_weight * encoder_pen + + running_loss_c += loss_c.item() + running_loss_f += loss_f.item() + running_loss_vq += vq_pen.item() + running_loss_vqc += encoder_pen.item() + running_entropy += entropy + + avg_loss_c = running_loss_c / (i + 1) + avg_loss_f = running_loss_f / (i + 1) + avg_loss_vq = running_loss_vq / (i + 1) + avg_loss_vqc = running_loss_vqc / (i + 1) + avg_entropy = running_entropy / (i + 1) - def do_generate(self, paths, step, data_path, test_index, deterministic=False, use_half=False, verbose=False): k = step // 1000 + logger.log( + f'[Testing] Epoch: {epoch} -- Loss: c={avg_loss_c:#.4} f={avg_loss_f:#.4} vq={avg_loss_vq:#.4} vqc={avg_loss_vqc:#.4} -- Entropy: {avg_entropy:#.4} -- Grad: {running_max_grad:#.1} {running_max_grad_name} -- Step: {k}k ') + + # tensorboard writer + writer.add_scalars('Test/loss_group', {'loss_c': avg_loss_c, + 'loss_f': avg_loss_f, + 'vq': avg_loss_vq, + 'vqc': avg_loss_vqc, + 'entropy': avg_entropy, }, step - 1) + + + def do_test_generate(self, paths, step, data_path, test_index, deterministic=False, use_half=False, verbose=False): + k = step // 1000 + test_index = [x[:2] if len(x) > 0 else [] for i, x in enumerate(test_index)] dataset = env.MultispeakerDataset(test_index, data_path) loader = DataLoader(dataset, shuffle=False) data = [x for x in loader] @@ -284,6 +400,7 @@ def do_generate(self, paths, step, data_path, test_index, deterministic=False, u aligned = [torch.cat([torch.FloatTensor(x), torch.zeros(maxlen-len(x))]) for x in extended] os.makedirs(paths.gen_path(), exist_ok=True) out = self.forward_generate(torch.stack(speakers + list(reversed(speakers)), dim=0).cuda(), torch.stack(aligned + aligned, dim=0).cuda(), verbose=verbose, use_half=use_half) + logger.log(f'out: {out.size()}') for i, x in enumerate(gt) : librosa.output.write_wav(f'{paths.gen_path()}/{k}k_steps_{i}_target.wav', x.cpu().numpy(), sr=sample_rate) @@ -291,3 +408,49 @@ def do_generate(self, paths, step, data_path, test_index, deterministic=False, u librosa.output.write_wav(f'{paths.gen_path()}/{k}k_steps_{i}_generated.wav', audio, sr=sample_rate) audio_tr = out[n_points+i][:len(x)].cpu().numpy() librosa.output.write_wav(f'{paths.gen_path()}/{k}k_steps_{i}_transferred.wav', audio_tr, sr=sample_rate) + + + + def do_generate(self, paths, step, data_path, test_index, deterministic=False, use_half=False, verbose=False): + k = step // 1000 + test_index = [x[:10] if len(x) > 0 else [] for i, x in enumerate(test_index)] + test_index[0] = [] + test_index[1] = [] + test_index[2] = [] + # test_index[3] = [] + + dataset = env.MultispeakerDataset(test_index, data_path) + loader = DataLoader(dataset, shuffle=False) + data = [x for x in loader] + n_points = len(data) + gt = [(x[0].float() + 0.5) / (2**15 - 0.5) for speaker, x in data] + extended = [np.concatenate([np.zeros(self.pad_left_encoder(), dtype=np.float32), x, np.zeros(self.pad_right(), dtype=np.float32)]) for x in gt] + speakers = [torch.FloatTensor(speaker[0].float()) for speaker, x in data] + + vc_speakers = [torch.FloatTensor((np.arange(30) == 1).astype(np.float)) for _ in range(10)] + # vc_speakers = [torch.FloatTensor((np.arange(30) == 14).astype(np.float)) for _ in range(20)] + # vc_speakers = [torch.FloatTensor((np.arange(30) == 23).astype(np.float)) for _ in range(20)] + # vc_speakers = [torch.FloatTensor((np.arange(30) == 4).astype(np.float)) for _ in range(20)] + maxlen = max([len(x) for x in extended]) + aligned = [torch.cat([torch.FloatTensor(x), torch.zeros(maxlen-len(x))]) for x in extended] + os.makedirs(paths.gen_path(), exist_ok=True) + # out = self.forward_generate(torch.stack(speakers + list(reversed(speakers)), dim=0).cuda(), torch.stack(aligned + aligned, dim=0).cuda(), verbose=verbose, use_half=use_half) + out = self.forward_generate(torch.stack(vc_speakers, dim=0).cuda(), + torch.stack(aligned, dim=0).cuda(), verbose=verbose, use_half=use_half) + logger.log(f'out: {out.size()}') + # for i, x in enumerate(gt) : + # librosa.output.write_wav(f'{paths.gen_path()}/{k}k_steps_{i}_target.wav', x.cpu().numpy(), sr=sample_rate) + # audio = out[i][:len(x)].cpu().numpy() + # librosa.output.write_wav(f'{paths.gen_path()}/{k}k_steps_{i}_generated.wav', audio, sr=sample_rate) + # audio_tr = out[n_points+i][:len(x)].cpu().numpy() + # librosa.output.write_wav(f'{paths.gen_path()}/{k}k_steps_{i}_transferred.wav', audio_tr, sr=sample_rate) + + for i, x in enumerate(gt): + # librosa.output.write_wav(f'{paths.gen_path()}/gsb_{i+1:04d}.wav', x.cpu().numpy(), sr=sample_rate) + # librosa.output.write_wav(f'{paths.gen_path()}/gt_gsb_{i+1:03d}.wav', x.cpu().numpy(), sr=sample_rate) + # audio = out[i][:len(x)].cpu().numpy() + # librosa.output.write_wav(f'{paths.gen_path()}/{k}k_steps_{i}_generated.wav', audio, sr=sample_rate) + # audio_tr = out[n_points+i][:len(x)].cpu().numpy() + audio_tr = out[i][:self.pad_left_encoder() + len(x)].cpu().numpy() + # librosa.output.write_wav(f'{paths.gen_path()}/{k}k_steps_{i}_transferred.wav', audio_tr, sr=sample_rate) + librosa.output.write_wav(f'{paths.gen_path()}/gsb_{i + 1:04d}.wav', audio_tr, sr=sample_rate) diff --git a/preprocess_multispeaker.py b/preprocess_multispeaker.py index 3b179e2..06e3fc6 100644 --- a/preprocess_multispeaker.py +++ b/preprocess_multispeaker.py @@ -9,17 +9,14 @@ DATA_PATH = sys.argv[2] def get_files(path): - next_speaker_id = 0 - speaker_ids = {} filenames = [] - for filename in glob.iglob(f'{path}/**/*.wav', recursive=True): - speaker_name = filename.split('/')[-2] - if speaker_name not in speaker_ids: - speaker_ids[speaker_name] = next_speaker_id - next_speaker_id += 1 - filenames.append([]) - filenames[speaker_ids[speaker_name]].append(filename) - + speakers = sorted(os.listdir(path)) + for speaker in speakers: + filenames_speaker = [] + files = sorted(os.listdir(f'{path}/{speaker}')) + for file in files: + filenames_speaker.append(os.path.join(path, speaker, file)) + filenames.append(filenames_speaker) return filenames files = get_files(SEG_PATH) diff --git a/utils/env.py b/utils/env.py index 6b90aac..d9543f0 100644 --- a/utils/env.py +++ b/utils/env.py @@ -106,15 +106,20 @@ def collate(left_pad, mel_win, right_pad, batch) : return mels, coarse, fine, coarse_f, fine_f -def restore(path, model): - model.load_state_dict(torch.load(path)) - - match = re.search(r'_([0-9]+)\.pyt', path) - if match: - return int(match.group(1)) - - step_path = re.sub(r'\.pyt', '_step.npy', path) - return np.load(step_path) +def restore(path, model, optimiser): + checkpoint = torch.load(path) + model.load_state_dict(checkpoint['state_dict']) + optimiser.load_state_dict(checkpoint['optimiser']) + step = checkpoint['step'] + epoch = checkpoint['epoch'] + 1 + + # match = re.search(r'_([0-9]+)\.pyt', path) + # if match: + # return int(match.group(1)) + # + # step_path = re.sub(r'\.pyt', '_step.npy', path) + # return np.load(step_path) + return step, epoch if __name__ == '__main__': import pickle diff --git a/wavernn.py b/wavernn.py index 66adc54..75c0d7e 100644 --- a/wavernn.py +++ b/wavernn.py @@ -18,6 +18,7 @@ import utils.logger as logger import time import subprocess +from tensorboardX import SummaryWriter import config @@ -31,6 +32,14 @@ parser.add_argument('--force', action='store_true', help='skip the version check') parser.add_argument('--count', '-c', type=int, default=3, help='size of the test set') parser.add_argument('--partial', action='append', default=[], help='model to partially load') +parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, help="initial learning rate") +parser.add_argument('--weight-decay', default=1e-04, type=float, help="weight decay (default: 1e-04)") +parser.add_argument('--batch-size', type=int, default=48, metavar='N', help='input batch size for training (default: 128)') +parser.add_argument('--beta', type=float, default=0., help='the beta of singular loss') +parser.add_argument('--epochs', type=int, default=1000, help='epochs in training') +parser.add_argument('--test-epochs', type=int, default=200, help='testing every X epochs') +parser.add_argument('--num-group', type=int, default=8, help='num of groups in dictionary') +parser.add_argument('--num-sample', type=int, default=1, help='num of Monte Carlo samples') args = parser.parse_args() if args.float and args.half: @@ -47,9 +56,9 @@ model_name = f'{model_type}.43.upconv' -if model_type == 'vqvae': - model_fn = lambda dataset: vqvae.Model(rnn_dims=896, fc_dims=896, global_decoder_cond_dims=dataset.num_speakers(), - upsample_factors=(4, 4, 4), normalize_vq=True, noise_x=True, noise_y=True).cuda() +if model_type[:5] == 'vqvae': + model_fn = lambda dataset: vqvae.Model(model_type=model_type, rnn_dims=896, fc_dims=896, global_decoder_cond_dims=dataset.num_speakers(), + upsample_factors=(4, 4, 4), num_group=args.num_group, num_sample=args.num_sample, normalize_vq=True, noise_x=True, noise_y=True).cuda() dataset_type = 'multi' elif model_type == 'wavernn': model_fn = lambda dataset: wr.Model(rnn_dims=896, fc_dims=896, pad=2, @@ -65,8 +74,8 @@ data_path = config.multi_speaker_data_path with open(f'{data_path}/index.pkl', 'rb') as f: index = pickle.load(f) - test_index = [x[-1:] if i < 2 * args.count else [] for i, x in enumerate(index)] - train_index = [x[:-1] if i < args.count else x for i, x in enumerate(index)] + test_index = [x[:30] if i < args.count else [] for i, x in enumerate(index)] + train_index = [x[30:] if i < args.count else x for i, x in enumerate(index)] dataset = env.MultispeakerDataset(train_index, data_path) elif dataset_type == 'single': data_path = config.single_speaker_data_path @@ -88,11 +97,14 @@ for partial_path in args.partial: model.load_state_dict(torch.load(partial_path), strict=False) +optimiser = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + paths = env.Paths(model_name, data_path) if args.scratch or args.load == None and not os.path.exists(paths.model_path()): # Start from scratch step = 0 + epoch = 0 else: if args.load: prev_model_name = re.sub(r'_[0-9]+$', '', re.sub(r'\.pyt$', '', os.path.basename(args.load))) @@ -105,11 +117,11 @@ prev_path = args.load else: prev_path = paths.model_path() - step = env.restore(prev_path, model) + step, epoch = env.restore(prev_path, model, optimiser) #model.freeze_encoder() -optimiser = optim.Adam(model.parameters()) + if args.generate: model.do_generate(paths, step, data_path, test_index, use_half=use_half, verbose=True)#, deterministic=True) @@ -118,4 +130,12 @@ logger.log('------------------------------------------------------------') logger.log('-- New training session starts here ------------------------') logger.log(time.strftime('%c UTC', time.gmtime())) - model.do_train(paths, dataset, optimiser, epochs=1000, batch_size=16, step=step, lr=1e-4, use_half=use_half, valid_index=test_index) + logger.log('beta={}'.format(args.beta)) + logger.log('num_group={}'.format(args.num_group)) + logger.log('count={}'.format(args.count)) + logger.log('num_sample={}'.format(args.num_sample)) + writer = SummaryWriter(paths.logfile_path() + '_tensorboard') + writer.add_scalars('Params/Train', {'beta': args.beta}) + writer.add_scalars('Params/Train', {'num_group': args.num_group}) + writer.add_scalars('Params/Train', {'num_sample': args.num_sample}) + model.do_train(paths, dataset, optimiser, writer, epochs=args.epochs, test_epochs=args.test_epochs, batch_size=args.batch_size, step=step, epoch=epoch, use_half=use_half, valid_index=test_index, beta=args.beta)