Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.

Commit e0c436d

Browse files
authored
Merge pull request #53 from jeffxtang/master
updated script and iOS code to use torchaudio 0.9 based wav2vec2 model with no input limit
2 parents 0f28ced + f5578a4 commit e0c436d

File tree

6 files changed

+102
-80
lines changed

6 files changed

+102
-80
lines changed

SpeechRecognition/Podfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@ target 'SpeechRecognition' do
66
use_frameworks!
77

88
# Pods for SpeechRecognition
9-
pod 'LibTorch', '~>1.8.0'
9+
pod 'LibTorch', '~>1.9.0'
1010
end

SpeechRecognition/README.md

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,63 @@
22

33
## Introduction
44

5-
Facebook AI's [wav2vec 2.0](https://github.com/pytorch/fairseq/tree/master/examples/wav2vec) is one of the leading models in speech recognition. It is also available in the [Huggingface Transformers](https://github.com/huggingface/transformers) library, which is also used in another PyTorch iOS demo app for [Question Answering](https://github.com/pytorch/ios-demo-app/tree/master/QuestionAnswering).
5+
Facebook AI's [wav2vec 2.0](https://github.com/pytorch/fairseq/tree/master/examples/wav2vec) is one of the leading models in speech recognition. It is also available in the [Hugging Face Transformers](https://github.com/huggingface/transformers) library, which is also used in another PyTorch iOS demo app for [Question Answering](https://github.com/pytorch/ios-demo-app/tree/master/QuestionAnswering).
66

7-
In this demo app, we'll show how to quantize, trace, and optimize the wav2vec2 model for mobile and how to use the converted model on an iOS demo app to perform speech recognition.
7+
In this demo app, we'll show how to quantize, trace, and optimize the wav2vec2 model, powered by the newly released torchaudio 0.9.0, and how to use the converted model on an iOS demo app to perform speech recognition.
88

99
## Prerequisites
1010

11-
* PyTorch 1.8.0/1.8.1 (Optional)
11+
* PyTorch 1.9.0 and torchaudio 0.9.0 (Optional)
1212
* Python 3.8 or above (Optional)
13-
* iOS PyTorch pod library 1.8.0
14-
* Xcode 12 or later
13+
* iOS PyTorch Cocoapods library LibTorch 1.9.0
14+
* Xcode 12.4 or later
1515

1616
## Quick Start
1717

18-
### 1. Prepare the Model
18+
### 1. Get the Repo
19+
20+
Simply run the commands below:
1921

20-
First, run the following commands on a Terminal:
2122
```
2223
git clone https://github.com/pytorch/ios-demo-app
2324
cd ios-demo-app/SpeechRecognition
2425
```
2526

26-
If you don't have PyTorch 1.8.1 installed or want to have a quick try of the demo app, you can download the quantized scripted wav2vec2 model file [here](https://drive.google.com/file/d/1RcCy3K3gDVN2Nun5IIdDbpIDbrKD-XVw/view?usp=sharing), then drag and drop to the project, and continue to Step 2.
27+
If you don't have PyTorch 1.9.0 and torchaudio 0.9.0 installed or want to have a quick try of the demo app, you can download the quantized scripted wav2vec2 model file [here](https://drive.google.com/file/d/1RcCy3K3gDVN2Nun5IIdDbpIDbrKD-XVw/view?usp=sharing), then drag and drop to the project, and continue to Step 3.
28+
29+
Be aware that the downloadable model file was created with PyTorch 1.9.0 and torchaudio 0.9.0, matching the iOS LibTorch library 1.9.0 specified in the `Podfile`. If you use a different version of PyTorch to create your model by following the instructions below, make sure you specify the same iOS LibTorch version in the `Podfile` to avoid possible errors caused by the version mismatch. Furthermore, if you want to use the latest prototype features in the PyTorch master branch to create the model, follow the steps at [Building PyTorch iOS Libraries from Source](https://pytorch.org/mobile/ios/#build-pytorch-ios-libraries-from-source) on how to use the model in iOS.
30+
31+
32+
### 2. Prepare the Model
33+
34+
To install PyTorch 1.9.0 and torchvision 0.10.0, you can do something like this:
2735

28-
Be aware that the downloadable model file was created with PyTorch 1.8, matching the iOS LibTorch library 1.8.0 specified in the `Podfile`. If you use a different version of PyTorch to create your model by following the instructions below, make sure you specify the same iOS LibTorch version in the `Podfile` to avoid possible errors caused by the version mismatch. Furthermore, if you want to use the latest prototype features in the PyTorch master branch to create the model, follow the steps at [Building PyTorch iOS Libraries from Source](https://pytorch.org/mobile/ios/#build-pytorch-ios-libraries-from-source) on how to use the model in iOS.
36+
```
37+
conda create -n wav2vec2 python=3.8.5
38+
conda activate wav2vec2
39+
pip install torch torchvision
40+
```
41+
42+
Now with PyTorch 1.9.0 and torchaudio 0.9.0 installed, run the following commands on a Terminal:
2943

30-
With PyTorch 1.8.1 installed, first install the `soundfile` package by running `pip install pysoundfile`, then install the Huggingface `transformers` by running `pip install transformers` (the version that has been tested is 4.4.2). Finally run `python create_wav2vec2.py`, which creates `wav2vec2.pt` in the project folder. [Dynamic quantization](https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html) is used to quantize the model to reduce its size.
44+
```
45+
python create_wav2vec2.py
46+
```
3147

32-
Note that the sample `scent_of_a_woman_future.wav` file used to trace the model is about 6 second long, so 6 second is the limit of the recorded audio for speech recognition in the demo app. If your speech is less than 6 seconds, padding is applied in the iOS code to make the model work correctly.
48+
This will create the model file `wav2vec2.pt` and save to the `SpeechRecognition` folder.
3349

3450
### 2. Use LibTorch
3551

3652
Run the commands below:
3753

3854
```
39-
cd SpeechRecognition
4055
pod install
4156
open SpeechRecognition.xcworkspace/
4257
```
4358

4459
### 3. Build and run with Xcode
4560

46-
After the app runs, tap the Start button and start saying something; after 6 seconds, the model will infer to recognize your speech. Only basic decoding of the recognition result, in the form of an array of floating numbers of logits, to a list of tokens is provided in this demo app, but it is easy to see, without further post-processing, whether the model can recognize your utterances. Some example results are as follows:
61+
After the app runs, tap the Start button and start saying something; after 12 seconds (you can change `private let AUDIO_LEN_IN_SECOND = 12` in `ViewController.swift` for the recording length), the model will infer to recognize your speech. Some example results are as follows:
4762

4863
![](screenshot1.png)
4964
![](screenshot2.png)

SpeechRecognition/SpeechRecognition/InferenceModule.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ NS_ASSUME_NONNULL_BEGIN
1414
- (nullable instancetype)initWithFileAtPath:(NSString*)filePath
1515
NS_SWIFT_NAME(init(fileAtPath:))NS_DESIGNATED_INITIALIZER;
1616
- (instancetype)init NS_UNAVAILABLE;
17-
- (nullable NSString*)recognize:(void*)wavBuffer NS_SWIFT_NAME(recognize(wavBuffer:));
17+
- (nullable NSString*)recognize:(void*)wavBuffer bufLength:(int)bufLength NS_SWIFT_NAME(recognize(wavBuffer:bufLength));
1818
@end
1919

2020
NS_ASSUME_NONNULL_END

SpeechRecognition/SpeechRecognition/InferenceModule.mm

Lines changed: 5 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
#import <AudioToolbox/AudioToolbox.h>
1414

1515

16-
const int MODEL_INPUT_LENGTH = 65024;
17-
const NSString *TOKENS[] = {@"<s>", @"<pad>", @"</s>", @"<unk>", @"|", @"E", @"T", @"A", @"O", @"N", @"I", @"H", @"S", @"R", @"D", @"L", @"U", @"M", @"W", @"C", @"F", @"G", @"Y", @"P", @"B", @"V", @"K", @"'", @"X", @"J", @"Q", @"Z"};
18-
1916
@implementation InferenceModule {
2017

2118
@protected torch::jit::script::Module _impl;
@@ -40,65 +37,26 @@ - (nullable instancetype)initWithFileAtPath:(NSString*)filePath {
4037
return self;
4138
}
4239

43-
- (int)argMax:(NSArray*)array {
44-
int maxIdx = 0;
45-
float maxVal = -FLT_MAX;
46-
for (int j = 0; j < [array count]; j++) {
47-
if ([array[j] floatValue]> maxVal) {
48-
maxVal = [array[j] floatValue];
49-
maxIdx = j;
50-
}
51-
}
52-
return maxIdx;
53-
}
54-
5540

56-
- (NSString*)recognize:(void*)wavBuffer {
41+
- (NSString*)recognize:(void*)wavBuffer bufLength:(int)bufLength{
5742
try {
58-
at::Tensor tensorInputs = torch::from_blob((void*)wavBuffer, {1, MODEL_INPUT_LENGTH}, at::kFloat);
43+
at::Tensor tensorInputs = torch::from_blob((void*)wavBuffer, {1, bufLength}, at::kFloat);
5944

6045
float* floatInput = tensorInputs.data_ptr<float>();
6146
if (!floatInput) {
6247
return nil;
6348
}
6449
NSMutableArray* inputs = [[NSMutableArray alloc] init];
65-
for (int i = 0; i < MODEL_INPUT_LENGTH; i++) {
50+
for (int i = 0; i < bufLength; i++) {
6651
[inputs addObject:@(floatInput[i])];
6752
}
6853

6954
torch::autograd::AutoGradMode guard(false);
7055
at::AutoNonVariableTypeMode non_var_type_mode(true);
7156

72-
auto outputDict = _impl.forward({ tensorInputs }).toGenericDict();
73-
74-
auto logitsTensor = outputDict.at("logits").toTensor();
75-
float* logitsBuffer = logitsTensor.data_ptr<float>();
76-
if (!logitsBuffer) {
77-
return nil;
78-
}
79-
80-
NSUInteger TOKEN_LENGTH = (NSUInteger) (sizeof(TOKENS) / sizeof(NSString*));
81-
int64_t output_len = logitsTensor.numel();
82-
NSMutableArray* logits = [[NSMutableArray alloc] init];
83-
NSString *result = @"";
84-
for (int i = 0; i < output_len; i++) {
85-
// for every 32 output values, get the argmax and its token
86-
if (i > 0 && i % TOKEN_LENGTH == 0) {
87-
int tid = [self argMax:logits];
88-
if (tid > 4)
89-
result = [NSString stringWithFormat:@"%@%@", result, TOKENS[tid]];
90-
else if (tid == 4)
91-
result = [NSString stringWithFormat:@"%@ ", result];
57+
auto result = _impl.forward({ tensorInputs }).toStringRef();
9258

93-
[logits removeAllObjects];
94-
[logits addObject:@(logitsBuffer[i])];
95-
}
96-
else {
97-
[logits addObject:@(logitsBuffer[i])];
98-
}
99-
}
100-
101-
return result;
59+
return [NSString stringWithCString:result.c_str() encoding:[NSString defaultCStringEncoding]];
10260
}
10361
catch (const std::exception& exception) {
10462
NSLog(@"%s", exception.what());

SpeechRecognition/SpeechRecognition/ViewController.swift

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ class ViewController: UIViewController, AVAudioRecorderDelegate {
2323

2424
private var audioRecorder: AVAudioRecorder!
2525
private var _recorderFilePath: String!
26+
27+
private let AUDIO_LEN_IN_SECOND = 12
28+
private let SAMPLE_RATE = 16000
2629

2730
private lazy var module: InferenceModule = {
2831
if let filePath = Bundle.main.path(forResource:
@@ -55,7 +58,7 @@ class ViewController: UIViewController, AVAudioRecorderDelegate {
5558

5659
let settings = [
5760
AVFormatIDKey: Int(kAudioFormatLinearPCM),
58-
AVSampleRateKey: 16000,
61+
AVSampleRateKey: SAMPLE_RATE,
5962
AVNumberOfChannelsKey: 1,
6063
AVLinearPCMBitDepthKey: 16,
6164
AVLinearPCMIsBigEndianKey: false,
@@ -67,7 +70,7 @@ class ViewController: UIViewController, AVAudioRecorderDelegate {
6770
_recorderFilePath = NSHomeDirectory().stringByAppendingPathComponent(path: "tmp").stringByAppendingPathComponent(path: "recorded_file.wav")
6871
audioRecorder = try AVAudioRecorder(url: NSURL.fileURL(withPath: _recorderFilePath), settings: settings)
6972
audioRecorder.delegate = self
70-
audioRecorder.record(forDuration: 6)
73+
audioRecorder.record(forDuration: TimeInterval(AUDIO_LEN_IN_SECOND))
7174
} catch let error {
7275
tvResult.text = "error:" + error.localizedDescription
7376
}
@@ -88,7 +91,7 @@ class ViewController: UIViewController, AVAudioRecorderDelegate {
8891

8992
DispatchQueue.global().async {
9093
floatArray.withUnsafeMutableBytes {
91-
let result = self.module.recognize(wavBuffer: $0.baseAddress!)
94+
let result = self.module.recognize($0.baseAddress!, bufLength: Int32(self.AUDIO_LEN_IN_SECOND * self.SAMPLE_RATE))
9295
DispatchQueue.main.async {
9396
self.tvResult.text = result
9497
self.btnStart.setTitle("Start", for: .normal)
Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,65 @@
1-
import soundfile as sf
21
import torch
3-
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
2+
from torch import Tensor
43
from torch.utils.mobile_optimizer import optimize_for_mobile
4+
import torchaudio
5+
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
6+
from transformers import Wav2Vec2ForCTC
57

6-
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
8+
# Wav2vec2 model emits sequences of probability (logits) distributions over the characters
9+
# The following class adds steps to decode the transcript (best path)
10+
class SpeechRecognizer(torch.nn.Module):
11+
def __init__(self, model):
12+
super().__init__()
13+
self.model = model
14+
self.labels = [
15+
"<s>", "<pad>", "</s>", "<unk>", "|", "E", "T", "A", "O", "N", "I", "H", "S",
16+
"R", "D", "L", "U", "M", "W", "C", "F", "G", "Y", "P", "B", "V", "K", "'", "X",
17+
"J", "Q", "Z"]
18+
19+
def forward(self, waveforms: Tensor) -> str:
20+
"""Given a single channel speech data, return transcription.
21+
22+
Args:
23+
waveforms (Tensor): Speech tensor. Shape `[1, num_frames]`.
24+
25+
Returns:
26+
str: The resulting transcript
27+
"""
28+
logits, _ = self.model(waveforms) # [batch, num_seq, num_label]
29+
best_path = torch.argmax(logits[0], dim=-1) # [num_seq,]
30+
prev = ''
31+
hypothesis = ''
32+
for i in best_path:
33+
char = self.labels[i]
34+
if char == prev:
35+
continue
36+
if char == '<s>':
37+
prev = ''
38+
continue
39+
hypothesis += char
40+
prev = char
41+
return hypothesis.replace('|', ' ')
42+
43+
44+
# Load Wav2Vec2 pretrained model from Hugging Face Hub
745
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
8-
model.eval()
9-
10-
audio_input, _ = sf.read("scent_of_a_woman_future.wav")
11-
input_values = tokenizer(audio_input, return_tensors="pt").input_values
12-
logits = model(input_values).logits
13-
predicted_ids = torch.argmax(logits, dim=-1)
14-
transcription = tokenizer.batch_decode(predicted_ids)[0]
15-
16-
model_dynamic_quantized = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
17-
traced_quantized_model = torch.jit.trace(model_dynamic_quantized, input_values, strict=False)
18-
optimized_traced_quantized_model = optimize_for_mobile(traced_quantized_model)
19-
optimized_traced_quantized_model.save("wav2vec2.pt")
46+
# Convert the model to torchaudio format, which supports TorchScript.
47+
model = import_huggingface_model(model)
48+
# Remove weight normalization which is not supported by quantization.
49+
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
50+
model = model.eval()
51+
# Attach decoder
52+
model = SpeechRecognizer(model)
53+
54+
# Apply quantization / script / optimize for motbile
55+
quantized_model = torch.quantization.quantize_dynamic(
56+
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
57+
scripted_model = torch.jit.script(quantized_model)
58+
optimized_model = optimize_for_mobile(scripted_model)
59+
60+
# Sanity check
61+
waveform , _ = torchaudio.load('scent_of_a_woman_future.wav')
62+
print(waveform.size())
63+
print('Result:', optimized_model(waveform))
64+
65+
optimized_model.save("SpeechRecognition/wav2vec2.pt")

0 commit comments

Comments
 (0)