Skip to content

Commit 4ed48b2

Browse files
authored
Merge pull request #167 from PSAL-POSTECH/autotune
[Frontend] Template autotune
2 parents 5f9f098 + 3692365 commit 4ed48b2

15 files changed

Lines changed: 575 additions & 783 deletions

PyTorchSimFrontend/extension_codecache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def dummy_simulator(*args, **kwargs):
299299
# Dump arguments and meta data
300300
dump_metadata(args, arg_attributes, result_path)
301301
runtime_path = FunctionalSimulator.get_runtime_dump_path(result_path)
302-
if extension_config.CONFIG_TORCHSIM_VALIDATION_MODE or validate:
302+
if not autotune and (extension_config.CONFIG_TORCHSIM_VALIDATION_MODE or validate):
303303
funcsim = FunctionalSimulator(result_path, key)
304304
funcsim.run_spike(args, arg_attributes,
305305
runtime_path, self.validation_binary_name,

PyTorchSimFrontend/extension_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@
4646

4747
# AUTOTUNE config
4848
CONFIG_AUTOTUNE = int(os.environ.get('AUTOTUNE', default=True))
49+
CONFIG_AUTOTUNE_TEMPLATE = int(os.environ.get('AUTOTUNE_TEMPLATE', default=True))
4950
CONFIG_MAX_AUTOTUNE_TRY = int(os.environ.get('MAX_AUTOTUNE_TRY', default=10))
51+
CONFIG_AUTOTUNE_TEMPLATE_TOPK = int(os.environ.get('AUTOTUNE_TEMPLATE_TOPK', default=4))
5052

5153
# For block sparse
5254
CONFIG_BLOCK_SPARSE = int(os.environ.get('BLOCK_SPARSE', default=0))

PyTorchSimFrontend/mlir/mlir_autotune.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def cached_run_fn(*args, **kwargs):
7474
self.source_code, vectorlane_size=self.extra_args["vector_lane"],
7575
loop_size=None, spad_info=self.extra_args["spad_info"],
7676
vlen=self.extra_args["vlen"], arg_attributes=self.extra_args["arg_attributes"],
77-
origins="Unknown", silent_mode=True)
77+
origins="Unknown", silent_mode=True,
78+
validate=self.extra_args['validate'], autotune=self.extra_args['autotune'])
7879

7980
args = [
8081
tensor

PyTorchSimFrontend/mlir/mlir_bmm_template.py

Lines changed: 59 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate
77
from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel
88
from torch._inductor.ir import IRNode
9-
from torch._inductor.codecache import write_atomic
10-
import PyTorchSimFrontend.extension_codecache as extension_codecache
119
from PyTorchSimFrontend.mlir import mlir_common
1210

1311
BMM_TEMPLATE = r"""
@@ -162,51 +160,31 @@ def render(self,
162160
template_buffer_node = None,
163161
epilogue_nodes: Optional[List[IRNode]] = None,
164162
prologue_nodes: Optional[List[IRNode]] = None,
163+
tile_info = None,
165164
**kwargs):
166-
if template_buffer_node is not None:
167-
self.output_node = template_buffer_node
168-
169-
# Extract input arguments info
170-
X, W = self.input_nodes[0], self.input_nodes[1]
171-
Y = self.output_node
172-
Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2]
173-
174-
W_tensor = empty_strided(W.layout.size, W.layout.stride)
175-
X_tensor = empty_strided(X.layout.size, X.layout.stride)
176-
if len(W_tensor.size()) > 3 or len(W_tensor.size()) == 2:
177-
W_tensor = W_tensor.view([-1, W_tensor.shape[-2], W_tensor.shape[-1]])
178-
if len(X_tensor.size()) > 3 or len(X_tensor.size()) == 2:
179-
X_tensor = X_tensor.view([-1, X_tensor.shape[-2], X_tensor.shape[-1]])
180-
B, M, N, K = X_tensor.size()[0], X_tensor.size()[1], W_tensor.size()[2], X_tensor.size()[2]
181-
182-
W_stride = W_tensor.stride()
183-
X_stride = X_tensor.stride()
184-
185-
# Select tile size
186-
n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0
187-
TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node)
188-
SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or prologue_nodes else kernel.vector_lane
189-
SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane
190-
SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane
165+
X, W, Y, Bias, W_tensor, X_tensor, B, M, N, K, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes)
166+
if tile_info is None:
167+
TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node)[0]
168+
else:
169+
TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info
191170

192171
TOG_latency = M if TILE_M > M else TILE_M
193172
kernel.loop_size = [TOG_latency, TILE_N, TILE_K]
194-
TILE_K = TILE_K // 2 if prologue_nodes else TILE_K
195173

196174
# Select template code
197175
nr_reduction_nodes = [node for node in epilogue_nodes if node.is_reduction()] if epilogue_nodes is not None else []
198176
if nr_reduction_nodes:
199-
template = BMM_REDUCTION_TEMPLATE
200-
epilogue_dim_aliasing = {"index0":"index0", "index1":"index2", "index2": "index1"}
201-
nr_rdim = 1
177+
template = BMM_REDUCTION_TEMPLATE
178+
epilogue_dim_aliasing = {"index0":"index0", "index1":"index2", "index2": "index1"}
179+
nr_rdim = 1
202180
elif prologue_nodes:
203-
template = BMM_PROLOGUE_TEMPLATE
204-
epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"}
205-
nr_rdim = 0
181+
template = BMM_PROLOGUE_TEMPLATE
182+
epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"}
183+
nr_rdim = 0
206184
else:
207-
template = BMM_TEMPLATE
208-
epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"}
209-
nr_rdim = 0
185+
template = BMM_TEMPLATE
186+
epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"}
187+
nr_rdim = 0
210188

211189
# Prepare tile descriptors
212190
vlane_stride = 1
@@ -323,19 +301,53 @@ def render(self,
323301
dram_idx = Y_idx,
324302
dram_tile_desc = Y_tile_desc,
325303
nr_rdim = nr_rdim,
304+
r_dim_size = M,
326305
dim_aliasing = epilogue_dim_aliasing
327306
)
328307
code = self._template_from_string(template).render(**kernel.render_options)
329308
kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]])
330309
return code
331310

332-
def codegen_header(self, code, extra_headers):
333-
write_path = extension_codecache.get_write_path(code)
334-
if not os.path.exists(write_path):
335-
os.makedirs(write_path)
336-
spike_write_path = os.path.join(write_path, "global_var.h")
337-
gem5_write_path = os.path.join(write_path, "gem5_global_var.h")
338-
if not os.path.exists(spike_write_path):
339-
write_atomic(spike_write_path, extra_headers[0])
340-
if not os.path.exists(gem5_write_path):
341-
write_atomic(gem5_write_path, extra_headers[1])
311+
def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes):
312+
if template_buffer_node is not None:
313+
self.output_node = template_buffer_node
314+
315+
# Extract input arguments info
316+
X, W = self.input_nodes[0], self.input_nodes[1]
317+
Y = self.output_node
318+
Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2]
319+
320+
W_tensor = empty_strided(W.layout.size, W.layout.stride)
321+
X_tensor = empty_strided(X.layout.size, X.layout.stride)
322+
if len(W_tensor.size()) > 3 or len(W_tensor.size()) == 2:
323+
W_tensor = W_tensor.view([-1, W_tensor.shape[-2], W_tensor.shape[-1]])
324+
if len(X_tensor.size()) > 3 or len(X_tensor.size()) == 2:
325+
X_tensor = X_tensor.view([-1, X_tensor.shape[-2], X_tensor.shape[-1]])
326+
B, M, N, K = X_tensor.size()[0], X_tensor.size()[1], W_tensor.size()[2], X_tensor.size()[2]
327+
328+
W_stride = W_tensor.stride()
329+
X_stride = X_tensor.stride()
330+
331+
# Select tile size
332+
n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0
333+
n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0
334+
return X,W,Y,Bias,W_tensor,X_tensor,B,M,N,K,n_extra_node, n_prologue_node
335+
336+
def get_tile_candidates(self,
337+
kernel: MLIRTemplateKernel,
338+
template_buffer_node = None,
339+
epilogue_nodes: Optional[List[IRNode]] = None,
340+
prologue_nodes: Optional[List[IRNode]] = None,
341+
**kwargs):
342+
X, W, Y, Bias, W_tensor, X_tensor, B, M, N, K, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes)
343+
return self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node)
344+
345+
def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node):
346+
tile_candidates = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node)
347+
for idx, (TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates):
348+
SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or n_prologue_node else kernel.vector_lane
349+
SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane
350+
SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane
351+
TILE_K = TILE_K // 2 if n_prologue_node else TILE_K
352+
tile_candidates[idx] = TILE_M,TILE_N,TILE_K,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K
353+
return tile_candidates

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
sympy_product
1818
)
1919
from torch.utils._sympy.functions import ModularIndexing, FloorDiv
20-
import PyTorchSimFrontend.extension_codecache as extension_codecache
21-
20+
from PyTorchSimFrontend import extension_codecache
2221
from PyTorchSimFrontend import extension_config
2322
from . import mlir_common
2423
from .mlir_common import LoopLevel, LoopNest
@@ -1565,10 +1564,10 @@ def make_choices(self, nodes, kernel_name):
15651564
current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size())
15661565
search_space.add(current_tile_sz)
15671566

1568-
print(f"[Auto-tune] Trying tile size: {current_tile_sz}, vlane_stride: {vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}")
1567+
print(f"[Auto-tune] Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}")
15691568
self._prepare_simulator_headers(src_code)
15701569
bench_runner = self.run_bench(nodes, kernel_name, src_code)
1571-
choices.append((bench_runner, src_code, self.kernel_group))
1570+
choices.append((bench_runner, src_code, current_tile_sz, self.kernel_group.tile_desc.vmap.vlane_stride))
15721571

15731572
while prevent_infinite_loop < 10 and candidate_axes:
15741573
for axis in list(candidate_axes):
@@ -1593,33 +1592,39 @@ def make_choices(self, nodes, kernel_name):
15931592
src_code = super().codegen_nodes(nodes, kernel_name)
15941593
current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size())
15951594

1595+
# FIXME. How to intergrate this constraint to tile system?
1596+
pad = self.kernel_group.tile_desc.vmap.get_used_vlane(current_tile_sz) * self.kernel_group.tile_desc.vmap.vlane_stride
1597+
vlane_size = current_tile_sz[self.kernel_group.tile_desc.vmap.vlane_split_axis]
1598+
if vlane_size > pad and vlane_size % pad:
1599+
prevent_infinite_loop += 1
1600+
continue
1601+
15961602
# If tile size is converged for this axis, remove from candidate axes
15971603
if current_tile_sz in search_space:
15981604
candidate_axes.remove(axis)
15991605
continue
16001606

16011607
# Add this choice
16021608
search_space.add(current_tile_sz)
1603-
print(f"[Auto-tune] Trying tile size: {current_tile_sz}, vlane_stride: {vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}")
1609+
print(f"[Auto-tune] Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}")
16041610
self._prepare_simulator_headers(src_code)
16051611
bench_runner = self.run_bench(nodes, kernel_name, src_code)
1606-
choices.append((bench_runner, src_code, self.kernel_group))
1612+
choices.append((bench_runner, src_code, self.kernel_group.tile_desc.get_tile_size(), self.kernel_group.tile_desc.vmap.vlane_stride))
16071613
prevent_infinite_loop += 1
16081614
self.kernel_group.tile_desc.prev_tail_threshold = prev_tail_threshold
16091615
return choices
16101616

1611-
def autotune(self, nodes, kernel_name):
1617+
def autotune(self, *args):
16121618
def get_cycle(choice):
1613-
bench_runner, src_code, kernel_group = choice
1619+
bench_runner = choice[0]
16141620
for n_try in range(extension_config.CONFIG_MAX_AUTOTUNE_TRY): # TODO: make simple
16151621
try:
1616-
# bench_runner = self.run_bench(nodes, kernel_name, src_code)
1617-
out = bench_runner(validate=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, autotune=True)
1622+
out = bench_runner()
16181623
return out[-1]
16191624
except (extension_codecache.SpadOverflowError, RuntimeError) as e:
16201625
return float("inf")
16211626
return float("inf") # Exceeded maximum number of autotuning attempts
1622-
choices = self.make_choices(nodes, kernel_name)
1627+
choices = self.make_choices(*args)
16231628

16241629
if len(choices) == 0: # can't autotune
16251630
return None
@@ -1628,21 +1633,25 @@ def get_cycle(choice):
16281633
max_idx = results.index(min(results))
16291634
if min(results) == float("inf"):
16301635
raise RuntimeError("Failed to find optimal tile size...")
1631-
print(f"[Auto-tune] Optimal tile size: {choices[max_idx][2].tile_desc.get_tile_size()}, vlane_stride: {choices[max_idx][2].tile_desc.vmap.vlane_stride}, cycles: {results[max_idx]}")
1636+
self._log_autotune_result(choices[max_idx], results[max_idx])
16321637
optimal_src_code = choices[max_idx][1]
16331638
return optimal_src_code
16341639

1640+
def _log_autotune_result(self, best_choice, best_cycle):
1641+
print(
1642+
f"[Auto-tune] Optimal tile size: {list(best_choice[2])}, "
1643+
f"vlane_stride: {best_choice[3]}, "
1644+
f"cycles: {best_cycle}"
1645+
)
1646+
16351647
def codegen_nodes(self, nodes, kernel_name):
16361648
src_code = super().codegen_nodes(nodes, kernel_name)
16371649
self._prepare_simulator_headers(src_code)
1638-
if not extension_config.CONFIG_AUTOTUNE or extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY:
1639-
return src_code
1640-
else:
1650+
if extension_config.CONFIG_AUTOTUNE and not extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY:
16411651
optimal_src_code = self.autotune(nodes, kernel_name)
1642-
if optimal_src_code:
1652+
if optimal_src_code is not None:
16431653
return optimal_src_code
1644-
else:
1645-
return src_code
1654+
return src_code
16461655

16471656
def _prepare_simulator_headers(self, src_code):
16481657
write_path = extension_codecache.get_write_path(src_code)

PyTorchSimFrontend/mlir/mlir_common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ def select_vlane_axis(self):
408408
self.vmap.vlane_split_axis = best_vlane_split_axis
409409

410410
def pad_vlane_tile(self):
411+
# FIXME. this doesn't follow tile constraints...
411412
vlane_split_axis, vlane_stride, vector_lane = self.vmap.vlane_split_axis, self.vmap.vlane_stride, self.vmap.vector_lane
412413
used_vlane = min(math.ceil(self._tile_size[vlane_split_axis] / vlane_stride), vector_lane)
413414
padded_size = used_vlane * vlane_stride
@@ -790,7 +791,9 @@ def run_bench(self, nodes, kernel_name, src_code):
790791
"vector_lane" : self.vector_lane,
791792
"spad_info": self.spad_info,
792793
"vlen" : self.vlen,
793-
"arg_attributes" : arg_attributes
794+
"arg_attributes" : arg_attributes,
795+
"validate" : extension_config.CONFIG_TORCHSIM_VALIDATION_MODE,
796+
"autotune" : True,
794797
},
795798
source_code=src_code,
796799
)

0 commit comments

Comments
 (0)