Skip to content

Commit ea79ad0

Browse files
committed
[Fusion] Fix template codegen + Add custom fusion hook
1 parent 8df3bee commit ea79ad0

File tree

2 files changed

+59
-15
lines changed

2 files changed

+59
-15
lines changed

PyTorchSimFrontend/mlir/mlir_scheduling.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,48 @@ class MLIRScheduling(BaseScheduling):
2525
target_kernel = MLIRKernel
2626
def __init__(self, scheduler):
2727
self.scheduler = scheduler
28-
#self.scheduler.enter_context = self.enter_context_fixed # FIXME. Monkey patch: For fixing the inductor bug
28+
if scheduler is not None:
29+
self.scheduler.can_fuse_origin = self.scheduler.can_fuse
30+
self.scheduler.can_fuse = self.can_fuse_with_exceptions # FIXME. Monkey patch: For prolouge fusion
2931
self.kernel_group = mlir_common.MLIRWrapperKenrelGroup()
3032
self._ready_to_flush = False
3133
self.outer_function = set()
3234
config.inplace_buffers = False # FIXME. inout kernel makes trouble.. So disabled it!
3335
self.max_fusion_size = 5
3436

37+
def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool:
38+
if not extension_config.CONFIG_FUSION:
39+
return False
40+
41+
# Extract base template node
42+
base_template_node1 = [node for node in node1.get_nodes() if node.is_template()]
43+
base_template_node2 = [node for node in node2.get_nodes() if node.is_template()]
44+
45+
# Case 3: Prologue(Pointwise) + Tempalte
46+
if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE:
47+
from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate
48+
from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate
49+
50+
target_node = base_template_node2[0].node
51+
# Currently only BMM, MM support prologue fusion
52+
if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)):
53+
return False
54+
55+
if len(node1.read_writes.writes) != 1:
56+
return False
57+
if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME
58+
return False
59+
60+
# We don't fuse this edge case...
61+
if base_template_node2[0].group[1][0][0] == 1:
62+
return False
63+
64+
if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]:
65+
node1 = self.revert_group(node1)
66+
return True
67+
return self.scheduler.can_fuse_origin(node1, node2)
68+
69+
3570
def _set_flush_status(self, status: bool):
3671
self._ready_to_flush = status
3772

@@ -45,6 +80,9 @@ def get_backend_features(self, device):
4580
def can_fuse_vertical(self, node1, node2):
4681
return self.can_fuse_horizontal(node1, node2)
4782

83+
def can_fuse_multi_outputs_template(self, node1, node2):
84+
return self.can_fuse_horizontal(node1, node2)
85+
4886
def can_fuse_horizontal(self, node1, node2):
4987
if not extension_config.CONFIG_FUSION:
5088
return False
@@ -88,7 +126,7 @@ def can_fuse_horizontal(self, node1, node2):
88126
return same_iter and no_dependency
89127

90128
# Case 1: Template + Pointwise fusion
91-
if len(base_template_node1) == 1 and len(base_template_node2) == 0 and not node2.is_reduction():
129+
if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(base_template_node2) == 0 and not node2.is_reduction():
92130
# Don't fuse maxpool template code
93131
from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate
94132
from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate
@@ -132,9 +170,10 @@ def can_fuse_horizontal(self, node1, node2):
132170
return True
133171

134172
# Case 2: Tempalte + Reduction fusion
135-
if len(base_template_node1) == 1 and len(base_template_node2) == 0 and node2.is_reduction() and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE:
173+
if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(base_template_node2) == 0 and node2.is_reduction() and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE:
136174
from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate
137175
from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate
176+
target_node = base_template_node1[0].node
138177
if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)):
139178
return False
140179

@@ -149,7 +188,7 @@ def can_fuse_horizontal(self, node1, node2):
149188
# We can't fuse dim=-1
150189
layout_possible = stride != 1
151190
# Directed linked?
152-
dependency_check = node2.get_nodes()[0] in [node.node for node in base_template_node1[0].users]# and len(node2.read_writes.reads)==1
191+
dependency_check = writes1 & reads2
153192
dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads])
154193
return size_match and layout_possible and dependency_check and dependency_size
155194

@@ -177,8 +216,8 @@ def can_fuse_horizontal(self, node1, node2):
177216
return True
178217

179218
# Check elementwise fusion
180-
if vars1 == vars2 and reduce1 == reduce2:
181-
return True
219+
if vars1 == vars2 and reduce1 == reduce2 and not node1.is_reduction() and not node2.is_reduction():
220+
return writes1 & reads2
182221
return False
183222

184223
def revert_group(self, act_nodes, args=None, var_ranges=None):

PyTorchSimFrontend/mlir/mlir_template.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -573,8 +573,8 @@ def template_store():
573573
with contextlib.ExitStack() as stack:
574574
stack.enter_context(compute_body.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line()))
575575
if self.reduction_fusion:
576-
compute_body.writelines(self.reduction_body_loop.lines())
577576
compute_body.splice(self.masks)
577+
compute_body.writelines(self.reduction_body_loop.lines())
578578
stack.enter_context(compute_body.indent(attribute="{inner_loop=false}"))
579579
compute_body.splice(self.loads)
580580
compute_body.splice(self.compute)
@@ -848,7 +848,6 @@ def get_spad_size_per_lane(self, tile_m, tile_n):
848848
return max(size, 2) # vector load/store
849849

850850
def load_epilogue(self, name: str, index: sympy.Expr):
851-
index = self.rename_indexing(index)
852851
dram_var = self.kernel_group.args.input(name)
853852
dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name])
854853
dtype = V.graph.get_dtype(name)
@@ -898,7 +897,6 @@ def load_epilogue(self, name: str, index: sympy.Expr):
898897
return out
899898

900899
def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs):
901-
index = self.rename_indexing(index)
902900
dram_var = self.kernel_group.args.output(name)
903901
dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name])
904902
dtype = V.graph.get_dtype(name)
@@ -1000,7 +998,6 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value):
1000998
return sram_var
1001999

10021000
def store_reduction_epilogue(self, name, index, value):
1003-
index = self.rename_indexing(index)
10041001
dram_var = self.kernel_group.args.output(name)
10051002
dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name])
10061003
dtype = V.graph.get_dtype(name)
@@ -1119,11 +1116,19 @@ def set_tile_size(self, template_fusion_info, prologue=False):
11191116
return tile_desc
11201117

11211118
def rename_indexing(self, index) -> sympy.Expr:
1122-
for dim_name, dim_aliased_name in self.dim_aliasing.items():
1123-
index = index.subs(sympy.Symbol(dim_name), sympy.Symbol("tmp_"+dim_aliased_name))
1124-
# To avoid this case ({"index0":"index1", "index1":"index0"})
1125-
for dim_aliased_name in self.dim_aliasing.values():
1126-
index = index.subs(sympy.Symbol("tmp_"+dim_aliased_name), sympy.Symbol(dim_aliased_name))
1119+
# First step: replace dim_name with tmp_+dim_aliased_name to avoid circular dependencies
1120+
# (e.g., {"index0":"index1", "index1":"index0"})
1121+
tmp_subs = {
1122+
sympy.Symbol(dim_name): sympy.Symbol("tmp_"+dim_aliased_name)
1123+
for dim_name, dim_aliased_name in self.dim_aliasing.items()
1124+
}
1125+
index = index.subs(tmp_subs)
1126+
# Second step: replace tmp_+dim_aliased_name with dim_aliased_name
1127+
final_subs = {
1128+
sympy.Symbol("tmp_"+dim_aliased_name): sympy.Symbol(dim_aliased_name)
1129+
for dim_aliased_name in self.dim_aliasing.values()
1130+
}
1131+
index = index.subs(final_subs)
11271132
return index
11281133

11291134
class MLIRTemplateCaller(CUDATemplateCaller):

0 commit comments

Comments
 (0)