-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtranscriptor.py
More file actions
359 lines (292 loc) · 13.7 KB
/
transcriptor.py
File metadata and controls
359 lines (292 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import os
import torch
import scipy
from itertools import groupby
import numpy as np
from pydub import AudioSegment
import librosa
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from faster_whisper import WhisperModel
from config import Config
from speaker_recognize import SpeakerVerifier
from speech_enhance import SpeechEnhance
class Transcriptor:
def __init__(self):
self.samplerate = 16000
self.epoch = 0
self.load_models(Config.models)
self.preheat(Config.preheat_audio)
def load_models(self, models):
asr_config = models.get("asr")
vad_config = models.get("vad")
self.asr_model = WhisperModel(
model_size_or_path = asr_config["path"],
device = asr_config["device"],
local_files_only = False,
compute_type = asr_config["compute_type"]
)
self.speaker_verifier = SpeakerVerifier()
if Config.vad.get("enable"):
self.vad_model, _ = torch.hub.load(
repo_or_dir = vad_config["path"],
model = 'silero_vad',
trust_repo = None,
source = 'local',
)
else:
self.vad_model = None
se_config = Config.speech_enhance
if se_config.get("enable"):
self.speech_enhance = SpeechEnhance(
model_name=se_config.get("model_name"),
target_lufs=se_config.get("target_lufs"),
true_peak_limit=se_config.get("true_peak_limit"),
mute_if_too_quiet=se_config.get("mute_if_too_quiet"),
threshold_dbfs=se_config.get("threshold_dbfs"),
)
else:
self.speech_enhance = None
if Config.filter_match.get("enable"):
self.vectorizer = TfidfVectorizer()
else:
self.vectorizer = None
self.whisper_config = Config.whisper_config
if self.whisper_config.get("tradition_to_simple"):
import opencc
self.cc_model = opencc.OpenCC('t2s.json')
else:
self.cc_model = None
def preheat(self, preheat_audio):
preheat_audio_, _ = librosa.load(preheat_audio, sr=self.samplerate, dtype=np.float32)
self.asr_model.transcribe(
preheat_audio_,
beam_size = self.whisper_config.get("beam_size"),
best_of = self.whisper_config.get("best_of"),
patience = self.whisper_config.get("patience"),
suppress_blank = self.whisper_config.get("suppress_blank"),
repetition_penalty = self.whisper_config.get("repetition_penalty"),
log_prob_threshold = self.whisper_config.get("log_prob_threshold"),
no_speech_threshold = self.whisper_config.get("no_speech_threshold"),
condition_on_previous_text = self.whisper_config.get("condition_on_previous_text"),
initial_prompt = self.whisper_config.get("initial_prompt"),
hotwords = self.whisper_config.get("hotwords_text"),
prefix = self.whisper_config.get("previous_text_prefix"),
temperature = self.whisper_config.get("temperature"),
)
def dump(self, final, audio_buffer):
dump_config = Config.dump
save_mode = dump_config.get("audio_save")
if save_mode not in ["all", "final"]:
return
if save_mode == "final" and not final:
return
audio_dir = dump_config.get("audio_dir")
if not os.path.exists(audio_dir):
os.makedirs(audio_dir)
self.epoch += 1
audio_path = os.path.join(audio_dir, f"{self.epoch:06d}.wav")
scipy.io.wavfile.write(audio_path, rate=self.samplerate, data=audio_buffer)
def vad_rm_silence(self, audio_chunk):
vad_config = Config.vad
vad_flags = []
chunk_num = len(audio_chunk) // 512
sampling_rate = vad_config.get("sampling_rate")
sampling_per_chunk = vad_config.get("sampling_per_chunk")
for i in range(chunk_num):
chunk = audio_chunk[i*sampling_per_chunk:(i+1)*sampling_per_chunk]
chunk_torch = torch.tensor(chunk).unsqueeze(0)
silero_score = self.vad_model(chunk_torch, sampling_rate).item()
# 如果人生检测概率大于阈值,则认为有语音
if silero_score > vad_config.get("vad_threshold"):
vad_flags.append(1)
else:
vad_flags.append(0)
# print("vad_flags: ", vad_flags)
# 如果语音时间小于最小语音时间,则认为没有语音,直接返回空
voice_duration = vad_flags.count(1)
if voice_duration < vad_config.get("min_voice_duration"):
return None
# 如果静音时间小于最小静音时间,则认为没有静音,直接返回原始音频
silence_duration = vad_flags.count(0)
if silence_duration < vad_config.get("min_silence_duration"):
return audio_chunk
# 删除静音部分,但是语音前后均保留 silence_reserve 个采样点
silence_reserve = vad_config.get("silence_reserve")
# 找到所有语音段的起始和结束位置
indices = []
for flag, group in groupby(enumerate(vad_flags), lambda x: x[1]):
if flag == 1: # 语音段
group = list(group)
start = group[0][0]
end = group[-1][0]
indices.append((start, end))
# print("indices: ", indices)
split_chunk = []
for start, end in indices:
# 计算保留的前后静音区间
# print("start: ", (start - silence_reserve), "end: ", (end + 1 + silence_reserve))
start_sample = max(0, (start - silence_reserve) * sampling_per_chunk)
end_sample = min(len(audio_chunk), (end + 1 + silence_reserve) * sampling_per_chunk)
split_chunk.extend(audio_chunk[start_sample:end_sample])
if len(split_chunk) > 0:
return np.array(split_chunk, dtype=np.float32)
else:
return None
def filter(self, text):
filter_match = Config.filter_match
for match_text in filter_match.get("find_match"):
if text.find(match_text) != -1:
return ""
for match_text in filter_match.get("cos_match"):
tfidf_matrix = self.vectorizer.fit_transform([match_text, text])
cos_sim = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])
if cos_sim > filter_match.get("cos_sim"):
return ""
return text
def transcript(self, audio_buffer, last_speaker, last_sentence):
whisper_config = Config.whisper_config
initial_prompt = whisper_config.get("initial_prompt")
if whisper_config.get("previous_text_prompt"):
initial_prompt += last_sentence
hotwords = whisper_config.get("hotwords_text")
if whisper_config.get("previous_text_hotwords"):
hotwords += last_sentence
prefix_text = None
if whisper_config.get("previous_text_prefix"):
prefix_text = last_sentence
interruption_duration = whisper_config.get("interruption_duration")
segments, info = self.asr_model.transcribe(
audio_buffer,
beam_size = whisper_config.get("beam_size"),
best_of = whisper_config.get("best_of"),
patience = whisper_config.get("patience"),
suppress_blank = whisper_config.get("suppress_blank"),
repetition_penalty = whisper_config.get("repetition_penalty"),
log_prob_threshold = whisper_config.get("log_prob_threshold"),
no_speech_threshold = whisper_config.get("no_speech_threshold"),
condition_on_previous_text = whisper_config.get("condition_on_previous_text"),
initial_prompt = initial_prompt,
hotwords = hotwords,
prefix = prefix_text,
temperature = whisper_config.get("temperature"),
)
# print("transcript info: ", info)
final = False
speaker = last_speaker
sentence = last_sentence
transcript = ""
new_buffer = audio_buffer
# 计算音频时长
audio_duration = len(audio_buffer) / self.samplerate
# 获取转录结果
generated_segments = []
for segment in segments:
generated_segments.append(segment)
num_segments = len(generated_segments)
if num_segments == 0:
# 如果转录结果为空,则直接返回
return False, speaker, sentence, transcript, new_buffer
elif num_segments == 1:
# 如果只有一段,则记录转录信息
# print("log: ", generated_segments[0].avg_logprob)
if generated_segments[0].avg_logprob > whisper_config.get("log_prob_threshold"):
transcript = generated_segments[0].text
else:
transcript = ""
# 如果音频时长超过最大中断时长,则认为中断结束
if audio_duration > interruption_duration:
print(f"Warning: audio buffer over {interruption_duration} seconds, interrupt")
speaker = self.speaker_verifier.match_speaker(audio_buffer)
sentence = transcript
transcript = ""
new_buffer = np.array([],dtype=np.float32)
final = True
else:
final = False
self.dump(final, audio_buffer)
elif num_segments >= 2:
# 如果有多段,则截取最后一段
sentence = ""
for i in range(num_segments - 1):
sentence += generated_segments[i].text
# print("log: ", generated_segments[num_segments - 1].avg_logprob)
if generated_segments[num_segments - 1].avg_logprob > whisper_config.get("log_prob_threshold"):
transcript = generated_segments[num_segments - 1].text
else:
transcript = ""
# 截取最后一段音频作为新的音频缓冲区
cut_point = int(generated_segments[num_segments - 2].end * self.samplerate)
last_buffer = audio_buffer[:cut_point]
speaker = self.speaker_verifier.match_speaker(last_buffer)
new_buffer = audio_buffer[cut_point:]
final = True
self.dump(final, last_buffer)
if whisper_config.get("tradition_to_simple"):
# 繁体到简体
transcript = self.cc_model.convert(transcript)
return final, speaker, sentence, transcript, new_buffer
def inference(self, audio_data, last_speaker, last_sentence, last_transcript, last_buffer):
if Config.speech_enhance.get("enable"):
# 语音增强
audio_data = self.speech_enhance.enhance(audio_data, self.samplerate)
if Config.vad.get("enable"):
# vad 过滤静音
audio_data = self.vad_rm_silence(audio_data)
# 如果 audio_data 为空,不做转录
if audio_data is None:
if len(last_buffer) > 0 and len(last_transcript) > 0:
# 如果 last_buffer 不为空,则视为结束,完整句子为 last_transcript ,新的转录结果为空,新的音频缓冲区为空
self.dump(True, last_buffer)
speaker = self.speaker_verifier.match_speaker(last_buffer)
new_buffer = np.array([],dtype=np.float32)
return True, speaker, last_transcript, "", new_buffer
else:
# 如果 last_buffer 为空,则视为未结束
return False, last_speaker, last_sentence, last_transcript, last_buffer
# 合并 last_buffer 和 chunk_audio
audio_buffer = np.concatenate([last_buffer, audio_data])
# 转录,last_sentence 为上一段转录的完整句子,可作为 prompt 或 hotwords
final, speaker, sentence, transcript, new_buffer = self.transcript(audio_buffer, last_speaker, last_sentence)
# 过滤幻觉词
sentence = self.filter(sentence)
transcript = self.filter(transcript)
return final, speaker, sentence, transcript, new_buffer
if __name__ == "__main__":
transcriptor = Transcriptor()
# 读取音频文件
audio = AudioSegment.from_file("./examples/asr_example.wav")
audio = audio.set_frame_rate(transcriptor.samplerate)
# 设置音频数据为 int16 格式
audio = audio.set_sample_width(2)
# 将双声道转换为单声道
if audio.channels == 2:
audio = audio.set_channels(1)
# 打印信息
print(f"采样率: {audio.frame_rate} Hz")
print(f"样本宽度: {audio.sample_width} 字节")
print(f"音频时长: {len(audio) / 1000} 秒")
samples = np.array(audio.get_array_of_samples())
# print(samples.shape)
last_speaker = "guest"
last_sentence = ""
last_transcript = ""
last_buffer = np.array([],dtype=np.float32)
# 按 1 秒的频率读取数据
audio_size = 16384 # 每秒的样本数
for i in range(0, len(samples), audio_size):
audio_data = samples[i:i + audio_size]
audio_f32 = audio_data.astype(np.float32) / 32768.0
final, speaker, sentence, transcript, new_buffer = transcriptor.inference(
audio_f32, last_speaker, last_sentence, last_transcript, last_buffer)
if final:
print("\r\033[K", end="", flush=True)
print(f"{speaker}: {sentence}")
print(transcript, end="", flush=True)
else:
print("\r\033[K", end="", flush=True)
print(transcript, end="", flush=True)
last_sentence = sentence
last_transcript = transcript
last_buffer = new_buffer
print("")