forked from ROCm/composable_kernel
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path05_numpy_integration.py
More file actions
166 lines (128 loc) · 4.83 KB
/
Copy path05_numpy_integration.py
File metadata and controls
166 lines (128 loc) · 4.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 05: NumPy Integration
Shows how to create a GPU-accelerated matmul wrapper.
Usage:
python3 05_numpy_integration.py
python3 05_numpy_integration.py --help
python3 05_numpy_integration.py --dtype bf16
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
Dispatcher,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
detect_gpu_arch,
)
class GPUMatmul:
"""GPU-accelerated matrix multiplication wrapper."""
def __init__(self, dispatcher: Dispatcher):
self.dispatcher = dispatcher
def __call__(self, A: np.ndarray, B: np.ndarray) -> np.ndarray:
"""Compute C = A @ B on GPU with CPU fallback."""
M, K = A.shape
K2, N = B.shape
if K != K2:
raise ValueError(f"Dimension mismatch: {A.shape} @ {B.shape}")
if not self.dispatcher.is_supported(M, N, K):
return np.matmul(A, B)
result = self.dispatcher.run(A, B, M, N, K)
return result.output if result.success else np.matmul(A, B)
def main():
parser = argparse.ArgumentParser(
description="NumPy Integration Example - GPU-accelerated matmul wrapper",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 05_numpy_integration.py # Default FP16
python3 05_numpy_integration.py --dtype bf16 # BF16 mode
""",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--arch",
default=detect_gpu_arch(),
help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()
print("=" * 60)
print("Example 05: NumPy Integration")
print("=" * 60)
# =========================================================================
# Step 1: Setup dispatcher
# =========================================================================
print("\nStep 1: Setup Dispatcher")
config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=128,
tile_n=128,
tile_k=32,
gfx_arch=args.arch,
)
setup = setup_gemm_dispatcher(config, registry_name="numpy", verbose=True)
if not setup.success:
print(f" ERROR: {setup.error}")
return 1
dispatcher = setup.dispatcher
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
# =========================================================================
# Step 2: Create GPU matmul wrapper
# =========================================================================
print("\nStep 2: Create GPUMatmul")
gpu_matmul = GPUMatmul(dispatcher=dispatcher)
print(" gpu_matmul ready")
# =========================================================================
# Step 3: Demo - Simple multiplication using gpu_matmul
# =========================================================================
print("\nStep 3: Demo - Simple Multiplication")
A = np.random.randn(1024, 512).astype(np_dtype) * 0.1
B = np.random.randn(512, 256).astype(np_dtype) * 0.1
# Use the gpu_matmul wrapper
C = gpu_matmul(A, B)
print(f" gpu_matmul result: {C.shape}, sum={C.sum():.4f}")
M, K = A.shape
_, N = B.shape
result = dispatcher.run(A, B, M, N, K)
print(f" A: {A.shape}, B: {B.shape} -> C: {result.output.shape}")
print(f" GPU: {result.time_ms:.4f} ms, {result.tflops:.2f} TFLOPS")
# =========================================================================
# Step 4: Demo - FFN block
# =========================================================================
print("\nStep 4: Demo - FFN Block")
batch, hidden, ffn = 128, 768, 3072
X = np.random.randn(batch, hidden).astype(np_dtype) * 0.02
W1 = np.random.randn(hidden, ffn).astype(np_dtype) * 0.02
W2 = np.random.randn(ffn, hidden).astype(np_dtype) * 0.02
result1 = dispatcher.run(X, W1, batch, ffn, hidden)
H = result1.output
result2 = dispatcher.run(H, W2, batch, hidden, ffn)
print(f" X: {X.shape} -> H: {H.shape} -> Y: {result2.output.shape}")
print(f" Total: {result1.time_ms + result2.time_ms:.4f} ms")
# Cleanup
cleanup_gemm()
# Summary
print("\n" + "=" * 60)
print("NumPy Integration Pattern:")
print("=" * 60)
print(" 1. setup_gemm_dispatcher(config)")
print(" 2. GPUMatmul(dispatcher)")
print(" 3. C = gpu_matmul(A, B)")
print("=" * 60)
return 0
if __name__ == "__main__":
sys.exit(main())