@@ -35,15 +35,15 @@ def __init__(self, scheduler):
3535 self .max_fusion_size = 5
3636
3737 def can_fuse_with_exceptions (self , node1 : BaseSchedulerNode , node2 : BaseSchedulerNode ) -> bool :
38- if not extension_config .CONFIG_FUSION :
39- return False
38+ if not extension_config .CONFIG_FUSION_PROLOGUE :
39+ return self . scheduler . can_fuse_origin ( node1 , node2 )
4040
4141 # Extract base template node
4242 base_template_node1 = [node for node in node1 .get_nodes () if node .is_template ()]
4343 base_template_node2 = [node for node in node2 .get_nodes () if node .is_template ()]
4444
4545 # Case 3: Prologue(Pointwise) + Tempalte
46- if len (base_template_node1 ) == 0 and len (node1 .get_nodes ())== 1 and not node1 .is_reduction () and len (base_template_node2 ) == 1 and extension_config .CONFIG_FUSION_PROLOGUE :
46+ if len (base_template_node1 ) == 0 and len (node1 .get_nodes ())== 1 and len ( node2 . get_nodes ()) == 1 and not node1 .is_reduction () and len (base_template_node2 ) == 1 and extension_config .CONFIG_FUSION_PROLOGUE :
4747 from PyTorchSimFrontend .mlir .mlir_gemm_template import MLIRGemmTemplate
4848 from PyTorchSimFrontend .mlir .mlir_bmm_template import MLIRBMMTemplate
4949
@@ -126,7 +126,7 @@ def can_fuse_horizontal(self, node1, node2):
126126 return same_iter and no_dependency
127127
128128 # Case 1: Template + Pointwise fusion
129- if len (base_template_node1 ) == 1 and len (node1 .get_nodes ())== 1 and len (base_template_node2 ) == 0 and not node2 .is_reduction ():
129+ if len (base_template_node1 ) == 1 and len (node1 .get_nodes ())== 1 and len (node2 . get_nodes ()) == 1 and len ( base_template_node2 ) == 0 and not node2 .is_reduction ():
130130 # Don't fuse maxpool template code
131131 from PyTorchSimFrontend .mlir .mlir_maxpool_template import MLIRMaxPoolTemplate
132132 from PyTorchSimFrontend .mlir .mlir_bmm_template import MLIRBMMTemplate
@@ -170,7 +170,7 @@ def can_fuse_horizontal(self, node1, node2):
170170 return True
171171
172172 # Case 2: Tempalte + Reduction fusion
173- if len (base_template_node1 ) == 1 and len (node1 .get_nodes ())== 1 and len (base_template_node2 ) == 0 and node2 .is_reduction () and extension_config .CONFIG_FUSION_REDUCTION_EPILOGUE :
173+ if len (base_template_node1 ) == 1 and len (node1 .get_nodes ())== 1 and len (node2 . get_nodes ()) == 1 and len ( base_template_node2 ) == 0 and node2 .is_reduction () and extension_config .CONFIG_FUSION_REDUCTION_EPILOGUE :
174174 from PyTorchSimFrontend .mlir .mlir_gemm_template import MLIRGemmTemplate
175175 from PyTorchSimFrontend .mlir .mlir_bmm_template import MLIRBMMTemplate
176176 target_node = base_template_node1 [0 ].node
@@ -185,39 +185,35 @@ def can_fuse_horizontal(self, node1, node2):
185185 except :
186186 return False
187187
188- # We can't fuse dim=-1
189- layout_possible = stride != 1
188+ # We can't fuse dim=-1 & N == 1
189+ layout_possible = stride != 1 and ( 1 not in node1 . node . get_size ())
190190 # Directed linked?
191191 dependency_check = writes1 & reads2
192192 dependency_size = all ([i .get_numel () == node1 .get_nodes ()[0 ].node .get_numel () for i in node2 .read_writes .reads ])
193193 return size_match and layout_possible and dependency_check and dependency_size
194194
195195 # Case 3: Prologue(Pointwise) + Tempalte
196- if len (base_template_node1 ) == 0 and len (node1 .get_nodes ())== 1 and not node1 .is_reduction () and len (base_template_node2 ) == 1 and extension_config .CONFIG_FUSION_PROLOGUE :
197- from PyTorchSimFrontend .mlir .mlir_gemm_template import MLIRGemmTemplate
198- from PyTorchSimFrontend .mlir .mlir_bmm_template import MLIRBMMTemplate
199-
200- target_node = base_template_node2 [0 ].node
201- # Currently only BMM, MM support prologue fusion
202- if not isinstance (target_node .template , (MLIRBMMTemplate , MLIRGemmTemplate )):
203- return False
204-
205- if len (node1 .read_writes .writes ) != 1 :
206- return False
207- if node1 .node not in target_node .inputs or any (["view" in str (ori ) for ori in node1 .node .origins ]): #FIXME
208- return False
209-
210- # We don't fuse this edge case...
211- if base_template_node2 [0 ].group [1 ][0 ][0 ] == 1 :
212- return False
213-
214- if list (node1 .read_writes .writes )[0 ].name in [dep .name for dep in node2 .read_writes .reads ]:
215- node1 = self .revert_group (node1 )
216- return True
217-
218- # Check elementwise fusion
219- if vars1 == vars2 and reduce1 == reduce2 and not node1 .is_reduction () and not node2 .is_reduction ():
220- return writes1 & reads2
196+ # if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE:
197+ # from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate
198+ # from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate
199+
200+ # target_node = base_template_node2[0].node
201+ # # Currently only BMM, MM support prologue fusion
202+ # if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)):
203+ # return False
204+
205+ # if len(node1.read_writes.writes) != 1:
206+ # return False
207+ # if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME
208+ # return False
209+
210+ # # We don't fuse this edge case...
211+ # if base_template_node2[0].group[1][0][0] == 1:
212+ # return False
213+
214+ # if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]:
215+ # node1 = self.revert_group(node1)
216+ # return True
221217 return False
222218
223219 def revert_group (self , act_nodes , args = None , var_ranges = None ):
0 commit comments