Skip to content

Triton reverse_kl loss kernel is ~3x slower than torch.compile (single GPU) #497

@jlamypoirier

Description

@jlamypoirier

Summary

The Triton `reverse_kl` loss kernel (`triton_reverse_kl_forward_backward` in `fast_llm/functional/triton/entropy_loss.py`) is ~2.9× slower than `torch.compile` at vocab=32K on a single GPU. Note: in tensor-parallel training logits are sharded (`vocab / tp_size` per rank) so the gap narrows at higher TP degrees — and at large vocab TP is the only feasible approach (see #triton-tp-benchmark plan).

Benchmark results (H100 SXM, bf16, single GPU, fwd+bwd)

Shape (tokens × vocab) pytorch_compiled fast_llm_triton Triton memory
4 Ki × 32 Ki 485 GB/s, 14.5% BW, Δpeak 0.75 GiB 169 GB/s, 5.1% BW, Δpeak 0.25 GiB 3× less
4 Ki × 64 Ki 481 GB/s, 14.4% BW, Δpeak 1.50 GiB anomalous timing*
4 Ki × 128 Ki 483 GB/s, 14.4% BW, Δpeak 3.00 GiB 389 GB/s, 11.6% BW, Δpeak 1.00 GiB 3× less

*The 64 Ki case recorded a negative backward time (measurement artifact).

Triton uses 3× less activation memory by fusing fwd+bwd into one pass without intermediate tensors — but that doesn't compensate for the throughput deficit at 32K vocab.

Root cause

`block_size = min(next_power_of_2(n_cols), 32768)` → at vocab=32K, block_size=32K. With 512 threads/block each thread holds 64 fp32 values in registers. H100 has 65536 registers/SM → only ~2 blocks fit per SM → ~50% warp occupancy → DRAM latency not hidden.

Three compounding factors vs `cross_entropy_labels` (which achieves 69–96% BW):

  1. Dual softmax in forward — `triton_reverse_kl_forward_from_distribution` runs `triton_fused_softmax_iter_base` on both logits and target simultaneously. Extra live registers (target tile, target_max, target_sum_exp) squeeze occupancy further.
  2. Heavier backward formula — reverse KL grad needs an extra `log` + scalar broadcast vs cross-entropy's `p − q`.
  3. torch.compile advantage — for CE, the unfused backward is slow enough that Triton wins despite low occupancy. For reverse KL, `F.kl_div(target.log_softmax(), logits.softmax())` produces two high-occupancy softmax kernels + pointwise, each well-tuned individually.

Fix options

Option A — Lower the block_size cap for distribution kernels (e.g. 8192 instead of 32768): 4× more blocks/SM at the cost of 3–4 extra re-reads of logits+target in the backward. Empirically likely a net win on H100. Try caps of 4096, 8192, 16384 and benchmark.

Option B — Two-pass fwd/bwd with cached stats: the `group is not None` (distributed) path already does this split via the `max_logits_ptr is not None` branch — a separate forward kernel stores per-row max/sum stats to DRAM, then the backward kernel reloads them without redoing the softmax. The fix for the non-distributed path is to invoke those same existing forward + backward kernels separately rather than the fused `forward_backward` kernel. Both passes run at their independently optimal block sizes; the backward only reads logits+target once. Any fix should stay consistent with the distributed path or unify them.

Expected outcome (after fix)

variant current estimated after fix
pytorch_compiled 1.634 ms, 14.7% BW — (baseline)
fast_llm_triton 4.669 ms, 5.1% BW ~1.0–1.5 ms, ~20–30% BW

Triton should match or beat compiled, especially at larger vocab sizes where the fused single-pass advantage (no intermediate softmax tensors written to DRAM) matters more.

Notes

`cross_entropy_logits` (CE with a soft target distribution rather than hard labels) has the same dual-softmax register pressure and the same fix applies. It is currently ~1.5× faster than compiled vs 3× for CE with labels, consistent with the same occupancy issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions