-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbenchmark_fa3_sm120.py
More file actions
162 lines (120 loc) · 5.17 KB
/
benchmark_fa3_sm120.py
File metadata and controls
162 lines (120 loc) · 5.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""
FA3 SM120 Configuration Benchmark
Uses sdpa_causal_timed to measure attention kernel performance.
Environment variables control which FA3 variant is used:
- PYGPUKIT_FA3=1: Force FA3 on
- PYGPUKIT_FA3_TMA=1: Force TMA variant
Current: FA3 TMA at 51.97 TFLOPS (baseline)
Target: 60+ TFLOPS with SM120 tuning
"""
import numpy as np
import time
import os
import sys
import pygpukit as gpk
from pygpukit.core.backend import get_native_module
from pygpukit.core.dtypes import DataType
native = get_native_module()
def compute_attention_flops(batch: int, heads: int, seq_q: int, seq_kv: int, head_dim: int) -> int:
"""Compute total FLOPs for attention forward pass."""
# Q@K^T: 2 * batch * heads * seq_q * seq_kv * head_dim
qk_flops = 2 * batch * heads * seq_q * seq_kv * head_dim
# P@V: 2 * batch * heads * seq_q * head_dim * seq_kv
pv_flops = 2 * batch * heads * seq_q * head_dim * seq_kv
return qk_flops + pv_flops
def benchmark_sdpa_timed(heads: int, seq_len: int, head_dim: int, num_iters: int = 50):
"""Benchmark SDPA using kernel-only timing (sdpa_causal_timed)."""
bf16 = DataType.from_string("bfloat16")
# Allocate tensors [n_heads, seq_len, head_dim]
Q_np = np.random.randn(heads, seq_len, head_dim).astype(np.float32) * 0.1
K_np = np.random.randn(heads, seq_len, head_dim).astype(np.float32) * 0.1
V_np = np.random.randn(heads, seq_len, head_dim).astype(np.float32) * 0.1
Q = gpk.from_numpy(Q_np).astype(bf16)
K = gpk.from_numpy(K_np).astype(bf16)
V = gpk.from_numpy(V_np).astype(bf16)
O = gpk.zeros((heads, seq_len, head_dim), dtype=bf16)
scale = 1.0 / np.sqrt(head_dim)
# Warmup
for _ in range(3):
native.sdpa_causal_(Q._native, K._native, V._native, O._native, scale)
# Benchmark using kernel timing
native.device_synchronize()
total_time_us = 0.0
for _ in range(num_iters):
kernel_us = native.sdpa_causal_timed(Q._native, K._native, V._native, O._native, scale)
total_time_us += kernel_us
avg_time_us = total_time_us / num_iters
# Compute TFLOPS (batch=1 for single head group)
flops = compute_attention_flops(1, heads, seq_len, seq_len, head_dim)
tflops = flops / (avg_time_us * 1e-6) / 1e12
return avg_time_us, tflops
def benchmark_sdpa_python_timing(heads: int, seq_len: int, head_dim: int, num_iters: int = 50):
"""Benchmark SDPA using Python-side timing (includes overhead)."""
bf16 = DataType.from_string("bfloat16")
# Allocate tensors
Q_np = np.random.randn(heads, seq_len, head_dim).astype(np.float32) * 0.1
K_np = np.random.randn(heads, seq_len, head_dim).astype(np.float32) * 0.1
V_np = np.random.randn(heads, seq_len, head_dim).astype(np.float32) * 0.1
Q = gpk.from_numpy(Q_np).astype(bf16)
K = gpk.from_numpy(K_np).astype(bf16)
V = gpk.from_numpy(V_np).astype(bf16)
O = gpk.zeros((heads, seq_len, head_dim), dtype=bf16)
scale = 1.0 / np.sqrt(head_dim)
# Warmup
for _ in range(3):
native.sdpa_causal_(Q._native, K._native, V._native, O._native, scale)
# Benchmark
native.device_synchronize()
start = time.perf_counter()
for _ in range(num_iters):
native.sdpa_causal_(Q._native, K._native, V._native, O._native, scale)
native.device_synchronize()
elapsed = (time.perf_counter() - start) / num_iters
# Compute TFLOPS
flops = compute_attention_flops(1, heads, seq_len, seq_len, head_dim)
tflops = flops / elapsed / 1e12
return elapsed * 1e6, tflops
def main():
print("=" * 70)
print("FA3 SM120 Attention Benchmark")
print("=" * 70)
# Print environment
fa3_env = os.environ.get("PYGPUKIT_FA3", "auto")
fa3_tma_env = os.environ.get("PYGPUKIT_FA3_TMA", "auto")
print(f"PYGPUKIT_FA3={fa3_env}")
print(f"PYGPUKIT_FA3_TMA={fa3_tma_env}")
# Get device info
print(f"\nDevice: SM{native.get_sm_version()}")
# Test configurations
configs = [
# (heads, seq_len, head_dim)
(32, 512, 128),
(32, 1024, 128),
(32, 2048, 128),
(32, 4096, 128),
]
num_iters = 50
print(f"\n{'Config':<25} {'Kernel (us)':<12} {'TFLOPS':<10} {'Python (us)':<12} {'TFLOPS':<10}")
print("-" * 80)
for heads, seq_len, head_dim in configs:
config_str = f"h={heads}, s={seq_len}, d={head_dim}"
try:
# Kernel-only timing
kernel_us, kernel_tflops = benchmark_sdpa_timed(heads, seq_len, head_dim, num_iters)
# Python-side timing (for comparison)
python_us, python_tflops = benchmark_sdpa_python_timing(heads, seq_len, head_dim, num_iters)
print(f"{config_str:<25} {kernel_us:<12.1f} {kernel_tflops:<10.2f} {python_us:<12.1f} {python_tflops:<10.2f}")
except Exception as e:
print(f"{config_str:<25} ERROR: {e}")
# Print TMA cache stats
print("\n" + "=" * 70)
print("TMA Descriptor Cache Stats:")
native.print_tma_cache_stats()
print("\n" + "=" * 70)
print("Notes:")
print("- Kernel timing uses CUDA Events (excludes Python/host overhead)")
print("- Python timing includes launch overhead")
print("- TFLOPS calculated from kernel timing")
print("=" * 70)
if __name__ == "__main__":
main()