Skip to content

Commit c98e5fe

Browse files
committed
[Frontend] Introduce recompile signal + Support floordiv pattern
1 parent a9c4e1f commit c98e5fe

3 files changed

Lines changed: 154 additions & 31 deletions

File tree

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
is_welford_reduction,
1616
sympy_product
1717
)
18-
from torch.utils._sympy.functions import ModularIndexing
18+
from torch.utils._sympy.functions import ModularIndexing, FloorDiv
1919
import PyTorchSimFrontend.extension_codecache as extension_codecache
2020

2121
from PyTorchSimFrontend import extension_config
@@ -260,10 +260,10 @@ def binary_elementwise_common(operand1, operand2, var_info):
260260
operand2 = ops.to_dtype(operand2, op_type1[1], var_info)
261261
op_type2 = var_info[operand2]
262262
elif op_type1[1][0] == op_type2[1][0]:
263-
if int(op_type1[1][1:]) > int(op_type2[1][1:]):
263+
if mlir_common.MLIR_TO_BIT[op_type1[1]] > mlir_common.MLIR_TO_BIT[op_type2[1]]:
264264
operand2 = ops.ext(operand2, op_type1[1])
265265
op_type2 = var_info[operand2]
266-
elif int(op_type1[1][1:]) < int(op_type2[1][1:]):
266+
elif mlir_common.MLIR_TO_BIT[op_type1[1]] < mlir_common.MLIR_TO_BIT[op_type2[1]]:
267267
operand1 = ops.ext(operand1, op_type2[1])
268268
op_type1 = var_info[operand1]
269269
else:
@@ -348,17 +348,21 @@ def maximum(operand1, operand2, *args, var_info=None, **kwargs):
348348
@staticmethod
349349
def to_dtype(operand, dst_mlir_dtype, *args, var_info=None, **kwargs):
350350
src_mlir_dtype = var_info[operand][1]
351+
if src_mlir_dtype == "index":
352+
operand = ops.index_cast(operand, "i64", var_info=var_info)
353+
src_mlir_dtype = var_info[operand][1]
354+
351355
tile_size = var_info[operand][0]
352356
if isinstance(dst_mlir_dtype, torch.dtype):
353357
dst_mlir_dtype = mlir_common.DTYPE_TO_MLIR[dst_mlir_dtype]
354-
dst_bits = int(dst_mlir_dtype[1:])
355-
src_bits = int(src_mlir_dtype[1:])
358+
dst_bits = mlir_common.MLIR_TO_BIT[dst_mlir_dtype]
359+
src_bits = mlir_common.MLIR_TO_BIT[src_mlir_dtype]
356360
shape = f"vector<{tile_size}x{dst_mlir_dtype}>" if tile_size > 1 else dst_mlir_dtype
357361
src_shape = f"vector<{tile_size}x{src_mlir_dtype}>" if tile_size > 1 else src_mlir_dtype
358362
if dst_mlir_dtype[0] == "i" and src_mlir_dtype[0] == "f":
359-
return f"arith.fptoui%{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype]
363+
return f"arith.fptoui %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype]
360364
if dst_mlir_dtype[0] == "f" and src_mlir_dtype[0] == "i":
361-
return f"arith.uitofp%{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype]
365+
return f"arith.uitofp %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype]
362366
if dst_mlir_dtype[0] == "i":
363367
if dst_bits > src_bits:
364368
return f"arith.extui %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype]
@@ -955,6 +959,8 @@ def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> com
955959
# Extract index var
956960
indirect_args = [f"%{i}" for i in indirect_dims]
957961
expr_str = str(expr)
962+
if "//" in expr_str:
963+
expr_str = expr_str.replace("//", " floordiv ")
958964
args = ", ".join(map(str, indices))
959965
map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[{','.join(indirect_dims)}] -> ({expr_str})>")
960966
args = ", ".join([f"%{i}" for i in indices])
@@ -1063,6 +1069,9 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs):
10631069
vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype)
10641070
compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size()
10651071
require_store = True
1072+
if compute_vec_size < self.var_info[value][0]:
1073+
value = self.cse.generate(self.stores, f"vector.extract_strided_slice %{value} {{offsets = [0], sizes = [{compute_vec_size}], strides = [1]}}: vector<{self.var_info[value][0]}x{self.var_info[value][1]}> to {vshape}")
1074+
self.register_var_info(value, [compute_vec_size, mlir_dtype])
10661075

10671076
if str(value) in self.spad_buffer_dict:
10681077
# Todo. If tile_size is not same (i.e., view operation), we can't apply peephole optimization easily
@@ -1680,6 +1689,40 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
16801689
sorted_keys = sorted(dram_dict.keys())
16811690
dram_stride = sum((dram_dict[key] for key in sorted_keys), [])
16821691

1692+
# Support floordiv pattern
1693+
# FIXME. How to integrate implicit dims and floordiv?
1694+
# This was introduced to support GroupNorm
1695+
if index.has(FloorDiv) and not index.has(ModularIndexing):
1696+
dim_divisor = [1] * len(local_dims)
1697+
for sub in sympy.preorder_traversal(index):
1698+
if isinstance(sub, FloorDiv):
1699+
if not str(sub.args[0]).startswith("index"):
1700+
continue
1701+
dim_idx = int((str(sub.args[0])[5:]))
1702+
if int(self.kernel_group.tile_desc.get_tile_size()[dim_idx] % sub.args[1]) != 0:
1703+
# In this case, need to recompile
1704+
original_size = self.kernel_group.tile_desc.get_tile_size()[dim_idx]
1705+
divisor = sub.args[1]
1706+
new_size = ((original_size + divisor - 1) // divisor) * divisor
1707+
new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size())
1708+
new_tile_sizes[dim_idx] = new_size
1709+
self.kernel_group.tile_desc.set_tile_size(new_tile_sizes)
1710+
1711+
# Send recompile signal
1712+
self.reset("recompile")
1713+
raise mlir_common.RecompileSignal(f"Tile size {self.kernel_group.tile_desc.get_tile_size()[dim_idx]} is not divisible by {sub.args[1]}")
1714+
dim_divisor[dim_idx] = sub.args[1]
1715+
1716+
# Update dram_stride, just insert 0 next to target dim
1717+
offset = 0
1718+
for dim_idx, divisor in enumerate(dim_divisor):
1719+
if divisor == 1:
1720+
continue
1721+
dram_stride.insert(dim_idx+offset+1, 0)
1722+
local_tile_desc.apply_divisor(dim_idx+offset, divisor, "pad")
1723+
local_tile_desc.apply_divisor(dim_idx+offset, divisor, "split")
1724+
offset = offset+1
1725+
16831726
# FIXME. It will be nice to modify node instead of this exception handling...
16841727
if len(self.itervars) == 1 and self.reduction_depth == 0:
16851728
# In case of reduction loop only case, we will add dummy loop so shift it once
@@ -1810,7 +1853,11 @@ def convert_indirect_indexing(self, index :sympy.Expr):
18101853

18111854
# Load indirect operands
18121855
for target_dim in indirect_dims:
1813-
sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim]
1856+
if target_dim in self.spad_buffer_dict:
1857+
sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim]
1858+
else:
1859+
raise NotImplementedError("TODO.")
1860+
18141861
mlir_dtype = vshape.split("x")[1][:-1]
18151862
vshape = f"vector<{tile_numel_per_lane}x{mlir_dtype}>" # FIXME. Maybe require fine grain compute...
18161863
if tile_numel_per_lane > 1:

PyTorchSimFrontend/mlir/mlir_common.py

Lines changed: 92 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,19 @@
6161
torch.bfloat16: "bfloat16",
6262
}
6363

64+
MLIR_TO_BIT = {
65+
"i1": 1,
66+
"i8": 8,
67+
"i16": 16,
68+
"i32": 32,
69+
"i64": 64,
70+
"f16": 16,
71+
"f32": 32,
72+
"f64": 64,
73+
"bf16": 16,
74+
"index": 64
75+
}
76+
6477
DTYPE_LOWP_FP = [
6578
torch.bfloat16,
6679
torch.float16,
@@ -105,6 +118,14 @@ def ctx():
105118

106119
return ctx()
107120

121+
class RecompileSignal(BaseException):
122+
"""
123+
Exception raised when a recompilation of a kernel or code block is required.
124+
"""
125+
def __init__(self, message="Recompilation requested."):
126+
self.message = message
127+
super().__init__(self.message)
128+
108129
class MLIRKernelArgs(common.KernelArgs):
109130
MLIR_ARGS_IN = 0x01
110131
MLIR_ARGS_OUT = 0x02
@@ -310,7 +331,7 @@ def get_compute_vec_size(self):
310331
if self.vec_size is not None:
311332
return self.vec_size
312333
if self.nr_rdim:
313-
assert self.nr_rdim==1
334+
assert self.nr_rdim!=0
314335
val = self.get_numel_per_lane() // self._tile_size[-1]
315336
if self.get_numel_per_lane() >= val * 8:
316337
return val*8
@@ -331,6 +352,44 @@ def get_compute_vec_size(self):
331352
def div_round_up(size, round_val):
332353
return (size + round_val - 1) // round_val
333354

355+
def apply_divisor(self, axis: int, divisor: int, mode: str = "split"):
356+
# Apply divisor to tile size at given axis.
357+
# This method based on axis order.
358+
old_size = self._tile_size[axis]
359+
if divisor == 1:
360+
return
361+
padded = self.div_round_up(old_size, divisor) * divisor
362+
outer = self.div_round_up(old_size, divisor)
363+
inner = divisor
364+
if mode == "pad":
365+
self._tile_size[axis] = padded
366+
self.update_tile_stride()
367+
return
368+
elif mode == "split":
369+
new_sizes = list(self._tile_size)
370+
new_sizes[axis] = outer
371+
new_sizes.insert(axis + 1, inner)
372+
self._tile_size = new_sizes
373+
374+
# Update tile_axis_order
375+
old_order_val = self.tile_axis_order[axis]
376+
new_order = list(self.tile_axis_order)
377+
new_order.insert(axis + 1, old_order_val + 0.1)
378+
sorted_pairs = sorted(
379+
zip(range(len(new_order)), new_order),
380+
key=lambda x: x[1]
381+
)
382+
self.tile_axis_order = [idx for idx, _ in sorted_pairs]
383+
self.update_tile_stride()
384+
385+
if self.vlane_split_axis == axis:
386+
self.vlane_split_axis = axis
387+
elif self.vlane_split_axis > axis:
388+
self.vlane_split_axis += 1
389+
return
390+
else:
391+
raise ValueError(f"Unknown mode: {mode}. Supported modes are 'pad' and 'split'.")
392+
334393
class MLIRWrapperKenrelGroup(cpp.KernelGroup):
335394
def __init__(self):
336395
super().__init__()
@@ -538,6 +597,8 @@ def dummy_tile_size():
538597
dim = int(self.recodegen.split("_")[-1])
539598
tile_size = self.kernel_group.tile_desc.get_tile_size() # TODO:
540599
tile_size[dim] = tile_size[dim] * 2
600+
elif self.recodegen == "recompile":
601+
return self.kernel_group.tile_desc
541602
else:
542603
raise NotImplementedError(f"Unknown recodegen reason: {self.recodegen}")
543604

@@ -608,26 +669,36 @@ def dummy_tile_size():
608669
return tile_desc
609670

610671
def codegen_nodes(self, nodes, kernel_name):
611-
_, (group, reduction_group) = max(
612-
nodes, key=lambda x: int(x.is_reduction())
613-
).group
614-
615-
# Set node range info
616-
vars, reduction_vars = self.set_ranges(group, reduction_group)
617-
tile_desc = self.compute_tile_size(nodes, vars, reduction_vars)
618-
self.compute_body_loop.size = tile_desc.get_numel_per_lane()
619-
self.compute_body_loop.step = tile_desc.get_compute_vec_size()
620-
self.kernel_group.set_tile_info(tile_desc)
621-
622-
_, _, _, self.buffer_types = self.kernel_group.args.mlir_argdefs()
623-
with self as kernel:
624-
for node in nodes:
625-
node.run(vars, reduction_vars)
626-
V.graph.removed_buffers |= self.removed_buffers
627-
# V.graph.inplaced_to_remove |= self.inplaced_to_remove
628-
src_code = self.codegen_kernel(kernel_name=kernel_name)
629-
self.meta_kernel()
630-
return src_code
672+
recompile_try = 0
673+
max_retry_compile = 5
674+
while True:
675+
_, (group, reduction_group) = max(
676+
nodes, key=lambda x: int(x.is_reduction())
677+
).group
678+
679+
# Set node range info
680+
vars, reduction_vars = self.set_ranges(group, reduction_group)
681+
tile_desc = self.compute_tile_size(nodes, vars, reduction_vars)
682+
self.compute_body_loop.size = tile_desc.get_numel_per_lane()
683+
self.compute_body_loop.step = tile_desc.get_compute_vec_size()
684+
self.kernel_group.set_tile_info(tile_desc)
685+
try:
686+
_, _, _, self.buffer_types = self.kernel_group.args.mlir_argdefs()
687+
with self as kernel:
688+
for node in nodes:
689+
node.run(vars, reduction_vars)
690+
except RecompileSignal as e:
691+
recompile_try += 1
692+
if recompile_try > max_retry_compile:
693+
raise RuntimeError("Failed to compile kernel after multiple attempts.")
694+
# Retry compile nodes
695+
#print(f"Try recompile({recompile_try}/{max_retry_compile}). Reason: {e}")
696+
continue
697+
V.graph.removed_buffers |= self.removed_buffers
698+
# V.graph.inplaced_to_remove |= self.inplaced_to_remove
699+
src_code = self.codegen_kernel(kernel_name=kernel_name)
700+
self.meta_kernel()
701+
return src_code
631702

632703
def run_bench(self, nodes, kernel_name, src_code):
633704
_, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs()

PyTorchSimFrontend/mlir/mlir_scheduling.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,15 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule
4646
if (isinstance(base_template_node1[0].node.template, MLIRGemmTemplate) or isinstance(base_template_node1[0].node.template, MLIRBMMTemplate)) and node2.is_reduction():
4747
# For matmul/bmm+reduction case
4848
size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1)
49-
stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1]
5049
target_symbol = symbols("r0")
50+
try:
51+
stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1]
52+
stride = int(sympify(stride).coeff(target_symbol))
53+
except sympy.core.SympifyError:
54+
return False
55+
5156
# We can't fuse dim=-1
52-
layout_possible = int(sympify(stride).coeff(target_symbol)) != 1
57+
layout_possible = stride != 1
5358
# Directed linked?
5459
dependency_check = node2.get_nodes()[0] in [node.node for node in base_template_node1[0].users]# and len(node2.read_writes.reads)==1
5560
dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads])

0 commit comments

Comments
 (0)