-
Notifications
You must be signed in to change notification settings - Fork 143
dLLM(Dflash) Offline Training Support #445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
The implementation of DFlashHiddenStatesGenerator seems overly complex. My understanding is that the only difference from how Eagle3 retrieves hidden states is that we need to filter out unnecessary ones based on block size. Could we basically reuse the original hidden state logic? As for the layer differences, couldn't those be handled via configuration? |
Thanks for the suggestion. I’ve refactored the code accordingly:
|
Motivation
Compared to EAGLE3, DFlash requires more epochs for model convergence, so offline training can greatly improve training efficiency.This PR adds offline training support for DFlash, enabling DFlash draft model training using pre-computed hidden states. This eliminates the need to load the full target model during training, significantly reducing GPU memory requirements and training costs.
Key benefits:
Memory Efficiency: No need to load the target model during training (only draft model + embeddings/lm_head needed)
Decoupled Pipeline: Hidden states generation can be done separately from training, enabling better resource utilization
Consistent with Online Mode: The offline mode uses the same data preprocessing logic as online training
Modifications
1. Hidden States Generation (scripts/prepare_hidden_states.py)
Added DFlashHiddenStatesGenerator class for generating DFlash-specific hidden states
Captures hidden states from target layers based on draft model configuration
Supports filtering samples with insufficient loss tokens (< 2 * block_size)
Added build_dflash_target_model() function to build DFlash target model with layer capture configuration
Added DFlash-specific CLI arguments: --model-type dflash, --num-draft-layers, --target-layers, --block-size
2. Offline Dataset Support (specforge/data/preprocessing.py)
3. Training Script Updates (scripts/train_dflash.py)
Refactored build_target_model() to skip loading target model in offline mode
Refactored build_dataloader() to handle both online and offline datasets
4. Data Collator Improvements (specforge/data/utils.py)
Related Issues
Accuracy Test
Benchmark & Profiling
For Qwen3-8B:
Checklist