From fdb1b485651199715d4dac835c3955e7fb433a23 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Wed, 25 Mar 2026 11:18:18 +0900 Subject: [PATCH 01/16] EXAONE-4.5 support Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- tensorrt_llm/_torch/models/__init__.py | 2 + .../checkpoints/hf/exaone4_5_weight_mapper.py | 26 ++ .../_torch/models/modeling_exaone4_5.py | 224 ++++++++++++++++++ .../_torch/models/modeling_qwen2vl.py | 30 ++- 4 files changed, 273 insertions(+), 9 deletions(-) create mode 100644 tensorrt_llm/_torch/models/checkpoints/hf/exaone4_5_weight_mapper.py create mode 100644 tensorrt_llm/_torch/models/modeling_exaone4_5.py diff --git a/tensorrt_llm/_torch/models/__init__.py b/tensorrt_llm/_torch/models/__init__.py index aabd0d50483..157cd56e605 100644 --- a/tensorrt_llm/_torch/models/__init__.py +++ b/tensorrt_llm/_torch/models/__init__.py @@ -6,6 +6,7 @@ from .modeling_cohere2 import Cohere2ForCausalLM from .modeling_deepseekv3 import DeepseekV3ForCausalLM from .modeling_exaone4 import Exaone4ForCausalLM +from .modeling_exaone4_5 import Exaone4_5_ForConditionalGeneration from .modeling_exaone_moe import ExaoneMoeForCausalLM from .modeling_gemma3 import Gemma3ForCausalLM from .modeling_gemma3vl import Gemma3VLM @@ -48,6 +49,7 @@ "CLIPVisionModel", "DeepseekV3ForCausalLM", "Exaone4ForCausalLM", + "Exaone4_5_ForConditionalGeneration", "ExaoneMoeForCausalLM", "Gemma3ForCausalLM", "Gemma3VLM", diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/exaone4_5_weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/hf/exaone4_5_weight_mapper.py new file mode 100644 index 00000000000..1bbbd8909b4 --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/hf/exaone4_5_weight_mapper.py @@ -0,0 +1,26 @@ +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import ConsumableWeightsDict +from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper +from tensorrt_llm._torch.models.modeling_utils import register_mapper + + +@register_mapper("HF", "Exaone4_5_ForConditionalGeneration") +class Exaone4_5HfWeightMapper(HfWeightMapper): + def __init__(self): + super().__init__() + + def preprocess_weights(self, weights: dict): + """Rename HF checkpoint prefixes; supports plain dict and ConsumableWeightsDict.""" + is_consumable = isinstance(weights, ConsumableWeightsDict) + renamed = {} + for key, value in weights.items(): + if key.startswith("model.visual."): + new_key = key.replace("model.visual.", "visual.") + renamed[new_key] = value + elif key.startswith("model.language_model."): + new_key = key.replace("model.language_model.", "model.") + renamed[new_key] = value + else: + renamed[key] = value + if is_consumable: + return ConsumableWeightsDict(renamed) + return renamed diff --git a/tensorrt_llm/_torch/models/modeling_exaone4_5.py b/tensorrt_llm/_torch/models/modeling_exaone4_5.py new file mode 100644 index 00000000000..c778ca0528d --- /dev/null +++ b/tensorrt_llm/_torch/models/modeling_exaone4_5.py @@ -0,0 +1,224 @@ +# Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import copy +from typing import List, Optional, Tuple + +import torch +from transformers import AutoTokenizer, PretrainedConfig, PreTrainedModel + +from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper +from tensorrt_llm._torch.models.modeling_multimodal_utils import _is_disagg + +from ...inputs import ( + ExtraProcessedInputs, + MultimodalPlaceholderMetadata, + MultimodalPlaceholderPlacement, + TextPrompt, + register_input_processor, +) +from ...sampling_params import SamplingParams +from ..attention_backend import AttentionMetadata +from .checkpoints.hf.exaone4_5_weight_mapper import Exaone4_5HfWeightMapper +from .modeling_auto import AutoModelForCausalLM +from .modeling_multimodal_utils import ( + find_input_mm_embeds, + fuse_input_embeds, + get_multimodal_embeddings, +) +from .modeling_qwen2vl import ( + Qwen2_5_VisionModel, + Qwen2VisionModelBase, + Qwen2VLInputProcessorBase, + Qwen2VLModelBase, +) +from .modeling_utils import ModelConfig, register_auto_model, register_vision_encoder + + +class Exaone4_5InputProcessor(Qwen2VLInputProcessorBase): + def __init__( + self, + model_path: str, + config: PretrainedConfig, + tokenizer: AutoTokenizer, + trust_remote_code: bool = True, + **kwargs, + ): + super().__init__( + model_path=model_path, + config=config, + tokenizer=tokenizer, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + @torch.inference_mode() + def __call__( + self, + inputs: TextPrompt, + sampling_params: SamplingParams, + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + text_prompt, mm_data, mm_processor_kwargs = ( + inputs.get("prompt"), + inputs.get("multi_modal_data", {}), + inputs.get("mm_processor_kwargs", {}), + ) + processed_inputs = self._preprocess(text_prompt, mm_data, mm_processor_kwargs) + + multimodal_data = {} + pixel_values = processed_inputs.get("pixel_values", None) + if pixel_values is not None: + multimodal_data["image"] = { + "pixel_values": pixel_values.to(self.dtype), + "image_grid_thw": processed_inputs.get("image_grid_thw"), + } + + pixel_values_videos = processed_inputs.get("pixel_values_videos", None) + if pixel_values_videos is not None: + multimodal_data["video"] = { + "pixel_values_videos": pixel_values_videos.to(self.dtype), + "video_grid_thw": processed_inputs.get("video_grid_thw"), + } + fused_input_ids = processed_inputs["input_ids"][0] + if mm_data: + fused_input_ids = self._postprocess(fused_input_ids) + + return fused_input_ids.to(torch.int32).tolist(), { + "multimodal_data": multimodal_data, + } + + +class Exaone4_5_VisionModel(Qwen2VisionModelBase): + def __init__( + self, model_config: ModelConfig[PretrainedConfig], model_class: type[Qwen2_5_VisionModel] + ): + super().__init__(model_config, model_class=model_class) + self.config.tie_word_embeddings = False + + +class Exaone4_5_VLModel(Qwen2VLModelBase): + def __init__( + self, + model_config: ModelConfig[PretrainedConfig], + *args, + **kwargs, + ) -> None: + self.original_arch = model_config.pretrained_config.architectures[0] + + # model_config.pretrained_config.rope_scaling['type'] = 'mrope' + config = model_config.pretrained_config + + self._supports_sdpa = True + PreTrainedModel.__init__(self, config) + + self.model_config = model_config + self.config = model_config.pretrained_config + + if model_config.attn_backend != "TRTLLM": + raise ValueError("Exaone4.5 only supports TRTLLM backend") + + llm_model_config = copy.deepcopy(model_config) + # llm_model_config.pretrained_config.architectures = ["Exaone4ForCausalLM"] + llm_model_config.pretrained_config = llm_model_config.pretrained_config.text_config + self.llm = AutoModelForCausalLM.from_config(llm_model_config) + + if not _is_disagg(): + mm_encoder_config = copy.deepcopy(model_config) + self.mm_encoder = Exaone4_5_VisionModel( + mm_encoder_config, + kwargs.get("vision_model_class", Qwen2_5_VisionModel), + ) + else: + self.mm_encoder = None + + def infer_max_seq_len(self) -> int: + return self.llm.infer_max_seq_len() + + @torch.inference_mode() + def forward( + self, + attn_metadata: AttentionMetadata, + input_ids: Optional[torch.IntTensor] = None, + position_ids: Optional[torch.IntTensor] = None, + input_embeds: Optional[torch.Tensor] = None, + return_context_logits: bool = False, + **kwargs, + ) -> torch.Tensor: + multimodal_params = kwargs.get("multimodal_params", []) + mm_embeds = [] + + if len(multimodal_params) > 0: + if not _is_disagg(): + mm_embeds = get_multimodal_embeddings( + encoder_forward_fn=self.mm_encoder.forward, + multimodal_params=multimodal_params, + ) + else: + raise NotImplementedError( + "Exaone4.5-VL does not support disaggregated inference yet. " + "Unset TLLM_MULTIMODAL_DISAGGREGATED or set it to '0'." + ) + mm_embeds = find_input_mm_embeds(mm_embeds, multimodal_params) + + input_ids, input_embeds = fuse_input_embeds( + self.llm.model.embed_tokens, + input_ids, + mm_embeds, + **kwargs, + ) + + output_prob = self.llm.forward( + attn_metadata=attn_metadata, + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=input_embeds, + return_context_logits=return_context_logits, + ) + return output_prob + + +@register_vision_encoder(Exaone4_5_VLModel, vlm_base_model=Qwen2_5_VisionModel) +@register_auto_model("Exaone4_5_ForConditionalGeneration") +@register_input_processor( + Exaone4_5InputProcessor, + model_type="exaone4_5", + placeholder_metadata=MultimodalPlaceholderMetadata( + placeholder_map={ + "image": "<|vision_start|><|image_pad|><|vision_end|>", + "video": "<|vision_start|><|video_pad|><|vision_end|>", + }, + placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT, + ), +) +class Exaone4_5_ForConditionalGeneration(Exaone4_5_VLModel): + def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, **kwargs): + kwargs["vision_model_class"] = Qwen2_5_VisionModel + super().__init__(model_config, *args, **kwargs) + + @property + def multimodal_data_device_paths(self) -> List[str]: + return [ + "image.pixel_values", + "video.pixel_values_videos", + "multimodal_embedding", + ] + + def load_weights(self, weights, weight_mapper: BaseWeightMapper): + assert isinstance(weight_mapper, Exaone4_5HfWeightMapper) + weights = weight_mapper.preprocess_weights(weights) + if not _is_disagg(): + self.mm_encoder.load_weights(weights) + self.llm.load_weights(weights, weight_mapper) diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 453f28cac41..951cd95e6ca 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -79,6 +79,10 @@ def __init__(self, self.temporal_patch_size = getattr(self.config.vision_config, 'temporal_patch_size', 1) + def get_vocab_size(self) -> int: + """Return the vocab size of the model.""" + return self.config.text_config.vocab_size + @property def config(self) -> PretrainedConfig: return self._config @@ -511,7 +515,8 @@ def __init__(self, super().__init__( hidden_size=config.hidden_size, num_attention_heads=config.num_heads, - num_key_value_heads=config.num_heads, + num_key_value_heads=config.num_key_value_heads + if config.num_key_value_heads is not None else config.num_heads, max_position_embeddings=model_config.pretrained_config. max_position_embeddings, bias=True, @@ -523,6 +528,19 @@ def __init__(self, reduce_output=reduce_output, ) + def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + position_embeddings: Tuple[torch.Tensor, torch.Tensor]): + seq_len, _ = q.size() + q = q.view(seq_len, -1, self.head_dim) + k = k.view(seq_len, -1, self.head_dim) + v = v.view(seq_len, -1, self.head_dim) + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + q, k, v = q.reshape(seq_len, -1), k.reshape(seq_len, + -1), v.reshape(seq_len, -1) + return q, k, v + def forward( self, hidden_states: torch.Tensor, @@ -538,16 +556,10 @@ def forward( qkv = self.qkv_proj(hidden_states) q, k, v = qkv, None, None q, k, v = self.split_qkv(q, k, v) - seq_length = hidden_states.shape[0] - q, k, v = (qkv.reshape(seq_length, 3, self.num_heads, - -1).permute(1, 0, 2, 3).unbind(0)) - cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) - q, k, v = q.reshape(seq_length, - -1), k.reshape(seq_length, - -1), v.reshape(seq_length, -1) + q, k, v = self.apply_rope(q, k, v, position_embeddings) q, k, v = self.convert_qkv(q, k, v) + output = self.forward_impl(q=q, k=k, v=v, From a8ec75a48c8570827d58091917dca4a5b413bc58 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Wed, 25 Mar 2026 11:19:20 +0900 Subject: [PATCH 02/16] address model_type issue Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- tensorrt_llm/serve/chat_utils.py | 12 ++++++------ tensorrt_llm/serve/openai_server.py | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py index c4982407df9..6edfbc7dfe0 100644 --- a/tensorrt_llm/serve/chat_utils.py +++ b/tensorrt_llm/serve/chat_utils.py @@ -286,8 +286,8 @@ def parse_chat_messages_coroutines( """Parse multiple chat messages and return conversation and coroutine.""" conversation = [] mm_placeholder_counts = [] - mm_data_tracker = MultimodalDataTracker(model_config.model_type, - multimodal_server_config) + mm_data_tracker = MultimodalDataTracker( + type(model_config).model_type, multimodal_server_config) # Determine content format to decide placeholder strategy. # @@ -304,7 +304,7 @@ def parse_chat_messages_coroutines( # See also: `_resolve_content_format` (inputs/utils.py) for the full resolution used downstream. model_type = model_config.model_type registry_format = MULTIMODAL_PLACEHOLDER_REGISTRY.get_content_format( - model_type) + type(model_config).model_type) if registry_format is not None: content_format = registry_format else: @@ -333,14 +333,14 @@ def parse_chat_messages_coroutines( # prepend/append according to placeholder_placement. content_parts = parsed_msg.get("content_parts") interleave = MULTIMODAL_PLACEHOLDER_REGISTRY.get_interleave_placeholders( - model_type) + type(model_config).model_type) if content_parts and interleave: parsed_msg["content"] = interleave_mm_placeholders( - model_type, content_parts, msg_placeholder_counts, + type(model_config).model_type, content_parts, msg_placeholder_counts, mm_data_tracker.placeholder_modalities()) else: parsed_msg["content"] = add_multimodal_placeholders( - model_type, parsed_msg["content"], msg_placeholder_counts) + type(model_config).model_type, parsed_msg["content"], msg_placeholder_counts) mm_placeholder_counts.append(msg_placeholder_counts) return conversation, mm_data_tracker.retrieve_all_async( diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index bebf897bf3d..9349acbbb48 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -358,13 +358,13 @@ def _init_llm(self, chat_template: Optional[str] = None): if disable_harmony or self.model_config is None: self.use_harmony = False else: - self.use_harmony = (self.model_config.model_type == "gpt_oss") + self.use_harmony = (type(self.model_config).model_type == "gpt_oss") self.tool_call_id_type = "random" # default tool call id type is random if self.model_config is not None: - if self.model_config.model_type == "kimi_k2": + if type(self.model_config).model_type == "kimi_k2": self.tool_call_id_type = "kimi_k2" - elif self.model_config.model_type == "deepseek_v32": + elif type(self.model_config).model_type == "deepseek_v32": self.tool_call_id_type = "deepseek_v32" if self.generator.args.return_perf_metrics: @@ -958,7 +958,7 @@ async def chat_stream_generator( prompt = request.prompt_token_ids else: prompt: str = apply_chat_template( - model_type=self.model_config.model_type, + model_type=type(self.model_config).model_type, tokenizer=self.tokenizer, processor=self.processor, conversation=conversation, @@ -1098,7 +1098,7 @@ async def create_mm_embedding_response(promise: RequestOutput): prompt = request.prompt_token_ids else: prompt: str = apply_chat_template( - model_type=self.model_config.model_type, + model_type=type(self.model_config).model_type, tokenizer=self.tokenizer, processor=self.processor, conversation=conversation, From 304b61369ac9d7e1c9fa631789555b1c4eeaac4a Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Wed, 25 Mar 2026 17:36:24 +0900 Subject: [PATCH 03/16] transformer 5.0.0 compatibility Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- .../checkpoints/hf/exaone4_5_weight_mapper.py | 14 +++++++-- .../_torch/models/modeling_exaone4_5.py | 29 ++++++++++++++++--- .../_torch/models/modeling_qwen2vl.py | 8 +++-- 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/exaone4_5_weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/hf/exaone4_5_weight_mapper.py index 1bbbd8909b4..9b418c577c7 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/exaone4_5_weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/exaone4_5_weight_mapper.py @@ -1,12 +1,20 @@ +from typing import Union + +from torch import nn + +from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.checkpoints.base_weight_loader import ConsumableWeightsDict from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper -from tensorrt_llm._torch.models.modeling_utils import register_mapper +from tensorrt_llm._torch.models.modeling_utils import DecoderModelForCausalLM, register_mapper @register_mapper("HF", "Exaone4_5_ForConditionalGeneration") class Exaone4_5HfWeightMapper(HfWeightMapper): - def __init__(self): - super().__init__() + def init_model_and_config( + self, model: Union[nn.Module, DecoderModelForCausalLM], config: ModelConfig + ): + super().init_model_and_config(model, config) + self.model.config.tie_word_embeddings = False def preprocess_weights(self, weights: dict): """Rename HF checkpoint prefixes; supports plain dict and ConsumableWeightsDict.""" diff --git a/tensorrt_llm/_torch/models/modeling_exaone4_5.py b/tensorrt_llm/_torch/models/modeling_exaone4_5.py index c778ca0528d..8ad4088f6d3 100644 --- a/tensorrt_llm/_torch/models/modeling_exaone4_5.py +++ b/tensorrt_llm/_torch/models/modeling_exaone4_5.py @@ -15,10 +15,10 @@ # SPDX-License-Identifier: Apache-2.0 import copy -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch -from transformers import AutoTokenizer, PretrainedConfig, PreTrainedModel +from transformers import AutoConfig, AutoTokenizer, PretrainedConfig, PreTrainedModel from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper from tensorrt_llm._torch.models.modeling_multimodal_utils import _is_disagg @@ -48,6 +48,27 @@ from .modeling_utils import ModelConfig, register_auto_model, register_vision_encoder +class Exaone4_5Config(PretrainedConfig): + """VLM config: nested ``text_config`` / ``vision_config`` from JSON become real sub-configs.""" + + model_type = "exaone4_5" + + def __init__( + self, + text_config: Optional[Union[PretrainedConfig, dict]] = None, + vision_config: Optional[Union[PretrainedConfig, dict]] = None, + **kwargs, + ): + if isinstance(text_config, dict): + text_config = PretrainedConfig.from_dict(copy.deepcopy(text_config)) + if isinstance(vision_config, dict): + vision_config = PretrainedConfig.from_dict(copy.deepcopy(vision_config)) + super().__init__(text_config=text_config, vision_config=vision_config, **kwargs) + + +AutoConfig.register(Exaone4_5Config.model_type, Exaone4_5Config) + + class Exaone4_5InputProcessor(Qwen2VLInputProcessorBase): def __init__( self, @@ -118,7 +139,6 @@ def __init__( ) -> None: self.original_arch = model_config.pretrained_config.architectures[0] - # model_config.pretrained_config.rope_scaling['type'] = 'mrope' config = model_config.pretrained_config self._supports_sdpa = True @@ -131,8 +151,8 @@ def __init__( raise ValueError("Exaone4.5 only supports TRTLLM backend") llm_model_config = copy.deepcopy(model_config) - # llm_model_config.pretrained_config.architectures = ["Exaone4ForCausalLM"] llm_model_config.pretrained_config = llm_model_config.pretrained_config.text_config + llm_model_config.pretrained_config.tie_word_embeddings = False self.llm = AutoModelForCausalLM.from_config(llm_model_config) if not _is_disagg(): @@ -221,4 +241,5 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper): weights = weight_mapper.preprocess_weights(weights) if not _is_disagg(): self.mm_encoder.load_weights(weights) + print(self.llm.config) self.llm.load_weights(weights, weight_mapper) diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 951cd95e6ca..5138b907cc8 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -299,9 +299,11 @@ def _preprocess(self, text: Dict[str, any], mm_data: Dict[str, any], **mm_processor_kwargs) def _postprocess(self, input_ids: torch.IntTensor) -> torch.IntTensor: - masks = (input_ids == self.config.image_token_id) | ( - input_ids == self.config.vision_token_id) | ( - input_ids == self.config.video_token_id) + masks = torch.zeros_like(input_ids, dtype=torch.bool) + for attr in ("image_token_id", "vision_token_id", "video_token_id"): + token_id = getattr(self.config, attr, None) + if token_id is not None: + masks |= input_ids == token_id input_ids[masks] = self.tllm_multimodal_token_id return input_ids From a2ec83fccb219a212b22d9d72283aaa57b6ea88e Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Wed, 25 Mar 2026 18:35:18 +0900 Subject: [PATCH 04/16] remove print Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_exaone4_5.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorrt_llm/_torch/models/modeling_exaone4_5.py b/tensorrt_llm/_torch/models/modeling_exaone4_5.py index 8ad4088f6d3..4a75f0c0e6a 100644 --- a/tensorrt_llm/_torch/models/modeling_exaone4_5.py +++ b/tensorrt_llm/_torch/models/modeling_exaone4_5.py @@ -241,5 +241,4 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper): weights = weight_mapper.preprocess_weights(weights) if not _is_disagg(): self.mm_encoder.load_weights(weights) - print(self.llm.config) self.llm.load_weights(weights, weight_mapper) From 2b670ac2b84eb0424f783cdc65bf81a1f8ad198c Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Thu, 2 Apr 2026 19:02:00 +0900 Subject: [PATCH 05/16] optimize Qwen2.5 Vision Encoder Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- .../_torch/models/modeling_exaone4_5.py | 15 +- .../_torch/models/modeling_qwen2vl.py | 327 ++++++++++++------ 2 files changed, 213 insertions(+), 129 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_exaone4_5.py b/tensorrt_llm/_torch/models/modeling_exaone4_5.py index 4a75f0c0e6a..57f396912a2 100644 --- a/tensorrt_llm/_torch/models/modeling_exaone4_5.py +++ b/tensorrt_llm/_torch/models/modeling_exaone4_5.py @@ -1,18 +1,5 @@ -# Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# # SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. import copy from typing import List, Optional, Tuple, Union diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 5138b907cc8..95880bf3cc5 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -1,5 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + import copy import re +from functools import lru_cache from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -8,9 +12,7 @@ from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel) from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding, - Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLVisionBlock, - apply_rotary_pos_emb_vision) + Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionTransformerPretrainedModel) from transformers.models.qwen2_vl.modeling_qwen2_vl import \ Qwen2VisionTransformerPretrainedModel @@ -39,8 +41,9 @@ from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams from ..attention_backend.utils import get_attention_backend +from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE from ..modules.gated_mlp import GatedMLP -from ..modules.rotary_embedding import MRotaryEmbedding +from ..modules.rotary_embedding import MRotaryEmbedding, RotaryEmbedding from .modeling_auto import AutoModelForCausalLM from .modeling_multimodal_utils import (find_input_mm_embeds, fuse_input_embeds, get_multimodal_embeddings) @@ -484,6 +487,7 @@ def _parse_and_batch_multimodal_data( return mm_content_dict, mm_extra_data + @nvtx_range("Qwen2VisionModelBase forward()") @torch.inference_mode() def forward(self, multimodal_params: List[MultimodalParams]): @@ -517,8 +521,9 @@ def __init__(self, super().__init__( hidden_size=config.hidden_size, num_attention_heads=config.num_heads, - num_key_value_heads=config.num_key_value_heads - if config.num_key_value_heads is not None else config.num_heads, + num_key_value_heads=config.num_key_value_heads if getattr( + config, "num_key_value_heads", + None) is not None else config.num_heads, max_position_embeddings=model_config.pretrained_config. max_position_embeddings, bias=True, @@ -531,46 +536,76 @@ def __init__(self, ) def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], - v: Optional[torch.Tensor], + v: Optional[torch.Tensor], position_ids: torch.IntTensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor]): seq_len, _ = q.size() + cos, sin = position_embeddings + + # FlashInfer fused RoPE assumes head_size is a multiple of 64 (see + # auto_deploy custom op rope docs / flashinfer tests). Qwen2.5-VL vision + # uses head_dim=80 (e.g. 1280 hidden / 16 heads), so use PyTorch RoPE. + if IS_FLASHINFER_AVAILABLE and self.head_dim % 64 == 0: + try: + from ..custom_ops import \ + flashinfer_apply_rope_with_cos_sin_cache_inplace + cos_sin_cache = torch.cat([cos, sin], dim=-1).contiguous() + flashinfer_apply_rope_with_cos_sin_cache_inplace( + position_ids, + q, + k, + self.head_dim, + cos_sin_cache, + is_neox=True, + ) + return q, k, v + except RuntimeError as err: + logger.warning( + "Qwen2.5-VL vision RoPE: FlashInfer failed (%s); " + "falling back to PyTorch RotaryEmbedding.apply_rotary_pos_emb.", + err, + ) + + cos = cos.to(dtype=q.dtype) + sin = sin.to(dtype=q.dtype) q = q.view(seq_len, -1, self.head_dim) k = k.view(seq_len, -1, self.head_dim) v = v.view(seq_len, -1, self.head_dim) - cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + q = RotaryEmbedding.apply_rotary_pos_emb(q, cos, sin) + k = RotaryEmbedding.apply_rotary_pos_emb(k, cos, sin) q, k, v = q.reshape(seq_len, -1), k.reshape(seq_len, -1), v.reshape(seq_len, -1) return q, k, v def forward( self, + position_ids: torch.IntTensor, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], **kwargs, ) -> torch.Tensor: - # NOTE: Need separate Attention forward() for Qwen2.5-VL for multiple reasons - # 1. We don't have the route for handing over position_embeddings to the Attention forward() - # 2. Could not override the apply_rope() as we don't have the position_ids in the Vision Attention's rotary embedding. - # (TODO: yechank-nvidia) Make OOTB path more modular and reusable for Attention's Rotary Embedding. + # NOTE: Qwen2.5-VL vision attention needs a custom forward: the generic + # Attention path does not accept precomputed (cos, sin) position_embeddings, + # and vision RoPE may use FlashInfer with explicit position_ids. qkv = self.qkv_proj(hidden_states) q, k, v = qkv, None, None q, k, v = self.split_qkv(q, k, v) - q, k, v = self.apply_rope(q, k, v, position_embeddings) + q, k, v = self.apply_rope(q, k, v, position_ids, position_embeddings) q, k, v = self.convert_qkv(q, k, v) - output = self.forward_impl(q=q, - k=k, - v=v, - attn_metadata=attn_metadata, - attention_mask=PredefinedAttentionMask.FULL, - attention_window_size=None, - attention_mask_data=None, - mrope_config=None, - attention_sinks=None) + output = self.forward_impl( + q=q, + k=k, + v=v, + attn_metadata=attn_metadata, + attention_mask=PredefinedAttentionMask.FULL, + attention_window_size=None, + attention_mask_data=None, + mrope_config=None, + attention_sinks=None, + ) attn_output = self.o_proj(output, layer_idx=self.layer_idx) return attn_output @@ -609,6 +644,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], @torch.inference_mode() def forward( self, + position_ids: torch.IntTensor, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, rotary_pos_emb: Optional[torch.Tensor] = None, @@ -619,6 +655,7 @@ def forward( residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = residual + self.attn( + position_ids=position_ids, hidden_states=hidden_states, attn_metadata=attn_metadata, rotary_pos_emb=rotary_pos_emb, @@ -691,8 +728,18 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): embed_dim=self.config.hidden_size, ) - head_dim = self.config.hidden_size // self.config.num_heads - self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + self.config.max_position_embeddings = 8192 + self.config.partial_rotary_factor = 0.5 + self.head_dim = self.config.hidden_size // self.config.num_heads + self.pos_embd_params = PositionalEmbeddingParams( + type=PositionEmbeddingType.rope_gpt_neox, + rope=RopeParams.from_config(self.config), + ) + self.rotary_pos_emb = RotaryEmbedding( + self.pos_embd_params.rope, + head_dim=self.head_dim, + is_neox=self.pos_embd_params.is_neox, + ) self.blocks = torch.nn.ModuleList([ Qwen2_5_VLVisionBlock(model_config, layer_idx=layer_idx) @@ -713,78 +760,108 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): kv_cache_manager=None, ) - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def get_window_index(self, grid_thw): - window_index: List[torch.Tensor] = [] - seq_lens = [] + def get_rotary_pos_emb_window_data( + self, grid_rows: List[List[int]] + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], + List[int]]: window_index_id = 0 - vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + rotary_pos_emb_cos: List[torch.Tensor] = [] + rotary_pos_emb_sin: List[torch.Tensor] = [] + window_indices: List[torch.Tensor] = [] + window_seq_lens: List[int] = [] + for row in grid_rows: + t, h, w = int(row[0]), int(row[1]), int(row[2]) + llm_h = h // self.spatial_merge_size + llm_w = w // self.spatial_merge_size + (cos_thw, sin_thw, window_index_thw, + window_seq_lens_thw) = self.get_rope_and_window_index_by_thw( + t, h, w) + + window_indices.append(window_index_thw + window_index_id) + window_index_id += t * llm_h * llm_w + + rotary_pos_emb_cos.append(cos_thw) + rotary_pos_emb_sin.append(sin_thw) + + window_seq_lens.extend(window_seq_lens_thw) + + return (rotary_pos_emb_cos, rotary_pos_emb_sin, window_indices, + window_seq_lens) + + def get_window_index_by_thw(self, grid_t: int, grid_h: int, + grid_w: int) -> Tuple[torch.Tensor, List[int]]: + vit_merger_window_size = (self.window_size // self.spatial_merge_size // + self.patch_size) + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + index = torch.arange(grid_t * llm_grid_h * llm_grid_w, + dtype=torch.long).reshape(grid_t, llm_grid_h, + llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", PAD_INDEX) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != PAD_INDEX).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != PAD_INDEX] + seqlens = seqlens * self.spatial_merge_unit + return index_new, seqlens.tolist() + + @lru_cache(maxsize=1024) # noqa: B019 + def get_rope_and_window_index_by_thw( + self, t: int, h: int, w: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[int, ...]]: + """CPU (cos, sin, window_idx, seqlens) in window order; cached per ``(t, h, w)``.""" + hpos_ids = torch.arange(h, dtype=torch.long).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w, dtype=torch.long).unsqueeze(0).expand(h, -1) + hpos_ids = (hpos_ids.reshape(h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size).permute( + 0, 2, 1, 3).flatten()) + wpos_ids = (wpos_ids.reshape(h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size).permute( + 0, 2, 1, 3).flatten()) + pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) + max_grid_size = max(h, w) + cos_sin = self.rotary_pos_emb.rotary_cos_sin[:max_grid_size] + cos, sin = cos_sin[:, 0, :], cos_sin[:, 1, :] + cos_flattened = cos[pos_ids].flatten(1) + sin_flattened = sin[pos_ids].flatten(1) + + cos_thw = cos_flattened.reshape( + cos_flattened.shape[0] // self.spatial_merge_unit, + self.spatial_merge_unit, + -1, + ) + sin_thw = sin_flattened.reshape( + sin_flattened.shape[0] // self.spatial_merge_unit, + self.spatial_merge_unit, + -1, + ) - for grid_t, grid_h, grid_w in grid_thw: - llm_grid_h, llm_grid_w = ( - grid_h // self.spatial_merge_size, - grid_w // self.spatial_merge_size, - ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", - PAD_INDEX) - index_padded = index_padded.reshape( - grid_t, - num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size, - ) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, - num_windows_h * num_windows_w, - vit_merger_window_size, - vit_merger_window_size, - ) - seqlens = (index_padded != PAD_INDEX).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != PAD_INDEX] - window_index.append(index_new + window_index_id) - seqlens = seqlens * self.spatial_merge_unit - seq_lens.extend(seqlens.tolist()) - window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() - window_index = torch.cat(window_index, dim=0) + window_index_thw, seq_lens_thw = self.get_window_index_by_thw(t, h, w) - return window_index, seq_lens + cos_thw = cos_thw[window_index_thw, :, :].reshape(-1, cos_thw.shape[-1]) + sin_thw = sin_thw[window_index_thw, :, :].reshape(-1, sin_thw.shape[-1]) + + return cos_thw, sin_thw, window_index_thw, tuple(seq_lens_thw) def prepare_attn_metadata(self, seq_lens, attn_metadata: AttentionMetadata): batch_size = 1 # NOTE: Qwen2/2.5-VL concats all the pixel_values into a single tensor, so batch_size is 1 @@ -802,45 +879,65 @@ def prepare_attn_metadata(self, seq_lens, attn_metadata: AttentionMetadata): attn_metadata.prepare() return attn_metadata + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + @torch.inference_mode() def forward(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: - window_index, window_seq_lens = self.get_window_index(grid_thw) + + hidden_states = self.patch_embed(pixel_values) + + seq_len, _ = hidden_states.size() + rope_position_ids = torch.arange(seq_len, + dtype=torch.int32, + pin_memory=prefer_pinned()) + grid_rows = grid_thw.detach().cpu().tolist() + + (rotary_pos_emb_cos, rotary_pos_emb_sin, window_indices, + window_seq_lens) = self.get_rotary_pos_emb_window_data(grid_rows) + + window_index = torch.cat(window_indices).to(device=self.device, + non_blocking=True) + + # Scatter sort: window_index maps original token order -> window order. + reverse_indices = torch.empty_like(window_index) + reverse_indices[window_index] = torch.arange( + window_index.numel(), + device=self.device, + dtype=window_index.dtype, + ) + + cos = torch.cat(rotary_pos_emb_cos).to(device=self.device, + non_blocking=True) + sin = torch.cat(rotary_pos_emb_sin).to(device=self.device, + non_blocking=True) + position_embeddings = (cos, sin) + + rope_position_ids = rope_position_ids.to(device=self.device, + dtype=torch.int32, + non_blocking=True) seq_lens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).tolist() - reverse_indices = torch.argsort(window_index) - # Getting positional embedding - rotary_pos_emb = self.rot_pos_emb(grid_thw) + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :].reshape(seq_len, -1) full_attn_metadata = self.prepare_attn_metadata(seq_lens, self.full_attn_metadata) window_attn_metadata = self.prepare_attn_metadata( window_seq_lens, self.window_attn_metadata) - # From this point, pure GPU operation - hidden_states = self.patch_embed(pixel_values) - seq_len, _ = hidden_states.size() - hidden_states = hidden_states.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - hidden_states = hidden_states[window_index, :, :] - hidden_states = hidden_states.reshape(seq_len, -1) - - rotary_pos_emb = rotary_pos_emb.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - rotary_pos_emb = rotary_pos_emb[window_index, :, :] - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - position_embeddings = (emb.cos(), emb.sin()) - for layer_num, block in enumerate(self.blocks): - if layer_num in self.fullatt_block_indexes: attn_metadata = full_attn_metadata else: attn_metadata = window_attn_metadata hidden_states = block( - hidden_states, + position_ids=rope_position_ids, + hidden_states=hidden_states, attn_metadata=attn_metadata, position_embeddings=position_embeddings, ) From bb918051d3bcdcecf0bc4c74c398c1c6dfe5a4a0 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Thu, 2 Apr 2026 19:11:00 +0900 Subject: [PATCH 06/16] add EXAONE-4.5 to the README.md Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- examples/models/core/exaone/README.md | 29 ++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/examples/models/core/exaone/README.md b/examples/models/core/exaone/README.md index ad5aac64802..4db9cd6804b 100644 --- a/examples/models/core/exaone/README.md +++ b/examples/models/core/exaone/README.md @@ -14,9 +14,11 @@ This document shows how to build and run [EXAONE](https://huggingface.co/LGAI-EX - [EXAONE-3.0](#exaone-30) - [EXAONE-Deep](#exaone-deep) - [EXAONE-4.0](#exaone-40) + - [EXAONE-4.5](#exaone-45) - [K-EXAONE](#k-exaone) - [PyTorch flow](#pytorch-flow) - [Running EXAONE-4.0](#running-exaone-40) + - [Running EXAONE-4.5](#running-exaone-45) - [Running K-EXAONE](#running-k-exaone) - [MoE Backend Options](#moe-backend-options) - [PyTorch flow Quantization](#pytorch-flow-quantization) @@ -45,6 +47,7 @@ This document shows how to build and run [EXAONE](https://huggingface.co/LGAI-EX * FP16 * BF16 * Tensor Parallel (TP) + * Multimodal (EXAONE-4.5 only) * Expert Parallel (EP) (K-EXAONE only) * Attention Data Parallel (ADP) (K-EXAONE only) * Disaggregated Serving @@ -59,7 +62,7 @@ This document shows how to build and run [EXAONE](https://huggingface.co/LGAI-EX **Note:** - **EXAONE-3.0** & **EXAONE-Deep** are supported using the [TRT Flow](#trt-flow). -- **EXAONE-4.0** & **K-EXAONE** are supported using the [PyTorch flow](#pytorch-flow). +- **EXAONE-4.0**, **EXAONE-4.5**, & **K-EXAONE** are supported using the [PyTorch flow](#pytorch-flow). Please refer to the corresponding sections below for usage instructions and examples for each model. @@ -90,6 +93,17 @@ export HF_MODEL_DIR=hf_models/exaone4 git clone https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B $HF_MODEL_DIR ``` +### EXAONE-4.5 + +EXAONE-4.5 is a multimodal model. It is supported only via the [PyTorch flow](#pytorch-flow). + +Download the HuggingFace checkpoint for your EXAONE-4.5 variant from the [LGAI-EXAONE](https://huggingface.co/LGAI-EXAONE) organization. The example below uses a placeholder repo name; replace it with the model card you use. + +```bash +export HF_MODEL_DIR=hf_models/exaone4_5 +git clone https://huggingface.co/LGAI-EXAONE/ $HF_MODEL_DIR +``` + ### K-EXAONE K-EXAONE is a Mixture of Experts (MoE) model based on the EXAONE architecture. It features a hybrid architecture with both dense and MoE layers, sliding window attention, and supports FP8 and NVFP4 quantization for efficient inference. @@ -117,6 +131,19 @@ The output will be like: [2] Prompt: 'The future of AI is', Generated text: ' not just about technology but also about how we choose to use it. We must ensure that AI is developed and deployed in a way that benefits all of humanity, not just a select few. This means prioritizing ethical considerations, transparency, and accountability in AI development. It also means involving diverse stakeholders in the conversation about AI' ``` +### Running EXAONE-4.5 + +To quickly run EXAONE-4.5 models, you can use [examples/llm-api/quickstart_multimodal.py](../../../llm-api/quickstart_multimodal.py): + +```bash +python ../../../llm-api/quickstart_multimodal.py --model_dir $HF_MODEL_DIR +``` + +The output will be like: +```bash +TODO: FILL +``` + ### Running K-EXAONE K-EXAONE is a Mixture of Experts model that benefits from multiple parallelism strategies. You can run it with tensor parallelism (TP), expert parallelism (EP), and attention data parallelism (ADP): From 6e67024e17da750d6c9bac7bebe3ada21eee1f55 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Thu, 2 Apr 2026 20:27:58 +0900 Subject: [PATCH 07/16] add unittest Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- .../modeling/test_modeling_exaone4_5.py | 254 ++++++++++++++++++ 1 file changed, 254 insertions(+) create mode 100644 tests/unittest/_torch/modeling/test_modeling_exaone4_5.py diff --git a/tests/unittest/_torch/modeling/test_modeling_exaone4_5.py b/tests/unittest/_torch/modeling/test_modeling_exaone4_5.py new file mode 100644 index 00000000000..0fef90b786a --- /dev/null +++ b/tests/unittest/_torch/modeling/test_modeling_exaone4_5.py @@ -0,0 +1,254 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +from dataclasses import dataclass +from typing import List + +import torch +from test_modeling_multimodal import MultimodalScenario, TestModelingMultimodal +from transformers import PreTrainedModel + +from tensorrt_llm._torch.models.checkpoints.hf.exaone4_5_weight_mapper import ( + Exaone4_5HfWeightMapper, +) +from tensorrt_llm._torch.models.modeling_exaone4_5 import ( + Exaone4_5_ForConditionalGeneration, + Exaone4_5Config, +) +from tensorrt_llm._utils import get_sm_version + +EXAONE_4_5_TEST_CONFIG = { + "architectures": ["Exaone4_5_ForConditionalGeneration"], + "attention_dropout": 0.0, + "bos_token_id": 1, + "dtype": "bfloat16", + "eos_token_id": 53, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 27392, + "max_position_embeddings": 131072, + "max_window_layers": 64, + "model_type": "exaone4_5", + "num_attention_heads": 40, + "num_hidden_layers": 64, + "num_key_value_heads": 8, + "reorder_qk_norm": True, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 16.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + "rope_theta": 1000000.0, + "sliding_window": None, + "text_config": { + "architectures": ["Exaone4ForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 1, + "dtype": "bfloat16", + "eos_token_id": 53, + "hidden_act": "silu", + "hidden_size": 5120, + "image_token_id": 67, + "initializer_range": 0.02, + "intermediate_size": 27392, + "layer_types": [ + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + ], + "max_position_embeddings": 131072, + "max_window_layers": 64, + "model_type": "exaone4_vl_text", + "num_attention_heads": 40, + "num_hidden_layers": 64, + "num_key_value_heads": 8, + "num_kv_heads": 8, + "reorder_qk_norm": True, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 16.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + "rope_theta": 1000000.0, + "sliding_window": 4096, + "sliding_window_pattern": "LLLG", + "use_cache": True, + "video_token_id": None, + "vision_end_token_id": 74, + "vision_start_token_id": 73, + "vision_token_id": 67, + "vocab_size": 153600, + }, + "transformers_version": "5.0.0.dev0", + "use_cache": True, + "video_token_id": 68, + "vision_config": { + "depth": 28, + "dtype": "bfloat16", + "fullatt_block_indexes": [6, 13, 20, 27], + "hidden_act": "silu", + "hidden_size": 2048, + "in_channels": 3, + "in_chans": 3, + "initializer_range": 0.02, + "intermediate_size": 5120, + "model_type": "exaone4_5_vision", + "num_heads": 32, + "num_key_value_heads": 8, + "out_hidden_size": 5120, + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2, + "tokens_per_second": 2, + "torch_dtype": "bfloat16", + "window_size": 112, + }, + "vision_end_token_id": 74, + "vision_start_token_id": 73, + "vision_token_id": 67, + "vocab_size": 153600, + "_name_or_path": str( + os.path.join("/code/yechan-models", "exaone45_beta_2026-03-19_bf16") + ), # str(os.path.join(llm_models_root(), "Qwen2.5-VL-7B-Instruct")) +} + + +@dataclass(repr=False) +class TestExaone4_5Scenario(MultimodalScenario): + """Scenario config (name avoids pytest collecting as Test* class).""" + + pass + + +class TestExaone4_5(TestModelingMultimodal): + """ + Smoke tests for Exaone4.5. + + ``get_hf_model_class`` returns bare ``PreTrainedModel`` (no official HF + ``ForConditionalGeneration`` in pinned transformers). That yields an empty + ``state_dict``, so weight loading into TRT-LLM always fails unless you + skip HF and use ``load_weights=False`` (see ``skip_hf_inference``). + """ + + # TODO: Remove this once we have a proper transformers version for Exaone4.5 + @property + def skip_hf_inference(self) -> bool: + return True + + @property + def trust_remote_code(self) -> bool: + return True + + def get_model_config(self): + return EXAONE_4_5_TEST_CONFIG + + def get_trtllm_model_class(self): + return Exaone4_5_ForConditionalGeneration + + def get_hf_model_class(self): + # TODO: Change to EXAONE4_5ForConditionalGeneration + return PreTrainedModel + + def get_weight_mapper_class(self): + return Exaone4_5HfWeightMapper + + def get_model_type(self): + return "exaone4_5" + + def get_model_config_class(self): + return Exaone4_5Config + + def get_scenarios(self) -> List[TestExaone4_5Scenario]: + scenarios: List[TestExaone4_5Scenario] = [ + TestExaone4_5Scenario( + modality="image", use_cuda_graph=False, chunked_prefill=False, kv_cache_reuse=False + ), + TestExaone4_5Scenario( + modality="image", use_cuda_graph=True, chunked_prefill=False, kv_cache_reuse=False + ), + TestExaone4_5Scenario( + modality="image", use_cuda_graph=False, chunked_prefill=True, kv_cache_reuse=False + ), + ] + # Paged context + cache_reuse matches production but TRTLLM-GEN FMHA coverage + # on Blackwell (SM100) can differ from Hopper; run this scenario on Hopper only. + if torch.cuda.is_available() and get_sm_version() == 90: + scenarios.append( + TestExaone4_5Scenario( + modality="image", + use_cuda_graph=False, + chunked_prefill=False, + kv_cache_reuse=True, + ) + ) + return scenarios From 923db4f72dd80986a32334187c915e4434c2d367 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Fri, 3 Apr 2026 11:12:32 +0900 Subject: [PATCH 08/16] update PLACEHOLDERS Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- .../_torch/models/modeling_exaone4_5.py | 4 +-- .../modeling/test_modeling_exaone4_5.py | 28 ++++++++++++++++--- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_exaone4_5.py b/tensorrt_llm/_torch/models/modeling_exaone4_5.py index 57f396912a2..a23bbf447c6 100644 --- a/tensorrt_llm/_torch/models/modeling_exaone4_5.py +++ b/tensorrt_llm/_torch/models/modeling_exaone4_5.py @@ -204,8 +204,8 @@ def forward( model_type="exaone4_5", placeholder_metadata=MultimodalPlaceholderMetadata( placeholder_map={ - "image": "<|vision_start|><|image_pad|><|vision_end|>", - "video": "<|vision_start|><|video_pad|><|vision_end|>", + "image": "<|image_pad|>", + "video": "<|video_pad|>", }, placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT, ), diff --git a/tests/unittest/_torch/modeling/test_modeling_exaone4_5.py b/tests/unittest/_torch/modeling/test_modeling_exaone4_5.py index 0fef90b786a..bc8217fafa6 100644 --- a/tests/unittest/_torch/modeling/test_modeling_exaone4_5.py +++ b/tests/unittest/_torch/modeling/test_modeling_exaone4_5.py @@ -1,11 +1,14 @@ # SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import importlib import os from dataclasses import dataclass from typing import List import torch +import transformers +from packaging.version import Version from test_modeling_multimodal import MultimodalScenario, TestModelingMultimodal from transformers import PreTrainedModel @@ -183,6 +186,15 @@ } +def _transformers_version_at_most_5_3() -> bool: + """True when installed transformers is at most 5.3.x (no Exaone4.5 in HF yet).""" + ver = Version(transformers.__version__) + if not ver.release: + return True + major, minor = ver.release[0], ver.release[1] if len(ver.release) > 1 else 0 + return (major < 5) or (major == 5 and minor <= 3) + + @dataclass(repr=False) class TestExaone4_5Scenario(MultimodalScenario): """Scenario config (name avoids pytest collecting as Test* class).""" @@ -200,10 +212,9 @@ class TestExaone4_5(TestModelingMultimodal): skip HF and use ``load_weights=False`` (see ``skip_hf_inference``). """ - # TODO: Remove this once we have a proper transformers version for Exaone4.5 @property def skip_hf_inference(self) -> bool: - return True + return _transformers_version_at_most_5_3() @property def trust_remote_code(self) -> bool: @@ -216,8 +227,17 @@ def get_trtllm_model_class(self): return Exaone4_5_ForConditionalGeneration def get_hf_model_class(self): - # TODO: Change to EXAONE4_5ForConditionalGeneration - return PreTrainedModel + if _transformers_version_at_most_5_3(): + return PreTrainedModel + hf_cls = getattr(transformers, "Exaone4_5ForConditionalGeneration", None) + if hf_cls is not None: + return hf_cls + try: + mod = importlib.import_module("transformers.models.exaone4_5.modeling_exaone4_5") + hf_cls = getattr(mod, "Exaone4_5ForConditionalGeneration", None) + return hf_cls if hf_cls is not None else PreTrainedModel + except ImportError: + return PreTrainedModel def get_weight_mapper_class(self): return Exaone4_5HfWeightMapper From 12292fdfddffc644772738e0eb6e6cbb8f5540f3 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Fri, 3 Apr 2026 17:32:11 +0900 Subject: [PATCH 09/16] modify Qwen3-VL Vision Encoder Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- .../_torch/models/modeling_qwen2vl.py | 47 +-- .../_torch/models/modeling_qwen3vl.py | 293 +++++++++++------- 2 files changed, 198 insertions(+), 142 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 95880bf3cc5..3fc75e6cfe6 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -535,16 +535,20 @@ def __init__(self, reduce_output=reduce_output, ) - def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], - v: Optional[torch.Tensor], position_ids: torch.IntTensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor]): + def apply_rope(self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + position_ids: Optional[torch.IntTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None): seq_len, _ = q.size() cos, sin = position_embeddings # FlashInfer fused RoPE assumes head_size is a multiple of 64 (see # auto_deploy custom op rope docs / flashinfer tests). Qwen2.5-VL vision # uses head_dim=80 (e.g. 1280 hidden / 16 heads), so use PyTorch RoPE. - if IS_FLASHINFER_AVAILABLE and self.head_dim % 64 == 0: + if IS_FLASHINFER_AVAILABLE and self.head_dim % 64 == 0 and position_ids is not None: try: from ..custom_ops import \ flashinfer_apply_rope_with_cos_sin_cache_inplace @@ -578,10 +582,10 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], def forward( self, - position_ids: torch.IntTensor, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], + position_ids: Optional[torch.IntTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: # NOTE: Qwen2.5-VL vision attention needs a custom forward: the generic @@ -644,10 +648,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], @torch.inference_mode() def forward( self, - position_ids: torch.IntTensor, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, - rotary_pos_emb: Optional[torch.Tensor] = None, + position_ids: Optional[torch.IntTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: @@ -655,10 +658,9 @@ def forward( residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = residual + self.attn( - position_ids=position_ids, hidden_states=hidden_states, attn_metadata=attn_metadata, - rotary_pos_emb=rotary_pos_emb, + position_ids=position_ids, position_embeddings=position_embeddings, **kwargs, ) @@ -863,19 +865,19 @@ def get_rope_and_window_index_by_thw( return cos_thw, sin_thw, window_index_thw, tuple(seq_lens_thw) - def prepare_attn_metadata(self, seq_lens, attn_metadata: AttentionMetadata): - batch_size = 1 # NOTE: Qwen2/2.5-VL concats all the pixel_values into a single tensor, so batch_size is 1 - prompt_lens = seq_lens - seq_lens = torch.tensor(seq_lens, - dtype=torch.int, - pin_memory=prefer_pinned()) + def prepare_attn_metadata(self, batch_size: int, seq_lens: List[int], + attn_metadata: AttentionMetadata): + batch_size = len(seq_lens) + seq_lens_torch = torch.tensor(seq_lens, + dtype=torch.int, + pin_memory=prefer_pinned()) request_ids = list(range(1, batch_size + 1)) attn_metadata.num_contexts = len(seq_lens) attn_metadata.request_ids = request_ids - attn_metadata.prompt_lens = prompt_lens - attn_metadata.seq_lens = seq_lens - attn_metadata.max_seq_len = seq_lens.max().item() + attn_metadata.prompt_lens = seq_lens + attn_metadata.seq_lens = seq_lens_torch + attn_metadata.max_seq_len = max(seq_lens) attn_metadata.prepare() return attn_metadata @@ -925,10 +927,11 @@ def forward(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :].reshape(seq_len, -1) - full_attn_metadata = self.prepare_attn_metadata(seq_lens, + full_attn_metadata = self.prepare_attn_metadata(len(grid_rows), + seq_lens, self.full_attn_metadata) window_attn_metadata = self.prepare_attn_metadata( - window_seq_lens, self.window_attn_metadata) + len(grid_rows), window_seq_lens, self.window_attn_metadata) for layer_num, block in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: @@ -936,10 +939,10 @@ def forward(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, else: attn_metadata = window_attn_metadata hidden_states = block( - position_ids=rope_position_ids, hidden_states=hidden_states, attn_metadata=attn_metadata, position_embeddings=position_embeddings, + position_ids=rope_position_ids, ) hidden_states = self.merger(hidden_states) hidden_states = hidden_states[reverse_indices, :] diff --git a/tensorrt_llm/_torch/models/modeling_qwen3vl.py b/tensorrt_llm/_torch/models/modeling_qwen3vl.py index 5a418af6729..c6b4d921b2c 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3vl.py @@ -1,7 +1,9 @@ import copy import re +from functools import lru_cache from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn from transformers import AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel @@ -9,9 +11,6 @@ from transformers.models.qwen3_vl.modeling_qwen3_vl import ( Qwen3VLVisionPatchEmbed as HFQwen3VLVisionPatchEmbed, ) -from transformers.models.qwen3_vl.modeling_qwen3_vl import ( - Qwen3VLVisionRotaryEmbedding as HFQwen3VLVisionRotaryEmbedding, -) from tensorrt_llm._torch.models.modeling_multimodal_utils import _is_disagg from tensorrt_llm.functional import PositionEmbeddingType @@ -38,7 +37,7 @@ from ..modules.layer_norm import LayerNorm from ..modules.linear import Linear, TensorParallelMode from ..modules.mlp import MLP -from ..modules.rotary_embedding import MRotaryEmbedding +from ..modules.rotary_embedding import MRotaryEmbedding, RotaryEmbedding from .checkpoints.base_weight_mapper import BaseWeightMapper from .checkpoints.hf.qwen3vl_weight_mapper import Qwen3VLHfWeightMapper from .modeling_auto import AutoModelForCausalLM @@ -565,6 +564,77 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states +# Referenced from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_vl.py#L668 +def pos_embed_interpolate_native( + embed_weight: torch.Tensor, + t: int, + h: int, + w: int, + num_grid_per_side: int, + m_size: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Eager PyTorch bilinear position-embedding interpolation. + + Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the + bilinearly-interpolated position embeddings in spatial-merge order. + """ + assert h % m_size == 0 and w % m_size == 0, ( + f"h={h} and w={w} must be divisible by m_size={m_size}" + ) + hidden_dim = embed_weight.shape[1] + device = embed_weight.device + + h_idxs = torch.linspace( + 0, + num_grid_per_side - 1, + h, + dtype=torch.float32, + device=device, + ) + w_idxs = torch.linspace( + 0, + num_grid_per_side - 1, + w, + dtype=torch.float32, + device=device, + ) + + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") + h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij") + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") + + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1 - dh_grid - w01 + + h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) + w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) + h_grid_idx = h_grid * num_grid_per_side + + indices = (h_grid_idx + w_grid).reshape(4, -1) + weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) + weights = weights.to(dtype=dtype) + + embeds = embed_weight[indices] + embeds *= weights + combined = embeds.sum(dim=0) + + combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim) + combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) + repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) + return repeated.to(dtype=dtype) + + class Qwen3VisionModel(torch.nn.Module): def __init__(self, model_config: ModelConfig[PretrainedConfig]): super().__init__() @@ -582,8 +652,19 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): self.pos_embed = nn.Embedding(self.config.num_position_embeddings, self.config.hidden_size) self.num_grid_per_side = int(self.config.num_position_embeddings**0.5) - head_dim = self.config.hidden_size // self.config.num_heads - self.rotary_pos_emb = HFQwen3VLVisionRotaryEmbedding(head_dim // 2) + self.config.max_position_embeddings = 8192 + self.config.partial_rotary_factor = 0.5 + self.config.num_attention_heads = self.config.num_heads + self.head_dim = self.config.hidden_size // self.config.num_heads + self.pos_embd_params = PositionalEmbeddingParams( + type=PositionEmbeddingType.rope_gpt_neox, + rope=RopeParams.from_config(self.config), + ) + self.rotary_pos_emb = RotaryEmbedding( + self.pos_embd_params.rope, + head_dim=self.head_dim, + is_neox=self.pos_embd_params.is_neox, + ) self.blocks = nn.ModuleList( [ @@ -613,118 +694,89 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): kv_cache_manager=None, ) - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - merge_size = self.spatial_merge_size - - max_hw = int(grid_thw[:, 1:].max().item()) - freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) - device = freq_table.device - - total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw: - merged_h, merged_w = height // merge_size, width // merge_size - - block_rows = torch.arange(merged_h, device=device) # block row indices - block_cols = torch.arange(merged_w, device=device) # block col indices - intra_row = torch.arange(merge_size, device=device) # intra-block row offsets - intra_col = torch.arange(merge_size, device=device) # intra-block col offsets - - # Compute full-resolution positions - row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] - col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] - - row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - - coords = torch.stack((row_idx, col_idx), dim=-1) - - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset : offset + num_tokens] = coords - offset += num_tokens - - embeddings = freq_table[pos_ids] # lookup rotary embeddings - embeddings = embeddings.flatten(1) - return embeddings - - def fast_pos_embed_interpolate(self, grid_thw): - grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in zip(grid_ts, grid_hs, grid_ws): - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side - - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] - - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) - - idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) - weight_tensor = torch.tensor( - weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + @staticmethod + @lru_cache(maxsize=1024) + def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor: + hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w)) + h_div = h // spatial_merge_size + w_div = w // spatial_merge_size + hpos_ids = hpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + hpos_ids = hpos_ids.transpose(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w)) + wpos_ids = wpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, ) - pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] - patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - - patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) - - patch_pos_embeds_permute = [] - merge_size = self.config.spatial_merge_size - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) - pos_embed = ( - pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) - .permute(0, 1, 3, 2, 4, 5) - .flatten(0, 4) + wpos_ids = wpos_ids.transpose(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + + return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1)) + + def rot_pos_emb(self, grid_thw: list[list[int]]): + max_grid_size = max(max(h, w) for _, h, w in grid_thw) + pos_ids = [ + self.rot_pos_ids(h, w, self.spatial_merge_size) + if t == 1 + else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1) + for t, h, w in grid_thw + ] + pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True) + + # Use pre-computed cos_sin_cache from RotaryEmbedding + cos_sin = self.rotary_pos_emb.rotary_cos_sin[:max_grid_size] + cos, sin = cos_sin[:, 0, :], cos_sin[:, 1, :] + cos_combined = cos[pos_ids].flatten(1) + sin_combined = sin[pos_ids].flatten(1) + + return (cos_combined, sin_combined) + + # Referenced from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_vl.py#L668 + def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: + interpolate_fn = pos_embed_interpolate_native + outputs = [] + for t, h, w in grid_thw: + outputs.append( + interpolate_fn( + self.pos_embed.weight, + t, + h, + w, + self.num_grid_per_side, + self.spatial_merge_size, + self.dtype, + ) ) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds + return torch.cat(outputs, dim=0) - def prepare_attn_metadata(self, seq_lens, attn_metadata: AttentionMetadata): - # NOTE: The single prompt is divided into multiple seq_lens, so pretending have many batch_sizes. + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + def prepare_attn_metadata( + self, batch_size: int, seq_lens: List[int], attn_metadata: AttentionMetadata + ): batch_size = len(seq_lens) - prompt_lens = seq_lens - seq_lens = torch.tensor(seq_lens, dtype=torch.int, pin_memory=prefer_pinned()) + seq_lens_torch = torch.tensor(seq_lens, dtype=torch.int, pin_memory=prefer_pinned()) request_ids = list(range(1, batch_size + 1)) - attn_metadata.num_contexts = batch_size + attn_metadata.num_contexts = len(seq_lens) attn_metadata.request_ids = request_ids - attn_metadata.prompt_lens = prompt_lens - attn_metadata.seq_lens = seq_lens - attn_metadata.max_seq_len = seq_lens.max().item() + attn_metadata.prompt_lens = seq_lens + attn_metadata.seq_lens = seq_lens_torch + attn_metadata.max_seq_len = max(seq_lens) attn_metadata.prepare() return attn_metadata @@ -733,28 +785,29 @@ def forward( self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs ) -> Tuple[torch.Tensor, List[torch.Tensor]]: seq_lens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).tolist() - attn_metadata = self.prepare_attn_metadata(seq_lens, self.attn_metadata) + grid_rows = grid_thw.detach().cpu().tolist() + attn_metadata = self.prepare_attn_metadata(len(grid_thw), seq_lens, self.attn_metadata) # Getting positional embedding rotary_pos_emb = self.rot_pos_emb(grid_thw) - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - # From this point, pure GPU operation + pos_embeds = self.fast_pos_embed_interpolate(grid_rows) hidden_states = self.patch_embed(pixel_values) hidden_states = hidden_states + pos_embeds seq_len, _ = hidden_states.size() + rope_position_ids = torch.arange(seq_len, dtype=torch.int32, pin_memory=prefer_pinned()) + rope_position_ids = rope_position_ids.to( + device=self.device, dtype=torch.int32, non_blocking=True + ) hidden_states = hidden_states.reshape(seq_len, -1) - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - position_embeddings = (emb.cos(), emb.sin()) - deepstack_feature_lists = [] for layer_num, block in enumerate(self.blocks): hidden_states = block( - hidden_states, + position_ids=rope_position_ids, + hidden_states=hidden_states, attn_metadata=attn_metadata, - position_embeddings=position_embeddings, + position_embeddings=rotary_pos_emb, ) if layer_num in self.deepstack_visual_indexes: deepstack_feature = self.deepstack_merger_list[ From 17a68b5bca5f24e0aa5add4a8910d7d93e467b98 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Mon, 6 Apr 2026 14:35:58 +0900 Subject: [PATCH 10/16] refine test Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_exaone4_5.py | 2 +- tensorrt_llm/_torch/models/modeling_exaone_moe.py | 2 +- .../_torch/modeling/test_modeling_exaone4_5.py | 15 +++++++++++++++ .../_torch/modeling/test_modeling_multimodal.py | 13 +++++++++++++ 4 files changed, 30 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_exaone4_5.py b/tensorrt_llm/_torch/models/modeling_exaone4_5.py index a23bbf447c6..eab7cb0fd04 100644 --- a/tensorrt_llm/_torch/models/modeling_exaone4_5.py +++ b/tensorrt_llm/_torch/models/modeling_exaone4_5.py @@ -53,7 +53,7 @@ def __init__( super().__init__(text_config=text_config, vision_config=vision_config, **kwargs) -AutoConfig.register(Exaone4_5Config.model_type, Exaone4_5Config) +AutoConfig.register(Exaone4_5Config.model_type, Exaone4_5Config, exist_ok=True) class Exaone4_5InputProcessor(Qwen2VLInputProcessorBase): diff --git a/tensorrt_llm/_torch/models/modeling_exaone_moe.py b/tensorrt_llm/_torch/models/modeling_exaone_moe.py index fe420178558..621065ccb47 100644 --- a/tensorrt_llm/_torch/models/modeling_exaone_moe.py +++ b/tensorrt_llm/_torch/models/modeling_exaone_moe.py @@ -53,7 +53,7 @@ class ExaoneMoEConfig(PretrainedConfig): "Register ExaoneMoEConfig to mimic the ExaoneMoE model.", key="EXAONE_MOE_REGISTER_WARNING" ) -AutoConfig.register(ExaoneMoEConfig.model_type, ExaoneMoEConfig) +AutoConfig.register(ExaoneMoEConfig.model_type, ExaoneMoEConfig, exist_ok=True) # End of the config register. # fmt: on diff --git a/tests/unittest/_torch/modeling/test_modeling_exaone4_5.py b/tests/unittest/_torch/modeling/test_modeling_exaone4_5.py index bc8217fafa6..5afe5b6c047 100644 --- a/tests/unittest/_torch/modeling/test_modeling_exaone4_5.py +++ b/tests/unittest/_torch/modeling/test_modeling_exaone4_5.py @@ -216,6 +216,21 @@ class TestExaone4_5(TestModelingMultimodal): def skip_hf_inference(self) -> bool: return _transformers_version_at_most_5_3() + @property + def skip_test(self) -> bool: + path = EXAONE_4_5_TEST_CONFIG.get("_name_or_path") + if not path: + return True + return not os.path.exists(path) + + @property + def skip_test_reason(self) -> str: + path = EXAONE_4_5_TEST_CONFIG.get("_name_or_path") + return ( + "Exaone4.5 multimodal test requires local weights at " + f"config _name_or_path (missing or path not found): {path!r}" + ) + @property def trust_remote_code(self) -> bool: return True diff --git a/tests/unittest/_torch/modeling/test_modeling_multimodal.py b/tests/unittest/_torch/modeling/test_modeling_multimodal.py index b65dfe85377..89cb49a8baf 100644 --- a/tests/unittest/_torch/modeling/test_modeling_multimodal.py +++ b/tests/unittest/_torch/modeling/test_modeling_multimodal.py @@ -105,6 +105,16 @@ def trust_remote_code(self) -> bool: """Return whether to trust remote code. Will override when using custom config and model classes.""" return False + @property + def skip_test(self) -> bool: + """Return whether to skip the entire test class (e.g. missing local weights).""" + return False + + @property + def skip_test_reason(self) -> str: + """Message passed to unittest.skipTest when skip_test is True.""" + return "skip_test is True for this class (see skip_test property)." + @property def skip_hf_inference(self) -> bool: """Return whether to skip HuggingFace inference.""" @@ -658,6 +668,9 @@ def get_scenarios(self) -> List[MultimodalScenario]: def setUp(self): """Initialize models and configurations for testing.""" + if self.skip_test: + self.skipTest(self.skip_test_reason) + torch.random.manual_seed(0) # TODO: Add multi-GPU support From 403c34949c4cb5d6e636378e31b8b3df858b0757 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Tue, 7 Apr 2026 13:36:19 +0900 Subject: [PATCH 11/16] fix weight loading Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- .../_torch/models/modeling_exaone4_5.py | 2 + .../_torch/models/modeling_qwen2vl.py | 53 +++++++++++++++++-- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_exaone4_5.py b/tensorrt_llm/_torch/models/modeling_exaone4_5.py index eab7cb0fd04..5e410215efc 100644 --- a/tensorrt_llm/_torch/models/modeling_exaone4_5.py +++ b/tensorrt_llm/_torch/models/modeling_exaone4_5.py @@ -11,6 +11,7 @@ from tensorrt_llm._torch.models.modeling_multimodal_utils import _is_disagg from ...inputs import ( + ContentFormat, ExtraProcessedInputs, MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, @@ -208,6 +209,7 @@ def forward( "video": "<|video_pad|>", }, placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT, + content_format=ContentFormat.STRING, ), ) class Exaone4_5_ForConditionalGeneration(Exaone4_5_VLModel): diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 3fc75e6cfe6..3dadca21e4f 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -404,6 +404,49 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], raise NotImplementedError( f"Model class {model_class} not implemented") + def _split_fused_vision_qkv_tensor( + self, tensor: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Split HF fused ``attn.qkv`` along output dim (dim 0 for Linear). + + Qwen2.5-VL vision is **MHA** (``num_key_value_heads == num_heads``): Q, K, and V + each occupy ``num_heads * head_dim`` — three equal blocks. + + EXAONE-4.5 vision is **GQA**: Q uses ``num_heads * head_dim``, K and V each use + ``num_key_value_heads * head_dim`` (asymmetric split). + """ + cfg = self.config + num_heads = cfg.num_heads + num_kv_heads = getattr(cfg, "num_key_value_heads", None) + if num_kv_heads is None: + num_kv_heads = num_heads + head_dim, rem = divmod(cfg.hidden_size, num_heads) + if rem != 0: + raise ValueError( + f"vision hidden_size {cfg.hidden_size} not divisible by " + f"num_heads {num_heads}") + q_dim = num_heads * head_dim + kv_dim = num_kv_heads * head_dim + # Fused Linear out_features = Q + K + V along dim 0 of the weight (or bias). + fused_out_features = q_dim + 2 * kv_dim + leading_dim = tensor.shape[0] + if leading_dim == fused_out_features: + # GQA (e.g. EXAONE-4.5 vision) or MHA with fused length matching config. + return (tensor[:q_dim], tensor[q_dim:q_dim + kv_dim], + tensor[q_dim + kv_dim:]) + if num_kv_heads == num_heads and leading_dim % 3 == 0: + # MHA (e.g. Qwen2.5-VL vision): three equal Q/K/V blocks; used if fused + # leading dim is a triple split but does not match ``fused_out_features``. + dim_shape = leading_dim // 3 + return (tensor[:dim_shape], tensor[dim_shape:2 * dim_shape], + tensor[2 * dim_shape:]) + raise ValueError( + f"Fused vision qkv leading dim is {leading_dim}, " + f"want {fused_out_features} from config (q_dim={q_dim}, kv_dim={kv_dim}) " + f"or for MHA a length divisible by 3; " + f"num_heads={num_heads}, num_key_value_heads={num_kv_heads}, " + f"head_dim={head_dim}.") + def load_weights(self, weights: Dict): visual_weights = filter_weights("visual", weights) converted_weights = dict() @@ -421,11 +464,11 @@ def load_weights(self, weights: Dict): q_name = f"{prefix}attn.q_proj.{suffix}" k_name = f"{prefix}attn.k_proj.{suffix}" v_name = f"{prefix}attn.v_proj.{suffix}" - dim_shape = visual_weights[name].shape[0] // 3 - converted_weights[q_name] = visual_weights[name][:dim_shape] - converted_weights[k_name] = visual_weights[name][dim_shape:2 * - dim_shape] - converted_weights[v_name] = visual_weights[name][2 * dim_shape:] + q_part, k_part, v_part = self._split_fused_vision_qkv_tensor( + visual_weights[name]) + converted_weights[q_name] = q_part + converted_weights[k_name] = k_part + converted_weights[v_name] = v_part else: converted_weights[name] = visual_weights[name] pattern_mapping = { From d1c8f6713579e749d5b24da45ddaa2f5459d9db7 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Thu, 9 Apr 2026 11:50:04 +0900 Subject: [PATCH 12/16] compatibility matching to latest model HF Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- requirements.txt | 2 +- tensorrt_llm/_torch/models/modeling_clip.py | 10 ++++------ tensorrt_llm/_torch/models/modeling_exaone4_5.py | 13 ++++++++++--- tensorrt_llm/_torch/models/modeling_llama.py | 2 +- tensorrt_llm/_torch/models/modeling_qwen2vl.py | 11 +++++++---- tensorrt_llm/_torch/models/modeling_siglip.py | 10 ++++------ .../_torch/visual_gen/models/wan/transformer_wan.py | 5 ++--- tensorrt_llm/models/gpt/convert.py | 4 ++-- tensorrt_llm/tools/multimodal_builder.py | 6 +++--- 9 files changed, 34 insertions(+), 29 deletions(-) diff --git a/requirements.txt b/requirements.txt index b76e28208bd..bfe364fde97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,7 +30,7 @@ nvidia-modelopt[torch]~=0.37.0 # torch 2.10.0+cu130 depends on nvidia-nccl-cu13==2.28.9 nvidia-nccl-cu13>=2.28.9,<=2.29.2 nvidia-cuda-nvrtc -transformers==4.57.3 +transformers==5.3.0 prometheus_client prometheus_fastapi_instrumentator pydantic>=2.9.1 diff --git a/tensorrt_llm/_torch/models/modeling_clip.py b/tensorrt_llm/_torch/models/modeling_clip.py index 1e203eda8b7..d50eb55dc31 100644 --- a/tensorrt_llm/_torch/models/modeling_clip.py +++ b/tensorrt_llm/_torch/models/modeling_clip.py @@ -4,8 +4,6 @@ import torch.nn as nn from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutput -from transformers.modeling_utils import (get_parameter_device, - get_parameter_dtype) from transformers.models.clip.configuration_clip import CLIPVisionConfig from transformers.models.clip.modeling_clip import CLIPVisionEmbeddings @@ -219,12 +217,12 @@ def prepare_attn_metadata(self, batch_size): return self.attn_metadata @property - def dtype(self): - return get_parameter_dtype(self) + def dtype(self) -> torch.dtype: + return self.vision_model.embeddings.patch_embedding.weight.dtype @property - def device(self): - return get_parameter_device(self) + def device(self) -> torch.device: + return self.vision_model.embeddings.patch_embedding.weight.device @torch.inference_mode() def forward(self, diff --git a/tensorrt_llm/_torch/models/modeling_exaone4_5.py b/tensorrt_llm/_torch/models/modeling_exaone4_5.py index 5e410215efc..6d19f863390 100644 --- a/tensorrt_llm/_torch/models/modeling_exaone4_5.py +++ b/tensorrt_llm/_torch/models/modeling_exaone4_5.py @@ -114,6 +114,9 @@ class Exaone4_5_VisionModel(Qwen2VisionModelBase): def __init__( self, model_config: ModelConfig[PretrainedConfig], model_class: type[Qwen2_5_VisionModel] ): + model_config.pretrained_config.max_position_embeddings = ( + model_config.pretrained_config.text_config.max_position_embeddings + ) super().__init__(model_config, model_class=model_class) self.config.tie_word_embeddings = False @@ -132,9 +135,6 @@ def __init__( self._supports_sdpa = True PreTrainedModel.__init__(self, config) - self.model_config = model_config - self.config = model_config.pretrained_config - if model_config.attn_backend != "TRTLLM": raise ValueError("Exaone4.5 only supports TRTLLM backend") @@ -152,6 +152,13 @@ def __init__( else: self.mm_encoder = None + model_config.pretrained_config.text_config.architectures = ( + model_config.pretrained_config.architectures + ) + model_config.pretrained_config = model_config.pretrained_config.text_config + self.model_config = model_config + self.config = model_config.pretrained_config + def infer_max_seq_len(self) -> int: return self.llm.infer_max_seq_len() diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index b13c2e3de91..4e8c4ae4aeb 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -7,8 +7,8 @@ from torch import nn from transformers import (AutoProcessor, AutoTokenizer, Llama4Config, Llama4VisionModel, LlamaConfig, PretrainedConfig) -from transformers.modeling_utils import load_sharded_checkpoint from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector +from transformers.trainer_utils import load_sharded_checkpoint from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, MoEAllReduce) diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 3dadca21e4f..24ada711235 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -380,7 +380,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], type[torch.nn.Module]]): super().__init__() self.model_config = model_config - self.model_dtype = self.model_config.pretrained_config.torch_dtype + self.model_dtype = self.model_config.pretrained_config.torch_dtype or self.model_config.pretrained_config.vision_config.dtype self.config = self.model_config.pretrained_config.vision_config self.config.num_attention_heads = self.config.num_heads @@ -680,10 +680,12 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], super().__init__() config = model_config.pretrained_config.vision_config self.norm1 = RMSNorm(hidden_size=config.hidden_size, - eps=model_config.pretrained_config.rms_norm_eps, + eps=getattr(model_config.pretrained_config, + "rms_norm_eps", 1e-6), dtype=model_config.pretrained_config.torch_dtype) self.norm2 = RMSNorm(hidden_size=config.hidden_size, - eps=model_config.pretrained_config.rms_norm_eps, + eps=getattr(model_config.pretrained_config, + "rms_norm_eps", 1e-6), dtype=model_config.pretrained_config.torch_dtype) self.attn = Qwen2_5_VLVisionAttention(model_config, layer_idx) self.mlp = Qwen2_5_VLMLP(model_config, layer_idx) @@ -725,7 +727,8 @@ def __init__(self, context_dim = config.hidden_size self.hidden_size = context_dim * (spatial_merge_size**2) self.ln_q = RMSNorm(hidden_size=context_dim, - eps=model_config.pretrained_config.rms_norm_eps, + eps=getattr(model_config.pretrained_config, + "rms_norm_eps", 1e-6), dtype=model_config.pretrained_config.torch_dtype) self.mlp = torch.nn.Sequential( Linear(in_features=self.hidden_size, diff --git a/tensorrt_llm/_torch/models/modeling_siglip.py b/tensorrt_llm/_torch/models/modeling_siglip.py index e4ed6d462b8..261c972611a 100644 --- a/tensorrt_llm/_torch/models/modeling_siglip.py +++ b/tensorrt_llm/_torch/models/modeling_siglip.py @@ -2,8 +2,6 @@ import torch import torch.nn as nn -from transformers.modeling_utils import (get_parameter_device, - get_parameter_dtype) from transformers.models.siglip.configuration_siglip import SiglipVisionConfig from transformers.models.siglip.modeling_siglip import (SiglipVisionConfig, SiglipVisionEmbeddings) @@ -112,12 +110,12 @@ def prepare_attn_metadata(self, batch_size): return self.attn_metadata @property - def dtype(self): - return get_parameter_dtype(self) + def dtype(self) -> torch.dtype: + return self.vision_model.embeddings.patch_embedding.weight.dtype @property - def device(self): - return get_parameter_device(self) + def device(self) -> torch.device: + return self.vision_model.embeddings.patch_embedding.weight.device @torch.inference_mode() def forward(self, pixel_values, attn_metadata: AttentionMetadata): diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py index 49e56c4d23d..e1a05752bbb 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py @@ -6,7 +6,6 @@ import torch.nn.functional as F from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from tqdm import tqdm -from transformers.modeling_utils import get_parameter_device from tensorrt_llm._torch.modules.layer_norm import LayerNorm from tensorrt_llm._torch.modules.linear import Linear @@ -544,8 +543,8 @@ def __init__( self.__post_init__() @property - def device(self): - return get_parameter_device(self) + def device(self) -> torch.device: + return self.patch_embedding.weight.device def __post_init__(self): self.apply_quant_config_exclude_modules() diff --git a/tensorrt_llm/models/gpt/convert.py b/tensorrt_llm/models/gpt/convert.py index 1e2bc4b999d..b8a7ffda56c 100644 --- a/tensorrt_llm/models/gpt/convert.py +++ b/tensorrt_llm/models/gpt/convert.py @@ -29,7 +29,7 @@ import torch.nn as nn import yaml from tqdm import tqdm -from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq, +from transformers import (AutoModelForCausalLM, AutoModelForImageTextToText, AutoTokenizer) from transformers.models.gpt2.modeling_gpt2 import GPT2Block from transformers.pytorch_utils import Conv1D @@ -910,7 +910,7 @@ def quantize(hf_model_dir: str, def load_hf_gpt(model_dir: str, load_model_on_cpu: bool = False): if 'kosmos-2' in model_dir: - hf_model = AutoModelForVision2Seq.from_pretrained( + hf_model = AutoModelForImageTextToText.from_pretrained( model_dir, trust_remote_code=True) else: hf_model = AutoModelForCausalLM.from_pretrained( diff --git a/tensorrt_llm/tools/multimodal_builder.py b/tensorrt_llm/tools/multimodal_builder.py index bf948eb2506..2adbd5109fa 100644 --- a/tensorrt_llm/tools/multimodal_builder.py +++ b/tensorrt_llm/tools/multimodal_builder.py @@ -15,7 +15,7 @@ from tensorrt_llm.builder import Builder from tensorrt_llm.logger import logger from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, - AutoModelForVision2Seq, AutoProcessor, + AutoModelForImageTextToText, AutoProcessor, Blip2ForConditionalGeneration, Blip2Processor, FuyuForCausalLM, FuyuProcessor, LlavaForConditionalGeneration, NougatProcessor, @@ -953,8 +953,8 @@ def forward(self, images): img_features, _ = self.connector(img_features) return img_features - model = AutoModelForVision2Seq.from_pretrained(args.model_path, - dtype=torch.float16) + model = AutoModelForImageTextToText.from_pretrained(args.model_path, + dtype=torch.float16) wrapper = VisionEncoderWrapper( model.vision_model.to(args.device), model.image_to_text_projection.to(args.device)) From 2c1ae85b58a69165fff6ebef65310050d4e4344b Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Mon, 13 Apr 2026 17:26:56 +0900 Subject: [PATCH 13/16] fix yampf Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- tensorrt_llm/serve/chat_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py index 6edfbc7dfe0..63cd542bc90 100644 --- a/tensorrt_llm/serve/chat_utils.py +++ b/tensorrt_llm/serve/chat_utils.py @@ -302,7 +302,6 @@ def parse_chat_messages_coroutines( # path calls `_build_openai_content`, which reconstructs `conv["content"]` from # `content_parts` - overwriting any STRING-style placeholders inserted here. # See also: `_resolve_content_format` (inputs/utils.py) for the full resolution used downstream. - model_type = model_config.model_type registry_format = MULTIMODAL_PLACEHOLDER_REGISTRY.get_content_format( type(model_config).model_type) if registry_format is not None: @@ -336,11 +335,13 @@ def parse_chat_messages_coroutines( type(model_config).model_type) if content_parts and interleave: parsed_msg["content"] = interleave_mm_placeholders( - type(model_config).model_type, content_parts, msg_placeholder_counts, + type(model_config).model_type, content_parts, + msg_placeholder_counts, mm_data_tracker.placeholder_modalities()) else: parsed_msg["content"] = add_multimodal_placeholders( - type(model_config).model_type, parsed_msg["content"], msg_placeholder_counts) + type(model_config).model_type, parsed_msg["content"], + msg_placeholder_counts) mm_placeholder_counts.append(msg_placeholder_counts) return conversation, mm_data_tracker.retrieve_all_async( From eb46c0515abcceae5f4f1551cf3b92fe9dc0093d Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Mon, 13 Apr 2026 18:41:22 +0900 Subject: [PATCH 14/16] remove NemotronH from Autoconfig register Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_nemotron_h.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_h.py b/tensorrt_llm/_torch/models/modeling_nemotron_h.py index 623195da94a..9e78a508f2a 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_h.py @@ -23,7 +23,7 @@ from tensorrt_llm.llmapi.llm_args import TorchLlmArgs from torch import nn -from transformers import AutoConfig, PretrainedConfig +from transformers import PretrainedConfig from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ BaseWeightMapper @@ -1071,6 +1071,3 @@ def forward( lora_params=lora_params, ) return hidden_states - - -AutoConfig.register(NemotronHConfig.model_type, NemotronHConfig) From 97c4792307b5694561f8498b9b4c9e4b6ffdcb66 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Mon, 13 Apr 2026 19:53:38 +0900 Subject: [PATCH 15/16] remove qwen3.5 moe from Autoconfig Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- .../_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py index e227bc7ebec..a0110e74c81 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py @@ -32,7 +32,6 @@ from PIL import Image from torch import nn from torch.export import Dim -from transformers import AutoConfig from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from transformers.generation import GenerationMixin @@ -2854,8 +2853,5 @@ def init_input_processor(self, base): # Registration # ============================================================================= -AutoConfig.register("qwen3_5_moe", Qwen3_5MoeConfig) -AutoConfig.register("qwen3_5_moe_text", Qwen3_5MoeTextConfig) - AutoModelForCausalLMFactory.register_custom_model_cls("Qwen3_5MoeTextConfig", Qwen3_5MoeForCausalLM) Qwen3_5MoeFactory.register_custom_model_cls("Qwen3_5MoeConfig", Qwen3_5MoeForConditionalGeneration) From 8d029d5b61f0b8b351bf1c026217e9af43b580e7 Mon Sep 17 00:00:00 2001 From: yechank <161688079+yechank-nvidia@users.noreply.github.com> Date: Mon, 13 Apr 2026 23:32:20 +0900 Subject: [PATCH 16/16] remove HybridCache and fix Gemma3 Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> --- tensorrt_llm/_torch/attention_backend/interface.py | 11 +++++++---- tensorrt_llm/_torch/models/modeling_gemma3.py | 7 ++++++- tensorrt_llm/functional.py | 3 +++ .../unittest/_torch/modeling/test_modeling_cohere2.py | 10 ++-------- .../unittest/_torch/modeling/test_modeling_exaone4.py | 8 ++------ .../unittest/_torch/modeling/test_modeling_gemma3.py | 8 ++------ 6 files changed, 22 insertions(+), 25 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index 600f655bc51..b35c97288cc 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -485,10 +485,13 @@ def from_config(config) -> "RopeParams": hf_rope_parameters = getattr(config, 'rope_parameters', None) if hf_rope_parameters is not None: - assert not set(hf_rope_parameters.keys()).issubset( - ALLOWED_ATTENTION_LAYER_TYPES), ( - "Per-layer-type RoPE configuration is not supported yet.") - config.update(hf_rope_parameters) + if not set(hf_rope_parameters.keys()).issubset( + ALLOWED_ATTENTION_LAYER_TYPES): + # Flat rope_parameters dict: merge into config directly. + config.update(hf_rope_parameters) + # Per-layer-type rope_parameters (e.g. Gemma3 in transformers>=5.x) + # are handled by model-specific logic (e.g. rope_local_base_freq), + # so skip merging here. # get rotary parameters. hidden_size = config.hidden_size diff --git a/tensorrt_llm/_torch/models/modeling_gemma3.py b/tensorrt_llm/_torch/models/modeling_gemma3.py index 24ba665afbf..f9513f7bccd 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3.py @@ -65,7 +65,12 @@ def __init__( rope_params = RopeParams.from_config(config) self.attention_window_size = None if is_sliding: - rope_params.theta = config.rope_local_base_freq + # transformers>=5.x moved rope_local_base_freq into rope_parameters + rope_params.theta = getattr( + config, 'rope_local_base_freq', + None) or (config.rope_parameters.get( + 'sliding_attention', {}).get('rope_theta', 10000) if + getattr(config, 'rope_parameters', None) else 10000) rope_params.scale_type = RotaryScalingType.none rope_params.scale = 1.0 self.attention_window_size = config.sliding_window diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 5dd99755dc6..cc61ba08bb4 100755 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -680,6 +680,9 @@ class RotaryScalingType(IntEnum): @staticmethod def from_string(s): + # 'default' or None means standard RoPE with no scaling (transformers>=5.x) + if s is None or s == 'default': + return RotaryScalingType.none try: return RotaryScalingType[s] except KeyError: diff --git a/tests/unittest/_torch/modeling/test_modeling_cohere2.py b/tests/unittest/_torch/modeling/test_modeling_cohere2.py index 20c2e88fe69..29608be217a 100644 --- a/tests/unittest/_torch/modeling/test_modeling_cohere2.py +++ b/tests/unittest/_torch/modeling/test_modeling_cohere2.py @@ -3,7 +3,7 @@ import torch from transformers import Cohere2Config from transformers import Cohere2ForCausalLM as HFCohere2ForCausalLM -from transformers.cache_utils import HybridCache +from transformers.cache_utils import DynamicCache import tensorrt_llm from tensorrt_llm._torch.attention_backend.utils import get_attention_backend @@ -161,13 +161,7 @@ def test_cohere2_allclose_to_hf(self) -> None: # Initialize the hugging face model hf_cohere2 = HFCohere2ForCausalLM(cohere2_config).to(dtype).to(device).eval() - hf_cache = HybridCache( - config=cohere2_config, - max_batch_size=batch_size, - max_cache_len=10, - device=device, - dtype=dtype, - ) + hf_cache = DynamicCache(config=cohere2_config) # Initialize the TRT-LLM model model_config = ModelConfig(pretrained_config=cohere2_config) diff --git a/tests/unittest/_torch/modeling/test_modeling_exaone4.py b/tests/unittest/_torch/modeling/test_modeling_exaone4.py index 931828be848..7f85cdac2a3 100644 --- a/tests/unittest/_torch/modeling/test_modeling_exaone4.py +++ b/tests/unittest/_torch/modeling/test_modeling_exaone4.py @@ -26,7 +26,7 @@ class Exaone4Config(PretrainedConfig): SKIP_EXAONE4_HF_ACCURACY_TEST = True from _torch.helpers import create_mock_cuda_graph_runner -from transformers.cache_utils import HybridCache +from transformers.cache_utils import DynamicCache from utils.util import getSMVersion import tensorrt_llm @@ -248,11 +248,7 @@ def test_exaone4_allclose_to_hf(self, scenario: Scenario) -> None: num_kv_heads = exaone4.config.num_key_value_heads max_seq_len = num_blocks * tokens_per_block batch_size = 1 - hf_cache = HybridCache(config=exaone4_config, - max_batch_size=batch_size, - max_cache_len=max_seq_len, - device=device, - dtype=dtype) + hf_cache = DynamicCache(config=exaone4_config) if dtype == torch.half: kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF elif dtype == torch.bfloat16: diff --git a/tests/unittest/_torch/modeling/test_modeling_gemma3.py b/tests/unittest/_torch/modeling/test_modeling_gemma3.py index 6b532b9b1c6..eef132d3004 100644 --- a/tests/unittest/_torch/modeling/test_modeling_gemma3.py +++ b/tests/unittest/_torch/modeling/test_modeling_gemma3.py @@ -7,7 +7,7 @@ from transformers import Gemma3Config from transformers import Gemma3ForCausalLM as HFGemma3ForCausalLM from transformers import Gemma3TextConfig -from transformers.cache_utils import HybridCache +from transformers.cache_utils import DynamicCache import tensorrt_llm from tensorrt_llm._torch.attention_backend import (AttentionMetadata, @@ -285,11 +285,7 @@ def test_gemma3_allclose_to_hf(self, scenario: Scenario) -> None: hf_gemma3 = HFGemma3ForCausalLM(gemma3_config).to(dtype).to( device).eval() - hf_cache = HybridCache(config=gemma3_config, - max_batch_size=batch_size, - max_cache_len=10, - device=device, - dtype=dtype) + hf_cache = DynamicCache(config=gemma3_config) model_config = ModelConfig(pretrained_config=gemma3_config, attn_backend=backend)