11import dataclasses
22import math
33import os
4+ import random
45import re
56import statistics
67import sys
2526BENCH_ITERS = 100
2627
2728
29+ def make_seed_schedule ():
30+ total = WARMUP_ITERS + 3 * BENCH_ITERS
31+ seeds = random .SystemRandom ().sample (range (1 , 2 ** 31 - 1 ), total )
32+ warmup_end = WARMUP_ITERS
33+ forward_end = warmup_end + BENCH_ITERS
34+ backward_end = forward_end + BENCH_ITERS
35+ return {
36+ "warmup" : seeds [:warmup_end ],
37+ "forward" : seeds [warmup_end :forward_end ],
38+ "backward" : seeds [forward_end :backward_end ],
39+ "combined" : seeds [backward_end :],
40+ }
41+
42+
2843class PopcornOutput :
2944 def __init__ (self , fd : int ):
3045 self .file = os .fdopen (fd , "w" )
@@ -134,16 +149,20 @@ def check_correctness(mod, vocab_size):
134149 return fwd_close , bwd_close , max_fwd_err , max_bwd_err
135150
136151
137- def benchmark_one (mod , vocab_size ):
138- logits , targets , grad_output = generate_inputs (B , vocab_size , seed = 123 )
152+ def benchmark_one (mod , vocab_size , seed_schedule ):
153+ def phase_inputs (phase , idx ):
154+ seed = seed_schedule [phase ][idx ]
155+ return generate_inputs (B , vocab_size , seed = seed )
139156
140- for _ in range (WARMUP_ITERS ):
157+ for idx in range (WARMUP_ITERS ):
158+ logits , targets , grad_output = phase_inputs ("warmup" , idx )
141159 mod .cross_entropy_forward (logits , targets )
142160 mod .cross_entropy_backward (logits , targets , grad_output )
143161 torch .cuda .synchronize ()
144162
145163 fwd_times = []
146- for _ in range (BENCH_ITERS ):
164+ for idx in range (BENCH_ITERS ):
165+ logits , targets , _ = phase_inputs ("forward" , idx )
147166 start = torch .cuda .Event (enable_timing = True )
148167 end = torch .cuda .Event (enable_timing = True )
149168 start .record ()
@@ -153,7 +172,8 @@ def benchmark_one(mod, vocab_size):
153172 fwd_times .append (start .elapsed_time (end ))
154173
155174 bwd_times = []
156- for _ in range (BENCH_ITERS ):
175+ for idx in range (BENCH_ITERS ):
176+ logits , targets , grad_output = phase_inputs ("backward" , idx )
157177 start = torch .cuda .Event (enable_timing = True )
158178 end = torch .cuda .Event (enable_timing = True )
159179 start .record ()
@@ -163,7 +183,8 @@ def benchmark_one(mod, vocab_size):
163183 bwd_times .append (start .elapsed_time (end ))
164184
165185 combined_times = []
166- for _ in range (BENCH_ITERS ):
186+ for idx in range (BENCH_ITERS ):
187+ logits , targets , grad_output = phase_inputs ("combined" , idx )
167188 start = torch .cuda .Event (enable_timing = True )
168189 end = torch .cuda .Event (enable_timing = True )
169190 start .record ()
@@ -263,8 +284,9 @@ def baseline_bwd(logits, targets, grad_output):
263284 vocab_size = int (test .args ["vocab_size" ])
264285 logger .log (f"benchmark.{ idx } .spec" , test .spec )
265286 try :
266- baseline = benchmark_one (baseline_mod , vocab_size )
267- result = benchmark_one (mod , vocab_size )
287+ seed_schedule = make_seed_schedule ()
288+ baseline = benchmark_one (baseline_mod , vocab_size , seed_schedule )
289+ result = benchmark_one (mod , vocab_size , seed_schedule )
268290 speedup = baseline .mean / result .mean
269291 except Exception as exc :
270292 logger .log (f"benchmark.{ idx } .status" , "fail" )
0 commit comments