-
Notifications
You must be signed in to change notification settings - Fork 25
ASV-format microbenchmark suite #487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Micky774
wants to merge
15
commits into
dev
Choose a base branch
from
zain/asv-demo
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
d7c643c
Initial benchmark porting to ASV
Micky774 b829122
Update casting benchmark
Micky774 21678b4
Added helper script and documentation
Micky774 6cb91a5
Corrected local benchmarking
Micky774 1a98989
Added direct-run option to bypass subprocess overhead
Micky774 498f16d
Refactor to prefer direct runs, and moved asv conf
Micky774 1e41715
Allowed for direct run of bench files
Micky774 c1e489d
Remove CI component
Micky774 9772f2d
Rename direct_run to driver
Micky774 770a3f0
Refactored driver, streamlined README.md
Micky774 aa2a4a1
Updated to CUDA event based timing
Micky774 a2e5999
Added throughput/bandwidth calc, improved driver
Micky774 89ebfa5
Streamline and clean code
Micky774 91b6b2c
Updated readme, simplified helper script
Micky774 1b5d042
Updated docstrings to include config sources
Micky774 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,3 +55,4 @@ artifacts/ | |
| **/times.csv | ||
| transformer_engine/build_info.txt | ||
| transformer_engine/common/util/hip_nvml.* | ||
| .asv/ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| # ASV Benchmarks for TransformerEngine | ||
|
|
||
| Performance benchmarks built on [ASV (Air Speed Velocity)](https://asv.readthedocs.io/), | ||
| a framework for benchmarking Python packages over their lifetime. | ||
|
|
||
| ## Prerequisites | ||
|
|
||
| - TransformerEngine must already be built and installed in the current Python environment. | ||
| - A ROCm or CUDA GPU must be available. | ||
| - Install ASV: `pip install asv` | ||
|
|
||
| ASV is configured with `environment_type: "existing"` (in `benchmarks/asv/asv.conf.json`), | ||
| meaning it uses the current Python environment directly — it does not create virtualenvs or | ||
| attempt to build TE itself. The config sets `branches: ["HEAD"]` so that `asv publish` accepts results from | ||
| whichever branch is currently checked out — this works for both local development | ||
| and CI (where `HEAD` points to `dev`). | ||
|
|
||
| ## Running benchmarks | ||
|
|
||
| ### Direct execution (recommended for development) | ||
|
|
||
| Each `bench_*.py` file is directly executable. Results are saved in ASV-compatible | ||
| format by default. | ||
|
|
||
| ```bash | ||
| cd benchmarks/asv | ||
| python driver.py --all # run every suite | ||
| python driver.py bench_gemm # run one suite via driver | ||
| python bench_gemm.py # run one suite directly | ||
| python bench_gemm.py time_forward # filter to a specific method | ||
| python bench_gemm.py -w 5 -n 20 # custom warmup/iteration counts | ||
| python bench_casting.py --no-save # skip saving results | ||
| ``` | ||
|
|
||
| ### Helper script | ||
|
|
||
| `run_benchmarks.sh` wraps common tasks and can be run from anywhere. | ||
|
|
||
| ```bash | ||
| bash benchmarks/asv/run_benchmarks.sh <command> [options] | ||
| ``` | ||
|
|
||
| | Command | Description | | ||
| |---|---| | ||
| | `setup [name]` | Register machine with ASV (defaults to `hostname`) | | ||
| | `run [suite] [method]` | Run benchmarks in-process (fast, saves ASV-compatible results) | | ||
| | `run --asv [suite]` | Run via ASV subprocess isolation (for CI or statistical rigor) | | ||
| | `compare [ref] [new]` | Compare two commits (defaults to `HEAD~1` vs `HEAD`) | | ||
| | `view` | Generate HTML dashboard and serve on `localhost:8080` | | ||
| | `list` | List available benchmark suites | | ||
|
|
||
| ### Manual ASV commands | ||
|
|
||
| All `asv` commands require `--config` with an **absolute path** and should be run | ||
| from the **repo root**. The common flags are: | ||
|
|
||
| ```bash | ||
| ASV="asv --config $(pwd)/benchmarks/asv/asv.conf.json" | ||
| COMMON="--python=same --launch-method spawn --set-commit-hash $(git rev-parse HEAD)" | ||
| ``` | ||
|
|
||
| - `--python=same` — use the current interpreter (required with `environment_type: "existing"`) | ||
| - `--launch-method spawn` — required for CUDA/ROCm (fork causes reinitialization errors) | ||
| - `--set-commit-hash` — **required** with `environment_type: "existing"`, otherwise ASV silently discards results | ||
|
|
||
| ```bash | ||
| $ASV machine --yes --machine mi325 # register machine | ||
| $ASV run $COMMON # run all benchmarks | ||
| $ASV run $COMMON --bench bench_casting # single suite (regex match) | ||
| $ASV continuous $COMMON HEAD~1 HEAD # compare two commits | ||
| $ASV publish && $ASV preview # HTML dashboard on localhost:8080 | ||
| ``` | ||
|
|
||
| ## How results are stored | ||
|
|
||
| ### Local results | ||
|
|
||
| ASV stores results as JSON files under `benchmarks/.asv/results/`: | ||
|
|
||
| ``` | ||
| benchmarks/.asv/results/ | ||
| my-machine-name/ | ||
| machine.json # Hardware/OS metadata | ||
| <commit-hash>.json # Timing results for that commit | ||
| <commit-hash>.json | ||
| ... | ||
| ``` | ||
|
|
||
| Each commit JSON contains the wall-clock timings for every benchmark + parameter combination | ||
| run on that machine. The `benchmarks/.asv/` directory is in `.gitignore`. | ||
|
|
||
| ## Writing new benchmarks | ||
|
|
||
| Create a new file in `benchmarks/asv/` following the naming convention `bench_<name>.py`. | ||
|
|
||
| ```python | ||
| #!/usr/bin/env python3 | ||
| import torch | ||
| import transformer_engine.pytorch as te | ||
|
|
||
| class BenchSomething: | ||
| params = [[1024, 4096], ["config_a", "config_b"]] | ||
| param_names = ["M", "config"] | ||
| timeout = 300 # seconds, per parameter combination | ||
|
|
||
| def setup(self, M, config): | ||
| # Allocate tensors, create modules. | ||
| # This runs before each time_* method but is NOT timed. | ||
| self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] | ||
| ... | ||
|
|
||
| def time_forward(self, M, config): | ||
| # Use CUDA events for accurate GPU timing. | ||
| # Return elapsed seconds — the driver uses this instead of wall time. | ||
| self._evt[0].record() | ||
| self.module(self.x) | ||
| self._evt[1].record() | ||
| torch.cuda.synchronize() | ||
| return self._evt[0].elapsed_time(self._evt[1]) / 1000 | ||
|
|
||
| # Optional: define work_<name> to get throughput columns (TFLOPS / GB/s). | ||
| def work_forward(self, M, config): | ||
| return {"flops": 2 * M * self.N * self.K} # compute-bound | ||
| # return {"bytes": M * self.hidden * 4} # memory-bound | ||
|
|
||
| if __name__ == "__main__": | ||
| from driver import run_as_main | ||
| run_as_main(__file__) | ||
| ``` | ||
|
|
||
| Key rules: | ||
| - Method names starting with `time_` are automatically timed. | ||
| - Use CUDA events and return elapsed seconds for accurate GPU timing. | ||
| - Optionally define `work_<name>` companions to get TFLOPS or GB/s columns. | ||
| - Clear `.grad` attributes in backward benchmarks to prevent memory accumulation. | ||
| - The `params` list defines a cross-product; keep the matrix size reasonable. |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| { | ||
| "version": 1, | ||
| "project": "TransformerEngine", | ||
| "project_url": "https://github.com/ROCm/TransformerEngine", | ||
| "repo": "../..", | ||
| "branches": ["HEAD"], | ||
| "environment_type": "existing", | ||
| "install_command": [], | ||
| "build_command": [], | ||
| "benchmark_dir": ".", | ||
| "results_dir": "../.asv/results", | ||
| "html_dir": "../.asv/html", | ||
| "install_timeout": 600, | ||
| "benchmark_timeout": 1200, | ||
| "launch_method": "spawn" | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| #!/usr/bin/env python3 | ||
| ############################################################################### | ||
| # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
| ############################################################################### | ||
| """ | ||
| Attention micro-benchmark using te.DotProductAttention. | ||
|
|
||
| Benchmarks fused multi-head attention (with flash attention backend) for | ||
| model configurations with grouped-query attention (GQA). | ||
|
|
||
| Models: | ||
| - Llama 3 8B (TP=1, TP=8), 70B (TP=8), 405B (TP=8) | ||
| - Qwen 2.5 7B (TP=1), 72B (TP=8) | ||
|
|
||
| Forward FLOPs = 4 * batch * num_q_heads * seq_len^2 * head_dim | ||
| (two matmuls: Q@K^T and attn@V, each contributing 2*b*h*s^2*d) | ||
| Backward FLOPs = 2 * Forward FLOPs (approximately) | ||
|
|
||
| Sources for model configs: | ||
| https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json | ||
| https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json | ||
| https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json | ||
| https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json | ||
| https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json | ||
|
|
||
| Forward FLOPs = 4 * batch * num_q_heads * seq_len^2 * head_dim | ||
| Backward FLOPs ~ 2x forward | ||
| """ | ||
|
|
||
| import torch | ||
| import transformer_engine.pytorch as te | ||
|
|
||
| BATCH = 2 | ||
|
|
||
| # (num_q_heads, num_kv_heads, head_dim, tp) | ||
| MODELS = { | ||
| "Llama3-8B_TP1": (32, 8, 128, 1), | ||
| "Llama3-8B_TP8": (32, 8, 128, 8), | ||
| "Llama3-70B_TP8": (64, 8, 128, 8), | ||
| "Llama3-405B_TP8": (128, 8, 128, 8), | ||
| "Qwen2.5-7B_TP1": (28, 4, 128, 1), | ||
| "Qwen2.5-72B_TP8": (64, 8, 128, 8), | ||
| } | ||
|
|
||
|
|
||
| class BenchAttention: | ||
| params = [[1024, 2048, 4096, 8192], list(MODELS)] | ||
| param_names = ["seq_len", "model"] | ||
| timeout = 300 | ||
|
|
||
| def setup(self, seq_len, model): | ||
| n_q, n_kv, hd, tp = MODELS[model] | ||
| qh, kvh = n_q // tp, n_kv // tp | ||
| dtype = torch.bfloat16 | ||
|
|
||
| self.attn = te.DotProductAttention( | ||
| num_attention_heads=qh, kv_channels=hd, | ||
| num_gqa_groups=kvh, attn_mask_type="causal", | ||
| ).to(device="cuda", dtype=dtype) | ||
|
|
||
| self.q = torch.randn(seq_len, BATCH, qh, hd, dtype=dtype, device="cuda", requires_grad=True) | ||
| self.k = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True) | ||
| self.v = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True) | ||
| self.grad_out = torch.randn_like(self.attn(self.q, self.k, self.v)) | ||
| self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] | ||
|
|
||
| def work_forward(self, seq_len, model): | ||
| n_q, n_kv, hd, tp = MODELS[model] | ||
| qh = n_q // tp | ||
| return {"flops": 4 * BATCH * qh * seq_len * seq_len * hd} | ||
|
|
||
| def work_forward_backward(self, seq_len, model): | ||
| n_q, n_kv, hd, tp = MODELS[model] | ||
| qh = n_q // tp | ||
| return {"flops": 3 * 4 * BATCH * qh * seq_len * seq_len * hd} | ||
|
|
||
| def time_forward(self, seq_len, model): | ||
| self._evt[0].record() | ||
| self.attn(self.q, self.k, self.v) | ||
| self._evt[1].record() | ||
| torch.cuda.synchronize() | ||
| return self._evt[0].elapsed_time(self._evt[1]) / 1000 | ||
|
|
||
| def time_forward_backward(self, seq_len, model): | ||
| self._evt[0].record() | ||
| out = self.attn(self.q, self.k, self.v) | ||
| out.backward(self.grad_out) | ||
| self._evt[1].record() | ||
| torch.cuda.synchronize() | ||
| self.q.grad = self.k.grad = self.v.grad = None | ||
| return self._evt[0].elapsed_time(self._evt[1]) / 1000 | ||
|
|
||
| if __name__ == "__main__": | ||
| from driver import run_as_main | ||
| run_as_main(__file__) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| #!/usr/bin/env python3 | ||
| ############################################################################### | ||
| # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
| ############################################################################### | ||
| """ | ||
| Benchmarks quantization (BF16 -> FP8) and dequantization (FP8 -> BF16) for | ||
| both E4M3 (activations/weights) and E5M2 (gradients) formats. | ||
|
|
||
| Shapes are (M, hidden_size) matching the activation tensors from models: | ||
| - Llama 3.1 8B, 70B, 405B | ||
| - Qwen 2.5 7B, 72B | ||
|
|
||
| These casts are memory-bound; we report GB/s (input + output bytes). | ||
|
|
||
| Sources for model configs: | ||
| https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json | ||
| https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json | ||
| https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json | ||
| https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json | ||
| https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json | ||
| """ | ||
|
|
||
| import torch | ||
| from transformer_engine.pytorch import Float8CurrentScalingQuantizer | ||
| from transformer_engine_torch import DType as TE_DType | ||
|
|
||
| HIDDEN_SIZES = { | ||
| "Llama3-8B": 4096, | ||
| "Llama3-70B": 8192, | ||
| "Llama3-405B": 16384, | ||
| "Qwen2.5-7B": 3584, | ||
| "Qwen2.5-72B": 8192, | ||
| } | ||
|
|
||
| CAST_CONFIGS = { | ||
| "BF16_to_E4M3": ("quantize", TE_DType.kFloat8E4M3), | ||
| "E4M3_to_BF16": ("dequantize", TE_DType.kFloat8E4M3), | ||
| "BF16_to_E5M2": ("quantize", TE_DType.kFloat8E5M2), | ||
| "E5M2_to_BF16": ("dequantize", TE_DType.kFloat8E5M2), | ||
| } | ||
|
|
||
|
|
||
| class BenchCasting: | ||
| params = [[1024, 2048, 4096, 8192], list(HIDDEN_SIZES), list(CAST_CONFIGS)] | ||
| param_names = ["M", "model", "cast"] | ||
| timeout = 120 | ||
|
|
||
| def setup(self, M, model, cast): | ||
| hidden = HIDDEN_SIZES[model] | ||
| direction, fp8_dtype = CAST_CONFIGS[cast] | ||
| self.direction = direction | ||
| quantizer = Float8CurrentScalingQuantizer( | ||
| fp8_dtype=fp8_dtype, | ||
| device=torch.device("cuda"), | ||
| rowwise=True, | ||
| columnwise=False, | ||
| ) | ||
| if direction == "dequantize": | ||
| bf16_tensor = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") | ||
| self.x = quantizer.quantize(bf16_tensor) | ||
| else: | ||
| self.x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") | ||
| self.quantizer = quantizer | ||
| self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] | ||
|
|
||
| def work_cast(self, M, model, cast): | ||
| hidden = HIDDEN_SIZES[model] | ||
| direction = CAST_CONFIGS[cast][0] | ||
| if direction == "quantize": | ||
| # Read BF16 (2B) + write FP8 (1B) + write scale | ||
| return {"bytes": M * hidden * 3} | ||
| else: | ||
| # Read FP8 (1B) + read scale + write BF16 (2B) | ||
| return {"bytes": M * hidden * 3} | ||
|
|
||
| def time_cast(self, M, model, cast): | ||
| self._evt[0].record() | ||
| if self.direction == "quantize": | ||
| self.quantizer.quantize(self.x) | ||
| else: | ||
| self.x.dequantize(dtype=torch.bfloat16) | ||
| self._evt[1].record() | ||
| torch.cuda.synchronize() | ||
| return self._evt[0].elapsed_time(self._evt[1]) / 1000 | ||
|
|
||
| if __name__ == "__main__": | ||
| from driver import run_as_main | ||
| run_as_main(__file__) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it need to be in root of TE?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I've updated it