-
Notifications
You must be signed in to change notification settings - Fork 143
Add Eagle3 training for more MLLM model #302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @KerwinKai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Eagle3 training framework by introducing a more abstract and generalized approach to supporting Multimodal Large Language Models (MLLMs). The changes consolidate MLLM-specific logic into common components, streamline the loading of various MLLM architectures, and improve data preprocessing capabilities to handle a broader range of multimodal inputs. This refactoring allows for easier integration and training of new MLLM models, moving beyond the previously limited support for Qwen2.5-VL. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request effectively refactors the training pipeline to support a wider range of Multimodal Language Models (MLLMs) for Eagle3 training. The key changes, such as removing the specialized QwenVLOnlineEagle3Model and generalizing HFEagle3TargetModel and preprocess_vlm_conversations, significantly improve the modularity and maintainability of the codebase. This abstraction makes it much easier to add new MLLMs in the future.
My review includes suggestions to further enhance the scalability of the model loading mechanism and to clean up some minor redundancies in the code. Overall, this is a great step towards a more flexible and extensible MLLM training framework.
specforge/data/preprocessing.py
Outdated
| except ImportError: | ||
| process_vision_info = None | ||
| raise ImportError( | ||
| "qwen_vl_utils is required for MLLM preprocessing but is not installed. " | ||
| "Please install it to use MLLM features." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assignment process_vision_info = None is redundant as an ImportError is raised on the next line, which will exit the function. You can safely remove this line for cleaner code.
except ImportError:
raise ImportError(
"qwen_vl_utils is required for MLLM preprocessing but is not installed. "
"Please install it to use MLLM features."
)
specforge/data/preprocessing.py
Outdated
| except ImportError: | ||
| process_mm_info = None | ||
| raise ImportError( | ||
| "qwen_omni_utils is required for MLLM preprocessing but is not installed. " | ||
| "Please install it to use MLLM features." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the previous comment, the assignment process_mm_info = None is redundant because an ImportError is raised immediately after. This line can be removed to improve code clarity.
except ImportError:
raise ImportError(
"qwen_omni_utils is required for MLLM preprocessing but is not installed. "
"Please install it to use MLLM features."
)| _mllm_model_pool = [ | ||
| "Qwen2_5_VLForConditionalGeneration", | ||
| "Qwen3OmniMoeForConditionalGeneration", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make adding new MLLM models more scalable, consider using a dictionary to map architecture names to their loader information instead of a list. This is more descriptive and can be used to simplify the model loading logic, as I'll suggest in another comment. The todo comment can be moved here to keep it close to the relevant model.
| _mllm_model_pool = [ | |
| "Qwen2_5_VLForConditionalGeneration", | |
| "Qwen3OmniMoeForConditionalGeneration", | |
| ] | |
| _mllm_loader_map = { | |
| "Qwen2_5_VLForConditionalGeneration": ("transformers", "Qwen2_5_VLForConditionalGeneration"), | |
| # todo: change load method from `modelscope` to `transformers` after new version release | |
| "Qwen3OmniMoeForConditionalGeneration": ("modelscope", "Qwen3OmniMoeThinkerForConditionalGeneration"), | |
| } |
| if architecture in cls._mllm_model_pool: | ||
| if architecture == "Qwen2_5_VLForConditionalGeneration": | ||
| from transformers import Qwen2_5_VLForConditionalGeneration | ||
|
|
||
| auto_model_loader = Qwen2_5_VLForConditionalGeneration | ||
| elif architecture == "Qwen3OmniMoeForConditionalGeneration": | ||
| # todo: change load method from `modelscope` to `transformers` after new version release | ||
| from modelscope import Qwen3OmniMoeThinkerForConditionalGeneration | ||
|
|
||
| auto_model_loader = Qwen3OmniMoeThinkerForConditionalGeneration | ||
| else: | ||
| auto_model_loader = AutoModelForCausalLM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following up on my previous comment, you can now use the _mllm_loader_map to dynamically import and select the model loader. This removes the if/elif chain and makes the code more maintainable and easier to extend with new models, which aligns with the goal of this PR.
| if architecture in cls._mllm_model_pool: | |
| if architecture == "Qwen2_5_VLForConditionalGeneration": | |
| from transformers import Qwen2_5_VLForConditionalGeneration | |
| auto_model_loader = Qwen2_5_VLForConditionalGeneration | |
| elif architecture == "Qwen3OmniMoeForConditionalGeneration": | |
| # todo: change load method from `modelscope` to `transformers` after new version release | |
| from modelscope import Qwen3OmniMoeThinkerForConditionalGeneration | |
| auto_model_loader = Qwen3OmniMoeThinkerForConditionalGeneration | |
| else: | |
| auto_model_loader = AutoModelForCausalLM | |
| if architecture in cls._mllm_loader_map: | |
| import importlib | |
| module_name, class_name = cls._mllm_loader_map[architecture] | |
| try: | |
| module = importlib.import_module(module_name) | |
| auto_model_loader = getattr(module, class_name) | |
| except ImportError: | |
| raise ImportError( | |
| f"Failed to import {class_name} from {module_name}. " | |
| f"Please ensure the required packages for {architecture} are installed." | |
| ) | |
| else: | |
| auto_model_loader = AutoModelForCausalLM |
|
Great job!!!!! This enables us to support all MLLM models. @FrankLeeeee |
support qwen3-vl & refactor preprocess
|
What is the current status of this PR now? |
Hi, this PR is now ready for review. Meanwhile, under @sleepcoo 's guidance, we are also experimenting with training the draft model for Omni. |
|
Hi. Thanks to your work. I've tried training draft for qwen2_5-vl with your work. I found that there is no places to get position_ids from target_model. get_rope_index will give 3D of position_ids, and it works well while trainning qwen vl model. If we don't do it, I guess there will be 2D of position_ids. However |
Hi, I haven't modified the training logic for Qwen2.5-VL in this PR, but I’ll run tests as soon as possible to identify the issue. |
|
Does this work for other VLMs which are not Qwen? |
|
there is conflict with the main branch. |
I’ll resolve the merge conflicts as soon as possible and add support for DS-VL and InternVL based on the current architecture. From the design of the Omni model (see modeling_qwen3_omni_moe.py#L3976 ), our Eagle-based acceleration can only apply to the first-stage thinker model’s generate step, as the talker model already incorporates its own MTP module. Therefore, to effectively integrate Eagle into Omni’s inference pipeline within SGLang, we likely need to decouple the two-stage forward pass (thinker → talker). Given this, I recommend prioritizing the review of https://github.com/sgl-project/SpecForge/pull/251 first. |
Motivation
In addition to the previously supported qwen2.5 vl, more multimodal language models have emerged, such as omni, dsvl, internvl, qwen3vl, etc. An abstract solution needs to be designed to easily support the training of draft models of these new models.
Modifications
QwenVLOnlineEagle3Modelwas removed because, for the draft model, only hidden states and the target are needed for SFT fine-tuning;pixel_values,image_grid_thw, etc., only need to be provided to the target model during the prepare_data stage. Therefore, for the draft model, whether it's LLM or MLLM, it has the same input and output, a structure that can be represented byOnlineEagle3Model.HFEagle3TargetModelclass.process_infofunction inpreprocessing.pyfor different MLLM models to obtain input features for image, video, and audio.Testing of this design is still ongoing. I will try to add more support for the aforementioned MLLM models in this PR, and I very much hope to get guidance on whether the changes to this class abstraction are concise enough.
Checklist