-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbenchmark_gatr.py
More file actions
98 lines (77 loc) · 3.02 KB
/
benchmark_gatr.py
File metadata and controls
98 lines (77 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import time
import os
import resource
import sys
try:
import kernel as algebra
except ImportError:
import sys
sys.path.append('.')
import kernel as algebra
def get_memory_mb():
# Helper to get current memory usage
if torch.cuda.is_available():
return torch.cuda.memory_allocated() / (1024**2)
else:
# Resource usage on Unix (Mac/Linux) returns maxrss in kilobytes on Linux, bytes on Mac?
# Actually python resource.getrusage(resource.RUSAGE_SELF).ru_maxrss is KB on Linux, but bytes on Mac often?
# Documentation says KB on Linux. On Mac it is bytes.
# Let's standardize:
usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
if sys.platform == 'darwin':
return usage / (1024**2) # bytes -> MB
else:
return usage / 1024 # KB -> MB
def benchmark_gatr_lift():
print("="*60)
print("GATr LIFT BENCHMARK: Memory & Speed")
print("="*60)
device = "cpu" # Force CPU for reliable consistent measuring on Mac if MPS is unstable
if torch.cuda.is_available():
device = "cuda"
# elif torch.backends.mps.is_available(): # MPS often lacks dual-event timing
# device = "mps"
print(f"Device: {device}")
# Config
B, N, K = 1000, 64, 32 # Match paper B=1000 claim
x = torch.randn(B, 32, 32, device=device) # (B, K_in, 32)
w = torch.randn(32, 32, 32, device=device) # (N_out, K_in, 32)
# 1. Standard PyTorch Baseline (simulating large tensor)
print("\n--- Benchmarking Standard Approach ---")
start_mem = get_memory_mb()
start_time = time.time()
# Naive simulation:
# To be fair, we do the computation.
# On CPU this might differ, but we show the relative gap.
try:
cayley_sim = torch.randn(32, 32, 32, device=device)
res = torch.einsum('bki, nkj, ijk -> bni', x, w, cayley_sim)
end_time = time.time()
max_mem = get_memory_mb()
std_time = (end_time - start_time)
std_mem = max_mem - start_mem
print(f"Standard Time: {std_time*1000:.2f} ms")
print(f"Standard Memory Delta: {std_mem:.2f} MB")
except Exception as e:
print(f"Standard Failed: {e}")
std_time = 1.0; std_mem = 1000.0
# 2. Optimized Kernel
print("\n--- Benchmarking Optimized Kernel ---")
start_mem = get_memory_mb()
start_time = time.time()
res_opt = algebra.geometric_linear_layer(x, w)
end_time = time.time()
max_mem = get_memory_mb()
opt_time = (end_time - start_time)
opt_mem = max_mem - start_mem
# If delta is 0/neg due to GC, assume small const
if opt_mem <= 0: opt_mem = 0.1
print(f"Optimized Time: {opt_time*1000:.2f} ms")
print(f"Optimized Memory Delta: {opt_mem:.2f} MB")
print("\n" + "="*60)
print(f"Speedup: {std_time/opt_time:.1f}x")
print(f"Memory Reduction: {std_mem/opt_mem:.1f}x")
print("="*60)
if __name__ == "__main__":
benchmark_gatr_lift()