Skip to content

Commit 8af3048

Browse files
hiworldwzjsangchengmengwangzaijunshihaobai
authored
add-qwen3-omni-thinker (#1208)
Co-authored-by: sangchengmeng <sangchengmeng@sensetime.com> Co-authored-by: wangzaijun <wangzaijun@sensetime.com> Co-authored-by: shihaobai <1798930569@qq.com>
1 parent a2a61ae commit 8af3048

20 files changed

Lines changed: 1318 additions & 9 deletions

File tree

lightllm/common/basemodel/layer_weights/base_layer_weight.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ def verify_load(self):
3333
for attr_name in dir(self):
3434
attr = getattr(self, attr_name)
3535
if isinstance(attr, BaseWeight):
36-
assert attr.verify_load(), f"Loading {attr_name} of layers {self.layer_num_} fails."
36+
if hasattr(self, "layer_num_"):
37+
layer_num = self.layer_num_
38+
else:
39+
layer_num = None
40+
assert attr.verify_load(), f"Loading {attr_name} of layers {layer_num} fails."
3741

3842
def _cuda(self, cpu_tensor):
3943
return cpu_tensor.contiguous().to(self.data_type_).cuda(get_current_device_id())

lightllm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@
3737
Tarsier2LlamaTpPartModel,
3838
)
3939
from lightllm.models.gpt_oss.model import GptOssTpPartModel
40+
from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel
4041
from .registry import get_model, get_model_class

lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ def _get_mrope_position_triton(
2828
local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i)
2929
image_start_idx = start_loc + local_image_start_idx - cache_len
3030
image_len = tl.load(b_image_len + image_start_num + i)
31-
image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1)
31+
# image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1)
3232
image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2)
3333
for j in range(0, image_len, BLOCK_SIZE):
3434
off = j + tl.arange(0, BLOCK_SIZE)
3535
# 目前没考虑视频,所以t 恒为 0
3636
t_pos = local_image_start_idx + off * 0
37-
h_pos = local_image_start_idx + off // image_h
37+
h_pos = local_image_start_idx + off // image_w
3838
w_pos = local_image_start_idx + off % image_w
3939
tl.store(
4040
position_ids + off + image_start_idx,

lightllm/models/qwen3_omni_moe_thinker/__init__.py

Whitespace-only changes.
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import torch
2+
import numpy as np
3+
from typing import TYPE_CHECKING, Any, Optional, Union, Tuple
4+
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
5+
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
6+
from transformers.feature_extraction_utils import BatchFeature
7+
from transformers.utils import TensorType
8+
9+
10+
class WhisperFeatureExtractor(SequenceFeatureExtractor):
11+
12+
model_input_names = ["input_features"]
13+
14+
def __init__(
15+
self,
16+
feature_size=80,
17+
sampling_rate=16000,
18+
hop_length=160,
19+
chunk_length=30,
20+
n_fft=400,
21+
padding_value=0.0,
22+
dither=0.0,
23+
return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
24+
**kwargs,
25+
):
26+
super().__init__(
27+
feature_size=feature_size,
28+
sampling_rate=sampling_rate,
29+
padding_value=padding_value,
30+
return_attention_mask=return_attention_mask,
31+
**kwargs,
32+
)
33+
self.n_fft = n_fft
34+
self.hop_length = hop_length
35+
self.chunk_length = chunk_length
36+
self.n_samples = chunk_length * sampling_rate
37+
self.nb_max_frames = self.n_samples // hop_length
38+
self.sampling_rate = sampling_rate
39+
self.dither = dither
40+
self.mel_filters = mel_filter_bank(
41+
num_frequency_bins=1 + n_fft // 2,
42+
num_mel_filters=feature_size,
43+
min_frequency=0.0,
44+
max_frequency=8000.0,
45+
sampling_rate=sampling_rate,
46+
norm="slaney",
47+
mel_scale="slaney",
48+
)
49+
50+
def _torch_extract_fbank_features(self, waveform: np.ndarray, device: str = "cpu") -> np.ndarray:
51+
waveform = torch.from_numpy(waveform).to(device, torch.float32)
52+
window = torch.hann_window(self.n_fft, device=device)
53+
54+
if self.dither != 0.0:
55+
waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device)
56+
57+
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
58+
magnitudes = stft[..., :-1].abs() ** 2
59+
60+
mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
61+
mel_spec = mel_filters.T @ magnitudes
62+
63+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
64+
if waveform.dim() == 2:
65+
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
66+
log_spec = torch.maximum(log_spec, max_val - 8.0)
67+
else:
68+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
69+
log_spec = (log_spec + 4.0) / 4.0
70+
if device != "cpu":
71+
log_spec = log_spec.detach().cpu()
72+
return log_spec.numpy()
73+
74+
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.
75+
# Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
76+
def zero_mean_unit_var_norm(
77+
self, input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0
78+
) -> list[np.ndarray]:
79+
if attention_mask is not None:
80+
attention_mask = np.array(attention_mask, np.int32)
81+
normed_input_values = []
82+
83+
for vector, length in zip(input_values, attention_mask.sum(-1)):
84+
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
85+
if length < normed_slice.shape[0]:
86+
normed_slice[length:] = padding_value
87+
88+
normed_input_values.append(normed_slice)
89+
else:
90+
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
91+
92+
return normed_input_values
93+
94+
def _preprocess(
95+
self,
96+
raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
97+
truncation: bool = True,
98+
pad_to_multiple_of: Optional[int] = None,
99+
return_tensors: Optional[Union[str, TensorType]] = None,
100+
return_attention_mask: Optional[bool] = None,
101+
padding: Optional[str] = "longest", # max_length代表padding到max_length
102+
max_length: Optional[int] = None,
103+
sampling_rate: Optional[int] = 16000,
104+
do_normalize: Optional[bool] = None,
105+
device: Optional[str] = "cpu",
106+
return_token_timestamps: Optional[bool] = None,
107+
**kwargs,
108+
) -> Tuple[torch.Tensor, torch.Tensor]:
109+
110+
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
111+
if is_batched_numpy and len(raw_speech.shape) > 2:
112+
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
113+
is_batched = is_batched_numpy or (
114+
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
115+
)
116+
117+
if is_batched:
118+
raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
119+
elif not is_batched and not isinstance(raw_speech, np.ndarray):
120+
raw_speech = np.asarray(raw_speech, dtype=np.float32)
121+
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
122+
raw_speech = raw_speech.astype(np.float32)
123+
124+
# always return batch
125+
if not is_batched:
126+
raw_speech = [np.asarray([raw_speech]).T]
127+
128+
batched_speech = BatchFeature({"input_features": raw_speech})
129+
130+
# convert into correct format for padding
131+
132+
padded_inputs = self.pad(
133+
batched_speech,
134+
padding=padding,
135+
max_length=max_length if max_length else self.n_samples,
136+
truncation=truncation,
137+
pad_to_multiple_of=pad_to_multiple_of,
138+
return_attention_mask=return_attention_mask or do_normalize,
139+
)
140+
141+
# zero-mean and unit-variance normalization
142+
if do_normalize:
143+
padded_inputs["input_features"] = self.zero_mean_unit_var_norm(
144+
padded_inputs["input_features"],
145+
attention_mask=padded_inputs["attention_mask"],
146+
padding_value=self.padding_value,
147+
)
148+
padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], axis=0)
149+
150+
# make sure list is in array format
151+
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
152+
153+
input_features = self._torch_extract_fbank_features(input_features[0], device)
154+
155+
if isinstance(input_features[0], list):
156+
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
157+
158+
else:
159+
padded_inputs["input_features"] = input_features
160+
161+
if return_attention_mask:
162+
# rescale from sample (48000) to feature (3000)
163+
rescaled_attention_mask = padded_inputs["attention_mask"][:, :: self.hop_length]
164+
165+
# The STFT computation produces L//hop_length + 1 frames,
166+
# but we skip the last frame (see `_torch_extract_fbank_features`).
167+
# This means we need to trim the rescaled attention mask to match
168+
# the actual number of frames (L//hop_length) when the input length
169+
# is not perfectly divisible by the hop length.
170+
if padded_inputs["attention_mask"].shape[1] % self.hop_length != 0:
171+
rescaled_attention_mask = rescaled_attention_mask[:, :-1]
172+
padded_inputs["attention_mask"] = rescaled_attention_mask
173+
174+
if return_token_timestamps is not None:
175+
padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech]
176+
177+
if return_tensors is not None:
178+
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
179+
input_features = torch.from_numpy(np.asarray(padded_inputs["input_features"], dtype=np.float32)).to(
180+
device="cuda", dtype=torch.bfloat16
181+
)
182+
attention_mask = torch.from_numpy(np.asarray(padded_inputs["attention_mask"], dtype=np.float32)).to(
183+
device="cuda", dtype=torch.int32
184+
)
185+
return input_features, attention_mask
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
2+
3+
4+
class Qwen3OmniMOEInferStateInfo(Qwen3VLInferStateInfo):
5+
def __init__(self):
6+
super().__init__()
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
from lightllm.models.qwen3_vl_moe.layer_infer.transformer_layer_infer import Qwen3VLMOETransformerLayerInfer
3+
from lightllm.utils.log_utils import init_logger
4+
5+
logger = init_logger(__name__)
6+
7+
8+
class Qwen3OmniMOETransformerLayerInfer(Qwen3VLMOETransformerLayerInfer):
9+
def __init__(self, layer_num, network_config):
10+
super().__init__(layer_num, network_config)
11+
self.head_dim_ = network_config["head_dim"]
12+
self.mrope_section = torch.tensor(
13+
network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda"
14+
)
15+
return
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
2+
from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight
3+
4+
5+
class Qwen3OmniMOEThinkerPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
6+
def __init__(self, data_type, network_config):
7+
super().__init__(data_type, network_config)
8+
9+
hidden_size = network_config["hidden_size"]
10+
vocab_size = network_config["vocab_size"]
11+
self.wte_weight_ = EmbeddingWeight(
12+
dim=hidden_size,
13+
vocab_size=vocab_size,
14+
weight_name="thinker.model.embed_tokens.weight",
15+
data_type=self.data_type_,
16+
)
17+
tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False)
18+
self.lm_head_weight_ = LMHeadWeight(
19+
dim=hidden_size,
20+
vocab_size=vocab_size,
21+
weight_name="thinker.lm_head.weight",
22+
data_type=self.data_type_,
23+
embedding_weight=self.wte_weight_ if tie_word_embeddings else None,
24+
)
25+
self.final_norm_weight_ = RMSNormWeight(
26+
dim=hidden_size,
27+
weight_name="thinker.model.norm.weight",
28+
data_type=self.data_type_,
29+
)
30+
return
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import os
2+
from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
3+
from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight
4+
5+
6+
class Qwen3OmniMOEThinkerTransformerLayerWeight(Qwen3MOETransformerLayerWeight):
7+
def __init__(self, layer_num, data_type, network_config, quant_cfg=None):
8+
super().__init__(layer_num, data_type, network_config, quant_cfg)
9+
return
10+
11+
def _init_weight_names(self):
12+
self._q_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.q_proj.weight"
13+
self._q_norm_name = f"thinker.model.layers.{self.layer_num_}.self_attn.q_norm.weight"
14+
self._q_bias_name = None
15+
self._k_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.k_proj.weight"
16+
self._k_norm_name = f"thinker.model.layers.{self.layer_num_}.self_attn.k_norm.weight"
17+
self._k_bias_name = None
18+
self._v_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.v_proj.weight"
19+
self._v_bias_name = None
20+
self._kv_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.kv_proj.weight"
21+
self._kv_bias_name = None
22+
self._o_weight_name = f"thinker.model.layers.{self.layer_num_}.self_attn.o_proj.weight"
23+
self._o_bias_name = None
24+
self._att_norm_weight_name = f"thinker.model.layers.{self.layer_num_}.input_layernorm.weight"
25+
self._att_norm_bias_name = None
26+
self._ffn_norm_weight_name = f"thinker.model.layers.{self.layer_num_}.post_attention_layernorm.weight"
27+
self._ffn_norm_bias_name = None
28+
29+
def _init_moe(self):
30+
moe_intermediate_size = self.network_config_["moe_intermediate_size"]
31+
self.moe_gate = ROWMMWeight(
32+
in_dim=self.network_config_["hidden_size"],
33+
out_dims=[self.n_routed_experts],
34+
weight_names=f"thinker.model.layers.{self.layer_num_}.mlp.gate.weight",
35+
data_type=self.data_type_,
36+
quant_method=None,
37+
tp_rank=0,
38+
tp_world_size=1,
39+
)
40+
self.experts = FusedMoeWeight(
41+
gate_proj_name="gate_proj",
42+
down_proj_name="down_proj",
43+
up_proj_name="up_proj",
44+
e_score_correction_bias_name="",
45+
weight_prefix=f"thinker.model.layers.{self.layer_num_}.mlp.experts",
46+
n_routed_experts=self.n_routed_experts,
47+
hidden_size=self.network_config_["hidden_size"],
48+
moe_intermediate_size=moe_intermediate_size,
49+
data_type=self.data_type_,
50+
quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"),
51+
layer_num=self.layer_num_,
52+
network_config=self.network_config_,
53+
)

0 commit comments

Comments
 (0)