Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c0ebfcf
implement multi-layer mlp for llama3 draft
xiaoxi-s Jan 1, 2026
fb5f389
update old hidden_layers comment in config utils.py
xiaoxi-s Jan 1, 2026
ae050ee
update unittest for draft llama3 for the num_draft_hidden_layers param
xiaoxi-s Jan 1, 2026
17e8e6f
reformat code with pre-commit hooks
xiaoxi-s Jan 2, 2026
8c886a1
revert old changes based on num_draft_hidden_layers
xiaoxi-s Jan 4, 2026
88620b4
llama3 draft model with multi-layer decoders implementation and tests
xiaoxi-s Jan 11, 2026
c02bc6d
precommit format
xiaoxi-s Jan 11, 2026
dff80df
Merge branch 'main' into llama3-mlp-for-draft-decoder
xiaoxi-s Jan 11, 2026
71e21cb
fix signature of past_key_values
xiaoxi-s Jan 11, 2026
e1f7fe6
fix specforge core online eagle3 model and QWen model
xiaoxi-s Jan 11, 2026
b9c228d
Merge branch 'main' into llama3-mlp-for-draft-decoder
xiaoxi-s Jan 14, 2026
e1e9695
Merge branch 'main' into llama3-mlp-for-draft-decoder
xiaoxi-s Jan 14, 2026
70d3cdc
add basic decoder layer for experiment
xiaoxi-s Jan 22, 2026
5c35e80
Merge branch 'main' into llama3-mlp-for-draft-decoder
xiaoxi-s Jan 23, 2026
4497a2d
Merge branch 'llama3-mlp-for-draft-decoder' into llama3-basic-multi-l…
xiaoxi-s Jan 26, 2026
7439d0d
use llamadecoderlayer for both fused and not fused input
xiaoxi-s Feb 7, 2026
12b4b99
reformat file
xiaoxi-s Feb 7, 2026
2a3969c
Merge branch 'main' into llama3-mlp-for-draft-decoder
xiaoxi-s Feb 7, 2026
95ab95a
simplify basic layer's implementation to work with sglang's loading l…
xiaoxi-s Feb 9, 2026
e4bfba0
rm redundant class
xiaoxi-s Feb 9, 2026
fd82bf2
rename eagle3 specific classes for clarity
xiaoxi-s Feb 9, 2026
1dea3cf
Merge branch 'llama3-mlp-for-draft-decoder' into llama3-basic-multi-l…
xiaoxi-s Feb 9, 2026
8912516
Merge pull request #449 from sgl-project/llama3-basic-multi-layer-dec…
xiaoxi-s Feb 28, 2026
47a6844
Merge branch 'main' into llama3-mlp-for-draft-decoder
xiaoxi-s Feb 28, 2026
d12fffb
fix ci errors
xiaoxi-s Feb 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions scripts/train_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@
from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead
from specforge.optimizer import BF16Optimizer
from specforge.tracker import create_tracker
from specforge.utils import (
get_last_checkpoint,
print_on_rank0,
print_with_rank,
)
from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank


def parse_args():
Expand Down
22 changes: 12 additions & 10 deletions specforge/core/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def forward(
target: torch.Tensor,
loss_mask: torch.Tensor,
hidden_states: torch.Tensor,
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
position_ids: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.Tensor] = None,
is_vlm: bool = False,
Expand Down Expand Up @@ -166,6 +166,7 @@ def forward(
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
past_key_values_length = 0
draft_num_hidden_layers = self.draft_model.num_hidden_layers

# Step 2: project the concatenated hidden states to the target hidden size
hidden_states = self.draft_model.project_hidden_states(hidden_states)
Expand Down Expand Up @@ -208,11 +209,11 @@ def forward(
# for sequence paralle, position mask and input ids will split by sequence dim, need to keep origin for ttt shift
global_input_ids = input_ids
if self.attention_backend in ["sdpa", "fa", "usp"]:
cache_hidden = [[], []]
caches_hidden = [[[], []] for _ in range(draft_num_hidden_layers)]
past_key_values = None
elif self.attention_backend == "flex_attention":
cache_hidden = None
past_key_values = DynamicCache()
caches_hidden = None
past_key_values = [DynamicCache() for _ in range(draft_num_hidden_layers)]
else:
raise ValueError(f"Unknown attention backend: {self.attention_backend}")

Expand All @@ -239,7 +240,7 @@ def forward(
hidden_states_out = self.draft_model.backbone(
input_embeds=inputs_embeds,
hidden_states=state.hidden_states,
cache_hidden=cache_hidden,
caches_hidden=caches_hidden,
attention_mask=state.attention_mask,
position_ids=state.position_ids,
past_key_values=past_key_values,
Expand Down Expand Up @@ -416,7 +417,7 @@ def forward(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
position_ids: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -451,13 +452,14 @@ def forward(
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
past_key_values_length = 0
draft_num_hidden_layers = self.draft_model.num_hidden_layers

# Step 2: project the concatenated hidden states to the target hidden size
hidden_states = self.draft_model.project_hidden_states(hidden_states)

# Step 3: process kv cache, position ids and position ids
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
past_key_values_length = past_key_values[0][0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length

if position_ids is None:
Expand Down Expand Up @@ -507,11 +509,11 @@ def forward(
vlosses = []
acces = []
if self.attention_backend in ["sdpa", "fa"]:
cache_hidden = [[], []]
caches_hidden = [[[], []] for _ in range(draft_num_hidden_layers)]
past_key_values = None
elif self.attention_backend == "flex_attention":
cache_hidden = None
past_key_values = DynamicCache()
past_key_values = [DynamicCache() for _ in range(draft_num_hidden_layers)]
else:
raise ValueError(f"Unknown attention backend: {self.attention_backend}")

Expand All @@ -528,7 +530,7 @@ def forward(
hidden_states_out = self.draft_model.backbone(
input_embeds=inputs_embeds,
hidden_states=hidden_states,
cache_hidden=cache_hidden,
caches_hidden=caches_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
Expand Down
6 changes: 3 additions & 3 deletions specforge/modeling/draft/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import json
import os
from abc import ABC, abstractmethod
from typing import Optional
from typing import List, Optional

import torch
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -98,10 +98,10 @@ def backbone(
self,
input_embeds: torch.Tensor,
hidden_states: torch.Tensor,
cache_hidden: torch.Tensor,
caches_hidden: List[List[List[torch.Tensor]]],
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: Optional[Cache] = None,
past_key_values: Optional[List[Cache]] = None,
use_cache: bool = True,
) -> torch.Tensor:
"""
Expand Down
119 changes: 87 additions & 32 deletions specforge/modeling/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config):
def __init__(self, config, fused_input=True):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
Expand All @@ -524,14 +524,19 @@ def __init__(self, config):
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings

if fused_input:
in_features_size = self.hidden_size * 2
else:
in_features_size = self.hidden_size

self.q_proj = nn.Linear(
self.hidden_size * 2, self.num_heads * self.head_dim, bias=False
in_features_size, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False
in_features_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False
in_features_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
Expand Down Expand Up @@ -972,8 +977,8 @@ class LlamaUSPFlashAttention(LlamaAttention):
LlamaUSPFlashAttention with Trainable Ring Attention & Correct Eagle3 Branch Merging.
"""

def __init__(self, config):
super().__init__(config)
def __init__(self, config, fused_input=True):
super().__init__(config, fused_input=fused_input)
assert (
dist.is_initialized()
), f"LlamaUSPAttention requires torch.distributed; call init_distributed first."
Expand Down Expand Up @@ -1220,28 +1225,38 @@ def forward(self, hidden_states):
return self.weight * hidden_states.to(input_dtype)


class LlamaDecoderLayer(nn.Module):
def __init__(self, config, attention_backend: str = "sdpa"):
class Eagle3LlamaDecoderLayer(nn.Module):
def __init__(self, config, attention_backend: str = "sdpa", fused_input=True):
super().__init__()
self.hidden_size = config.hidden_size
self.fused_input = fused_input

if attention_backend == "sdpa":
self.self_attn = LlamaAttention(config=config)
self.self_attn = LlamaAttention(config=config, fused_input=self.fused_input)
elif attention_backend == "flex_attention":
print_with_rank("Using flex attention on draft model training!")
self.self_attn = LlamaFlexAttention(config=config)
self.self_attn = LlamaFlexAttention(
config=config, fused_input=self.fused_input
)
elif attention_backend == "fa":
self.self_attn = LlamaFlashAttention(config=config)
self.self_attn = LlamaFlashAttention(
config=config, fused_input=self.fused_input
)
elif attention_backend == "usp":
self.self_attn = LlamaUSPFlashAttention(config=config)
self.self_attn = LlamaUSPFlashAttention(
config=config, fused_input=self.fused_input
)
else:
raise ValueError(f"Unknown attention backend {attention_backend}")

self.attention_backend = attention_backend
self.mlp = LlamaMLP(config)
# self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if self.fused_input:
self.input_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
# if self.index!=0:

self.post_attention_layernorm = LlamaRMSNorm(
Expand All @@ -1250,8 +1265,8 @@ def __init__(self, config, attention_backend: str = "sdpa"):

def forward(
self,
input_emb: torch.Tensor,
hidden_states: torch.Tensor,
input_emb: Optional[torch.Tensor] = None,
cache_hidden: List[List[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
Expand All @@ -1278,9 +1293,11 @@ def forward(
residual = hidden_states

hidden_states = self.hidden_norm(hidden_states)
input_emb = self.input_layernorm(input_emb)

hidden_states = torch.cat((input_emb, hidden_states), dim=-1)
if self.fused_input:
input_emb = self.input_layernorm(input_emb)
hidden_states = torch.cat((input_emb, hidden_states), dim=-1)

# Self Attention
hidden_states = self.self_attn(
cache_hidden=cache_hidden,
Expand Down Expand Up @@ -1317,7 +1334,24 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None:
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, config.pad_token_id
)
self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend)
self.num_hidden_layers = config.num_hidden_layers

# Multi-layer decoder for Eagle3 draft model
# First being the embeds + hidden_states fuse layer
self.fuse_layer = Eagle3LlamaDecoderLayer(
config, attention_backend=attention_backend, fused_input=True
)
# the rests are the traditional decoder layers with only hidden_states as inputs
self.additional_layers = None
if self.num_hidden_layers > 1:
self.additional_layers = nn.ModuleList(
[
Eagle3LlamaDecoderLayer(
config, attention_backend=attention_backend, fused_input=False
)
for _ in range(self.num_hidden_layers - 1)
]
)

if hasattr(config, "target_hidden_size"):
self.fc = torch.nn.Linear(
Expand Down Expand Up @@ -1355,11 +1389,15 @@ def forward(
position_ids (`torch.LongTensor`, *optional*): position ids of shape `(batch, seq_len)`
"""
if ttt_length == 1:
print_with_rank("using ttt_length 1, no need to cache hidden states")
cache_hidden = None
print_with_rank(
"using ttt_length 1, no need to cache hidden states for decoder layer(s)"
)
caches_hidden = None
else:
print_with_rank(f"using ttt_length {ttt_length}, caching hidden states")
cache_hidden = [[], []]
print_with_rank(
f"using ttt_length {ttt_length}, caching hidden states for decoder layer(s)"
)
caches_hidden = [[[], []] for _ in range(self.num_hidden_layers)]

batch_size, seq_length, _ = hidden_states.size()

Expand All @@ -1379,14 +1417,14 @@ def forward(

# fc
hidden_states = self.fc(hidden_states)
hidden_states = self.midlayer(
input_emb=inputs_embeds,

hidden_states = self.backbone(
input_embeds=inputs_embeds,
hidden_states=hidden_states,
cache_hidden=cache_hidden,
caches_hidden=caches_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
output_attentions=False,
use_cache=False,
)

Expand All @@ -1411,19 +1449,36 @@ def backbone(
self,
input_embeds: torch.Tensor,
hidden_states: torch.Tensor,
cache_hidden: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: Optional[Cache] = None,
caches_hidden: Optional[List[List[List[torch.Tensor]]]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Cache]] = None,
use_cache: bool = True,
) -> torch.Tensor:
return self.midlayer(
hidden_states = self.fuse_layer(
input_emb=input_embeds,
hidden_states=hidden_states,
cache_hidden=cache_hidden,
cache_hidden=caches_hidden[0] if caches_hidden is not None else None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
past_key_values=(
past_key_values[0] if past_key_values is not None else None
),
output_attentions=False,
use_cache=False,
)

if self.num_hidden_layers > 1:
for i, layer in enumerate(self.additional_layers):
hidden_states = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=(
past_key_values[i + 1] if past_key_values is not None else None
),
output_attentions=False,
use_cache=False,
)

return hidden_states
19 changes: 11 additions & 8 deletions tests/test_modeling/test_draft/test_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from unittest.mock import patch

import torch
import torch.nn as nn
from transformers import LlamaConfig

from specforge.modeling.draft.llama3_eagle import (
Expand Down Expand Up @@ -35,7 +36,7 @@ def setUp(self):
"model_type": "llama",
"num_attention_heads": 32,
"num_key_value_heads": 8,
"num_hidden_layers": 1,
"num_hidden_layers": 3,
"pad_token_id": 0,
"rms_norm_eps": 1e-05,
"tie_word_embeddings": False,
Expand All @@ -53,13 +54,15 @@ def tearDown(self):

def test_model_initialization(self):
model = LlamaForCausalLMEagle3(self.config)

self.assertIsInstance(model.midlayer.self_attn, LlamaAttention)
self.assertIsInstance(model.midlayer.mlp, LlamaMLP)
self.assertIsInstance(model.midlayer.hidden_norm, LlamaRMSNorm)
self.assertIsInstance(model.midlayer.input_layernorm, LlamaRMSNorm)
self.assertIsInstance(model.midlayer.post_attention_layernorm, LlamaRMSNorm)
self.assertEqual(model.midlayer.hidden_size, self.config.hidden_size)
self.assertEqual(model.num_hidden_layers, self.config.num_hidden_layers)
self.assertIsInstance(model.midlayers, nn.ModuleList)
for layer in model.midlayers:
self.assertIsInstance(layer.self_attn, LlamaAttention)
self.assertIsInstance(layer.mlp, LlamaMLP)
self.assertIsInstance(layer.hidden_norm, LlamaRMSNorm)
self.assertIsInstance(layer.input_layernorm, LlamaRMSNorm)
self.assertIsInstance(layer.post_attention_layernorm, LlamaRMSNorm)
self.assertEqual(layer.hidden_size, self.config.hidden_size)

def test_save_pretrained(self):
"""Test the model's save_pretrained functionality."""
Expand Down
Loading