Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions benchmarks/bench_gru_reset_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,37 @@ def _first_call_ms(fn: Callable[[], object], device: torch.device) -> float:
return (time.perf_counter() - start) * 1000


def _cudagraph_wrap(
fn: Callable[[], object],
device: torch.device,
warmup_iters: int = 5,
) -> Callable[[], object]:
"""Capture ``fn`` into a ``torch.cuda.CUDAGraph`` and return a replay wrapper.

The captured graph reuses the input/output buffers that ``fn`` closes over,
so on replay there is no Python-side work and no per-launch driver
overhead. Warmup iterations on a side stream let Triton autotune settle
and let the allocator stake out its working set before capture.
"""
if device.type != "cuda":
return fn
s = torch.cuda.Stream(device=device)
s.wait_stream(torch.cuda.current_stream(device))
with torch.cuda.stream(s), torch.inference_mode():
for _ in range(warmup_iters):
fn()
torch.cuda.current_stream(device).wait_stream(s)
torch.cuda.synchronize(device)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g), torch.inference_mode():
fn()

def graphed_fn():
g.replay()

return graphed_fn


def _make_modules(
rnn_type: RNNType,
input_size: int,
Expand Down Expand Up @@ -478,6 +509,16 @@ def main() -> None:
"'auto' leaves the argument unset."
),
)
parser.add_argument(
"--cudagraph",
action="store_true",
help=(
"Wrap the (optionally compiled) benchmark callable in a "
"torch.cuda.CUDAGraph. Applied after --compile, so the two "
"compose. cuDNN's reset path is incompatible with capture and is "
"skipped automatically when reset_prob > 0."
),
)
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()
if "scan_compile_td" in args.modes or args.compile != "none":
Expand All @@ -496,7 +537,7 @@ def main() -> None:

print(
"device,rnn_type,batch,steps,num_layers,dropout,reset_prob,mode,"
"compile,compile_fullgraph,compile_dynamic,first_call_ms,"
"compile,compile_fullgraph,compile_dynamic,cudagraph,first_call_ms,"
"median_ms,min_ms,frames_per_s,actual_reset_frac"
)
for rnn_type, batch, steps, num_layers, dropout, reset_prob in itertools.product(
Expand Down Expand Up @@ -564,14 +605,20 @@ def main() -> None:
args.compile_fullgraph,
compile_dynamic,
)
# The cudnn_pad backend inspects ``is_init.any()`` host-side to
# pick between the dense and split-segment paths, which forbids
# stream capture. Skip cudnn_pad whenever --cudagraph is on.
apply_cudagraph = args.cudagraph and mode != "cudnn_pad_td"
if apply_cudagraph:
fn = _cudagraph_wrap(fn, device)
first_call_ms = _first_call_ms(fn, device)
median_ms, min_ms = _bench(fn, device, args.warmup, args.iters)
frames_per_s = batch * steps / (median_ms / 1000)
print(
f"{device},{rnn_type},{batch},{steps},{num_layers},{dropout},"
f"{reset_prob},{mode},"
f"{mode_compile},{args.compile_fullgraph},{args.compile_dynamic},"
f"{first_call_ms:.4f},"
f"{apply_cudagraph},{first_call_ms:.4f},"
f"{median_ms:.4f},{min_ms:.4f},{frames_per_s:.2f},"
f"{actual_reset_frac:.6f}"
)
Expand Down
Loading
Loading