Skip to content
Open

Main #246

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 13 additions & 148 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,161 +1,26 @@
# FastSpeech 2 - PyTorch Implementation

This is a PyTorch implementation of Microsoft's text-to-speech system [**FastSpeech 2: Fast and High-Quality End-to-End Text to Speech**](https://arxiv.org/abs/2006.04558v1).
This project is based on [xcmyz's implementation](https://github.com/xcmyz/FastSpeech) of FastSpeech. Feel free to use/modify the code.
This repository is an extended PyTorch implementation of Microsoft's [**FastSpeech 2: Fast and High-Quality End-to-End Text to Speech**](https://arxiv.org/abs/2006.04558v1), initially based on [xcmyz's implementation](https://github.com/xcmyz/FastSpeech), with the core code structure derived from [ming024's original FastSpeech2 implementation](https://github.com/ming024/FastSpeech2).
We introduce several modifications to enable training and inference using **phonological features** instead of phoneme IDs, supporting cross-lingual and low-resource speech synthesis scenarios. This modification allows more linguistically informed training and better generalization across languages. Using this version, we successfully trained a **German baseline TTS model**, and further performed **transfer learning** with a small amount of English data to train an English model.

There are several versions of FastSpeech 2.
This implementation is more similar to [version 1](https://arxiv.org/abs/2006.04558v1), which uses F0 values as the pitch features.
On the other hand, pitch spectrograms extracted by continuous wavelet transform are used as the pitch features in the [later versions](https://arxiv.org/abs/2006.04558).
Our method is inspired by the concept of using cross-lingual phonological information as described in the paper:
> _"Cross-lingual Transfer of Phonological Features for Low-resource Speech Synthesis"_
> [SSW11 Paper PDF](https://www.pure.ed.ac.uk/ws/portalfiles/portal/215873748/pf_tts_ssw11.pdf)

![](./img/model.png)
We also refer to the [PHOIBLE database](https://phoible.org) for phonological feature definitions and mappings.

# Updates
- 2021/7/8: Release the checkpoint and audio samples of a multi-speaker English TTS model trained on LibriTTS
- 2021/2/26: Support English and Mandarin TTS
- 2021/2/26: Support multi-speaker TTS (AISHELL-3 and LibriTTS)
- 2021/2/26: Support MelGAN and HiFi-GAN vocoder
The overall training and synthesis pipeline still follows the original repository structure [ming024's original FastSpeech2 implementation](https://github.com/ming024/FastSpeech2). However, we have made the following key modifications to support phonological feature-based modeling:

# Audio Samples
Audio samples generated by this implementation can be found [here](https://ming024.github.io/FastSpeech2/).
- **`text/` folder**: contains several modified files to support phonological feature data preparation.
- **`transformer/models.py`**: updated to allow model input as phonological feature vectors instead of phoneme IDs.
- **`synthesis.py`**: modified to support inference using phonological features as input.

# Quickstart

## Dependencies
You can install the Python dependencies with
```
pip3 install -r requirements.txt
```

## Inference

You have to download the [pretrained models](https://drive.google.com/drive/folders/1DOhZGlTLMbbAAFZmZGDdc77kz1PloS7F?usp=sharing) and put them in ``output/ckpt/LJSpeech/``, ``output/ckpt/AISHELL3``, or ``output/ckpt/LibriTTS/``.

For English single-speaker TTS, run
```
python3 synthesize.py --text "YOUR_DESIRED_TEXT" --restore_step 900000 --mode single -p config/LJSpeech/preprocess.yaml -m config/LJSpeech/model.yaml -t config/LJSpeech/train.yaml
```

For Mandarin multi-speaker TTS, try
```
python3 synthesize.py --text "大家好" --speaker_id SPEAKER_ID --restore_step 600000 --mode single -p config/AISHELL3/preprocess.yaml -m config/AISHELL3/model.yaml -t config/AISHELL3/train.yaml
```

For English multi-speaker TTS, run
```
python3 synthesize.py --text "YOUR_DESIRED_TEXT" --speaker_id SPEAKER_ID --restore_step 800000 --mode single -p config/LibriTTS/preprocess.yaml -m config/LibriTTS/model.yaml -t config/LibriTTS/train.yaml
```

The generated utterances will be put in ``output/result/``.

Here is an example of synthesized mel-spectrogram of the sentence "Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition", with the English single-speaker TTS model.
![](./img/synthesized_melspectrogram.png)

## Batch Inference
Batch inference is also supported, try

```
python3 synthesize.py --source preprocessed_data/LJSpeech/val.txt --restore_step 900000 --mode batch -p config/LJSpeech/preprocess.yaml -m config/LJSpeech/model.yaml -t config/LJSpeech/train.yaml
```
to synthesize all utterances in ``preprocessed_data/LJSpeech/val.txt``

## Controllability
The pitch/volume/speaking rate of the synthesized utterances can be controlled by specifying the desired pitch/energy/duration ratios.
For example, one can increase the speaking rate by 20 % and decrease the volume by 20 % by

```
python3 synthesize.py --text "YOUR_DESIRED_TEXT" --restore_step 900000 --mode single -p config/LJSpeech/preprocess.yaml -m config/LJSpeech/model.yaml -t config/LJSpeech/train.yaml --duration_control 0.8 --energy_control 0.8
```

# Training

## Datasets

The supported datasets are

- [LJSpeech](https://keithito.com/LJ-Speech-Dataset/): a single-speaker English dataset consists of 13100 short audio clips of a female speaker reading passages from 7 non-fiction books, approximately 24 hours in total.
- [AISHELL-3](http://www.aishelltech.com/aishell_3): a Mandarin TTS dataset with 218 male and female speakers, roughly 85 hours in total.
- [LibriTTS](https://research.google/tools/datasets/libri-tts/): a multi-speaker English dataset containing 585 hours of speech by 2456 speakers.

We take LJSpeech as an example hereafter.

## Preprocessing

First, run
```
python3 prepare_align.py config/LJSpeech/preprocess.yaml
```
for some preparations.

As described in the paper, [Montreal Forced Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/) (MFA) is used to obtain the alignments between the utterances and the phoneme sequences.
Alignments of the supported datasets are provided [here](https://drive.google.com/drive/folders/1DBRkALpPd6FL9gjHMmMEdHODmkgNIIK4?usp=sharing).
You have to unzip the files in ``preprocessed_data/LJSpeech/TextGrid/``.

After that, run the preprocessing script by
```
python3 preprocess.py config/LJSpeech/preprocess.yaml
```

Alternately, you can align the corpus by yourself.
Download the official MFA package and run
```
./montreal-forced-aligner/bin/mfa_align raw_data/LJSpeech/ lexicon/librispeech-lexicon.txt english preprocessed_data/LJSpeech
```
or
```
./montreal-forced-aligner/bin/mfa_train_and_align raw_data/LJSpeech/ lexicon/librispeech-lexicon.txt preprocessed_data/LJSpeech
```

to align the corpus and then run the preprocessing script.
```
python3 preprocess.py config/LJSpeech/preprocess.yaml
```

## Training

Train your model with
```
python3 train.py -p config/LJSpeech/preprocess.yaml -m config/LJSpeech/model.yaml -t config/LJSpeech/train.yaml
```

The model takes less than 10k steps (less than 1 hour on my GTX1080Ti GPU) of training to generate audio samples with acceptable quality, which is much more efficient than the autoregressive models such as Tacotron2.

# TensorBoard

Use
```
tensorboard --logdir output/log/LJSpeech
```

to serve TensorBoard on your localhost.
The loss curves, synthesized mel-spectrograms, and audios are shown.

![](./img/tensorboard_loss.png)
![](./img/tensorboard_spec.png)
![](./img/tensorboard_audio.png)

# Implementation Issues

- Following [xcmyz's implementation](https://github.com/xcmyz/FastSpeech), I use an additional Tacotron-2-styled Post-Net after the decoder, which is not used in the original FastSpeech 2.
- Gradient clipping is used in the training.
- In my experience, using phoneme-level pitch and energy prediction instead of frame-level prediction results in much better prosody, and normalizing the pitch and energy features also helps. Please refer to ``config/README.md`` for more details.

Please inform me if you find any mistakes in this repo, or any useful tips to train the FastSpeech 2 model.
---

# References
- [FastSpeech 2: Fast and High-Quality End-to-End Text to Speech](https://arxiv.org/abs/2006.04558), Y. Ren, *et al*.
- [xcmyz's FastSpeech implementation](https://github.com/xcmyz/FastSpeech)
- [TensorSpeech's FastSpeech 2 implementation](https://github.com/TensorSpeech/TensorflowTTS)
- [rishikksh20's FastSpeech 2 implementation](https://github.com/rishikksh20/FastSpeech2)

# Citation
```
@INPROCEEDINGS{chien2021investigating,
author={Chien, Chung-Ming and Lin, Jheng-Hao and Huang, Chien-yu and Hsu, Po-chun and Lee, Hung-yi},
booktitle={ICASSP 2021 - 2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Investigating on Incorporating Pretrained and Learnable Speaker Representations for Multi-Speaker Multi-Style Text-to-Speech},
year={2021},
volume={},
number={},
pages={8588-8592},
doi={10.1109/ICASSP39728.2021.9413880}}
```
- [PHOIBLE: Phonological Segment Inventory Database](https://phoible.org)
- [Cross-lingual Transfer of Phonological Features for Low-resource Speech Synthesis (SSW11)](https://www.pure.ed.ac.uk/ws/portalfiles/portal/215873748/pf_tts_ssw11.pdf)
22 changes: 13 additions & 9 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __getitem__(self, idx):
speaker = self.speaker[idx]
speaker_id = self.speaker_map[speaker]
raw_text = self.raw_text[idx]
phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
phone = text_to_sequence(self.text[idx], self.cleaners) # delete np.array
mel_path = os.path.join(
self.preprocessed_path,
"mel",
Expand All @@ -59,11 +59,14 @@ def __getitem__(self, idx):
"{}-duration-{}.npy".format(speaker, basename),
)
duration = np.load(duration_path)

# add for debugging
assert len(pitch) == len(phone), \
f"Pitch length {len(pitch)} != Text length {len(phone)} for {basename}"

sample = {
"id": basename,
"speaker": speaker_id,
"text": phone,
"text": phone, #is feature vectors
"raw_text": raw_text,
"mel": mel,
"pitch": pitch,
Expand Down Expand Up @@ -92,7 +95,7 @@ def process_meta(self, filename):
def reprocess(self, data, idxs):
ids = [data[idx]["id"] for idx in idxs]
speakers = [data[idx]["speaker"] for idx in idxs]
texts = [data[idx]["text"] for idx in idxs]
texts = [data[idx]["text"].float() for idx in idxs]
raw_texts = [data[idx]["raw_text"] for idx in idxs]
mels = [data[idx]["mel"] for idx in idxs]
pitches = [data[idx]["pitch"] for idx in idxs]
Expand All @@ -103,7 +106,8 @@ def reprocess(self, data, idxs):
mel_lens = np.array([mel.shape[0] for mel in mels])

speakers = np.array(speakers)
texts = pad_1D(texts)
#texts = pad_1D(texts)
texts = pad_2D(texts)
mels = pad_2D(mels)
pitches = pad_1D(pitches)
energies = pad_1D(energies)
Expand Down Expand Up @@ -168,7 +172,7 @@ def __getitem__(self, idx):
speaker = self.speaker[idx]
speaker_id = self.speaker_map[speaker]
raw_text = self.raw_text[idx]
phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
phone = text_to_sequence(self.text[idx], self.cleaners) # delete np.array

return (basename, speaker_id, phone, raw_text)

Expand All @@ -189,12 +193,12 @@ def process_meta(self, filename):
def collate_fn(self, data):
ids = [d[0] for d in data]
speakers = np.array([d[1] for d in data])
texts = [d[2] for d in data]
texts = [d[2].astype(np.float32) for d in data]
raw_texts = [d[3] for d in data]
text_lens = np.array([text.shape[0] for text in texts])

texts = pad_1D(texts)

#texts = pad_1D(texts)
texts = pad_2D(texts)
return ids, raw_texts, speakers, texts, text_lens, max(text_lens)


Expand Down
39 changes: 38 additions & 1 deletion synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from utils.tools import to_device, synth_samples
from dataset import TextDataset
from text import text_to_sequence
from text.german_numbers import german_normalize_numbers

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -83,6 +84,41 @@ def preprocess_mandarin(text, preprocess_config):

return np.array(sequence)

def preprocess_de(text, preprocess_config):
text = text.rstrip(punctuation).replace("ß","ss")
lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
def split_alphanum(match):
letters = match.group(1)
numbers = match.group(2)
# split if character+number
return letters + " " + " ".join(list(numbers))

text = re.sub(r'([a-zA-Z]+)(\d+)', split_alphanum, text)

text = german_normalize_numbers(text)

phones = []
words = re.split(r"([,;.\-\?\!\s+])", text)

for w in words:
if w.lower() in lexicon:
phones += lexicon[w.lower()]
elif re.match(r"[,;.\-\?\!]", w):
phones.append("sil")
phones = "{" + "}{".join(phones) + "}"
phones = phones.replace("}{", " ")

# text_to_sequence get features
features = text_to_sequence(
phones,
preprocess_config["preprocessing"]["text"]["text_cleaners"],
)

print("Raw Text Sequence: {}".format(text))
print("Phoneme Sequence: {}".format(phones))
#print("features:{}".format(features))
return np.array(features)


def synthesize(model, step, configs, vocoder, batchs, control_values):
preprocess_config, model_config, train_config = configs
Expand All @@ -109,7 +145,6 @@ def synthesize(model, step, configs, vocoder, batchs, control_values):


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("--restore_step", type=int, required=True)
parser.add_argument(
Expand Down Expand Up @@ -206,6 +241,8 @@ def synthesize(model, step, configs, vocoder, batchs, control_values):
texts = np.array([preprocess_english(args.text, preprocess_config)])
elif preprocess_config["preprocessing"]["text"]["language"] == "zh":
texts = np.array([preprocess_mandarin(args.text, preprocess_config)])
elif preprocess_config["preprocessing"]["text"]["language"] == "de":
texts = np.array([preprocess_de(args.text, preprocess_config)])
text_lens = np.array([len(texts[0])])
batchs = [(ids, raw_texts, speakers, texts, text_lens, max(text_lens))]

Expand Down
Loading