Skip to content

Commit 16022b0

Browse files
authored
Merge pull request #142 from Ammaar-Alam/fix/princeton-cross-entropy-replay-exploit
Fix Princeton cross-entropy replay exploit via phase-specific inputs
2 parents 575362a + 5a11ebd commit 16022b0

File tree

1 file changed

+30
-8
lines changed
  • problems/princeton/cross_entropy_py

1 file changed

+30
-8
lines changed

problems/princeton/cross_entropy_py/eval.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import dataclasses
22
import math
33
import os
4+
import random
45
import re
56
import statistics
67
import sys
@@ -25,6 +26,20 @@
2526
BENCH_ITERS = 100
2627

2728

29+
def make_seed_schedule():
30+
total = WARMUP_ITERS + 3 * BENCH_ITERS
31+
seeds = random.SystemRandom().sample(range(1, 2**31 - 1), total)
32+
warmup_end = WARMUP_ITERS
33+
forward_end = warmup_end + BENCH_ITERS
34+
backward_end = forward_end + BENCH_ITERS
35+
return {
36+
"warmup": seeds[:warmup_end],
37+
"forward": seeds[warmup_end:forward_end],
38+
"backward": seeds[forward_end:backward_end],
39+
"combined": seeds[backward_end:],
40+
}
41+
42+
2843
class PopcornOutput:
2944
def __init__(self, fd: int):
3045
self.file = os.fdopen(fd, "w")
@@ -134,16 +149,20 @@ def check_correctness(mod, vocab_size):
134149
return fwd_close, bwd_close, max_fwd_err, max_bwd_err
135150

136151

137-
def benchmark_one(mod, vocab_size):
138-
logits, targets, grad_output = generate_inputs(B, vocab_size, seed=123)
152+
def benchmark_one(mod, vocab_size, seed_schedule):
153+
def phase_inputs(phase, idx):
154+
seed = seed_schedule[phase][idx]
155+
return generate_inputs(B, vocab_size, seed=seed)
139156

140-
for _ in range(WARMUP_ITERS):
157+
for idx in range(WARMUP_ITERS):
158+
logits, targets, grad_output = phase_inputs("warmup", idx)
141159
mod.cross_entropy_forward(logits, targets)
142160
mod.cross_entropy_backward(logits, targets, grad_output)
143161
torch.cuda.synchronize()
144162

145163
fwd_times = []
146-
for _ in range(BENCH_ITERS):
164+
for idx in range(BENCH_ITERS):
165+
logits, targets, _ = phase_inputs("forward", idx)
147166
start = torch.cuda.Event(enable_timing=True)
148167
end = torch.cuda.Event(enable_timing=True)
149168
start.record()
@@ -153,7 +172,8 @@ def benchmark_one(mod, vocab_size):
153172
fwd_times.append(start.elapsed_time(end))
154173

155174
bwd_times = []
156-
for _ in range(BENCH_ITERS):
175+
for idx in range(BENCH_ITERS):
176+
logits, targets, grad_output = phase_inputs("backward", idx)
157177
start = torch.cuda.Event(enable_timing=True)
158178
end = torch.cuda.Event(enable_timing=True)
159179
start.record()
@@ -163,7 +183,8 @@ def benchmark_one(mod, vocab_size):
163183
bwd_times.append(start.elapsed_time(end))
164184

165185
combined_times = []
166-
for _ in range(BENCH_ITERS):
186+
for idx in range(BENCH_ITERS):
187+
logits, targets, grad_output = phase_inputs("combined", idx)
167188
start = torch.cuda.Event(enable_timing=True)
168189
end = torch.cuda.Event(enable_timing=True)
169190
start.record()
@@ -263,8 +284,9 @@ def baseline_bwd(logits, targets, grad_output):
263284
vocab_size = int(test.args["vocab_size"])
264285
logger.log(f"benchmark.{idx}.spec", test.spec)
265286
try:
266-
baseline = benchmark_one(baseline_mod, vocab_size)
267-
result = benchmark_one(mod, vocab_size)
287+
seed_schedule = make_seed_schedule()
288+
baseline = benchmark_one(baseline_mod, vocab_size, seed_schedule)
289+
result = benchmark_one(mod, vocab_size, seed_schedule)
268290
speedup = baseline.mean / result.mean
269291
except Exception as exc:
270292
logger.log(f"benchmark.{idx}.status", "fail")

0 commit comments

Comments
 (0)