|
1 | | -import soundfile as sf |
2 | 1 | import torch |
3 | | -from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer |
| 2 | +from torch import Tensor |
4 | 3 | from torch.utils.mobile_optimizer import optimize_for_mobile |
| 4 | +import torchaudio |
| 5 | +from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model |
| 6 | +from transformers import Wav2Vec2ForCTC |
5 | 7 |
|
6 | | -tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") |
| 8 | +# Wav2vec2 model emits sequences of probability (logits) distributions over the characters |
| 9 | +# The following class adds steps to decode the transcript (best path) |
| 10 | +class SpeechRecognizer(torch.nn.Module): |
| 11 | + def __init__(self, model): |
| 12 | + super().__init__() |
| 13 | + self.model = model |
| 14 | + self.labels = [ |
| 15 | + "<s>", "<pad>", "</s>", "<unk>", "|", "E", "T", "A", "O", "N", "I", "H", "S", |
| 16 | + "R", "D", "L", "U", "M", "W", "C", "F", "G", "Y", "P", "B", "V", "K", "'", "X", |
| 17 | + "J", "Q", "Z"] |
| 18 | + |
| 19 | + def forward(self, waveforms: Tensor) -> str: |
| 20 | + """Given a single channel speech data, return transcription. |
| 21 | + |
| 22 | + Args: |
| 23 | + waveforms (Tensor): Speech tensor. Shape `[1, num_frames]`. |
| 24 | +
|
| 25 | + Returns: |
| 26 | + str: The resulting transcript |
| 27 | + """ |
| 28 | + logits, _ = self.model(waveforms) # [batch, num_seq, num_label] |
| 29 | + best_path = torch.argmax(logits[0], dim=-1) # [num_seq,] |
| 30 | + prev = '' |
| 31 | + hypothesis = '' |
| 32 | + for i in best_path: |
| 33 | + char = self.labels[i] |
| 34 | + if char == prev: |
| 35 | + continue |
| 36 | + if char == '<s>': |
| 37 | + prev = '' |
| 38 | + continue |
| 39 | + hypothesis += char |
| 40 | + prev = char |
| 41 | + return hypothesis.replace('|', ' ') |
| 42 | + |
| 43 | + |
| 44 | +# Load Wav2Vec2 pretrained model from Hugging Face Hub |
7 | 45 | model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") |
8 | | -model.eval() |
9 | | - |
10 | | -audio_input, _ = sf.read("scent_of_a_woman_future.wav") |
11 | | -input_values = tokenizer(audio_input, return_tensors="pt").input_values |
12 | | -logits = model(input_values).logits |
13 | | -predicted_ids = torch.argmax(logits, dim=-1) |
14 | | -transcription = tokenizer.batch_decode(predicted_ids)[0] |
15 | | - |
16 | | -model_dynamic_quantized = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) |
17 | | -traced_quantized_model = torch.jit.trace(model_dynamic_quantized, input_values, strict=False) |
18 | | -optimized_traced_quantized_model = optimize_for_mobile(traced_quantized_model) |
19 | | -optimized_traced_quantized_model.save("wav2vec2.pt") |
| 46 | +# Convert the model to torchaudio format, which supports TorchScript. |
| 47 | +model = import_huggingface_model(model) |
| 48 | +# Remove weight normalization which is not supported by quantization. |
| 49 | +model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() |
| 50 | +model = model.eval() |
| 51 | +# Attach decoder |
| 52 | +model = SpeechRecognizer(model) |
| 53 | + |
| 54 | +# Apply quantization / script / optimize for motbile |
| 55 | +quantized_model = torch.quantization.quantize_dynamic( |
| 56 | + model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) |
| 57 | +scripted_model = torch.jit.script(quantized_model) |
| 58 | +optimized_model = optimize_for_mobile(scripted_model) |
| 59 | + |
| 60 | +# Sanity check |
| 61 | +waveform , _ = torchaudio.load('scent_of_a_woman_future.wav') |
| 62 | +print(waveform.size()) |
| 63 | +print('Result:', optimized_model(waveform)) |
| 64 | + |
| 65 | +optimized_model.save("SpeechRecognition/wav2vec2.pt") |
0 commit comments