Skip to content

Conversation

@jiapingW
Copy link
Collaborator

@jiapingW jiapingW commented Jan 21, 2026

Motivation

This RP handle the VRAM in 64k training with sp=8 and decreae it from 94G per GPU to 76G per GPU. And this contains the PR #429.

Modifications

Related Issues

Accuracy Test

Benchmark & Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @jiapingW, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the efficiency and stability of long context training for the eagle3 model, particularly for large models like Qwen3-30B-A3B. It introduces critical memory optimizations to reduce DRAM usage during forward passes and improves data loading robustness by handling diverse conversation data structures more gracefully. These changes collectively enable more reliable and scalable training with extended sequence lengths.

Highlights

  • DRAM Memory Optimization: Implemented several memory optimizations, including explicit GPU cache clearing (torch.cuda.empty_cache()) after intermediate tensor usage and deferring GPU transfer for large tensors like data['target'], to significantly reduce peak DRAM consumption during long context training.
  • Robust Dataset Loading: Introduced a new safe_conversations_generator utility to enhance the robustness of dataset loading. This generator converts complex data types (lists and dictionaries) within conversation messages to JSON strings, preventing mixed-type conflicts and ensuring compatibility with data processing frameworks like Arrow.
  • Long Context Tokenization: Updated tokenization logic to include truncation=True and max_length parameters, ensuring proper handling and prevention of issues with very long input sequences during processing.
  • New Training Example and Chat Template: Added a new example script (examples/repo-wiki.sh) demonstrating offline training of the eagle3 model for Qwen3-30B-A3B with long contexts (65536 tokens). A corresponding repo-wiki chat template has also been registered to support this new training configuration.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several optimizations for long context training, primarily focusing on reducing DRAM usage. The key changes include switching to a generator-based dataset loading approach to avoid loading entire datasets into memory, and reordering operations to perform preprocessing on the CPU before moving large tensors to the GPU. These are solid strategies for memory optimization. My review includes suggestions to improve the clarity of a new example script, remove leftover debugging code, and to reconsider the use of torch.cuda.empty_cache(), which can impact performance.

Comment on lines 14 to 15
TARGET_MODEL_PATH=/disk3/wjp/pretrained_models/Qwen3-Coder-30B-A3B-Instruct
TRAIN_DATA_PATH=/disk3/wjp/datasets/repowiki/data_for_SpecForge_test.jsonl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The paths for TARGET_MODEL_PATH and TRAIN_DATA_PATH are hardcoded. This makes the script difficult for others to use without modification. Consider using environment variables with default placeholder values to make the script more portable.

Suggested change
TARGET_MODEL_PATH=/disk3/wjp/pretrained_models/Qwen3-Coder-30B-A3B-Instruct
TRAIN_DATA_PATH=/disk3/wjp/datasets/repowiki/data_for_SpecForge_test.jsonl
TARGET_MODEL_PATH=${TARGET_MODEL_PATH:-"/path/to/your/Qwen3-Coder-30B-A3B-Instruct"}
TRAIN_DATA_PATH=${TRAIN_DATA_PATH:-"/path/to/your/data_for_SpecForge_test.jsonl"}

Comment on lines 40 to 41
LOR_INTERNAL=200
SAVE_INTERNAL=10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable names LOR_INTERNAL and SAVE_INTERNAL are confusing as they don't clearly represent their purpose. LOR_INTERNAL is used for --save-interval and SAVE_INTERNAL for --log-interval.

To improve readability and maintainability, I suggest renaming them to SAVE_INTERVAL and LOG_INTERVAL respectively. You'll also need to update their usage on lines 62 and 63.

Suggested change
LOR_INTERNAL=200
SAVE_INTERNAL=10
SAVE_INTERVAL=200
LOG_INTERVAL=10

Comment on lines 62 to 63
--save-interval $LOR_INTERNAL \
--log-interval $SAVE_INTERNAL \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Following the renaming of variables for clarity, please update the usage here.

Suggested change
--save-interval $LOR_INTERNAL \
--log-interval $SAVE_INTERNAL \
--save-interval $SAVE_INTERVAL \
--log-interval $LOG_INTERVAL \

target_model: Optional[Eagle3TargetModel] = None,
is_online: bool = True,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
print(data["input_ids"].shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be a leftover from debugging. It should be removed to avoid cluttering the logs.

length=self.length,
)
del target
torch.cuda.empty_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Explicitly calling torch.cuda.empty_cache() can introduce significant performance overhead due to CPU-GPU synchronization. The preceding del target should be sufficient to free the tensor's memory if there are no other references. Is this call strictly necessary for memory optimization in this case? If so, a comment explaining why would be helpful for future maintenance.

Comment on lines 211 to 212
# from .forkedpdb import ForkedPdb
# ForkedPdb().set_trace()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These commented-out lines appear to be leftover debugging code and should be removed.

# ForkedPdb().set_trace()
logits = gather_outputs_and_unpad(logits_, gather_dim=1)
del logits_
torch.cuda.empty_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to my other comment, this empty_cache() call can impact performance and may be redundant after del logits_. Is this explicit cache clearing essential here to prevent out-of-memory errors? If so, please add a comment explaining why.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as gemini, how much impact will this have on the training performance?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I‘ve deleted it and the impact on time is minimal based on the current results.

@jiapingW jiapingW changed the title [Draft] Long Context Training DRAM Optimization [Draft] Long Context Training VRAM Optimization Jan 21, 2026
# ForkedPdb().set_trace()
logits = gather_outputs_and_unpad(logits_, gather_dim=1)
del logits_
torch.cuda.empty_cache()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as gemini, how much impact will this have on the training performance?

Comment on lines +604 to +608
input_ids = input_ids.cuda()
target = target_model(
target.cuda()
) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU.
loss_mask = loss_mask.cuda()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the impact of this on performance? if it is large, maybe we can set it as a flag to control whether do this on GPU or CPU.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can simply compute this: vocab size(150000) * seq_length(64k) will cost 10G more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is target_head's preprocess function will use padding will generate an extra copy of the target memory.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can split hidden state in dataset getitem for usp to reduce memory use.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this is a better optimization method. Can you help add this optimization?

we can split hidden state in dataset getitem for usp to reduce memory use.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants