diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index 6d397cc99..f8d75d92b 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -17,13 +17,13 @@ import onnxruntime import torch import numpy as np -import whisper from typing import Callable import torchaudio.compliance.kaldi as kaldi import os import re import inflect from cosyvoice.utils.file_utils import logging, load_wav +from cosyvoice.utils.audio_utils import log_mel_spectrogram from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation @@ -95,7 +95,7 @@ def _extract_text_token_generator(self, text_generator): def _extract_speech_token(self, prompt_wav): speech = load_wav(prompt_wav, 16000) assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s' - feat = whisper.log_mel_spectrogram(speech, n_mels=128) + feat = log_mel_spectrogram(speech, n_mels=128) speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(), diff --git a/cosyvoice/dataset/processor.py b/cosyvoice/dataset/processor.py index deba209ba..ce1491f79 100644 --- a/cosyvoice/dataset/processor.py +++ b/cosyvoice/dataset/processor.py @@ -17,12 +17,12 @@ import pyarrow.parquet as pq from io import BytesIO import numpy as np -import whisper import torch import torchaudio from torch.nn.utils.rnn import pad_sequence import torch.nn.functional as F import pyworld as pw +from cosyvoice.utils.audio_utils import log_mel_spectrogram from cosyvoice.utils.onnx import embedding_extractor, online_feature AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} @@ -193,7 +193,7 @@ def compute_whisper_fbank(data, num_frames=-1, mode='train'): if num_frames != -1: assert sample['speech'].shape[1] % num_frames == 0, 'speech length is not aligned with speech_token' sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech']) - sample['whisper_feat'] = whisper.log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1) + sample['whisper_feat'] = log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1) yield sample diff --git a/cosyvoice/tokenizer/tokenizer.py b/cosyvoice/tokenizer/tokenizer.py index 6ecf4ae84..fab635231 100644 --- a/cosyvoice/tokenizer/tokenizer.py +++ b/cosyvoice/tokenizer/tokenizer.py @@ -4,8 +4,6 @@ from typing import Optional import torch from transformers import AutoTokenizer -from whisper.tokenizer import Tokenizer - import tiktoken LANGUAGES = { @@ -213,7 +211,7 @@ def get_tokenizer( num_languages: int = 99, language: Optional[str] = None, task: Optional[str] = None, # Literal["transcribe", "translate", None] -) -> Tokenizer: +): if language is not None: language = language.lower() if language not in LANGUAGES: @@ -233,6 +231,7 @@ def get_tokenizer( encoding = get_encoding(name=encoding_name, num_languages=num_languages) + from whisper.tokenizer import Tokenizer return Tokenizer( encoding=encoding, num_languages=num_languages, language=language, task=task ) diff --git a/cosyvoice/utils/audio_utils.py b/cosyvoice/utils/audio_utils.py new file mode 100644 index 000000000..b7330cf28 --- /dev/null +++ b/cosyvoice/utils/audio_utils.py @@ -0,0 +1,56 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torchaudio + + +def log_mel_spectrogram(audio, n_mels=128, n_fft=400, hop_length=160, sample_rate=16000): + """Compute a log-mel spectrogram from a waveform tensor. + + This is a drop-in replacement for ``whisper.log_mel_spectrogram`` that uses + only ``torch`` and ``torchaudio``, avoiding the heavy ``openai-whisper`` + dependency. The output is numerically equivalent for the default Whisper + parameters (n_fft=400, hop_length=160, sample_rate=16000). + + Args: + audio: 1-D or 2-D float tensor of raw audio at *sample_rate* Hz. + n_mels: Number of mel-frequency bins. + n_fft: FFT window size. + hop_length: Hop length for STFT. + sample_rate: Expected sample rate of *audio*. + + Returns: + Tensor of shape ``(n_mels, n_frames)`` (if 1-D input) or + ``(batch, n_mels, n_frames)`` (if 2-D input). + """ + window = torch.hann_window(n_fft).to(audio.device) + stft = torch.stft(audio, n_fft, hop_length, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + mel_filters = torchaudio.functional.melscale_fbanks( + n_freqs=n_fft // 2 + 1, + f_min=0.0, + f_max=sample_rate / 2.0, + n_mels=n_mels, + sample_rate=sample_rate, + norm="slaney", + mel_scale="slaney", + ).to(audio.device) + + mel_spec = mel_filters.T @ magnitudes + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.amax() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec diff --git a/requirements.txt b/requirements.txt index 989bccf61..366e9054e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,6 @@ omegaconf==2.3.0 onnx==1.16.0 onnxruntime-gpu==1.18.0; sys_platform == 'linux' onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'win32' -openai-whisper==20231117 protobuf==4.25 pyarrow==18.1.0 pydantic==2.7.0 diff --git a/tools/extract_speech_token.py b/tools/extract_speech_token.py index 976a23b48..e1404174c 100755 --- a/tools/extract_speech_token.py +++ b/tools/extract_speech_token.py @@ -20,7 +20,7 @@ import onnxruntime import numpy as np import torchaudio -import whisper +from cosyvoice.utils.audio_utils import log_mel_spectrogram def single_job(utt): @@ -34,7 +34,7 @@ def single_job(utt): logging.warning('do not support extract speech token for audio longer than 30s') speech_token = [] else: - feat = whisper.log_mel_spectrogram(audio, n_mels=128) + feat = log_mel_spectrogram(audio, n_mels=128) speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(), ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() return utt, speech_token