diff --git a/.gitignore b/.gitignore index d3b18b358..a5fd89b4b 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,4 @@ artifacts/ **/times.csv transformer_engine/build_info.txt transformer_engine/common/util/hip_nvml.* +.asv/ diff --git a/benchmarks/asv/README.md b/benchmarks/asv/README.md new file mode 100644 index 000000000..0d128da7a --- /dev/null +++ b/benchmarks/asv/README.md @@ -0,0 +1,136 @@ +# ASV Benchmarks for TransformerEngine + +Performance benchmarks built on [ASV (Air Speed Velocity)](https://asv.readthedocs.io/), +a framework for benchmarking Python packages over their lifetime. + +## Prerequisites + +- TransformerEngine must already be built and installed in the current Python environment. +- A ROCm or CUDA GPU must be available. +- Install ASV: `pip install asv` + +ASV is configured with `environment_type: "existing"` (in `benchmarks/asv/asv.conf.json`), +meaning it uses the current Python environment directly — it does not create virtualenvs or +attempt to build TE itself. The config sets `branches: ["HEAD"]` so that `asv publish` accepts results from +whichever branch is currently checked out — this works for both local development +and CI (where `HEAD` points to `dev`). + +## Running benchmarks + +### Direct execution (recommended for development) + +Each `bench_*.py` file is directly executable. Results are saved in ASV-compatible +format by default. + +```bash +cd benchmarks/asv +python driver.py --all # run every suite +python driver.py bench_gemm # run one suite via driver +python bench_gemm.py # run one suite directly +python bench_gemm.py time_forward # filter to a specific method +python bench_gemm.py -w 5 -n 20 # custom warmup/iteration counts +python bench_casting.py --no-save # skip saving results +``` + +### Helper script + +`run_benchmarks.sh` wraps common tasks and can be run from anywhere. + +```bash +bash benchmarks/asv/run_benchmarks.sh [options] +``` + +| Command | Description | +|---|---| +| `setup [name]` | Register machine with ASV (defaults to `hostname`) | +| `run [suite] [method]` | Run benchmarks in-process (fast, saves ASV-compatible results) | +| `run --asv [suite]` | Run via ASV subprocess isolation (for CI or statistical rigor) | +| `compare [ref] [new]` | Compare two commits (defaults to `HEAD~1` vs `HEAD`) | +| `view` | Generate HTML dashboard and serve on `localhost:8080` | +| `list` | List available benchmark suites | + +### Manual ASV commands + +All `asv` commands require `--config` with an **absolute path** and should be run +from the **repo root**. The common flags are: + +```bash +ASV="asv --config $(pwd)/benchmarks/asv/asv.conf.json" +COMMON="--python=same --launch-method spawn --set-commit-hash $(git rev-parse HEAD)" +``` + +- `--python=same` — use the current interpreter (required with `environment_type: "existing"`) +- `--launch-method spawn` — required for CUDA/ROCm (fork causes reinitialization errors) +- `--set-commit-hash` — **required** with `environment_type: "existing"`, otherwise ASV silently discards results + +```bash +$ASV machine --yes --machine mi325 # register machine +$ASV run $COMMON # run all benchmarks +$ASV run $COMMON --bench bench_casting # single suite (regex match) +$ASV continuous $COMMON HEAD~1 HEAD # compare two commits +$ASV publish && $ASV preview # HTML dashboard on localhost:8080 +``` + +## How results are stored + +### Local results + +ASV stores results as JSON files under `benchmarks/.asv/results/`: + +``` +benchmarks/.asv/results/ + my-machine-name/ + machine.json # Hardware/OS metadata + .json # Timing results for that commit + .json + ... +``` + +Each commit JSON contains the wall-clock timings for every benchmark + parameter combination +run on that machine. The `benchmarks/.asv/` directory is in `.gitignore`. + +## Writing new benchmarks + +Create a new file in `benchmarks/asv/` following the naming convention `bench_.py`. + +```python +#!/usr/bin/env python3 +import torch +import transformer_engine.pytorch as te + +class BenchSomething: + params = [[1024, 4096], ["config_a", "config_b"]] + param_names = ["M", "config"] + timeout = 300 # seconds, per parameter combination + + def setup(self, M, config): + # Allocate tensors, create modules. + # This runs before each time_* method but is NOT timed. + self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + ... + + def time_forward(self, M, config): + # Use CUDA events for accurate GPU timing. + # Return elapsed seconds — the driver uses this instead of wall time. + self._evt[0].record() + self.module(self.x) + self._evt[1].record() + torch.cuda.synchronize() + return self._evt[0].elapsed_time(self._evt[1]) / 1000 + + # Optional: define work_ to get throughput columns (TFLOPS / GB/s). + def work_forward(self, M, config): + return {"flops": 2 * M * self.N * self.K} # compute-bound + # return {"bytes": M * self.hidden * 4} # memory-bound + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) +``` + +Key rules: +- Method names starting with `time_` are automatically timed. +- Use CUDA events and return elapsed seconds for accurate GPU timing. +- Optionally define `work_` companions to get TFLOPS or GB/s columns. +- Clear `.grad` attributes in backward benchmarks to prevent memory accumulation. +- The `params` list defines a cross-product; keep the matrix size reasonable. diff --git a/benchmarks/asv/__init__.py b/benchmarks/asv/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmarks/asv/asv.conf.json b/benchmarks/asv/asv.conf.json new file mode 100644 index 000000000..3c1616aac --- /dev/null +++ b/benchmarks/asv/asv.conf.json @@ -0,0 +1,16 @@ +{ + "version": 1, + "project": "TransformerEngine", + "project_url": "https://github.com/ROCm/TransformerEngine", + "repo": "../..", + "branches": ["HEAD"], + "environment_type": "existing", + "install_command": [], + "build_command": [], + "benchmark_dir": ".", + "results_dir": "../.asv/results", + "html_dir": "../.asv/html", + "install_timeout": 600, + "benchmark_timeout": 1200, + "launch_method": "spawn" +} diff --git a/benchmarks/asv/bench_attention.py b/benchmarks/asv/bench_attention.py new file mode 100644 index 000000000..fd202cdca --- /dev/null +++ b/benchmarks/asv/bench_attention.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +Attention micro-benchmark using te.DotProductAttention. + +Benchmarks fused multi-head attention (with flash attention backend) for +model configurations with grouped-query attention (GQA). + +Models: + - Llama 3 8B (TP=1, TP=8), 70B (TP=8), 405B (TP=8) + - Qwen 2.5 7B (TP=1), 72B (TP=8) + +Forward FLOPs = 4 * batch * num_q_heads * seq_len^2 * head_dim + (two matmuls: Q@K^T and attn@V, each contributing 2*b*h*s^2*d) +Backward FLOPs = 2 * Forward FLOPs (approximately) + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + +Forward FLOPs = 4 * batch * num_q_heads * seq_len^2 * head_dim +Backward FLOPs ~ 2x forward +""" + +import torch +import transformer_engine.pytorch as te + +BATCH = 2 + +# (num_q_heads, num_kv_heads, head_dim, tp) +MODELS = { + "Llama3-8B_TP1": (32, 8, 128, 1), + "Llama3-8B_TP8": (32, 8, 128, 8), + "Llama3-70B_TP8": (64, 8, 128, 8), + "Llama3-405B_TP8": (128, 8, 128, 8), + "Qwen2.5-7B_TP1": (28, 4, 128, 1), + "Qwen2.5-72B_TP8": (64, 8, 128, 8), +} + + +class BenchAttention: + params = [[1024, 2048, 4096, 8192], list(MODELS)] + param_names = ["seq_len", "model"] + timeout = 300 + + def setup(self, seq_len, model): + n_q, n_kv, hd, tp = MODELS[model] + qh, kvh = n_q // tp, n_kv // tp + dtype = torch.bfloat16 + + self.attn = te.DotProductAttention( + num_attention_heads=qh, kv_channels=hd, + num_gqa_groups=kvh, attn_mask_type="causal", + ).to(device="cuda", dtype=dtype) + + self.q = torch.randn(seq_len, BATCH, qh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.k = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.v = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.attn(self.q, self.k, self.v)) + self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + + def work_forward(self, seq_len, model): + n_q, n_kv, hd, tp = MODELS[model] + qh = n_q // tp + return {"flops": 4 * BATCH * qh * seq_len * seq_len * hd} + + def work_forward_backward(self, seq_len, model): + n_q, n_kv, hd, tp = MODELS[model] + qh = n_q // tp + return {"flops": 3 * 4 * BATCH * qh * seq_len * seq_len * hd} + + def time_forward(self, seq_len, model): + self._evt[0].record() + self.attn(self.q, self.k, self.v) + self._evt[1].record() + torch.cuda.synchronize() + return self._evt[0].elapsed_time(self._evt[1]) / 1000 + + def time_forward_backward(self, seq_len, model): + self._evt[0].record() + out = self.attn(self.q, self.k, self.v) + out.backward(self.grad_out) + self._evt[1].record() + torch.cuda.synchronize() + self.q.grad = self.k.grad = self.v.grad = None + return self._evt[0].elapsed_time(self._evt[1]) / 1000 + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/asv/bench_casting.py b/benchmarks/asv/bench_casting.py new file mode 100644 index 000000000..bcdfd865a --- /dev/null +++ b/benchmarks/asv/bench_casting.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +Benchmarks quantization (BF16 -> FP8) and dequantization (FP8 -> BF16) for +both E4M3 (activations/weights) and E5M2 (gradients) formats. + +Shapes are (M, hidden_size) matching the activation tensors from models: + - Llama 3.1 8B, 70B, 405B + - Qwen 2.5 7B, 72B + +These casts are memory-bound; we report GB/s (input + output bytes). + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json +""" + +import torch +from transformer_engine.pytorch import Float8CurrentScalingQuantizer +from transformer_engine_torch import DType as TE_DType + +HIDDEN_SIZES = { + "Llama3-8B": 4096, + "Llama3-70B": 8192, + "Llama3-405B": 16384, + "Qwen2.5-7B": 3584, + "Qwen2.5-72B": 8192, +} + +CAST_CONFIGS = { + "BF16_to_E4M3": ("quantize", TE_DType.kFloat8E4M3), + "E4M3_to_BF16": ("dequantize", TE_DType.kFloat8E4M3), + "BF16_to_E5M2": ("quantize", TE_DType.kFloat8E5M2), + "E5M2_to_BF16": ("dequantize", TE_DType.kFloat8E5M2), +} + + +class BenchCasting: + params = [[1024, 2048, 4096, 8192], list(HIDDEN_SIZES), list(CAST_CONFIGS)] + param_names = ["M", "model", "cast"] + timeout = 120 + + def setup(self, M, model, cast): + hidden = HIDDEN_SIZES[model] + direction, fp8_dtype = CAST_CONFIGS[cast] + self.direction = direction + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=fp8_dtype, + device=torch.device("cuda"), + rowwise=True, + columnwise=False, + ) + if direction == "dequantize": + bf16_tensor = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") + self.x = quantizer.quantize(bf16_tensor) + else: + self.x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") + self.quantizer = quantizer + self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + + def work_cast(self, M, model, cast): + hidden = HIDDEN_SIZES[model] + direction = CAST_CONFIGS[cast][0] + if direction == "quantize": + # Read BF16 (2B) + write FP8 (1B) + write scale + return {"bytes": M * hidden * 3} + else: + # Read FP8 (1B) + read scale + write BF16 (2B) + return {"bytes": M * hidden * 3} + + def time_cast(self, M, model, cast): + self._evt[0].record() + if self.direction == "quantize": + self.quantizer.quantize(self.x) + else: + self.x.dequantize(dtype=torch.bfloat16) + self._evt[1].record() + torch.cuda.synchronize() + return self._evt[0].elapsed_time(self._evt[1]) / 1000 + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/asv/bench_gemm.py b/benchmarks/asv/bench_gemm.py new file mode 100644 index 000000000..85152d1da --- /dev/null +++ b/benchmarks/asv/bench_gemm.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""BF16 GEMM benchmarks via te.Linear. + +GEMM shapes derived from transformer layer projections: + QKV, AttnOut, GateUp (SwiGLU), Down. + +Model configuration sources: +- Llama 3 8B (hidden=4096, intermediate=14336, heads=32, kv_heads=8, head_dim=128) + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + +- Llama 3 70B (hidden=8192, intermediate=28672, heads=64, kv_heads=8, head_dim=128) + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + +- Llama 3 405B (hidden=16384, intermediate=53248, heads=128, kv_heads=8, head_dim=128) + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + +- Qwen 2.5 7B (hidden=3584, intermediate=18944, heads=28, kv_heads=4, head_dim=128) + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + +- Qwen 2.5 72B (hidden=8192, intermediate=29568, heads=64, kv_heads=8, head_dim=128) + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + """ + +import torch +import transformer_engine.pytorch as te + +# (hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) +MODELS = { + "Llama3-8B_TP1": (4096, 14336, 32, 8, 128, 1), + "Llama3-8B_TP8": (4096, 14336, 32, 8, 128, 8), + "Llama3-70B_TP8": (8192, 28672, 64, 8, 128, 8), + "Llama3-405B_TP8": (16384, 53248, 128, 8, 128, 8), + "Qwen2.5-7B_TP1": (3584, 18944, 28, 4, 128, 1), + "Qwen2.5-72B_TP8": (8192, 29568, 64, 8, 128, 8), +} + +# Pre-compute (N, K) for each GEMM shape +SHAPES = {} +for _name, (h, inter, nq, nkv, hd, tp) in MODELS.items(): + SHAPES[f"{_name}-QKV"] = ((nq * hd + 2 * nkv * hd) // tp, h) + SHAPES[f"{_name}-AttnOut"] = (h, (nq * hd) // tp) + SHAPES[f"{_name}-GateUp"] = ((2 * inter) // tp, h) + SHAPES[f"{_name}-Down"] = (h, inter // tp) + + +class BenchGemm: + params = [[1024, 2048, 4096, 8192], list(SHAPES)] + param_names = ["M", "shape"] + timeout = 300 + + def setup(self, M, shape): + N, K = SHAPES[shape] + dtype = torch.bfloat16 + self.linear = te.Linear(K, N, bias=False).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.linear(self.x)) + self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + + def work_forward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 2 * M * N * K} + + def work_forward_backward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 3 * 2 * M * N * K} + + def time_forward(self, M, shape): + self._evt[0].record() + self.linear(self.x) + self._evt[1].record() + torch.cuda.synchronize() + return self._evt[0].elapsed_time(self._evt[1]) / 1000 + + def time_forward_backward(self, M, shape): + self._evt[0].record() + out = self.linear(self.x) + out.backward(self.grad_out) + self._evt[1].record() + torch.cuda.synchronize() + self.x.grad = None + self.linear.weight.grad = None + return self._evt[0].elapsed_time(self._evt[1]) / 1000 + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/asv/bench_gemm_fp8.py b/benchmarks/asv/bench_gemm_fp8.py new file mode 100644 index 000000000..ce8d8b708 --- /dev/null +++ b/benchmarks/asv/bench_gemm_fp8.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +FP8 GEMM benchmarks via te.Linear under fp8_autocast. + +Same shapes as bench_gemm.py but with FP8 quantized compute: + - Llama 3 8B (TP=1, TP=8), 70B (TP=8), 405B (TP=8) + - Qwen 2.5 7B (TP=1), 72B (TP=8) + +Each model contributes four GEMM shapes: + QKV projection (column-parallel) N = (Qheads + 2*KVheads)*head_dim / TP, K = hidden + Attention output (row-parallel) N = hidden, K = Qheads*head_dim / TP + MLP Gate+Up (column-parallel) N = 2*intermediate / TP, K = hidden (SwiGLU) + MLP Down (row-parallel) N = hidden, K = intermediate / TP + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json +""" + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling, Format + +# (hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) +MODELS = { + "Llama3-8B_TP1": (4096, 14336, 32, 8, 128, 1), + "Llama3-8B_TP8": (4096, 14336, 32, 8, 128, 8), + "Llama3-70B_TP8": (8192, 28672, 64, 8, 128, 8), + "Llama3-405B_TP8": (16384, 53248, 128, 8, 128, 8), + "Qwen2.5-7B_TP1": (3584, 18944, 28, 4, 128, 1), + "Qwen2.5-72B_TP8": (8192, 29568, 64, 8, 128, 8), +} + +SHAPES = {} +for _name, (h, inter, nq, nkv, hd, tp) in MODELS.items(): + SHAPES[f"{_name}-QKV"] = ((nq * hd + 2 * nkv * hd) // tp, h) + SHAPES[f"{_name}-AttnOut"] = (h, (nq * hd) // tp) + SHAPES[f"{_name}-GateUp"] = ((2 * inter) // tp, h) + SHAPES[f"{_name}-Down"] = (h, inter // tp) + +FP8_RECIPE = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max", +) + + +class BenchGemmFP8: + params = [[1024, 2048, 4096, 8192], list(SHAPES)] + param_names = ["M", "shape"] + timeout = 300 + + def setup(self, M, shape): + N, K = SHAPES[shape] + dtype = torch.bfloat16 + self.linear = te.Linear(K, N, bias=False).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn(M, N, dtype=dtype, device="cuda") + self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + + def work_forward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 2 * M * N * K} + + def work_forward_backward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 3 * 2 * M * N * K} + + def time_forward(self, M, shape): + self._evt[0].record() + with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + self.linear(self.x) + self._evt[1].record() + torch.cuda.synchronize() + return self._evt[0].elapsed_time(self._evt[1]) / 1000 + + def time_forward_backward(self, M, shape): + self._evt[0].record() + with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + out = self.linear(self.x) + out.backward(self.grad_out) + self._evt[1].record() + torch.cuda.synchronize() + self.x.grad = None + self.linear.weight.grad = None + return self._evt[0].elapsed_time(self._evt[1]) / 1000 + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/asv/bench_grouped_gemm.py b/benchmarks/asv/bench_grouped_gemm.py new file mode 100644 index 000000000..40e1f2b0d --- /dev/null +++ b/benchmarks/asv/bench_grouped_gemm.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""Grouped GEMM benchmarks via te.GroupedLinear. + +MoE model configurations with GateUp and Down projections. +Configurations are based on: +https://github.com/AMD-AGI/Primus-Turbo/blob/main/benchmark/ops/config.py +""" + +import torch +import transformer_engine.pytorch as te + +# (n_routed_experts, moe_intermediate_size, hidden_size) +MOE_MODELS = { + "DSV2-Lite": (64, 1408, 2048), + "DSV2": (160, 1536, 5120), + "DSV3": (256, 2048, 7168), + "Grok-V2": (8, 16384, 8192), +} + +# Build (config_key -> (num_gemms, N, K)) mapping +CONFIGS = {} +for model, (n_experts, inter, hidden) in MOE_MODELS.items(): + for ep in [32, 16, 8]: + if n_experts % ep != 0: + continue + B = n_experts // ep + CONFIGS[f"{model}_EP{ep}-GateUp"] = (B, 2 * inter, hidden) + CONFIGS[f"{model}_EP{ep}-Down"] = (B, hidden, inter) + + +class BenchGroupedGemm: + params = [[512, 1024, 2048, 4096], list(CONFIGS)] + param_names = ["M", "config"] + timeout = 300 + + def setup(self, M, config): + B, N, K = CONFIGS[config] + dtype = torch.bfloat16 + + self.module = te.GroupedLinear( + num_gemms=B, in_features=K, out_features=N, bias=False, + ).to(device="cuda", dtype=dtype) + + self.xs = [ + torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + for _ in range(B) + ] + outs = self.module(self.xs) + self.grad_outs = [torch.randn_like(o) for o in outs] + self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + + def work_forward(self, M, config): + B, N, K = CONFIGS[config] + return {"flops": B * 2 * M * N * K} + + def work_forward_backward(self, M, config): + B, N, K = CONFIGS[config] + return {"flops": B * 3 * 2 * M * N * K} + + def time_forward(self, M, config): + self._evt[0].record() + self.module(self.xs) + self._evt[1].record() + torch.cuda.synchronize() + return self._evt[0].elapsed_time(self._evt[1]) / 1000 + + def time_forward_backward(self, M, config): + self._evt[0].record() + outs = self.module(self.xs) + torch.autograd.backward(outs, self.grad_outs) + self._evt[1].record() + torch.cuda.synchronize() + for x in self.xs: + x.grad = None + for p in self.module.parameters(): + p.grad = None + return self._evt[0].elapsed_time(self._evt[1]) / 1000 + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/asv/bench_normalization.py b/benchmarks/asv/bench_normalization.py new file mode 100644 index 000000000..9c0c571ca --- /dev/null +++ b/benchmarks/asv/bench_normalization.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +RMSNorm and LayerNorm benchmarks on activation-sized tensors. + +Shapes are derived from training workloads: + - Llama 3 8B, 70B, 405B (all use RMSNorm) + - Qwen 2.5 7B, 72B (all use RMSNorm) + +Modern models predominantly use RMSNorm, but we benchmark both +LayerNorm and RMSNorm since TE supports both and they share the +same kernel infrastructure. + +The M dimension (batch * seq_len) is swept across typical training sizes. + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json +""" + +import torch +import transformer_engine.pytorch as te + +NORMS = {"RMSNorm": te.RMSNorm, "LayerNorm": te.LayerNorm} +HIDDEN_SIZES = [3584, 4096, 8192, 16384] + + +class BenchNormalization: + params = [[1024, 2048, 4096, 8192], HIDDEN_SIZES, list(NORMS)] + param_names = ["M", "hidden", "norm_type"] + timeout = 120 + + def setup(self, M, hidden, norm_type): + dtype = torch.bfloat16 + self.norm = NORMS[norm_type](hidden).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, hidden, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.norm(self.x)) + self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + + def work_forward(self, M, hidden, norm_type): + # Read input (2B) + write output (2B) = 4 bytes per element + return {"bytes": M * hidden * 4} + + def work_forward_backward(self, M, hidden, norm_type): + # Fwd: read+write (4B), Bwd: read input+grad_out+write grad_in (6B) = 10B + return {"bytes": M * hidden * 10} + + def time_forward(self, M, hidden, norm_type): + self._evt[0].record() + self.norm(self.x) + self._evt[1].record() + torch.cuda.synchronize() + return self._evt[0].elapsed_time(self._evt[1]) / 1000 + + def time_forward_backward(self, M, hidden, norm_type): + self._evt[0].record() + out = self.norm(self.x) + out.backward(self.grad_out) + self._evt[1].record() + torch.cuda.synchronize() + self.x.grad = None + for p in self.norm.parameters(): + p.grad = None + return self._evt[0].elapsed_time(self._evt[1]) / 1000 + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/asv/driver.py b/benchmarks/asv/driver.py new file mode 100644 index 000000000..e7831586c --- /dev/null +++ b/benchmarks/asv/driver.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""ASV benchmark driver — runs bench classes in-process and saves ASV-compatible results. + +Usage: + python driver.py [method_filter] [-w W] [-n N] [--no-save] + python driver.py --all [-w W] [-n N] [--no-save] + python bench_gemm.py [method_filter] [-w W] [-n N] [--no-save] +""" + +import argparse +import glob +import hashlib +import importlib +import inspect +import itertools +import json +import math +import os +import platform +import subprocess +import sys +import textwrap +import time + + +# --------------------------------------------------------------------------- +# ASV result generation +# --------------------------------------------------------------------------- + +def _get_benchmark_code_and_version(cls, method_name): + """Build the code string and version hash the same way ASV does. + + ASV hashes a code string built from the time_* and setup methods. + The string is class header + indented time method + indented setup, + with no trailing newline. + + Returns (code, version_hash). + """ + time_src = textwrap.dedent(inspect.getsource(getattr(cls, method_name))) + setup_src = textwrap.dedent(inspect.getsource(cls.setup)) + code = ( + f"class {cls.__name__}:\n" + + textwrap.indent(time_src, " ") + "\n" + + textwrap.indent(setup_src, " ") + ).rstrip("\n") + return code, hashlib.sha256(code.encode()).hexdigest() + + +def _format_param_value(v): + """Format a parameter value the way ASV stores it in JSON.""" + if isinstance(v, str): + return f"'{v}'" + return repr(v) + + +def _get_machine_info(): + """Build the params/machine dict ASV expects.""" + machine = platform.node() + info = { + "arch": platform.machine(), + "cpu": "", + "machine": machine, + "num_cpu": str(os.cpu_count()), + "os": f"{platform.system()} {platform.release()}", + "ram": "", + } + try: + with open("/proc/cpuinfo") as f: + for line in f: + if line.startswith("model name"): + info["cpu"] = line.split(":", 1)[1].strip() + break + with open("/proc/meminfo") as f: + for line in f: + if line.startswith("MemTotal"): + info["ram"] = line.split()[1] # kB + break + except OSError: + pass + return machine, info + + +def _get_commit_hash(): + """Get the current git HEAD hash.""" + try: + return subprocess.check_output( + ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL + ).decode().strip() + except Exception: + return "unknown" + + +def _compute_stats(samples): + """Return (median, mean, stdev, ci_lo, ci_hi, q25, q75) for *samples*.""" + s = sorted(samples) + n = len(s) + mean = sum(s) / n + stdev = math.sqrt(sum((t - mean) ** 2 for t in s) / n) + ci = 2.576 * stdev / math.sqrt(n) # 99 % CI half-width + return (s[n // 2], mean, stdev, + max(0, mean - ci), mean + ci, + s[max(0, n // 4)], s[min(n - 1, 3 * n // 4)]) + + +def _get_results_dir(): + """Read results_dir from asv.conf.json, resolved to an absolute path.""" + conf_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "asv.conf.json") + with open(conf_path) as f: + conf = json.load(f) + conf_dir = os.path.dirname(conf_path) + return os.path.normpath(os.path.join(conf_dir, conf["results_dir"])) + + +def save_asv_results(all_results, bench_meta): + """Write results and benchmark index to ASV's results directory.""" + commit_hash = _get_commit_hash() + machine_name, machine_info = _get_machine_info() + env_name = "existing-" + sys.executable.replace("/", "_").strip("_") + results_dir = _get_results_dir() + machine_dir = os.path.join(results_dir, machine_name) + os.makedirs(machine_dir, exist_ok=True) + + # Write machine.json if missing + machine_json = os.path.join(machine_dir, "machine.json") + if not os.path.exists(machine_json): + with open(machine_json, "w") as f: + json.dump({**machine_info, "version": 1}, f, indent=4) + + # Load existing result file or start fresh + filename = f"{commit_hash[:8]}-{env_name}.json" + result_path = os.path.join(machine_dir, filename) + if os.path.exists(result_path): + with open(result_path) as f: + data = json.load(f) + else: + data = { + "commit_hash": commit_hash, + "env_name": env_name, + "date": int(time.time() * 1000), + "params": {**machine_info, "python": sys.executable}, + "python": sys.executable, + "requirements": {}, + "env_vars": {}, + "result_columns": [ + "result", "params", "version", + "started_at", "duration", + "stats_ci_99_a", "stats_ci_99_b", + "stats_q_25", "stats_q_75", + "stats_number", "stats_repeat", + "samples", + ], + "results": {}, + "durations": {}, + "version": 2, + } + + # Merge new results + for bench_key, bench_data in all_results.items(): + data["results"][bench_key] = bench_data + + with open(result_path, "w") as f: + json.dump(data, f, indent=2) + + print(f"\nResults saved to {result_path}") + + # Update benchmarks.json index so ASV dashboard stays in sync + benchmarks_path = os.path.join(results_dir, "benchmarks.json") + if os.path.exists(benchmarks_path): + with open(benchmarks_path) as f: + benchmarks_data = json.load(f) + else: + benchmarks_data = {"version": 2} + + benchmarks_data.update(bench_meta) + + with open(benchmarks_path, "w") as f: + json.dump(benchmarks_data, f, indent=4) + + print(f"Updated {benchmarks_path}") + + +# --------------------------------------------------------------------------- +# Benchmark runner +# --------------------------------------------------------------------------- + +_ASV_META_DEFAULTS = { + "min_run_count": 2, "number": 0, "repeat": 0, "rounds": 2, + "sample_time": 0.01, "type": "time", "unit": "seconds", "warmup_time": -1, +} + + +def run_class(suite_name, cls, class_name, method_filter=None, warmup=3, iters=7): + """Run all benchmarks in a class, returning (results, metadata) dicts.""" + methods = sorted(m for m in dir(cls) if m.startswith("time_")) + if method_filter: + methods = [m for m in methods if method_filter in m] + if not methods: + return {}, {} + + params = getattr(cls, "params", [[]]) + param_names = getattr(cls, "param_names", []) + combos = list(itertools.product(*params)) + asv_params = [[_format_param_value(v) for v in dim] for dim in params] + + # Discover throughput columns from work_* companions + # Each entry: (dict_key, column_header, unit_divisor) + probe_keys = set() + for m in methods: + wfn = getattr(cls, "work_" + m[5:], None) + if wfn: + try: + probe_keys.update(wfn(cls(), *combos[0])) + except Exception: + pass + throughput_cols = [] + if "flops" in probe_keys: + throughput_cols.append(("flops", "TFLOPS", 1e12)) + if "bytes" in probe_keys: + throughput_cols.append(("bytes", "GB/s", 1e9)) + + # Print table header + print(f"\n{class_name} ({len(combos)} combos x {len(methods)} methods, " + f"{warmup} warmup, {iters} timed)") + extra_hdr = "".join(f" {label:>10}" for _, label, _ in throughput_cols) + HDR = (f" {'median':>10} {'mean':>10} {'stdev':>10}" + f" {'q25':>10} {'q75':>10} {'min':>10} {'max':>10}" + + extra_hdr + f" {'method':<30} params") + print("-" * len(HDR)) + print(HDR) + print("-" * len(HDR)) + + all_results = {} + all_meta = {} + + for method_name in methods: + bench_key = f"{suite_name}.{class_name}.{method_name}" + code, version = _get_benchmark_code_and_version(cls, method_name) + + all_meta[bench_key] = { + **_ASV_META_DEFAULTS, + "code": code, "name": bench_key, "version": version, + "param_names": list(param_names), "params": asv_params, + "timeout": getattr(cls, "timeout", 300), + } + + medians, ci_los, ci_his, q25s, q75s = [], [], [], [], [] + numbers, repeats = [], [] + started_at = int(time.time() * 1000) + t_start = time.perf_counter() + + for combo in combos: + label = ", ".join(f"{n}={v}" for n, v in zip(param_names, combo)) + instance = cls() + try: + instance.setup(*combo) + except Exception as e: + print(f" SKIP {label} setup failed: {e}") + for lst in (medians, ci_los, ci_his, q25s, q75s, numbers, repeats): + lst.append(None) + continue + + method = getattr(instance, method_name) + for _ in range(warmup): + method(*combo) + + samples = [] + for _ in range(iters): + t0 = time.perf_counter() + result = method(*combo) + wall = time.perf_counter() - t0 + samples.append(wall if result is None else result) + + median, mean, stdev, ci_lo, ci_hi, q25, q75 = _compute_stats(samples) + s_min, s_max = min(samples), max(samples) + + medians.append(median) + ci_los.append(ci_lo) + ci_his.append(ci_hi) + q25s.append(q25) + q75s.append(q75) + numbers.append(1) + repeats.append(iters) + + # Derive throughput from work_* companion + work = {} + wfn = getattr(instance, "work_" + method_name[5:], None) + if wfn and median > 0: + try: + work = wfn(*combo) + except Exception: + pass + extra_cols = "" + for key, _, divisor in throughput_cols: + if key in work and median > 0: + extra_cols += f" {work[key] / median / divisor:>10.1f}" + else: + extra_cols += f" {'':>10}" + + print(f" {median*1000:>8.3f}ms {mean*1000:>8.3f}ms " + f"{stdev*1000:>8.3f}ms {q25*1000:>8.3f}ms {q75*1000:>8.3f}ms " + f"{s_min*1000:>8.3f}ms {s_max*1000:>8.3f}ms" + f"{extra_cols} " + f"{method_name:<30} {label}") + + all_results[bench_key] = [ + medians, asv_params, version, started_at, round(duration, 2), + ci_los, ci_his, q25s, q75s, numbers, repeats, + ] + + return all_results, all_meta + + +def run_as_main(caller_file=None): + """Run benchmarks from a bench file or from the command line. + + When called with a file path (from a bench file's ``__main__`` block), + the suite is derived from the filename. When called without arguments + (i.e. ``python driver.py bench_gemm``), the suite is taken from argv. + + Usage from a bench file:: + + if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) + """ + parser = argparse.ArgumentParser( + description="Run ASV benchmarks directly in-process (no subprocess overhead).") + if caller_file is None: + parser.add_argument("suite", nargs="?", default=None, + help="Benchmark module name (e.g. bench_casting)") + parser.add_argument("--all", action="store_true", + help="Run all bench_*.py suites in the directory") + parser.add_argument("method_filter", nargs="?", default=None, + help="Only run time_* methods containing this string") + parser.add_argument("-w", "--warmup", type=int, default=3, + help="Number of warmup iterations (default: 3)") + parser.add_argument("-n", "--iters", type=int, default=7, + help="Number of timed iterations (default: 7)") + parser.add_argument("--no-save", action="store_true", + help="Skip saving results to ASV format") + args = parser.parse_args() + + if caller_file is not None: + script_dir = os.path.dirname(os.path.abspath(caller_file)) + suite_names = [os.path.splitext(os.path.basename(caller_file))[0]] + else: + script_dir = os.path.dirname(os.path.abspath(__file__)) + run_all = getattr(args, "all", False) + if run_all: + suite_names = sorted( + os.path.splitext(os.path.basename(f))[0] + for f in glob.glob(os.path.join(script_dir, "bench_*.py")) + ) + elif args.suite: + suite_names = [args.suite] + else: + parser.error("provide a suite name or use --all") + + os.chdir(script_dir) + if script_dir not in sys.path: + sys.path.insert(0, script_dir) + + all_results = {} + all_meta = {} + for suite_name in suite_names: + mod = importlib.import_module(suite_name) + for name in sorted(dir(mod)): + obj = getattr(mod, name) + if isinstance(obj, type) and name.startswith("Bench"): + results, meta = run_class( + suite_name, obj, name, args.method_filter, args.warmup, args.iters) + all_results.update(results) + all_meta.update(meta) + + if all_results and not args.no_save: + save_asv_results(all_results, all_meta) + + +if __name__ == "__main__": + run_as_main() diff --git a/benchmarks/asv/run_benchmarks.sh b/benchmarks/asv/run_benchmarks.sh new file mode 100755 index 000000000..43b4465d4 --- /dev/null +++ b/benchmarks/asv/run_benchmarks.sh @@ -0,0 +1,69 @@ +#!/usr/bin/env bash +# Helper script for common ASV benchmark tasks. +set -euo pipefail + +cd "$(git rev-parse --show-toplevel)" + +BENCH_DIR="benchmarks/asv" +ASV_CONF="$(pwd)/$BENCH_DIR/asv.conf.json" + +usage() { + cat < [options] + +Commands: + setup Register this machine with ASV + run [-w W] [-n N] [SUITE] [METHOD] + Run benchmarks in-process (fast, saves ASV-compatible results) + run --asv [SUITE] Run benchmarks via ASV (subprocess isolation per benchmark) + compare [REF] [NEW] Compare two commits (default: HEAD~1 vs HEAD) + view Generate HTML dashboard and open preview server + list List available benchmark suites + +EOF +} + +case "${1:-}" in + setup) + MACHINE="${2:-$(hostname)}" + echo "Registering machine as: $MACHINE" + asv machine --yes --machine "$MACHINE" --config "$ASV_CONF" + ;; + run) + shift + if [[ "${1:-}" == "--asv" ]]; then + shift + CMD=(asv run --config "$ASV_CONF" --python=same --launch-method spawn + --set-commit-hash "$(git rev-parse HEAD)") + [[ -n "${1:-}" ]] && CMD+=(--bench "$1") + echo "Running (asv): ${CMD[*]}" + "${CMD[@]}" + else + # Default: fast in-process run (--all when no suite given) + if [[ $# -eq 0 ]]; then + python "$BENCH_DIR/driver.py" --all + else + python "$BENCH_DIR/driver.py" "$@" + fi + fi + ;; + compare) + REF="${2:-HEAD~1}" + NEW="${3:-HEAD}" + echo "Comparing $REF vs $NEW" + asv continuous --config "$ASV_CONF" --python=same --launch-method spawn "$REF" "$NEW" + ;; + view) + asv publish --config "$ASV_CONF" + echo "Starting preview server at http://localhost:8080" + asv preview --config "$ASV_CONF" + ;; + list) + echo "Available benchmark suites:" + ls "$BENCH_DIR"/bench_*.py 2>/dev/null | sed 's|.*/bench_| bench_|;s|\.py$||' + ;; + *) + usage + exit 1 + ;; +esac