Skip to content

Commit fc2aebf

Browse files
committed
[Template] Fix template fusion codegen
1 parent ea79ad0 commit fc2aebf

3 files changed

Lines changed: 37 additions & 36 deletions

File tree

PyTorchSimFrontend/mlir/mlir_gemm_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def render(self,
154154
W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride)
155155
W_tile_desc.set_name("W_buffer")
156156
W_tile_desc.offset = W.get_layout().offset
157-
W_stride = W.get_layout().stride
157+
W_stride = W.get_layout().stride if N>1 else [Y.get_layout().stride[0], 0]
158158
W_idx = [sympy.Symbol("index2") * W_stride[0], sympy.Symbol("index1") * W_stride[1]]
159159

160160
vlane_split_axis = vlane_split_axis if nr_rdim==0 else 0
@@ -163,7 +163,7 @@ def render(self,
163163
Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride)
164164
Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride)
165165
Y_tile_desc.set_name("Y_buffer")
166-
Y_stride = Y.get_layout().stride
166+
Y_stride = Y.get_layout().stride if N>1 else [Y.get_layout().stride[0], 0]
167167
if nr_rdim == 0:
168168
Y_idx = [sympy.Symbol("index0") * Y_stride[0], sympy.Symbol("index1") * Y_stride[1]]
169169
else:

PyTorchSimFrontend/mlir/mlir_scheduling.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ def __init__(self, scheduler):
3535
self.max_fusion_size = 5
3636

3737
def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool:
38-
if not extension_config.CONFIG_FUSION:
39-
return False
38+
if not extension_config.CONFIG_FUSION_PROLOGUE:
39+
return self.scheduler.can_fuse_origin(node1, node2)
4040

4141
# Extract base template node
4242
base_template_node1 = [node for node in node1.get_nodes() if node.is_template()]
4343
base_template_node2 = [node for node in node2.get_nodes() if node.is_template()]
4444

4545
# Case 3: Prologue(Pointwise) + Tempalte
46-
if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE:
46+
if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE:
4747
from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate
4848
from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate
4949

@@ -126,7 +126,7 @@ def can_fuse_horizontal(self, node1, node2):
126126
return same_iter and no_dependency
127127

128128
# Case 1: Template + Pointwise fusion
129-
if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(base_template_node2) == 0 and not node2.is_reduction():
129+
if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and not node2.is_reduction():
130130
# Don't fuse maxpool template code
131131
from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate
132132
from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate
@@ -170,7 +170,7 @@ def can_fuse_horizontal(self, node1, node2):
170170
return True
171171

172172
# Case 2: Tempalte + Reduction fusion
173-
if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(base_template_node2) == 0 and node2.is_reduction() and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE:
173+
if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and node2.is_reduction() and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE:
174174
from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate
175175
from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate
176176
target_node = base_template_node1[0].node
@@ -185,39 +185,35 @@ def can_fuse_horizontal(self, node1, node2):
185185
except:
186186
return False
187187

188-
# We can't fuse dim=-1
189-
layout_possible = stride != 1
188+
# We can't fuse dim=-1 & N == 1
189+
layout_possible = stride != 1 and (1 not in node1.node.get_size())
190190
# Directed linked?
191191
dependency_check = writes1 & reads2
192192
dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads])
193193
return size_match and layout_possible and dependency_check and dependency_size
194194

195195
# Case 3: Prologue(Pointwise) + Tempalte
196-
if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE:
197-
from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate
198-
from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate
199-
200-
target_node = base_template_node2[0].node
201-
# Currently only BMM, MM support prologue fusion
202-
if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)):
203-
return False
204-
205-
if len(node1.read_writes.writes) != 1:
206-
return False
207-
if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME
208-
return False
209-
210-
# We don't fuse this edge case...
211-
if base_template_node2[0].group[1][0][0] == 1:
212-
return False
213-
214-
if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]:
215-
node1 = self.revert_group(node1)
216-
return True
217-
218-
# Check elementwise fusion
219-
if vars1 == vars2 and reduce1 == reduce2 and not node1.is_reduction() and not node2.is_reduction():
220-
return writes1 & reads2
196+
# if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE:
197+
# from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate
198+
# from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate
199+
200+
# target_node = base_template_node2[0].node
201+
# # Currently only BMM, MM support prologue fusion
202+
# if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)):
203+
# return False
204+
205+
# if len(node1.read_writes.writes) != 1:
206+
# return False
207+
# if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME
208+
# return False
209+
210+
# # We don't fuse this edge case...
211+
# if base_template_node2[0].group[1][0][0] == 1:
212+
# return False
213+
214+
# if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]:
215+
# node1 = self.revert_group(node1)
216+
# return True
221217
return False
222218

223219
def revert_group(self, act_nodes, args=None, var_ranges=None):

PyTorchSimFrontend/mlir/mlir_template.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,6 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value):
981981
compute_index_var = ", ".join(zero_var_list)
982982
with self.override_buffer_cse(buffer=self.loads):
983983
out = ops._load(vec_size, type_name, sram_var, compute_index_var, tile_shape)
984-
985984
# Reduction body codegen
986985
with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse):
987986
init = ops.constant(reduction_init(reduction_type, dtype), type_name)
@@ -990,6 +989,12 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value):
990989
mask_shape, mask_var = self.get_mask()
991990
if mask_var is not None:
992991
value = ops.where(mask_var, value, init_vec)
992+
993+
with self.override_buffer_cse(buffer=self.masks, cse=self.mask_cse):
994+
not_first_idx = ops.ne(self.compute_idx, ops.constant(0, "index"))
995+
not_first_idx = ops.broadcast(not_first_idx, compute_vec_size)
996+
out = ops.where(not_first_idx, out, init_vec)
997+
993998
result = reduction_partial_combine_vec(reduction_type, value, out)
994999

9951000
# Store partial result
@@ -1100,7 +1105,7 @@ def set_tile_size(self, template_fusion_info, prologue=False):
11001105
self.r_tile_size = tile_desc.get_tile_size()[-1]
11011106
self.r_dim_size = template_fusion_info['r_dim_size']
11021107
self.reduction_nr_outer_loop = nr_outer_loop
1103-
self.reduction_loop_idx = "reduce_loop_idx"
1108+
self.reduction_loop_idx = self.register_var_cse("reduce_loop_idx", 1, "index")
11041109
self.compute_body_loop.size = r_tile_size
11051110
self.compute_body_loop.step = tile_desc.get_compute_vec_size() // nr_outer_loop
11061111
self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop)

0 commit comments

Comments
 (0)