diff --git a/examples/minimal/ShardTensorExamples/6_ring_attention/README.md b/examples/minimal/ShardTensorExamples/6_ring_attention/README.md index 4877f4a144..46d2a74d0b 100644 --- a/examples/minimal/ShardTensorExamples/6_ring_attention/README.md +++ b/examples/minimal/ShardTensorExamples/6_ring_attention/README.md @@ -31,7 +31,14 @@ torchrun --nproc-per-node 4 benchmark_sharded_attention.py \ | `--dtype` | `float32` | `float32`, `float16`, or `bfloat16` | | `--num_warmup` | 5 | Warmup iterations | | `--num_iterations` | 10 | Timed iterations | -| `--output_file` | — | Path to write JSON results | +| `--results_dir` | `/results/` | Directory for the auto-named JSON output | +| `--print-only` | off | Skip writing JSON; print to stdout only | + +By default the benchmark writes a JSON file to `results/` whose name +encodes the run configuration, e.g. +`single_gpu_inference_float32_seq4096.json` or +`distributed_4gpu_train_bfloat16_seq8192.json`. Pass `--print-only` to +disable this. ## Plotting results diff --git a/examples/minimal/ShardTensorExamples/6_ring_attention/benchmark_sharded_attention.py b/examples/minimal/ShardTensorExamples/6_ring_attention/benchmark_sharded_attention.py index 99704a3cb4..3844b954db 100644 --- a/examples/minimal/ShardTensorExamples/6_ring_attention/benchmark_sharded_attention.py +++ b/examples/minimal/ShardTensorExamples/6_ring_attention/benchmark_sharded_attention.py @@ -25,6 +25,7 @@ import argparse import json from datetime import datetime +from pathlib import Path import numpy as np import torch @@ -36,6 +37,13 @@ from physicsnemo.utils import Profiler +# Default output directory for benchmark JSON results, sibling to this script. +# Filenames are built to match the regex consumed by ``plot_scaling_results.py``: +# ___seq.json +# where ```` is either ``single_gpu`` or ``distributed_gpu``. +_SCRIPT_DIR = Path(__file__).resolve().parent +DEFAULT_RESULTS_DIR = _SCRIPT_DIR / "results" + def parse_args(): """Parse command-line arguments for the attention benchmark. @@ -43,7 +51,8 @@ def parse_args(): Returns: argparse.Namespace: Parsed arguments including seq_len, num_heads, head_dim, batch_size, warmup/iteration counts, dtype, benchmark - mode (inference or train), and an optional output file path. + mode (inference or train), the results directory, and a + ``--print-only`` flag that disables JSON output. """ parser = argparse.ArgumentParser( description="Benchmark scaled_dot_product_attention: single GPU vs ShardTensor" @@ -85,14 +94,37 @@ def parse_args(): help="Benchmark mode: 'inference' (forward only) or 'train' (forward + backward)", ) parser.add_argument( - "--output_file", + "--results_dir", type=str, - default=None, - help="Path to write JSON results file. If not set, results are only printed.", + default=str(DEFAULT_RESULTS_DIR), + help=( + "Directory in which to write the JSON results file. " + "The filename is auto-generated to match the format expected by " + "plot_scaling_results.py. Ignored when --print-only is set." + ), + ) + parser.add_argument( + "--print-only", + dest="print_only", + action="store_true", + help="Print results to stdout only; do not write a JSON file.", ) return parser.parse_args() +def build_output_filename( + *, distributed: bool, world_size: int, mode: str, dtype: str, seq_len: int +) -> str: + """Build a results filename compatible with ``plot_scaling_results.py``. + + Format: ``___seq.json`` where ```` + is ``single_gpu`` for non-distributed runs and ``distributed_gpu`` + otherwise. + """ + topology = f"distributed_{world_size}gpu" if distributed else "single_gpu" + return f"{topology}_{mode}_{dtype}_seq{seq_len}.json" + + DTYPE_MAP = { "float32": torch.float32, "float16": torch.float16, @@ -317,10 +349,20 @@ def main(): f"Max peak allocated (across {dm.world_size} ranks): {max_peak_allocated_across_ranks / mb:.2f} MB" ) - if args.output_file: - with open(args.output_file, "w") as f: + if not args.print_only: + results_dir = Path(args.results_dir).expanduser() + results_dir.mkdir(parents=True, exist_ok=True) + fname = build_output_filename( + distributed=distributed, + world_size=dm.world_size, + mode=args.mode, + dtype=args.dtype, + seq_len=S, + ) + output_path = results_dir / fname + with open(output_path, "w") as f: json.dump(results, f, indent=2) - print(f"Results saved to {args.output_file}") + print(f"Results saved to {output_path}") if __name__ == "__main__":