Skip to content

feat: add tensor-parallel entropy loss benchmark #500

@jlamypoirier

Description

@jlamypoirier

Motivation

The Triton entropy loss kernel supports tensor-parallel training natively via its group parameter: each rank holds vocab / tp_size logits, all-reduces two scalars per row (max and sum-exp), then computes loss and grad locally. This avoids materializing the full logits on any rank.

The PyTorch alternative requires an all-gather of the full logit tensor first — O(tokens × vocab) communication. At realistic scale this is prohibitive:

  • Llama 3.1 405B: vocab=128K, tokens=8K, TP=8 → 16 GB all-gather per step

The Triton path is not just faster; it is the only feasible approach at large vocab × TP. The current single-GPU benchmark does not demonstrate this. This issue tracks building a multi-GPU benchmark that makes it concrete.

Variants to benchmark

Three qualitatively different approaches:

  1. triton_tp — existing triton_entropy_loss_forward_backward(..., group=group). Shards vocab across ranks, all-reduces two scalars per row. O(tokens) communication.

  2. pytorch_tp_manual — same algorithm in PyTorch without Triton: local_max = logits.max(-1)all_reduce(MAX)local_sum = (logits − max).exp().sum(-1)all_reduce(SUM) → loss. Tests whether Triton fusion still wins when both paths use the same O(tokens) communication pattern.

  3. pytorch_gather — all-gather logits to full vocab on each rank, then F.cross_entropy. O(tokens × vocab) communication. Included as a reference to show where the naive approach becomes infeasible; expected to OOM at large vocab × TP.

Shapes

Fix tokens=4096. Sweep (vocab, tp_size) pairs:

vocab tp_size shard / rank
32768 2 16384
32768 4 8192
65536 4 16384
131072 4 32768
131072 8 16384

Infrastructure changes

Multi-process runner

The current runner.py is single-process. Two options:

Option A — new tools/benchmark/run_tp.py entry point using torch.multiprocessing.spawn(worker, nprocs=tp_size). Each worker: initializes dist.init_process_group, creates a TP process group, runs benchmark variants with group=group, rank 0 collects and prints results.

Option B — extend __main__.py to detect a --tp N flag and re-launch itself via torchrun --nproc_per_node=N.

Option A is simpler to implement. Option B integrates more cleanly with the existing CLI.

Timing

  • dist.barrier() + torch.cuda.synchronize() before and after each timed region so all ranks agree on wall time.
  • Report max latency across ranks (rank 0 collects via dist.all_reduce(MAX)).
  • Communication time is included automatically since the all-reduce is inside the kernel call.

OOM guard for pytorch_gather

Wrap in try/except and report OOM in the table instead of a time. This makes the table show exactly where the naive approach becomes infeasible.

Expected outcome

At vocab=131072, TP=8:

variant result
triton_tp fast (~O(tokens) communication)
pytorch_tp_manual slightly slower (no kernel fusion, same communication)
pytorch_gather OOM or ~10–20× slower (16 GB all-gather dominates)

Files

  • New: tools/benchmark/run_tp.py (or extend __main__.py)
  • Modified: tools/benchmark/bench_entropy_loss.py — add TP variants alongside existing single-GPU variants

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions