-
Notifications
You must be signed in to change notification settings - Fork 143
[Draft] Long Context Training VRAM Optimization #446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ve_jsonl_data_file
Summary of ChangesHello @jiapingW, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on enhancing the efficiency and stability of long context training for the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several optimizations for long context training, primarily focusing on reducing DRAM usage. The key changes include switching to a generator-based dataset loading approach to avoid loading entire datasets into memory, and reordering operations to perform preprocessing on the CPU before moving large tensors to the GPU. These are solid strategies for memory optimization. My review includes suggestions to improve the clarity of a new example script, remove leftover debugging code, and to reconsider the use of torch.cuda.empty_cache(), which can impact performance.
examples/repo-wiki.sh
Outdated
| TARGET_MODEL_PATH=/disk3/wjp/pretrained_models/Qwen3-Coder-30B-A3B-Instruct | ||
| TRAIN_DATA_PATH=/disk3/wjp/datasets/repowiki/data_for_SpecForge_test.jsonl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The paths for TARGET_MODEL_PATH and TRAIN_DATA_PATH are hardcoded. This makes the script difficult for others to use without modification. Consider using environment variables with default placeholder values to make the script more portable.
| TARGET_MODEL_PATH=/disk3/wjp/pretrained_models/Qwen3-Coder-30B-A3B-Instruct | |
| TRAIN_DATA_PATH=/disk3/wjp/datasets/repowiki/data_for_SpecForge_test.jsonl | |
| TARGET_MODEL_PATH=${TARGET_MODEL_PATH:-"/path/to/your/Qwen3-Coder-30B-A3B-Instruct"} | |
| TRAIN_DATA_PATH=${TRAIN_DATA_PATH:-"/path/to/your/data_for_SpecForge_test.jsonl"} |
examples/repo-wiki.sh
Outdated
| LOR_INTERNAL=200 | ||
| SAVE_INTERNAL=10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable names LOR_INTERNAL and SAVE_INTERNAL are confusing as they don't clearly represent their purpose. LOR_INTERNAL is used for --save-interval and SAVE_INTERNAL for --log-interval.
To improve readability and maintainability, I suggest renaming them to SAVE_INTERVAL and LOG_INTERVAL respectively. You'll also need to update their usage on lines 62 and 63.
| LOR_INTERNAL=200 | |
| SAVE_INTERNAL=10 | |
| SAVE_INTERVAL=200 | |
| LOG_INTERVAL=10 |
examples/repo-wiki.sh
Outdated
| --save-interval $LOR_INTERNAL \ | ||
| --log-interval $SAVE_INTERNAL \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scripts/train_eagle3.py
Outdated
| target_model: Optional[Eagle3TargetModel] = None, | ||
| is_online: bool = True, | ||
| ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: | ||
| print(data["input_ids"].shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| length=self.length, | ||
| ) | ||
| del target | ||
| torch.cuda.empty_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explicitly calling torch.cuda.empty_cache() can introduce significant performance overhead due to CPU-GPU synchronization. The preceding del target should be sufficient to free the tensor's memory if there are no other references. Is this call strictly necessary for memory optimization in this case? If so, a comment explaining why would be helpful for future maintenance.
specforge/core/eagle3.py
Outdated
| # from .forkedpdb import ForkedPdb | ||
| # ForkedPdb().set_trace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
specforge/core/eagle3.py
Outdated
| # ForkedPdb().set_trace() | ||
| logits = gather_outputs_and_unpad(logits_, gather_dim=1) | ||
| del logits_ | ||
| torch.cuda.empty_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question as gemini, how much impact will this have on the training performance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I‘ve deleted it and the impact on time is minimal based on the current results.
specforge/core/eagle3.py
Outdated
| # ForkedPdb().set_trace() | ||
| logits = gather_outputs_and_unpad(logits_, gather_dim=1) | ||
| del logits_ | ||
| torch.cuda.empty_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question as gemini, how much impact will this have on the training performance?
| input_ids = input_ids.cuda() | ||
| target = target_model( | ||
| target.cuda() | ||
| ) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU. | ||
| loss_mask = loss_mask.cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the impact of this on performance? if it is large, maybe we can set it as a flag to control whether do this on GPU or CPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can simply compute this: vocab size(150000) * seq_length(64k) will cost 10G more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason is target_head's preprocess function will use padding will generate an extra copy of the target memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can split hidden state in dataset getitem for usp to reduce memory use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think this is a better optimization method. Can you help add this optimization?
we can split hidden state in dataset getitem for usp to reduce memory use.
Motivation
This RP handle the VRAM in 64k training with sp=8 and decreae it from 94G per GPU to 76G per GPU. And this contains the PR #429.
Modifications
Related Issues
Accuracy Test
Benchmark & Profiling
Checklist