Skip to content

Commit 208fd03

Browse files
Fix batch-and-skip benchmark exploit via per-call timing and correctness checks
The current eval times all 15 custom_kernel() calls as a single batch and divides by 15. A malicious submission can exploit this by deferring all work to one call (batching 15 problems into a single kernel launch) and making the other 14 calls no-ops, reporting ~1/15th of the real per-call cost. Cloning data alone (as proposed in #102) does not fully prevent this -- a shape-matching fallback path can still collect new data objects and batch them. This fix: - Clones data each timing iteration (prevents object-identity caching) - Times each call individually with its own CUDA events and GPU sync (prevents amortization across calls) - Checks correctness after each individual call in recheck/leaderboard mode (catches deferred-computation exploits that return uncomputed tensors) - Uses a local seed variable instead of mutating test.args - Fixes the recheck indentation bug where only the last call was checked
1 parent 04c0b02 commit 208fd03

File tree

1 file changed

+32
-23
lines changed

1 file changed

+32
-23
lines changed

problems/nvidia/eval_better_bench_grouped_gemm.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,16 @@ def _run_single_benchmark(
240240

241241
durations = []
242242
data_list = []
243-
# generate input data once
243+
# generate input data once (local seed avoids mutating test.args)
244244

245+
local_seed = test.args.get("seed", None)
245246
for i in range(NUM_ITERATIONS_PER_BENCHMARK):
246-
if "seed" in test.args:
247-
test.args["seed"] += 42
248-
data = generate_input(**test.args)
247+
if local_seed is not None:
248+
local_seed += 42
249+
args = {**test.args, "seed": local_seed}
250+
else:
251+
args = test.args
252+
data = generate_input(**args)
249253
data_list.append(data)
250254

251255
check_copy = _clone_data(data_list)
@@ -263,35 +267,40 @@ def _run_single_benchmark(
263267
if not good:
264268
return message
265269

266-
# now, do multiple timing runs without further correctness testing
267-
# there is an upper bound of 200 runs, and a lower bound of 3 runs;
268-
# otherwise, we repeat until we either measure at least 10 full seconds,
269-
# or the relative error of the mean is below 1%.
270+
# Timing: individual per-call measurement with GPU sync between calls.
271+
# This prevents "batch-and-skip" exploits where a submission defers all
272+
# work to one call and returns cached/uncomputed results for the rest.
273+
# Data is cloned each iteration to prevent object-identity caching.
270274

271275
bm_start_time = time.perf_counter_ns()
272276
for i in range(max_repeats):
277+
iteration_data = _clone_data(data_list)
273278
torch.cuda.synchronize()
279+
clear_l2_cache()
274280

281+
per_call_durations = []
275282
outputs = []
276-
clear_l2_cache()
277-
start_event = torch.cuda.Event(enable_timing=True)
278-
end_event = torch.cuda.Event(enable_timing=True)
279-
start_event.record()
280-
for data in data_list:
283+
for j, data in enumerate(iteration_data):
284+
start_event = torch.cuda.Event(enable_timing=True)
285+
end_event = torch.cuda.Event(enable_timing=True)
286+
start_event.record()
281287
output = custom_kernel(data)
288+
end_event.record()
289+
torch.cuda.synchronize()
290+
per_call_durations.append(
291+
start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns
292+
)
282293
outputs.append(output)
283-
end_event.record()
284-
torch.cuda.synchronize()
285-
duration = (
286-
start_event.elapsed_time(end_event) / NUM_ITERATIONS_PER_BENCHMARK
287-
) * 1e6 # Convert ms to ns
288294

289-
if recheck:
290-
for reference_output, custom_output in zip(check_copy, outputs):
291-
good, message = check_implementation(reference_output, custom_output)
292-
if not good:
293-
return message
295+
# Per-call correctness check catches deferred-computation exploits:
296+
# if a submission skips the kernel and returns uncomputed tensors,
297+
# the check fails immediately.
298+
if recheck:
299+
good, message = check_implementation(check_copy[j], output)
300+
if not good:
301+
return message
294302

303+
duration = sum(per_call_durations) / NUM_ITERATIONS_PER_BENCHMARK
295304
durations.append(duration)
296305

297306
total_bm_duration = time.perf_counter_ns() - bm_start_time

0 commit comments

Comments
 (0)