Skip to content
1 change: 1 addition & 0 deletions LICENSE.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ MIT License

Copyright (c) 2019 fatchord (https://github.com/fatchord)
Copyright (c) 2019 mkotha (https://github.com/mkotha)
Copyright (c) 2019 shaojinding (https://github.com/shaojinding/GroupLatentEmbedding)

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
100 changes: 28 additions & 72 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,30 +1,23 @@
# WaveRNN + VQ-VAE
# Group Latent Embedding for Vector Quantized Variational Autoencoder in Non-Parallel Voice Conversion

This is a Pytorch implementation of [WaveRNN](
https://arxiv.org/abs/1802.08435v1). Currently 3 top-level networks are
provided:
Code for this paper [Group Latent Embedding for Vector Quantized Variational Autoencoder in Non-Parallel Voice Conversion](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/1198.pdf)

* A [VQ-VAE](https://avdnoord.github.io/homepage/vqvae/) implementation with a
WaveRNN decoder. Trained on a multispeaker dataset of speech, it can
demonstrate speech reconstruction and speaker conversion.
* A vocoder implementation. Trained on a single-speaker dataset, it can turn a
mel spectrogram into raw waveform.
* An unconditioned WaveRNN. Trained on a single-speaker dataset, it can generate
random speech.
Shaojin Ding, Ricardo Gutierrez-Osuna

[Audio samples](https://mkotha.github.io/WaveRNN/).
In INTERSPEECH 2019

It has been tested with the following datasets.
This is a Pytorch implementation. This implementation is based on the VQ-VAE-WaveRNN implementation at [https://github.com/mkotha/WaveRNN](https://github.com/mkotha/WaveRNN).

Multispeaker datasets:
## Dataset:

* [VCTK](https://datashare.is.ed.ac.uk/handle/10283/2651)
* [Audio samples](https://shaojinding.github.io/samples/gle/gle_demo).
* [Trained model](https://drive.google.com/file/d/1W4lA37_susadCY5UQUPPbaqKbRNfNUW7/view?usp=sharing).

Single-speaker datasets:
## Preparation

* [LJ Speech](https://keithito.com/LJ-Speech-Dataset/)
The preparation is similar to that at [https://github.com/mkotha/WaveRNN](https://github.com/mkotha/WaveRNN). We repeat it here for convenience.

## Preparation

### Requirements

Expand Down Expand Up @@ -52,27 +45,16 @@ You can skip this section if you don't need a multi-speaker dataset.
3. In `config.py`, set `multi_speaker_data_path` to point to the output
directory.

### Preparing LJ-Speech

You can skip this section if you don't need a single-speaker dataset.

1. Download and uncompress [the LJ speech dataset](
https://keithito.com/LJ-Speech-Dataset/).
2. `python preprocess16.py /path/to/dataset/LJSpeech-1.1/wavs
/path/to/output/directory`
3. In `config.py`, set `single_speaker_data_path` to point to the output
directory.

## Usage

`wavernn.py` is the entry point:
To run Group Latent Embedding:

```
$ python wavernn.py
$ python wavernn.py -m vqvae_group --num-group 41 --num-sample 10
```

By default, it trains a VQ-VAE model. The `-m` option can be used to tell the
the script to train a different model.
The `-m` option can be used to tell the the script what model to train. By default, it trains a vanilla VQ-VAE model.

Trained models are saved under the `model_checkpoints` directory.

Expand All @@ -85,47 +67,21 @@ goes under the `model_outputs` directory.
When the `-g` option is given, the script produces the output using the saved
model, rather than training it.

# Deviations from the papers

I deviated from the papers in some details, sometimes because I was lazy, and
sometimes because I was unable to get good results without it. Below is a
(probably incomplete) list of deviations.

All models:

* The sampling rate is 22.05kHz.

VQ-VAE:

* I normalize each latent embedding vector, so that it's on the unit 128-
dimensional sphere. Without this change, I was unable to get good utilization
of the embedding vectors.
* In the early stage of training, I scale with a small number the penalty term
that apply to the input to the VQ layer. Without this, the input very often
collapses into a degenerate distribution which always selects the same
embedding vector.
* During training, the target audio signal (which is also the input signal) is
translated along the time axis by a random amount, uniformly chosen from
[-128, 127] samples. Less importantly, some additive and multiplicative
Gaussian noise is also applied to each audio sample. Without these types of
noise, the feature captured by the model tended to be very sensitive to small
purterbations to the input, and the subjective quality of the model output
kept descreasing after a certain point in training.
* The decoder is based on WaveRNN instead of WaveNet. See the next section for
details about this network.

# Context stacks

The VQ-VAE implementation uses a WaveRNN-based decoder instead of a WaveNet-
based decoder found in the paper. This is a WaveRNN network augmented
with a context stack to extend the receptive field. This network is
defined in `layers/overtone.py`.

The network has 6 convolutions with stride 2 to generate 64x downsampled
'summary' of the waveform, and then 4 layers of upsampling RNNs, the last of
which is the WaveRNN layer. It also has U-net-like skip connections that
connect layers with the same operating frequency.
`--num-group` specifies the number of groups. `--num-sample` specifies the number of atoms in each group. Note that num-group times num-sample should be equal to the total number of atoms in the embedding dictionary (`n_classes` in class `VectorQuantGroup` in `vector_quant.py`)

# Acknowledgement

The code is based on [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN).
The code is based on [mkotha/WaveRNN](https://github.com/mkotha/WaveRNN).

# Cite the work
```
@inproceedings{Ding2019,
author={Shaojin Ding and Ricardo Gutierrez-Osuna},
title={{Group Latent Embedding for Vector Quantized Variational Autoencoder in Non-Parallel Voice Conversion}},
year=2019,
booktitle={Proc. Interspeech 2019},
pages={724--728},
doi={10.21437/Interspeech.2019-1198},
url={http://dx.doi.org/10.21437/Interspeech.2019-1198}
}
```
112 changes: 111 additions & 1 deletion layers/vector_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import torch.nn.functional as F
import math
import utils.logger as logger
import numpy as np
from layers.attention import EmbeddingAttention

class VectorQuant(nn.Module):
"""
Input: (N, samples, n_channels, vec_len) numeric tensor
Output: (N, samples, n_channels, vec_len) numeric tensor
"""
def __init__(self, n_channels, n_classes, vec_len, normalize=False):
def __init__(self, n_channels, n_classes, vec_len, num_group, num_sample, normalize=False):
super().__init__()
if normalize:
target_scale = 0.06
Expand Down Expand Up @@ -66,3 +68,111 @@ def after_update(self):
with torch.no_grad():
target_norm = self.embedding_scale * math.sqrt(self.embedding0.size(2))
self.embedding0.mul_(target_norm / self.embedding0.norm(dim=2, keepdim=True))


class VectorQuantGroup(nn.Module):
"""
Input: (N, samples, n_channels, vec_len) numeric tensor
Output: (N, samples, n_channels, vec_len) numeric tensor
"""
def __init__(self, n_channels, n_classes, vec_len, num_group, num_sample, normalize=False):
super().__init__()
if normalize:
target_scale = 0.06
self.embedding_scale = target_scale
self.normalize_scale = target_scale
else:
self.embedding_scale = 1e-3
self.normalize_scale = None

self.n_classes = n_classes
self._num_group = num_group
self._num_sample = num_sample
if not self.n_classes % self._num_group == 0:
raise ValueError('num of embeddings in each group should be an integer')
self._num_classes_per_group = int(self.n_classes / self._num_group)

self.embedding0 = nn.Parameter(torch.randn(n_channels, n_classes, vec_len, requires_grad=True) * self.embedding_scale)
self.offset = torch.arange(n_channels).cuda() * n_classes
# self.offset: (n_channels) long tensor
self.after_update()

def forward(self, x0):
if self.normalize_scale:
target_norm = self.normalize_scale * math.sqrt(x0.size(3))
x = target_norm * x0 / x0.norm(dim=3, keepdim=True)
embedding = target_norm * self.embedding0 / self.embedding0.norm(dim=2, keepdim=True)
else:
x = x0
embedding = self.embedding0
#logger.log(f'std[x] = {x.std()}')
x1 = x.reshape(x.size(0) * x.size(1), x.size(2), 1, x.size(3))
# x1: (N*samples, n_channels, 1, vec_len) numeric tensor

# Perform chunking to avoid overflowing GPU RAM.
index_chunks = []
prob_chunks = []
for x1_chunk in x1.split(512, dim=0):
d = (x1_chunk - embedding).norm(dim=3)

# Compute the group-wise distance
d_group = torch.zeros(x1_chunk.shape[0], 1, self._num_group).to(torch.device('cuda'))
for i in range(self._num_group):
d_group[:, :, i] = torch.mean(
d[:, :, i * self._num_classes_per_group: (i + 1) * self._num_classes_per_group], 2)

# Find the nearest group
index_chunk_group = d_group.argmin(dim=2)

# Generate mask for the nearest group
index_chunk_group = index_chunk_group.repeat(1, self._num_classes_per_group)
index_chunk_group = torch.mul(self._num_classes_per_group, index_chunk_group)
idx_mtx = torch.LongTensor([x for x in range(self._num_classes_per_group)]).unsqueeze(0).cuda()
index_chunk_group += idx_mtx
encoding_mask = torch.zeros(x1_chunk.shape[0], self.n_classes).cuda()
encoding_mask.scatter_(1, index_chunk_group, 1)

# Compute the weight atoms in the group
encoding_prob = torch.div(1, d.squeeze())

# Apply the mask
masked_encoding_prob = torch.mul(encoding_mask, encoding_prob)
p, idx = masked_encoding_prob.sort(dim=1, descending=True)
prob_chunks.append(p[:, :self._num_sample])
index_chunks.append(idx[:, :self._num_sample])



index = torch.cat(index_chunks, dim=0)
prob_dist = torch.cat(prob_chunks, dim=0)
prob_dist = F.normalize(prob_dist, p=1, dim=1)
# index: (N*samples, n_channels) long tensor
if True: # compute the entropy
hist = index[:, 0].float().cpu().histc(bins=self.n_classes, min=-0.5, max=self.n_classes - 0.5)
prob = hist.masked_select(hist > 0) / len(index)
entropy = - (prob * prob.log()).sum().item()
#logger.log(f'entrypy: {entropy:#.4}/{math.log(self.n_classes):#.4}')
else:
entropy = 0
index1 = (index + self.offset)
# index1: (N*samples*n_channels) long tensor
output_list = []
for i in range(self._num_sample):
output_list.append(torch.mul(embedding.view(-1, embedding.size(2)).index_select(dim=0, index=index1[:, i]), prob_dist[:, i].unsqueeze(1).detach()))

output_cat = torch.stack(output_list, dim=2)
output_flat = torch.sum(output_cat, dim=2)
# output_flat: (N*samples*n_channels, vec_len) numeric tensor
output = output_flat.view(x.size())

out0 = (output - x).detach() + x
out1 = (x.detach() - output).float().norm(dim=3).pow(2)
out2 = (x - output.detach()).float().norm(dim=3).pow(2) + (x - x0).float().norm(dim=3).pow(2)
#logger.log(f'std[embedding0] = {self.embedding0.view(-1, embedding.size(2)).index_select(dim=0, index=index1).std()}')
return (out0, out1, out2, entropy)

def after_update(self):
if self.normalize_scale:
with torch.no_grad():
target_norm = self.embedding_scale * math.sqrt(self.embedding0.size(2))
self.embedding0.mul_(target_norm / self.embedding0.norm(dim=2, keepdim=True))
Loading