Skip to content

Non-record: QAT ablation — int8 QAT overhead exceeds quantization gap recovery#145

Open
mrdavtan wants to merge 2 commits intoopenai:mainfrom
mrdavtan:qat-ablation-pr
Open

Non-record: QAT ablation — int8 QAT overhead exceeds quantization gap recovery#145
mrdavtan wants to merge 2 commits intoopenai:mainfrom
mrdavtan:qat-ablation-pr

Conversation

@mrdavtan
Copy link

Summary

Clean ablation of per-row int8 quantization-aware training (QAT) on the baseline 9L×512d architecture with default hyperparameters. Finding: the overhead costs more than it recovers.

  • QAT uses a straight-through estimator matching quantize_float_tensor exactly (same INT8_CLIP_Q = 0.9999984 percentile, same per-row scale)
  • torch.quantile adds ~20% per-step overhead (64ms → 77ms), costing ~2,000 training steps in the 600s budget
  • Post-roundtrip val_bpb: 1.2052 (vs SlidingWindowEval's 1.1925 without QAT)
  • The lost training tokens hurt more than the ~0.007 BPB quantization gap recovery helps

Key takeaways for other participants

  1. Int8 QAT with exact percentile matching is not worth it under the 10-minute wallclock cap — aggressive warmdown (WD=20000) achieves similar quant gap reduction at zero training cost
  2. Int6 QAT likely does pay off — the quantization gap is larger (~0.01+ BPB), making the overhead worthwhile. PRs Record: Int6 MLP3x + STE QAT + Sliding Window (val_bpb=1.1594) #128 and Record: Int6 + MLP 3x + STE QAT + NorMuon + sliding window (val_bpb 1.1666) #137 confirm this.
  3. torch.compile graph priming pitfall: pre-compiling both QAT and non-QAT graphs during warmup causes the compiler to use a slower path for the non-QAT forward pass. Don't do this — accept the one-time recompile instead.

Results

Metric SlidingWindowEval (no QAT) This run (QAT)
Steps 13,450 8,011
step_avg 44.6ms 75.2ms
Post-quant val_bpb (sliding window) 1.1925 1.2052

Test plan

  • Runs on 8×H100 SXM within 10-minute wallclock cap
  • Artifact under 16,000,000 bytes (15,868,103)
  • QAT activates at step 6,000 and runs for 2,011 steps
  • Training log included

Built with Claude Code

Clean ablation of per-row int8 QAT with exact INT8_CLIP_Q percentile
match on baseline 9L×512d architecture. Finding: torch.quantile adds
~20% step overhead, costing ~2000 training steps in the 600s budget.
The lost training tokens hurt more than the ~0.007 BPB quantization
gap recovery helps. QAT likely only pays off with int6 (larger gap)
or a faster approximate quantile. Includes graph priming finding
for torch.compile conditional code paths.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant