From feb30e6ed3ed7d25aa8c9229ac96a784ebb049c1 Mon Sep 17 00:00:00 2001 From: Subrahmanya Pavankumar Dubagunta Date: Tue, 5 May 2026 16:00:58 -0500 Subject: [PATCH] Move measurement out of kernel.py into the harness for 5 tasks Closes the "kernel.py contains the timer" attack surface in geak_eval tasks where test_kernel_harness.py imported benchmark_config (or run_benchmark/run_correctness/run_profile) from kernel.py. An agent optimising kernel.py could legitimately rewrite the imported timer function and inflate the reported speedup with no actual GPU change (observed previously on refk_identity at 12.07x). For each of the 5 in-scope tasks the harness now follows the fused_rms_fp8 structural convention: - refk_identity, refk_fp8_blockwise_mm, ff_backward, lean_atten_paged, gemm_a16wfp4 Per task: - kernel.py: deleted benchmark_config / check_correctness / run_* / evaluate / __main__ blocks, the in-kernel WARMUP/ITERATIONS module constants, the torch_op wrapper, and any *_pytorch reference. - test_kernel_harness.py: rewritten as a self-contained script that * imports only kernel callables, get_inputs, shape configs, and tolerance constants from kernel.py; * carries the moved PyTorch reference and uses it ONLY in run_correctness via torch.testing.assert_close; * times only the candidate kernel via torch.cuda.Event median; * exposes the standard --correctness / --benchmark / --full-benchmark / --profile modes plus --warmup (default 50) and --iterations (default 200); * emits GEAK_SHAPES_USED=[...] and GEAK_RESULT_LATENCY_MS=; * does NOT compute a speedup in-process (no GEAK_RESULT_GEOMEAN_SPEEDUP). Speedup is now the orchestrator's job: run the harness twice (unmodified vs patched kernel) and divide. Notes: - gemm_a16wfp4: kernel.py was already kernel-only. Harness rewritten to swap perf_counter timing for cuda.Event median and standardise the iteration defaults to 50/200 (was 5/10 and 5/20). - refk_fp8_blockwise_mm: the new benchmark loop allocates the output c once per cfg (the OLD benchmark_config did c.clone() per iteration, which incorrectly timed an allocation alongside the kernel). The baseline absolute ms will drop slightly on this task as a result. - lean_atten_paged: --correctness uses CORRECTNESS_CONFIGS (not HARNESS_SHAPES) to match the kernel.py-side convention. Out of scope (intentionally): orchestrator file-allowlist on patches, which is the sufficient condition to fully defend against an agent that patches the harness file directly. Co-authored-by: Cursor --- .../L1/refk_fp8_blockwise_mm/kernel.py | 193 +-------- .../test_kernel_harness.py | 299 +++++++++---- .../geak_eval/L1/refk_identity/kernel.py | 167 +------ .../L1/refk_identity/test_kernel_harness.py | 267 ++++++++---- .../geak_eval/L2/ff_backward/kernel.py | 202 +-------- .../L2/ff_backward/test_kernel_harness.py | 329 +++++++++----- .../geak_eval/L2/lean_atten_paged/kernel.py | 408 ------------------ .../lean_atten_paged/test_kernel_harness.py | 398 ++++++++--------- .../L3/gemm_a16wfp4/test_kernel_harness.py | 348 +++++++++------ 9 files changed, 1039 insertions(+), 1572 deletions(-) diff --git a/tasks/triton2triton/geak_eval/L1/refk_fp8_blockwise_mm/kernel.py b/tasks/triton2triton/geak_eval/L1/refk_fp8_blockwise_mm/kernel.py index 77646be..97cfb7a 100644 --- a/tasks/triton2triton/geak_eval/L1/refk_fp8_blockwise_mm/kernel.py +++ b/tasks/triton2triton/geak_eval/L1/refk_fp8_blockwise_mm/kernel.py @@ -16,10 +16,6 @@ c: [m, n] bfloat16 (pre-allocated output) """ -import math -import os -import time - import torch import triton import triton.language as tl @@ -177,36 +173,7 @@ def fp8_blockwise_mm_triton(a, b, a_scale, b_scale, c): # ============================================================================ -# REFERENCE IMPLEMENTATION (pure PyTorch — same as submission.py) -# ============================================================================ - - -def fp8_blockwise_mm_pytorch(a, b, a_scale, b_scale, c): - a_c = a.contiguous() - a_s = a_scale.contiguous() - b_s = b_scale.contiguous() - - m, k = a_c.shape - n = b.shape[0] - block_n, block_k = BLOCK_SHAPE_N, BLOCK_SHAPE_K - sn = b_s.shape[0] - sk = b_s.shape[1] - - a_sc = a_s.unsqueeze(-1).repeat(1, 1, block_k).reshape(m, sk * block_k)[:, :k] - a_deq = a_c.to(a_sc.dtype) * a_sc - - b_sc = (b_s.view(-1, 1).repeat(1, block_n * block_k) - .view(sn, sk, block_n, block_k) - .permute(0, 2, 1, 3) - .reshape(sn * block_n, sk * block_k))[:n, :k] - b_deq = b.to(b_sc.dtype) * b_sc - - c[...] = (a_deq @ b_deq.T).to(torch.bfloat16) - return c - - -# ============================================================================ -# ENTRY POINTS (for GEAK harness) +# ENTRY POINT (callable form of the kernel for downstream profilers) # ============================================================================ @@ -215,11 +182,6 @@ def triton_op(m, n, k, seed): return fp8_blockwise_mm_triton(*data) -def torch_op(m, n, k, seed): - data = _generate_input(m, n, k, seed) - return fp8_blockwise_mm_pytorch(*data) - - # ============================================================================ # SYNTHETIC INPUT BUILDER (matches reference.py generate_input) # ============================================================================ @@ -296,157 +258,4 @@ def get_inputs(m, n, k, seed=42, device="cuda"): {"m": 6144, "n": 4608, "k": 7168, "seed": 65436}, ] -WARMUP = 50 -ITERATIONS = int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")) RTOL, ATOL = 2e-2, 1e-3 - - -# ============================================================================ -# SELF-TEST HARNESS -# ============================================================================ - - -def check_correctness(cfg) -> dict: - try: - data = get_inputs(**cfg) - a, b, a_scale, b_scale, c_triton = data - c_ref = c_triton.clone() - - fp8_blockwise_mm_triton(a, b, a_scale, b_scale, c_triton) - fp8_blockwise_mm_pytorch(a, b, a_scale, b_scale, c_ref) - torch.cuda.synchronize() - - correct = torch.allclose(c_triton.float(), c_ref.float(), rtol=RTOL, atol=ATOL) - max_diff = torch.max(torch.abs(c_triton.float() - c_ref.float())).item() - return {"correct": correct, "max_diff": max_diff, "error": None} - except Exception as e: - return {"correct": False, "max_diff": float("inf"), "error": str(e)} - - -def benchmark_config(cfg, warmup=WARMUP, iters=ITERATIONS) -> dict: - data = get_inputs(**cfg) - a, b, a_scale, b_scale, c = data - - for _ in range(warmup): - c_t = c.clone() - fp8_blockwise_mm_triton(a, b, a_scale, b_scale, c_t) - torch.cuda.synchronize() - - start = time.perf_counter() - for _ in range(iters): - c_t = c.clone() - fp8_blockwise_mm_triton(a, b, a_scale, b_scale, c_t) - torch.cuda.synchronize() - triton_ms = (time.perf_counter() - start) * 1000 / iters - - for _ in range(warmup): - c_r = c.clone() - fp8_blockwise_mm_pytorch(a, b, a_scale, b_scale, c_r) - torch.cuda.synchronize() - - start = time.perf_counter() - for _ in range(iters): - c_r = c.clone() - fp8_blockwise_mm_pytorch(a, b, a_scale, b_scale, c_r) - torch.cuda.synchronize() - torch_ms = (time.perf_counter() - start) * 1000 / iters - - return {"triton_ms": triton_ms, "torch_ms": torch_ms, - "speedup": torch_ms / triton_ms if triton_ms > 0 else 0.0} - - -def _config_label(cfg): - return f"(M={cfg['m']},N={cfg['n']},K={cfg['k']})" - - -def evaluate(configs=None, warmup=WARMUP, iters=ITERATIONS, verbose=True) -> dict: - configs = configs or TEST_CONFIGS[:5] - results, failures = [], [] - - if verbose: - print(f"{'Config':<26} {'Correct':>8} {'Torch':>10} {'Triton':>10} {'Speedup':>10}") - print("-" * 66) - - for cfg in configs: - label = _config_label(cfg) - corr = check_correctness(cfg) - if not corr["correct"]: - failures.append({"config": cfg, **corr}) - if verbose: - err = corr["error"] or f"max_diff={corr['max_diff']:.2e}" - print(f"{label:<26} {'FAIL':>8} {err[:30]}") - continue - - bench = benchmark_config(cfg, warmup=warmup, iters=iters) - results.append({"config": cfg, "correct": True, **bench}) - - if verbose: - marker = " *" if bench["speedup"] > 1.0 else "" - print( - f"{label:<26} {'PASS':>8} " - f"{bench['torch_ms']:>8.3f}ms {bench['triton_ms']:>8.3f}ms " - f"{bench['speedup']:>8.2f}x{marker}" - ) - - speedups = [r["speedup"] for r in results] - geomean = math.prod(speedups) ** (1 / len(speedups)) if speedups else 0.0 - - if verbose: - print("-" * 66) - status = "ALL PASS" if not failures else f"FAILED ({len(failures)}/{len(configs)})" - print(f"{'Status:':<26} {status}") - if speedups: - print(f"{'Speedup (geomean):':<26} {geomean:.2f}x") - - return { - "correct": len(failures) == 0, - "num_correct": len(results), - "num_failed": len(failures), - "failures": failures, - "results": results, - "speedup_geomean": geomean, - } - - -def run_profile(configs=None, warmup=5, iters=1, verbose=True): - configs = configs or PROFILE_CONFIGS - if verbose: - print(f"Profile: {len(configs)} config(s)") - for cfg in configs: - data = get_inputs(**cfg) - a, b, a_scale, b_scale, c = data - for _ in range(warmup): - ct = c.clone() - fp8_blockwise_mm_triton(a, b, a_scale, b_scale, ct) - torch.cuda.synchronize() - for _ in range(iters): - ct = c.clone() - fp8_blockwise_mm_triton(a, b, a_scale, b_scale, ct) - torch.cuda.synchronize() - if verbose: - print(f" {_config_label(cfg)} done") - - -# ============================================================================ -# MAIN -# ============================================================================ - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="FP8 Block-Scale GEMM (Triton dequant)") - parser.add_argument("--profile", action="store_true") - args = parser.parse_args() - - print("=" * 66) - print("FP8 Block-Scale GEMM — Triton dequant + torch.mm") - print("=" * 66) - - if args.profile: - print("\n[Profile Mode]") - run_profile() - else: - print("\n[Evaluation]") - evaluate() - - print("=" * 66) diff --git a/tasks/triton2triton/geak_eval/L1/refk_fp8_blockwise_mm/test_kernel_harness.py b/tasks/triton2triton/geak_eval/L1/refk_fp8_blockwise_mm/test_kernel_harness.py index 37eb395..a6a55cb 100644 --- a/tasks/triton2triton/geak_eval/L1/refk_fp8_blockwise_mm/test_kernel_harness.py +++ b/tasks/triton2triton/geak_eval/L1/refk_fp8_blockwise_mm/test_kernel_harness.py @@ -1,92 +1,225 @@ #!/usr/bin/env python3 -"""Generic test harness wrapping kernel.py's built-in test functions.""" +""" +Test harness for the FP8 block-scale GEMM Triton kernel. + +Modes: --correctness, --profile, --benchmark, --full-benchmark +""" import argparse import math import os import sys +import torch + +from kernel import ( + fp8_blockwise_mm_triton, get_inputs, + EVAL_CONFIGS, PROFILE_CONFIGS, + BLOCK_SHAPE_N, BLOCK_SHAPE_K, + RTOL, ATOL, +) + + +# ============================================================================ +# SHAPE SUBSETS +# ============================================================================ + +ALL_SHAPES = EVAL_CONFIGS + +_n_all = len(ALL_SHAPES) +if _n_all <= 25: + HARNESS_SHAPES = ALL_SHAPES +else: + _harness_indices = [int(round(i * (_n_all - 1) / 24)) for i in range(25)] + HARNESS_SHAPES = [ALL_SHAPES[i] for i in _harness_indices] + +# PROFILE_CONFIGS already provides the profile subset (3 entries) per kernel.py. +PROFILE_SHAPES = PROFILE_CONFIGS + + +# ============================================================================ +# PYTORCH REFERENCE (moved from kernel.py; correctness-only) +# ============================================================================ + +def fp8_blockwise_mm_pytorch(a, b, a_scale, b_scale, c): + a_c = a.contiguous() + a_s = a_scale.contiguous() + b_s = b_scale.contiguous() + + m, k = a_c.shape + n = b.shape[0] + block_n, block_k = BLOCK_SHAPE_N, BLOCK_SHAPE_K + sn = b_s.shape[0] + sk = b_s.shape[1] + + a_sc = a_s.unsqueeze(-1).repeat(1, 1, block_k).reshape(m, sk * block_k)[:, :k] + a_deq = a_c.to(a_sc.dtype) * a_sc + + b_sc = (b_s.view(-1, 1).repeat(1, block_n * block_k) + .view(sn, sk, block_n, block_k) + .permute(0, 2, 1, 3) + .reshape(sn * block_n, sk * block_k))[:n, :k] + b_deq = b.to(b_sc.dtype) * b_sc + + c[...] = (a_deq @ b_deq.T).to(torch.bfloat16) + return c + + +# ============================================================================ +# TEST HARNESS +# ============================================================================ + +def _label(cfg): + return f"M={cfg['m']:>5}, N={cfg['n']:>5}, K={cfg['k']:>5}" + + +def run_correctness(shapes=None, verbose=True): + if shapes is None: + shapes = HARNESS_SHAPES + if verbose: + print(f"Running correctness on {len(shapes)} shapes...") + + results, failures = [], [] + + for cfg in shapes: + try: + a, b, a_scale, b_scale, c_triton = get_inputs(**cfg) + c_ref = c_triton.clone() + + fp8_blockwise_mm_triton(a, b, a_scale, b_scale, c_triton) + fp8_blockwise_mm_pytorch(a, b, a_scale, b_scale, c_ref) + torch.cuda.synchronize() + + torch.testing.assert_close(c_triton.float(), c_ref.float(), atol=ATOL, rtol=RTOL) + results.append({"config": cfg, "correct": True}) + if verbose: + print(f" PASS: {_label(cfg)}") + del a, b, a_scale, b_scale, c_triton, c_ref + torch.cuda.empty_cache() + except Exception as e: + failures.append({"config": cfg, "error": str(e)}) + if verbose: + print(f" FAIL: {_label(cfg)} - {str(e)[:80]}") + + if verbose: + print("-" * 62) + status = "ALL PASS" if not failures else f"FAILED ({len(failures)}/{len(shapes)})" + print(f"{'Status:':<22} {status}") -_harness_dir = os.path.dirname(os.path.abspath(__file__)) -if _harness_dir not in sys.path: - sys.path.insert(0, _harness_dir) - -from kernel import EVAL_CONFIGS, check_correctness, benchmark_config - -ALL_CONFIGS = EVAL_CONFIGS -HARNESS_CONFIGS = ALL_CONFIGS # use all configs so benchmark matches full-benchmark - -def _pick(configs, count): - if len(configs) <= count: - return list(range(len(configs))) - n = len(configs) - return [round(i * (n - 1) / (count - 1)) for i in range(count)] - -def run_correctness(configs, indices): - print(f"Running correctness on {len(indices)} configs...") - all_ok = True - for idx in indices: - r = check_correctness(configs[idx]) - tag = f"config[{idx}]" - if r["correct"]: - print(f" PASS {tag}") - else: - print(f" FAIL {tag}: {r.get('error','')[:80]}") - all_ok = False - print(f"GEAK_SHAPES_USED={indices}") - if all_ok: - print("ALL CORRECTNESS CHECKS PASSED") - return 0 - print("CORRECTNESS FAILED") - return 1 - -def run_benchmark(configs, indices, warmup=50, iters=200): - print(f"Running benchmark on {len(indices)} configs...") - lats = [] - for idx in indices: - r = benchmark_config(configs[idx], warmup=warmup, iters=iters) - lat = r.get("triton_ms", 0) - lats.append(lat) - print(f" config[{idx}] {lat:.4f}ms") - valid = [l for l in lats if l > 0] - geo = math.exp(sum(math.log(l) for l in valid) / len(valid)) if valid else 0 - print(f"GEAK_SHAPES_USED={indices}") - print(f"GEAK_RESULT_LATENCY_MS={geo:.4f}") - return 0 - -def run_profile(configs, indices): - from kernel import triton_op, get_inputs - import torch - print(f"Running profile on {len(indices)} configs...") - for idx in indices: - cfg = configs[idx] - for _ in range(3): - if isinstance(cfg, dict): - triton_op(**cfg) - elif isinstance(cfg, (list, tuple)): - triton_op(*cfg) - else: - triton_op(cfg) + return { + "correct": len(failures) == 0, + "num_correct": len(results), + "num_failed": len(failures), + "failures": failures, + "results": results, + } + + +def run_profile(shapes=None, warmup=50, iters=200, verbose=True): + if shapes is None: + shapes = PROFILE_SHAPES + if verbose: + print(f"Profile: {len(shapes)} config(s), {warmup} warmup, {iters} iter(s)") + + for cfg in shapes: + a, b, a_scale, b_scale, c = get_inputs(**cfg) + for _ in range(warmup): + fp8_blockwise_mm_triton(a, b, a_scale, b_scale, c) + torch.cuda.synchronize() + for _ in range(iters): + fp8_blockwise_mm_triton(a, b, a_scale, b_scale, c) + torch.cuda.synchronize() + if verbose: + print(f" {_label(cfg)} done") + del a, b, a_scale, b_scale, c + torch.cuda.empty_cache() + + +def run_benchmark(shapes=None, warmup=50, iters=200, verbose=True): + if shapes is None: + shapes = HARNESS_SHAPES + + latencies = [] + + print(f"Running benchmark on {len(shapes)} shapes, {warmup} warmup, {iters} iterations each...") + if verbose: + print(f"{'Config':<32} {'Triton':>10}") + print("-" * 44) + + for cfg in shapes: + # NOTE: c is mutated in place by the kernel. Allocate inputs ONCE per cfg + # OUTSIDE the timed loop and reuse them; the OLD benchmark_config did + # c.clone() per iteration, which incorrectly timed allocation. + a, b, a_scale, b_scale, c = get_inputs(**cfg) + + for _ in range(warmup): + fp8_blockwise_mm_triton(a, b, a_scale, b_scale, c) torch.cuda.synchronize() - return 0 - -def main(): - iters = int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")) - p = argparse.ArgumentParser() - g = p.add_mutually_exclusive_group(required=True) - g.add_argument("--correctness", action="store_true") - g.add_argument("--benchmark", action="store_true") - g.add_argument("--full-benchmark", action="store_true") - g.add_argument("--profile", action="store_true") - p.add_argument("--iterations", type=int, default=iters) - p.add_argument("--warmup", type=int, default=50) - a = p.parse_args() - if a.correctness: - sys.exit(run_correctness(ALL_CONFIGS, _pick(ALL_CONFIGS, 25))) - elif a.benchmark: - sys.exit(run_benchmark(HARNESS_CONFIGS, _pick(HARNESS_CONFIGS, 25), a.warmup, a.iterations)) - elif a.full_benchmark: - sys.exit(run_benchmark(ALL_CONFIGS, list(range(len(ALL_CONFIGS))), a.warmup, a.iterations)) - elif a.profile: - sys.exit(run_profile(ALL_CONFIGS, _pick(ALL_CONFIGS, 5))) + + triton_times = [] + for _ in range(iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fp8_blockwise_mm_triton(a, b, a_scale, b_scale, c) + end.record() + torch.cuda.synchronize() + triton_times.append(start.elapsed_time(end)) + + triton_ms = sorted(triton_times)[len(triton_times) // 2] + latencies.append(triton_ms) + + if verbose: + print(f"{_label(cfg):<32} {triton_ms:>8.4f}ms", flush=True) + + del a, b, a_scale, b_scale, c + torch.cuda.empty_cache() + + geomean_latency = math.exp(sum(math.log(l) for l in latencies) / len(latencies)) + + print("-" * 44) + print(f"{'Geometric mean latency:':<22} {geomean_latency:.4f} ms") + print(f"GEAK_SHAPES_USED={list(range(len(shapes)))}") + print(f"GEAK_RESULT_LATENCY_MS={geomean_latency:.4f}", flush=True) + + return {"geomean_latency_ms": geomean_latency, "latencies": latencies} + + +# ============================================================================ +# MAIN +# ============================================================================ if __name__ == "__main__": - main() + parser = argparse.ArgumentParser(description="FP8 Block-Scale GEMM Test Harness") + parser.add_argument("--correctness", action="store_true", + help="Run correctness tests on HARNESS_SHAPES") + parser.add_argument("--profile", action="store_true", + help="Run minimal profiling workload") + parser.add_argument("--benchmark", action="store_true", + help="Run benchmark on HARNESS_SHAPES (25 uniformly sampled)") + parser.add_argument("--full-benchmark", action="store_true", + help="Run benchmark on ALL_SHAPES (complete set)") + parser.add_argument("--warmup", type=int, default=50, + help="Number of warmup iterations (default: 50)") + parser.add_argument("--iterations", type=int, + default=int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")), + help="Number of benchmark iterations (default: GEAK_BENCHMARK_ITERATIONS or 200)") + args = parser.parse_args() + + print("=" * 62) + print("FP8 Block-Scale GEMM Test Harness") + print("=" * 62) + + if args.correctness: + print("\n[Correctness Mode]") + result = run_correctness(HARNESS_SHAPES) + sys.exit(0 if result["correct"] else 1) + elif args.profile: + print("\n[Profile Mode]") + run_profile(PROFILE_SHAPES, warmup=args.warmup, iters=args.iterations) + elif args.full_benchmark: + print("\n[Full Benchmark Mode]") + run_benchmark(ALL_SHAPES, warmup=args.warmup, iters=args.iterations) + else: + print("\n[Benchmark Mode]") + run_benchmark(HARNESS_SHAPES, warmup=args.warmup, iters=args.iterations) + + print("=" * 62) diff --git a/tasks/triton2triton/geak_eval/L1/refk_identity/kernel.py b/tasks/triton2triton/geak_eval/L1/refk_identity/kernel.py index 44a21d6..d4b1941 100644 --- a/tasks/triton2triton/geak_eval/L1/refk_identity/kernel.py +++ b/tasks/triton2triton/geak_eval/L1/refk_identity/kernel.py @@ -6,10 +6,6 @@ Triton kernel generated from PyTorch's `output.copy_(input)` on float16 1-D tensors. """ -import math -import os -import time - import torch import triton import triton.language as tl @@ -51,17 +47,7 @@ def identity_triton(input_tensor: torch.Tensor, output_tensor: torch.Tensor) -> # ============================================================================ -# REFERENCE IMPLEMENTATION (pure PyTorch) -# ============================================================================ - - -def identity_pytorch(input_tensor: torch.Tensor, output_tensor: torch.Tensor) -> torch.Tensor: - output_tensor[...] = input_tensor - return output_tensor - - -# ============================================================================ -# ENTRY POINTS (for GEAK harness) +# ENTRY POINT (callable form of the kernel for downstream profilers) # ============================================================================ @@ -74,15 +60,6 @@ def triton_op(size, seed): return identity_triton(data, output) -def torch_op(size, seed): - gen = torch.Generator(device="cuda") - gen.manual_seed(seed) - data = torch.empty(size, device="cuda", dtype=torch.float16) - data.uniform_(0, 1, generator=gen) - output = torch.empty_like(data) - return identity_pytorch(data, output) - - # ============================================================================ # SYNTHETIC INPUT BUILDER # ============================================================================ @@ -125,146 +102,4 @@ def get_inputs(size, seed=42, device="cuda"): {"size": 65536, "seed": 125432}, ] -WARMUP = 50 -ITERATIONS = int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")) RTOL, ATOL = 1e-5, 1e-5 - - -# ============================================================================ -# SELF-TEST HARNESS -# ============================================================================ - - -def check_correctness(cfg) -> dict: - try: - data, out_triton = get_inputs(**cfg) - out_ref = torch.empty_like(data) - identity_triton(data, out_triton) - identity_pytorch(data, out_ref) - torch.cuda.synchronize() - correct = torch.equal(out_triton, out_ref) - max_diff = torch.max(torch.abs(out_triton.float() - out_ref.float())).item() - return {"correct": correct, "max_diff": max_diff, "error": None} - except Exception as e: - return {"correct": False, "max_diff": float("inf"), "error": str(e)} - - -def benchmark_config(cfg, warmup=WARMUP, iters=ITERATIONS) -> dict: - data, output = get_inputs(**cfg) - for _ in range(warmup): - identity_triton(data, output) - torch.cuda.synchronize() - - start = time.perf_counter() - for _ in range(iters): - identity_triton(data, output) - torch.cuda.synchronize() - triton_ms = (time.perf_counter() - start) * 1000 / iters - - output2 = torch.empty_like(data) - for _ in range(warmup): - identity_pytorch(data, output2) - torch.cuda.synchronize() - - start = time.perf_counter() - for _ in range(iters): - identity_pytorch(data, output2) - torch.cuda.synchronize() - torch_ms = (time.perf_counter() - start) * 1000 / iters - - return {"triton_ms": triton_ms, "torch_ms": torch_ms, - "speedup": torch_ms / triton_ms if triton_ms > 0 else 0.0} - - -def _config_label(cfg): - return f"(size={cfg['size']})" - - -def evaluate(configs=None, warmup=WARMUP, iters=ITERATIONS, verbose=True) -> dict: - configs = configs or EVAL_CONFIGS - results, failures = [], [] - - if verbose: - print(f"{'Config':<22} {'Correct':>8} {'Torch':>10} {'Triton':>10} {'Speedup':>10}") - print("-" * 62) - - for cfg in configs: - label = _config_label(cfg) - corr = check_correctness(cfg) - if not corr["correct"]: - failures.append({"config": cfg, **corr}) - if verbose: - err = corr["error"] or f"max_diff={corr['max_diff']:.2e}" - print(f"{label:<22} {'FAIL':>8} {err[:30]}") - continue - - bench = benchmark_config(cfg, warmup=warmup, iters=iters) - results.append({"config": cfg, "correct": True, **bench}) - - if verbose: - marker = " *" if bench["speedup"] > 1.0 else "" - print( - f"{label:<22} {'PASS':>8} " - f"{bench['torch_ms']:>8.4f}ms {bench['triton_ms']:>8.4f}ms " - f"{bench['speedup']:>8.2f}x{marker}" - ) - - speedups = [r["speedup"] for r in results] - geomean = math.prod(speedups) ** (1 / len(speedups)) if speedups else 0.0 - - if verbose: - print("-" * 62) - status = "ALL PASS" if not failures else f"FAILED ({len(failures)}/{len(configs)})" - print(f"{'Status:':<22} {status}") - if speedups: - print(f"{'Speedup (geomean):':<22} {geomean:.2f}x") - - return { - "correct": len(failures) == 0, - "num_correct": len(results), - "num_failed": len(failures), - "failures": failures, - "results": results, - "speedup_geomean": geomean, - } - - -def run_profile(configs=None, warmup=5, iters=1, verbose=True): - configs = configs or PROFILE_CONFIGS - if verbose: - print(f"Profile: {len(configs)} config(s)") - for cfg in configs: - data, output = get_inputs(**cfg) - for _ in range(warmup): - identity_triton(data, output) - torch.cuda.synchronize() - for _ in range(iters): - identity_triton(data, output) - torch.cuda.synchronize() - if verbose: - print(f" {_config_label(cfg)} done") - - -# ============================================================================ -# MAIN -# ============================================================================ - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Identity Kernel (Triton)") - parser.add_argument("--profile", action="store_true") - args = parser.parse_args() - - print("=" * 62) - print("Identity Kernel — Triton (torch.compile extracted)") - print("=" * 62) - - if args.profile: - print("\n[Profile Mode]") - run_profile() - else: - print("\n[Evaluation]") - evaluate() - - print("=" * 62) diff --git a/tasks/triton2triton/geak_eval/L1/refk_identity/test_kernel_harness.py b/tasks/triton2triton/geak_eval/L1/refk_identity/test_kernel_harness.py index 37eb395..42c978b 100644 --- a/tasks/triton2triton/geak_eval/L1/refk_identity/test_kernel_harness.py +++ b/tasks/triton2triton/geak_eval/L1/refk_identity/test_kernel_harness.py @@ -1,92 +1,193 @@ #!/usr/bin/env python3 -"""Generic test harness wrapping kernel.py's built-in test functions.""" +""" +Test harness for the identity copy Triton kernel. + +Modes: --correctness, --profile, --benchmark, --full-benchmark +""" import argparse import math import os import sys +import torch + +from kernel import identity_triton, get_inputs, EVAL_CONFIGS, PROFILE_CONFIGS, RTOL, ATOL + + +# ============================================================================ +# SHAPE SUBSETS +# ============================================================================ + +ALL_SHAPES = EVAL_CONFIGS + +_n_all = len(ALL_SHAPES) +if _n_all <= 25: + HARNESS_SHAPES = ALL_SHAPES +else: + _harness_indices = [int(round(i * (_n_all - 1) / 24)) for i in range(25)] + HARNESS_SHAPES = [ALL_SHAPES[i] for i in _harness_indices] + +# PROFILE_CONFIGS already provides the profile subset (3 entries) per kernel.py. +PROFILE_SHAPES = PROFILE_CONFIGS + + +# ============================================================================ +# PYTORCH REFERENCE (moved from kernel.py; correctness-only) +# ============================================================================ + +def identity_pytorch(input_tensor, output_tensor): + output_tensor[...] = input_tensor + return output_tensor + + +# ============================================================================ +# TEST HARNESS +# ============================================================================ + +def run_correctness(shapes=None, verbose=True): + if shapes is None: + shapes = HARNESS_SHAPES + if verbose: + print(f"Running correctness on {len(shapes)} shapes...") + + results, failures = [], [] + + for cfg in shapes: + try: + data, out_triton = get_inputs(**cfg) + out_ref = torch.empty_like(data) + identity_triton(data, out_triton) + identity_pytorch(data, out_ref) + torch.cuda.synchronize() + + torch.testing.assert_close(out_triton, out_ref, atol=ATOL, rtol=RTOL) + results.append({"config": cfg, "correct": True}) + if verbose: + print(f" PASS: size={cfg['size']}") + del data, out_triton, out_ref + torch.cuda.empty_cache() + except Exception as e: + failures.append({"config": cfg, "error": str(e)}) + if verbose: + print(f" FAIL: size={cfg['size']} - {str(e)[:50]}") + + if verbose: + print("-" * 62) + status = "ALL PASS" if not failures else f"FAILED ({len(failures)}/{len(shapes)})" + print(f"{'Status:':<22} {status}") + + return { + "correct": len(failures) == 0, + "num_correct": len(results), + "num_failed": len(failures), + "failures": failures, + "results": results, + } + + +def run_profile(shapes=None, warmup=50, iters=200, verbose=True): + if shapes is None: + shapes = PROFILE_SHAPES + if verbose: + print(f"Profile: {len(shapes)} config(s), {warmup} warmup, {iters} iter(s)") -_harness_dir = os.path.dirname(os.path.abspath(__file__)) -if _harness_dir not in sys.path: - sys.path.insert(0, _harness_dir) - -from kernel import EVAL_CONFIGS, check_correctness, benchmark_config - -ALL_CONFIGS = EVAL_CONFIGS -HARNESS_CONFIGS = ALL_CONFIGS # use all configs so benchmark matches full-benchmark - -def _pick(configs, count): - if len(configs) <= count: - return list(range(len(configs))) - n = len(configs) - return [round(i * (n - 1) / (count - 1)) for i in range(count)] - -def run_correctness(configs, indices): - print(f"Running correctness on {len(indices)} configs...") - all_ok = True - for idx in indices: - r = check_correctness(configs[idx]) - tag = f"config[{idx}]" - if r["correct"]: - print(f" PASS {tag}") - else: - print(f" FAIL {tag}: {r.get('error','')[:80]}") - all_ok = False - print(f"GEAK_SHAPES_USED={indices}") - if all_ok: - print("ALL CORRECTNESS CHECKS PASSED") - return 0 - print("CORRECTNESS FAILED") - return 1 - -def run_benchmark(configs, indices, warmup=50, iters=200): - print(f"Running benchmark on {len(indices)} configs...") - lats = [] - for idx in indices: - r = benchmark_config(configs[idx], warmup=warmup, iters=iters) - lat = r.get("triton_ms", 0) - lats.append(lat) - print(f" config[{idx}] {lat:.4f}ms") - valid = [l for l in lats if l > 0] - geo = math.exp(sum(math.log(l) for l in valid) / len(valid)) if valid else 0 - print(f"GEAK_SHAPES_USED={indices}") - print(f"GEAK_RESULT_LATENCY_MS={geo:.4f}") - return 0 - -def run_profile(configs, indices): - from kernel import triton_op, get_inputs - import torch - print(f"Running profile on {len(indices)} configs...") - for idx in indices: - cfg = configs[idx] - for _ in range(3): - if isinstance(cfg, dict): - triton_op(**cfg) - elif isinstance(cfg, (list, tuple)): - triton_op(*cfg) - else: - triton_op(cfg) + for cfg in shapes: + data, output = get_inputs(**cfg) + for _ in range(warmup): + identity_triton(data, output) torch.cuda.synchronize() - return 0 - -def main(): - iters = int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")) - p = argparse.ArgumentParser() - g = p.add_mutually_exclusive_group(required=True) - g.add_argument("--correctness", action="store_true") - g.add_argument("--benchmark", action="store_true") - g.add_argument("--full-benchmark", action="store_true") - g.add_argument("--profile", action="store_true") - p.add_argument("--iterations", type=int, default=iters) - p.add_argument("--warmup", type=int, default=50) - a = p.parse_args() - if a.correctness: - sys.exit(run_correctness(ALL_CONFIGS, _pick(ALL_CONFIGS, 25))) - elif a.benchmark: - sys.exit(run_benchmark(HARNESS_CONFIGS, _pick(HARNESS_CONFIGS, 25), a.warmup, a.iterations)) - elif a.full_benchmark: - sys.exit(run_benchmark(ALL_CONFIGS, list(range(len(ALL_CONFIGS))), a.warmup, a.iterations)) - elif a.profile: - sys.exit(run_profile(ALL_CONFIGS, _pick(ALL_CONFIGS, 5))) + for _ in range(iters): + identity_triton(data, output) + torch.cuda.synchronize() + if verbose: + print(f" size={cfg['size']} done") + del data, output + torch.cuda.empty_cache() + + +def run_benchmark(shapes=None, warmup=50, iters=200, verbose=True): + if shapes is None: + shapes = HARNESS_SHAPES + + latencies = [] + + print(f"Running benchmark on {len(shapes)} shapes, {warmup} warmup, {iters} iterations each...") + if verbose: + print(f"{'Config':<22} {'Triton':>10}") + print("-" * 34) + + for cfg in shapes: + data, output = get_inputs(**cfg) + + for _ in range(warmup): + identity_triton(data, output) + torch.cuda.synchronize() + + triton_times = [] + for _ in range(iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + identity_triton(data, output) + end.record() + torch.cuda.synchronize() + triton_times.append(start.elapsed_time(end)) + + triton_ms = sorted(triton_times)[len(triton_times) // 2] + latencies.append(triton_ms) + + if verbose: + print(f"size={cfg['size']:<14} {triton_ms:>8.4f}ms", flush=True) + + del data, output + torch.cuda.empty_cache() + + geomean_latency = math.exp(sum(math.log(l) for l in latencies) / len(latencies)) + + print("-" * 34) + print(f"{'Geometric mean latency:':<22} {geomean_latency:.4f} ms") + print(f"GEAK_SHAPES_USED={list(range(len(shapes)))}") + print(f"GEAK_RESULT_LATENCY_MS={geomean_latency:.4f}", flush=True) + + return {"geomean_latency_ms": geomean_latency, "latencies": latencies} + + +# ============================================================================ +# MAIN +# ============================================================================ if __name__ == "__main__": - main() + parser = argparse.ArgumentParser(description="Identity Kernel Test Harness") + parser.add_argument("--correctness", action="store_true", + help="Run correctness tests on HARNESS_SHAPES") + parser.add_argument("--profile", action="store_true", + help="Run minimal profiling workload") + parser.add_argument("--benchmark", action="store_true", + help="Run benchmark on HARNESS_SHAPES (25 uniformly sampled)") + parser.add_argument("--full-benchmark", action="store_true", + help="Run benchmark on ALL_SHAPES (complete set)") + parser.add_argument("--warmup", type=int, default=50, + help="Number of warmup iterations (default: 50)") + parser.add_argument("--iterations", type=int, + default=int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")), + help="Number of benchmark iterations (default: GEAK_BENCHMARK_ITERATIONS or 200)") + args = parser.parse_args() + + print("=" * 62) + print("Identity Kernel Test Harness") + print("=" * 62) + + if args.correctness: + print("\n[Correctness Mode]") + result = run_correctness(HARNESS_SHAPES) + sys.exit(0 if result["correct"] else 1) + elif args.profile: + print("\n[Profile Mode]") + run_profile(PROFILE_SHAPES, warmup=args.warmup, iters=args.iterations) + elif args.full_benchmark: + print("\n[Full Benchmark Mode]") + run_benchmark(ALL_SHAPES, warmup=args.warmup, iters=args.iterations) + else: + print("\n[Benchmark Mode]") + run_benchmark(HARNESS_SHAPES, warmup=args.warmup, iters=args.iterations) + + print("=" * 62) diff --git a/tasks/triton2triton/geak_eval/L2/ff_backward/kernel.py b/tasks/triton2triton/geak_eval/L2/ff_backward/kernel.py index 513c690..1967553 100755 --- a/tasks/triton2triton/geak_eval/L2/ff_backward/kernel.py +++ b/tasks/triton2triton/geak_eval/L2/ff_backward/kernel.py @@ -13,49 +13,17 @@ _dw_down_kernel : dw_down = g.T @ dy """ -import math - import torch import triton import triton.language as tl # ============================================================================ -# REFERENCE HELPERS (PyTorch, for correctness checking) +# FORWARD (PyTorch) — required by both the Triton backward call site and +# any reference implementation; produces the cached intermediates h0/h1/a/g. # ============================================================================ -def silu_backward(x, grad): - sigmoid_x = torch.sigmoid(x) - return grad * (sigmoid_x + x * sigmoid_x * (1 - sigmoid_x)) - - -def _pytorch_backward_reference(dy, x, w_up, w_down, h0, h1, a, g, activation='silu'): - N_half = h0.shape[1] - - dg = torch.matmul(dy, w_down.t()) - - if activation == 'silu': - da = silu_backward(h0, dg * h1) - else: - da = dg * h1 - - dh0 = da - dh1 = dg * a - - w_gate = w_up[:N_half, :] - w_value = w_up[N_half:, :] - dx = torch.matmul(dh0, w_gate) + torch.matmul(dh1, w_value) - - dw_gate = torch.matmul(dh0.t(), x) - dw_value = torch.matmul(dh1.t(), x) - dw_up = torch.cat([dw_gate, dw_value], dim=0) - - dw_down = torch.matmul(g.t(), dy) - - return dx, dw_up, dw_down - - def ff_fused_gated_forward(x, w_up, w_down, activation='silu'): N = w_up.shape[0] N_half = N // 2 @@ -454,7 +422,7 @@ def grid_dw_down(META): # ============================================================================ -# ENTRY POINTS (triton_op / torch_op for GEAK harness) +# ENTRY POINT (callable form of the kernel for downstream profilers) # ============================================================================ @@ -464,12 +432,6 @@ def triton_op(M, N, K, x, w_up, w_down, dy, activation='silu'): return ff_fused_gated_backward_triton(dy, x, w_up, w_down, h0, h1, a, g, activation) -def torch_op(M, N, K, x, w_up, w_down, dy, activation='silu'): - """Run forward then PyTorch reference backward.""" - y, h0, h1, a, g = ff_fused_gated_forward(x, w_up, w_down, activation) - return _pytorch_backward_reference(dy, x, w_up, w_down, h0, h1, a, g, activation) - - # ============================================================================ # TEST CONFIGURATIONS # ============================================================================ @@ -497,7 +459,7 @@ def torch_op(M, N, K, x, w_up, w_down, dy, activation='silu'): # ============================================================================ -# TEST HARNESS +# SYNTHETIC INPUT BUILDER # ============================================================================ @@ -511,155 +473,7 @@ def get_inputs(M, K, N, dtype=DTYPE, device="cuda"): return x, w_up, w_down, dy -def check_correctness(M, K, N, activation=ACTIVATION, dtype=DTYPE) -> dict: - try: - x, w_up, w_down, dy = get_inputs(M, K, N, dtype) - - dx_tri, dwup_tri, dwdown_tri = triton_op(M, N, K, x, w_up, w_down, dy, activation) - dx_ref, dwup_ref, dwdown_ref = torch_op(M, N, K, x, w_up, w_down, dy, activation) - - def rel_diff(a, b): - max_diff = (a - b).abs().max().item() - max_val = max(a.abs().max().item(), b.abs().max().item()) - return max_diff / max_val if max_val > 0 else max_diff - - rd_dx = rel_diff(dx_tri, dx_ref) - rd_dwup = rel_diff(dwup_tri, dwup_ref) - rd_dwdown = rel_diff(dwdown_tri, dwdown_ref) - - correct = rd_dx < 0.01 and rd_dwup < 0.01 and rd_dwdown < 0.01 - return { - "correct": correct, - "rel_dx": rd_dx, "rel_dwup": rd_dwup, "rel_dwdown": rd_dwdown, - "error": None, - } - except Exception as e: - import traceback - return {"correct": False, "error": str(e) + "\n" + traceback.format_exc()} - - -def benchmark_config(M, K, N, activation=ACTIVATION, warmup=50, iters=200) -> dict: - import time - x, w_up, w_down, dy = get_inputs(M, K, N) - - y, h0, h1, a, g = ff_fused_gated_forward(x, w_up, w_down, activation) - - # Torch reference - for _ in range(warmup): - _pytorch_backward_reference(dy, x, w_up, w_down, h0, h1, a, g, activation) - torch.cuda.synchronize() - start = time.perf_counter() - for _ in range(iters): - _pytorch_backward_reference(dy, x, w_up, w_down, h0, h1, a, g, activation) - torch.cuda.synchronize() - torch_ms = (time.perf_counter() - start) * 1000 / iters - - # Triton - for _ in range(warmup): - ff_fused_gated_backward_triton(dy, x, w_up, w_down, h0, h1, a, g, activation) - torch.cuda.synchronize() - start = time.perf_counter() - for _ in range(iters): - ff_fused_gated_backward_triton(dy, x, w_up, w_down, h0, h1, a, g, activation) - torch.cuda.synchronize() - triton_ms = (time.perf_counter() - start) * 1000 / iters - - return { - "torch_ms": torch_ms, - "triton_ms": triton_ms, - "speedup": torch_ms / triton_ms if triton_ms > 0 else 0.0, - } - - -def evaluate(configs=None, warmup=50, iters=200, verbose=True) -> dict: - configs = configs or EVAL_CONFIGS - results, failures = [], [] - - if verbose: - print(f"{'Config (M,N,K)':<22} {'Correct':>8} {'Torch':>10} {'Triton':>10} {'Speedup':>10}") - print("-" * 62) - - for M, N, K in configs: - corr = check_correctness(M, K, N) - if not corr["correct"]: - failures.append({"config": (M, N, K), **corr}) - if verbose: - err = corr["error"] or f"dx={corr.get('rel_dx',0):.4f}" - print(f"({M},{N},{K}){'':<8} {'FAIL':>8} {err[:30]}") - continue - - bench = benchmark_config(M, K, N, warmup=warmup, iters=iters) - results.append({"config": (M, N, K), "correct": True, **bench}) - - if verbose: - marker = " *" if bench["speedup"] > 1.0 else "" - print( - f"({M},{N},{K}){'':<8} {'PASS':>8} " - f"{bench['torch_ms']:>8.3f}ms {bench['triton_ms']:>8.3f}ms " - f"{bench['speedup']:>8.2f}x{marker}" - ) - - speedups = [r["speedup"] for r in results] - geomean = math.prod(speedups) ** (1 / len(speedups)) if speedups else 0.0 - - if verbose: - print("-" * 62) - status = "ALL PASS" if not failures else f"FAILED ({len(failures)}/{len(configs)})" - print(f"{'Status:':<22} {status}") - if speedups: - print(f"{'Speedup (geomean):':<22} {geomean:.2f}x") - - return { - "correct": len(failures) == 0, - "num_correct": len(results), - "num_failed": len(failures), - "failures": failures, - "results": results, - "speedup_geomean": geomean, - } - - -def run_profile(configs=None, warmup=3, iters=1, verbose=True): - configs = configs or PROFILE_CONFIGS - if verbose: - print(f"Profile: {len(configs)} config(s)") - - for M, N, K in configs: - x, w_up, w_down, dy = get_inputs(M, K, N) - y, h0, h1, a, g = ff_fused_gated_forward(x, w_up, w_down, ACTIVATION) - - for _ in range(warmup): - ff_fused_gated_backward_triton(dy, x, w_up, w_down, h0, h1, a, g, ACTIVATION) - torch.cuda.synchronize() - - for _ in range(iters): - ff_fused_gated_backward_triton(dy, x, w_up, w_down, h0, h1, a, g, ACTIVATION) - torch.cuda.synchronize() - - if verbose: - print(f" ({M},{N},{K}) done") - - -# ============================================================================ -# MAIN -# ============================================================================ - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="FF Backward Kernel (Pure Triton)") - parser.add_argument("--profile", action="store_true") - args = parser.parse_args() - - print("=" * 62) - print("Fused Gated MLP Backward — Pure Triton") - print("=" * 62) - - if args.profile: - print("\n[Profile Mode]") - run_profile() - else: - print("\n[Evaluation]") - evaluate() - - print("=" * 62) +# Tolerance for the harness's torch.testing.assert_close. The previous +# correctness check used max_relative_diff < 0.01 across all three +# returned gradients; mirror that here as atol=rtol=1e-2. +RTOL, ATOL = 1e-2, 1e-2 diff --git a/tasks/triton2triton/geak_eval/L2/ff_backward/test_kernel_harness.py b/tasks/triton2triton/geak_eval/L2/ff_backward/test_kernel_harness.py index 3608b1e..2a4b1de 100755 --- a/tasks/triton2triton/geak_eval/L2/ff_backward/test_kernel_harness.py +++ b/tasks/triton2triton/geak_eval/L2/ff_backward/test_kernel_harness.py @@ -1,134 +1,241 @@ #!/usr/bin/env python3 """ -Test harness for ff_backward (SwiGLU fused gated backward) kernel. - -Modes: - --correctness : validate Triton backward against PyTorch reference - --benchmark : benchmark on HARNESS_CONFIGS, report GEAK_RESULT_LATENCY_MS - --full-benchmark : benchmark on ALL configs, report GEAK_RESULT_LATENCY_MS - --profile : run 3 configs for profiler capture - --iterations N : override iteration count (default from GEAK_BENCHMARK_ITERATIONS or 200) +Test harness for the fused gated MLP (SwiGLU) backward Triton kernel. + +Modes: --correctness, --profile, --benchmark, --full-benchmark """ import argparse import math import os import sys -import time - import torch -# Ensure kernel.py is importable -_harness_dir = os.path.dirname(os.path.abspath(__file__)) -if _harness_dir not in sys.path: - sys.path.insert(0, _harness_dir) - from kernel import ( - EVAL_CONFIGS, - check_correctness, - benchmark_config, - triton_op, + ff_fused_gated_forward, ff_fused_gated_backward_triton, get_inputs, + EVAL_CONFIGS, PROFILE_CONFIGS, ACTIVATION, + RTOL, ATOL, ) -# ── Config space ──────────────────────────────────────────────────────────── -ALL_CONFIGS = EVAL_CONFIGS -HARNESS_CONFIGS = ALL_CONFIGS # use all configs so benchmark matches full-benchmark -PROFILE_CONFIGS = ALL_CONFIGS[:3] - - -def _pick(configs, count): - if len(configs) <= count: - return list(range(len(configs))) - n = len(configs) - return [round(i * (n - 1) / (count - 1)) for i in range(count)] - - -# ── Correctness ──────────────────────────────────────────────────────────── -def run_correctness(configs, indices): - print(f"Running correctness on {len(indices)} configs...") - all_passed = True - for idx in indices: - M, N, K = configs[idx] - result = check_correctness(M, K, N) - if result["correct"]: - print(f" PASS config[{idx}] M={M} N={N} K={K}") - else: - err = result.get("error", f"rel_dx={result.get('rel_dx', '?')}") - print(f" FAIL config[{idx}] M={M} N={N} K={K}: {err}") - all_passed = False - print(f"GEAK_SHAPES_USED={indices}") - if all_passed: - print("ALL CORRECTNESS CHECKS PASSED") - return 0 - print("CORRECTNESS FAILED") - return 1 - - -# ── Benchmark ────────────────────────────────────────────────────────────── -def run_benchmark(configs, indices, warmup=50, iters=200): - print(f"Running benchmark on {len(indices)} configs...") - latencies = [] - for idx in indices: - M, N, K = configs[idx] - result = benchmark_config(M, K, N, warmup=warmup, iters=iters) - lat = result.get("triton_ms", 0) - latencies.append(lat) - print(f" M={M} N={N} K={K} {lat:.4f}ms") - - valid = [l for l in latencies if l > 0] - if valid: - geo_mean = math.exp(sum(math.log(l) for l in valid) / len(valid)) - else: - geo_mean = 0.0 - print(f"GEAK_SHAPES_USED={indices}") - print(f"GEAK_RESULT_LATENCY_MS={geo_mean:.4f}") - return 0 +# ============================================================================ +# SHAPE SUBSETS +# ============================================================================ + +ALL_SHAPES = EVAL_CONFIGS + +_n_all = len(ALL_SHAPES) +if _n_all <= 25: + HARNESS_SHAPES = ALL_SHAPES +else: + _harness_indices = [int(round(i * (_n_all - 1) / 24)) for i in range(25)] + HARNESS_SHAPES = [ALL_SHAPES[i] for i in _harness_indices] + +# PROFILE_CONFIGS already provides the profile subset per kernel.py. +PROFILE_SHAPES = PROFILE_CONFIGS + + +# ============================================================================ +# PYTORCH REFERENCE (moved from kernel.py; correctness-only) +# ============================================================================ + +def silu_backward(x, grad): + sigmoid_x = torch.sigmoid(x) + return grad * (sigmoid_x + x * sigmoid_x * (1 - sigmoid_x)) -# ── Profile ──────────────────────────────────────────────────────────────── -def run_profile(configs, indices): - print(f"Running profile on {len(indices)} configs...") - for idx in indices: - M, N, K = configs[idx] + +def _pytorch_backward_reference(dy, x, w_up, w_down, h0, h1, a, g, activation='silu'): + N_half = h0.shape[1] + + dg = torch.matmul(dy, w_down.t()) + + if activation == 'silu': + da = silu_backward(h0, dg * h1) + else: + da = dg * h1 + + dh0 = da + dh1 = dg * a + + w_gate = w_up[:N_half, :] + w_value = w_up[N_half:, :] + dx = torch.matmul(dh0, w_gate) + torch.matmul(dh1, w_value) + + dw_gate = torch.matmul(dh0.t(), x) + dw_value = torch.matmul(dh1.t(), x) + dw_up = torch.cat([dw_gate, dw_value], dim=0) + + dw_down = torch.matmul(g.t(), dy) + + return dx, dw_up, dw_down + + +# ============================================================================ +# TEST HARNESS +# ============================================================================ + +def _label(cfg): + M, N, K = cfg + return f"M={M},N={N},K={K}" + + +def run_correctness(shapes=None, verbose=True): + if shapes is None: + shapes = HARNESS_SHAPES + if verbose: + print(f"Running correctness on {len(shapes)} shapes...") + + results, failures = [], [] + + for cfg in shapes: + try: + M, N, K = cfg + x, w_up, w_down, dy = get_inputs(M, K, N) + y, h0, h1, a, g = ff_fused_gated_forward(x, w_up, w_down, ACTIVATION) + dx_t, dwup_t, dwdown_t = ff_fused_gated_backward_triton( + dy, x, w_up, w_down, h0, h1, a, g, ACTIVATION) + dx_r, dwup_r, dwdown_r = _pytorch_backward_reference( + dy, x, w_up, w_down, h0, h1, a, g, ACTIVATION) + torch.cuda.synchronize() + + torch.testing.assert_close(dx_t, dx_r, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(dwup_t, dwup_r, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(dwdown_t, dwdown_r, atol=ATOL, rtol=RTOL) + results.append({"config": cfg, "correct": True}) + if verbose: + print(f" PASS: {_label(cfg)}") + del x, w_up, w_down, dy, y, h0, h1, a, g + del dx_t, dwup_t, dwdown_t, dx_r, dwup_r, dwdown_r + torch.cuda.empty_cache() + except Exception as e: + failures.append({"config": cfg, "error": str(e)}) + if verbose: + print(f" FAIL: {_label(cfg)} - {str(e)[:80]}") + + if verbose: + print("-" * 62) + status = "ALL PASS" if not failures else f"FAILED ({len(failures)}/{len(shapes)})" + print(f"{'Status:':<22} {status}") + + return { + "correct": len(failures) == 0, + "num_correct": len(results), + "num_failed": len(failures), + "failures": failures, + "results": results, + } + + +def run_profile(shapes=None, warmup=50, iters=200, verbose=True): + if shapes is None: + shapes = PROFILE_SHAPES + if verbose: + print(f"Profile: {len(shapes)} config(s), {warmup} warmup, {iters} iter(s)") + + for cfg in shapes: + M, N, K = cfg x, w_up, w_down, dy = get_inputs(M, K, N) - # Warmup - for _ in range(3): - triton_op(M, N, K, x, w_up, w_down, dy) + y, h0, h1, a, g = ff_fused_gated_forward(x, w_up, w_down, ACTIVATION) + for _ in range(warmup): + ff_fused_gated_backward_triton(dy, x, w_up, w_down, h0, h1, a, g, ACTIVATION) torch.cuda.synchronize() - # One profiled run - triton_op(M, N, K, x, w_up, w_down, dy) + for _ in range(iters): + ff_fused_gated_backward_triton(dy, x, w_up, w_down, h0, h1, a, g, ACTIVATION) torch.cuda.synchronize() - print(f" M={M} N={N} K={K} done") - return 0 - - -# ── Main ─────────────────────────────────────────────────────────────────── -def main(): - default_iters = int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")) - - parser = argparse.ArgumentParser(description="ff_backward test harness") - group = parser.add_mutually_exclusive_group(required=True) - group.add_argument("--correctness", action="store_true") - group.add_argument("--benchmark", action="store_true") - group.add_argument("--full-benchmark", action="store_true") - group.add_argument("--profile", action="store_true") - parser.add_argument("--iterations", type=int, default=default_iters) - parser.add_argument("--warmup", type=int, default=50) + if verbose: + print(f" {_label(cfg)} done") + del x, w_up, w_down, dy, y, h0, h1, a, g + torch.cuda.empty_cache() + + +def run_benchmark(shapes=None, warmup=50, iters=200, verbose=True): + if shapes is None: + shapes = HARNESS_SHAPES + + latencies = [] + + print(f"Running benchmark on {len(shapes)} shapes, {warmup} warmup, {iters} iterations each...") + if verbose: + print(f"{'Config':<22} {'Triton':>10}") + print("-" * 34) + + for cfg in shapes: + M, N, K = cfg + x, w_up, w_down, dy = get_inputs(M, K, N) + # Forward computed ONCE outside the timed loop; the candidate Triton + # backward consumes the cached intermediates h0/h1/a/g. + y, h0, h1, a, g = ff_fused_gated_forward(x, w_up, w_down, ACTIVATION) + + for _ in range(warmup): + ff_fused_gated_backward_triton(dy, x, w_up, w_down, h0, h1, a, g, ACTIVATION) + torch.cuda.synchronize() + + triton_times = [] + for _ in range(iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + ff_fused_gated_backward_triton(dy, x, w_up, w_down, h0, h1, a, g, ACTIVATION) + end.record() + torch.cuda.synchronize() + triton_times.append(start.elapsed_time(end)) + + triton_ms = sorted(triton_times)[len(triton_times) // 2] + latencies.append(triton_ms) + + if verbose: + print(f"{_label(cfg):<22} {triton_ms:>8.4f}ms", flush=True) + + del x, w_up, w_down, dy, y, h0, h1, a, g + torch.cuda.empty_cache() + + geomean_latency = math.exp(sum(math.log(l) for l in latencies) / len(latencies)) + + print("-" * 34) + print(f"{'Geometric mean latency:':<22} {geomean_latency:.4f} ms") + print(f"GEAK_SHAPES_USED={list(range(len(shapes)))}") + print(f"GEAK_RESULT_LATENCY_MS={geomean_latency:.4f}", flush=True) + + return {"geomean_latency_ms": geomean_latency, "latencies": latencies} + + +# ============================================================================ +# MAIN +# ============================================================================ + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FF Backward (SwiGLU) Test Harness") + parser.add_argument("--correctness", action="store_true", + help="Run correctness tests on HARNESS_SHAPES") + parser.add_argument("--profile", action="store_true", + help="Run minimal profiling workload") + parser.add_argument("--benchmark", action="store_true", + help="Run benchmark on HARNESS_SHAPES (25 uniformly sampled)") + parser.add_argument("--full-benchmark", action="store_true", + help="Run benchmark on ALL_SHAPES (complete set)") + parser.add_argument("--warmup", type=int, default=50, + help="Number of warmup iterations (default: 50)") + parser.add_argument("--iterations", type=int, + default=int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")), + help="Number of benchmark iterations (default: GEAK_BENCHMARK_ITERATIONS or 200)") args = parser.parse_args() + print("=" * 62) + print("FF Backward (SwiGLU) Test Harness") + print("=" * 62) + if args.correctness: - indices = list(range(len(ALL_CONFIGS))) - sys.exit(run_correctness(ALL_CONFIGS, indices)) - elif args.benchmark: - indices = _pick(HARNESS_CONFIGS, 25) - sys.exit(run_benchmark(HARNESS_CONFIGS, indices, args.warmup, args.iterations)) - elif args.full_benchmark: - indices = list(range(len(ALL_CONFIGS))) - sys.exit(run_benchmark(ALL_CONFIGS, indices, args.warmup, args.iterations)) + print("\n[Correctness Mode]") + result = run_correctness(HARNESS_SHAPES) + sys.exit(0 if result["correct"] else 1) elif args.profile: - indices = list(range(len(PROFILE_CONFIGS))) - sys.exit(run_profile(PROFILE_CONFIGS, indices)) - + print("\n[Profile Mode]") + run_profile(PROFILE_SHAPES, warmup=args.warmup, iters=args.iterations) + elif args.full_benchmark: + print("\n[Full Benchmark Mode]") + run_benchmark(ALL_SHAPES, warmup=args.warmup, iters=args.iterations) + else: + print("\n[Benchmark Mode]") + run_benchmark(HARNESS_SHAPES, warmup=args.warmup, iters=args.iterations) -if __name__ == "__main__": - main() + print("=" * 62) diff --git a/tasks/triton2triton/geak_eval/L2/lean_atten_paged/kernel.py b/tasks/triton2triton/geak_eval/L2/lean_atten_paged/kernel.py index c434fba..a8fffda 100644 --- a/tasks/triton2triton/geak_eval/L2/lean_atten_paged/kernel.py +++ b/tasks/triton2triton/geak_eval/L2/lean_atten_paged/kernel.py @@ -15,8 +15,6 @@ from __future__ import annotations -import argparse -import math import random from typing import Sequence @@ -584,29 +582,6 @@ def _make_test_case( } -def torch_op( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - ref_indices, - n_ctx_q: int, - sm_scale: float, -): - ref_out = torch.empty_like(q, dtype=v.dtype) - for head_idx in range(q.shape[0]): - start_q = 0 - for batch_idx in range(len(ref_indices[head_idx])): - qb = q[head_idx, start_q : start_q + n_ctx_q, :] - idxs = ref_indices[head_idx][batch_idx] - kb = torch.index_select(k[head_idx], dim=0, index=idxs) - vb = torch.index_select(v[head_idx], dim=0, index=idxs) - p = torch.matmul(qb, kb.transpose(0, 1)) * sm_scale - p = torch.softmax(p.float(), dim=-1).to(q.dtype) - ref_out[head_idx, start_q : start_q + n_ctx_q, :] = torch.matmul(p, vb) - start_q += n_ctx_q - return ref_out - - # ============================================================================ # CONFIGS # ============================================================================ @@ -646,386 +621,3 @@ def torch_op( PROFILE_SHAPES = PROFILE_CONFIGS -# ============================================================================ -# TEST HARNESS -# ============================================================================ - - -def _run_single_correctness( - batch: int, - h: int, - n_ctx_q: int, - n_ctx: Sequence[int], - d: int, - total_programs: int, - dtype: torch.dtype, - block_m: int, - block_n: int, - waves_per_eu: int, - num_warps: int, -): - case = _make_test_case( - batch, - h, - n_ctx_q, - n_ctx, - d, - total_programs, - dtype, - block_m, - block_n, - waves_per_eu, - num_warps, - ) - - out_triton = persistent_lean_attention_paged( - q=case["q"], - k=case["k"], - v=case["v"], - kv_block_tables=case["kv_block_tables"], - Mp=case["Mp"], - Lp=case["Lp"], - Op=case["Op"], - locks=case["locks"], - batch_num_block_n=case["batch_num_block_n"], - total_programs=total_programs, - BLOCK_M=block_m, - BLOCK_N=block_n, - batch_size=batch, - sm_scale=case["sm_scale"], - num_warps=case["num_warps"], - waves_per_eu=case["waves_per_eu"], - ) - out_torch = torch_op( - case["q"], - case["k"], - case["v"], - case["ref_indices"], - n_ctx_q, - case["sm_scale"], - ) - torch.testing.assert_close(out_torch, out_triton, atol=ATOL, rtol=RTOL) - - -def run_correctness(configs=None, verbose=True): - if configs is None: - configs = CORRECTNESS_CONFIGS - print(f"Running correctness on {len(configs)} configs...") - results = [] - failures = [] - - for cfg in configs: - batch, h, n_ctx_q, n_ctx, d, total_programs, dtype, block_m, block_n, waves_per_eu, num_warps = cfg - tag = _config_tag( - batch, h, n_ctx_q, n_ctx, d, total_programs, block_m, block_n, waves_per_eu, num_warps - ) - try: - _run_single_correctness(*cfg) - results.append(tag) - if verbose: - print(f" PASS: {tag}") - except Exception as exc: - failures.append({"config": tag, "error": str(exc)}) - if verbose: - print(f" FAIL: {tag} - {str(exc)[:120]}") - torch.cuda.empty_cache() - - if verbose: - print("-" * 70) - status = "ALL PASS" if not failures else f"FAILED ({len(failures)}/{len(configs)})" - print(f"{'Status:':<22} {status}") - - return { - "correct": len(failures) == 0, - "num_correct": len(results), - "num_failed": len(failures), - "failures": failures, - } - - -def run_profile(configs=None, warmup=50, iters=200, verbose=True): - if configs is None: - configs = PROFILE_CONFIGS - if verbose: - print(f"Profile: {len(configs)} config(s), {warmup} warmup, {iters} iter(s)") - - for cfg in configs: - batch, h, n_ctx_q, n_ctx, d, total_programs, dtype, block_m, block_n, waves_per_eu, num_warps = cfg - case = _make_test_case( - batch, - h, - n_ctx_q, - n_ctx, - d, - total_programs, - dtype, - block_m, - block_n, - waves_per_eu, - num_warps, - ) - for _ in range(warmup): - persistent_lean_attention_paged( - q=case["q"], - k=case["k"], - v=case["v"], - kv_block_tables=case["kv_block_tables"], - Mp=case["Mp"], - Lp=case["Lp"], - Op=case["Op"], - locks=case["locks"], - batch_num_block_n=case["batch_num_block_n"], - total_programs=total_programs, - BLOCK_M=block_m, - BLOCK_N=block_n, - batch_size=batch, - sm_scale=case["sm_scale"], - num_warps=case["num_warps"], - waves_per_eu=case["waves_per_eu"], - ) - torch.cuda.synchronize() - for _ in range(iters): - persistent_lean_attention_paged( - q=case["q"], - k=case["k"], - v=case["v"], - kv_block_tables=case["kv_block_tables"], - Mp=case["Mp"], - Lp=case["Lp"], - Op=case["Op"], - locks=case["locks"], - batch_num_block_n=case["batch_num_block_n"], - total_programs=total_programs, - BLOCK_M=block_m, - BLOCK_N=block_n, - batch_size=batch, - sm_scale=case["sm_scale"], - num_warps=case["num_warps"], - waves_per_eu=case["waves_per_eu"], - ) - torch.cuda.synchronize() - if verbose: - print(f" {_config_tag(batch, h, n_ctx_q, n_ctx, d, total_programs, block_m, block_n, waves_per_eu, num_warps)} done") - torch.cuda.empty_cache() - - -def run_benchmark(configs=None, warmup=50, iters=200, verbose=True, baseline_fn=None): - """Benchmark kernel vs reference. Uses baseline_fn (Triton) when provided; else torch_op (PyTorch).""" - if configs is None: - configs = HARNESS_CONFIGS - - latencies = [] - speedups = [] - results = [] - ref_label = "baseline_triton" if baseline_fn is not None else "PyTorch" - - print( - f"Running benchmark on {len(configs)} configs, {warmup} warmup, {iters} iterations each..." - ) - print(f" Comparing kernel vs {ref_label}") - if verbose: - print(f"{'Config':<72} {'Ref':>10} {'Triton':>10} {'Speedup':>10}") - print("-" * 108) - - for cfg in configs: - batch, h, n_ctx_q, n_ctx, d, total_programs, dtype, block_m, block_n, waves_per_eu, num_warps = cfg - case = _make_test_case( - batch, - h, - n_ctx_q, - n_ctx, - d, - total_programs, - dtype, - block_m, - block_n, - waves_per_eu, - num_warps, - ) - tag = _config_tag( - batch, h, n_ctx_q, n_ctx, d, total_programs, block_m, block_n, waves_per_eu, num_warps - ) - - for _ in range(warmup): - persistent_lean_attention_paged( - q=case["q"], - k=case["k"], - v=case["v"], - kv_block_tables=case["kv_block_tables"], - Mp=case["Mp"], - Lp=case["Lp"], - Op=case["Op"], - locks=case["locks"], - batch_num_block_n=case["batch_num_block_n"], - total_programs=total_programs, - BLOCK_M=block_m, - BLOCK_N=block_n, - batch_size=batch, - sm_scale=case["sm_scale"], - num_warps=case["num_warps"], - waves_per_eu=case["waves_per_eu"], - ) - torch.cuda.synchronize() - - triton_times = [] - for _ in range(iters): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - persistent_lean_attention_paged( - q=case["q"], - k=case["k"], - v=case["v"], - kv_block_tables=case["kv_block_tables"], - Mp=case["Mp"], - Lp=case["Lp"], - Op=case["Op"], - locks=case["locks"], - batch_num_block_n=case["batch_num_block_n"], - total_programs=total_programs, - BLOCK_M=block_m, - BLOCK_N=block_n, - batch_size=batch, - sm_scale=case["sm_scale"], - num_warps=case["num_warps"], - waves_per_eu=case["waves_per_eu"], - ) - end.record() - torch.cuda.synchronize() - triton_times.append(start.elapsed_time(end)) - triton_ms = sorted(triton_times)[len(triton_times) // 2] - - ref_times = [] - for _ in range(iters): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - if baseline_fn is not None: - baseline_fn( - q=case["q"], - k=case["k"], - v=case["v"], - kv_block_tables=case["kv_block_tables"], - Mp=case["Mp"], - Lp=case["Lp"], - Op=case["Op"], - locks=case["locks"], - batch_num_block_n=case["batch_num_block_n"], - total_programs=total_programs, - BLOCK_M=block_m, - BLOCK_N=block_n, - batch_size=batch, - sm_scale=case["sm_scale"], - num_warps=case["num_warps"], - waves_per_eu=case["waves_per_eu"], - ) - else: - torch_op( - case["q"], - case["k"], - case["v"], - case["ref_indices"], - n_ctx_q, - case["sm_scale"], - ) - end.record() - torch.cuda.synchronize() - ref_times.append(start.elapsed_time(end)) - ref_ms = sorted(ref_times)[len(ref_times) // 2] - - speedup = ref_ms / triton_ms if triton_ms > 0 else 1.0 - latencies.append(triton_ms) - speedups.append(speedup) - results.append( - { - "config": tag, - "ref_ms": ref_ms, - "triton_ms": triton_ms, - "speedup": speedup, - } - ) - - if verbose: - marker = " *" if speedup > 1.0 else "" - print(f"{tag:<72} {ref_ms:>8.4f}ms {triton_ms:>8.4f}ms {speedup:>8.2f}x{marker}") - - torch.cuda.empty_cache() - - geomean_latency = math.exp(sum(math.log(t) for t in latencies) / len(latencies)) - geomean_speedup = math.exp(sum(math.log(s) for s in speedups) / len(speedups)) - - if verbose: - print("-" * 108) - print(f"{'Geometric mean latency:':<72} {geomean_latency:.4f} ms") - print(f"{'Geometric mean speedup:':<72} {geomean_speedup:.2f}x") - print(f"GEAK_RESULT_LATENCY_MS={geomean_latency:.4f}") - print(f"GEAK_RESULT_GEOMEAN_SPEEDUP={geomean_speedup:.4f}") - - return { - "geomean_latency_ms": geomean_latency, - "geomean_speedup": geomean_speedup, - "results": results, - } - - -# ============================================================================ -# MAIN -# ============================================================================ - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Lean Attention + Paged Attention Kernel Test Harness" - ) - parser.add_argument( - "--correctness", - action="store_true", - help="Run correctness tests on correctness configs", - ) - parser.add_argument( - "--profile", - action="store_true", - help="Run minimal profiling workload", - ) - parser.add_argument( - "--benchmark", - action="store_true", - help="Run benchmark on HARNESS_CONFIGS", - ) - parser.add_argument( - "--full-benchmark", - action="store_true", - help="Run benchmark on ALL_CONFIGS", - ) - parser.add_argument( - "--warmup", - type=int, - default=50, - help="Number of warmup iterations (default: 50)", - ) - parser.add_argument( - "--iterations", - type=int, - default=200, - help="Number of benchmark iterations (default: 200)", - ) - args = parser.parse_args() - - print("=" * 70) - print("Lean Attention + Paged Attention Kernel Test Harness") - print("=" * 70) - - if args.correctness: - print("\n[Correctness Mode]") - run_correctness(CORRECTNESS_CONFIGS) - elif args.profile: - print("\n[Profile Mode]") - run_profile(PROFILE_CONFIGS, warmup=args.warmup, iters=args.iterations) - elif args.full_benchmark: - print("\n[Full Benchmark Mode]") - run_benchmark(ALL_CONFIGS, warmup=args.warmup, iters=args.iterations) - else: - print("\n[Benchmark Mode]") - run_benchmark(HARNESS_CONFIGS, warmup=args.warmup, iters=args.iterations) - - print("=" * 70) diff --git a/tasks/triton2triton/geak_eval/L2/lean_atten_paged/test_kernel_harness.py b/tasks/triton2triton/geak_eval/L2/lean_atten_paged/test_kernel_harness.py index 71c976b..f9e3816 100644 --- a/tasks/triton2triton/geak_eval/L2/lean_atten_paged/test_kernel_harness.py +++ b/tasks/triton2triton/geak_eval/L2/lean_atten_paged/test_kernel_harness.py @@ -1,222 +1,232 @@ #!/usr/bin/env python3 """ -Lean Attention + Paged Attention kernel test harness. - -Wraps the built-in harness in kernel.py to ensure: -- --correctness exits non-zero on failure -- --iterations reads GEAK_BENCHMARK_ITERATIONS env var -- --benchmark uses HARNESS_CONFIGS -- --full-benchmark uses ALL_CONFIGS -- --profile uses PROFILE_CONFIGS -- GEAK_RESULT_LATENCY_MS is always the LAST line of benchmark output - -Modes: - --correctness : validate kernel against torch reference - --profile : run kernel once per PROFILE_SHAPES for profiler capture - --benchmark : benchmark on HARNESS_CONFIGS, print GEAK_RESULT_LATENCY_MS - --full-benchmark : benchmark on ALL_CONFIGS, print GEAK_RESULT_LATENCY_MS - --iterations N : override iteration count (default from GEAK_BENCHMARK_ITERATIONS or 20) -""" -from __future__ import annotations +Test harness for the persistent lean attention + paged attention Triton kernel. -import argparse -import os -import sys +Modes: --correctness, --profile, --benchmark, --full-benchmark -# GEAK materialized harness bootstrap -import importlib.util +Per-task divergence from the standard dispatch convention: --correctness +uses CORRECTNESS_CONFIGS (a small fixed correctness-focused set), not +HARNESS_CONFIGS. CORRECTNESS_CONFIGS and ALL_CONFIGS are different sets in +this task and the project plan calls for preserving the kernel.py-side +convention here. +""" +import argparse +import math import os import sys -import types -from pathlib import Path - -def _find_baseline_kernel_dir(): - """Find preprocess dir (has benchmark_baseline.txt) by walking up from GEAK_WORK_DIR.""" - work = os.environ.get("GEAK_WORK_DIR", "").strip() - if not work: - return None - d = Path(work).resolve() - for _ in range(10): - if d is None or not d.exists(): - break - bb = d / "benchmark_baseline.txt" - if bb.is_file(): - return str(d) - d = d.parent - return None - -def _load_baseline_triton(baseline_dir, module_alias, entry_name): - """Load kernel from baseline_dir. Returns callable or None.""" - entry_file = Path(baseline_dir) / "kernel.py" - if not entry_file.is_file(): - return None - if baseline_dir not in sys.path: - sys.path.insert(0, baseline_dir) - spec = importlib.util.spec_from_file_location(module_alias, entry_file) - if spec is None or spec.loader is None: - return None - module = importlib.util.module_from_spec(spec) - sys.modules[module_alias] = module - try: - spec.loader.exec_module(module) - return getattr(module, entry_name, None) - except Exception: - return None - -def _resolve_geak_kernel_dir(): - candidates = [] - work_dir = os.environ.get("GEAK_WORK_DIR", "").strip() - if work_dir: - candidates.append(work_dir) - repo_root = os.environ.get("GEAK_REPO_ROOT", "").strip() - rel_kernel_dir = '.' - if repo_root and rel_kernel_dir: - candidates.append(os.path.join(repo_root, rel_kernel_dir)) - original_kernel_dir = os.path.dirname(os.path.abspath(__file__)) - if original_kernel_dir: - candidates.append(original_kernel_dir) - for candidate in candidates: - if candidate and os.path.isfile(os.path.join(candidate, "kernel.py")): - return candidate - return original_kernel_dir or os.getcwd() - -def _ensure_geak_package(module_name): - parts = module_name.split(".") - for idx in range(1, len(parts)): - prefix = ".".join(parts[:idx]) - if prefix in sys.modules: - continue - pkg = types.ModuleType(prefix) - pkg.__path__ = [] - sys.modules[prefix] = pkg - -def _ensure_geak_aiter_fp8_dtype(module): - fp8_value = getattr(module, "fp8_dtype", None) - if fp8_value is None: - return - aiter_mod = sys.modules.get("aiter") - if aiter_mod is None: - try: - import aiter as aiter_mod - except Exception: - _ensure_geak_package("aiter") - aiter_mod = sys.modules.get("aiter") - if aiter_mod is None: - return - dtypes_obj = getattr(aiter_mod, "dtypes", None) - if dtypes_obj is None: - dtypes_obj = types.SimpleNamespace() - setattr(aiter_mod, "dtypes", dtypes_obj) - if getattr(dtypes_obj, "fp8", None) is None: - setattr(dtypes_obj, "fp8", fp8_value) - -def _register_geak_aliases(kernel_dir): - aliases = ['lean_atten_paged'] - entry_file = os.path.join(kernel_dir, "kernel.py") - if not os.path.isfile(entry_file): - return - for alias in aliases: - if alias in sys.modules: - continue - _ensure_geak_package(alias) - spec = importlib.util.spec_from_file_location(alias, entry_file) - if spec is None or spec.loader is None: - continue - module = importlib.util.module_from_spec(spec) - sys.modules[alias] = module - spec.loader.exec_module(module) - _ensure_geak_aiter_fp8_dtype(module) - -_KERNEL_DIR = _resolve_geak_kernel_dir() -if _KERNEL_DIR and _KERNEL_DIR not in sys.path: - sys.path.insert(0, _KERNEL_DIR) -_register_geak_aliases(_KERNEL_DIR) +import torch from kernel import ( - run_correctness, - run_profile, - run_benchmark, - CORRECTNESS_CONFIGS, - HARNESS_CONFIGS, - ALL_CONFIGS, - PROFILE_CONFIGS, + persistent_lean_attention_paged, + _make_test_case, _config_tag, + CORRECTNESS_CONFIGS, ALL_CONFIGS, HARNESS_CONFIGS, PROFILE_CONFIGS, + RTOL, ATOL, ) -def _get_baseline_fn(): - """Resolve baseline Triton kernel when in patch-eval mode.""" - baseline_dir = _find_baseline_kernel_dir() - kernel_dir = _resolve_geak_kernel_dir() - if baseline_dir and baseline_dir != kernel_dir: - return _load_baseline_triton(baseline_dir, "baseline_lean_atten", "persistent_lean_attention_paged") - return None +# ============================================================================ +# SHAPE SUBSETS +# ============================================================================ + +# kernel.py already pre-samples HARNESS_CONFIGS (25) and PROFILE_CONFIGS (5) +# from ALL_CONFIGS, so we just re-export under the standard names here. +ALL_SHAPES = ALL_CONFIGS +HARNESS_SHAPES = HARNESS_CONFIGS +PROFILE_SHAPES = PROFILE_CONFIGS + + +# ============================================================================ +# PYTORCH REFERENCE (moved from kernel.py; correctness-only) +# ============================================================================ + +def torch_op(q, k, v, ref_indices, n_ctx_q, sm_scale): + ref_out = torch.empty_like(q, dtype=v.dtype) + for head_idx in range(q.shape[0]): + start_q = 0 + for batch_idx in range(len(ref_indices[head_idx])): + qb = q[head_idx, start_q : start_q + n_ctx_q, :] + idxs = ref_indices[head_idx][batch_idx] + kb = torch.index_select(k[head_idx], dim=0, index=idxs) + vb = torch.index_select(v[head_idx], dim=0, index=idxs) + p = torch.matmul(qb, kb.transpose(0, 1)) * sm_scale + p = torch.softmax(p.float(), dim=-1).to(q.dtype) + ref_out[head_idx, start_q : start_q + n_ctx_q, :] = torch.matmul(p, vb) + start_q += n_ctx_q + return ref_out + + +# ============================================================================ +# Helpers +# ============================================================================ + +def _call_triton(case, cfg): + batch, h, n_ctx_q, n_ctx, d, total_programs, dtype, block_m, block_n, waves_per_eu, num_warps = cfg + return persistent_lean_attention_paged( + q=case["q"], k=case["k"], v=case["v"], + kv_block_tables=case["kv_block_tables"], + Mp=case["Mp"], Lp=case["Lp"], Op=case["Op"], locks=case["locks"], + batch_num_block_n=case["batch_num_block_n"], + total_programs=total_programs, + BLOCK_M=block_m, BLOCK_N=block_n, + batch_size=batch, + sm_scale=case["sm_scale"], + num_warps=case["num_warps"], + waves_per_eu=case["waves_per_eu"], + ) -def main(): - default_iters = int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")) +# ============================================================================ +# TEST HARNESS +# ============================================================================ - parser = argparse.ArgumentParser( - description="Lean Attention + Paged Attention Kernel Test Harness" - ) +def run_correctness(shapes=None, verbose=True): + if shapes is None: + shapes = CORRECTNESS_CONFIGS + if verbose: + print(f"Running correctness on {len(shapes)} shapes...") + + results, failures = [], [] + + for cfg in shapes: + batch, h, n_ctx_q, n_ctx, d, total_programs, dtype, block_m, block_n, waves_per_eu, num_warps = cfg + tag = _config_tag(batch, h, n_ctx_q, n_ctx, d, total_programs, block_m, block_n, waves_per_eu, num_warps) + try: + case = _make_test_case(*cfg) + out_triton = _call_triton(case, cfg) + out_torch = torch_op(case["q"], case["k"], case["v"], + case["ref_indices"], n_ctx_q, case["sm_scale"]) + torch.cuda.synchronize() + + torch.testing.assert_close(out_torch, out_triton, atol=ATOL, rtol=RTOL) + results.append({"config": tag, "correct": True}) + if verbose: + print(f" PASS: {tag}") + except Exception as exc: + failures.append({"config": tag, "error": str(exc)}) + if verbose: + print(f" FAIL: {tag} - {str(exc)[:120]}") + torch.cuda.empty_cache() + + if verbose: + print("-" * 70) + status = "ALL PASS" if not failures else f"FAILED ({len(failures)}/{len(shapes)})" + print(f"{'Status:':<22} {status}") + + return { + "correct": len(failures) == 0, + "num_correct": len(results), + "num_failed": len(failures), + "failures": failures, + "results": results, + } + + +def run_profile(shapes=None, warmup=50, iters=200, verbose=True): + if shapes is None: + shapes = PROFILE_SHAPES + if verbose: + print(f"Profile: {len(shapes)} config(s), {warmup} warmup, {iters} iter(s)") + + for cfg in shapes: + case = _make_test_case(*cfg) + for _ in range(warmup): + _call_triton(case, cfg) + torch.cuda.synchronize() + for _ in range(iters): + _call_triton(case, cfg) + torch.cuda.synchronize() + if verbose: + batch, h, n_ctx_q, n_ctx, d, total_programs, dtype, block_m, block_n, waves_per_eu, num_warps = cfg + tag = _config_tag(batch, h, n_ctx_q, n_ctx, d, total_programs, block_m, block_n, waves_per_eu, num_warps) + print(f" {tag} done") + torch.cuda.empty_cache() + + +def run_benchmark(shapes=None, warmup=50, iters=200, verbose=True): + if shapes is None: + shapes = HARNESS_SHAPES + + latencies = [] + + print(f"Running benchmark on {len(shapes)} shapes, {warmup} warmup, {iters} iterations each...") + if verbose: + print(f"{'Config':<72} {'Triton':>10}") + print("-" * 84) + + for cfg in shapes: + batch, h, n_ctx_q, n_ctx, d, total_programs, dtype, block_m, block_n, waves_per_eu, num_warps = cfg + tag = _config_tag(batch, h, n_ctx_q, n_ctx, d, total_programs, block_m, block_n, waves_per_eu, num_warps) + case = _make_test_case(*cfg) + + for _ in range(warmup): + _call_triton(case, cfg) + torch.cuda.synchronize() + + triton_times = [] + for _ in range(iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + _call_triton(case, cfg) + end.record() + torch.cuda.synchronize() + triton_times.append(start.elapsed_time(end)) + + triton_ms = sorted(triton_times)[len(triton_times) // 2] + latencies.append(triton_ms) + + if verbose: + print(f"{tag:<72} {triton_ms:>8.4f}ms", flush=True) + + torch.cuda.empty_cache() + + geomean_latency = math.exp(sum(math.log(l) for l in latencies) / len(latencies)) + + print("-" * 84) + print(f"{'Geometric mean latency:':<72} {geomean_latency:.4f} ms") + print(f"GEAK_SHAPES_USED={list(range(len(shapes)))}") + print(f"GEAK_RESULT_LATENCY_MS={geomean_latency:.4f}", flush=True) + + return {"geomean_latency_ms": geomean_latency, "latencies": latencies} + + +# ============================================================================ +# MAIN +# ============================================================================ + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Lean Attention (Paged) Test Harness") parser.add_argument("--correctness", action="store_true", - help="Run correctness tests") + help="Run correctness tests on CORRECTNESS_CONFIGS") parser.add_argument("--profile", action="store_true", help="Run minimal profiling workload") parser.add_argument("--benchmark", action="store_true", - help="Run benchmark on HARNESS_CONFIGS") + help="Run benchmark on HARNESS_SHAPES (25 uniformly sampled)") parser.add_argument("--full-benchmark", action="store_true", - help="Run benchmark on ALL_CONFIGS") - parser.add_argument("--iterations", type=int, default=default_iters, - help=f"Number of benchmark iterations (default: {default_iters})") + help="Run benchmark on ALL_SHAPES (complete set)") parser.add_argument("--warmup", type=int, default=50, help="Number of warmup iterations (default: 50)") + parser.add_argument("--iterations", type=int, + default=int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")), + help="Number of benchmark iterations (default: GEAK_BENCHMARK_ITERATIONS or 200)") args = parser.parse_args() - if args.correctness: - print("=" * 70) - print("[Correctness Mode]") - print("=" * 70) - result = run_correctness(CORRECTNESS_CONFIGS, verbose=True) - if not result["correct"]: - print(f"\nFAILED: {result['num_failed']} correctness test(s) failed") - sys.exit(1) - print("\nAll correctness tests PASSED") - sys.exit(0) + print("=" * 70) + print("Lean Attention (Paged) Test Harness") + print("=" * 70) + if args.correctness: + print("\n[Correctness Mode]") + result = run_correctness(CORRECTNESS_CONFIGS) + sys.exit(0 if result["correct"] else 1) elif args.profile: - print("=" * 70) - print("[Profile Mode]") - print("=" * 70) - run_profile(PROFILE_CONFIGS, warmup=args.warmup, iters=args.iterations, - verbose=True) - sys.exit(0) - + print("\n[Profile Mode]") + run_profile(PROFILE_SHAPES, warmup=args.warmup, iters=args.iterations) elif args.full_benchmark: - print("=" * 70) - print("[Full Benchmark Mode]") - print("=" * 70) - baseline_fn = _get_baseline_fn() - result = run_benchmark(ALL_CONFIGS, warmup=args.warmup, - iters=args.iterations, verbose=True, baseline_fn=baseline_fn) - # Ensure GEAK_RESULT_LATENCY_MS is the LAST line of output - print(f"GEAK_RESULT_LATENCY_MS={result['geomean_latency_ms']:.4f}") - sys.exit(0) - - elif args.benchmark: - print("=" * 70) - print("[Benchmark Mode]") - print("=" * 70) - baseline_fn = _get_baseline_fn() - result = run_benchmark(HARNESS_CONFIGS, warmup=args.warmup, - iters=args.iterations, verbose=True, baseline_fn=baseline_fn) - # Ensure GEAK_RESULT_LATENCY_MS is the LAST line of output - print(f"GEAK_RESULT_LATENCY_MS={result['geomean_latency_ms']:.4f}") - sys.exit(0) - + print("\n[Full Benchmark Mode]") + run_benchmark(ALL_SHAPES, warmup=args.warmup, iters=args.iterations) else: - parser.print_help() - sys.exit(1) + print("\n[Benchmark Mode]") + run_benchmark(HARNESS_SHAPES, warmup=args.warmup, iters=args.iterations) - -if __name__ == "__main__": - main() + print("=" * 70) diff --git a/tasks/triton2triton/geak_eval/L3/gemm_a16wfp4/test_kernel_harness.py b/tasks/triton2triton/geak_eval/L3/gemm_a16wfp4/test_kernel_harness.py index 10d3793..bd2f672 100644 --- a/tasks/triton2triton/geak_eval/L3/gemm_a16wfp4/test_kernel_harness.py +++ b/tasks/triton2triton/geak_eval/L3/gemm_a16wfp4/test_kernel_harness.py @@ -1,20 +1,37 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: MIT -# Test harness for gemm_a16wfp4 kernel +""" +Test harness for the gemm_a16wfp4 (MXFP4) Triton kernel. +Modes: --correctness, --profile, --benchmark, --full-benchmark +""" import argparse +import math import os import sys -import time import torch -# Import kernel and utilities from kernel import gemm_a16wfp4, is_fp4_avail -# Note this is specified by the HW and cannot be changed. + +# ============================================================================ +# CONSTANTS +# ============================================================================ + +# Specified by the HW and cannot be changed. SCALE_GROUP_SIZE = 32 -# ALL_SHAPES: All unique shapes from test file, sorted by total element count +DTYPE = torch.bfloat16 + +# Tolerance defaults — match the previous in-harness assert_close (rtol=1e-2, atol=1e-2). +RTOL, ATOL = 1e-2, 1e-2 + + +# ============================================================================ +# SHAPE LISTS +# ============================================================================ + +# ALL_SHAPES: All unique shapes from test file, sorted by total element count. ALL_SHAPES = [ (1, 8192, 1024), (1, 1280, 8192), @@ -75,28 +92,27 @@ # (9728, 8192, 65536), # Too large, may cause OOM ] -# HARNESS_SHAPES: use ALL shapes so task-local and verified benchmarks match -HARNESS_SHAPES = ALL_SHAPES +# HARNESS_SHAPES: 25 uniformly sampled shapes from ALL_SHAPES. +_n_all = len(ALL_SHAPES) +if _n_all <= 25: + HARNESS_SHAPES = ALL_SHAPES +else: + _harness_indices = [int(round(i * (_n_all - 1) / 24)) for i in range(25)] + HARNESS_SHAPES = [ALL_SHAPES[i] for i in _harness_indices] -# PROFILE_SHAPES: 5 evenly-spaced shapes for profiling +# PROFILE_SHAPES: 5 evenly-spaced shapes for profiling. PROFILE_SHAPES = [ - (1, 8192, 1024), # smallest - (32, 7168, 2048), # small-medium - (256, 8192, 1024), # medium - (2048, 2048, 2048), # medium-large - (4096, 4096, 4096), # large + (1, 8192, 1024), + (32, 7168, 2048), + (256, 8192, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), ] -def shuffle_scales(scales: torch.Tensor): - """Shuffle scales for preshuffle kernel.""" - scales_shuffled = scales.clone() - sm, sn = scales_shuffled.shape - scales_shuffled = scales_shuffled.view(sm // 32, 2, 16, sn // 8, 2, 4, 1) - scales_shuffled = scales_shuffled.permute(0, 3, 5, 2, 4, 1, 6).contiguous() - scales_shuffled = scales_shuffled.view(sm // 32, sn * 32) - return scales_shuffled - +# ============================================================================ +# PYTORCH REFERENCE (correctness-only) +# ============================================================================ def mxfp4_to_f32(x): """Convert MXFP4 packed uint8 to float32.""" @@ -118,15 +134,30 @@ def e8m0_to_f32(x): return x_f32 -def generate_inputs(M: int, N: int, K: int, dtype=torch.bfloat16): +def run_torch_reference(x, w, w_scales, dtype): + """Compute reference output using PyTorch.""" + x_f32 = x.to(torch.float32) + w_f32 = mxfp4_to_f32(w) + w_scales_expanded = w_scales.repeat_interleave(SCALE_GROUP_SIZE, dim=-1).to(torch.float32) + w_scales_f32 = e8m0_to_f32(w_scales_expanded) + assert w_f32.shape == w_scales_f32.shape + w_f32 = w_f32 * w_scales_f32 + return torch.mm(x_f32, w_f32.T).to(dtype) + + +# ============================================================================ +# INPUT GENERATION +# ============================================================================ + +def generate_inputs(M, N, K, dtype=DTYPE): """Generate inputs for gemm_a16wfp4 kernel.""" torch.manual_seed(42) - - # Generate x (bf16 input) - TN layout only + + # Generate x (bf16 input) — TN layout only x_low = torch.randint(0, 16, (M, K // 2), dtype=torch.uint8, device="cuda") x_high = torch.randint(0, 16, (M, K // 2), dtype=torch.uint8, device="cuda") x_uint8 = x_low | x_high << 4 - + # Generate x_scales and convert x to bf16 x_scales = torch.randint(124, 128, (K // SCALE_GROUP_SIZE, M), dtype=torch.uint8, device="cuda").T x_f32 = mxfp4_to_f32(x_uint8) @@ -134,152 +165,187 @@ def generate_inputs(M: int, N: int, K: int, dtype=torch.bfloat16): x_scales_f32 = e8m0_to_f32(x_scales_expanded) x_f32 = x_f32 * x_scales_f32 x = x_f32.to(dtype) - - # Generate w (fp4 weights) - TN layout only + + # Generate w (fp4 weights) — TN layout only w_low = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda") w_high = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda") w = w_low | w_high << 4 - + # Generate w_scales w_scales = torch.randint(124, 128, (K // SCALE_GROUP_SIZE, N), dtype=torch.uint8, device="cuda").T - - # Non-preshuffled deterministic path only + + # Non-preshuffled deterministic path only. return x, w, w, w_scales, w_scales -def run_torch_reference(x, w, w_scales, dtype): - """Compute reference output using PyTorch.""" - x_f32 = x.to(torch.float32) - w_f32 = mxfp4_to_f32(w) - w_scales_expanded = w_scales.repeat_interleave(SCALE_GROUP_SIZE, dim=-1).to(torch.float32) - w_scales_f32 = e8m0_to_f32(w_scales_expanded) - assert w_f32.shape == w_scales_f32.shape - w_f32 = w_f32 * w_scales_f32 - return torch.mm(x_f32, w_f32.T).to(dtype) +# ============================================================================ +# TEST HARNESS +# ============================================================================ + +def _label(cfg): + M, N, K = cfg + return f"({M}, {N}, {K})" -def run_correctness(shapes): - """Run correctness tests on given shapes.""" +def run_correctness(shapes=None, verbose=True): + if shapes is None: + shapes = HARNESS_SHAPES if not is_fp4_avail(): print("MXFP4 not supported on this architecture, skipping correctness tests") - return True - - print(f"Running correctness tests on {len(shapes)} shapes...") - all_passed = True - - for i, (M, N, K) in enumerate(shapes): - torch.cuda.empty_cache() - dtype = torch.bfloat16 - + return {"correct": True, "num_correct": 0, "num_failed": 0, + "failures": [], "results": [], "skipped": True} + + if verbose: + print(f"Running correctness on {len(shapes)} shapes...") + + results, failures = [], [] + + for cfg in shapes: + M, N, K = cfg try: - x, w, w_kernel, w_scales, w_scales_kernel = generate_inputs(M, N, K, dtype=dtype) - - # Run kernel - y = gemm_a16wfp4(x, w_kernel, w_scales_kernel, atomic_add=False, dtype=dtype) - - # Run reference - y_ref = run_torch_reference(x, w, w_scales, dtype) - - # Compare - torch.testing.assert_close(y, y_ref, rtol=1e-2, atol=1e-2) - print(f" [{i+1}/{len(shapes)}] Shape ({M}, {N}, {K}): PASSED") + x, w, w_kernel, w_scales, w_scales_kernel = generate_inputs(M, N, K, DTYPE) + y = gemm_a16wfp4(x, w_kernel, w_scales_kernel, atomic_add=False, dtype=DTYPE) + y_ref = run_torch_reference(x, w, w_scales, DTYPE) + torch.cuda.synchronize() + + torch.testing.assert_close(y, y_ref, atol=ATOL, rtol=RTOL) + results.append({"config": cfg, "correct": True}) + if verbose: + print(f" PASS: {_label(cfg)}") + del x, w, w_kernel, w_scales, w_scales_kernel, y, y_ref + torch.cuda.empty_cache() except Exception as e: - print(f" [{i+1}/{len(shapes)}] Shape ({M}, {N}, {K}): FAILED - {e}") - all_passed = False - - return all_passed + failures.append({"config": cfg, "error": str(e)}) + if verbose: + print(f" FAIL: {_label(cfg)} - {str(e)[:80]}") + + if verbose: + print("-" * 62) + status = "ALL PASS" if not failures else f"FAILED ({len(failures)}/{len(shapes)})" + print(f"{'Status:':<22} {status}") + return { + "correct": len(failures) == 0, + "num_correct": len(results), + "num_failed": len(failures), + "failures": failures, + "results": results, + } -def run_profile(shapes): - """Run kernel once for profiling.""" + +def run_profile(shapes=None, warmup=50, iters=200, verbose=True): + if shapes is None: + shapes = PROFILE_SHAPES if not is_fp4_avail(): print("MXFP4 not supported on this architecture") return - - for M, N, K in shapes: - torch.cuda.empty_cache() - dtype = torch.bfloat16 - - x, w, w_kernel, w_scales, w_scales_kernel = generate_inputs(M, N, K, dtype=dtype) - - # Warmup - y = gemm_a16wfp4(x, w_kernel, w_scales_kernel, atomic_add=False, dtype=dtype) + if verbose: + print(f"Profile: {len(shapes)} config(s), {warmup} warmup, {iters} iter(s)") + + for cfg in shapes: + M, N, K = cfg + x, w, w_kernel, w_scales, w_scales_kernel = generate_inputs(M, N, K, DTYPE) + for _ in range(warmup): + gemm_a16wfp4(x, w_kernel, w_scales_kernel, atomic_add=False, dtype=DTYPE) torch.cuda.synchronize() - - # Profile run - y = gemm_a16wfp4(x, w_kernel, w_scales_kernel, atomic_add=False, dtype=dtype) + for _ in range(iters): + gemm_a16wfp4(x, w_kernel, w_scales_kernel, atomic_add=False, dtype=DTYPE) torch.cuda.synchronize() - - print(f"Profiled shape ({M}, {N}, {K})") + if verbose: + print(f" {_label(cfg)} done") + del x, w, w_kernel, w_scales, w_scales_kernel + torch.cuda.empty_cache() -def run_benchmark(shapes, iterations=20): - """Run benchmark on given shapes.""" +def run_benchmark(shapes=None, warmup=50, iters=200, verbose=True): + if shapes is None: + shapes = HARNESS_SHAPES if not is_fp4_avail(): print("MXFP4 not supported on this architecture") - print("GEAK_RESULT_LATENCY_MS=0.0") - return - - print(f"Running benchmark on {len(shapes)} shapes with {iterations} iterations...") + print("GEAK_RESULT_LATENCY_MS=0.0", flush=True) + return {"geomean_latency_ms": 0.0, "latencies": [], "skipped": True} + latencies = [] - - for i, (M, N, K) in enumerate(shapes): - torch.cuda.empty_cache() - dtype = torch.bfloat16 - - x, w, w_kernel, w_scales, w_scales_kernel = generate_inputs(M, N, K, dtype=dtype) - - # Warmup - for _ in range(5): - y = gemm_a16wfp4(x, w_kernel, w_scales_kernel, atomic_add=False, dtype=dtype) + + print(f"Running benchmark on {len(shapes)} shapes, {warmup} warmup, {iters} iterations each...") + if verbose: + print(f"{'Config':<22} {'Triton':>10}") + print("-" * 34) + + for cfg in shapes: + M, N, K = cfg + x, w, w_kernel, w_scales, w_scales_kernel = generate_inputs(M, N, K, DTYPE) + + for _ in range(warmup): + gemm_a16wfp4(x, w_kernel, w_scales_kernel, atomic_add=False, dtype=DTYPE) torch.cuda.synchronize() - - # Benchmark - times = [] - for _ in range(iterations): - torch.cuda.synchronize() - start = time.perf_counter() - y = gemm_a16wfp4(x, w_kernel, w_scales_kernel, atomic_add=False, dtype=dtype) + + triton_times = [] + for _ in range(iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + gemm_a16wfp4(x, w_kernel, w_scales_kernel, atomic_add=False, dtype=DTYPE) + end.record() torch.cuda.synchronize() - end = time.perf_counter() - times.append((end - start) * 1000) # Convert to ms - - median_time = sorted(times)[len(times) // 2] - latencies.append(median_time) - print(f" [{i+1}/{len(shapes)}] Shape ({M}, {N}, {K}): {median_time:.4f} ms") - - # Compute geometric mean of latencies - import math - geomean = math.exp(sum(math.log(t) for t in latencies) / len(latencies)) - print(f"\nGeometric mean latency: {geomean:.4f} ms") - print(f"GEAK_RESULT_LATENCY_MS={geomean:.4f}") - - -def main(): - parser = argparse.ArgumentParser(description="Test harness for gemm_a16wfp4 kernel") - parser.add_argument("--correctness", action="store_true", help="Run correctness tests") - parser.add_argument("--profile", action="store_true", help="Run kernel once for profiling") - parser.add_argument("--benchmark", action="store_true", help="Run benchmark on HARNESS_SHAPES") - parser.add_argument("--full-benchmark", action="store_true", help="Run benchmark on ALL_SHAPES") - parser.add_argument("--iterations", type=int, default=None, help="Number of benchmark iterations") - + triton_times.append(start.elapsed_time(end)) + + triton_ms = sorted(triton_times)[len(triton_times) // 2] + latencies.append(triton_ms) + + if verbose: + print(f"{_label(cfg):<22} {triton_ms:>8.4f}ms", flush=True) + + del x, w, w_kernel, w_scales, w_scales_kernel + torch.cuda.empty_cache() + + geomean_latency = math.exp(sum(math.log(l) for l in latencies) / len(latencies)) + + print("-" * 34) + print(f"{'Geometric mean latency:':<22} {geomean_latency:.4f} ms") + print(f"GEAK_SHAPES_USED={list(range(len(shapes)))}") + print(f"GEAK_RESULT_LATENCY_MS={geomean_latency:.4f}", flush=True) + + return {"geomean_latency_ms": geomean_latency, "latencies": latencies} + + +# ============================================================================ +# MAIN +# ============================================================================ + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="gemm_a16wfp4 (MXFP4) Test Harness") + parser.add_argument("--correctness", action="store_true", + help="Run correctness tests on HARNESS_SHAPES") + parser.add_argument("--profile", action="store_true", + help="Run minimal profiling workload") + parser.add_argument("--benchmark", action="store_true", + help="Run benchmark on HARNESS_SHAPES (25 uniformly sampled)") + parser.add_argument("--full-benchmark", action="store_true", + help="Run benchmark on ALL_SHAPES (complete set)") + parser.add_argument("--warmup", type=int, default=50, + help="Number of warmup iterations (default: 50)") + parser.add_argument("--iterations", type=int, + default=int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")), + help="Number of benchmark iterations (default: GEAK_BENCHMARK_ITERATIONS or 200)") args = parser.parse_args() - + + print("=" * 62) + print("gemm_a16wfp4 (MXFP4) Test Harness") + print("=" * 62) + if args.correctness: - success = run_correctness(HARNESS_SHAPES) - sys.exit(0 if success else 1) + print("\n[Correctness Mode]") + result = run_correctness(HARNESS_SHAPES) + sys.exit(0 if result["correct"] else 1) elif args.profile: - run_profile(PROFILE_SHAPES) - elif args.benchmark: - iterations = args.iterations if args.iterations is not None else int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "10")) - run_benchmark(HARNESS_SHAPES, iterations) + print("\n[Profile Mode]") + run_profile(PROFILE_SHAPES, warmup=args.warmup, iters=args.iterations) elif args.full_benchmark: - iterations = args.iterations if args.iterations is not None else int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "20")) - run_benchmark(ALL_SHAPES, iterations) + print("\n[Full Benchmark Mode]") + run_benchmark(ALL_SHAPES, warmup=args.warmup, iters=args.iterations) else: - parser.print_help() - sys.exit(1) + print("\n[Benchmark Mode]") + run_benchmark(HARNESS_SHAPES, warmup=args.warmup, iters=args.iterations) - -if __name__ == "__main__": - main() + print("=" * 62)