Skip to content

Conversation

@KerwinKai
Copy link
Contributor

Motivation

In addition to the previously supported qwen2.5 vl, more multimodal language models have emerged, such as omni, dsvl, internvl, qwen3vl, etc. An abstract solution needs to be designed to easily support the training of draft models of these new models.

Modifications

  1. QwenVLOnlineEagle3Model was removed because, for the draft model, only hidden states and the target are needed for SFT fine-tuning; pixel_values, image_grid_thw, etc., only need to be provided to the target model during the prepare_data stage. Therefore, for the draft model, whether it's LLM or MLLM, it has the same input and output, a structure that can be represented by OnlineEagle3Model.
  2. The functionality for loading MLLM models is also placed in the HFEagle3TargetModel class.
  3. Supports the process_info function in preprocessing.py for different MLLM models to obtain input features for image, video, and audio.

Testing of this design is still ongoing. I will try to add more support for the aforementioned MLLM models in this PR, and I very much hope to get guidance on whether the changes to this class abstraction are concise enough.

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @KerwinKai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Eagle3 training framework by introducing a more abstract and generalized approach to supporting Multimodal Large Language Models (MLLMs). The changes consolidate MLLM-specific logic into common components, streamline the loading of various MLLM architectures, and improve data preprocessing capabilities to handle a broader range of multimodal inputs. This refactoring allows for easier integration and training of new MLLM models, moving beyond the previously limited support for Qwen2.5-VL.

Highlights

  • Abstracted MLLM Training: The specialized QwenVLOnlineEagle3Model has been removed, and the generic OnlineEagle3Model now uniformly handles the training of both Large Language Model (LLM) and Multimodal Large Language Model (MLLM) draft models by abstracting their input and output structures.
  • Centralized MLLM Loading: The HFEagle3TargetModel class has been enhanced to manage the loading of various MLLM architectures, including Qwen2.5-VL and Qwen3-Omni-Moe, providing a more unified approach to integrating different target models.
  • Enhanced Multimodal Data Preprocessing: The preprocessing.py module now dynamically supports different MLLM processors, such as Qwen2_5_VLProcessor and Qwen3OmniMoeProcessor, allowing for flexible handling of diverse multimodal input features including images, videos, and audio.
  • New Dependency Added: The qwen-omni-utils package has been added to the project's requirements, enabling support for the Qwen3-Omni-Moe model.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@KerwinKai KerwinKai changed the title Add Eagle3 training for more MLLM model [WIP] Add Eagle3 training for more MLLM model Nov 15, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively refactors the training pipeline to support a wider range of Multimodal Language Models (MLLMs) for Eagle3 training. The key changes, such as removing the specialized QwenVLOnlineEagle3Model and generalizing HFEagle3TargetModel and preprocess_vlm_conversations, significantly improve the modularity and maintainability of the codebase. This abstraction makes it much easier to add new MLLMs in the future.

My review includes suggestions to further enhance the scalability of the model loading mechanism and to clean up some minor redundancies in the code. Overall, this is a great step towards a more flexible and extensible MLLM training framework.

Comment on lines 245 to 250
except ImportError:
process_vision_info = None
raise ImportError(
"qwen_vl_utils is required for MLLM preprocessing but is not installed. "
"Please install it to use MLLM features."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assignment process_vision_info = None is redundant as an ImportError is raised on the next line, which will exit the function. You can safely remove this line for cleaner code.

            except ImportError:
                raise ImportError(
                    "qwen_vl_utils is required for MLLM preprocessing but is not installed. "
                    "Please install it to use MLLM features."
                )

Comment on lines 256 to 261
except ImportError:
process_mm_info = None
raise ImportError(
"qwen_omni_utils is required for MLLM preprocessing but is not installed. "
"Please install it to use MLLM features."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous comment, the assignment process_mm_info = None is redundant because an ImportError is raised immediately after. This line can be removed to improve code clarity.

            except ImportError:
                raise ImportError(
                    "qwen_omni_utils is required for MLLM preprocessing but is not installed. "
                    "Please install it to use MLLM features."
                )

Comment on lines 145 to 148
_mllm_model_pool = [
"Qwen2_5_VLForConditionalGeneration",
"Qwen3OmniMoeForConditionalGeneration",
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To make adding new MLLM models more scalable, consider using a dictionary to map architecture names to their loader information instead of a list. This is more descriptive and can be used to simplify the model loading logic, as I'll suggest in another comment. The todo comment can be moved here to keep it close to the relevant model.

Suggested change
_mllm_model_pool = [
"Qwen2_5_VLForConditionalGeneration",
"Qwen3OmniMoeForConditionalGeneration",
]
_mllm_loader_map = {
"Qwen2_5_VLForConditionalGeneration": ("transformers", "Qwen2_5_VLForConditionalGeneration"),
# todo: change load method from `modelscope` to `transformers` after new version release
"Qwen3OmniMoeForConditionalGeneration": ("modelscope", "Qwen3OmniMoeThinkerForConditionalGeneration"),
}

Comment on lines 181 to 192
if architecture in cls._mllm_model_pool:
if architecture == "Qwen2_5_VLForConditionalGeneration":
from transformers import Qwen2_5_VLForConditionalGeneration

auto_model_loader = Qwen2_5_VLForConditionalGeneration
elif architecture == "Qwen3OmniMoeForConditionalGeneration":
# todo: change load method from `modelscope` to `transformers` after new version release
from modelscope import Qwen3OmniMoeThinkerForConditionalGeneration

auto_model_loader = Qwen3OmniMoeThinkerForConditionalGeneration
else:
auto_model_loader = AutoModelForCausalLM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Following up on my previous comment, you can now use the _mllm_loader_map to dynamically import and select the model loader. This removes the if/elif chain and makes the code more maintainable and easier to extend with new models, which aligns with the goal of this PR.

Suggested change
if architecture in cls._mllm_model_pool:
if architecture == "Qwen2_5_VLForConditionalGeneration":
from transformers import Qwen2_5_VLForConditionalGeneration
auto_model_loader = Qwen2_5_VLForConditionalGeneration
elif architecture == "Qwen3OmniMoeForConditionalGeneration":
# todo: change load method from `modelscope` to `transformers` after new version release
from modelscope import Qwen3OmniMoeThinkerForConditionalGeneration
auto_model_loader = Qwen3OmniMoeThinkerForConditionalGeneration
else:
auto_model_loader = AutoModelForCausalLM
if architecture in cls._mllm_loader_map:
import importlib
module_name, class_name = cls._mllm_loader_map[architecture]
try:
module = importlib.import_module(module_name)
auto_model_loader = getattr(module, class_name)
except ImportError:
raise ImportError(
f"Failed to import {class_name} from {module_name}. "
f"Please ensure the required packages for {architecture} are installed."
)
else:
auto_model_loader = AutoModelForCausalLM

@sleepcoo
Copy link
Collaborator

Great job!!!!! This enables us to support all MLLM models. @FrankLeeeee

@FrankLeeeee
Copy link
Collaborator

What is the current status of this PR now?

@KerwinKai KerwinKai marked this pull request as ready for review November 23, 2025 12:40
@KerwinKai KerwinKai changed the title [WIP] Add Eagle3 training for more MLLM model Add Eagle3 training for more MLLM model Nov 23, 2025
@KerwinKai
Copy link
Contributor Author

What is the current status of this PR now?

Hi, this PR is now ready for review. Meanwhile, under @sleepcoo 's guidance, we are also experimenting with training the draft model for Omni.

@justadogistaken
Copy link
Contributor

Hi. Thanks to your work. I've tried training draft for qwen2_5-vl with your work. I found that there is no places to get position_ids from target_model. get_rope_index will give 3D of position_ids, and it works well while trainning qwen vl model. If we don't do it, I guess there will be 2D of position_ids. However LlamaMutiRotaryEmbedding needs the 3D of position_ids.

position_ids, rope_deltas = self.target_model.model.get_rope_index(
                input_ids,
                image_grid_thw,
                None,
                second_per_grid_ts=None,
                attention_mask=attention_mask_tensor,
            )

@KerwinKai
Copy link
Contributor Author

Hi. Thanks to your work. I've tried training draft for qwen2_5-vl with your work. I found that there is no places to get position_ids from target_model. get_rope_index will give 3D of position_ids, and it works well while trainning qwen vl model. If we don't do it, I guess there will be 2D of position_ids. However LlamaMutiRotaryEmbedding needs the 3D of position_ids.

position_ids, rope_deltas = self.target_model.model.get_rope_index(
                input_ids,
                image_grid_thw,
                None,
                second_per_grid_ts=None,
                attention_mask=attention_mask_tensor,
            )

Hi, I haven't modified the training logic for Qwen2.5-VL in this PR, but I’ll run tests as soon as possible to identify the issue.

@FrankLeeeee
Copy link
Collaborator

Does this work for other VLMs which are not Qwen?

@FrankLeeeee
Copy link
Collaborator

there is conflict with the main branch.

@KerwinKai
Copy link
Contributor Author

there is conflict with the main branch.

I’ll resolve the merge conflicts as soon as possible and add support for DS-VL and InternVL based on the current architecture.

From the design of the Omni model (see modeling_qwen3_omni_moe.py#L3976 ), our Eagle-based acceleration can only apply to the first-stage thinker model’s generate step, as the talker model already incorporates its own MTP module.

Therefore, to effectively integrate Eagle into Omni’s inference pipeline within SGLang, we likely need to decouple the two-stage forward pass (thinker → talker). Given this, I recommend prioritizing the review of https://github.com/sgl-project/SpecForge/pull/251 first.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants