Skip to content

Conversation

@xiaoxi-s
Copy link
Collaborator

@xiaoxi-s xiaoxi-s commented Jan 2, 2026

Motivation

Support multiple Llama3DecoderLayers in LlamaForCausalLMEagle3.

Modifications

A LlamaForCausalLMEagle3 instance can now hold a arbitrary number (>=1) of Llama3DecoderLayerss configured via num_hidden_layers option of LlamaConfig. The test for the new config is also updated in tests/test_modeling/test_draft/test_llama3.py.

Related Issues

Sub-issue of #374

Accuracy Test

Tested under tests/test_modeling/test_draft/test_llama3.py.

Benchmark & Profiling

N/A

Checklist

@xiaoxi-s xiaoxi-s self-assigned this Jan 2, 2026
@xiaoxi-s xiaoxi-s added the enhancement New feature or request label Jan 2, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @xiaoxi-s, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the flexibility of the Llama3 draft decoder by enabling the configuration of multiple Multi-Layer Perceptrons (MLPs) within each decoder layer. Previously limited to a single MLP, the model can now be instantiated with a user-defined number of MLPs, allowing for more complex and potentially more powerful draft model architectures. The changes include adding a new command-line option for specifying the number of MLPs, updating the core model classes to handle this configuration, and ensuring comprehensive test coverage for the new functionality.

Highlights

  • Multi-layer MLP Support: The Llama3 draft decoder now supports an arbitrary number of Multi-Layer Perceptrons (MLPs) within its LlamaDecoderLayer, configurable via a new parameter.
  • New Command-Line Argument: A --num-draft-hidden-layers argument has been added to scripts/train_eagle3.py to specify the desired number of MLPs for the draft model decoder.
  • Architectural Updates: The LlamaDecoderLayer and LlamaForCausalLMEagle3 classes have been modified to incorporate and utilize the num_draft_hidden_layers parameter, replacing a single MLP with a nn.Sequential module containing multiple MLPs.
  • Test Coverage: Existing unit tests in tests/test_modeling/test_draft/test_llama3.py have been updated to validate the correct initialization and behavior of the model with the new multi-MLP configuration.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully adds support for a multi-layer MLP in the draft decoder, with corresponding changes to the training script and tests. The implementation is consistent and the new functionality is well-tested. I have provided a couple of suggestions to improve the architecture of the multi-layer MLP for better training stability by incorporating residual connections for each MLP block. Additionally, I've pointed out a minor indentation issue in one of the test files.

I am having trouble creating individual review comments. Click here to see my feedback.

specforge/modeling/draft/llama3_eagle.py (1252)

medium

Using nn.Sequential here creates a deep MLP stack without any intermediate normalization or residual connections. This can lead to training instability, especially if num_draft_hidden_layers is large. A more standard and stable approach is to use nn.ModuleList to treat each MLP as a separate block. I'll add a related suggestion on the forward method to complete this change.

        self.mlps = nn.ModuleList([LlamaMLP(config) for _ in range(num_draft_hidden_layers)])

specforge/modeling/draft/llama3_eagle.py (1308-1311)

medium

To accompany the change to nn.ModuleList for self.mlps, the forward pass should be updated to loop through the MLPs. This applies a residual connection around each MLP block, which is a more standard and stable architecture than a single deep MLP. This implementation reuses the same post_attention_layernorm for each block.

        for mlp in self.mlps:
            residual = hidden_states
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states = mlp(hidden_states)
            hidden_states = residual + hidden_states

tests/test_modeling/test_draft/test_llama3.py (63)

medium

This line has incorrect indentation. According to the PEP 8 style guide, indentation should be in multiples of 4 spaces. The line inside the for loop should be indented by 4 more spaces relative to the for statement.

            self.assertIsInstance(model.midlayer.mlps[i], LlamaMLP)

Copy link

@Dogacel Dogacel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess for true multi-layer architecture, we should duplicate the decoder, not MLP. Why you wanted to duplicate the MLP instead?

choices=["sglang", "hf", "custom"],
help="The backend of the target model",
)
model_group.add_argument(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable can be already set using the model config,

"num_hidden_layers": 1,

Which allows us to save & load models directly with the right number of hidden layers.

@xiaoxi-s xiaoxi-s changed the title Multi-layer MLP for Draft Decoder Multi-layer Decoder for Llama3 Draft Model Jan 11, 2026
@sleepcoo
Copy link
Collaborator

@uygnef

@uygnef
Copy link
Collaborator

uygnef commented Jan 15, 2026

@sleepcoo LGTM

@justadogistaken
Copy link
Contributor

hidden_states = torch.cat((input_emb, hidden_states), dim=-1)

Your implementation will do input_emb & hidden_states fusion every layer. May I ask why not just do fuse them in first layer. Them let others handle hidden_states only.

@xiaoxi-s
Copy link
Collaborator Author

Your implementation will do input_emb & hidden_states fusion every layer. May I ask why not just do fuse them in first layer. Them let others handle hidden_states only.

I believe the way of only iterating on hidden_states on later layers exists mostly in traditional transformer architectures. For the decoder layers used in draft model, this implementation keeps the initial token info available across all decoder layers for better speculative decoding instead of only next-token generation.

@Dogacel
Copy link

Dogacel commented Jan 16, 2026

hidden_states = torch.cat((input_emb, hidden_states), dim=-1)

Your implementation will do input_emb & hidden_states fusion every layer. May I ask why not just do fuse them in first layer. Them let others handle hidden_states only.

I agree, I think input embedding should only be injected in the first layer. Original EAGLE only has 1 decoder layer therefore it didn't matter for them where they have done this injection. Injecting the same data to each layer would result in excessive computations.

However changing this means you have to update some code to make sure your input/output shapes match. For example attention takes 2 * hidden_size and outputs hidden_size.

I think it would be nice to configure if you want to map from (2 * hidden_size -> hidden_size) and save compute on deeper layers, or if you want to keep the full 2 * hidden_size shape and be more expressive.

self.q_proj = nn.Linear(
self.hidden_size * 2, self.num_heads * self.head_dim, bias=False
)

@uygnef
Copy link
Collaborator

uygnef commented Jan 21, 2026

@xiaoxi-s Hi, xiaoxi.
A reviewer noted that input_emb is concatenated to hidden_states in every layer. Could you share the rationale behind this design?

It seems fusing them only in the first layer could be more efficient and less redundant. Have you experimented with or compared both approaches? If not, could we test the alternative to verify performance?

Appreciate your input.

@xiaoxi-s
Copy link
Collaborator Author

xiaoxi-s commented Jan 22, 2026

It seems fusing them only in the first layer could be more efficient and less redundant. Have you experimented with or compared both approaches? If not, could we test the alternative to verify performance?

Appreciate your input.

Working on the experiments. Will share after it's done.

#449 contains the implementation of the more traditional decoder layer as the additional layers. Will share their benchmarks on SharedGPT after the experiment is done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants