Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions cosyvoice/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import os
import json
import soundfile as sf
import torch
import torchaudio
import logging
Expand All @@ -42,8 +43,17 @@ def read_json_lists(list_file):


def load_wav(wav, target_sr, min_sr=16000):
speech, sample_rate = torchaudio.load(wav, backend='soundfile')
speech = speech.mean(dim=0, keepdim=True)
# Use soundfile directly to avoid the torchcodec dependency introduced
# in torchaudio >= 2.7, where torchaudio.load() routes all backends
# through TorchCodec (which requires FFmpeg 5+ not shipped by Ubuntu 22.04).
# libsndfile reads from the current cursor position; CosyVoice's frontend
# passes the same file-like object to multiple load_wav calls, so reset.
if hasattr(wav, 'seek'):
wav.seek(0)
data, sample_rate = sf.read(wav, dtype='float32', always_2d=False)
if data.ndim > 1:
data = data.mean(axis=1)
speech = torch.from_numpy(data).unsqueeze(0)
if sample_rate != target_sr:
assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
Expand Down
18 changes: 16 additions & 2 deletions cosyvoice/vllm/cosyvoice2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Optional
from typing import Iterable, Optional, Union
from packaging.version import parse as vparse
import vllm

Expand Down Expand Up @@ -80,7 +80,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model.make_empty_intermediate_tensors)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
if hasattr(self.model, "get_input_embeddings"):
return self.model.get_input_embeddings(input_ids)

return self.model.embed_input_ids(input_ids)

# vLLM >= 0.20 introduced the VllmModelForTextGeneration runtime-checkable
# protocol (vllm/model_executor/models/interfaces_base.py). Its
# _check_vllm_model_embed_input_ids probe looks for embed_input_ids
# specifically; without it, is_text_generation_model() returns False and
# ModelConfig validation raises "This model does not support `--runner
# generate`". The underlying vLLM Qwen2Model exposes embed_input_ids
# (which internally calls self.embed_tokens); there is no
# get_input_embeddings method on vLLM's Qwen2Model.
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

def forward(
self,
Expand Down