diff --git a/README.md b/README.md index baefb76..5c1afc7 100644 --- a/README.md +++ b/README.md @@ -233,6 +233,275 @@ print(output_text) ``` +## Ming SDK + +Ming SDK provides a simple and easy-to-use Python API for quickly integrating the multimodal capabilities of Ming-flash-omni 2.0. + +### SDK Features + +- **Unified API Interface**: Supports text generation, speech synthesis, image generation/editing, and more +- **Streaming Output Support**: Supports streaming generation for text and speech, suitable for real-time interaction scenarios +- **Flexible Device Configuration**: Supports multi-GPU deployment and memory optimization +- **Complete Usage Statistics**: Provides detailed statistics on token usage, audio duration, etc. + +### SDK Installation + +#### Install VLLM +pip install vllm-0.8.5.post3.dev90+gffc0d5a3f.ant-cp310-cp310-linux_x86_64.whl + +#### Install Ming SDK +#### Option 1: Build from Source + +```bash +# 1. Clone the repository +git clone https://github.com/inclusionAI/Ming.git +cd Ming + +# 2. Install dependencies +pip install -r requirements.txt + +python ming_sdk/setup.py bdist_wheel + +pip3 install dist/ming_sdk-1.0.0-py3-none-any.whl +``` + +### SDK Usage Examples + +#### Initialize SDK + +```python +from ming_sdk import Ming + +# Configuration parameters +model_path = "your model path" # Model path +device = "0,1,2,3" # GPU devices, supports multi-GPU parallelism +gpu_memory_utilization = {"moe": 0.8, "talker": 0.17} # GPU memory utilization +device_map = {"talker": ["cuda:0"]} # Module device mapping + +# Initialize Ming instance +ming = Ming( + model_path=model_path, + device=device, + gpu_memory_utilization=gpu_memory_utilization, + device_map=device_map, + speaker="DB30", # TTS speaker ID + with_async = True, + use_talker = True +) +``` + +#### Text Generation + +```python +# Non-streaming text generation +text, usage = ming.generate(text="介绍一下杭州") +print(f"text:{text}") +print(f"usage:{usage}") +assert text is not None + + +# Streaming text generation +all_text = "" +request_id = "" +for text, request_id, usage in ming.generate_stream( + text="介绍一下杭州", max_new_tokens=128 +): + all_text += text +print(f"request_id:{request_id},text={all_text},usage={usage}") +assert text is not None +print(f"\nFull text: {all_text}") +``` +#### Speech QA +```python +# Speech QA +output_audio_path = "test.wav" +waveform, gen_text, usage = ming.generate( + text="介绍一下杭州", output_type="speech", max_new_tokens=128 +) +sr = 44100 +torchaudio.save(output_audio_path, waveform, sr) +assert os.path.exists(output_audio_path) +print(f"request_id:{gen_text},usage={usage}") + + +# Streaming speech QA +all_wavs = [] +all_text = "" +request_id = "" +output_audio_path = "test_stream.wav" +for data_type, data_content in ming.generate_stream( + text="介绍一下杭州", output_type="speech", max_new_tokens=128 +): + if data_type == "text_data": + text, usage = data_content + elif data_type == "text_audio_data": + tts_speech, text, meta_info, session_id, usage = data_content + all_text += text + all_wavs.append(tts_speech) +waveform = torch.cat(all_wavs, dim=-1) +sr = 44100 +torchaudio.save(output_audio_path, waveform, sr) +print( + f"request_id:{request_id},audio:{output_audio_path},text={all_text},usage={usage}" +) +assert os.path.exists(output_audio_path) + + +# Streaming speech QA with interruption +all_wavs = [] +all_text = "" +request_id = "" +output_audio_path = "test_stream.wav" +for data_type, data_content in ming.generate_stream( + text="介绍一下杭州", output_type="speech", max_new_tokens=128 +): + if data_type == "text_data": + text, usage = data_content + elif data_type == "text_audio_data": + tts_speech, text, meta_info, session_id, usage = data_content + all_text += text + all_wavs.append(tts_speech) + if len(all_text) > 20: + ming.generate_interrupt(request_id) +waveform = torch.cat(all_wavs, dim=-1) +sr = 44100 +torchaudio.save(output_audio_path, waveform, sr) +print(f"request_id:{request_id},audio:{output_audio_path},text={all_text}") +assert os.path.exists(output_audio_path) + +``` + +#### ASR Task +```python +# ASR +asr_result, usage = ming.generate( + text="Please recognize the language of this speech and transcribe it. Format: oral.", + audio="https://example.com/audio.wav", +) +print(f"asr_result:{asr_result},usage={usage}") +assert asr_result is not None +``` + +#### Text-to-Speech (TTS) + +```python +import torchaudio + +# Non-streaming TTS +waveform, usage = ming.generate( + text="我爱北京故宫", + output_type="speech" +) +torchaudio.save("output_tts.wav", waveform, 44100) + +# Streaming TTS +all_wavs = [] +all_text = "" +for data_type, data_content in ming.generate_stream( + text="我爱北京故宫", + output_type="speech" +): + if data_type == "text_audio_data": + tts_speech, sentence, meta_info, session_id, usage = data_content + all_text += sentence + all_wavs.append(tts_speech) + +# Save audio +waveform = torch.cat(all_wavs, dim=-1) +torchaudio.save("output_tts_stream.wav", waveform, 44100) +``` + +#### Speech-to-Speech + +```python +# Non-streaming speech-to-speech +waveform, gen_text, usage = ming.generate( + audio="https://example.com/audio.wav", + output_type="speech", + max_new_tokens=128 +) +torchaudio.save("output_speech.wav", waveform, 44100) + +# Streaming speech-to-speech +all_wavs = [] +all_text = "" +for data_type, data_content in ming.generate_stream( + audio="https://example.com/audio.wav", + output_type="speech", + max_new_tokens=128 +): + if data_type == "text_data": + text, usage = data_content + elif data_type == "text_audio_data": + tts_speech, text, meta_info, session_id, usage = data_content + all_text += text + all_wavs.append(tts_speech) + +waveform = torch.cat(all_wavs, dim=-1) +torchaudio.save("output_speech_stream.wav", waveform, 16000) +``` + + +#### Video Understanding + +```python +# Video QA +text, usage = ming.generate( + text="详细描述一下这段视频", + video="test.mp4", + output_type="text" +) +print(f"Video description: {text}") +``` + +#### Request Interruption + +```python +# You can interrupt the request during streaming generation +msg_request_id = "your-request-id" +for data_type, data_content in ming.generate_stream( + text="介绍一下杭州", + output_type="speech", + msg_request_id=msg_request_id +): + # Interrupt when condition is met + if some_condition: + ming.generate_interrupt(msg_request_id) + break +``` + +### Parameter Reference + +#### Ming Initialization Parameters + +| Parameter | Type | Default | Description | +|------|------|--------|------| +| `model_path` | str | Required | Model weights path, must contain config.json and am.mvn | +| `sys_prompt` | str | "" | System prompt, prepended to all conversations | +| `device` | str | "0" | GPU device IDs, comma-separated for multi-GPU, e.g., "0,1,2,3" | +| `gpu_memory_utilization` | dict | {"moe": 0.6, "talker": 0.1} | GPU memory utilization for each module | +| `device_map` | dict | {"talker": ["cuda:0"], "image": "cuda:0"} | Mapping from modules to GPUs | +| `speaker` | str | "DB30" | TTS speaker ID | +| `quantization` | str \| None | None | Quantization configuration | +| `use_talker` | bool | True | Whether to load TTS module | + +#### generate Method Parameters + +| Parameter | Type | Default | Description | +|------|------|--------|------| +| `text` | str \| None | None | Input text | +| `audio` | str \| bytes \| List | None | Audio input (file path/binary/list) | +| `video` | str \| bytes \| List | None | Video input (file path/binary/list) | +| `image` | str \| bytes \| List | None | Image input (file path/binary/PIL Image/list) | +| `history` | list | [] | Conversation history | +| `output_type` | str | "text" | Output type: text/speech/image/tts | +| `max_new_tokens` | int | 512 | Maximum number of tokens to generate | + +### Complete Examples + +For more complete examples, please refer to [ming_sdk/ming_test.py](ming_sdk/ming_test.py). + + ## Citation If you find our work helpful, feel free to give us a cite. diff --git a/bailingmm_utils_video.py b/bailingmm_utils_video.py index 277e423..61889f1 100644 --- a/bailingmm_utils_video.py +++ b/bailingmm_utils_video.py @@ -286,22 +286,21 @@ def v1_smart_nframes( int: the number of frames for video used for model inputs. """ assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" - - min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) - max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR) - + max_frames = max( + 1, + floor_by_factor( + ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR + ), + ) if "nframes" in ele: nframes = min(total_frames, round_by_factor(ele["nframes"], FRAME_FACTOR), max_frames) else: fps = ele.get("max_video_fps", FPS) - nframes = total_frames / video_fps * fps + nframes = max(1, total_frames / video_fps * fps) if nframes > total_frames: logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]") - nframes = min(min(max(nframes, min_frames), max_frames), total_frames) - nframes = floor_by_factor(nframes, FRAME_FACTOR) - if not (FRAME_FACTOR <= nframes <= total_frames): - raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") - return nframes + nframes = min(min(nframes, max_frames), total_frames) + return int(nframes) def v1_sample_video(video_fps, total_frames, ele: dict) -> List[int]: @@ -367,8 +366,9 @@ def v1_fetch_video( return_metadata: bool = False, ) -> torch.Tensor | list[Image.Image]: if isinstance(ele["video"], str): - video, smp_fps = load_video(ele["video"], sampler=v2_sample_video) - + video, smp_fps = load_video( + ele["video"], sampler=partial(v1_sample_video, ele=ele) + ) if "resized_height" in ele and "resized_width" in ele: resized_height, resized_width = smart_resize( ele["resized_height"], diff --git a/configuration_bailing_moe_v2.py b/configuration_bailing_moe_v2.py index 4caf222..09d782e 100644 --- a/configuration_bailing_moe_v2.py +++ b/configuration_bailing_moe_v2.py @@ -54,6 +54,7 @@ def __init__( self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.use_qkv_bias = use_qkv_bias + self.use_qk_norm = use_qk_norm self.use_bias = use_bias self.norm_head = norm_head self.rms_norm_eps = rms_norm_eps diff --git a/configuration_bailingmm.py b/configuration_bailingmm.py new file mode 100644 index 0000000..fa1c571 --- /dev/null +++ b/configuration_bailingmm.py @@ -0,0 +1,41 @@ +# coding=utf-8 +# Copyright 2024 ANT Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import PretrainedConfig +from configuration_bailing_moe import BailingMoeConfig +from audio_tokenizer.configuration_audio_vae import AudioVAEconfig + + +class BailingMMConfig(PretrainedConfig): + model_type = "bailingmm" + + def __init__( + self, + llm_config: BailingMoeConfig = None, + audio_tokenizer_config: AudioVAEconfig = None, + ditar_config: dict = None, + aggregator_config: dict = None, + model_type: str = None, + **kwargs + ): + self.model_type = model_type + if self.model_type == 'dense': + self.llm_config = llm_config + else: + self.llm_config = BailingMoeConfig(**llm_config) if isinstance(llm_config, dict) else llm_config + self.audio_tokenizer_config = AudioVAEconfig(**audio_tokenizer_config) if isinstance(audio_tokenizer_config, dict) else audio_tokenizer_config + self.ditar_config = ditar_config + self.aggregator_config = aggregator_config + super().__init__(**kwargs) \ No newline at end of file diff --git a/diffusion/pipeline_z_image.py b/diffusion/pipeline_z_image.py index 22c0c31..04b551d 100644 --- a/diffusion/pipeline_z_image.py +++ b/diffusion/pipeline_z_image.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Portions of the implementations are adapted from https://github.com/Tongyi-MAI/Z-Image/blob/main/src/zimage/pipeline.py. -# Based on this code, we made modifications and extensions, including adding Image-Editing,Classifier-free guidance functionality, to better support training for Ming-Omni image generation. +# Portions of the implementations are adapted from https://github.com/Tongyi-MAI/Z-Image/blob/main/src/zimage/pipeline.py. +# Based on this code, we made modifications and extensions, including adding Image-Editing,Classifier-free guidance functionality, to better support training for Ming-Omni image generation. # All rights and credit for the original implementation remain with the original authors and contributors, and this project complies with the applicable open-source license terms of the referenced repository. import inspect @@ -185,6 +185,7 @@ def __init__( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.taylor_cache = False def encode_prompt( self, @@ -554,22 +555,31 @@ def __call__( prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds timestep_model_input = timestep.repeat(2) ref_hidden_states_input = ref_hidden_states.repeat(2, 1, 1, 1) if ref_hidden_states is not None else None + if ref_hidden_states_input is not None: + ref_hidden_states_input = ref_hidden_states_input.to(latent_model_input.dtype) else: latent_model_input = latents.to(self.transformer.dtype) prompt_embeds_model_input = prompt_embeds timestep_model_input = timestep ref_hidden_states_input = ref_hidden_states*1.0 if ref_hidden_states is not None else None + if ref_hidden_states_input is not None: + ref_hidden_states_input = ref_hidden_states_input.to(latent_model_input.dtype) latent_model_input = latent_model_input.unsqueeze(2) latent_model_input_list = list(latent_model_input.unbind(dim=0)) - + if ref_hidden_states_input is not None: ref_hidden_states_input = ref_hidden_states_input.unsqueeze(2) ref_hidden_states_input = list(ref_hidden_states_input.unbind(dim=0)) - model_out_list = self.transformer( - latent_model_input_list, timestep_model_input, prompt_embeds_model_input, ref_hidden_states=ref_hidden_states_input, return_dict=False - )[0] + if self.taylor_cache: + model_out_list = self.transformer( + latent_model_input_list, timestep_model_input, prompt_embeds_model_input, ref_hidden_states=ref_hidden_states_input, return_dict=False, step=i, + )[0] + else: + model_out_list = self.transformer( + latent_model_input_list, timestep_model_input, prompt_embeds_model_input, ref_hidden_states=ref_hidden_states_input, return_dict=False + )[0] if apply_cfg: # Perform CFG @@ -635,3 +645,6 @@ def __call__( return (image,) return ZImagePipelineOutput(images=image) + + def set_taylor_cache(self): + self.taylor_cache = True diff --git a/diffusion/transformer_z_image.py b/diffusion/transformer_z_image.py index 72ef570..63303c2 100644 --- a/diffusion/transformer_z_image.py +++ b/diffusion/transformer_z_image.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Portions of the implementations are adapted from https://github.com/Tongyi-MAI/Z-Image/blob/main/src/zimage/transformer.py. -# Based on this code, we made modifications and extensions, including adding Image-Editing functionality, to better support training for Ming-Omni image generation. +# Portions of the implementations are adapted from https://github.com/Tongyi-MAI/Z-Image/blob/main/src/zimage/transformer.py. +# Based on this code, we made modifications and extensions, including adding Image-Editing functionality, to better support training for Ming-Omni image generation. # All rights and credit for the original implementation remain with the original authors and contributors, and this project complies with the applicable open-source license terms of the referenced repository. import math @@ -550,6 +550,7 @@ def forward( f_patch_size=1, ref_x=None, return_dict: bool = True, + step=None, ): assert patch_size in self.all_patch_size assert f_patch_size in self.all_f_patch_size @@ -645,15 +646,61 @@ def forward( for i, seq_len in enumerate(unified_item_seqlens): unified_attn_mask[i, :seq_len] = 1 - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.layers: - unified = self._gradient_checkpointing_func( - layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input - ) + + if step is not None: + if step == 0: + self.saved_steps = [] + self.saved_features = [] + + #[0, 1, 2, 3, 4, 5, 7, 9, 11, 14, 17, 20, 22, 24, 26, 27, 28, 29]: + if step in [0, 1, 2, 3, 4, 5, 7, 9, 12, 15, 18, 21, 23, 25, 26, 27, 28, 29]: + for layer in self.layers: + unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) + + self.saved_steps.append(step) + self.saved_features.append(unified) + else: + assert len(self.saved_features) >= 3 + assert len(self.saved_steps) >= 3 + # 1. 提取最近三个点的数据 + f1, f2, f3 = self.saved_features[-1], self.saved_features[-2], self.saved_features[-3] + t1, t2, t3 = self.saved_steps[-1], self.saved_steps[-2], self.saved_steps[-3] + + # 2. 计算步长 + dt1 = t1 - t2 + dt2 = t2 - t3 + dt_next = 1.0 # 或者是你想要预测的未来跨度 + + # 3. 计算一阶导数 (速度) + v1 = (f1 - f2) / dt1 + v2 = (f2 - f3) / dt2 + + # 4. 计算二阶导数 (加速度) + # 注意:这里分母是时间中心点的差值 + a = (v1 - v2) / ((t1 - t3) / 2) + + # 5. 二阶泰勒展开预测 + unified = f1 + v1 * dt_next + 0.5 * a * (dt_next ** 2) + else: + # if torch.is_grad_enabled() and self.gradient_checkpointing: + # for layer in self.layers: + # unified = self._gradient_checkpointing_func( + # layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input + # ) + # else: for layer in self.layers: unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) + # if torch.is_grad_enabled() and self.gradient_checkpointing: + # for layer in self.layers: + # unified = self._gradient_checkpointing_func( + # layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input + # ) + # else: + # for layer in self.layers: + # unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) + unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) unified = list(unified.unbind(dim=0)) x = self.unpatchify(unified, x_size, patch_size, f_patch_size) diff --git a/diffusion/transformer_z_image_xfuser.py b/diffusion/transformer_z_image_xfuser.py new file mode 100644 index 0000000..386e3a7 --- /dev/null +++ b/diffusion/transformer_z_image_xfuser.py @@ -0,0 +1,331 @@ +import torch +import math +from torch.nn.utils.rnn import pad_sequence +from typing import List, Optional + +#from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel +from .transformer_z_image import ZImageTransformer2DModel +from diffusers.models.attention_processor import Attention +from diffusers.models.modeling_outputs import Transformer2DModelOutput + + +from xfuser.model_executor.layers.usp import USP + +from xfuser.core.distributed import ( + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_sp_group, +) + +import torch.nn.functional as F + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 + +class xFuserZSingleStreamAttnProcessor: + """ + Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the + original Z-ImageAttention module. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Apply Norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + # Transpose for attention + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Compute joint attention + hidden_states = USP( + query, + key, + value, + dropout_p=0.0, + is_causal=False, + ) + + # Transpose back to original shape + hidden_states = hidden_states.transpose(1, 2) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = attn.to_out[0](hidden_states) + if len(attn.to_out) > 1: # dropout + output = attn.to_out[1](output) + + return output + +class xFuserZImageTransformer2DWrapper(ZImageTransformer2DModel): + + def __init__( + self, + **kwargs + ): + super().__init__( + **kwargs + ) + for layer in self.layers: + layer.attention.processor = xFuserZSingleStreamAttnProcessor() + + + def _chunk_and_pad_sequence(self, x: torch.Tensor, sp_world_rank: int, sp_world_size: int, pad_amount: int, dim: int) -> torch.Tensor: + if pad_amount > 0: + if dim < 0: + dim = x.ndim + dim + pad_shape = list(x.shape) + pad_shape[dim] = pad_amount + x = torch.cat([x, + torch.zeros( + pad_shape, + dtype=x.dtype, + device=x.device, + )], dim=dim) + x = torch.chunk(x, + sp_world_size, + dim=dim)[sp_world_rank] + return x + + def _gather_and_unpad(self, x: torch.Tensor, pad_amount: int, dim: int) -> torch.Tensor: + x = get_sp_group().all_gather(x, dim=dim) + size = x.size(dim) + return x.narrow(dim=dim, start=0, length=size - pad_amount) + + def forward( + self, + x: List[torch.Tensor], + t, + cap_feats: List[torch.Tensor], + patch_size=2, + f_patch_size=1, + ref_hidden_states=None, + return_dict: bool = True, + step=None, + ): + assert patch_size in self.all_patch_size + assert f_patch_size in self.all_f_patch_size + sp_world_rank = get_sequence_parallel_rank() + sp_world_size = get_sequence_parallel_world_size() + + cap_feats = [F.normalize(i, dim=-1) * 1000.0 for i in cap_feats] + + + bsz = len(x) + device = x[0].device + t = t * self.t_scale + t = self.t_embedder(t) + + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size, ref_hidden_states) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + # Match t_embedder output dtype to x for layerwise casting compatibility + adaln_input = t.type_as(x) + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + x_freqs_cis = x_freqs_cis[:, : x.shape[1]] + + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + # SP support + # Leaving these here for posteriority - these would make the model slightly faster + # but causes slight numerical differences. + + # pad_amount = (sp_world_size - (x.shape[1] % sp_world_size)) % sp_world_size + # x = self._chunk_and_pad_sequence(x, sp_world_rank, sp_world_size, pad_amount, dim=-2) + # x_attn_mask = self._chunk_and_pad_sequence(x_attn_mask, sp_world_rank, sp_world_size, pad_amount, dim=-1) + # x_freqs_cis_chunked = self._chunk_and_pad_sequence(x_freqs_cis, sp_world_rank, sp_world_size, pad_amount, dim=-2) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.noise_refiner: + x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) + else: + for layer in self.noise_refiner: + x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) + + # Gather SP outputs and remove padding + #x = self._gather_and_unpad(x, pad_amount, dim=-2) + + # cap embed & refine + cap_item_seqlens = [len(_) for _ in cap_feats] + assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] + + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + + # SP support + # Same as above comments about speed vs numerical differences apply here. + + # pad_amount = (sp_world_size - (cap_feats.shape[1] % sp_world_size)) % sp_world_size + # cap_feats = self._chunk_and_pad_sequence(cap_feats, sp_world_rank, sp_world_size, pad_amount, dim=-2) + # cap_attn_mask = self._chunk_and_pad_sequence(cap_attn_mask, sp_world_rank, sp_world_size, pad_amount, dim=-1) + # cap_freqs_cis_chunked = self._chunk_and_pad_sequence(cap_freqs_cis, sp_world_rank, sp_world_size, pad_amount, dim=-2) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.context_refiner: + cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) + else: + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + + # Gather SP outputs and remove padding + # cap_feats = self._gather_and_unpad(cap_feats, pad_amount, dim=-2) + + + # unified + unified = [] + unified_freqs_cis = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) + unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + assert unified_item_seqlens == [len(_) for _ in unified] + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + # SP support + pad_amount = (sp_world_size - (unified.shape[1] % sp_world_size)) % sp_world_size + unified = self._chunk_and_pad_sequence(unified, sp_world_rank, sp_world_size, pad_amount, dim=-2) + unified_attn_mask = self._chunk_and_pad_sequence(unified_attn_mask, sp_world_rank, sp_world_size, pad_amount, dim=-1) + unified_freqs_cis = self._chunk_and_pad_sequence(unified_freqs_cis, sp_world_rank, sp_world_size, pad_amount, dim=-2) + + if step is not None: + if step == 0: + self.saved_steps = [] + self.saved_features = [] + + #[0, 1, 2, 3, 4, 5, 7, 9, 11, 14, 17, 20, 22, 24, 26, 27, 28, 29]: + if step in [0, 1, 2, 3, 4, 5, 7, 9, 12, 15, 18, 21, 23, 25, 26, 27, 28, 29]: + for layer in self.layers: + unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) + + self.saved_steps.append(step) + self.saved_features.append(unified) + else: + assert len(self.saved_features) >= 3 + assert len(self.saved_steps) >= 3 + # 1. 提取最近三个点的数据 + f1, f2, f3 = self.saved_features[-1], self.saved_features[-2], self.saved_features[-3] + t1, t2, t3 = self.saved_steps[-1], self.saved_steps[-2], self.saved_steps[-3] + + # 2. 计算步长 + dt1 = t1 - t2 + dt2 = t2 - t3 + dt_next = 1.0 # 或者是你想要预测的未来跨度 + + # 3. 计算一阶导数 (速度) + v1 = (f1 - f2) / dt1 + v2 = (f2 - f3) / dt2 + + # 4. 计算二阶导数 (加速度) + # 注意:这里分母是时间中心点的差值 + a = (v1 - v2) / ((t1 - t3) / 2) + + # 5. 二阶泰勒展开预测 + unified = f1 + v1 * dt_next + 0.5 * a * (dt_next ** 2) + + else: + # if torch.is_grad_enabled() and self.gradient_checkpointing: + # for layer in self.layers: + # unified = self._gradient_checkpointing_func( + # layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input + # ) + # else: + for layer in self.layers: + unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) + + # Gather SP outputs and remove padding + unified = self._gather_and_unpad(unified, pad_amount, dim=-2) + + unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) + unified = list(unified.unbind(dim=0)) + x = self.unpatchify(unified, x_size, patch_size, f_patch_size) + + if not return_dict: + return (x,) + + return Transformer2DModelOutput(sample=x) + diff --git a/image_gen_dit_server.py b/image_gen_dit_server.py new file mode 100644 index 0000000..d2c1389 --- /dev/null +++ b/image_gen_dit_server.py @@ -0,0 +1,539 @@ +import os +import time +import torch +import ray +import io +import logging +import base64 +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from typing import Optional +import argparse + +from xfuser import ( + xFuserArgs, +) +# from xfuser.model_executor.models.transformers.transformer_z_image import xFuserZImageTransformer2DWrapper +from diffusers import DiffusionPipeline +from xfuser.core.distributed import ( + get_world_group, + #get_runtime_state, + initialize_runtime_state, + #is_dp_last_group, +) +from bailingmm_utils import process_ratio +from PIL import Image +import base64 +import pickle +import numpy as np +from diffusers import AutoencoderKL + +def dict_to_base64(obj) -> str: + def encode(x): + if isinstance(x, torch.Tensor): + t = x.detach().cpu().contiguous() + arr = t.numpy() + return { + "__type__": "torch_tensor", + "dtype": str(arr.dtype), + "shape": arr.shape, + "data": arr.tobytes(), # raw bytes + } + elif isinstance(x, dict): + return {k: encode(v) for k, v in x.items()} + elif isinstance(x, (list, tuple)): + y = [encode(v) for v in x] + return {"__type__": "tuple", "items": y} if isinstance(x, tuple) else y + else: + return x + + packed = encode(obj) + blob = pickle.dumps(packed, protocol=pickle.HIGHEST_PROTOCOL) + return base64.b64encode(blob).decode("utf-8") + + +def base64_to_dict(s: str): + def decode(x): + if isinstance(x, dict) and x.get("__type__") == "torch_tensor": + arr = np.frombuffer(x["data"], dtype=np.dtype(x["dtype"])).reshape(x["shape"]) + return torch.from_numpy(arr) + elif isinstance(x, dict) and x.get("__type__") == "tuple": + return tuple(decode(v) for v in x["items"]) + elif isinstance(x, dict): + return {k: decode(v) for k, v in x.items()} + elif isinstance(x, list): + return [decode(v) for v in x] + else: + return x + + blob = base64.b64decode(s.encode("utf-8")) + packed = pickle.loads(blob) + return decode(packed) + + +def run_pipe(pipe: DiffusionPipeline, gen_param_b64, logger): #prompt, steps, seed): + # Pipe implementation currently encodes the prompt in-place, + # causing any subsequent calls to use the already encoded prompt as prompt, + # causing cascading encodings unless we provide a new list each time. + #prompt = str(input_config.prompt) + + print("run_pipe") + + #is_last_process = get_world_group().rank == get_world_group().world_size - 1 + # if is_last_process: + # import datetime + # ts = datetime.datetime.now().strftime("%Y%m%d%H%M%S") # 不含空格/冒号 + # filename = os.path.join("personal_gen_param_b64", f"{ts}.txt") + # with open(filename, "a", encoding="utf-8") as f: + # f.write(f"{gen_param_b64}\n") + + gen_param = base64_to_dict(gen_param_b64) + #"gen_param_b64": dict_to_base64(task), + condition_embeds = gen_param["image_gen_condition_embeds"].to(pipe.transformer.device) + negative_condition_embeds = gen_param["image_gen_negative_condition_embeds"].to(pipe.transformer.device) + image_gen_pixel_values_reference = gen_param["image_gen_pixel_values_reference"] if "image_gen_pixel_values_reference" in gen_param else None + image_gen_seed = gen_param["image_gen_seed"] if "image_gen_seed" in gen_param else None + image_gen_cfg = gen_param["image_gen_cfg"] if "image_gen_cfg" in gen_param else 2.0 + + image_gen_height = gen_param["image_gen_height"] + image_gen_width = gen_param["image_gen_width"] + image_gen_highres = gen_param["image_gen_highres"] + if image_gen_height is None or image_gen_width is None: + if isinstance(image_gen_highres, int): + image_gen_height, image_gen_width = [image_gen_highres] * condition_embeds.shape[0], [image_gen_highres] * condition_embeds.shape[0] + elif image_gen_highres is True: + image_gen_height, image_gen_width = [1024] * condition_embeds.shape[0], [1024] * condition_embeds.shape[0] + else: + image_gen_height, image_gen_width = [512] * condition_embeds.shape[0], [512] * condition_embeds.shape[0] + elif isinstance(image_gen_height, torch.Tensor) or isinstance(image_gen_width, torch.Tensor): + assert isinstance(image_gen_height, torch.Tensor), image_gen_height + assert isinstance(image_gen_width, torch.Tensor), image_gen_width + image_gen_height = image_gen_height.cpu().tolist() + image_gen_width = image_gen_width.cpu().tolist() + assert len(image_gen_height) == condition_embeds.shape[0] + assert len(image_gen_width) == condition_embeds.shape[0] + elif isinstance(image_gen_height, int) or isinstance(image_gen_width, int): + assert isinstance(image_gen_height, int), image_gen_height + assert isinstance(image_gen_width, int), image_gen_width + image_gen_height = [image_gen_height] * condition_embeds.shape[0] + image_gen_width = [image_gen_width] * condition_embeds.shape[0] + else: + assert isinstance(image_gen_height, list), image_gen_height + assert isinstance(image_gen_width, list), image_gen_width + assert len(image_gen_height) == condition_embeds.shape[0] + assert len(image_gen_width) == condition_embeds.shape[0] + + + image_gen_height_diffusion_list = [] + image_gen_width_diffusion_list = [] + image_gen_output_resize_height = [] + image_gen_output_resize_width = [] + for height, width in zip(image_gen_height, image_gen_width): + closest_size, resize_size = process_ratio(ori_h=height, ori_w=width, highres=image_gen_highres) + height, width = closest_size + image_gen_height_diffusion_list.append(height) + image_gen_width_diffusion_list.append(width) + height, width = resize_size + image_gen_output_resize_height.append(height) + image_gen_output_resize_width.append(width) + + image_gen_height = image_gen_height_diffusion_list[0] + assert all([i == image_gen_height for i in image_gen_height_diffusion_list]) + image_gen_width = image_gen_width_diffusion_list[0] + assert all([i == image_gen_width for i in image_gen_width_diffusion_list]) + + if image_gen_pixel_values_reference is not None: + assert (image_gen_height, image_gen_width) == (image_gen_pixel_values_reference.shape[-2], image_gen_pixel_values_reference.shape[-1]) + + if image_gen_seed is None or image_gen_seed < 0: + from datetime import datetime + image_gen_seed = datetime.now().microsecond % 1000 + + logger.info(f"condition_embeds.shape {condition_embeds.shape}") + logger.info(f"negative_condition_embeds.shape {negative_condition_embeds.shape}") + logger.info(f"height {image_gen_height}") + logger.info(f"height {image_gen_width}") + logger.info(f"guidance_scale {image_gen_cfg}") + logger.info(f"seed {image_gen_seed}") + + image = pipe( + prompt_embeds=list(condition_embeds.unbind(0)), + negative_prompt_embeds=list(negative_condition_embeds.unbind(0)), + height=image_gen_height, + width=image_gen_width, + num_inference_steps=30, # Recommended value + guidance_scale=image_gen_cfg, # Recommended value + generator=torch.manual_seed(image_gen_seed), + max_sequence_length=512, + ref_hidden_states=image_gen_pixel_values_reference, + ).images + + image = [i.resize((w, h)) for i, w, h in zip(image, image_gen_output_resize_width, image_gen_output_resize_height)] + + return image + + # prompt_embeds=encoder_hidden_states, + # negative_prompt_embeds=[en*0 for en in encoder_hidden_states], + # guidance_scale=cfg, + # #image_guidance_scale=image_cfg, + # #guidance_scale_mode=cfg_mode, + # generator=torch.manual_seed(seed), + # num_inference_steps=steps, + # height=height, + # width=width, + # max_sequence_length=512, + # device=self.device, + # #extra_vit_input=extra_vit_input, + # ref_hidden_states=ref_x, + # #use_dynamic_shifting=use_dynamic_shifting + + +# Define request model +class GenerateRequest(BaseModel): + # prompt: str + # num_inference_steps: Optional[int] = 50 + # seed: Optional[int] = 42 + # cfg: Optional[float] = 7.5 + # save_disk_path: Optional[str] = None + # height: Optional[int] = 1024 + # width: Optional[int] = 1024 + gen_param_b64: str + + # # Add input validation + # class Config: + # json_schema_extra = { + # "example": { + # "prompt": "a beautiful landscape", + # "num_inference_steps": 50, + # "seed": 42, + # "cfg": 7.5, + # "height": 1024, + # "width": 1024 + # } + # } + +app = FastAPI() + +@ray.remote(num_gpus=1) +class ImageGenerator: + def __init__(self, xfuser_args: xFuserArgs, rank: int, world_size: int, use_taylor_cache=False): + # Set PyTorch distributed environment variables + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" + + self.rank = rank + self.setup_logger() + self.initialize_model(xfuser_args, use_taylor_cache=use_taylor_cache) + + def setup_logger(self): + self.logger = logging.getLogger(__name__) + # Add console handler if not already present + if not self.logger.handlers: + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) + self.logger.setLevel(logging.INFO) + + def initialize_model(self, xfuser_args : xFuserArgs, use_taylor_cache=False): + + # init distributed environment in create_config + self.engine_config, self.input_config = xfuser_args.create_config() + print(self.engine_config) + + local_rank = get_world_group().local_rank + + print("self.engine_config.model_config.model", self.engine_config.model_config.model) + model_name_or_path = self.engine_config.model_config.model + + import sys + sys.path.insert(0, model_name_or_path) + + from diffusion.transformer_z_image_xfuser import xFuserZImageTransformer2DWrapper + from diffusion.pipeline_z_image import ZImagePipeline + from diffusers import FlowMatchEulerDiscreteScheduler + + + #"/nativemm/share/cpfs/weilong.cwl/checkpoints/flash_v2_xpo_final_20260205_hf_metax_ais16893664" + #"/nativemm/share/cpfs/weilong.cwl/checkpoints/bailing_native_moe_ming_flash_v2.0_xpo_final_20260205_vllm_new" + + zimage_model_path = os.path.join(model_name_or_path, "pipeline") + #zimage_model_path = "/nativemm/share/cpfs/weilong.cwl/checkpoints/Z-Image-Turbo" + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_name_or_path, subfolder="scheduler") + noise_scheduler.config['use_dynamic_shifting'] = True + + + transformer = xFuserZImageTransformer2DWrapper.from_pretrained( + model_name_or_path, subfolder="transformer", + torch_dtype=torch.bfloat16, + ) + + vae = AutoencoderKL.from_pretrained( + model_name_or_path, + subfolder="vae", + torch_dtype=torch.bfloat16, + ) + + self.pipe = ZImagePipeline.from_pretrained( + zimage_model_path, + transformer=transformer, + text_encoder=None, + tokenizer=None, + scheduler=noise_scheduler, + vae=vae, + ) + if use_taylor_cache: + print("use_taylor_cache") + self.pipe.set_taylor_cache() + + # from modeling_bailingmm2 import BailingMM2NativeForConditionalGeneration + # import time + # model = BailingMM2NativeForConditionalGeneration.from_pretrained( + # self.engine_config.model_config.model, + # torch_dtype=torch.bfloat16, + # attn_implementation="flash_attention_2", + # load_image_gen=True, + # load_image_gen_others=False, + # load_vlm=False, + # device_map=local_rank, + # image_gen_seq_parallel=True, + # ).to(dtype=torch.bfloat16) + # model.eval() + # model.diffusion_loss.pipelines.transformer.config.num_attention_heads = model.diffusion_loss.pipelines.transformer.config.n_heads + # model.diffusion_loss.pipelines.transformer.config.patch_size = model.diffusion_loss.pipelines.transformer.config.all_patch_size + # model.diffusion_loss.pipelines.transformer.config.attention_head_dim = model.diffusion_loss.pipelines.transformer.config.axes_dims[-1] + + + # #print(pipeline.transformer.config) + # self.pipe = model.diffusion_loss.pipelines + + + #is_last_process = get_world_group().rank == get_world_group().world_size - 1 + + # transformer = xFuserZImageTransformer2DWrapper.from_pretrained( + # self.engine_config.model_config.model, + # torch_dtype=torch.bfloat16, + # subfolder="transformer", + # ) + # self.pipe = ZImagePipeline.from_pretrained( + # pretrained_model_name_or_path=self.engine_config.model_config.model, + # engine_config=self.engine_config, + # transformer=transformer, + # torch_dtype=torch.bfloat16, + # ) + + # self.pipe = self.pipe.to(f"cuda:{local_rank}") + #parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + self.pipe.transformer.config.num_attention_heads = self.pipe.transformer.config.n_heads + self.pipe.transformer.config.patch_size = self.pipe.transformer.config.all_patch_size + self.pipe.transformer.config.attention_head_dim = self.pipe.transformer.config.axes_dims[-1] + + print(self.pipe._execution_device) + local_rank = get_world_group().local_rank + self.pipe = self.pipe.to(f"cuda:{local_rank}") + print(self.pipe._execution_device) + + initialize_runtime_state(self.pipe, self.engine_config) + + # model_name = self.engine_config.model_config.model.split("/")[-1] + # pipeline_map = { + # "PixArt-XL-2-1024-MS": xFuserPixArtAlphaPipeline, + # "PixArt-Sigma-XL-2-2K-MS": xFuserPixArtSigmaPipeline, + # "stable-diffusion-3-medium-diffusers": xFuserStableDiffusion3Pipeline, + # "stabilityai__stable-diffusion-3-medium-diffusers": xFuserStableDiffusion3Pipeline, + # "HunyuanDiT-v1.2-Diffusers": xFuserHunyuanDiTPipeline, + # "FLUX.1-schnell": xFuserFluxPipeline, + # "FLUX.1-dev": xFuserFluxPipeline, + # } + + # PipelineClass = pipeline_map.get(model_name) + # if PipelineClass is None: + # raise NotImplementedError(f"{model_name} is currently not supported!") + + # self.logger.info(f"Initializing model {model_name} from {xfuser_args.model}") + + # self.pipe = PipelineClass.from_pretrained( + # pretrained_model_name_or_path=xfuser_args.model, + # engine_config=self.engine_config, + # torch_dtype=torch.float16, + # ).to("cuda") + + # self.pipe.prepare_run(self.input_config) + self.logger.info("Model initialization completed") + + def generate(self, request: GenerateRequest): + # try: + # start_time = time.time() + # print("generate", len(request.gen_param_b64)) + # output = run_pipe(self.pipe, request.gen_param_b64) + # #, request.num_inference_steps, request.seed) + + # # output = self.pipe( + # # height=request.height, + # # width=request.width, + # # prompt=request.prompt, + # # num_inference_steps=request.num_inference_steps, + # # output_type="pil", + # # generator=torch.Generator(device="cuda").manual_seed(request.seed), + # # guidance_scale=request.cfg, + # # max_sequence_length=self.input_config.max_sequence_length + # # ) + # elapsed_time = time.time() - start_time + + # is_last_process = get_world_group().rank == get_world_group().world_size - 1 + + # #if self.pipe.is_dp_last_group(): + # if is_last_process: + # # if request.save_disk_path: + # # timestamp = time.strftime("%Y%m%d-%H%M%S") + # # filename = f"generated_image_{timestamp}.png" + # # file_path = os.path.join(request.save_disk_path, filename) + # # os.makedirs(request.save_disk_path, exist_ok=True) + # # output[0].save(file_path) + # # return { + # # "message": "Image generated successfully", + # # "elapsed_time": f"{elapsed_time:.2f} sec", + # # "output": file_path, + # # "save_to_disk": True + # # } + # # else: + # # Convert to base64 + # buffered = io.BytesIO() + # output[0].save(buffered, format="PNG") + # img_str = base64.b64encode(buffered.getvalue()).decode() + # return { + # "message": "Image generated successfully", + # "elapsed_time": f"{elapsed_time:.2f} sec", + # "output": img_str, + # "save_to_disk": False + # } + # return None + + # except Exception as e: + # self.logger.error(f"Error generating image: {str(e)}") + # raise HTTPException(status_code=500, detail=str(e)) + + start_time = time.time() + print("generate", len(request.gen_param_b64)) + output = run_pipe(self.pipe, request.gen_param_b64, logger=self.logger) + #, request.num_inference_steps, request.seed) + + # output = self.pipe( + # height=request.height, + # width=request.width, + # prompt=request.prompt, + # num_inference_steps=request.num_inference_steps, + # output_type="pil", + # generator=torch.Generator(device="cuda").manual_seed(request.seed), + # guidance_scale=request.cfg, + # max_sequence_length=self.input_config.max_sequence_length + # ) + elapsed_time = time.time() - start_time + + is_last_process = get_world_group().rank == get_world_group().world_size - 1 + + #if self.pipe.is_dp_last_group(): + if is_last_process: + # if request.save_disk_path: + # timestamp = time.strftime("%Y%m%d-%H%M%S") + # filename = f"generated_image_{timestamp}.png" + # file_path = os.path.join(request.save_disk_path, filename) + # os.makedirs(request.save_disk_path, exist_ok=True) + # output[0].save(file_path) + # return { + # "message": "Image generated successfully", + # "elapsed_time": f"{elapsed_time:.2f} sec", + # "output": file_path, + # "save_to_disk": True + # } + # else: + # Convert to base64 + buffered = io.BytesIO() + output[0].save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + return { + "message": "Image generated successfully", + "elapsed_time": f"{elapsed_time:.2f} sec", + "output": img_str, + "save_to_disk": False + } + return None + +class Engine: + def __init__(self, world_size: int, xfuser_args: xFuserArgs, use_taylor_cache=False): + # Ensure Ray is initialized + if not ray.is_initialized(): + ray.init() + + num_workers = world_size + self.workers = [ + ImageGenerator.remote(xfuser_args, rank=rank, world_size=world_size, use_taylor_cache=use_taylor_cache) + for rank in range(num_workers) + ] + + async def generate(self, request: GenerateRequest): + results = ray.get([ + worker.generate.remote(request) + for worker in self.workers + ]) + + return next(path for path in results if path is not None) + +@app.post("/generate") +async def generate_image(request: GenerateRequest): + try: + # Add input validation + # if not request.prompt: + # raise HTTPException(status_code=400, detail="Prompt cannot be empty") + # if request.height <= 0 or request.width <= 0: + # raise HTTPException(status_code=400, detail="Height and width must be positive") + # if request.num_inference_steps <= 0: + # raise HTTPException(status_code=400, detail="num_inference_steps must be positive") + print(len(request.gen_param_b64)) + + result = await engine.generate(request) + return result + except Exception as e: + if isinstance(e, HTTPException): + raise e + raise HTTPException(status_code=500, detail=str(e)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='xDiT HTTP Service') + parser.add_argument('--model_path', type=str, help='Path to the model', required=True) + parser.add_argument('--world_size', type=int, default=1, help='Number of parallel workers') + parser.add_argument('--pipefusion_parallel_degree', type=int, default=1, help='Degree of pipeline fusion parallelism') + parser.add_argument('--ulysses_parallel_degree', type=int, default=1, help='Degree of Ulysses parallelism') + parser.add_argument('--ring_degree', type=int, default=1, help='Degree of ring parallelism') + parser.add_argument('--save_disk_path', type=str, default='output', help='Path to save generated images') + parser.add_argument('--use_cfg_parallel', action='store_true', help='Whether to use CFG parallel') + parser.add_argument('--use_taylor_cache', action='store_true', help='Whether to use taylor cache') + args = parser.parse_args() + + xfuser_args = xFuserArgs( + model=args.model_path, + trust_remote_code=True, + warmup_steps=1, + use_parallel_vae=False, + use_torch_compile=False, + ulysses_degree=args.ulysses_parallel_degree, + pipefusion_parallel_degree=args.pipefusion_parallel_degree, + use_cfg_parallel=args.use_cfg_parallel, + dit_parallel_size=0, + ) + + engine = Engine( + world_size=args.world_size, + xfuser_args=xfuser_args, + use_taylor_cache=args.use_taylor_cache, + ) + + # Start the server + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=6000) + \ No newline at end of file diff --git a/image_processing_bailingmm2.py b/image_processing_bailingmm2.py index 20ebac4..d4f4bed 100644 --- a/image_processing_bailingmm2.py +++ b/image_processing_bailingmm2.py @@ -20,7 +20,7 @@ """Image processor class for Qwen2-VL.""" import math -from typing import Optional, Union +from typing import Dict, List, Optional, Union import numpy as np @@ -46,8 +46,12 @@ ) import torch +from torchvision.transforms import functional as F +from torchvision.transforms import InterpolationMode from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence +from PIL import Image +from transformers.image_utils import is_valid_image try: from transformers.image_utils import VideoInput except: @@ -58,6 +62,39 @@ logger = logging.get_logger(__name__) +def make_batched_videos_torch(videos, device="cpu") -> List[VideoInput]: + if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): + return [torch.stack([torch.as_tensor(t, device=device) for t in ts]) for ts in videos] + + elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): + if isinstance(videos[0], Image.Image): + return torch.as_tensor([videos], device=device) + elif len(videos[0].shape) == 4: + return [torch.as_tensor(video, device=device) for video in videos] + + elif is_valid_image(videos) and len(videos.shape) == 4: + return [torch.as_tensor(videos, device=device)] + + raise ValueError(f"Could not make batched video from {videos}") + + +def resize_torchvision(image, size, resample): + resample_method = { + PILImageResampling.NEAREST: InterpolationMode.NEAREST, + PILImageResampling.BILINEAR:InterpolationMode.BILINEAR, + PILImageResampling.BICUBIC: InterpolationMode.BICUBIC, + } + interpolation = resample_method.get(resample, InterpolationMode.BICUBIC) + return F.resize(img=image, size=size, interpolation=interpolation) + + +def rescale_torchvision(image, scale, dtype=torch.float32): + return (image * scale).to(dtype) + + +def normalize_torchvision(image, mean, std): + return F.normalize(image, mean=mean, std=std) + class Qwen2VLImageProcessorKwargs(ImagesKwargs, total=False): r""" @@ -200,6 +237,69 @@ def __init__( self.merge_size = merge_size self.do_convert_rgb = do_convert_rgb + def _preprocess_torch( + self, + images: Union[ImageInput, VideoInput], + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + patch_size: Optional[int] = None, + temporal_patch_size: Optional[int] = None, + merge_size: Optional[int] = None, + do_convert_rgb: Optional[bool] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + height, width = get_image_size(images[0], channel_dim=input_data_format) + resized_height, resized_width = height, width + + if input_data_format == ChannelDimension.LAST: + images = images.permute(0, 3, 1, 2) # to NCHW + if do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=patch_size * merge_size, + min_pixels=size["shortest_edge"], + max_pixels=size["longest_edge"], + ) + images = resize_torchvision(images, size=(resized_height, resized_width), resample=resample) + if do_rescale: + images = rescale_torchvision(images, scale=rescale_factor) + if do_normalize: + images = normalize_torchvision(images, mean=image_mean, std=image_std) + if images.shape[0] == 1: + images = torch.tile(images, (self.temporal_patch_size, 1, 1, 1)) + patches = images + channel = patches.shape[1] + grid_t = patches.shape[0] // self.temporal_patch_size + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + patches = patches.reshape( + grid_t, + temporal_patch_size, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size, + ) + return flatten_patches, (grid_t, grid_h, grid_w) + def _preprocess( self, images: Union[ImageInput, VideoInput], @@ -359,6 +459,7 @@ def preprocess( data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, videos_timestamps_seconds=None, + device="cpu", ): """ Args: @@ -501,11 +602,11 @@ def preprocess( # kept for BC only and should be removed after v5.0 if videos is not None: - videos = make_batched_videos(videos) + videos = make_batched_videos_torch(videos, device=device) pixel_values_videos, vision_grid_thws_videos = [], [] video_timestamps_seconds = [] for video_idx, images in enumerate(videos): - patches, video_grid_thw = self._preprocess( + patches, video_grid_thw = self._preprocess_torch( images, do_resize=do_resize, size=size, @@ -522,7 +623,7 @@ def preprocess( do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, ) - pixel_values_videos.extend(patches) + pixel_values_videos.append(patches) vision_grid_thws_videos.append(video_grid_thw) if videos_timestamps_seconds is not None: @@ -543,10 +644,12 @@ def preprocess( ] assert len(aligned_timestamps_seconds) == video_grid_thw[0] video_timestamps_seconds.append(aligned_timestamps_seconds) - + pixel_values_videos = torch.cat(pixel_values_videos, dim=0) + if device == "cpu": + pixel_values_videos = pixel_values_videos.cpu().numpy() data.update( { - "pixel_values_videos": np.array(pixel_values_videos), + "pixel_values_videos": pixel_values_videos, "video_grid_thw": np.array(vision_grid_thws_videos), } ) diff --git a/ming_sdk/__init__.py b/ming_sdk/__init__.py new file mode 100644 index 0000000..259695d --- /dev/null +++ b/ming_sdk/__init__.py @@ -0,0 +1,5 @@ +from .ming import Ming +from .ming_img import MingImg +from .ming_moe import MingMOE +from .ming_talker import MingTalker +from ming_sdk.ming_moe_async import MingMOEAsync diff --git a/ming_sdk/ming.py b/ming_sdk/ming.py new file mode 100644 index 0000000..1f0133a --- /dev/null +++ b/ming_sdk/ming.py @@ -0,0 +1,713 @@ +import os +import sys +import time +import queue +import shutil +import logging +import time +from PIL import Image +import threading +import multiprocessing +import queue +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union, Generator + +from ming_sdk.usage import Usage +from ming_sdk.ming_moe import MingMOE +from ming_sdk.ming_utils import MingUtils +from ming_sdk.ming_talker import MingTalker +from ming_sdk.ming_moe_async import MingMOEAsync +from ming_sdk.monitoring.request_metrics import ( + metrics_text, + metrics_image, + metrics_speech, + metrics_tts, + metrics_speech_text_audio +) +from ming_sdk.ming_img import MingImg, ratio_extraction_fromat, rewrite_fromat, rewrite_edit_fromat, image_gen_indent_format, auto_balance_saturation_exposure, DEFAUL_PROMPT_FOR_NO_INTENT +from ming_sdk.ming_utils import ThreadSafeCache + +logger = logging.getLogger() +warnings.filterwarnings("ignore") +current_file_path = os.path.abspath(__file__) +current_dir_path = os.path.dirname(current_file_path) +sys.path.insert(0, current_dir_path) + + +class Ming(object): + """ + Initialize the class with model components and configuration. + + Args: + model_path (str): + Path to the root model directory. Must contain: + - `config.json`: Model architecture and tokenizer config + - `am.mvn`: Audio normalization stats for TTS frontend + - Subdirectories: `talker/`, `talker/vae`, and optionally diffusion checkpoints + + sys_prompt (str, optional): + System-level instruction prepended to all conversations (e.g., "You are a helpful AI"). + If not provided, no system prompt is used. Defaults to "". + + device (str, optional): + GPU device IDs for tensor parallelism in the LLM (vLLM backend). + Format: comma-separated integers, e.g., "0", "0,1", "0,1,2,3". + Determines `tensor_parallel_size`. Defaults to "0" (single GPU). + + gpu_memory_utilization (dict, optional): + Fraction of GPU memory to allocate for specific modules. + Helps prevent OOM errors on resource-limited devices. + Supported keys: + - "moe": for LLM (default: 0.6) + - "talker": for TTS model (default: 0.1) + Example: {"moe": 0.8, "talker": 0.2} + + device_map (dict, optional): + Assign different modules to different CUDA devices for heterogeneous deployment. + Supported keys: + - "talker": TTS model (default: "cuda:0") + - "image": Image generation model (default: "cuda:0") + Example: {"talker": "cuda:1", "image": "cuda:0"} enables cross-GPU deployment. + with_async (bool, optional): + Enabling async may change the public API surface to coroutine-based methods; callers should run within an asyncio event loop. + """ + + def __init__( + self, + model_path: str, + sys_prompt: str = "", + device: str = "0", + gpu_memory_utilization: dict = {"moe": 0.6, "talker": 0.1}, + limit_mm_per_prompt: dict = {"image": 10, "video": 2}, + device_map: dict = {"talker": ["cuda:0"], "image": "cuda:0"}, + with_async: bool = False, + speaker: str = "DB30", + quantization: str | None = None, + use_talker: bool = True, + use_image_gen: bool = False + ): + logger.info( + f"gpu_memory_utilization={gpu_memory_utilization},model_path={model_path},limit_mm_per_prompt={limit_mm_per_prompt},device={device},device_map={device_map}" + ) + tensor_parallel_size = ( + len(device.split(",")) if len(device.split(",")) > 0 else 1 + ) + shutil.copy(model_path + "/config.json", current_dir_path) + am_path = os.path.join(model_path, "am.mvn") + shutil.copy(am_path, ".") + + # 1. Initialize talker (TTS module) + self.utils = MingUtils(model_path=current_dir_path, sys_prompt=sys_prompt) + if use_talker: + self.talker = MingTalker( + model_path=model_path, + device_list=device_map["talker"], + ) + + # 2. Initialize MOE (Mixture of Experts LLM) + os.environ["VLLM_USE_V1"] = "0" + if with_async: + self.moe = MingMOEAsync( + model_path, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=gpu_memory_utilization["moe"], + limit_mm_per_prompt=limit_mm_per_prompt, + quantization=quantization, + ) + else: + self.moe = MingMOE( + model_path, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=gpu_memory_utilization["moe"], + limit_mm_per_prompt=limit_mm_per_prompt, + quantization=quantization, + ) + + # 3. Initialize image generation module + if use_image_gen: + self.img = MingImg(model_path, device=device_map["image"]) + self.device_map = device_map + self.speaker = speaker + self.queue_manager = multiprocessing.Manager() + self.info_cache_default_time = 1200 + self.info_cache = ThreadSafeCache(max_size=500, default_ttl=self.info_cache_default_time) + + def _generate_text( + self, + prompt: str, + audio: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + video: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + image: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + history: list = [], + **kwargs, + ) -> Any: + """ + Generate text output based on the input prompt. + + Args: + prompt (str): User input text. + audio (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Audio data (e.g., file path or binary or list). + video (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Video data (e.g., file path or binary or list). + image (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Image data (file path, binary, or PIL Image or list). + history (list, optional): Conversation history. Defaults to empty list. + **kwargs: Additional parameters for the model. + Returns: + str: Generated text output. + """ + msg_request_id = kwargs.get("msg_request_id", None) + state = metrics_text.create_state(stream_mode=False) + text_token_count = self.utils.compute_text_input_tokens(prompt, **kwargs) + inputs, image_token_count, video_token_count, audio_token_count = self.utils.build_prompt( + prompt=prompt, + audio=audio, + video=video, + image=image, + history=history, + **kwargs, + ) + + gen_text, usage = self.moe.generate(inputs, **kwargs) + usage = Usage.update_usage_by_processor(usage, text_token_count, image_token_count, video_token_count, audio_token_count) + state.input_token_length = usage["prompt_tokens"] + state.output_token_length = usage["generated_tokens"] + state.finish("success", msg_request_id=msg_request_id) + + return gen_text, usage + + def _rewrite_for_image_gen( + self, + prompt: str, + ): + + logger.info(f"In _rewrite_for_image_gen, user input {prompt}") + rewrite_input = rewrite_fromat.format(prompt) + prompt = self._generate_text(rewrite_input)[0] + return prompt + + def _is_image_gen_intent( + self, + prompt: str, + ): + + logger.info(f"_is_image_gen_intent, user input {prompt}") + indent_input = image_gen_indent_format.format(prompt) + prompt = self._generate_text(indent_input)[0] + return prompt.lower().strip() != "no" + + def _rewrite_for_image_edit( + self, + prompt: str, + ): + + logger.info(f"_rewrite_for_image_edit, user input {prompt}") + rewrite_input = rewrite_edit_fromat.format(prompt) + prompt = self._generate_text(rewrite_input)[0] + return prompt + + def _extract_aspect_ratio_for_image_gen( + self, + prompt: str, + ): + + t1 = time.time() + ratio_extraction_input = ratio_extraction_fromat.format(prompt) + ratio_text = self._generate_text(ratio_extraction_input)[0] + logger.info(f"Extract_aspect_ratio_for_image_gen, cost {time.time() - t1}s") + ratio = None + if ":" in ratio_text: + split_char = ":" + if ratio_text.count(split_char) == 1: + try: + ratio = float(ratio_text.split(split_char)[0]) / float( + ratio_text.split(split_char)[1] + ) + except: + pass + elif ":" in ratio_text: + split_char = ":" + if ratio_text.count(split_char) == 1: + try: + ratio = float(ratio_text.split(split_char)[0]) / float( + ratio_text.split(split_char)[1] + ) + except: + pass + elif "x" in ratio_text: + split_char = "x" + if ratio_text.count(split_char) == 1: + try: + ratio = float(ratio_text.split(split_char)[0]) / float( + ratio_text.split(split_char)[1] + ) + except: + pass + + return ratio + + def _generate_image( + self, + prompt: str, + image: Optional[Union[str, bytes, Image.Image]] = None, + image_gen_highres: int = 672, + **kwargs, + ) -> Image.Image: + """ + Generate an image or edit an existing one based on the input prompt. + + Args: + prompt (str): User input text. + image (Optional[Union[str, bytes, Image.Image]]): Input image (for editing). Defaults to None. + **kwargs: Additional parameters for image generation. + + Returns: + Image.Image: Generated or edited image. + """ + msg_request_id = kwargs.get("msg_request_id", None) + is_t2i = (image is None) or (isinstance(image, list) and len(image) == 0) + + if kwargs is not None and "history" in kwargs: + del kwargs["history"] + + request_id = self.moe.create_request_id() + state = metrics_image.create_state(stream_mode=False, request_id=request_id) + user_aspect_ratio = None + user_prompt = prompt + text_token_count = self.utils.compute_text_input_tokens(user_prompt, **kwargs) + + # Prompt rewriting for better image generation quality + logger.info(f"In _generate_image, before input rewriten {prompt}") + + if is_t2i: + t1 = time.time() + if self._is_image_gen_intent(user_prompt): + prompt = self._rewrite_for_image_gen(user_prompt) + else: + prompt = DEFAUL_PROMPT_FOR_NO_INTENT + logger.info(f"Rewrite_for_image_gen, cost {time.time() - t1}s") + else: + t1 = time.time() + + if kwargs.get('is_segmentation', False): + user_prompt = f"Given the following instructions: {user_prompt}; please perform referring segmentation on this image." + prompt = self._rewrite_for_image_edit(user_prompt) + logger.info(f"Rewrite_for_image_edit, cost {time.time() - t1}s") + + logger.info("In _generate_image, after input rewriten") + # Extract user-specified aspect ratio from prompt + user_aspect_ratio = self._extract_aspect_ratio_for_image_gen(user_prompt) + + inputs, image_token_count, video_token_count, audio_token_count = self.utils.build_img_prompt(prompt=prompt, image=image, **kwargs) + kwargs["compute_input_tokens_flag"] = False + image_gen_llm_hidden_states = None + + os.environ["IMAGE_GEN_MODE"] = "T2I" if image is None else "EDIT" + # Generate hidden states from LLM for diffusion model + image_gen_llm_hidden_states, usage = self.moe.generate( + requests=inputs, + with_hidden_status=True, + max_new_tokens=1, + return_hidden_states=True, + ) + usage = Usage.update_usage_by_processor(usage, text_token_count, image_token_count, video_token_count, audio_token_count) + image_gen_negative_llm_hidden_states = None + inputs_img_gen_negative = None + if is_t2i: + negative_prompt = "mutations, deformities, tilted heads, bad fingers, bad eyes, extra limbs, excess arms, deformed limbs, deformed legs, ugly, watermarks, text, NSFW" + inputs, _, _, _ = self.utils.build_img_prompt( + prompt=negative_prompt, image=image, **kwargs + ) + os.environ["IMAGE_GEN_MODE"] = "T2I" if image is None else "EDIT" + # Generate negative prompt hidden states for classifier-free guidance + image_gen_negative_llm_hidden_states, usage_neg = self.moe.generate( + requests=inputs, + with_hidden_status=True, + max_new_tokens=1, + return_hidden_states=True, + ) + image_gen_negative_llm_hidden_states = ( + image_gen_negative_llm_hidden_states.unsqueeze(0) + ) + inputs_img_gen_negative = self.utils.build_img_gen_prompt( + prompt=negative_prompt, image=image, image_gen_highres=image_gen_highres + ) + + os.environ["IMAGE_GEN_MODE"] = "" + # device = self.device_map["image"] + inputs = self.utils.build_img_gen_prompt( + prompt=prompt, + image=image, + image_gen_highres=image_gen_highres, + image_gen_aspect_ratio=user_aspect_ratio, + ) + if inputs_img_gen_negative is not None: + inputs["image_gen_negative_input_ids"] = inputs_img_gen_negative[ + "input_ids" + ] + inputs["image_gen_negative_attention_mask"] = inputs_img_gen_negative[ + "attention_mask" + ] + + logger.info("In generate_image, begin diffusion") + inputs['input_ids'] = inputs['input_ids'].to(self.device_map['image']) + inputs['attention_mask'] = inputs['attention_mask'].to(self.device_map['image']) + inputs['image_gen_height'] = inputs['image_gen_height'].to(self.device_map['image']) + inputs['image_gen_width'] = inputs['image_gen_width'].to(self.device_map['image']) + if "image_gen_negative_input_ids" in inputs: + inputs['image_gen_negative_input_ids'] = inputs['image_gen_negative_input_ids'].to(self.device_map['image']) + if "image_gen_negative_attention_mask" in inputs: + inputs['image_gen_negative_attention_mask'] = inputs['image_gen_negative_attention_mask'].to(self.device_map['image']) + if type(image_gen_llm_hidden_states) is tuple: + image_gen_llm_hidden_states = image_gen_llm_hidden_states[0] + image_gen_llm_hidden_states = image_gen_llm_hidden_states.to(self.device_map['image']) + if image_gen_negative_llm_hidden_states is not None: + image_gen_negative_llm_hidden_states = image_gen_negative_llm_hidden_states.to(self.device_map['image']) + image = self.img.model_diffusion.generate( + **inputs, + image_gen_llm_hidden_states=image_gen_llm_hidden_states.unsqueeze(0), + image_gen_negative_llm_hidden_states=image_gen_negative_llm_hidden_states, + image_gen=True, + image_gen_cfg=5.5 if image is None else 5.0, + ) + if is_t2i: + image = auto_balance_saturation_exposure(image) + + # Update usage statistics + usage = Usage.update_image_usage_by_length( + usage=usage, image_gen_highres=image_gen_highres + ) + + # Record metrics + state.input_token_length = usage["prompt_tokens"] + state.output_token_length = usage["generated_tokens"] + state.finish("success", msg_request_id=msg_request_id) + + return image, usage + + def _generate_audio( + self, + prompt: str, + audio: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + video: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + image: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + history: list = [], + **kwargs, + ) -> Union[bytes, Generator[bytes, None, None]]: + """ + Generate audio (text-to-speech or speech-to-speech). + + Args: + prompt (str): User input text. + audio (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Audio data (e.g., file path or binary or list). + video (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Video data (e.g., file path or binary or list). + image (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Image data (file path, binary, or PIL Image or list). + history (list, optional): Conversation history. Defaults to empty list. + **kwargs: Additional parameters for the model. + + Returns: + Union[bytes, Generator[bytes, None, None]]: Generated audio data or a stream generator. + """ + msg_request_id = kwargs.get("msg_request_id", None) + request_id = self.moe.create_request_id() + state = metrics_speech.create_state(stream_mode=False, request_id=request_id) + text_token_count = self.utils.compute_text_input_tokens(prompt, **kwargs) + # Build prompt and generate text response + inputs, image_token_count, video_token_count, audio_token_count = self.utils.build_prompt( + prompt=prompt, + audio=audio, + video=video, + image=image, + history=history, + **kwargs, + ) + gen_text, usage = self.moe.generate(inputs, **kwargs) + usage = Usage.update_usage_by_processor(usage, text_token_count, image_token_count, video_token_count, audio_token_count) + audio, duration = self.talker.generate(text=gen_text, speaker=self.speaker, request_id=msg_request_id) + + # Update audio usage statistics + usage = Usage.update_audio_usage_by_duration(usage=usage, duration=duration) + + # Record metrics + state.input_token_length = usage["prompt_tokens"] + state.output_token_length = usage["generated_tokens"] + state.finish("success", msg_request_id=msg_request_id) + + return audio, gen_text, usage + + def _generate_tts(self, text: str, **kwargs) -> Union[bytes, Generator[bytes, None, None]]: + # Generate TTS audio from text + msg_request_id = kwargs.get("msg_request_id", None) + request_id = self.moe.create_request_id() + prompt_tokens = len(text) + usage = Usage.create_usage_default(prompt_tokens=prompt_tokens) + state = metrics_tts.create_state(stream_mode=False, request_id=request_id) + audio, duration = self.talker.generate(text=text, speaker=self.speaker, request_id=msg_request_id) + if duration == 0: + duration = audio.shape[-1]/16000 + usage = Usage.update_audio_usage_by_duration(usage=usage, duration=duration) + state.input_token_length = usage["prompt_tokens"] + state.output_token_length = usage["generated_tokens"] + state.finish("success", msg_request_id=msg_request_id) + usage['finish_reason'] = 'stop' + return audio, usage + + def generate( + self, + text: Optional[str] = None, + audio: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + video: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + image: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + history: list = [], + output_type: str = "text", + **kwargs, + ) -> Union[str, Image.Image, bytes, Generator]: + """ + Generate content based on the specified output type. + + Args: + text (Optional[str]): User input text. + audio (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Audio data (e.g., file path or binary or list). + video (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Video data (e.g., file path or binary or list). + image (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Image data (file path, binary, or PIL Image or list). + history (list, optional): Conversation history. Defaults to empty list. + output_type (str, optional): Output type ("text", "speech", "image", "tts"). Defaults to "text". + **kwargs: Additional parameters for the model. + + Returns: + Union[str, Image.Image, bytes, Generator]: Generated content (text, image, or audio). + + Raises: + ValueError: If `output_type` is not supported. + """ + if output_type == "text": + return self._generate_text( + prompt=text, + audio=audio, + video=video, + image=image, + history=history, + **kwargs, + ) + + elif output_type == "speech": + return self._generate_audio( + prompt=text, + audio=audio, + video=video, + image=image, + history=history, + **kwargs, + ) + + elif output_type == "image": + return self._generate_image( + prompt=text, + audio=audio, + video=video, + image=image, + history=history, + **kwargs, + ) + + elif output_type == "tts": + return self._generate_tts(text=text, **kwargs) + + else: + raise Exception("not support output_type") + + def generate_stream( + self, + text: Optional[str] = None, + audio: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + video: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + image: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + history: list = [], + output_type: str = "text", + **kwargs, + ) -> Generator[Tuple[Union[bytes, str], str], None, None]: + """ + Stream generated content (text or speech). + + Args: + text (Optional[str]): User input text. + audio (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Audio data (e.g., file path or binary or list). + video (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Video data (e.g., file path or binary or list). + image (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Image data (file path, binary, or PIL Image or list). + history (list, optional): Conversation history. Defaults to empty list. + output_type (str, optional): Output type ("text", "speech", "TTS"). Defaults to "text". + **kwargs: Additional parameters for the model. + + Yields: + Tuple[Union[bytes, str], str]: Generated content (text or audio) and request ID. + + Raises: + ValueError: If `output_type` is not supported for streaming. + """ + msg_request_id = kwargs.get("msg_request_id", None) + text_token_count = self.utils.compute_text_input_tokens(text, **kwargs) + if output_type == "text": + request_id = self.moe.create_request_id() + self.info_cache.set(f"{msg_request_id}", request_id, self.info_cache_default_time) + state = metrics_text.create_state(stream_mode=True, request_id=request_id) + inputs, image_token_count, video_token_count, audio_token_count = self.utils.build_prompt( + prompt=text, + audio=audio, + video=video, + image=image, + history=history, + **kwargs, + ) + for text in self.moe.generate_stream( + requests=inputs, request_id=request_id, **kwargs + ): + state.record_first_token() + usage = self.moe.usage.get_stream_usage_by_request_id( + request_id=request_id + ) + usage = Usage.update_usage_by_processor(usage, text_token_count, image_token_count, video_token_count, audio_token_count) + state.record_input_tokens(usage["prompt_tokens"]) + state.output_token_length = usage["generated_tokens"] + yield text, request_id, usage + state.finish("success", msg_request_id=msg_request_id) + + elif output_type == "speech": + request_id = self.moe.create_request_id() + self.info_cache.set(f"{msg_request_id}", request_id, self.info_cache_default_time) + state_text = metrics_speech.create_state(stream_mode=True, request_id=request_id) + state_text_audio = metrics_speech_text_audio.create_state(stream_mode=True, request_id=request_id) + + inputs, image_token_count, video_token_count, audio_token_count = self.utils.build_prompt( + prompt=text, + audio=audio, + video=video, + image=image, + history=history, + **kwargs, + ) + text_generator = self.moe.generate_stream( + requests=inputs, request_id=request_id, **kwargs + ) + + def _produce_text_to_queue(text_generator, talker_input_queue, result_queue): + def producer(): + try: + for chunk in text_generator: + state_text.record_first_token() + + talker_input_queue.put(chunk) + usage = self.moe.usage.get_stream_usage_by_request_id( + request_id=request_id + ) + usage = Usage.update_usage_by_processor(usage, text_token_count, image_token_count, video_token_count, audio_token_count) + state_text.record_input_tokens(usage["prompt_tokens"]) + state_text.output_token_length = usage["generated_tokens"] + result_queue.put(("text_data", (chunk, usage))) + talker_input_queue.put(None, None) # End signal + state_text.finish("success", msg_request_id=msg_request_id) + except Exception as e: + logger.error(f"Error in text producer: {e}") + talker_input_queue.put(None, None) + + producer_thread = threading.Thread(target=producer) + producer_thread.daemon = True + producer_thread.start() + return producer_thread + + def warpper(text_queue): + while True: + text = text_queue.get() + if text is None: + break + yield text + + def thread_talker_task(talker, moe, text_input_queue, speaker, result_queue): + def _thread_talker_task(talker, moe, text_input_queue, speaker, result_queue): + try: + duration = 0 + for tts_speech, sentence, meta_info in talker.generate_stream( + text=warpper(text_input_queue), speaker=speaker, request_id=msg_request_id + ): + state_text_audio.record_first_token() + # update audio usage + usage = moe.usage.get_stream_usage_by_request_id( + request_id=request_id + ) + usage = Usage.update_usage_by_processor(usage, text_token_count, image_token_count, video_token_count, audio_token_count) + duration += meta_info["duration"] + + usage = Usage.update_audio_usage_by_duration(usage, duration) + + state_text_audio.record_input_tokens(usage["prompt_tokens"]) + state_text_audio.output_token_length = usage["generated_tokens"] + result_queue.put(("text_audio_data", (tts_speech, sentence, meta_info, request_id, usage))) + finally: + result_queue.put((None, None)) + state_text_audio.finish("success", msg_request_id=msg_request_id) + + talker_thread = threading.Thread(target=_thread_talker_task, args=(talker, moe, text_input_queue, speaker, result_queue)) + talker_thread.daemon = True + talker_thread.start() + return talker_thread + + talker_input_queue = self.queue_manager.Queue() + result_queue = queue.Queue() + + producer_thread = _produce_text_to_queue(text_generator, talker_input_queue, result_queue) + talker_thread = thread_talker_task(self.talker, self.moe, talker_input_queue, self.speaker, result_queue) + + while True: + data_type, data_content = result_queue.get() + if data_type is None: + break + yield data_type, data_content + + # Ensure both threads have completed processing + producer_thread.join() + talker_thread.join() + + elif output_type == "tts": + request_id = self.moe.create_request_id() + self.info_cache.set(f"{msg_request_id}", request_id, self.info_cache_default_time) + state = metrics_tts.create_state(stream_mode=True, request_id=request_id) + prompt_tokens = len(text) + + duration = 0 + for tts_speech, sentence, meta_info in self.talker.generate_stream( + text=text, speaker=self.speaker, request_id=msg_request_id, + ): + usage = Usage.create_usage_default(prompt_tokens=prompt_tokens) + state.record_first_token() + if meta_info["duration"] == 0: + duration += tts_speech.shape[-1]/16000 + else: + duration += meta_info["duration"] + + usage = Usage.update_audio_usage_by_duration( + usage=usage, duration=duration + ) + usage = Usage.update_usage_by_processor(usage, text_token_count=text_token_count) + state.record_input_tokens(usage["prompt_tokens"]) + state.output_token_length = usage["generated_tokens"] + yield tts_speech, sentence, meta_info, request_id, usage + state.finish("success", msg_request_id=msg_request_id) + else: + raise Exception("not support output_type") + + def generate_interrupt(self, msg_request_id: str) -> None: + """ + Interrupt a specific request. + + Args: + request_id (str): ID of the request to interrupt. + + Raises: + ValueError: If `request_id` is empty. + """ + + vllm_infer_request_id = self.info_cache.get(f"{msg_request_id}") + if vllm_infer_request_id: + self.moe.generate_interrupt(vllm_infer_request_id) + logger.info(f"Generate_interrupt success, msg_request_id: {msg_request_id}, vllm infer request_id: {vllm_infer_request_id}") + else: + logger.info(f"Generate_interrupt failed, msg_request_id: {msg_request_id} is invalid") + + self.talker.generate_interrupt(msg_request_id) diff --git a/ming_sdk/ming_img.py b/ming_sdk/ming_img.py new file mode 100644 index 0000000..96e7376 --- /dev/null +++ b/ming_sdk/ming_img.py @@ -0,0 +1,492 @@ +""" +MingImg: Image Generation Module + +This module provides image generation and editing capabilities using +the BailingMM2 diffusion model. It supports text-to-image generation +and image editing with various aspect ratios and resolutions. + +Key Features: + - Text-to-image generation with prompt optimization + - Image editing with segmentation mask support + - Automatic saturation and exposure balancing + - Multiple aspect ratio support + +Usage: + >>> img_gen = MingImg(model_path="/path/to/model", device="cuda:0") +""" + +import os +import cv2 +import torch +import numpy as np +import torch.nn as nn +from PIL import Image +from PIL import ImageEnhance +import torchvision.transforms as transforms + + +def auto_balance_saturation_exposure(pil_img: Image.Image) -> Image.Image: + """ + Automatically detect and adjust saturation and exposure to prevent + over-saturation and over-exposure. + + This function uses fixed internal thresholds and applies adjustments + only when values exceed safe limits. + + Args: + pil_img (Image.Image): Input PIL image to process. + + Returns: + Image.Image: Processed image with balanced saturation and exposure. + """ + # Convert to HSV color space for saturation/value analysis + hsv_img = pil_img.convert('HSV') + hsv_array = np.array(hsv_img) + + # Extract saturation and brightness channels + saturation_channel = hsv_array[:, :, 1] / 255.0 # Range: 0~1 + value_channel = hsv_array[:, :, 2] / 255.0 # Range: 0~1 + + # Calculate mean saturation and brightness + mean_sat = saturation_channel.mean() + mean_val = value_channel.mean() + + # Fixed internal thresholds + max_sat_threshold = 0.5 # Over-saturation threshold + max_val_threshold = 0.75 # Over-exposure threshold + + adjusted_img = pil_img + + print("sat: ", mean_sat) + print("val: ", mean_val) + + if mean_sat > max_sat_threshold: + # Reduce saturation to threshold ratio + ratio = max_sat_threshold / mean_sat + adjusted_img = ImageEnhance.Color(adjusted_img).enhance(ratio) + + if mean_val > max_val_threshold: + # Reduce brightness to threshold ratio + ratio = max_val_threshold / mean_val + adjusted_img = ImageEnhance.Brightness(adjusted_img).enhance(ratio) + + return adjusted_img + + +def gaussian_kernel(size=5, sigma=1.0): + x, y = torch.meshgrid( + torch.linspace(-size / 2, size / 2, size), + torch.linspace(-size / 2, size / 2, size), + ) + d = x**2 + y**2 + g = torch.exp(-d / (2.0 * sigma**2)) + g = g / torch.sum(g) + return g + + +def get_filter_conv(kernel_size: int = 3, sigma: float = 1.0) -> nn.Conv2d: + """ + Create a 2D Gaussian filter convolution layer. + + Args: + kernel_size (int): Size of the Gaussian kernel. Defaults to 3. + sigma (float): Standard deviation of the Gaussian. Defaults to 1.0. + + Returns: + nn.Conv2d: Convolution layer with fixed Gaussian weights. + """ + gaussian_kernel_2d = gaussian_kernel(kernel_size, sigma) + gaussian_kernel_3d = gaussian_kernel_2d.expand(3, kernel_size, kernel_size) + conv_layer = nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=kernel_size, groups=3, bias=False + ) + + conv_layer.weight.data = gaussian_kernel_3d.unsqueeze(1) + conv_layer.weight.requires_grad = False + + return conv_layer + + +def remove_noise_opencv(binary_image: np.ndarray, radius: int = 5) -> np.ndarray: + """ + Remove noise from a binary image using median blur. + + Args: + binary_image (np.ndarray): Input binary image (0-1 or 0-255 range). + radius (int): Kernel size for median blur. Defaults to 5. + + Returns: + np.ndarray: Denoised binary image. + """ + if binary_image.max() <= 1: + binary_image = (binary_image * 255).astype(np.uint8) + else: + binary_image = binary_image.astype(np.uint8) + + denoised = cv2.medianBlur(binary_image, radius) + + return denoised + + +def get_mask(ref_img: Image.Image, pre_img: Image.Image, + kernel_size: int = 3, sigma: float = 1.0, + radius: int = 5, device: str = "cpu") -> np.ndarray: + """ + Generate a segmentation mask by comparing two images. + + This function computes the difference between a reference image and a + previous frame to identify changed regions, useful for video editing + and object segmentation. + + Args: + ref_img (Image.Image): Reference image (current frame). + pre_img (Image.Image): Previous image for comparison. + kernel_size (int): Gaussian kernel size for smoothing. Defaults to 3. + sigma (float): Gaussian sigma for smoothing. Defaults to 1.0. + radius (int): Denoising radius for morphological operations. Defaults to 5. + device (str): Device for tensor operations. Defaults to "cpu". + + Returns: + np.ndarray: Binary mask indicating changed regions. + """ + transform_img = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + conv_layer = get_filter_conv(kernel_size, sigma) + + image1_tensor = transform_img(ref_img).unsqueeze(0) + image2_tensor = transform_img(pre_img).unsqueeze(0) + + image1_tensor = image1_tensor.to(device) + image2_tensor = image2_tensor.to(device) + conv_layer = conv_layer.to(device) + + filtered_image1 = conv_layer(image1_tensor) + filtered_image2 = conv_layer(image2_tensor) + + filtered_image2 = (filtered_image2 - 0.5 * filtered_image1) * 2 + + diff = filtered_image1 - filtered_image2 + abs_diff = torch.abs(diff) + abs_diff = conv_layer(abs_diff) + # thresh = float(os.getenv('SEG_THRESH', 0.15)) # 0.1 + # radius = int(os.getenv('SEG_DENOISE', 5)) + thresh = 0.15 + radius = 11 + abs_diff = (abs_diff.mean(dim=1)[0] > thresh).float() + res = remove_noise_opencv(abs_diff.numpy(), radius) + + return res + + +def get_cutout(ref_img: Image.Image, pre_mask: np.ndarray) -> Image.Image: + """ + Create a cutout image with transparent background using a mask. + + Args: + ref_img (Image.Image): Source image to cut out from. + pre_mask (np.ndarray): Binary mask for the cutout region. + + Returns: + Image.Image: RGBA image with transparent background outside mask. + """ + rgba = ref_img.convert("RGBA") + alpha = Image.fromarray(pre_mask.astype(np.uint8), mode="L").resize(ref_img.size) + rgba.putalpha(alpha) + + return rgba + + +ratio_extraction_fromat = """你是一名专业的AI图像提示词处理专家,请依据以下规则处理输入: + +1.提取用户提示词中关于输出图像长宽比例的信息,如捕捉16:9, 1比1, 等信息 +2.输出必须为一行纯文本, 即为长宽比信息(格式为 number:number),不要包含任何前缀后缀; +3.如果没有发现长宽比例的信息,输出 "None" + +输出示例: +输入: 中国小女孩自拍,衣服上写着美丽二字, 按照16:10比例进行生成 +输出: 16:10 + +输入: In a 1:1 ratio, generate a man. +输出: 1:1 + +输入: generate a man, in a 1920x1080 +输出: 1920:1080 + +输入: draw a beautiful girl +输出: None + + +现在处理以下输入:{}""" + + +rewrite_fromat = """你是一位专业的文生图提示词工程师,擅长将用户简短的图像描述转化为细节丰富、视觉准确、适合主流图像模型生成的自然语言提示。你的任务是在忠实还原用户核心意图的前提下,自动补全必要视觉信息,并输出一行可用于直接生成高质量图像的描述。 + +请严格遵循以下原则: + +🔹 输入语言决定输出语言 +- 若用户输入为**中文**,则输出为**中文单句描述** +- 若用户输入为**英文**,则输出为**英文单句描述** +- 不要翻译、不要混用、不要添加解释性语句 + +🔹 核心意图优先 +始终将**图像的主要用途或载体类型**置于描述最前方,确保模型优先理解画面功能属性。例如: +- “PPT background with a woman” 先写 "PPT background design" +- “a running man” 转换为静态后写为 "a man standing on a running track" +- “手机锁屏” 写为 “手机屏幕界面” + +避免将人物外貌、环境细节等次要信息前置导致主题偏移。 + +🔹 绝对静态化处理 +禁止使用任何暗示运动或趋势的姿态词汇(如“正在”“摆动”“抬起”)。将动作类描述转换为静止状态: +- “打篮球” -> “抱着篮球站在篮球场” +- “踢足球” -> “脚踩足球站在足球场” +- “跳舞” -> “站在舞台上” + +仅保留主体与对象的共现关系。 + +🔹 零文本增生原则(全局禁止) +除非用户**明确使用“写着”“印着”“显示”“刻着”“挂着”“贴着”“涂鸦”“屏幕上显示”等动词 + 具体内容**,否则不得在图像中添加任何形式的文字元素,包括: +- 品牌标识(如 `"LV"`、`"Nike"`) +- 店名招牌(如 `"渔火"`) +- 屏幕显示内容 +- 服装印花文字 + +即使该品牌通常带有标志性文字(如 LV 包),也不予渲染。 + +🔹 文字标注规范(仅当显式提及) +若用户明确描述某物表面有可读文字(如 “the sign says 'Open'” 或 “T恤上印着Hello”),则提取该文字并用英文双引号 " " 包裹。其他情况一律不添加。 + +🔹 智能风格推断(自动适配最优风格) +若未指定艺术风格,则根据主体类型自动匹配最合理风格: +- 真实人物、静物、自然景观 -> 写实摄影风格 / photorealistic +- 拟人化动物、卡通角色 -> 卡通或插画风格 / cartoon, illustration +- 虚构生物、未来世界 -> 数字艺术 / concept art, digital painting +- 抽象概念 -> 超现实主义 / surrealism +- PPT/网页/界面 -> 扁平设计 / flat design, modern UI + +🔹 视觉细节增强 +在不偏离主次的前提下,补充以下维度信息: +- 主体外观:颜色、服装、发型、物品样式 +- 材质纹理:布料、金属、玻璃、皮革光泽 +- 光照氛围:自然光、室内灯、晴天/黄昏 +- 时间季节:清晨、秋日、雪景等典型情境 +- 环境背景:室内外场景、天气、空间层次 +- 构图建议:背景虚化、中心对称、留白区域 + +🔹 多余图片信息去除 +若用户输入中包含生成图像的比例,或者分辨率,则在输出中删除相关表述,例如“按照1:1生成”、“16:9”、“1920x1080” + +🔸 输出格式要求 +- 仅输出一行自然语言描述,作为完整句子呈现 +- 不换行、不编号、不加标题、不解释、不推荐参数 +- 语言简洁准确,避免抽象比喻、情感渲染或文学化修辞 +- 关键信息前置:**用途/载体, 场景, 主体, 风格** +- 禁用所有动态趋势词(如“正在”“即将”“欲”) +- 输出应符合文生图模型常用表达方式(关键词适度密集,语法自然) + +📌 示例参考(中英混合): +输入:一个女性背景的ppt, 16:9 +输出:PPT背景设计,浅蓝色渐变底纹搭配简约线条图案,右侧有一位优雅站立的女性剪影,整体为现代商务风格,扁平化视觉效果 + +输入:a girl dancing on the beach +输出:a girl standing on a sandy beach, facing the ocean, wearing a white dress, seagulls flying in the distance, soft sunlight, realistic photography style + +输入:海滨餐馆 +输出:一家临海的小型木结构餐馆坐落在岩石岸边,大落地窗面向大海,屋顶有茅草遮阳顶,门口摆放着几张木质桌椅,背景是波光粼粼的海面和晚霞,整体为写实摄影风格 + +输入:the restaurant, with a wooden sign hanging above the door that says Hello 渔火 +输出:a small wooden seaside restaurant with a thatched roof, large windows facing the ocean, and a hand-carved wooden sign above the entrance displaying the Chinese characters "Hello 渔火", gentle waves in the background, golden hour lighting, photorealistic style + +输入:an lv bag +输出:a brown monogram pattern handbag placed on a light gray marble surface, fine leather texture with visible weave, sturdy handles, soft ambient lighting, photorealistic product photography style + +输入:一个男人 +输出:一个男人站在城市街道旁,身穿深蓝色风衣,黑色短发,面容清晰,背景为虚化的行人和建筑,自然光照射,写实摄影风格 + +输入:哆啦A梦 +输出:卡通风格的哆啦A梦站立在明亮的室内场景中,圆润的蓝色机身,白色腹部,红色项圈配黄色铃铛,大眼睛直视前方,双手自然下垂,背景为浅色木纹地板与柔和的暖光照明,整体为彩色插画风格 + +输入:黑板报上面写着蚂蚁Ming-Omni +输出:黑板报上用白色粉笔字写着"蚂蚁Ming-Omni",深灰色木质边框的黑板置于教室墙面上,周围贴有彩色剪纸和学生绘画作品,阳光从左侧窗户斜射入内,形成柔和光斑与粉尘光束,粉笔字迹清晰带有轻微阴影,背景为浅黄色旧墙,整体为写实摄影风格 + + +现在,请优化以下输入: +{}""" + + +rewrite_edit_fromat = """你是一位专业的图像编辑提示词工程师,擅长将用户的编辑图像描述转化为固定的格式。你的任务是在忠实还原用户核心意图的前提下,自动转换格式与语言,并输出一行可用于直接生成图像编辑任务的指令。 + +请严格遵循以下原则: + +一. 一般指令输出英文, 输出必须为一行纯文本, 即为优化的结果,不要包含任何前缀后缀 +中文或中英文混合,要翻译成英文的指令: +- "将背景换成海滩" 应改为 "Change the background to a beach" +- "把eyeglasses去掉" 应改为 "Remove the eyeglasses" +- "举手" 应改为 "Raise the hand up" +不要将用户想在画面中改的字进行翻译: +- "把封面上的\"你好\"改成\"今天吃了吗\"" 应改为 "Change the text "你好" on the cover to "今天吃了吗"" +- "写上一句话orange真好吃", 应改为 "Add the text "orange真好吃"" + +二. 有关 segmentation 或者分割的指令,按照固定格式"Given the following instructions: [target]; please perform referring segmentation on this image with [color] mask."输出 +- [color]默认使用"green" , 如果已经指定掩码颜色,则使用用户指定的颜色 +- "Separate the little girl on the left" 应改为 "Given the following instructions: the little girl on the left; please perform referring segmentation on this image with green mask." +- "把塔分割出来" 应改为 "Given the following instructions: tower; please perform referring segmentation on this image with green mask." +- "用橙色把塔分割出来" 应改为 "Given the following instructions: tower; please perform referring segmentation on this image with orange mask." +- "perform segmentation on the tiger" 应改为 "Given the following instructions: tiger; please perform referring segmentation on this image with green mask." +- "please segment the cats" 应改为 "Given the following instructions: cats; please perform referring segmentation on this image with green mask." +- "请扣出图中的塔" 应改为 "Given the following instructions: tower; please perform referring segmentation on this image with green mask." +- "把右边的男人抠出来" 应改为 "Given the following instructions: the man on the right; please perform referring segmentation on this image with green mask." +- "分割出戴帽子的女人" 应改为 "Given the following instructions: woman wearing a hat; please perform referring segmentation on this image with green mask." +- "把站起来的小狗分割出来" 应改为 "Given the following instructions: the puppy standing up; please perform referring segmentation on this image with green mask." +- "把"hello 世界"抠出来" 应改为 "Given the following instructions: the text hello 世界; please perform referring segmentation on this image with green mask." + + +三. 探测是否有文字渲染需求 +如果用户指令中存在用户想要渲染的文字内容时将需要生成的文字改为使用""(英文双引号)括起来,比如在描述文本内容、书本封面、海报、广告牌、黑板、印刷等场景,对于英文提示词重点看是否有 "text" 或者 "word" 的暗示: +- Change the word unification in the book to promote +- "把封面上的\"hello\"改成\"你好\"" 应改为 "Change the text "hello" on the cover to "你好"". +- "添加文本北京欢迎您", 应改为 "Add the text "北京欢迎您"". +- "删除题目中"坏孩子"", 应改为 "Remove the text "坏孩子"". +如果用户没有明确的文字生成意愿, 则不要使用双引号进行改写,也不要额外增加渲染文字内容: +- "将背景换成 Arc de Triomphe" 应改为 "Change the background to Arc de Triomphe" +- "add a cat on the rock" 应改为 "add a cat on the rock" +特别注意,有关 segmentation 或者分割的指令,应去掉输出指令中的所有的双引号(“”,"") +- "把“你好”抠出来", 应改为 "Given the following instructions: text 你好; please perform referring segmentation on this image with green mask." +- "抠出标题中的"hello 世界"", 应改为 "Given the following instructions: the text hello 世界; please perform referring segmentation on this image with green mask." +- "please segment the text "great again"", 应改为 "Given the following instructions: the text great again; please perform referring segmentation on this image with green mask." +- "分割出 the word happy", 应改为 "Given the following instructions: the word happy; please perform referring segmentation on this image with green mask." + +四. 删除关于画面比例的表述,如1:1, 16:9这种 +- "生成证件照,16:9" 应改为 "Translate into a standard ID photo" + + +现在处理以下输入: +{}""" + +image_gen_indent_format = """你是一个意图判断器。 + +你的任务是:判断用户的输入是否满足以下任一条件: +1. 描述为具象的实体(场景、人物、动物、景色、物体、颜色、细节、要被渲染的文字任意一项即可)。 +2. 表达了明确的生成图片的意图(如“画一下”“生成图片”“给我一张…”)。 + +满足任一条件 → 输出 "yes" +否则 → 输出 "no" + +输出规则: +- 只能输出小写的yes或no。 +- 不要解释,不要输出其它符号或文字。 + +以下是示例: +用户输入: "画一只蓝色的鲸在天空中飞" +输出: yes + +用户输入: "帮我生成一张东京街头夜景的照片" +输出: yes + +用户输入: "在森林里,一只狐狸坐在溪边" +输出: yes + +用户输入: "故宫" +输出: yes + +用户输入: "戴眼镜" +输出: yes + +用户输入: "一行字 hello world" +输出: yes + +用户输入: "女孩" +输出: yes + +用户输入: "generate a boy" +输出: yes + +用户输入: "a boy" +输出: yes + +用户输入: "draw people" +输出: yes + +用户输入: "清晨薄雾笼罩的青翠山间小径,白发少年逆光奔跑,动态模糊捕捉疾驰瞬间;少年身穿白色运动服,手持一本摊开的书,柔光侧照突显发丝细节,远景层叠山峦与晨曦光晕,8K超清写实风格,浅景深突出主体" +输出: yes + +用户输入: "goodbye" +输出: no + +用户输入: "Hi" +输出: no + +用户输入: "What can you do for me" +输出: no + +用户输入: "你可以做什么?" +输出: no + +用户输入: "我们公司在北京" +输出: no + +用户输入: "描述你昨天的天气" +输出: no + +用户输入: "你好" +输出: no + +用户输入: "一只老虎" +输出: yes + +用户输入: "海滩" +输出: yes + +现在,请判断以下用户输入: +{}""" + +DEFAUL_PROMPT_FOR_NO_INTENT = "纯色背景, 写着\"Please input prompt for image generation\"" + + +class MingImg(object): + """ + Image generation and editing module using BailingMM2 diffusion model. + + This class provides high-level interfaces for text-to-image generation + and image editing operations with support for various aspect ratios + and resolutions. + + Attributes: + model_diffusion: The diffusion model for image generation. + """ + def __init__( + self, + model_path: str, + device: str = "cuda:0", + **kwargs, + ) -> None: + """ + Initialize the MingImg module. + + Args: + model_path (str): Path to the model directory containing the diffusion model. + device (str): GPU device to load the model on. Defaults to "cuda:0". + **kwargs: Additional arguments for model initialization. + """ + os.environ["IMAGE_GEN_MODE"] = "None" + from modeling_bailingmm2 import BailingMM2NativeForConditionalGeneration + + current_device = torch.cuda.current_device() + torch.cuda.set_device(device) + model_diffusion = ( + BailingMM2NativeForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, # Use bfloat16 for memory efficiency + attn_implementation="flash_attention_2", + load_image_gen=True, + low_cpu_mem_usage=True, # Minimize CPU memory during loading + load_vlm=False, # No VLM, only diffusion + ) + .to(device) + .to(torch.bfloat16) + ) + torch.cuda.set_device(current_device) + self.model_diffusion = model_diffusion diff --git a/ming_sdk/ming_moe.py b/ming_sdk/ming_moe.py new file mode 100644 index 0000000..3f9794b --- /dev/null +++ b/ming_sdk/ming_moe.py @@ -0,0 +1,303 @@ +""" +MingMOE: Mixture of Experts Language Model Module + +This module provides a synchronous wrapper around vLLM for text generation +using the Ming Mixture of Experts model. It supports both streaming and +non-streaming generation modes. + +Key Features: + - Tensor parallelism for multi-GPU inference + - Configurable GPU memory utilization + - Rich sampling parameter support (temperature, top_p, top_k, etc.) + - Request interruption support + +Usage: + >>> moe = MingMOE( + ... model_path="/path/to/model", + ... tensor_parallel_size=2, + ... gpu_memory_utilization=0.6 + ... ) + >>> output, usage = moe.generate([{"prompt": "Hello, world!"}]) +""" + +import sys +import uuid +import logging + +from vllm import LLM, SamplingParams +from vllm.inputs import TextPrompt as LLMInputs + +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Tuple, Union, Generator +from ming_sdk.usage import Usage +from ming_sdk.ming_utils import load_json_or_str + +logger = logging.getLogger() + + +class MingMOE(object): + """ + Synchronous Mixture of Experts LLM wrapper using vLLM backend. + + This class provides a high-level interface for text generation using + the Ming MOE model with vLLM as the inference engine. + + Attributes: + max_new_tokens (int): Maximum number of tokens to generate per request. + llm (LLM): The vLLM engine instance. + usage (Usage): Token usage tracker instance. + """ + + max_new_tokens = 10240 + + def __init__( + self, + model_path: str, + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.5, + limit_mm_per_prompt: dict = {"image": 10, "video": 2}, + quantization: str | None = None, + ) -> None: + """ + Initialize the BailigMOE instance. + + Args: + model_path (str): Path to the LLM model directory or Hugging Face repository. + tensor_parallel_size (int, optional): Number of GPU devices for tensor parallelism. Defaults to 1. + gpu_memory_utilization (float, optional): Fraction of GPU memory to use (0.0-1.0). Defaults to 0.6. + sys_prompt (str, optional): System-level prompt to prepend to user inputs. Defaults to empty. + """ + logger.info("using MingMOE") + if quantization == "fp8": + logger.info("using fp8") + else: + quantization = None + logger.info("using bf16") + self.llm = LLM( + model=model_path, + trust_remote_code=True, + enforce_eager=False, + disable_custom_all_reduce=False, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tensor_parallel_size, + limit_mm_per_prompt=limit_mm_per_prompt, + quantization=quantization, + max_seq_len_to_capture=32768, + disable_mm_preprocessor_cache=True, + ) + self.usage = Usage() + + def create_request_id(self) -> str: + """ + Generate a unique request ID for tracking. + + Returns: + str: A UUID4 string representing a unique request identifier. + """ + return str(uuid.uuid4()) + + def build_sampling_params(self, **kwargs) -> SamplingParams: + """ + Build sampling parameters for the LLM. + + Args: + **kwargs: Additional parameters for sampling (e.g., max_new_tokens). + + Returns: + SamplingParams: Configured sampling parameters. + """ + temperature, presence_penalty, repetition_penalty, return_hidden_states = ( + 0.6, + 0, + 1, + False, + ) + max_new_tokens = self.max_new_tokens + top_p, top_k, frequency_penalty, seed, stop = None, None, None, None, None + min_p, stop_token_ids, ignore_eos, logprobs, prompt_logprobs = ( + None, + None, + False, + None, + None, + ) + for key, value in kwargs.items(): + if key == "max_new_tokens" and value is not None: + max_new_tokens = value + if key == "temperature" and value is not None: + temperature = value + if key == "presence_penalty" and value is not None: + presence_penalty = value + if key == "repetition_penalty" and value is not None: + repetition_penalty = value + if key == "return_hidden_states" and value is not None: + return_hidden_states = value + if key == "top_p" and value is not None and isinstance(value, float): + top_p = value + if key == "top_k" and value is not None and isinstance(value, int): + top_k = value + if ( + key == "frequency_penalty" + and value is not None + and isinstance(value, float) + ): + frequency_penalty = value + if key == "seed" and value is not None and isinstance(value, int): + seed = value + if key == "stop" and value is not None: + value = load_json_or_str(value) + if ( + isinstance(value, list) + and all(isinstance(item, str) for item in value) + ) or isinstance(value, str): + stop = value + if key == "min_p" and value is not None and isinstance(value, float): + min_p = value + if key == "stop_token_ids" and value is not None: + if isinstance(value, list) and all(isinstance(v, int) for v in value): + stop_token_ids = value + if key == "ignore_eos" and value is not None and isinstance(value, bool): + ignore_eos = value + if ( + key == "logprobs" + and value is not None + and isinstance(value, int) + and value > 0 + ): + logprobs = value + if ( + key == "prompt_logprobs" + and value is not None + and isinstance(value, int) + and value > 0 + ): + prompt_logprobs = value + + sampling_params_kwargs = { + "temperature": temperature, + "max_tokens": max_new_tokens, + "presence_penalty": presence_penalty, + "repetition_penalty": repetition_penalty, + "return_hidden_states": return_hidden_states, + } + optional_params = { + "top_p": top_p, + "top_k": top_k, + "frequency_penalty": frequency_penalty, + "seed": seed, + "stop": stop, + "min_p": min_p, + "stop_token_ids": stop_token_ids, + "ignore_eos": ignore_eos, + "logprobs": logprobs, + "prompt_logprobs": prompt_logprobs, + } + for param, value in optional_params.items(): + if value is not None: + sampling_params_kwargs[param] = value + + sampling_params = SamplingParams(**sampling_params_kwargs) + return sampling_params + + def generate( + self, requests: List[LLMInputs], with_hidden_status: bool = False, **kwargs + ) -> Any: + """ + Generate text responses from the LLM. + + Args: + requests (List[LLMInputs]): List of input prompts for generation. + with_hidden_status (bool, optional): Whether to return hidden states from the LLM. Defaults to False. + **kwargs: Additional parameters for sampling (e.g., max_new_tokens, temperature). + + Returns: + Any: Generated text or hidden states, depending on `with_hidden_status`. + """ + sampling_params = self.build_sampling_params(**kwargs) + logger.info(f"In ming_moe generate, kwargs: {kwargs}") + logger.info(f"In vllm generate, sampling_params: {sampling_params}") + request_id = self.create_request_id() + logger.info(f"In vllm generate request_id: {request_id}") + inputs = [ + ( + request_id, + requests[0], + sampling_params, + ) + ] + req_id, prompt_text, sampling_params = inputs.pop(0) + llm_engine = self.llm.llm_engine + llm_engine.add_request(str(req_id), prompt_text, sampling_params) + logger.info("start to inference llm") + + output = [] + while llm_engine.has_unfinished_requests(): + try: + output = llm_engine.step() + if len(output) == 0: + continue + except Exception as e: + logger.error(f"[Unexpected Error] {e}") + sys.exit(1) + if len(output) == 0: + raise Exception("llm inference failed") + usage_output = output[0] + usage = self.usage.create_usage_by_requests_id(usage_output, request_id=req_id) + + if with_hidden_status: + return output[0].prefill_hidden_states, usage + return output[0].outputs[0].text, usage + + def generate_stream( + self, requests: List[LLMInputs], request_id: int = 0, **kwargs + ) -> Generator[str, None, None]: + """ + Args: + requests (List[LLMInputs]): List of input prompts for generation. + request_id (int, optional): Unique identifier for the request. Defaults to 0. + **kwargs: Additional parameters for sampling (e.g., max_new_tokens). + + Yields: + str: Incremental text output as it is generated. + """ + logger.info(f"In ming_moe generate stream, kwargs: {kwargs}") + sampling_params = self.build_sampling_params(**kwargs) + logger.info(f"In vllm generate_stream, sampling_params: {sampling_params}") + logger.info(f"In vllm generate_stream request_id: {request_id}") + inputs = [ + ( + request_id, + requests[0], + sampling_params, + ) + ] + req_id, prompt_text, sampling_params = inputs.pop(0) + llm_engine = self.llm.llm_engine + llm_engine.add_request(str(req_id), prompt_text, sampling_params) + logger.info("start to inference llm") + + history_sentence_index = 0 + while llm_engine.has_unfinished_requests(): + try: + request_outputs = llm_engine.step() + if len(request_outputs) == 0: + continue + usage_output = request_outputs[0] + self.usage.create_usage_by_requests_id(usage_output, request_id=req_id) + new_sentence = request_outputs[0].outputs[0].text + sentence = new_sentence[history_sentence_index:] + history_sentence_index = len(new_sentence) + yield sentence + except Exception as e: + logger.error(f"[Unexpected Error] {e}") + sys.exit(1) + + def generate_interrupt(self, request_id: str) -> None: + """ + Interrupt an ongoing request. + + Args: + request_id (str): Unique identifier of the request to abort. + """ + llm_engine = self.llm.llm_engine + llm_engine.abort_request(str(request_id)) diff --git a/ming_sdk/ming_moe_async.py b/ming_sdk/ming_moe_async.py new file mode 100644 index 0000000..437f66e --- /dev/null +++ b/ming_sdk/ming_moe_async.py @@ -0,0 +1,325 @@ +""" +MingMOEAsync: Asynchronous Mixture of Experts Language Model Module + +This module provides an asynchronous wrapper around vLLM for text generation +using the Ming Mixture of Experts model. It is designed for high-throughput +concurrent request scenarios. + +Key Features: + - Asynchronous inference using AsyncLLMEngine + - Dedicated event loop running in a separate thread + - Non-blocking streaming generation + - Request interruption support + +Usage: + >>> moe_async = MingMOEAsync( + ... model_path="/path/to/model", + ... tensor_parallel_size=2, + ... gpu_memory_utilization=0.6 + ... ) + >>> output, usage = await moe_async.generate([{"prompt": "Hello!"}]) +""" + +import sys +import time +import uuid +import asyncio +import logging +import threading + +from vllm import AsyncLLMEngine +from vllm import SamplingParams +from vllm.inputs import TextPrompt as LLMInputs +from vllm.engine.arg_utils import AsyncEngineArgs +from typing import Any, List, Optional, AsyncGenerator + +from ming_sdk.usage import Usage +from ming_sdk.ming_utils import StreamGenerator, load_json_or_str + + +logger = logging.getLogger() + + +class MingMOEAsync(object): + """ + Asynchronous Mixture of Experts LLM wrapper using vLLM AsyncLLMEngine backend. + + This class provides a high-level asynchronous interface for text generation + using the Ming MOE model. It runs a dedicated asyncio event loop in a + separate thread to enable non-blocking concurrent inference. + + Attributes: + max_new_tokens (int): Maximum number of tokens to generate per request. + loop (asyncio.AbstractEventLoop): Dedicated event loop for async operations. + thread (threading.Thread): Thread running the event loop. + engine (AsyncLLMEngine): The vLLM async engine instance. + usage (Usage): Token usage tracker instance. + """ + + max_new_tokens = 10240 + + def __init__( + self, + model_path: str, + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.5, + limit_mm_per_prompt: dict = {"image": 10, "video": 2}, + quantization: str | None = None, + ) -> None: + """ + Initialize the MingMOEAsync instance. + + Args: + model_path (str): Path to the LLM model directory or Hugging Face repository. + tensor_parallel_size (int, optional): Number of GPU devices for tensor parallelism. Defaults to 1. + gpu_memory_utilization (float, optional): Fraction of GPU memory to use (0.0-1.0). Defaults to 0.5. + limit_mm_per_prompt (dict, optional): Max multimedia items per prompt. Defaults to {"image": 10, "video": 2}. + quantization (str | None, optional): Quantization method ("fp8" or None for bf16). Defaults to None. + """ + logger.info("using MingMOEAsync") + if quantization == "fp8": + logger.info("using fp8") + else: + quantization = None + logger.info("using bf16") + engine_args = AsyncEngineArgs( + model=model_path, + trust_remote_code=True, + enforce_eager=False, + max_num_seqs=32, + disable_custom_all_reduce=False, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tensor_parallel_size, + limit_mm_per_prompt=limit_mm_per_prompt, + quantization=quantization, + max_seq_len_to_capture=32768, + disable_mm_preprocessor_cache=True, + ) + + self.loop, self.thread = self.new_and_run_event_loop() + self.engine = AsyncLLMEngine.from_engine_args(engine_args) + + self.usage = Usage() + + def new_and_run_event_loop(self) -> tuple: + """ + Create and start a new asyncio event loop in a dedicated thread. + + This method sets up an isolated event loop for async LLM inference, + preventing conflicts with the main thread's event loop. + + Returns: + tuple: A tuple containing (event_loop, thread) where thread is + running the event loop in daemon mode. + """ + new_loop = asyncio.new_event_loop() + thread_name = "moe_thread" + thread = threading.Thread( + target=new_loop.run_forever, daemon=True, name=thread_name + ) + thread.start() + + # Wait for the event loop to start running + while not new_loop.is_running(): + time.sleep(0.1) + + return new_loop, thread + + def create_request_id(self) -> str: + """ + Generate a unique request ID for tracking. + + Returns: + str: A UUID4 string representing a unique request identifier. + """ + return str(uuid.uuid4()) + + def build_sampling_params(self, **kwargs) -> SamplingParams: + """ + Build sampling parameters for the LLM. + + Args: + **kwargs: Additional parameters for sampling (e.g., max_new_tokens). + + Returns: + SamplingParams: Configured sampling parameters. + """ + temperature, presence_penalty, repetition_penalty, return_hidden_states = ( + 0.6, + 0, + 1, + False, + ) + max_new_tokens = self.max_new_tokens + top_p, top_k, frequency_penalty, seed, stop = None, None, None, None, None + min_p, stop_token_ids, ignore_eos, logprobs, prompt_logprobs = ( + None, + None, + False, + None, + None, + ) + for key, value in kwargs.items(): + if key == "max_new_tokens" and value is not None: + max_new_tokens = value + if key == "temperature" and value is not None: + temperature = value + if key == "presence_penalty" and value is not None: + presence_penalty = value + if key == "repetition_penalty" and value is not None: + repetition_penalty = value + if key == "return_hidden_states" and value is not None: + return_hidden_states = value + if key == "top_p" and value is not None and isinstance(value, float): + top_p = value + if key == "top_k" and value is not None and isinstance(value, int): + top_k = value + if ( + key == "frequency_penalty" + and value is not None + and isinstance(value, float) + ): + frequency_penalty = value + if key == "seed" and value is not None and isinstance(value, int): + seed = value + if key == "stop" and value is not None: + value = load_json_or_str(value) + if ( + isinstance(value, list) + and all(isinstance(item, str) for item in value) + ) or isinstance(value, str): + stop = value + if key == "min_p" and value is not None and isinstance(value, float): + min_p = value + if key == "stop_token_ids" and value is not None: + if isinstance(value, list) and all(isinstance(v, int) for v in value): + stop_token_ids = value + if key == "ignore_eos" and value is not None and isinstance(value, bool): + ignore_eos = value + if ( + key == "logprobs" + and value is not None + and isinstance(value, int) + and value > 0 + ): + logprobs = value + if ( + key == "prompt_logprobs" + and value is not None + and isinstance(value, int) + and value > 0 + ): + prompt_logprobs = value + + sampling_params_kwargs = { + "temperature": temperature, + "max_tokens": max_new_tokens, + "presence_penalty": presence_penalty, + "repetition_penalty": repetition_penalty, + "return_hidden_states": return_hidden_states, + } + optional_params = { + "top_p": top_p, + "top_k": top_k, + "frequency_penalty": frequency_penalty, + "seed": seed, + "stop": stop, + "min_p": min_p, + "stop_token_ids": stop_token_ids, + "ignore_eos": ignore_eos, + "logprobs": logprobs, + "prompt_logprobs": prompt_logprobs, + } + for param, value in optional_params.items(): + if value is not None: + sampling_params_kwargs[param] = value + + sampling_params = SamplingParams(**sampling_params_kwargs) + return sampling_params + + def generate( + self, requests: List[LLMInputs], with_hidden_status: bool = False, **kwargs + ) -> Any: + """ + Generate text responses from the LLM. + + Args: + requests (List[LLMInputs]): List of input prompts for generation. + with_hidden_status (bool, optional): Whether to return hidden states from the LLM. Defaults to False. + **kwargs: Additional parameters for sampling (e.g., max_new_tokens, temperature). + + Returns: + Any: Generated text or hidden states, depending on `with_hidden_status`. + """ + sampling_params = self.build_sampling_params(**kwargs) + request_id = self.create_request_id() + + async def _inner(): + final = None + async for output in self.engine.generate( + prompt=requests[0], + sampling_params=sampling_params, + request_id=request_id, + ): + final = output + return final + + future = asyncio.run_coroutine_threadsafe(_inner(), self.loop) + output = None + usage = self.usage.create_usage_default(prompt_tokens=0) + try: + output = future.result() + if isinstance(output, dict) and "error" in output: + logger.error(f"[Error] {output['type']}: {output['error']}") + except Exception as e: + logger.error(f"[Unexpected Error] {e}") + sys.exit(1) + if output is None: + return None, usage + if with_hidden_status: + return output.prefill_hidden_states, self.usage.create_usage(output) + return output.outputs[0].text, self.usage.create_usage(output) + + async def _generate_stream_async( + self, requests: List[LLMInputs], request_id: int = 0, **kwargs + ) -> AsyncGenerator[str, None]: + sampling_params = self.build_sampling_params(**kwargs) + history_sentence_index = 0 + async for request_outputs in self.engine.generate( + prompt=requests[0], + sampling_params=sampling_params, + request_id=request_id, + ): + self.usage.create_usage_by_requests_id( + request_outputs, request_id=request_id + ) + new_sentence = request_outputs.outputs[0].text + sentence = new_sentence[history_sentence_index:] + history_sentence_index = len(new_sentence) + yield sentence + + def generate_stream(self, requests: List[LLMInputs], request_id: int = 0, **kwargs): + async def _inner_stream(): + try: + async for output in self._generate_stream_async( + requests, request_id, **kwargs + ): + yield output + except Exception as e: + logger.error(f"[Unexpected Error] {e}") + sys.exit(1) + + return StreamGenerator(self.loop, _inner_stream()) + + def generate_interrupt(self, request_id: str) -> None: + """ + Interrupt an ongoing request. + + Args: + request_id (str): Unique identifier of the request to abort. + """ + + async def _inner(): + await self.engine.abort(str(request_id)) + + asyncio.run_coroutine_threadsafe(_inner(), self.loop) diff --git a/ming_sdk/ming_talker.py b/ming_sdk/ming_talker.py new file mode 100644 index 0000000..a2c21c6 --- /dev/null +++ b/ming_sdk/ming_talker.py @@ -0,0 +1,489 @@ +""" +MingTalker: Text-to-Speech Module + +This module provides a multiprocessing-based TTS (Text-to-Speech) system +using the BailingTalker model. It supports both streaming and non-streaming +audio generation with multi-GPU load balancing. + +Key Features: + - Process pool for parallel TTS generation across multiple GPUs + - Streaming audio output for real-time speech synthesis + - Chinese text detection for proper tokenization + - Request cancellation support + +Architecture: + - Main process: Handles request queuing and result collection + - Worker processes: Each hosts a Talker model on a specific GPU + - Inter-process communication via multiprocessing.Queue + +Usage: + >>> talker = MingTalker( + ... model_path="/path/to/model", + ... device_list=["cuda:0", "cuda:1"] + ... ) + >>> audio, duration = talker.generate("Hello, world!") +""" + +import re +import os +import time +import torch +import logging +import multiprocessing +import threading +from concurrent.futures import ProcessPoolExecutor, as_completed + +from typing import Any, List, Tuple, Generator, Union +import threading + +logger = logging.getLogger() +import queue +import uuid + + +def contains_chinese(text: str) -> bool: + """ + Check if the input text contains Chinese characters. + + Args: + text (str): Input text to check. + + Returns: + bool: True if Chinese characters are present, False otherwise. + """ + return bool(re.search(r"[\u4e00-\u9fff]", text)) + +class ProcessTalkerInstance: + """ + Singleton Talker instance for each worker process. + + Each worker process maintains a single Talker model instance + to avoid repeated model loading overhead. + """ + def __init__(self, model_path: str, device: str): + from modeling_bailing_talker import BailingTalker2 + from AudioVAE.modeling_audio_vae import AudioVAE + + logger.info(f"Loading talker model on {device}...") + with torch.cuda.device(device): + dtype = torch.bfloat16 + self.talker = BailingTalker2.from_pretrained( + f"{model_path}/talker").to( + dtype=dtype, + device=device + ) + self.talker_vae = AudioVAE.from_pretrained( + f"{model_path}/talker/vae").to( + dtype=dtype, + device=device + ) + self.talker.eval() + self.talker_vae.eval() + + self.device = device + logger.info(f"Talker model loaded successfully on {device}") + + +# Global Talker instance per process (singleton pattern) +_process_talker_instance = None + +def get_process_talker(model_path: str, device: str) -> ProcessTalkerInstance: + """ + Get or create the Talker instance for the current process (singleton). + + Args: + model_path (str): Path to the model directory. + device (str): GPU device to load the model on. + + Returns: + ProcessTalkerInstance: The singleton Talker instance for this process. + """ + global _process_talker_instance + if _process_talker_instance is None: + _process_talker_instance = ProcessTalkerInstance(model_path, device) + + logger.info(f'get_process_talker: {device}, actual model device: {_process_talker_instance.device}') + return _process_talker_instance + + +def process_generate(args) -> torch.Tensor: + """ + Worker function for non-streaming audio generation. + + This function runs in a worker process and performs the actual + TTS generation using the process-local Talker instance. + + Args: + args (tuple): Contains (text, speaker, model_path, device, cancel_flag). + + Returns: + Tuple[torch.Tensor, float]: Generated audio waveform and duration in seconds. + """ + text, speaker, model_path, device, cancel_flag = args + try: + talker_instance = get_process_talker(model_path, device) + actual_talker_device = talker_instance.device + talker = talker_instance.talker + talker_vae = talker_instance.talker_vae + + is_chinese = contains_chinese(text) + if not is_chinese: + text = text.split() + + all_wavs = [] + duration_ = 0 + last_time = time.perf_counter() + for ( + tts_speech, + text_list, + word_postion, + duration, + ) in talker.omni_audio_generation( + tts_text=text, + voice_name=speaker, + audio_detokenizer=talker_vae, + stream=False, + ): + if cancel_flag.is_set(): + break + all_wavs.append(tts_speech) + this_time = time.perf_counter() + # logging.info(f"chunk time cost: {this_time - last_time:.3f}s") + last_time = this_time + if duration is None: + duration = 0 + duration_ += duration + + waveform = torch.cat(all_wavs, dim=-1) + + # logging.info("finish process_generate func.") + return waveform, duration_ + + except Exception as e: + logger.error(f"Error in process_generate on {actual_talker_device}: {e}") + raise + + +def process_stream_generate(args) -> Generator[Tuple[torch.Tensor, List[str], dict], None, None]: + """ + Worker function for streaming audio generation. + + This function runs in a worker process and generates audio chunks + incrementally, yielding results through a multiprocessing Queue. + + Args: + args (tuple): Contains (text_queue, speaker, model_path, device, + result_queue, cancel_flag). + + Yields: + Tuple[torch.Tensor, List[str], dict]: Audio chunk, text segments, and metadata. + """ + text_queue, speaker, model_path, device, result_queue, cancel_flag = args + actual_talker_device = None + try: + talker_instance = get_process_talker(model_path, device) + actual_talker_device = talker_instance.device + talker = talker_instance.talker + talker_vae = talker_instance.talker_vae + + def text_wrapper(input_queue): + while True: + try: + if cancel_flag.is_set(): + break + text = input_queue.get(timeout=300) # 5分钟超时 + if text is None: # 结束信号 + break + yield text + except queue.Empty: + logger.warning("Text queue timeout, exiting...") + break + + logger.info(f"Starting stream generation on {device}") + last_time = time.perf_counter() + for ( + tts_speech, + text_list, + word_postion, + duration, + ) in talker.omni_audio_generation( + tts_text=text_wrapper(text_queue), + audio_detokenizer=talker_vae, + voice_name=speaker, + stream=True, + ): + if cancel_flag.is_set(): + break + this_time = time.perf_counter() + # logging.info(f"chunk time cost: {this_time - last_time:.3f}s") + last_time = this_time + + star_index, end_index, duration_ = 0, 0, 0 + if word_postion and len(word_postion) > 0 and word_postion[0] is not None: + star_index = word_postion[0] + if word_postion and len(word_postion) > 1 and word_postion[1] is not None: + end_index = word_postion[1] + if duration is not None and duration > 0: + duration_ = duration + + meta_info = { + "star_index": star_index, + "end_index": end_index, + "duration": duration_, + } + # Send results to the main process via queue + result_queue.put((tts_speech, text_list, meta_info)) + except Exception as e: + logger.error(f"Error in process_stream_generate on {actual_talker_device}: {e}") + raise + finally: + # Signal end of stream to the main process + result_queue.put(None) + logger.info(f"Stream generation finished on {actual_talker_device}") + logging.info(f"Stream generation finished on {actual_talker_device}") + + +class MingTalker(object): + """ + Main TTS interface with multi-GPU process pool support. + + This class provides a high-level interface for text-to-speech generation + with automatic load balancing across multiple GPUs using a process pool. + + Attributes: + model_path (str): Path to the model directory. + device_list (list): List of GPU devices for TTS workers. + process_pool (ProcessPoolExecutor): Pool of worker processes. + manager (multiprocessing.Manager): Manager for inter-process communication. + """ + def __init__( + self, + model_path: str, + tensor_parallel_size: int = 1, + device_list: list = ["cuda:0"], + ) -> None: + """ + Initialize the MingTalker instance. + + Args: + model_path (str): Path to the model directory containing TTS components. + tensor_parallel_size (int, optional): Number of GPU devices for tensor parallelism. Defaults to 1. + device_list (list, optional): List of GPU devices to use. Defaults to ["cuda:0"]. + """ + super().__init__() + self.model_path = model_path + self.device_list = device_list + multiprocessing.set_start_method('spawn', force=True) + self.manager = multiprocessing.Manager() + + # Create process pool, one worker per GPU device + self.process_pool = ProcessPoolExecutor( + max_workers=len(device_list) + ) + + # Preload models on all devices to ensure readiness + logging.info("Preloading models on all devices...") + futures = [] + for device in device_list: + future = self.process_pool.submit(process_generate, + ("test", "DB30", model_path, device, self.manager.Event())) + futures.append(future) + + # Wait for preload to complete (ignore test results) + for future in as_completed(futures): + try: + future.result(timeout=120) # 2-minute timeout for model loading + except Exception as e: + logger.warning(f"Preload may have failed (expected for test text): {e}") + + self.lock = threading.Lock() + self.task_count = 0 + self.cancel_flag_dict = dict() + self.futures_dict = dict() + self.future_lock = threading.Lock() + + logging.info(f"MingTalker initialized successfully with {len(device_list)} devices") + + def logging_tasks(self, request_id, future): + time.sleep(0.01) + running_keys = [] + pendding_keys = [] + done_count = 0 + with self.future_lock: + self.futures_dict[request_id] = future + all_keys = list(self.futures_dict.keys()) + done_keys = [] + for key in all_keys: + future = self.futures_dict[key] + if future.done(): + done_keys.append(key) + elif future.running(): + running_keys.append(key) + else: + pendding_keys.append(key) + + for key in done_keys: + self.futures_dict.pop(key) + + logging.info(f"Running Count: {len(running_keys)}, Pending Count: {len(pendding_keys)}") + if len(pendding_keys) > 0: + logging.info(f"Pending Keys: {pendding_keys}, Running Keys: {running_keys}") + + + def generate(self, text: str, speaker: str = "DB30", request_id: str = str(uuid.uuid4()),) -> torch.Tensor: + """ + Generate audio from text using the TTS model. + + Args: + text (str): Input text to convert to speech. + speaker (str, optional): Speaker identifier. Defaults to "DB30". + request_id (str, optional): Unique request identifier. Defaults to new UUID. + + Returns: + Tuple[torch.Tensor, float]: Generated audio waveform and duration in seconds. + """ + # Submit task to process pool; pool scheduler assigns to available worker + try: + cancel_flag = self.manager.Event() + with self.lock: + self.task_count += 1 + logging.info(f"Running Task Count: {self.task_count}, with generate.") + self.cancel_flag_dict[request_id] = cancel_flag + + future = self.process_pool.submit(process_generate, + (text, speaker, self.model_path, None, cancel_flag)) + self.logging_tasks(request_id=request_id, future=future) + return future.result() + + finally: + with self.lock: + self.task_count -= 1 + if request_id in self.cancel_flag_dict: + self.cancel_flag_dict[request_id].set() + self.cancel_flag_dict.pop(request_id) + + + def generate_stream( + self, text: Union[str, Generator[str, None, None]], speaker: str = "DB30", request_id: str = str(uuid.uuid4()), + ) -> Generator[Tuple[torch.Tensor, List[str], dict], None, None]: + """ + Stream audio generation from text incrementally. + + Args: + text: Input text, can be string or text generator for streaming input. + speaker (str, optional): Speaker identifier. Defaults to "DB30". + request_id (str, optional): Unique request identifier. Defaults to new UUID. + + Yields: + Tuple[torch.Tensor, List[str], dict]: Audio segment, text segments, and metadata. + """ + try: + cancel_flag = self.manager.Event() + with self.lock: + self.task_count += 1 + logging.info(f"Request_id: {request_id}, Running Task Count: {self.task_count}, with generate_stream.") + self.cancel_flag_dict[request_id] = cancel_flag + + # Create inter-process communication queues + text_queue = self.manager.Queue() + result_queue = self.manager.Queue() + + # Start streaming generation task + future = self.process_pool.submit(process_stream_generate, + (text_queue, speaker, self.model_path, None, result_queue, cancel_flag)) + + # Feed text into the queue + self._produce_text_to_queue(text, text_queue) + + self.logging_tasks(request_id=request_id, future=future) + + # Consume streaming results + for result in self._consume_stream_results(result_queue, future): + yield result + finally: + with self.lock: + self.task_count -= 1 + if request_id in self.cancel_flag_dict: + self.cancel_flag_dict[request_id].set() + self.cancel_flag_dict.pop(request_id) + + def _produce_text_to_queue(self, text: Union[str, Generator[str, None, None]], + text_queue: multiprocessing.Queue): + """ + Produce text data into the queue for streaming generation. + + This method runs a background thread to feed text chunks into + the worker process's input queue. + + Args: + text: Input text (string or generator). + text_queue: Multiprocessing queue for text chunks. + """ + def producer(): + try: + if isinstance(text, str): + for text_str in text: + text_queue.put(text_str) + else: + for chunk in text: + text_queue.put(chunk) + text_queue.put(None) # End signal + except Exception as e: + logger.error(f"Error in text producer: {e}") + text_queue.put(None) + + # Run producer in background thread + import threading + producer_thread = threading.Thread(target=producer) + producer_thread.daemon = True + producer_thread.start() + + def _consume_stream_results(self, result_queue: multiprocessing.Queue, future): + """ + Consume streaming results from the worker process queue. + + Args: + result_queue: Multiprocessing queue for audio results. + future: Future object for the worker process task. + + Yields: + Tuple[torch.Tensor, List[str], dict]: Audio chunk, text segments, and metadata. + """ + try: + while True: + try: + result = result_queue.get(timeout=300) # 5-minute timeout + if result is None: # End signal + break + + tts_speech, text_list, meta_info = result + + yield tts_speech, text_list, meta_info + + except queue.Empty: + logger.warning("Result queue timeout") + break + + finally: + # Ensure the task has completed + try: + future.result() + except: + logger.warning("Stream generation future may have timed out") + + def generate_interrupt(self, request_id: str) -> None: + """ + Interrupt an ongoing TTS generation request. + + Args: + request_id (str): ID of the request to cancel. + """ + with self.lock: + if request_id in self.cancel_flag_dict: + cancel_flag = self.cancel_flag_dict[request_id] + cancel_flag.set() + + def __del__(self): + """Clean up resources and shutdown process pool.""" + if hasattr(self, 'process_pool'): + self.process_pool.shutdown(wait=True) \ No newline at end of file diff --git a/ming_sdk/ming_test.py b/ming_sdk/ming_test.py new file mode 100644 index 0000000..04b5b41 --- /dev/null +++ b/ming_sdk/ming_test.py @@ -0,0 +1,378 @@ +""" +Ming SDK Test Examples + +This module demonstrates various usage patterns of the Ming SDK, including: +- Text generation (streaming and non-streaming) +- Speech synthesis (TTS) +- Speech-to-speech conversation +- Image understanding and generation +- Audio understanding (ASR) +- Video understanding + +Before running: +1. Install dependencies: pip install -r requirements.txt +2. Download model weights and configure model_path +3. Ensure sufficient GPU memory (4x A100/H20 recommended) + +Author: qiaozhuo +""" + +import os +import torch +import torchaudio +from ming_sdk.ming import Ming + + +# ============================================================================ +# Text Generation Tests +# ============================================================================ + +def test_text_generate(): + """ + Test non-streaming text generation. + + Use case: Batch text processing, scenarios without real-time feedback. + + Returns: + tuple: (generated_text, usage_statistics) + """ + text, usage = ming.generate(text="Introduce Hangzhou") + print(f"Generated text: {text}") + print(f"Usage: {usage}") + assert text is not None, "Text generation failed" + + +def test_text_generate_stream(): + """ + Test streaming text generation. + + Use case: Real-time dialogue systems, progressive content display. + + Yields: + tuple: (text_chunk, request_id, usage) for each streaming chunk + """ + all_text = "" + request_id = "" + for text, request_id, usage in ming.generate_stream( + text="Introduce Hangzhou", + max_new_tokens=128 + ): + all_text += text + print(text, end="", flush=True) # Real-time output + + print(f"\nRequest ID: {request_id}") + print(f"Full text: {all_text}") + print(f"Usage: {usage}") + assert all_text, "Streaming text generation failed" + + +# ============================================================================ +# Speech Generation Tests +# ============================================================================ + +def test_audio_generate(): + """ + Test non-streaming speech generation (Speech-to-Speech). + + Flow: + 1. Generate text response from input text + 2. Convert response text to speech + + Use case: Voice assistants, conversational AI systems. + + Returns: + tuple: (waveform, generated_text, usage) + """ + output_audio_path = "test_speech.wav" + waveform, gen_text, usage = ming.generate( + text="Introduce Hangzhou", + output_type="speech", + max_new_tokens=128 + ) + + sr = 16000 # Sample rate: 16kHz + torchaudio.save(output_audio_path, waveform, sr) + + print(f"Generated text: {gen_text}") + print(f"Audio saved to: {output_audio_path}") + print(f"Usage: {usage}") + assert os.path.exists(output_audio_path), "Audio file generation failed" + + +def test_audio_generate_stream(): + """ + Test streaming speech generation (Speech-to-Speech). + + Flow: + 1. Stream generate text response + 2. Convert each text chunk to speech in real-time + + Features: + - Lower first-token latency + - Better user experience + + Use case: Real-time voice conversation, phone customer service. + + Yields: + tuple: (data_type, data_content) where data_content varies by type + """ + all_wavs = [] + all_text = "" + request_id = "" + output_audio_path = "test_speech_stream.wav" + + for data_type, data_content in ming.generate_stream( + text="Introduce Hangzhou", + output_type="speech", + max_new_tokens=128 + ): + if data_type == "text_data": + # Pure text chunk (intermediate output) + text, usage = data_content + elif data_type == "text_audio_data": + # Text + audio chunk + tts_speech, text, meta_info, session_id, usage = data_content + all_text += text + all_wavs.append(tts_speech) + print(f"Chunk: {text}") + + # Concatenate all audio chunks + waveform = torch.cat(all_wavs, dim=-1) + sr = 16000 + torchaudio.save(output_audio_path, waveform, sr) + + print(f"Full text: {all_text}") + print(f"Audio saved to: {output_audio_path}") + print(f"Usage: {usage}") + assert os.path.exists(output_audio_path), "Streaming audio generation failed" + + +def test_audio_generate_stream_interrupt(): + """ + Test streaming speech generation with interruption capability. + + Use case: User-initiated interruption, timeout control. + + The generation will be interrupted when text length exceeds threshold. + """ + all_wavs = [] + all_text = "" + request_id = "test-interrupt-001" + output_audio_path = "test_speech_interrupt.wav" + + for data_type, data_content in ming.generate_stream( + text="Introduce Hangzhou", + output_type="speech", + max_new_tokens=128, + msg_request_id=request_id + ): + if data_type == "text_data": + text, usage = data_content + elif data_type == "text_audio_data": + tts_speech, text, meta_info, session_id, usage = data_content + all_text += text + all_wavs.append(tts_speech) + + # Interrupt when text length exceeds 20 characters + if len(all_text) > 20: + print(f"Interrupting at: {all_text}") + ming.generate_interrupt(request_id) + break + + if all_wavs: + waveform = torch.cat(all_wavs, dim=-1) + sr = 16000 + torchaudio.save(output_audio_path, waveform, sr) + + print(f"Interrupted text: {all_text}") + print(f"Audio saved to: {output_audio_path}") + assert os.path.exists(output_audio_path), "Interrupt test failed" + + +# ============================================================================ +# TTS Tests +# ============================================================================ + +def test_tts(): + """ + Test pure text-to-speech (TTS) conversion. + + Use case: Text reading, voice announcement systems. + + Note: Unlike speech generation, TTS directly converts input text + to speech without generating intermediate response text. + """ + output_audio_path = "test_tts.wav" + waveform, usage = ming.generate( + text="I love the Forbidden City in Beijing", + output_type="tts" + ) + + sr = 16000 # Sample rate: 16kHz + torchaudio.save(output_audio_path, waveform, sr) + + print(f"Audio saved to: {output_audio_path}") + print(f"Duration: {waveform.shape[-1] / sr:.2f} seconds") + print(f"Usage: {usage}") + assert os.path.exists(output_audio_path), "TTS generation failed" + + +# ============================================================================ +# Image Tests +# ============================================================================ + +def test_image_qa(): + """ + Test image understanding (Image QA). + + Use case: Image description, visual question answering. + + Args: + image: Path to the image file (supports jpg, png, etc.) + """ + image_path = "test.png" + + # Check if test image exists + if not os.path.exists(image_path): + print(f"Skipping test: Image file {image_path} not found") + return + + text, usage = ming.generate( + text="Describe this image in detail", + image=image_path, + output_type="text" + ) + + print(f"Image description: {text}") + print(f"Usage: {usage}") + assert text is not None, "Image QA failed" + + +# ============================================================================ +# Audio Understanding Tests +# ============================================================================ + +def test_audio_task(): + """ + Test audio understanding (ASR/Audio QA). + + Supported tasks: + - Automatic Speech Recognition (ASR) + - Audio content understanding + - Audio event detection + + Args: + audio: Path to audio file or URL + """ + audio_path = "test.wav" + + # Check if test audio exists + if not os.path.exists(audio_path): + print(f"Skipping test: Audio file {audio_path} not found") + return + + asr_result, usage = ming.generate( + text="Please recognize the language of this speech and transcribe it. Format: oral.", + audio=audio_path, + ) + + print(f"ASR result: {asr_result}") + print(f"Usage: {usage}") + assert asr_result is not None, "Audio task failed" + + +# ============================================================================ +# Video Understanding Tests +# ============================================================================ + +def test_video(): + """ + Test video understanding. + + Supported tasks: + - Video content description + - Video QA + - Video summarization + + Args: + video: Path to video file (supports mp4, etc.) + """ + video_path = "test.mp4" + + # Check if test video exists + if not os.path.exists(video_path): + print(f"Skipping test: Video file {video_path} not found") + return + + text, usage = ming.generate( + text="Describe this video in detail", + video=video_path, + output_type="text" + ) + + print(f"Video description: {text}") + print(f"Usage: {usage}") + assert text is not None, "Video understanding failed" + + +# ============================================================================ +# Main Entry Point +# ============================================================================ + +if __name__ == "__main__": + # ========================================================================= + # Configuration + # ========================================================================= + + # Model path - update this to your actual model path + MODEL_PATH = "your_model_path" + + # GPU device mapping for different modules + DEVICE_MAP = {"talker": ["cuda:0"]} + + # Initialize Ming SDK + print("=" * 60) + print("Initializing Ming SDK...") + print("=" * 60) + + ming = Ming( + model_path=MODEL_PATH, + device="0,1,2,3", # GPU devices for LLM (comma-separated) + gpu_memory_utilization={"moe": 0.8, "talker": 0.17}, + device_map=DEVICE_MAP, + speaker="DB30", # TTS speaker ID + with_async=True, + use_talker=True + ) + + print("Ming SDK initialized successfully!\n") + + # ========================================================================= + # Run Tests + # ========================================================================= + + # Text generation tests + print("\n" + "=" * 60) + print("Running Text Generation Tests") + print("=" * 60) + test_text_generate() + test_text_generate_stream() + + # Multimodal tests + print("\n" + "=" * 60) + print("Running Multimodal Tests") + print("=" * 60) + test_image_qa() + test_audio_task() + test_video() + + # Speech generation tests + print("\n" + "=" * 60) + print("Running Speech Generation Tests") + print("=" * 60) + test_audio_generate() + test_audio_generate_stream() + + print("\n" + "=" * 60) + print("All tests completed!") + print("=" * 60) \ No newline at end of file diff --git a/ming_sdk/ming_utils.py b/ming_sdk/ming_utils.py new file mode 100644 index 0000000..26e7f28 --- /dev/null +++ b/ming_sdk/ming_utils.py @@ -0,0 +1,902 @@ +""" +MingUtils: Utility Functions and Classes for Ming SDK + +This module provides essential utilities for the Ming SDK, including: + - Multimedia processing (image, video, audio) + - Prompt building and tokenization + - File download utilities + - Caching mechanisms + - Streaming generation helpers + +Key Components: + - MingUtils: Main utility class for prompt processing + - DownloadUtils: HTTP download with metadata extraction + - SimpleCache / ThreadSafeCache: TTL-based caching with LRU eviction + - StreamGenerator: Async-to-sync stream wrapper +""" + +import os +import json +import time +import torch +import asyncio +import logging +import requests +import threading +import subprocess +from PIL import Image +from queue import Queue, Empty + +from collections import OrderedDict +from typing import Iterable, TypeVar, Iterator + +from transformers import AutoProcessor, AutoTokenizer +from vllm.inputs import TextPrompt as LLMInputs +from typing import Any, Dict, Optional, Tuple, Union, List, AsyncGenerator + +T = TypeVar("T") +logger = logging.getLogger() +from enum import IntEnum, unique + + +@unique +class MingStatus(IntEnum): + """Status codes for Ming SDK operations.""" + OK = 200 + ParametersIllegal = 401 + DonwloadFail = 402 + DonwloadTimeout = 403 + VideoSizeLimit = 404 + AudioSizeLimit = 405 + ImageSizeLimit = 406 + + +class DownloadUtils(object): + """HTTP download utility with metadata extraction support.""" + + def __init__(self): + pass + + def get_meta_info(self, url: str) -> Dict[str, Any]: + """ + Extract metadata (dimensions, duration, codec, etc.) from a media URL. + + Uses ffprobe to analyze the media file without downloading it. + + Args: + url (str): URL of the media file. + + Returns: + Dict[str, Any]: Metadata including width, height, fps, duration, + codec_type, and size. + """ + + meta_info = { + "width": 0, + "height": 0, + "fps": 0, + "duration": 0, + "codec_type": "", + "size": 0, + } + if "https://" in url: + url = url.replace("https://", "http://") + + ffprobe_cmd = f"ffprobe -v error -show_entries format=duration,size -show_entries stream=duration,codec_type,width,height,avg_frame_rate -i '{url}' -of json" + ret = subprocess.run( + [ffprobe_cmd], shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + if ret.returncode != 0: + logger.warning(f"get meta info fial {url}") + return meta_info + output = ret.stdout.decode("utf8") + info = json.loads(output) + + if "streams" in info: + for stream in info["streams"]: + if "codec_type" not in stream: + continue + if stream["codec_type"] == "video": + meta_info["width"] = stream["width"] + meta_info["height"] = stream["height"] + meta_info["size"] = int(info["format"]["size"]) / ( + 1024 * 1024 + ) # MB + if "duration" in stream: + meta_info["codec_type"] = "video" + meta_info["duration"] = float(stream["duration"]) + meta_info["fps"] = float(stream["avg_frame_rate"]) + else: + meta_info["codec_type"] = "image" + break + if stream["codec_type"] == "audio": + meta_info["duration"] = float(stream["duration"]) + meta_info["size"] = int(info["format"]["size"]) / (1024 * 1024) + meta_info["codec_type"] = "audio" + return meta_info + + def download( + self, url: str, target_path: str, filename, timeout: Tuple[int, int] = (10, 180) + ) -> Dict[int, Any]: + """ + Download a file from the given URL and save it to the specified local path. + + This function uses streaming to efficiently download large files without loading them + entirely into memory. It ensures atomicity by writing to a temporary file first and + then renaming it upon completion. + + Args: + url (str): The URL of the file to download. + target_path (str): The local directory where the file will be saved. + filename (str): The name to save the file as. + timeout (Tuple[int, int]): Connection and read timeout in seconds. Default is (10, 180). + + Returns: + Tuple[int, str | None]: A tuple containing: + - status code: 0 for success, negative values for specific failures + - file path if successful, otherwise None + """ + STATUS_SUCCESS = 0 + STATUS_DOWNLOAD_FAILED = -1 + try: + response = requests.get(url, stream=True, timeout=timeout) + response.raise_for_status() + if not os.path.exists(target_path): + os.makedirs(target_path) + target_path = os.path.join(target_path, filename) + with open(target_path, "wb") as file: + for chunk in response.iter_content(chunk_size=1024): + file.write(chunk) + logger.info(f"Download success: {target_path}") + return STATUS_SUCCESS, target_path + except requests.exceptions.RequestException as e: + logger.error(f"Download failed: {e}, url {url}") + except Exception as e: + logger.error(f"Save file failed: {e}, url {url}") + return STATUS_DOWNLOAD_FAILED, None + +def load_json_or_str(s: str) -> object: + """ + Attempt to parse a string as JSON; return the original value on failure. + + Args: + s (str): Input string to parse. + + Returns: + object: Parsed JSON object or original string. + """ + if not isinstance(s, str): + return s + try: + return json.loads(s) + except (json.JSONDecodeError, TypeError): + return s + + +class MingUtils(object): + """ + Main utility class for prompt building and multimedia processing. + + This class handles: + - Tokenization and prompt template application + - Multimedia (image/video/audio) input processing + - Message filtering based on token limits + - History management for multi-turn conversations + + Attributes: + processor: HuggingFace processor for multimodal inputs. + tokenizer: HuggingFace tokenizer. + limit_images (int): Maximum images per prompt. + limit_videos (int): Maximum videos per prompt. + sample_rate (int): Audio sample rate for processing. + max_frames (int): Maximum frames for video processing. + """ + def __init__( + self, + model_path: str, + limit_mm_per_prompt={"image": 10, "video": 2}, + sample_rate=16000, + sys_prompt=None, + ): + from processing_bailingmm2 import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_FRAME_PATCH_TOKEN, DEFAULT_AUDIO_PATCH_TOKEN + self.image_path_token = DEFAULT_IMAGE_PATCH_TOKEN + self.frame_path_token = DEFAULT_FRAME_PATCH_TOKEN + self.audio_path_token = DEFAULT_AUDIO_PATCH_TOKEN + self.processor = AutoProcessor.from_pretrained( + model_path, trust_remote_code=True + ) + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True + ) + self.limit_mm_per_prompt = limit_mm_per_prompt + if "image" in limit_mm_per_prompt: + self.limit_images = limit_mm_per_prompt["image"] + else: + self.limit_images = None + + if "video" in limit_mm_per_prompt: + self.limit_videos = limit_mm_per_prompt["video"] + else: + self.limit_videos = None + self.sys_prompt = sys_prompt + self.sample_rate = sample_rate + self.max_frames = 40 + + def filter_message(self, data: list, limit_images: int = 10, + limit_videos: int = 2, limit_audios: int = 1) -> list: + """ + Filter conversation messages to enforce multimedia limits. + + This method ensures the total number of images, videos, and audios + in the conversation history does not exceed specified limits. + + Args: + data (list): List of conversation messages. + limit_images (int): Maximum allowed images. Defaults to 10. + limit_videos (int): Maximum allowed videos. Defaults to 2. + limit_audios (int): Maximum allowed audios. Defaults to 1. + + Returns: + list: Filtered list of messages respecting the limits. + """ + total_image_count = 0 + total_video_count = 0 + total_audio_count = 0 + last_item_audios = 0 + + filtered_data = [] + last_item = data[-1] if data else None + + if last_item and last_item["role"] == "HUMAN": + last_item_images = sum( + 1 for content in last_item["content"] if content["type"] == "image" + ) + last_item_videos = sum( + 1 for content in last_item["content"] if content["type"] == "video" + ) + last_item_audios = sum( + 1 for content in last_item["content"] if content["type"] == "audio" + ) + + if ( + total_image_count + last_item_images <= limit_images + and total_video_count + last_item_videos <= limit_videos + and total_audio_count + last_item_audios <= limit_audios + ): + filtered_data.append(last_item) + total_image_count += last_item_images + total_video_count += last_item_videos + total_audio_count += last_item_audios + + temp_human = None + temp_assistant = None + for entry in reversed(data[:-1]): + if entry["role"] == "HUMAN": + temp_human = entry + + if temp_human and temp_assistant: + human_images = sum( + 1 + for content in temp_human["content"] + if content["type"] == "image" + ) + human_videos = sum( + 1 + for content in temp_human["content"] + if content["type"] == "video" + ) + human_audios = sum( + 1 + for content in temp_human["content"] + if content["type"] == "audio" + ) + assistant_images = sum( + 1 + for content in temp_assistant["content"] + if content["type"] == "image" + ) + assistant_videos = sum( + 1 + for content in temp_assistant["content"] + if content["type"] == "video" + ) + assistant_audios = sum( + 1 + for content in temp_assistant["content"] + if content["type"] == "audio" + ) + + new_image_count = ( + total_image_count + human_images + assistant_images + ) + new_video_count = ( + total_video_count + human_videos + assistant_videos + ) + new_audio_count = ( + total_audio_count + human_audios + assistant_audios + ) + + if ( + new_image_count > limit_images + or new_video_count > limit_videos + or new_audio_count > limit_audios + ): + temp_human = None + temp_assistant = None + continue + elif last_item_audios > 0 and human_audios + assistant_audios > 0: + temp_human = None + temp_assistant = None + continue + else: + filtered_data.append(temp_assistant) + filtered_data.append(temp_human) + total_image_count = new_image_count + total_video_count = new_video_count + total_audio_count = new_audio_count + + temp_human = None + temp_assistant = None + + elif entry["role"] == "ASSISTANT": + temp_assistant = entry + + return filtered_data[::-1] + + def compute_text_input_tokens(self, text: str, **kwargs) -> Optional[int]: + """ + Compute the number of input tokens for a text prompt. + + Args: + text (str): Input text to tokenize. + **kwargs: Additional arguments (compute_input_tokens_flag, system_prompt). + + Returns: + Optional[int]: Token count if compute_input_tokens_flag is True, else None. + """ + compute_input_tokens_flag = kwargs.get("compute_input_tokens_flag", False) + if not compute_input_tokens_flag: + return None + t1 = time.time() + system_prompt = kwargs.get("system_prompt", None) + if system_prompt and isinstance(system_prompt, str): + text += system_prompt + if text and len(text): + input_text_ids = self.tokenizer.encode(text, add_special_tokens=True) + text_token_count = len(input_text_ids) + else: + text_token_count = 0 + t2 = time.time() + logger.info(f"Compute text input tokens, text_token_count: {text_token_count}, cost time: {t2-t1}s") + return text_token_count + + def compute_image_audio_video_input_tokens(self, prompt: str, image_inputs: list, + video_inputs: list, audio_inputs: list, + **kwargs) -> Tuple[Optional[int], Optional[int], Optional[int]]: + """ + Compute token counts for image, video, and audio inputs. + + This method processes multimodal inputs through the processor and + counts the number of patch tokens for each modality. + + Args: + prompt (str): Text prompt. + image_inputs (list): List of image inputs. + video_inputs (list): List of video inputs. + audio_inputs (list): List of audio inputs. + **kwargs: Additional arguments (compute_input_tokens_flag). + + Returns: + Tuple[Optional[int], Optional[int], Optional[int]]: Token counts for + (image, video, audio) if compute_input_tokens_flag is True. + """ + compute_input_tokens_flag = kwargs.get("compute_input_tokens_flag", False) + if not compute_input_tokens_flag: + return None, None, None + t1 = time.time() + inputs_processor = self.processor( + text=[prompt], + images=image_inputs, + videos=video_inputs, + audios=audio_inputs, + audio_kwargs={"use_whisper_encoder": True}, + return_tensors="pt", + ) + image_patch_id = self.tokenizer.convert_tokens_to_ids(self.image_path_token) + image_token_count = (inputs_processor['input_ids'] == image_patch_id).sum().item() + + #DEFAULT_FRAME_PATCH_TOKEN = "" + video_patch_id = self.tokenizer.convert_tokens_to_ids(self.frame_path_token) + video_token_count = (inputs_processor['input_ids'] == video_patch_id).sum().item() + + #DEFAULT_AUDIO_PATCH_TOKEN = "" + audio_patch_id = self.tokenizer.convert_tokens_to_ids(self.audio_path_token) + audio_token_count = (inputs_processor['input_ids'] == audio_patch_id).sum().item() + t2 = time.time() + logger.info(f"Compute image/video/audio tokens, cost time: {t2-t1}s") + return image_token_count, video_token_count, audio_token_count + + def build_prompt( + self, + prompt: str, + audio: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + video: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + image: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, + history: list = [], + **kwargs, + ) -> Dict[str, Any]: + """ + Build a prompt input for the model (common logic for text/audio/image generation). + + Args: + prompt (str): User input text. + audio (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Audio data (e.g., file path or binary or list). + video (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Video data (e.g., file path or binary or list). + image (Optional[Union[str, bytes, List[Union[str, bytes]]]]): Image data (file path, binary, or PIL Image or list). + history (list, optional): Conversation history. Defaults to empty list. + **kwargs: Additional parameters for prompt building. + + Returns: + Dict[str, Any]: A dictionary containing the built prompt. + """ + + os.environ["IMAGE_GEN_MODE"] = "" + current_sys_prompt = self.sys_prompt + use_cot_system_prompt = False + for key, value in kwargs.items(): + if key == "audio": + audio = value + if key == "video": + video = value + if key == "image": + image = value + if key == "system_prompt": + current_sys_prompt = value + if key == "use_cot" and isinstance(value, bool): + use_cot_system_prompt = value + + messages_video_and_audio = [{"role": "HUMAN", "content": []}] + if current_sys_prompt is not None and current_sys_prompt != "": + current_sys_prompt = current_sys_prompt + else: + current_sys_prompt = None + + if video is not None: + if isinstance(video, list): + videos = video + else: + videos = [video] + logger.info("llm activate video input") + + for single_video in videos: + messages_video_and_audio[0]["content"].append( + { + "type": "video", + "video": single_video, + "sample": "uniform", + "max_frames": self.max_frames, + } + ) + + if image is not None: + if isinstance(image, list): + images = image + else: + images = [image] + logger.info("llm activate image input") + messages_video_and_audio[0]["content"].append( + {"type": "image", "image": images} + ) + + if audio is not None: + if isinstance(audio, list): + audios = audio + else: + audios = [audio] + logger.info("llm activate audio input") + # audio = torch.from_numpy(audio).unsqueeze(0) + for single_audio in audios: + messages_video_and_audio[0]["content"].append( + { + "type": "audio", + "audio": single_audio, + "sample_rate": self.sample_rate, + } + ) + if prompt is not None: + logger.info("llm activate text input") + messages_video_and_audio[0]["content"].append( + {"type": "text", "text": prompt} + ) + # self.manage_history_query_message() + if len(history) > 0: + messages_video_and_audio = history + messages_video_and_audio + + logger.info("In ming_sdk, prompt: " + str(messages_video_and_audio)) + if self.limit_images and self.limit_videos: + messages_video_and_audio = self.filter_message( + messages_video_and_audio, self.limit_images, self.limit_videos + ) + logger.info(f"After filter, prompt: {messages_video_and_audio}, current_sys_prompt: {current_sys_prompt}, use_cot_system_prompt: {use_cot_system_prompt}") + + prompt = self.processor.apply_chat_template( + messages_video_and_audio, + sys_prompt_exp = current_sys_prompt, + use_cot_system_prompt = use_cot_system_prompt + ) + image_inputs, video_inputs, audio_inputs = self.processor.process_vision_info( + messages_video_and_audio + ) + + compute_input_tokens_flag = kwargs.get("compute_input_tokens_flag", False) + logger.info(f"In build_prompt, compute_input_tokens_flag: {compute_input_tokens_flag}") + image_token_count, video_token_count, audio_token_count = None, None, None + if compute_input_tokens_flag: + image_token_count, video_token_count, audio_token_count = self.compute_image_audio_video_input_tokens(prompt, image_inputs, video_inputs, audio_inputs, **kwargs) + logger.info(f"In build_prompt, image_token_count: {image_token_count}, video_token_count: {video_token_count}, audio_token_count: {audio_token_count}") + + requests = [] + inputs = LLMInputs( + { + "prompt": prompt, + } + ) + """" + "image": image_inputs, + "video": video_inputs, + "audio": audio_inputs, + """ + if image is not None or image_inputs is not None: + if "multi_modal_data" in inputs.keys(): + inputs["multi_modal_data"]["image"] = image_inputs + else: + inputs["multi_modal_data"] = {"image": image_inputs} + if video is not None or video_inputs is not None: + if "multi_modal_data" in inputs.keys(): + inputs["multi_modal_data"]["video"] = video_inputs + else: + inputs["multi_modal_data"] = {"video": video_inputs} + if audio is not None or audio_inputs is not None: + if "multi_modal_data" in inputs.keys(): + inputs["multi_modal_data"]["audio"] = audio_inputs + else: + inputs["multi_modal_data"] = {"audio": audio_inputs} + requests.append(inputs) + return requests, image_token_count, video_token_count, audio_token_count + + def build_img_prompt( + self, + prompt: str, + image: Optional[Union[str, bytes, Image.Image]] = None, + **kwargs, + ) -> Dict[str, Any]: + """ + Build a prompt for image generation based on input text and optional image. + + Args: + text (str): The text prompt for image generation. + image (Optional[Union[str, bytes, Image.Image]]): Optional input image (for editing mode). + **kwargs: Additional keyword arguments (unused in this method). + + Returns: + List[LLMInputs]: A list of LLM input objects containing the generated prompt and image data. + + Description: + - Constructs a message structure for the model. If no image is provided, a dummy image is used. + - The message order depends on whether an image is provided: + - If `image is None`: [Text, Dummy Image] + - Else: [Image, Text] + - Applies the chat template to generate a text prompt. + - Processes vision-related information (e.g., image inputs). + - Returns LLM input objects with the prompt and multi-modal data. + """ + if image is None: + messages = [ + { + "role": "HUMAN", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image", "image": Image.new("RGB", (1, 1), (0, 0, 0))}, + ], + } + ] + else: + if isinstance(image, str): + messages = [ + { + "role": "HUMAN", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": prompt}, + ], + }, + ] + elif isinstance(image, list): + messages = [ + { + "role": "HUMAN", + "content": [], + } + ] + for img in image: + messages[0]["content"].append({"type": "image", "image": img}) + messages[0]["content"].append({"type": "text", "text": prompt}) + + logger.info("Image task, prompt: " + str(messages)) + text = self.processor.apply_chat_template( + messages + ) + + image_inputs, video_inputs, audio_inputs = self.processor.process_vision_info( + messages + ) + + compute_input_tokens_flag = kwargs.get("compute_input_tokens_flag", False) + logger.info(f"In build_img_prompt, compute_input_tokens_flag: {compute_input_tokens_flag}") + image_token_count, video_token_count, audio_token_count = None, None, None + if compute_input_tokens_flag: + image_token_count, video_token_count, audio_token_count = self.compute_image_audio_video_input_tokens(text, image_inputs, video_inputs, audio_inputs, **kwargs) + logger.info(f"In build_img_prompt, image_token_count: {image_token_count}, video_token_count: {video_token_count}, audio_token_count: {audio_token_count}") + + requests = [ + LLMInputs({"prompt": text, "multi_modal_data": {"image": image_inputs}}), + ] + return requests, image_token_count, video_token_count, audio_token_count + + def build_img_gen_prompt( + self, + prompt: str, + image: Optional[Union[str, bytes, Image.Image]] = None, + device: str = "cuda:0", + image_gen_highres = False, + image_gen_aspect_ratio = None, + **kwargs, + ) -> Dict[str, Any]: + """ + Prepare input data for image generation or editing. + + Args: + text (str): The text prompt for image generation. + image (Optional[Union[str, bytes, Image.Image]]): Optional input image (for editing mode). + **kwargs: Additional keyword arguments (unused in this method). + + Returns: + Dict[str, torch.Tensor]: A dictionary of processed inputs in tensor format, including text and multi-modal data. + + Description: + - Constructs a message structure for the model. If no image is provided, only the text is included. + - Applies the chat template to generate a text prompt. + - Processes vision-related information (e.g., image inputs). + - Converts the inputs into PyTorch tensors and moves them to the GPU. + - Converts specific tensor types (e.g., pixel values) to `torch.bfloat16` for efficient inference. + """ + if image is None: + messages = [ + { + "role": "HUMAN", + "content": [{"type": "text", "text": prompt}], + } + ] + else: + if isinstance(image, str): + messages = [ + { + "role": "HUMAN", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": prompt}, + ], + }, + ] + elif isinstance(image, list): + messages = [ + { + "role": "HUMAN", + "content": [], + } + ] + for img in image: + messages[0]["content"].append({"type": "image", "image": img}) + messages[0]["content"].append({"type": "text", "text": prompt}) + text = self.processor.apply_chat_template(messages) + image_inputs, video_inputs, audio_inputs = self.processor.process_vision_info( + messages + ) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + audios=audio_inputs, + return_tensors="pt", + image_gen_highres=image_gen_highres, + image_gen_aspect_ratio=image_gen_aspect_ratio, + ).to(device) + + for k in inputs.keys(): + if k in [ + "pixel_values", + "pixel_values_videos", + "audio_feats", + "pixel_values_reference", + ]: + inputs[k] = inputs[k].to(dtype=torch.bfloat16) + return inputs + + def _check_and_download_message(self, messages) -> Dict[str, int]: + """ + Checks each message in the input list for media content and downloads it if necessary. + """ + pass + + +class StreamGenerator: + """ + Wrapper to convert an async generator to a synchronous iterator. + + This class enables iteration over async generators from synchronous + code by running the async operations in a separate event loop. + + Args: + loop: The asyncio event loop to use. + async_generator: The async generator to wrap. + """ + def __init__(self, loop, async_generator): + self.loop = loop + self.async_generator = async_generator + + def __iter__(self): + return self + + def __next__(self): + try: + future = asyncio.run_coroutine_threadsafe( + self.async_generator.__anext__(), self.loop + ) + return future.result() + except StopAsyncIteration: + raise StopIteration + + +class SimpleCache: + """ + A simple local cache with TTL expiration and LRU eviction. + + Features: + - Set and get values with optional TTL + - Automatic expiration checking on access + - LRU eviction when max_size is reached + - Thread-unsafe (use ThreadSafeCache for concurrent access) + + Args: + max_size (int): Maximum number of cached items. Defaults to 128. + default_ttl (int): Default time-to-live in seconds. Defaults to 300. + """ + + def __init__(self, max_size: int = 128, default_ttl: int = 300): + """ + Initialize the cache. + + Args: + max_size: Maximum number of cached items (for LRU eviction). + default_ttl: Default time-to-live in seconds. + """ + self.max_size = max_size + self.default_ttl = default_ttl + self._cache: OrderedDict[str, tuple] = OrderedDict() # key -> (value, expire_time) + + def _is_expired(self, expire_time: float) -> bool: + """Check if a cache entry has expired.""" + return time.time() > expire_time + + def get(self, key: str) -> Optional[Any]: + """ + Get a cached value, automatically cleaning up expired items. + + Args: + key (str): Cache key. + + Returns: + Optional[Any]: Cached value if exists and not expired, else None. + """ + if key not in self._cache: + return None + + value, expire_time = self._cache[key] + if self._is_expired(expire_time): + del self._cache[key] + return None + + # Move to end (mark as recently used for LRU) + self._cache.move_to_end(key) + return value + + def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: + """ + Set a value in the cache. + + Args: + key (str): Cache key. + value (Any): Value to cache. + ttl (Optional[int]): Time-to-live in seconds. Uses default_ttl if None. + """ + ttl = ttl or self.default_ttl + expire_time = time.time() + ttl + + # If already exists, delete first (to avoid order issues) + if key in self._cache: + del self._cache[key] + + # Check capacity and evict the oldest item if needed + if len(self._cache) >= self.max_size: + self._cache.popitem(last=False) # Pop the oldest (FIFO for LRU) + + self._cache[key] = (value, expire_time) + + def delete(self, key: str) -> bool: + """ + Delete a specific key from the cache. + + Args: + key (str): Cache key to delete. + + Returns: + bool: True if key existed and was deleted, False otherwise. + """ + return self._cache.pop(key, None) is not None + + def clear(self) -> None: + """Clear all items from the cache.""" + self._cache.clear() + + def size(self) -> int: + """Return the current number of cached items.""" + return len(self._cache) + + def keys(self) -> list: + """ + Get all valid keys (excluding expired ones). + + Returns: + list: List of non-expired cache keys. + """ + valid_keys = [] + for k, (_, expire_time) in self._cache.items(): + if self._is_expired(expire_time): + continue + valid_keys.append(k) + # Synchronously remove expired keys + for k in [k for k in self._cache if self._is_expired(self._cache[k][1])]: + del self._cache[k] + return valid_keys + + +class ThreadSafeCache(SimpleCache): + """ + Thread-safe version of SimpleCache using RLock. + + This class wraps SimpleCache with thread-safe operations for + concurrent access scenarios. + + Args: + max_size (int): Maximum number of cached items. Defaults to 128. + default_ttl (int): Default time-to-live in seconds. Defaults to 300. + """ + def __init__(self, max_size: int = 128, default_ttl: int = 300): + super().__init__(max_size, default_ttl) + self._lock = threading.RLock() + + def get(self, key: str) -> Optional[Any]: + with self._lock: + return super().get(key) + + def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: + with self._lock: + super().set(key, value, ttl) + + def delete(self, key: str) -> bool: + with self._lock: + return super().delete(key) + + def clear(self) -> None: + with self._lock: + super().clear() diff --git a/ming_sdk/monitoring/README.md b/ming_sdk/monitoring/README.md new file mode 100644 index 0000000..92741c4 --- /dev/null +++ b/ming_sdk/monitoring/README.md @@ -0,0 +1,59 @@ +# 请求指标监控使用说明 + +## 概述 +为三个推理服务(text、speech、image)添加了请求粒度的指标监控,通过日志输出。 + +## 支持指标 +每个请求都会记录以下指标: +- **ttft**:首个token响应时间(毫秒) +- **tpot**:每个token平均生成时间(毫秒) +- **e2e_latency**:端到端总延迟(毫秒) +- **input_token_length**:输入token长度 +- **output_token_length**:输出token长度 +- **status**:请求状态(success/fail) + +## 日志格式 +``` +2024-01-15 14:30:45,123 - ming_sdk.monitoring.request_metrics - INFO - [REQUEST_METRICS] service=[speech],request_id=[f47ac10b-58cc-4372-a567-0e02b2c3d479],timestamp=[2024-01-15T06:30:45.123456],status=[success],e2e_latency_ms=[1234.56],ttft_ms=[234.78],tpot_ms=[45.67],input_token_length=[25],output_token_length=[22050],speaker=[luna] +``` + +## 使用方法 +当前支持 moe、talker、img。如需新增可参考如下代码: + +```python +#!/usr/bin/env python3 + +def demonstrate_new_usage(): + from ming_sdk.monitoring.request_metrics import metrics_speech + + # 1. 创建状态对象 + state = metrics_speech.create_state() + + # 2. 设置初始信息 + state.input_token_length = len("输入文本") + + try: + # 3. 请求处理中... + state.record_first_token() # 记录首token时间 + state.increment_output_tokens(100) # 累计token数 + + # 4. 成功完成 + state.finish("success", speaker="luna") + + except Exception as e: + # 5. 失败完成 + state.finish("fail", error=str(e), speaker="luna") + + +if __name__ == "__main__": + demonstrate_new_usage() +``` + +### 查看日志 +使用关键字过滤日志: +```bash +grep "\[REQUEST_METRICS\]" application.log +``` + +### 日志级别 +所有监控日志使用INFO级别输出。 diff --git a/ming_sdk/monitoring/__init__.py b/ming_sdk/monitoring/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ming_sdk/monitoring/request_metrics.py b/ming_sdk/monitoring/request_metrics.py new file mode 100644 index 0000000..34ea668 --- /dev/null +++ b/ming_sdk/monitoring/request_metrics.py @@ -0,0 +1,159 @@ +import logging +import time +import os +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, Any, Optional + +REQUEST_METRICS_LOG_FILE = "/home/admin/logs/request_metrics.log" + + +def setup_request_metrics_logger(): + metrics_logger = logging.getLogger("request_metrics") + + if not metrics_logger.handlers: + log_dir = os.path.dirname(REQUEST_METRICS_LOG_FILE) + if not os.path.exists(log_dir): + try: + os.makedirs(log_dir, exist_ok=True) + except OSError as e: + print(f"Warning: Cannot create log directory {log_dir}: {e}") + return metrics_logger + + try: + from logging.handlers import RotatingFileHandler + + file_handler = RotatingFileHandler( + REQUEST_METRICS_LOG_FILE, + maxBytes=100 * 1024 * 1024, # 100MB + backupCount=5, + encoding="utf-8", + ) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + file_handler.setFormatter(formatter) + + metrics_logger.addHandler(file_handler) + metrics_logger.setLevel(logging.INFO) + except Exception as e: + print(f"Warning: Cannot setup file logging: {e}") + + return metrics_logger + + +# 获取独立配置的logger +logger = setup_request_metrics_logger() + + +@dataclass +class ReqState: + """请求状态统计类""" + + request_id: str + service_name: str + created_time: float = field(default_factory=time.time) + finished_time: float = 0.0 + first_token_time: float = 0.0 + last_time: float = 0.0 + input_token_length: int = 0 + output_token_length: int = 0 + finish_reason: str = "" + status: str = "success" + stream_mode: bool = False + extra_data: Dict[str, Any] = field(default_factory=dict) + is_record_input_token_length: bool = False + + def record_first_token(self): + """记录首token时间""" + if self.first_token_time == 0.0: + self.first_token_time = time.time() + + def record_input_tokens(self, count: int = 0): + """记录输入token数量""" + if not self.is_record_input_token_length: + self.input_token_length = count + self.is_record_input_token_length = True + + def increment_input_tokens(self, count: int = 0): + """增加token计数""" + self.input_token_length += count + + def increment_output_tokens(self, count: int = 0): + """增加token计数""" + self.output_token_length += count + + def finish(self, status: str = "success", **kwargs): + """完成请求统计""" + self.status = status + self.finished_time = time.time() + self.extra_data.update(kwargs) + self._log_metrics() + + def _log_metrics(self): + """打印日志""" + e2e_latency = (self.finished_time - self.created_time) * 1000 + + # 流式请求记录TTFT,非流式默认为0 + ttft = 0.0 + + if self.stream_mode and self.first_token_time: + ttft = (self.first_token_time - self.created_time) * 1000 + + # 根据流式/非流式计算不同的TPOT + if self.output_token_length > 0: + if not self.stream_mode: + # 非流式:tpot = e2e / output_token_length + tpot = e2e_latency / self.output_token_length + else: + # 流式:tpot = (e2e - ttft) / (output_token_length - 1) + # 注意:当output_token_length=1时,避免除以0 + if self.output_token_length > 1 and ttft > 0: + tpot = (e2e_latency - ttft) / (self.output_token_length - 1) + else: + tpot = 0.0 + else: + tpot = 0.0 + + log_parts = [ + f"service=[{self.service_name}]", + f"request_id=[{self.request_id}]", + f"timestamp=[{datetime.utcfromtimestamp(self.created_time).isoformat()}]", + f"status=[{self.status}]", + f"stream_mode=[{self.stream_mode}]", + f"e2e_latency_ms=[{round(e2e_latency, 2)}]", + f"ttft_ms=[{round(ttft, 2)}]", # 非流式显示0.0 + f"tpot_ms=[{round(tpot, 2)}]", + f"input_token_length=[{self.input_token_length}]", + f"output_token_length=[{self.output_token_length}]", + ] + + # 添加额外数据 + for key, value in self.extra_data.items(): + if value is not None: + log_parts.append(f"{key}=[{value}]") + + logger.info("[REQUEST_METRICS] " + ",".join(log_parts)) + + +class RequestMetrics: + def __init__(self, service_name: str): + self.service_name = service_name + + def create_state( + self, request_id: Optional[str] = "0", stream_mode: bool = False + ) -> ReqState: + """创建新的请求状态对象""" + return ReqState( + request_id=request_id, + service_name=self.service_name, + stream_mode=stream_mode, + ) + + +metrics_text = RequestMetrics("text") +metrics_image = RequestMetrics("image") +metrics_speech = RequestMetrics("speech") +metrics_tts = RequestMetrics("tts") +metrics_speech_text_audio = RequestMetrics("speech_audio") \ No newline at end of file diff --git a/ming_sdk/requirements.txt b/ming_sdk/requirements.txt new file mode 100644 index 0000000..16fad7a --- /dev/null +++ b/ming_sdk/requirements.txt @@ -0,0 +1,15 @@ +torch==2.7.1 +torchvision==0.22.1 +torchaudio==2.7.1 +flash-attn==2.8.0.post2 +diffusers==0.36.0 +tokenizers==0.22.2 +transformers==4.57.1 +decord==0.6.0 +onnxruntime==1.22.1 +inflect==7.5.0 +conformer==0.3.2 +lightning==2.5.2 +gdown==5.2.0 +openai-whisper==20240930 +numpy==1.26.4 \ No newline at end of file diff --git a/ming_sdk/setup.py b/ming_sdk/setup.py new file mode 100644 index 0000000..2707845 --- /dev/null +++ b/ming_sdk/setup.py @@ -0,0 +1,118 @@ +import os +import shutil +import os.path as osp +from setuptools import setup, find_packages + +__version__ = "1.0.0" # +requirement = open("ming_sdk/requirements.txt").readlines() + +__setup_name__ = "ming_sdk" + + +def fetch_installed_data(model_dir): + root_dir = os.path.dirname(os.path.abspath(__file__)) + root_dir = os.path.join(root_dir, model_dir) + data_files = [] + # root_dir = os.path.dirname(os.path.abspath(__file__)) + for root, dirs, files in os.walk(root_dir, topdown=False): + for f in files: + file_path = os.path.join(root, f) + # package c source & header files, and swig files + if ( + f.endswith(".cc") + or f.endswith(".h") + or f.endswith(".i") + or f.endswith(".yml") + or f.endswith(".yaml") + or f.endswith("*.md") + or f.endswith(".so") + or f.endswith(".py") + or f.endswith(".dylib") + or f.endswith(".engine") + or f.endswith(".txt") + or f.endswith(".jpg") + or f.endswith(".gz") + or f.endswith(".bk") + or f.endswith(".json") + or f.endswith(".bk") + or f.endswith(".cfg") + or f.endswith("cfg") + or f.endswith("model") + or f.endswith("moves") + or f.endswith("keyrow") + or f.endswith("tokenizer") + or f.endswith("vectors") + or f.endswith("patterns") + or f.endswith(".bin") + or f.endswith(".model") + or f.endswith(".pt") + or f.endswith(".wav") + or f.endswith(".fst") + ): + data_files.append(file_path) + return data_files + + +file_list = [ + "audio_processing_bailingmm2.py", + "bailingmm_utils.py", + "bailingmm_utils_video.py", + "chat_format.py", + "config.json", + "configuration_audio.py", + "configuration_bailingmm2.py", + "configuration_bailing_moe_v2.py", + "configuration_bailing_talker.py", + "configuration_whisper_encoder.py", + "cookbook.ipynb", + "image_processing_bailingmm2.py", + "modeling_bailingmm2.py", + "modeling_bailing_moe_v2.py", + "modeling_bailing_talker.py", + "modeling_utils.py", + "modeling_whisper_encoder.py", + "preprocessor_config.json", + "processing_bailingmm2.py", + "qwen3_moe_vit.py", + "s3bpe_tokenizer.py", + "special_tokens_map.json", + "tokenization_bailing.py", + "tokenizer_config.json", + "tokenizer.json", +] + +dir_list = [ + "data/", + "diffusion/", + "talker_tn/", + "talker_module", + "AudioVAE", + "bizgen", + "front" +] + +for i in file_list: + shutil.copy(i, "ming_sdk") + +for i in dir_list: + shutil.copytree(i, "ming_sdk/" + i, dirs_exist_ok=True) + +setup( + name="ming_sdk", + version=__version__, + author="qiaozhuo", + author_email="shoukui.xsk@antgroup.com", + url="", + classifiers=[ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: Implementation :: CPython", + "Operating System :: OS Independent", + ], + package_data={"ming_sdk": fetch_installed_data("")}, + description="Ming Multimodal sdk", + keywords="Ming Multimodal sdk", + packages=["ming_sdk"], + python_requires=">=3.9.0", +) diff --git a/ming_sdk/usage.py b/ming_sdk/usage.py new file mode 100644 index 0000000..2156543 --- /dev/null +++ b/ming_sdk/usage.py @@ -0,0 +1,254 @@ +import copy +from ming_sdk.ming_utils import ThreadSafeCache + +""" + +usage = { + "prompt_tokens": 0, + "generated_tokens": 0, + "total_tokens": 0, + "completion_tokens_details": { + "text_tokens": 0, + "audio_tokens": 0, + "image_tokens": 0, + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "image_tokens": 0, + "video_tokens": 0, + "text_tokens": 0, + }, + } + +""" + +usage_useless_value = [0, None] +extra_key = ["finish_reason"] + +def remove_zero_values(usage): + def check_useless(item): + for i in usage_useless_value: + if item == i: + return True + return False + def remove_extra_key(usage): + for i in extra_key: + if i in usage: + del usage[i] + + if usage is None: + return {} + new_usage = copy.deepcopy(usage) + nested_keys = [k for k, v in new_usage.items() if isinstance(v, dict)] + for key in nested_keys: + keys_to_remove = [k for k, v in new_usage[key].items() if check_useless(v)] + for k in keys_to_remove: + del new_usage[key][k] + if not new_usage[key]: + del new_usage[key] + + keys_to_remove = [k for k, v in new_usage.items() if check_useless(v)] + for key in keys_to_remove: + del new_usage[key] + + remove_extra_key(new_usage) + + return new_usage + + +class Usage(object): + + def __init__(self): + self.cache = ThreadSafeCache(max_size=500, default_ttl=1800) + + def create_usage(self, output): + usage = { + "prompt_tokens": 0, + "generated_tokens": 0, + "total_tokens": 0, + "completion_tokens_details": { + "text_tokens": 0, + "audio_tokens": 0, + "image_tokens": 0, + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "image_tokens": 0, + "video_tokens": 0, + "text_tokens": 0, + }, + "finish_reason": None + } + prompt_tokens = len(output.prompt_token_ids) + generated_tokens = len(output.outputs[0].token_ids) + usage["prompt_tokens"] = prompt_tokens + usage["generated_tokens"] = generated_tokens + usage["total_tokens"] = prompt_tokens + usage["generated_tokens"] + usage["finish_reason"] = output.outputs[0].finish_reason + return usage + + def create_usage_by_requests_id(self, output, request_id): + usage = { + "prompt_tokens": 0, + "generated_tokens": 0, + "total_tokens": 0, + "completion_tokens_details": { + "text_tokens": 0, + "audio_tokens": 0, + "image_tokens": 0, + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "image_tokens": 0, + "video_tokens": 0, + "text_tokens": 0, + }, + "finish_reason": None + } + prompt_tokens = len(output.prompt_token_ids) + generated_tokens = len(output.outputs[0].token_ids) + usage["prompt_tokens"] = prompt_tokens + usage["generated_tokens"] = generated_tokens + usage["total_tokens"] = prompt_tokens + usage["generated_tokens"] + usage["finish_reason"] = output.outputs[0].finish_reason + self.cache.set(f"{request_id}", usage, 1800) + return usage + + def get_stream_usage_by_request_id(self, request_id: int = 0): + usage = { + "prompt_tokens": 0, + "generated_tokens": 0, + "total_tokens": 0, + "completion_tokens_details": { + "text_tokens": 0, + "audio_tokens": 0, + "image_tokens": 0, + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "image_tokens": 0, + "video_tokens": 0, + "text_tokens": 0, + }, + "finish_reason": None + } + usage_res = self.cache.get(f"{request_id}") + if usage_res == None: + return usage + return copy.deepcopy(usage_res) + + @staticmethod + def update_audio_usage_by_duration(usage, duration): + if usage is None: + usage = { + "prompt_tokens": 0, + "generated_tokens": 0, + "total_tokens": 0, + "completion_tokens_details": { + "text_tokens": 0, + "audio_tokens": 0, + "image_tokens": 0, + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "image_tokens": 0, + "video_tokens": 0, + "text_tokens": 0, + }, + "finish_reason": None + } + audio_tokens = int(duration * 50) + text_tokens = usage["generated_tokens"] + + usage["generated_tokens"] += audio_tokens + usage["total_tokens"] += audio_tokens + usage["completion_tokens_details"]["audio_tokens"] = audio_tokens + usage["completion_tokens_details"]["text_tokens"] = text_tokens + return usage + + @staticmethod + def update_usage_by_processor(usage, text_token_count=None, image_token_count=None, video_token_count=None, audio_token_count=None): + if usage is None: + usage = { + "prompt_tokens": 0, + "generated_tokens": 0, + "total_tokens": 0, + "completion_tokens_details": { + "text_tokens": 0, + "audio_tokens": 0, + "image_tokens": 0, + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "image_tokens": 0, + "video_tokens": 0, + "text_tokens": 0, + }, + "finish_reason": None + } + input_tokens_sum = 0 + if image_token_count and isinstance(image_token_count, int): + usage["prompt_tokens_details"]["image_tokens"] = image_token_count + input_tokens_sum += image_token_count + if video_token_count and isinstance(video_token_count, int): + usage["prompt_tokens_details"]["video_tokens"] = video_token_count + input_tokens_sum += video_token_count + if audio_token_count and isinstance(audio_token_count, int): + usage["prompt_tokens_details"]["audio_tokens"] = audio_token_count + input_tokens_sum += audio_token_count + if text_token_count and isinstance(text_token_count, int): + usage["prompt_tokens_details"]["text_tokens"] = text_token_count + input_tokens_sum += text_token_count + if input_tokens_sum > 0: + usage["prompt_tokens"] = input_tokens_sum + usage["total_tokens"] = usage["prompt_tokens"] + usage["generated_tokens"] + return usage + + @staticmethod + def update_image_usage_by_length(usage, image_gen_highres): + if usage is None: + usage = { + "prompt_tokens": 0, + "generated_tokens": 0, + "total_tokens": 0, + "completion_tokens_details": { + "text_tokens": 0, + "audio_tokens": 0, + "image_tokens": 0, + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "image_tokens": 0, + "video_tokens": 0, + "text_tokens": 0, + }, + "finish_reason": None + } + image_tokens = int(image_gen_highres * image_gen_highres / 16 / 16) + text_tokens = usage["generated_tokens"] + usage["generated_tokens"] += image_tokens + usage["total_tokens"] += image_tokens + usage["completion_tokens_details"]["image_tokens"] = image_tokens + usage["completion_tokens_details"]["text_tokens"] = text_tokens + return usage + + @staticmethod + def create_usage_default(prompt_tokens=0): + usage = { + "prompt_tokens": prompt_tokens, + "generated_tokens": 0, + "total_tokens": prompt_tokens, + "completion_tokens_details": { + "text_tokens": 0, + "audio_tokens": 0, + "image_tokens": 0, + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "image_tokens": 0, + "video_tokens": 0, + "text_tokens": 0, + }, + "finish_reason": None + } + return usage diff --git a/modeling_bailing_moe.py b/modeling_bailing_moe.py new file mode 100644 index 0000000..a198065 --- /dev/null +++ b/modeling_bailing_moe.py @@ -0,0 +1,1808 @@ +# coding=utf-8 +# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch BailingMoE model.""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.init as init + +from tqdm import tqdm +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.generation import GenerationMixin +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available + +from configuration_bailing_moe import BailingMoeConfig + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "BailingMoeConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + warnings.warn( + "Calling `transformers.models.BailingMoe.modeling_BailingMoe._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask" + ) + return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + warnings.warn( + "Calling `transformers.models.BailingMoe.modeling_BailingMoe._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.BailingMoe.modeling_BailingMoe.AttentionMaskConverter._make_causal_mask" + ) + return AttentionMaskConverter._make_causal_mask( + input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length + ) + + +class BailingMoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + BailingMoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(BailingMoeRMSNorm) + + +class BailingMoeRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->BailingMoe +class BailingMoeLinearScalingRotaryEmbedding(BailingMoeRotaryEmbedding): + """BailingMoeRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->BailingMoe +class BailingMoeDynamicNTKScalingRotaryEmbedding(BailingMoeRotaryEmbedding): + """BailingMoeRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class BailingMoeYarnRotaryEmbedding(BailingMoeRotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + freq_inter = 1.0 / ( + self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False) + self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False) + + +class BailingMoe3DRotaryEmbedding(BailingMoeRotaryEmbedding): + def forward(self, x, position_ids): + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + inv_freq_expand = inv_freq[None, None, :, None].expand(3, 1, -1, 1) + position_ids_expand = position_ids[:, :, None, :].float() + + with torch.autocast(device_type=x.device.type, enabled=False): + freqs = (inv_freq_expand.to(x.device).float() @ position_ids_expand.to(x.device).float()).transpose(2, 3) + assert freqs.dtype == torch.float32 + emb = torch.cat((freqs, freqs), dim=-1) + cos_cached = emb.cos() + sin_cached = emb.sin() + return (cos_cached, sin_cached) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section=[16, 24, 24], unsqueeze_dim=1): + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class BailingMoeGroupedLinear(nn.Module): + def __init__(self, num_groups: int, in_features: int, out_features: int, bias: bool = False): + super().__init__() + self.num_groups = num_groups + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.empty(num_groups, out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.empty(num_groups, out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + + def forward(self, x): + raise NotImplementedError("`GroupedLinear` is a weight container for use with specialized kernels.") + + def extra_repr(self) -> str: + return f"num_groups={self.num_groups}, in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" + + +class BailingMoeGroupedMLP(nn.Module): + def __init__(self, config: BailingMoeConfig, intermediate_size: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size + + self.gate_proj = BailingMoeGroupedLinear(self.config.num_experts, self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = BailingMoeGroupedLinear(self.config.num_experts, self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = BailingMoeGroupedLinear(self.config.num_experts, self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + raise NotImplementedError("The MoE computation is performed in the parent `DeepseekV3MoE` module.") + + +class BailingMoeMLP(nn.Module): + def __init__(self, config: BailingMoeConfig, intermediate_size: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class BailingMoeGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim))) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states, sort=False): + bsz, seq_len, h = hidden_states.shape + # compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states, self.weight, None) + scores = logits.softmax(dim=-1, dtype=torch.float32) + + # select top-k experts + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1) + + # norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + topk_weight = topk_weight / denominator + + return topk_idx, topk_weight, logits + + +class BailingMoeSparseMoeBlock(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config: BailingMoeConfig): + super().__init__() + self.config = config + self.use_grouped_gemm = config.use_grouped_gemm + self.num_experts_per_tok = config.num_experts_per_tok + self._setup_experts() + self.multi_gate = config.multi_gate + if self.multi_gate: + self.image_gate = BailingMoeGate(config) + self.audio_gate = BailingMoeGate(config) + self.gate = BailingMoeGate(config) + if config.num_shared_experts is not None: + self.shared_experts = BailingMoeMLP( + config=config, intermediate_size=config.moe_intermediate_size * config.num_shared_experts + ) + + def _setup_experts(self): + if self.use_grouped_gemm: + self.experts = BailingMoeGroupedMLP(config=self.config, intermediate_size=self.config.moe_intermediate_size) + else: + self.experts = nn.ModuleList( + [ + BailingMoeMLP(config=self.config, intermediate_size=self.config.moe_intermediate_size) + for _ in range(self.config.num_experts) + ] + ) + + def create_mask(self, device, start_indices, end_indices, indices): + start_indices = torch.tensor(start_indices, device=device).view(-1, 1) + end_indices = torch.tensor(end_indices, device=device).view(-1, 1) + return (indices > start_indices) & (indices < end_indices) + + def forward( + self, + hidden_states: torch.Tensor, + image_mask: Optional[torch.Tensor] = None, + audio_mask: Optional[torch.Tensor] = None, + ): + identity = hidden_states + bsz, seq_len, h = hidden_states.shape + + if self.multi_gate: + # Get base text router results + topk_idx, topk_weight, router_logits = self.gate(hidden_states) + + # Verify mask consistency when both modalities exist + if image_mask is not None and audio_mask is not None: + assert torch.logical_and(image_mask, audio_mask).sum() == 0 + + # Process image modality + if image_mask is not None: + image_topk_idx, image_topk_weight, image_router_logits = self.image_gate(hidden_states) + image_mask = image_mask.reshape(bsz * seq_len, 1) + + topk_idx = topk_idx * ~image_mask + image_topk_idx * image_mask + topk_weight = topk_weight * ~image_mask + image_topk_weight * image_mask + router_logits = router_logits * ~image_mask + image_router_logits * image_mask + + # Process audio modality + if audio_mask is not None: + audio_topk_idx, audio_topk_weight, audio_router_logits = self.audio_gate(hidden_states) + audio_mask = audio_mask.reshape(bsz * seq_len, 1) + + topk_idx = topk_idx * ~audio_mask + audio_topk_idx * audio_mask + topk_weight = topk_weight * ~audio_mask + audio_topk_weight * audio_mask + router_logits = router_logits * ~audio_mask + audio_router_logits * audio_mask + + else: + topk_idx, topk_weight, router_logits = self.gate(hidden_states) + # hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + # flat_topk_idx = topk_idx.view(-1) + + if self.use_grouped_gemm: + import grouped_gemm.ops as ops + residuals = hidden_states + orig_shape = hidden_states.shape + flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + topk_idx = topk_idx.to(torch.int32) + batch_sizes = torch.bincount(topk_idx.flatten().cpu(), minlength=self.config.num_experts) + permuted_hidden_states, row_id_map = ops.permute(flat_hidden_states, topk_idx) + permuted_hidden_states = permuted_hidden_states.to(torch.bfloat16) + gate_out = ops.gmm(permuted_hidden_states, self.experts.gate_proj.weight, batch_sizes, trans_b=True) + up_out = ops.gmm(permuted_hidden_states, self.experts.up_proj.weight, batch_sizes, trans_b=True) + intermediate_out = self.experts.act_fn(gate_out) * up_out + expert_out = ops.gmm(intermediate_out, self.experts.down_proj.weight, batch_sizes, trans_b=True) + y = ops.unpermute(expert_out, row_id_map, topk_weight) + y = y.view(*orig_shape) + else: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training: + hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0) + y = torch.empty_like(hidden_states) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.to(hidden_states.dtype).view(bsz, seq_len, h) + else: + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(bsz, seq_len, h) + + if self.config.num_shared_experts is not None: + y = y + self.shared_experts(identity) + return y, (router_logits.view(bsz, seq_len, -1), topk_idx.view(bsz, seq_len, -1)) + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + tokens_per_expert = tokens_per_expert.cpu().numpy() + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->BailingMoe +class BailingMoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: BailingMoeConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim or self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + self.query_key_value = nn.Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=config.use_qkv_bias, + ) + self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = BailingMoeRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = BailingMoeLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = BailingMoeDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = BailingMoeYarnRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + elif scaling_type == "3D": + self.rotary_emb = BailingMoe3DRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + + qkv = self.query_key_value(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + + query_states, key_states, value_states = qkv.split( + [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2 + ) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + if self.config.rope_scaling is not None and self.config.rope_scaling["type"] == "3D": + cos, sin = self.rotary_emb(value_states, position_ids=position_ids) + query_states, key_states = apply_multimodal_rotary_pos_emb(query_states, key_states, cos, sin) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states / math.sqrt(self.head_dim), key_states.transpose(2, 3)) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.dense(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->BailingMoe +class BailingMoeFlashAttention2(BailingMoeAttention): + """ + BailingMoe flash attention module. This module inherits from `BailingMoeAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # BailingMoeFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + + qkv = self.query_key_value(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + + query_states, key_states, value_states = qkv.split( + [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2 + ) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + if self.config.rope_scaling is not None and self.config.rope_scaling["type"] == "3D": + cos, sin = self.rotary_emb(value_states, position_ids=position_ids) + query_states, key_states = apply_multimodal_rotary_pos_emb(query_states, key_states, cos, sin) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently cast in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slow down training & inference so it is recommended to not cast the LayerNorms + # in fp32. (BailingMoeRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.dense(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + query_length (`int`): + The length of the query sequence in terms of tokens. This represents the number of tokens in the + `query_states` tensor along the sequence dimension. It is used to determine the effective sequence + length for attention computations. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in BailingMoeFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->BailingMoe +class BailingMoeSdpaAttention(BailingMoeAttention): + """ + BailingMoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `BailingMoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from BailingMoeAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "BailingMoeModel is using BailingMoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + qkv = self.query_key_value(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + + query_states, key_states, value_states = qkv.split( + [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2 + ) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + if self.config.rope_scaling is not None and self.config.rope_scaling["type"] == "3D": + cos, sin = self.rotary_emb(value_states, position_ids=position_ids) + query_states, key_states = apply_multimodal_rotary_pos_emb(query_states, key_states, cos, sin) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.dense(attn_output) + + return attn_output, None, past_key_value + + +BAILING_MOE_ATTENTION_CLASSES = { + "eager": BailingMoeAttention, + "flash_attention_2": BailingMoeFlashAttention2, + "sdpa": BailingMoeSdpaAttention, +} + + +class BailingMoeDecoderLayer(nn.Module): + def __init__(self, config: BailingMoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attention = BAILING_MOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = ( + BailingMoeSparseMoeBlock(config) + if (config.num_experts is not None and layer_idx >= config.first_k_dense_replace) + else BailingMoeMLP(config=config, intermediate_size=config.intermediate_size) + ) + self.input_layernorm = BailingMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = BailingMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + image_mask: Optional[torch.Tensor] = None, + audio_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): + cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, image_mask, audio_mask) + if isinstance(hidden_states, tuple): + hidden_states, router_logits = hidden_states + else: + router_logits = None + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +BAILINGMOE_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BailingMoeConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare BailingMoe Model outputting raw hidden-states without any specific head on top.", + BAILINGMOE_START_DOCSTRING, +) +class BailingMoePreTrainedModel(PreTrainedModel): + config_class = BailingMoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BailingMoeDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def __init__(self, config: BailingMoeConfig): + super().__init__(config) + + +BAILINGMOE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare BailingMoe Model outputting raw hidden-states without any specific head on top.", + BAILINGMOE_START_DOCSTRING, +) +class BailingMoeModel(BailingMoePreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BailingMoeDecoderLayer`] + + Args: + config: BailingMoeConfig + """ + + def __init__(self, config: BailingMoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [BailingMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._use_sdpa = config._attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = BailingMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, value): + self.word_embeddings = value + + def _fuse_experts(self): + for layer in tqdm(self.layers, desc="Fusing experts"): + if isinstance(layer.mlp, BailingMoeSparseMoeBlock) and not layer.mlp.use_grouped_gemm: + grouped_experts = BailingMoeGroupedMLP(config=layer.mlp.config, intermediate_size=layer.mlp.config.moe_intermediate_size) + + gate_weights = torch.stack([expert.gate_proj.weight for expert in layer.mlp.experts]) + grouped_experts.gate_proj.weight.data = gate_weights + del gate_weights + + up_weights = torch.stack([expert.up_proj.weight for expert in layer.mlp.experts]) + grouped_experts.up_proj.weight.data = up_weights + del up_weights + + down_weights = torch.stack([expert.down_proj.weight for expert in layer.mlp.experts]) + grouped_experts.down_proj.weight.data = down_weights + del down_weights + + layer.mlp.experts = grouped_experts + layer.mlp.use_grouped_gemm = True + + def _unfuse_experts(self): + for layer in tqdm(self.layers, desc="Unfusing experts"): + if isinstance(layer.mlp, BailingMoeSparseMoeBlock) and layer.mlp.use_grouped_gemm: + grouped_experts = layer.mlp.experts + gate_weights = grouped_experts.gate_proj.weight.data + up_weights = grouped_experts.up_proj.weight.data + down_weights = grouped_experts.down_proj.weight.data + + config = layer.mlp.config + num_experts = config.num_experts + experts = nn.ModuleList( + [BailingMoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(num_experts)] + ) + + for i in range(num_experts): + experts[i].gate_proj.weight.data = gate_weights[i] + experts[i].up_proj.weight.data = up_weights[i] + experts[i].down_proj.weight.data = down_weights[i] + + layer.mlp.experts = experts + layer.mlp.use_grouped_gemm = False + + @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + image_mask: Optional[torch.Tensor] = None, + audio_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." + ) + use_cache = False + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + image_mask, + audio_mask, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + image_mask=image_mask, + audio_mask=audio_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits and layer_outputs[-1] is not None: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +class BailingMoeForCausalLM(BailingMoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: BailingMoeConfig): + super().__init__(config) + self.model = BailingMoeModel(config) + self.vocab_size = config.vocab_size + self.norm_head = config.norm_head + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.word_embeddings + + def set_input_embeddings(self, value): + self.model.word_embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def compute_logit(self, hidden_states): + if self.norm_head: + if self.training: + norm_weight = ( + self.lm_head.weight / (torch.norm(self.lm_head.weight, p=2, dim=0, keepdim=True) + 1e-7).detach() + ) + logits = F.linear(hidden_states, norm_weight, None) + else: + self.lm_head.weight.data = ( + self.lm_head.weight.data.float() + / (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7) + ).to(hidden_states.dtype) + logits = F.linear(hidden_states, self.lm_head.weight.data, None) + self.norm_head = False + else: + logits = self.lm_head(hidden_states) + return logits + + def fuse_experts(self): + import importlib.util + + if importlib.util.find_spec("grouped_gemm") is None: + raise ImportError( + "Please install grouped_gemm to use use_grouped_gemm=True. " + "You can install it with `pip install git+https://github.com/fanshiqing/grouped_gemm@main`" + ) + self.model._fuse_experts() + + def unfuse_experts(self): + self.model._unfuse_experts() + + @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + image_mask: Optional[torch.Tensor] = None, + audio_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer + + >>> model = BailingMoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + image_mask=image_mask, + audio_mask=audio_mask, + **kwargs, + ) + + hidden_states = outputs[0] + + logits = self.compute_logit(hidden_states=hidden_states) + logits = logits.float() + + loss = None + aux_loss = None + + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + position_ids=None, + inputs_embeds=None, + cache_position=None, + image_mask=None, + audio_mask=None, + rope_deltas=None, + **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = ( + past_key_values.get_max_length() + if hasattr(past_key_values, "get_max_length") + else past_key_values.get_max_cache_shape() + ) + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + if inputs_embeds is not None: + input_ids = input_ids[:, -cache_position.shape[0]:] + elif input_ids.shape[1] != cache_position.shape[0]: + input_ids = input_ids[:, cache_position] + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: + model_inputs = {"inputs_embeds": inputs_embeds} + if rope_deltas is not None: + self.rope_deltas = rope_deltas + else: + model_inputs = {"input_ids": input_ids} + image_mask = None + audio_mask = None + if rope_deltas is not None: + batch_size, seq_length = input_ids.shape + if past_key_values and self.rope_deltas: + delta = past_key_values[0][1].shape[2] + self.rope_deltas + elif past_key_values: + delta = torch.tensor(past_key_values[0][1].shape[2]).to(input_ids.device) + else: + delta = torch.tensor(0).to(input_ids.device) + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "image_mask": image_mask, + "audio_mask": audio_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past \ No newline at end of file diff --git a/test_infer_imagegen_taylor.py b/test_infer_imagegen_taylor.py new file mode 100644 index 0000000..e9b89c3 --- /dev/null +++ b/test_infer_imagegen_taylor.py @@ -0,0 +1,206 @@ +import os +import torch +import time +import numpy as np +from bisect import bisect_left + +from tqdm import tqdm + +from transformers import ( + AutoProcessor, +) + +from modeling_bailingmm2 import BailingMM2NativeForConditionalGeneration + +import warnings + +warnings.filterwarnings("ignore") + +from IPython import embed + +import json +from PIL import Image + + +# def split_model(): +# device_map = {} +# world_size = torch.cuda.device_count() +# num_layers = 32 +# layer_per_gpu = num_layers // world_size +# layer_per_gpu = [i * layer_per_gpu - 1 for i in range(1, world_size + 1)] +# for i in range(num_layers): +# device_map[f'model.model.layers.{i}'] = bisect_left(layer_per_gpu, i) + +# device_map['vision'] = 0 +# device_map['audio'] = 0 +# device_map['linear_proj'] = 0 +# device_map['linear_proj_audio'] = 0 +# device_map['model.model.word_embeddings.weight'] = 0 +# device_map['model.model.norm.weight'] = 0 +# device_map['model.lm_head.weight'] = 0 +# device_map['model.model.norm'] = 0 +# device_map[f'model.model.layers.{num_layers - 1}'] = 0 +# return device_map + +def split_model(): + device_map = {} + world_size = torch.cuda.device_count() - 1 + print(world_size) + num_layers = 32 + layer_per_gpu = num_layers // world_size + layer_per_gpu = [i * layer_per_gpu - 1 for i in range(1, world_size + 1)] + for i in range(num_layers): + device_id = bisect_left(layer_per_gpu, i) + 1 + #print(device_id) + if device_id > world_size: + device_id = i % world_size + 1 + + print(device_id) + + device_map[f'model.model.layers.{i}'] = device_id + + device_map['vision'] = 0 + device_map['audio'] = 0 + device_map['linear_proj'] = 0 + device_map['linear_proj_audio'] = 0 + device_map['model.model.word_embeddings.weight'] = 0 + device_map['model.model.norm.weight'] = 0 + device_map['model.lm_head.weight'] = 0 + device_map['model.model.norm'] = 0 + device_map[f'model.model.layers.{num_layers - 1}'] = 0 + return device_map + + +if __name__ == '__main__': + + model_name_or_path = "/input/yushen.ys/checkpoints/bailing_native_moe_ming_flash_v2.0_xpo_final_20260205" + #"/nativemm/share/cpfs/yuxuzheng.yxz/release/bailing_native_moe_ming_flash_v2.0_xpo_final_20260205_hf_metax_ais16863699" + #"/nativemm/share/cpfs/weilong.cwl/checkpoints/megatron_flashv2.0_sft1_hf_metax/" #"." + code_path = "." + processor = AutoProcessor.from_pretrained(code_path, trust_remote_code=True) + save_dir = "./generated_imgs" + if not os.path.exists(save_dir): + os.mkdir(save_dir) + + + model = BailingMM2NativeForConditionalGeneration.from_pretrained( + model_name_or_path, + dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map=split_model(), + load_image_gen=True, + ).to(dtype=torch.bfloat16) + + model.diffusion_loss.pipelines.set_taylor_cache() + + prompt = "Draw a beautiful girl with short black hair and red dress." + messages = [ + { + "role": "HUMAN", + "content": [ + {"type": "text", "text": prompt}, + ], + } + ] + + text = processor.apply_chat_template(messages, add_generation_prompt=True) + image_inputs, _, _ = processor.process_vision_info(messages) + + inputs = processor( + text=[text], + images=image_inputs, + return_tensors="pt", + ).to(model.device) + + for k in inputs.keys(): + if k in ["pixel_values", "pixel_values_videos", "audio_feats", "pixel_values_reference"]: + inputs[k] = inputs[k].to(dtype=torch.bfloat16) + + + print(f"Instruction: {prompt}") + # set `image_gen=True` to enable image generation + image = model.generate( + **inputs, + image_gen=True, + image_gen_seed=42, + ) + save_path = os.path.join(save_dir, "./t2i_girl_taylor_cache_18.jpg") + image.save(save_path) + print(f"saved to {save_path}") + + + prompt = "背景换成沙滩, 动作是拿手机自拍." + messages = [ + { + "role": "HUMAN", + "content": [ + {"type": "image", "image": save_path}, + {"type": "text", "text": prompt}, + ], + } + ] + + text = processor.apply_chat_template(messages, add_generation_prompt=True) + image_inputs, _, _ = processor.process_vision_info(messages) + + ref_image_inputs = processor.process_reference_vision_info(messages) + inputs = processor( + text=[text], + images=image_inputs, + return_tensors="pt", + image_gen_ref_images=ref_image_inputs, + ) + + inputs = inputs.to(model.device) + + for k in inputs.keys(): + if k in ["pixel_values", "pixel_values_videos", "audio_feats", "pixel_values_reference"]: + inputs[k] = inputs[k].to(dtype=torch.bfloat16) + + print(f"Instruction: {prompt}; Input image: {save_path}") + # set `image_gen=True` to enable image generation + image = model.generate( + **inputs, + image_gen=True, + image_gen_seed=43, + ) + save_path = os.path.join(save_dir, "./edit_girl_taylor_cache_18.jpg") + image.save(save_path) + print(f"saved to {save_path}") + + + prompt = "A whimsical comic-style illustration of a cozy bookstore entrance on a sunny afternoon. The storefront features warm brick walls and large glass windows filled with stacked books and potted ferns. Above the wooden door hangs a hand-painted signboard with bold, stylized Chinese characters reading “理解与生成统一” accented with curling vines and tiny stars. Sunlight casts playful shadows on the cobblestone path leading to the door, where a vintage lantern in a sunbeam add charm. The linework is clean, colors vibrant yet soft, evoking a friendly, storybook atmosphere. No people or vehicles are present, emphasizing quiet serenity." + messages = [ + { + "role": "HUMAN", + "content": [ + {"type": "text", "text": prompt}, + ], + } + ] + + text = processor.apply_chat_template(messages, add_generation_prompt=True) + image_inputs, _, _ = processor.process_vision_info(messages) + + inputs = processor( + text=[text], + images=image_inputs, + return_tensors="pt", + ).to(model.device) + + for k in inputs.keys(): + if k in ["pixel_values", "pixel_values_videos", "audio_feats", "pixel_values_reference"]: + inputs[k] = inputs[k].to(dtype=torch.bfloat16) + + + print(f"Instruction: {prompt}") + # set `image_gen=True` to enable image generation + image = model.generate( + **inputs, + image_gen=True, + image_gen_seed=42, + ) + save_path = os.path.join(save_dir, "./t2i_text_taylor_cache_18.jpg") + image.save(save_path) + print(f"saved to {save_path}") + \ No newline at end of file diff --git a/tokenization_bailing.py b/tokenization_bailing.py index 4bfbb1d..b52f306 100644 --- a/tokenization_bailing.py +++ b/tokenization_bailing.py @@ -240,6 +240,7 @@ def apply_chat_template( return_dict=return_dict, return_assistant_tokens_mask=return_assistant_tokens_mask, tokenizer_kwargs=tokenizer_kwargs, + **kwargs, ) # 非chat_template方式后续将不再支持。 diff --git a/tokenizer_config.json b/tokenizer_config.json index 57aeb74..26fd34e 100644 --- a/tokenizer_config.json +++ b/tokenizer_config.json @@ -2321,7 +2321,7 @@ "" ], "bos_token": "<|startoftext|>", - "chat_template": "{% set thinking_option = 'off' %}\n{{- 'SYSTEM' }}\n{%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n{%- endif %}\n{%- if tools %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n\\n\" }}\n{%- endif %}\n{{- 'detailed thinking ' + thinking_option + '<|role_end|>' }}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if message.role == \"user\" %}\n {{- 'HUMAN' + message.content + '<|role_end|>' }}\n {%- elif message.role == \"system\" and not loop.first %}\n {{- 'SYSTEM' + message.content + '<|role_end|>' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in content %}\n {%- set reasoning_content = content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- set content = content.split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if reasoning_content %}\n {{- 'ASSISTANT' + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- 'ASSISTANT' + content }}\n {%- endif %}\n {%- else %}\n {{- 'ASSISTANT' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|role_end|>' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- 'OBSERVATION' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|role_end|>' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- 'ASSISTANT' }}\n{%- endif %}", + "chat_template": "{{- 'SYSTEM' }}\n{%- if messages and messages[0].role == 'system' and messages[0].content is string %}\n {{- messages[0].content }}\n{%- else %}\n {{- '你是一个友好的AI助手。' }}\n{%- endif %}\n\n{{- '\n\ndetailed thinking ' }}\n{{- 'on' if enable_thinking else 'off' }}\n{{- '<|role_end|>' }}\n\n{%- for message in messages %}\n {%- set role_lower = message.role | lower %}\n \n {%- if role_lower in ['user', 'human'] -%}\n HUMAN\n {%- set contents = message.content if message.content is iterable and message.content is not string else [message.content] %}\n {%- for part in contents %}\n {%- if part.type is defined %}\n {%- if part.type == 'image' or 'image_url' in part -%}\n \n {%- elif part.type == 'video' or 'video_url' in part -%}\n