Skip to content

Conversation

@Ximingwang-09
Copy link
Contributor

Motivation

Compared to EAGLE3, DFlash requires more epochs for model convergence, so offline training can greatly improve training efficiency.This PR adds offline training support for DFlash, enabling DFlash draft model training using pre-computed hidden states. This eliminates the need to load the full target model during training, significantly reducing GPU memory requirements and training costs.

Key benefits:

  • Memory Efficiency: No need to load the target model during training (only draft model + embeddings/lm_head needed)

  • Decoupled Pipeline: Hidden states generation can be done separately from training, enabling better resource utilization

  • Consistent with Online Mode: The offline mode uses the same data preprocessing logic as online training

Modifications

1. Hidden States Generation (scripts/prepare_hidden_states.py)

  • Added DFlashHiddenStatesGenerator class for generating DFlash-specific hidden states

    • Captures hidden states from target layers based on draft model configuration

    • Supports filtering samples with insufficient loss tokens (< 2 * block_size)

  • Added build_dflash_target_model() function to build DFlash target model with layer capture configuration

  • Added DFlash-specific CLI arguments: --model-type dflash, --num-draft-layers, --target-layers, --block-size

2. Offline Dataset Support (specforge/data/preprocessing.py)

  • Added OfflineDFlashDataset class for loading pre-computed hidden states
    • Minimal preprocessing to maintain consistency with online training (block-size truncation handled in forward pass)
  • Added build_offline_dflash_dataset() factory function

3. Training Script Updates (scripts/train_dflash.py)

  • Added automatic mode detection: online (from conversation data) vs offline (from pre-computed hidden states)
  • Added --train-hidden-states-path and --eval-hidden-states-path arguments for offline mode
  • Added loss mask filtering for online mode (filter samples with loss_mask.sum() < 2 * block_size)
    Refactored build_target_model() to skip loading target model in offline mode
    Refactored build_dataloader() to handle both online and offline datasets

4. Data Collator Improvements (specforge/data/utils.py)

  • Added requires_target parameter to DataCollatorWithPadding for flexible field handling
    • DFlash: requires_target=False (only needs hidden_state)
    • Eagle3: requires_target=True (needs both hidden_state and target)
  • Updated prepare_dp_dataloaders() to pass requires_target parameter

Related Issues

Accuracy Test

image image

Benchmark & Profiling

For Qwen3-8B:

  • Offline Training:1.72s/step
  • Online Training:3.85s/step

Checklist

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@sleepcoo
Copy link
Collaborator

The implementation of DFlashHiddenStatesGenerator seems overly complex. My understanding is that the only difference from how Eagle3 retrieves hidden states is that we need to filter out unnecessary ones based on block size. Could we basically reuse the original hidden state logic? As for the layer differences, couldn't those be handled via configuration?

@Ximingwang-09
Copy link
Contributor Author

The implementation of DFlashHiddenStatesGenerator seems overly complex. My understanding is that the only difference from how Eagle3 retrieves hidden states is that we need to filter out unnecessary ones based on block size. Could we basically reuse the original hidden state logic? As for the layer differences, couldn't those be handled via configuration?

Thanks for the suggestion. I’ve refactored the code accordingly:

  • Removed the DFlashHiddenStatesGenerator class entirely.
  • Unified HiddenStatesGenerator, which now supports both Eagle3 and DFlash via configuration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants