diff --git a/problems/princeton/cross_entropy_py/eval.py b/problems/princeton/cross_entropy_py/eval.py index 6fb8fd6c..65124d54 100644 --- a/problems/princeton/cross_entropy_py/eval.py +++ b/problems/princeton/cross_entropy_py/eval.py @@ -1,6 +1,7 @@ import dataclasses import math import os +import random import re import statistics import sys @@ -25,6 +26,20 @@ BENCH_ITERS = 100 +def make_seed_schedule(): + total = WARMUP_ITERS + 3 * BENCH_ITERS + seeds = random.SystemRandom().sample(range(1, 2**31 - 1), total) + warmup_end = WARMUP_ITERS + forward_end = warmup_end + BENCH_ITERS + backward_end = forward_end + BENCH_ITERS + return { + "warmup": seeds[:warmup_end], + "forward": seeds[warmup_end:forward_end], + "backward": seeds[forward_end:backward_end], + "combined": seeds[backward_end:], + } + + class PopcornOutput: def __init__(self, fd: int): self.file = os.fdopen(fd, "w") @@ -134,16 +149,20 @@ def check_correctness(mod, vocab_size): return fwd_close, bwd_close, max_fwd_err, max_bwd_err -def benchmark_one(mod, vocab_size): - logits, targets, grad_output = generate_inputs(B, vocab_size, seed=123) +def benchmark_one(mod, vocab_size, seed_schedule): + def phase_inputs(phase, idx): + seed = seed_schedule[phase][idx] + return generate_inputs(B, vocab_size, seed=seed) - for _ in range(WARMUP_ITERS): + for idx in range(WARMUP_ITERS): + logits, targets, grad_output = phase_inputs("warmup", idx) mod.cross_entropy_forward(logits, targets) mod.cross_entropy_backward(logits, targets, grad_output) torch.cuda.synchronize() fwd_times = [] - for _ in range(BENCH_ITERS): + for idx in range(BENCH_ITERS): + logits, targets, _ = phase_inputs("forward", idx) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() @@ -153,7 +172,8 @@ def benchmark_one(mod, vocab_size): fwd_times.append(start.elapsed_time(end)) bwd_times = [] - for _ in range(BENCH_ITERS): + for idx in range(BENCH_ITERS): + logits, targets, grad_output = phase_inputs("backward", idx) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() @@ -163,7 +183,8 @@ def benchmark_one(mod, vocab_size): bwd_times.append(start.elapsed_time(end)) combined_times = [] - for _ in range(BENCH_ITERS): + for idx in range(BENCH_ITERS): + logits, targets, grad_output = phase_inputs("combined", idx) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() @@ -263,8 +284,9 @@ def baseline_bwd(logits, targets, grad_output): vocab_size = int(test.args["vocab_size"]) logger.log(f"benchmark.{idx}.spec", test.spec) try: - baseline = benchmark_one(baseline_mod, vocab_size) - result = benchmark_one(mod, vocab_size) + seed_schedule = make_seed_schedule() + baseline = benchmark_one(baseline_mod, vocab_size, seed_schedule) + result = benchmark_one(mod, vocab_size, seed_schedule) speedup = baseline.mean / result.mean except Exception as exc: logger.log(f"benchmark.{idx}.status", "fail")