Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions configs/qwen3-omni-eagle3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"architectures": [
"LlamaForCausalLMEagle3"
],
"bos_token_id": 151672,
"eos_token_id": 151673,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 6144,
"max_position_embeddings": 65536,
"model_type": "llama",
"num_attention_heads": 32,
"num_key_value_heads": 4,
"num_hidden_layers": 1,
"pad_token_id": 0,
"rms_norm_eps": 1e-06,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.28.1",
"use_cache": true,
"vocab_size": 152064,
"draft_vocab_size": 32000
}
29 changes: 29 additions & 0 deletions examples/run_qwen3_omni_eagle3_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/bin/bash

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)

# support tp1 train eagle3 for Qwen3-Omni-30B-A3B-Instruct
NUM_GPUS=${1:-1}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3_online.py \
--target-model-path Qwen/Qwen3-Omni-30B-A3B-Instruct \
--draft-model-config $ROOT_DIR/configs/qwen3-omni-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/allava4v_train.jsonl \
--output-dir $ROOT_DIR/outputs/Qwen3-Omni-30B-A3B-eagle3 \
--num-epochs 10 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length 8192 \
--dist-timeout 360 \
--chat-template qwen3-omni \
--cache-dir $ROOT_DIR/cache \
--embedding-key thinker.model.embed_tokens.weight \
--tp-size 1 \
--is-vlm \
--target-model-backend hf \
--min-pixels 50176 \
--max-pixels 200704
1 change: 1 addition & 0 deletions requirements-rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ torchaudio==2.8.0+rocm6.3
torchvision==0.23.0+rocm6.3
transformers==4.57.1
qwen-vl-utils==0.0.11
qwen-omni-utils==0.0.8
datasets
setuptools
tqdm
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ torchaudio==2.8.0
torchvision==0.23.0
transformers==4.57.1
qwen-vl-utils==0.0.11
qwen-omni-utils==0.0.8
datasets
setuptools
tqdm
Expand Down
113 changes: 38 additions & 75 deletions scripts/train_eagle3_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@
from tqdm import tqdm
from transformers import AutoProcessor, AutoTokenizer

from specforge import (
AutoDraftModelConfig,
AutoEagle3DraftModel,
OnlineEagle3Model,
QwenVLOnlineEagle3Model,
)
from specforge import AutoDraftModelConfig, AutoEagle3DraftModel, OnlineEagle3Model
from specforge.data import (
build_eagle3_dataset,
generate_vocab_mapping_file,
Expand Down Expand Up @@ -239,29 +234,13 @@ def build_target_model(
Returns:
The target model.
"""
if (
args.is_vlm
and draft_model_config.target_model_type == "qwen2_5_vl"
and args.tp_size == 1
):
from transformers import Qwen2_5_VLForConditionalGeneration

target_model = (
Qwen2_5_VLForConditionalGeneration.from_pretrained(
pretrained_model_name_or_path=args.target_model_path,
torch_dtype=torch.bfloat16,
)
.eval()
.cuda()
)
else:
target_model = get_eagle3_target_model(
pretrained_model_name_or_path=args.target_model_path,
backend=args.target_model_backend,
torch_dtype=torch.bfloat16,
device="cuda",
cache_dir=args.cache_dir,
)
target_model = get_eagle3_target_model(
pretrained_model_name_or_path=args.target_model_path,
backend=args.target_model_backend,
torch_dtype=torch.bfloat16,
device="cuda",
cache_dir=args.cache_dir,
)

# set the aux hidden states layers
if (
Expand Down Expand Up @@ -462,36 +441,31 @@ def run_forward(
data: dict,
target_model: Optional[Eagle3TargetModel] = None,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
kwargs = {
"input_ids": data["input_ids"].cuda(),
"attention_mask": data["attention_mask"].cuda(),
"loss_mask": data["loss_mask"].cuda(),
"is_vlm": args.is_vlm,
}
if args.is_vlm:
plosses, _, acces = eagle3_model(
input_ids=data["input_ids"].cuda(),
attention_mask=data["attention_mask"].cuda(),
loss_mask=data["loss_mask"].cuda(),
pixel_values=data["pixel_values"].cuda(),
image_grid_thw=data["image_grid_thw"].cuda(),
)
else:
eagle3_data = target_model.generate_eagle3_data(
input_ids=data["input_ids"].cuda(),
attention_mask=data["attention_mask"].cuda(),
loss_mask=data["loss_mask"].cuda(),
)

eagle3_data.input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids)
eagle3_data.attention_mask = get_dp_data_shard_from_tp(
eagle3_data.attention_mask
)
eagle3_data.loss_mask = get_dp_data_shard_from_tp(eagle3_data.loss_mask)
eagle3_data.target = get_dp_data_shard_from_tp(eagle3_data.target)
eagle3_data.hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states)

plosses, _, acces = eagle3_model(
input_ids=eagle3_data.input_ids,
attention_mask=eagle3_data.attention_mask,
loss_mask=eagle3_data.loss_mask,
target=eagle3_data.target,
hidden_states=eagle3_data.hidden_states,
)
if "pixel_values" in data:
kwargs["pixel_values"] = data["pixel_values"].cuda()
if "image_grid_thw" in data:
kwargs["image_grid_thw"] = data["image_grid_thw"].cuda()
eagle3_data = target_model.generate_eagle3_data(**kwargs)
eagle3_data.input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids)
eagle3_data.attention_mask = get_dp_data_shard_from_tp(eagle3_data.attention_mask)
eagle3_data.loss_mask = get_dp_data_shard_from_tp(eagle3_data.loss_mask)
eagle3_data.target = get_dp_data_shard_from_tp(eagle3_data.target)
eagle3_data.hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states)

plosses, _, acces = eagle3_model(
input_ids=eagle3_data.input_ids,
attention_mask=eagle3_data.attention_mask,
loss_mask=eagle3_data.loss_mask,
target=eagle3_data.target,
hidden_states=eagle3_data.hidden_states,
)
return plosses, acces


Expand Down Expand Up @@ -596,23 +570,12 @@ def main():
# ================================================
# 4. Build Eagle3 model
# ================================================
if (
args.is_vlm
and getattr(draft_model_config, "target_model_type", None) == "qwen2_5_vl"
):
eagle3_model = QwenVLOnlineEagle3Model(
target_model=target_model,
draft_model=draft_model,
processor=processor,
length=args.ttt_length,
attention_backend=args.attention_backend,
)
else:
eagle3_model = OnlineEagle3Model(
draft_model=draft_model,
length=args.ttt_length,
attention_backend=args.attention_backend,
)

eagle3_model = OnlineEagle3Model(
draft_model=draft_model,
length=args.ttt_length,
attention_backend=args.attention_backend,
)

eagle3_model = FSDP(
eagle3_model,
Expand Down
4 changes: 2 additions & 2 deletions specforge/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .eagle3 import OfflineEagle3Model, OnlineEagle3Model, QwenVLOnlineEagle3Model
from .eagle3 import OfflineEagle3Model, OnlineEagle3Model

__all__ = ["OnlineEagle3Model", "OfflineEagle3Model", "QwenVLOnlineEagle3Model"]
__all__ = ["OnlineEagle3Model", "OfflineEagle3Model"]
Loading