forked from ademeure/cuda-side-boost
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample_wrapper.py
More file actions
187 lines (159 loc) · 10.3 KB
/
example_wrapper.py
File metadata and controls
187 lines (159 loc) · 10.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# ---------------------------------------------------------------------------
# CUDA L2 Side Boost wrapper originally created for DeeperGEMM
# Useful as an example of how it can be integrated with PyTorch
# ---------------------------------------------------------------------------
# https://github.com/ademeure/cuda-side-boost
# https://github.com/ademeure/DeeperGEMM
# ---------------------------------------------------------------------------
import torch
import ctypes
_lib = None
_cpu_side_index = None
_gpu_side_index = None
_torch_side_index = None
_info = (0, 0, 0, 0, 0)
_info_str = {"num_sms": 0, "side0": 0, "side1": 0, "min": 0, "hash": 0}
# -----------------------------------------------------------------------------
# Externally visible API (+torch.ops.sideaware.[memcpy/one_to_one/elementwise])
# -----------------------------------------------------------------------------
def sideaware_enabled():
return _lib and 1 or 0
# Create custom elementwise kernels (returns id for sideaware_elementwise)
def sideaware_create_kernel(header_code: bytes) -> int:
return _lib.sideaware_create_kernel(header_code)
# GPU SM side metadata (which SM is on which side, SMs per side, etc...)
def sideaware_torch_side_index():
global _torch_side_index
if _torch_side_index is None:
_torch_side_index = torch.zeros(1, dtype=torch.uint8, device="cuda")
return _torch_side_index # torch.uint8 tensor of size num_sms
def sideaware_gpu_side_index():
return _gpu_side_index # gpu buffer of size num_sms
def sideaware_cpu_side_index():
return _cpu_side_index # cpu buffer of size num_sms
def sideaware_info():
return _info_str # {"num_sms", "side0", "side1", "min", "hash"}
def sideaware_info_raw():
return _info # (num_sms, side0, side1, min, hash)
# Load sideaware.so library both directly and through CUDAPluggableAllocator
def sideaware_init(path = 'sideaware.so'):
sideaware_alloc = torch.cuda.memory.CUDAPluggableAllocator(path, 'sideaware_malloc_auto', 'sideaware_free_auto')
torch.cuda.memory.change_current_allocator(sideaware_alloc)
global _lib
_lib = ctypes.CDLL(path)
# Define C-style function signatures
_lib.sideaware_create_kernel.argtypes = [ctypes.c_char_p]
_lib.sideaware_create_kernel.restype = ctypes.c_int
_lib.sideaware_sm_side_summary.argtypes = []
_lib.sideaware_sm_side_summary.restype = ctypes.POINTER(ctypes.c_int * 5)
_lib.sideaware_fill_side_index.argtypes = [ctypes.c_void_p]
_lib.sideaware_fill_side_index.restype = None
_lib.sideaware_gpu_side_index.argtypes = []
_lib.sideaware_gpu_side_index.restype = ctypes.c_void_p
_lib.sideaware_cpu_side_index.argtypes = []
_lib.sideaware_cpu_side_index.restype = ctypes.c_void_p
_lib.sideaware_memcpy.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_void_p]
_lib.sideaware_memcpy.restype = None
_lib.sideaware_one_to_one.argtypes = [ctypes.c_int, ctypes.c_size_t, # kernel_id, num_bytes
ctypes.c_void_p, ctypes.c_void_p, # out0, in0
ctypes.c_int, ctypes.c_void_p] # device, stream
_lib.sideaware_one_to_one.restype = None
_lib.sideaware_elementwise.argtypes = [ctypes.c_int, ctypes.c_size_t, # kernel_id, num_bytes
ctypes.c_void_p, ctypes.c_void_p, # out0, out1
ctypes.c_void_p, ctypes.c_void_p, # out2, out3
ctypes.c_void_p, ctypes.c_void_p, # in0, in1
ctypes.c_void_p, ctypes.c_void_p, # in2, in3
ctypes.c_void_p, ctypes.c_void_p, # sideband_ptr, sideband_value
ctypes.c_int, ctypes.c_int, # parallel_chunks, forced_sm_per_side
ctypes.c_int, ctypes.c_void_p] # device, stream
_lib.sideaware_elementwise.restype = None
# Define PyTorch custom operations for memcpy/one_to_one/elementwise
def direct_register_custom_op(op_lib, op_name, op_func, mutates_args):
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
op_lib.define(op_name + schema_str)
op_lib.impl(op_name, op_func, "CUDA")
sideaware_lib = torch.library.Library("sideaware", "FRAGMENT")
direct_register_custom_op(sideaware_lib, "memcpy", sideaware_memcpy, mutates_args=(["dst"]))
direct_register_custom_op(sideaware_lib, "one_to_one", sideaware_one_to_one, mutates_args=(["dst"]))
direct_register_custom_op(sideaware_lib, "elementwise", sideaware_elementwise,
mutates_args=(["out0", "out1", "out2", "out3", "sideband_tensor"]))
# Initialize sideaware metadata
global _info, _info_str, _torch_side_index, _gpu_side_index, _cpu_side_index
_info = tuple(_lib.sideaware_sm_side_summary().contents)
_info_str = { "num_sms": _info[0], "side0": _info[1], "side1": _info[2], "min": _info[3], "hash": _info[4] }
_torch_side_index = torch.zeros(_info_str["num_sms"], dtype=torch.uint8, device="cuda")
_lib.sideaware_fill_side_index(_torch_side_index.data_ptr())
_gpu_side_index = _lib.sideaware_gpu_side_index()
_cpu_side_index = _lib.sideaware_cpu_side_index()
# Print metadata (shows we are done with initialization)
print(f"L2 Side Aware metadata: {_info_str}")
# -----------------------------------------------------------------------------
# Exposed via torch.ops.sideaware.[memcpy/one_to_one/elementwise]() only
# -----------------------------------------------------------------------------
# Sideaware memcpy (i.e. "default kernel" when no custom kernel is provided via sideaware_create_kernel)
def sideaware_memcpy(dst: torch.Tensor, src: torch.Tensor) -> None:
# Validate inputs
assert dst.device.type == "cuda" and src.device.type == "cuda", "Both tensors must be on CUDA"
assert dst.dtype == src.dtype, "Source and destination must have the same dtype"
assert dst.numel() >= src.numel(), "Destination tensor must be at least as large as source"
# Get pointers and size
dst_ptr = dst.data_ptr()
src_ptr = src.data_ptr()
num_bytes = src.numel() * src.element_size()
# Make sure src and dst are contiguous and aligned
assert src.is_contiguous(), "src must be contiguous"
assert dst.is_contiguous(), "dst must be contiguous"
assert (dst_ptr % 16 == 0) and (src_ptr % 16 == 0), "dst and src must be 16-byte aligned"
device, stream = torch.cuda.current_device(), torch.cuda.current_stream()
_lib.sideaware_memcpy(dst_ptr, src_ptr, num_bytes, device, stream)
# Sideaware single-input / single-output elementwise API (simple version)
def sideaware_one_to_one(kernel_id: int, dst: torch.Tensor, src: torch.Tensor) -> None:
# Validate inputs
src_bytes = dst is not None and dst.numel() * dst.element_size() or 0
dst_bytes = dst is not None and dst.numel() * dst.element_size() or 0
num_bytes = max(src_bytes, dst_bytes)
assert num_bytes > 0
# Make sure src and dst are contiguous and aligned
dst_ptr = dst is not None and dst.data_ptr() or 0
src_ptr = src_bytes and src.data_ptr() or 0
assert src is None or (src.is_contiguous() and src.device.type == "cuda"), "src must be contiguous"
assert dst is None or (dst.is_contiguous() and dst.device.type == "cuda"), "dst must be contiguous"
assert dst_ptr % 16 == 0 and src_ptr % 16 == 0, "dst and src must be 16-byte aligned"
device, stream = torch.cuda.current_device(), torch.cuda.current_stream()
_lib.sideaware_one_to_one(kernel_id, num_bytes, dst_ptr, src_ptr, device, stream)
# Sideaware multi-input / multi-output elementwise API (advanced version)
def sideaware_elementwise(kernel_id: int,
out0: torch.Tensor, out1: torch.Tensor, out2: torch.Tensor, out3: torch.Tensor,
in0: torch.Tensor, in1: torch.Tensor, in2: torch.Tensor, in3: torch.Tensor,
sideband_tensor: torch.Tensor = None, sideband_value: int = 0,
parallel_chunks: int = 0, forced_sm_per_side: int = 0) -> None:
# Validate inputs
src_bytes = out0 is not None and out0.numel() * out0.element_size() or 0
dst_bytes = out0 is not None and out0.numel() * out0.element_size() or 0
num_bytes = max(src_bytes, dst_bytes)
assert num_bytes > 0
# Make sure src and dst are contiguous and aligned
out0_ptr = out0 is not None and out0.data_ptr() or 0
out1_ptr = out1 is not None and out1.data_ptr() or 0
out2_ptr = out2 is not None and out2.data_ptr() or 0
out3_ptr = out3 is not None and out3.data_ptr() or 0
in0_ptr = in0 is not None and in0.data_ptr() or 0
in1_ptr = in1 is not None and in1.data_ptr() or 0
in2_ptr = in2 is not None and in2.data_ptr() or 0
in3_ptr = in3 is not None and in3.data_ptr() or 0
sideband_ptr = sideband_tensor is not None and sideband_tensor.data_ptr() or 0
assert in0 is None or (in0.is_contiguous() and in0.device.type == "cuda"), "in0 must be contiguous"
assert in1 is None or (in1.is_contiguous() and in1.device.type == "cuda"), "in1 must be contiguous"
assert in2 is None or (in2.is_contiguous() and in2.device.type == "cuda"), "in2 must be contiguous"
assert in3 is None or (in3.is_contiguous() and in3.device.type == "cuda"), "in3 must be contiguous"
assert out0 is None or (out0.is_contiguous() and out0.device.type == "cuda"), "out0 must be contiguous"
assert out1 is None or (out1.is_contiguous() and out1.device.type == "cuda"), "out1 must be contiguous"
assert out2 is None or (out2.is_contiguous() and out2.device.type == "cuda"), "out2 must be contiguous"
assert out3 is None or (out3.is_contiguous() and out3.device.type == "cuda"), "out3 must be contiguous"
assert out0_ptr % 16 == 0 and out1_ptr % 16 == 0 and out2_ptr % 16 == 0 and out3_ptr % 16 == 0, "16B alignment"
assert in0_ptr % 16 == 0 and in1_ptr % 16 == 0 and in2_ptr % 16 == 0 and in3_ptr % 16 == 0, "16B alignment"
device, stream = torch.cuda.current_device(), torch.cuda.current_stream()
_lib.sideaware_elementwise(kernel_id, num_bytes,
out0_ptr, out1_ptr, out2_ptr, out3_ptr,
in0_ptr, in1_ptr, in2_ptr, in3_ptr,
sideband_ptr, sideband_value, parallel_chunks, forced_sm_per_side, device, stream)