-
Notifications
You must be signed in to change notification settings - Fork 143
Feat: Refactor & Support TP for long-context draft model training #280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @zyksir, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive support for Tensor Parallelism (TP) in the training of long-context draft models, particularly addressing Out-of-Memory (OOM) issues encountered with large Llama-3.1 models. It achieves this through a major refactoring of the training infrastructure, including a new class-based trainer, decoupled distributed training groups for target and draft models, and a reorganized module structure for improved scalability and maintainability. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant enhancements to support Tensor Parallelism (TP) for draft model training, which is crucial for handling long contexts and mitigating Out-of-Memory (OOM) errors. The core changes involve decoupling the draft model's TP from the target model's Data Parallelism (DP), extensive refactoring of the training scripts into a more object-oriented structure with Eagle3Trainer classes, and updating the distributed communication setup. While the overall direction and refactoring are excellent for maintainability and scalability, I have identified a few issues, including a critical script path error, some redundant code, and opportunities for better code organization.
examples/run_llama3_eagle3_online.sh
Outdated
| --standalone \ | ||
| --nproc_per_node $NUM_GPUS \ | ||
| $ROOT_DIR/scripts/train_eagle3_online.py \ | ||
| $ROOT_DIR/scripts/train_eagle3.py \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The script path has been changed to $ROOT_DIR/scripts/train_eagle3.py, but the file being modified in this pull request is scripts/train_eagle3_online.py. The file train_eagle3.py does not appear to exist, which will cause this script to fail.
| $ROOT_DIR/scripts/train_eagle3.py \ | |
| $ROOT_DIR/scripts/train_eagle3_online.py \ |
| python scripts/build_eagle3_dataset_cache.py \ | ||
| --target-model-path $MODEL_PATH \ | ||
| --draft-model-config ./configs/llama3-8B-eagle3.json \ | ||
| --train-data-path $DATASET_PATH/sharegpt_ultrachat.jsonl \ | ||
| --train-data-path $GENERATED_DATASET_PATH/train_data.jsonl \ | ||
| --eval-data-path $GENERATED_DATASET_PATH/eval_data.jsonl \ | ||
| --cache-dir $CACHE_DIR \ | ||
| --chat-template $CHAT_TEMPLATE \ | ||
| --max-length $MAX_LENGTH \ | ||
| --view-train-data 1 | ||
| --view-train-data 1 --debug |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The build_eagle3_dataset_cache.py script is called here with the --debug flag, and then called again on lines 52-59 without it. This first call seems redundant for a final script. If it's intended for debugging during development, it should be removed to avoid unnecessary processing and potential confusion.
scripts/train_eagle3_online.py
Outdated
| class SglOnlineEagle3Trainer(Eagle3Trainer): | ||
| def __init__(self, args): | ||
| super().__init__(args) | ||
| assert ( | ||
| args.draft_micro_batch_size == 1 | ||
| ), "SglOnlineEagle3Trainer only supports draft_micro_batch_size = 1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SglOnlineEagle3Trainer has a hardcoded assertion args.draft_micro_batch_size == 1. This is also noted in the file's docstring as a TODO. While this is a known limitation, it's a significant one that restricts the training configurations. It would be beneficial to prioritize removing this limitation to allow for more flexible batching strategies.
specforge/distributed.py
Outdated
| def get_draft_dp_device_mesh(): | ||
| global _DRAFT_DP_DEVICE_MESH | ||
| return _DRAFT_DP_DEVICE_MESH |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def get_dp_data_shard_from_tp( | ||
| tensor: Union[torch.Tensor, List[torch.Tensor]] | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Get the data shard from the tensor. | ||
| """ | ||
| tp_size = dist.get_world_size(get_target_tp_group()) | ||
| tp_rank = dist.get_rank(get_target_tp_group()) | ||
| tensor_length = len(tensor) if isinstance(tensor, List) else tensor.shape[0] | ||
| assert tensor_length % tp_size == 0, "Tensor length must be divisible by tp_size" | ||
| chunk_size = tensor_length // tp_size | ||
| return tensor[tp_rank * chunk_size : (tp_rank + 1) * chunk_size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # This is a modified forward function for the SGLang's logits processor, adapted from https://github.com/sgl-project/sglang/blob/v0.5.4/python/sglang/srt/layers/logits_processor.py. | ||
| # The modification is to return the logits and aux hidden states instead of the last hidden states. | ||
| # """ | ||
|
|
||
| # if isinstance(logits_metadata, ForwardBatch): | ||
| # logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) | ||
|
|
||
| # # Check if multi-item scoring is enabled via server args (only for prefill-only requests) | ||
| # multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter | ||
| # if multi_item_delimiter is not None and logits_metadata.is_prefill_only: | ||
| # return self.compute_logprobs_for_multi_item_scoring( | ||
| # input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter | ||
| # ) | ||
|
|
||
| # # Get the last hidden states and last logits for the next token prediction | ||
| # if ( | ||
| # logits_metadata.forward_mode.is_decode_or_idle() | ||
| # or logits_metadata.forward_mode.is_target_verify() | ||
| # or logits_metadata.forward_mode.is_draft_extend_v2() | ||
| # ): | ||
| # pruned_states = hidden_states | ||
| # if aux_hidden_states is not None: | ||
| # aux_pruned_states = [hidden for hidden in aux_hidden_states] | ||
| # sample_indices = None | ||
| # input_logprob_indices = None | ||
| # else: | ||
| # raise RuntimeError( | ||
| # f"The modified logits processor is not supported for this forward mode: {logits_metadata.forward_mode}" | ||
| # ) | ||
|
|
||
| # # Compute logits for both input and sampled tokens. | ||
| # logits = self._get_logits(pruned_states, lm_head, logits_metadata) | ||
|
|
||
| # hidden_states_to_store: Optional[torch.Tensor] = None | ||
| # if logits_metadata.capture_hidden_mode.need_capture(): | ||
| # if logits_metadata.capture_hidden_mode.is_full(): | ||
| # if aux_hidden_states is not None: | ||
| # aux_hidden_states = torch.cat(aux_hidden_states, dim=-1) | ||
| # hidden_states_to_store = aux_hidden_states | ||
| # else: | ||
| # hidden_states_to_store = hidden_states | ||
| # elif logits_metadata.capture_hidden_mode.is_last(): | ||
| # # Get the last token hidden states. If sample_indices is None, | ||
| # # pruned states only contain the last tokens already. | ||
| # if aux_hidden_states is not None: | ||
| # aux_pruned_states = torch.cat(aux_pruned_states, dim=-1) | ||
| # hidden_states_to_store = ( | ||
| # aux_pruned_states[sample_indices] | ||
| # if sample_indices is not None | ||
| # else aux_pruned_states | ||
| # ) | ||
| # else: | ||
| # hidden_states_to_store = ( | ||
| # pruned_states[sample_indices] | ||
| # if sample_indices is not None | ||
| # else pruned_states | ||
| # ) | ||
| # else: | ||
| # assert False, "Should never reach" | ||
|
|
||
| # if not logits_metadata.extend_return_logprob: | ||
| # # Decode mode or extend mode without return_logprob. | ||
| # return ReplacedLogitsProcessorEagle3Output( | ||
| # logits=logits, | ||
| # aux_hidden_states=hidden_states_to_store, | ||
| # ) | ||
|
|
||
|
|
||
| class LogitsProcessorForEAGLE3(torch.nn.Module): | ||
| def __init__( | ||
| self, logits_processor: LogitsProcessor, return_full_logits: bool = False | ||
| ): | ||
| super().__init__() | ||
| self.logits_processor = logits_processor | ||
| self.return_full_logits = return_full_logits | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids, | ||
| hidden_states, | ||
| lm_head, | ||
| logits_metadata, | ||
| aux_hidden_states: Optional[List[torch.Tensor]] = None, | ||
| ) -> LogitsProcessorOutput: | ||
| logits_metadata.forward_mode = ForwardMode.DECODE | ||
| # ret = replaced_logits_processor_forward_for_eagle3( | ||
| # self.logits_processor, | ||
| # input_ids, | ||
| # hidden_states, | ||
| # lm_head, | ||
| # logits_metadata, | ||
| # aux_hidden_states, | ||
| # ) | ||
| # ret = self.logits_processor.forward( | ||
| # input_ids, hidden_states, lm_head, logits_metadata, aux_hidden_states | ||
| # ) | ||
| return ReplacedLogitsProcessorEagle3Output( | ||
| hidden_states=hidden_states, | ||
| aux_hidden_states=torch.cat(aux_hidden_states, dim=-1), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aa590a6 to
55ba98c
Compare
cc13007 to
3530625
Compare
3530625 to
9ed921e
Compare
8519999 to
ec800a0
Compare
b9d32d0 to
6b3f700
Compare
| seq_lengths.extend([16384]) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove 32k?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
32k will lead to OOM when I test in my H100
| from specforge.utils import padding | ||
|
|
||
| from .utils import norm_tensor | ||
| from tests.utils import norm_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import is not used.
| from .optimizer import BF16Optimizer | ||
| from .tracker import Tracker, build_tracker | ||
|
|
||
| __all__ = ["BF16Optimizer", "Tracker", "build_tracker"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this module name is a bit weird, helper usually refers to some utilities functions, but this module contains components necessary for training. I guess these can be an independent file, i.e. specforge.tracker and specforge.optimizer.
| def model_specific_adjustment(self): | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function will call something like log_on_rank0; here, dist is initialized but the sglang parallel_state._WORLD is not set, which will lead to an error
| ) | ||
| target_micro_batch_size = None | ||
| else: | ||
| server_args = ServerArgs.from_cli_args(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend that we don't do this, because this might cause some confusion to the users.
- Some arguments in SGLang are for optimizations other than prefill, these options won't take effect even if the user specify them
- this will make
--helpof our training script extremely long.
I recommend that we only keep those important to prefill.
| @dataclasses.dataclass | ||
| class Eagle3TrainerArgs: | ||
| target_model_path: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can migrate this to specforge.arguments.
| target_micro_batch_size: int = 8 | ||
| draft_tp_size: int = 1 | ||
| draft_dp_size: int = 1 | ||
| draft_global_batch_size: int = 16 | ||
| draft_micro_batch_size: int = 1 | ||
| draft_accumulation_steps: int = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need micro-batch and globa-batch?
Motivation
(See #117 ) Training Llama-3.1 models (8B and 70B) in offline mode with long context lengths (e.g., 8K, 16K, or 32K) currently fails with Out-of-Memory (OOM) errors, even on multi-GPU setups.
Modifications
num_kv_head < tpThanks @yd-oom for previous PR
Related Issues
#112
Accuracy Test
Benchmark & Profiling
Checklist