Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 1 addition & 192 deletions tasks/triton2triton/geak_eval/L1/refk_fp8_blockwise_mm/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
c: [m, n] bfloat16 (pre-allocated output)
"""

import math
import os
import time

import torch
import triton
import triton.language as tl
Expand Down Expand Up @@ -177,36 +173,7 @@ def fp8_blockwise_mm_triton(a, b, a_scale, b_scale, c):


# ============================================================================
# REFERENCE IMPLEMENTATION (pure PyTorch — same as submission.py)
# ============================================================================


def fp8_blockwise_mm_pytorch(a, b, a_scale, b_scale, c):
a_c = a.contiguous()
a_s = a_scale.contiguous()
b_s = b_scale.contiguous()

m, k = a_c.shape
n = b.shape[0]
block_n, block_k = BLOCK_SHAPE_N, BLOCK_SHAPE_K
sn = b_s.shape[0]
sk = b_s.shape[1]

a_sc = a_s.unsqueeze(-1).repeat(1, 1, block_k).reshape(m, sk * block_k)[:, :k]
a_deq = a_c.to(a_sc.dtype) * a_sc

b_sc = (b_s.view(-1, 1).repeat(1, block_n * block_k)
.view(sn, sk, block_n, block_k)
.permute(0, 2, 1, 3)
.reshape(sn * block_n, sk * block_k))[:n, :k]
b_deq = b.to(b_sc.dtype) * b_sc

c[...] = (a_deq @ b_deq.T).to(torch.bfloat16)
return c


# ============================================================================
# ENTRY POINTS (for GEAK harness)
# ENTRY POINT (callable form of the kernel for downstream profilers)
# ============================================================================


Expand All @@ -215,11 +182,6 @@ def triton_op(m, n, k, seed):
return fp8_blockwise_mm_triton(*data)


def torch_op(m, n, k, seed):
data = _generate_input(m, n, k, seed)
return fp8_blockwise_mm_pytorch(*data)


# ============================================================================
# SYNTHETIC INPUT BUILDER (matches reference.py generate_input)
# ============================================================================
Expand Down Expand Up @@ -296,157 +258,4 @@ def get_inputs(m, n, k, seed=42, device="cuda"):
{"m": 6144, "n": 4608, "k": 7168, "seed": 65436},
]

WARMUP = 50
ITERATIONS = int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200"))
RTOL, ATOL = 2e-2, 1e-3


# ============================================================================
# SELF-TEST HARNESS
# ============================================================================


def check_correctness(cfg) -> dict:
try:
data = get_inputs(**cfg)
a, b, a_scale, b_scale, c_triton = data
c_ref = c_triton.clone()

fp8_blockwise_mm_triton(a, b, a_scale, b_scale, c_triton)
fp8_blockwise_mm_pytorch(a, b, a_scale, b_scale, c_ref)
torch.cuda.synchronize()

correct = torch.allclose(c_triton.float(), c_ref.float(), rtol=RTOL, atol=ATOL)
max_diff = torch.max(torch.abs(c_triton.float() - c_ref.float())).item()
return {"correct": correct, "max_diff": max_diff, "error": None}
except Exception as e:
return {"correct": False, "max_diff": float("inf"), "error": str(e)}


def benchmark_config(cfg, warmup=WARMUP, iters=ITERATIONS) -> dict:
data = get_inputs(**cfg)
a, b, a_scale, b_scale, c = data

for _ in range(warmup):
c_t = c.clone()
fp8_blockwise_mm_triton(a, b, a_scale, b_scale, c_t)
torch.cuda.synchronize()

start = time.perf_counter()
for _ in range(iters):
c_t = c.clone()
fp8_blockwise_mm_triton(a, b, a_scale, b_scale, c_t)
torch.cuda.synchronize()
triton_ms = (time.perf_counter() - start) * 1000 / iters

for _ in range(warmup):
c_r = c.clone()
fp8_blockwise_mm_pytorch(a, b, a_scale, b_scale, c_r)
torch.cuda.synchronize()

start = time.perf_counter()
for _ in range(iters):
c_r = c.clone()
fp8_blockwise_mm_pytorch(a, b, a_scale, b_scale, c_r)
torch.cuda.synchronize()
torch_ms = (time.perf_counter() - start) * 1000 / iters

return {"triton_ms": triton_ms, "torch_ms": torch_ms,
"speedup": torch_ms / triton_ms if triton_ms > 0 else 0.0}


def _config_label(cfg):
return f"(M={cfg['m']},N={cfg['n']},K={cfg['k']})"


def evaluate(configs=None, warmup=WARMUP, iters=ITERATIONS, verbose=True) -> dict:
configs = configs or TEST_CONFIGS[:5]
results, failures = [], []

if verbose:
print(f"{'Config':<26} {'Correct':>8} {'Torch':>10} {'Triton':>10} {'Speedup':>10}")
print("-" * 66)

for cfg in configs:
label = _config_label(cfg)
corr = check_correctness(cfg)
if not corr["correct"]:
failures.append({"config": cfg, **corr})
if verbose:
err = corr["error"] or f"max_diff={corr['max_diff']:.2e}"
print(f"{label:<26} {'FAIL':>8} {err[:30]}")
continue

bench = benchmark_config(cfg, warmup=warmup, iters=iters)
results.append({"config": cfg, "correct": True, **bench})

if verbose:
marker = " *" if bench["speedup"] > 1.0 else ""
print(
f"{label:<26} {'PASS':>8} "
f"{bench['torch_ms']:>8.3f}ms {bench['triton_ms']:>8.3f}ms "
f"{bench['speedup']:>8.2f}x{marker}"
)

speedups = [r["speedup"] for r in results]
geomean = math.prod(speedups) ** (1 / len(speedups)) if speedups else 0.0

if verbose:
print("-" * 66)
status = "ALL PASS" if not failures else f"FAILED ({len(failures)}/{len(configs)})"
print(f"{'Status:':<26} {status}")
if speedups:
print(f"{'Speedup (geomean):':<26} {geomean:.2f}x")

return {
"correct": len(failures) == 0,
"num_correct": len(results),
"num_failed": len(failures),
"failures": failures,
"results": results,
"speedup_geomean": geomean,
}


def run_profile(configs=None, warmup=5, iters=1, verbose=True):
configs = configs or PROFILE_CONFIGS
if verbose:
print(f"Profile: {len(configs)} config(s)")
for cfg in configs:
data = get_inputs(**cfg)
a, b, a_scale, b_scale, c = data
for _ in range(warmup):
ct = c.clone()
fp8_blockwise_mm_triton(a, b, a_scale, b_scale, ct)
torch.cuda.synchronize()
for _ in range(iters):
ct = c.clone()
fp8_blockwise_mm_triton(a, b, a_scale, b_scale, ct)
torch.cuda.synchronize()
if verbose:
print(f" {_config_label(cfg)} done")


# ============================================================================
# MAIN
# ============================================================================

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="FP8 Block-Scale GEMM (Triton dequant)")
parser.add_argument("--profile", action="store_true")
args = parser.parse_args()

print("=" * 66)
print("FP8 Block-Scale GEMM — Triton dequant + torch.mm")
print("=" * 66)

if args.profile:
print("\n[Profile Mode]")
run_profile()
else:
print("\n[Evaluation]")
evaluate()

print("=" * 66)
Loading