Skip to content

Conversation

@ardenma
Copy link

@ardenma ardenma commented Jan 18, 2026

Motivation

Finishes #251 to add eagle3 support for qwen3vl. I have not extensively tried to optimize benchmark performance, but the accuracy/performance numbers at the bottom are there to show that training works.

Modifications

@dcw02 was kind enough to do most of the plumbing and mrope business. I just added a few fixes, and one important change which was to have auxiliary hidden state layers default to starting at layer 3 (e.g., [3, 17, 32] for the 8B model) to avoid Qwen3-VL's deepstack layers (0-2), where vision feature injection timing differs between training and serving (in sglang). The inference (sglang) captures auxiliary hidden states after the residual addition but before deepstack is added https://github.com/sgl-project/sglang/blob/53609e5e5b2aa00eb60c9a9a61d9b31b38aa0067/python/sglang/srt/models/qwen3_vl.py#L603-L627 whereas for specforge training with HF we capture the layer outputs containing hidden_states + residual + deepstack.

Also pins nvidia-cudnn-cu12==9.16.0.29 because of performance regression affecting Conv3d in pytorch 2.9.1 (pytorch/pytorch#168167).

Related Issues

Requires these sglang changes: sgl-project/sglang#17276

Accuracy Test

Wandb metrics for Qwen3VL 8B trained on 100k samples of allava4v and DP=8:
image
image
image

Benchmark & Profiling

Sglang (with changes from sgl-project/sglang#17276) command: python3 -m sglang.launch_server --model Qwen/Qwen3-VL-8B-Instruct --speculative-algorithm EAGLE3 --speculative-draft-model-path ../SpecForge/outputs/Qwen3-VL-8B-eagle3/epoch_11_step_140000 --speculative-num-steps 6 --speculative-eagle-topk 4 --spec ulative-num-draft-tokens 16 --mem-fraction-static 0.75 --cuda-graph-max-bs 1 --tp 1 --trust-remote-code --host 0.0.0.0 --port 30000 --dtype bfloat16
Benchmark command: python3 benchmarks/bench_eagle3.py --model-path Qwen/Qwen3-VL-8B-Instruct --port 30000 --config-list 1,0,0,0 1,3,1,4 --benchmark-list mmstar:100 --skip-launch-server

{
    "mmstar": [
        {
            "batch_size": 1,
            "steps": null,
            "topk": null,
            "num_draft_tokens": null,
            "metrics": [
                {
                    "latency": 108.34728950168937,
                    "output_throughput": 172.75928254493292,
                    "accept_length": 2.1631803998613197,
                    "accuracy": 0.24,
                    "num_questions": 100,
                    "num_valid_predictions": 56,
                    "categorical_performance": null
                }
            ],
            "num_samples": 100
        }
    ]
}

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @ardenma, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces full support for Qwen3-VL and Qwen3-VL-Moe models within the Eagle3 training and inference pipeline. The changes encompass new model configurations, dedicated training scripts, and crucial adaptations to the core Eagle3 logic to correctly handle Qwen3-VL's multimodal architecture, including its unique vision feature injection and positional encoding. Additionally, a specific CUDA library version is pinned to address a performance issue.

Highlights

  • Qwen3VL Eagle3 Integration: Added comprehensive support for Qwen3-VL and Qwen3-VL-Moe models within the Eagle3 framework, including new configuration files and training scripts.
  • Auxiliary Hidden State Layer Adjustment: Modified the selection of auxiliary hidden state layers for Qwen3-VL models to start from layer 3, specifically avoiding the deepstack layers (0-2) where vision feature injection timing differs between training and serving.
  • Interleaved Multi-Rotary Position Embedding (m-RoPE): Implemented support for interleaved m-RoPE, a specialized positional encoding scheme, for Qwen3-VL models to handle their unique multimodal input structure.
  • CUDA Dependency Pinning: Pinned the nvidia-cudnn-cu12 version to 9.16.0.29 to mitigate a known performance regression affecting Conv3d operations in PyTorch 2.9.1.
  • Enhanced VLM Data Handling: Updated the QwenVLOnlineEagle3Model to correctly process video-related arguments (pixel_values_videos, video_grid_thw, second_per_grid_ts) for Qwen3-VL models, ensuring proper data preparation for multimodal inputs.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@ardenma ardenma changed the title Add Qwen3VL Eagle3 Support Add Qwen3VL Eagle3 Training Support Jan 18, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Qwen3-VL models (both dense and MoE variants) to the Eagle3 speculative decoding framework. The changes are comprehensive, including new model configurations, example training scripts, and necessary modifications to the core training and modeling logic to accommodate the specifics of Qwen3-VL, such as its video capabilities and interleaved m-RoPE. The addition of a pinned nvidia-cudnn-cu12 dependency to address a known performance regression is also a thoughtful inclusion. My review focuses on a few areas for improvement, including a typo in a script path, hardcoded paths in an example script, and several opportunities for code simplification and cleanup by removing redundant logic and unused code. Overall, this is a solid contribution that can be made even better by addressing these minor issues.

Comment on lines 277 to 306
elif (
args.is_vlm
and draft_model_config.target_model_type == "qwen3_vl"
and args.tp_size == 1
):
from transformers import Qwen3VLForConditionalGeneration

target_model = (
Qwen3VLForConditionalGeneration.from_pretrained(
pretrained_model_name_or_path=args.target_model_path,
dtype=torch.bfloat16,
)
.eval()
.cuda()
)
elif (
args.is_vlm
and draft_model_config.target_model_type == "qwen3_vl_moe"
and args.tp_size == 1
):
from transformers import Qwen3VLMoeForConditionalGeneration

target_model = (
Qwen3VLMoeForConditionalGeneration.from_pretrained(
pretrained_model_name_or_path=args.target_model_path,
dtype=torch.bfloat16,
)
.eval()
.cuda()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for loading different Qwen3-VL models is duplicated across two elif blocks. While this works, you could refactor this to reduce code duplication. For instance, you could map the target_model_type to the corresponding model class and then have a single block for loading.

Comment on lines 328 to 331
filtered_target_kwargs = {}
for key, value in target_kwargs.items():
if key in {"input_ids", "attention_mask", "output_hidden_states", "use_cache"} or value is not None:
filtered_target_kwargs[key] = value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This loop for filtering None values from the keyword arguments can be written more concisely using a dictionary comprehension, which would improve readability.

Suggested change
filtered_target_kwargs = {}
for key, value in target_kwargs.items():
if key in {"input_ids", "attention_mask", "output_hidden_states", "use_cache"} or value is not None:
filtered_target_kwargs[key] = value
filtered_target_kwargs = {
key: value
for key, value in target_kwargs.items()
if value is not None or key in {"input_ids", "attention_mask", "output_hidden_states", "use_cache"}
}

Comment on lines +503 to +507
if self.target_model_type in {"qwen3_vl", "qwen3_vl_moe"}:
get_rope_kwargs["video_grid_thw"] = video_grid_thw
else:
get_rope_kwargs["video_grid_thw"] = video_grid_thw
get_rope_kwargs["second_per_grid_ts"] = second_per_grid_ts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The argument video_grid_thw is added to get_rope_kwargs in both the if and else branches, which is redundant. You can simplify the code by moving this assignment outside the conditional block.

Suggested change
if self.target_model_type in {"qwen3_vl", "qwen3_vl_moe"}:
get_rope_kwargs["video_grid_thw"] = video_grid_thw
else:
get_rope_kwargs["video_grid_thw"] = video_grid_thw
get_rope_kwargs["second_per_grid_ts"] = second_per_grid_ts
get_rope_kwargs["video_grid_thw"] = video_grid_thw
if self.target_model_type not in {"qwen3_vl", "qwen3_vl_moe"}:
get_rope_kwargs["second_per_grid_ts"] = second_per_grid_ts

Comment on lines 18 to 19
Qwen3VLConfig,
Qwen3VLMoeConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The imports for Qwen3VLConfig and Qwen3VLMoeConfig appear to be unused in this file. To maintain code cleanliness, it's best to remove any unused imports.

Comment on lines +578 to +580
# can optionally be rewritten as:
# if position_ids.ndim == 2:
# position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out code block appears to be a developer note. It should be removed from the final version of the code to improve clarity.

@ardenma ardenma force-pushed the arden/qwen3-vl-eagle3 branch from 55b9cdb to 8ab4467 Compare January 18, 2026 00:28
@ardenma ardenma marked this pull request as ready for review January 18, 2026 00:45
@dcw02 dcw02 mentioned this pull request Jan 18, 2026
6 tasks
@ardenma
Copy link
Author

ardenma commented Jan 19, 2026

@sleepcoo fixed the lint issues, sorry about that.

@ardenma ardenma force-pushed the arden/qwen3-vl-eagle3 branch from c5fc2a1 to eb08c3a Compare January 19, 2026 16:50
@ardenma
Copy link
Author

ardenma commented Jan 19, 2026

@sleepcoo thanks for bearing with me as I figure out the SpecForge review/testing process. I fixed the uv lock issue, ran pre-commit run --all-files
image
and ran the unit tests with python -m unittest:
image

@narutolhy
Copy link

Hi @ardenma Thank you for your work. I'm also training the Qwen3-VL spec model. Could you share your environment and scripts? Using the example scripts, I encounter out-of-memory errors with tp=1, and it only runs normally with tp=4, but the accuracy is only 0.3, using the same dataset. The training speed is also very slow; it only completed 300 steps in 10 hours with 4 GPUs and tp=4.

@ardenma
Copy link
Author

ardenma commented Jan 20, 2026

@narutolhy as mentioned in one of the comments there's a regression in the cudnn version that ships with torch==2.9.1 that causes performance regressions and large memory usage, try pip install nvidia-cudnn-cu12==9.16.0.29. I was running with DP=8, TP=1 (on big gpus).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants