Skip to content

Conversation

@zyksir
Copy link
Collaborator

@zyksir zyksir commented Nov 7, 2025

Motivation

(See #117 ) Training Llama-3.1 models (8B and 70B) in offline mode with long context lengths (e.g., 8K, 16K, or 32K) currently fails with Out-of-Memory (OOM) errors, even on multi-GPU setups.

Modifications

Thanks @yd-oom for previous PR

Related Issues

#112

Accuracy Test

Benchmark & Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @zyksir, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces comprehensive support for Tensor Parallelism (TP) in the training of long-context draft models, particularly addressing Out-of-Memory (OOM) issues encountered with large Llama-3.1 models. It achieves this through a major refactoring of the training infrastructure, including a new class-based trainer, decoupled distributed training groups for target and draft models, and a reorganized module structure for improved scalability and maintainability.

Highlights

  • Distributed Training Enhancement: Introduced distinct Tensor Parallelism (TP) and Data Parallelism (DP) groups for both draft and target models, allowing for more flexible and efficient distributed training configurations tailored for long-context models.
  • Unified Training Script: Consolidated the train_eagle3_online.py and train_eagle3_sgl_online.py scripts into a single, refactored train_eagle3_online.py with a class-based trainer (Eagle3Trainer and SglOnlineEagle3Trainer), significantly improving code organization and maintainability.
  • Improved Checkpointing and Resume: Enhanced the training process with more robust checkpointing and resume capabilities, including support for Zero2 optimization for improved memory efficiency during large-scale training.
  • Module Restructuring: Performed a significant reorganization of the specforge package, moving core and modeling modules into a new model subpackage, which improves logical grouping and clarity of the codebase.
  • SGLang Integration Updates: Modified the SGLang backend integration to better handle hidden states and logits processing for EAGLE3 training, enabling the target model to provide necessary data for draft model training more effectively.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant enhancements to support Tensor Parallelism (TP) for draft model training, which is crucial for handling long contexts and mitigating Out-of-Memory (OOM) errors. The core changes involve decoupling the draft model's TP from the target model's Data Parallelism (DP), extensive refactoring of the training scripts into a more object-oriented structure with Eagle3Trainer classes, and updating the distributed communication setup. While the overall direction and refactoring are excellent for maintainability and scalability, I have identified a few issues, including a critical script path error, some redundant code, and opportunities for better code organization.

--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3_online.py \
$ROOT_DIR/scripts/train_eagle3.py \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The script path has been changed to $ROOT_DIR/scripts/train_eagle3.py, but the file being modified in this pull request is scripts/train_eagle3_online.py. The file train_eagle3.py does not appear to exist, which will cause this script to fail.

Suggested change
$ROOT_DIR/scripts/train_eagle3.py \
$ROOT_DIR/scripts/train_eagle3_online.py \

Comment on lines 35 to 50
python scripts/build_eagle3_dataset_cache.py \
--target-model-path $MODEL_PATH \
--draft-model-config ./configs/llama3-8B-eagle3.json \
--train-data-path $DATASET_PATH/sharegpt_ultrachat.jsonl \
--train-data-path $GENERATED_DATASET_PATH/train_data.jsonl \
--eval-data-path $GENERATED_DATASET_PATH/eval_data.jsonl \
--cache-dir $CACHE_DIR \
--chat-template $CHAT_TEMPLATE \
--max-length $MAX_LENGTH \
--view-train-data 1
--view-train-data 1 --debug
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The build_eagle3_dataset_cache.py script is called here with the --debug flag, and then called again on lines 52-59 without it. This first call seems redundant for a final script. If it's intended for debugging during development, it should be removed to avoid unnecessary processing and potential confusion.

Comment on lines 1003 to 1038
class SglOnlineEagle3Trainer(Eagle3Trainer):
def __init__(self, args):
super().__init__(args)
assert (
args.draft_micro_batch_size == 1
), "SglOnlineEagle3Trainer only supports draft_micro_batch_size = 1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The SglOnlineEagle3Trainer has a hardcoded assertion args.draft_micro_batch_size == 1. This is also noted in the file's docstring as a TODO. While this is a known limitation, it's a significant one that restricts the training configurations. It would be beneficial to prioritize removing this limitation to allow for more flexible batching strategies.

Comment on lines 47 to 89
def get_draft_dp_device_mesh():
global _DRAFT_DP_DEVICE_MESH
return _DRAFT_DP_DEVICE_MESH
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function get_draft_dp_device_mesh is defined and exported but is not used anywhere in the codebase. This appears to be dead code and should be removed to improve maintainability.

Comment on lines 27 to 38
def get_dp_data_shard_from_tp(
tensor: Union[torch.Tensor, List[torch.Tensor]]
) -> torch.Tensor:
"""
Get the data shard from the tensor.
"""
tp_size = dist.get_world_size(get_target_tp_group())
tp_rank = dist.get_rank(get_target_tp_group())
tensor_length = len(tensor) if isinstance(tensor, List) else tensor.shape[0]
assert tensor_length % tp_size == 0, "Tensor length must be divisible by tp_size"
chunk_size = tensor_length // tp_size
return tensor[tp_rank * chunk_size : (tp_rank + 1) * chunk_size]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function get_dp_data_shard_from_tp is a generic distributed utility. Placing it within eagle3_target_model.py makes the code less modular. It would be better placed in a more general utility module like specforge/distributed.py to centralize distributed helper functions.

Comment on lines 29 to 137
# This is a modified forward function for the SGLang's logits processor, adapted from https://github.com/sgl-project/sglang/blob/v0.5.4/python/sglang/srt/layers/logits_processor.py.
# The modification is to return the logits and aux hidden states instead of the last hidden states.
# """

# if isinstance(logits_metadata, ForwardBatch):
# logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)

# # Check if multi-item scoring is enabled via server args (only for prefill-only requests)
# multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter
# if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
# return self.compute_logprobs_for_multi_item_scoring(
# input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
# )

# # Get the last hidden states and last logits for the next token prediction
# if (
# logits_metadata.forward_mode.is_decode_or_idle()
# or logits_metadata.forward_mode.is_target_verify()
# or logits_metadata.forward_mode.is_draft_extend_v2()
# ):
# pruned_states = hidden_states
# if aux_hidden_states is not None:
# aux_pruned_states = [hidden for hidden in aux_hidden_states]
# sample_indices = None
# input_logprob_indices = None
# else:
# raise RuntimeError(
# f"The modified logits processor is not supported for this forward mode: {logits_metadata.forward_mode}"
# )

# # Compute logits for both input and sampled tokens.
# logits = self._get_logits(pruned_states, lm_head, logits_metadata)

# hidden_states_to_store: Optional[torch.Tensor] = None
# if logits_metadata.capture_hidden_mode.need_capture():
# if logits_metadata.capture_hidden_mode.is_full():
# if aux_hidden_states is not None:
# aux_hidden_states = torch.cat(aux_hidden_states, dim=-1)
# hidden_states_to_store = aux_hidden_states
# else:
# hidden_states_to_store = hidden_states
# elif logits_metadata.capture_hidden_mode.is_last():
# # Get the last token hidden states. If sample_indices is None,
# # pruned states only contain the last tokens already.
# if aux_hidden_states is not None:
# aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
# hidden_states_to_store = (
# aux_pruned_states[sample_indices]
# if sample_indices is not None
# else aux_pruned_states
# )
# else:
# hidden_states_to_store = (
# pruned_states[sample_indices]
# if sample_indices is not None
# else pruned_states
# )
# else:
# assert False, "Should never reach"

# if not logits_metadata.extend_return_logprob:
# # Decode mode or extend mode without return_logprob.
# return ReplacedLogitsProcessorEagle3Output(
# logits=logits,
# aux_hidden_states=hidden_states_to_store,
# )


class LogitsProcessorForEAGLE3(torch.nn.Module):
def __init__(
self, logits_processor: LogitsProcessor, return_full_logits: bool = False
):
super().__init__()
self.logits_processor = logits_processor
self.return_full_logits = return_full_logits

def forward(
self,
input_ids,
hidden_states,
lm_head,
logits_metadata,
aux_hidden_states: Optional[List[torch.Tensor]] = None,
) -> LogitsProcessorOutput:
logits_metadata.forward_mode = ForwardMode.DECODE
# ret = replaced_logits_processor_forward_for_eagle3(
# self.logits_processor,
# input_ids,
# hidden_states,
# lm_head,
# logits_metadata,
# aux_hidden_states,
# )
# ret = self.logits_processor.forward(
# input_ids, hidden_states, lm_head, logits_metadata, aux_hidden_states
# )
return ReplacedLogitsProcessorEagle3Output(
hidden_states=hidden_states,
aux_hidden_states=torch.cat(aux_hidden_states, dim=-1),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file contains a significant amount of commented-out code. This should be removed to improve code clarity and maintainability.

@zyksir zyksir changed the title Feat: Support TP for long-context draft model training Feat: Refactor & Support TP for long-context draft model training Nov 7, 2025
@zyksir zyksir force-pushed the feature/refactor branch 2 times, most recently from cc13007 to 3530625 Compare November 8, 2025 21:58
Comment on lines +309 to 310
seq_lengths.extend([16384])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove 32k?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

32k will lead to OOM when I test in my H100

from specforge.utils import padding

from .utils import norm_tensor
from tests.utils import norm_tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import is not used.

Comment on lines +1 to +4
from .optimizer import BF16Optimizer
from .tracker import Tracker, build_tracker

__all__ = ["BF16Optimizer", "Tracker", "build_tracker"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this module name is a bit weird, helper usually refers to some utilities functions, but this module contains components necessary for training. I guess these can be an independent file, i.e. specforge.tracker and specforge.optimizer.

Comment on lines +52 to +53
def model_specific_adjustment(self):
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function will call something like log_on_rank0; here, dist is initialized but the sglang parallel_state._WORLD is not set, which will lead to an error

)
target_micro_batch_size = None
else:
server_args = ServerArgs.from_cli_args(args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend that we don't do this, because this might cause some confusion to the users.

  1. Some arguments in SGLang are for optimizations other than prefill, these options won't take effect even if the user specify them
  2. this will make --help of our training script extremely long.

I recommend that we only keep those important to prefill.

Comment on lines +72 to +74
@dataclasses.dataclass
class Eagle3TrainerArgs:
target_model_path: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can migrate this to specforge.arguments.

Comment on lines +88 to +93
target_micro_batch_size: int = 8
draft_tp_size: int = 1
draft_dp_size: int = 1
draft_global_batch_size: int = 16
draft_micro_batch_size: int = 1
draft_accumulation_steps: int = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need micro-batch and globa-batch?

@zyksir zyksir closed this Nov 10, 2025
@zyksir zyksir deleted the feature/refactor branch November 10, 2025 19:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants