|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Measure pure graph.replay() time vs kernel launches.""" |
| 3 | + |
| 4 | +import gc |
| 5 | +import time |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +model_path = "C:/Users/y_har/.cache/huggingface/hub/models--Aratako--Qwen3-8B-ERP-v0.1/snapshots/8311aa4482f02c2de93872e4979887def1841faf/model.safetensors.index.json" |
| 9 | + |
| 10 | +from pygpukit.llm import detect_model_spec, load_model_from_safetensors, load_safetensors |
| 11 | +from pygpukit.llm.model import DecodeBuffers, precompute_freqs_cis |
| 12 | +from pygpukit.core import default_stream, from_numpy |
| 13 | +from pygpukit.ops.basic import kv_cache_prefill_gqa, rmsnorm, copy_to, add_inplace, embedding_lookup |
| 14 | +from pygpukit._pygpukit_native import CudaGraph |
| 15 | + |
| 16 | +MAX_SEQ_LEN = 512 |
| 17 | + |
| 18 | +print("=" * 60) |
| 19 | +print("Pure Graph Replay Benchmark") |
| 20 | +print("=" * 60) |
| 21 | + |
| 22 | +print("\nLoading model...") |
| 23 | +st = load_safetensors(model_path) |
| 24 | +spec = detect_model_spec(st.tensor_names) |
| 25 | +model = load_model_from_safetensors(model_path, dtype="float16", spec=spec) |
| 26 | +dtype = str(model.embed_tokens.dtype) |
| 27 | +use_qk_norm = model.spec is not None and model.spec.use_qk_norm |
| 28 | + |
| 29 | +print("Initializing buffers...") |
| 30 | +for block in model.blocks: |
| 31 | + block.attn.init_fixed_cache(MAX_SEQ_LEN, dtype=dtype) |
| 32 | + |
| 33 | +buffers = DecodeBuffers.allocate(model.config, dtype=dtype, use_qk_norm=use_qk_norm) |
| 34 | + |
| 35 | +if model.config.use_rope: |
| 36 | + cos_np, sin_np = precompute_freqs_cis( |
| 37 | + model.config.head_dim, MAX_SEQ_LEN, model.config.rope_theta |
| 38 | + ) |
| 39 | + np_dtype = np.float16 if dtype == "float16" else np.float32 |
| 40 | + model._rope_cos_gpu = from_numpy(cos_np.astype(np_dtype)) |
| 41 | + model._rope_sin_gpu = from_numpy(sin_np.astype(np_dtype)) |
| 42 | + |
| 43 | +# Run prefill to initialize KV cache |
| 44 | +print("Running prefill...") |
| 45 | +input_ids = [1, 2, 3, 4, 5] # Dummy tokens |
| 46 | +hidden, past_key_values = model(input_ids, use_cache=True) |
| 47 | +for i, block in enumerate(model.blocks): |
| 48 | + past_k, past_v = past_key_values[i] |
| 49 | + kv_cache_prefill_gqa(past_k, block.attn._k_cache, block.attn.num_heads, start_pos=0) |
| 50 | + kv_cache_prefill_gqa(past_v, block.attn._v_cache, block.attn.num_heads, start_pos=0) |
| 51 | + |
| 52 | +token_id = 100 |
| 53 | +position = 5 |
| 54 | +context_len = 6 |
| 55 | + |
| 56 | +# Define inline decode step |
| 57 | +def _inline_decode_step(): |
| 58 | + embedding_lookup(model.embed_tokens, buffers.hidden, token_id) |
| 59 | + for block in model.blocks: |
| 60 | + rmsnorm(buffers.hidden, block.attn_norm.weight, block.attn_norm.eps, out=buffers.norm_out) |
| 61 | + copy_to(buffers.hidden, buffers.residual) |
| 62 | + model._attention_forward_zero_alloc( |
| 63 | + block.attn, buffers.norm_out, position, context_len, buffers, |
| 64 | + use_position_ptr=False, |
| 65 | + ) |
| 66 | + add_inplace(buffers.hidden, buffers.residual) |
| 67 | + copy_to(buffers.hidden, buffers.residual) |
| 68 | + rmsnorm(buffers.hidden, block.mlp_norm.weight, block.mlp_norm.eps, out=buffers.norm_out) |
| 69 | + model._mlp_forward_zero_alloc(block.mlp, buffers.norm_out, buffers) |
| 70 | + add_inplace(buffers.hidden, buffers.residual) |
| 71 | + rmsnorm(buffers.hidden, model.final_norm.weight, model.final_norm.eps, out=buffers.norm_out) |
| 72 | + copy_to(buffers.norm_out, buffers.hidden) |
| 73 | + |
| 74 | +# ============================================================ |
| 75 | +# Test 1: Direct kernel launches (no graph) |
| 76 | +# ============================================================ |
| 77 | +print("\n--- Test 1: Direct Kernel Launches ---") |
| 78 | + |
| 79 | +# Warmup |
| 80 | +for _ in range(3): |
| 81 | + _inline_decode_step() |
| 82 | +default_stream().synchronize() |
| 83 | + |
| 84 | +# Measure |
| 85 | +times_direct = [] |
| 86 | +for i in range(10): |
| 87 | + default_stream().synchronize() |
| 88 | + start = time.perf_counter() |
| 89 | + _inline_decode_step() |
| 90 | + default_stream().synchronize() |
| 91 | + elapsed = (time.perf_counter() - start) * 1000 |
| 92 | + times_direct.append(elapsed) |
| 93 | + print(f" {i+1}: {elapsed:.2f} ms") |
| 94 | + |
| 95 | +mean_direct = np.mean(times_direct) |
| 96 | +print(f" Mean: {mean_direct:.2f} ms") |
| 97 | + |
| 98 | +# ============================================================ |
| 99 | +# Test 2: Graph capture and replay |
| 100 | +# ============================================================ |
| 101 | +print("\n--- Test 2: CUDA Graph Replay ---") |
| 102 | + |
| 103 | +# Capture graph |
| 104 | +print("Capturing graph...") |
| 105 | +graph = CudaGraph() |
| 106 | +gc.disable() |
| 107 | +try: |
| 108 | + graph.begin_capture() |
| 109 | + _inline_decode_step() |
| 110 | + graph.end_capture() |
| 111 | +finally: |
| 112 | + gc.enable() |
| 113 | +print(f" Captured {graph.num_nodes} nodes") |
| 114 | + |
| 115 | +# Warmup replay |
| 116 | +for _ in range(3): |
| 117 | + graph.replay() |
| 118 | +graph.synchronize() |
| 119 | + |
| 120 | +# Measure replay |
| 121 | +times_graph = [] |
| 122 | +for i in range(10): |
| 123 | + graph.synchronize() # Ensure previous is done |
| 124 | + start = time.perf_counter() |
| 125 | + graph.replay() |
| 126 | + graph.synchronize() |
| 127 | + elapsed = (time.perf_counter() - start) * 1000 |
| 128 | + times_graph.append(elapsed) |
| 129 | + print(f" {i+1}: {elapsed:.2f} ms") |
| 130 | + |
| 131 | +mean_graph = np.mean(times_graph) |
| 132 | +print(f" Mean: {mean_graph:.2f} ms") |
| 133 | + |
| 134 | +# ============================================================ |
| 135 | +# Summary |
| 136 | +# ============================================================ |
| 137 | +print("\n" + "=" * 60) |
| 138 | +print("SUMMARY (Transformer blocks only, no get_logits)") |
| 139 | +print("=" * 60) |
| 140 | +print(f"Direct launches: {mean_direct:.2f} ms") |
| 141 | +print(f"Graph replay: {mean_graph:.2f} ms") |
| 142 | +print(f"Speedup: {mean_direct/mean_graph:.2f}x") |
| 143 | +print(f"Saved per step: {mean_direct - mean_graph:.2f} ms") |
| 144 | +print("=" * 60) |
0 commit comments