@@ -25,13 +25,48 @@ class MLIRScheduling(BaseScheduling):
2525 target_kernel = MLIRKernel
2626 def __init__ (self , scheduler ):
2727 self .scheduler = scheduler
28- #self.scheduler.enter_context = self.enter_context_fixed # FIXME. Monkey patch: For fixing the inductor bug
28+ if scheduler is not None :
29+ self .scheduler .can_fuse_origin = self .scheduler .can_fuse
30+ self .scheduler .can_fuse = self .can_fuse_with_exceptions # FIXME. Monkey patch: For prolouge fusion
2931 self .kernel_group = mlir_common .MLIRWrapperKenrelGroup ()
3032 self ._ready_to_flush = False
3133 self .outer_function = set ()
3234 config .inplace_buffers = False # FIXME. inout kernel makes trouble.. So disabled it!
3335 self .max_fusion_size = 5
3436
37+ def can_fuse_with_exceptions (self , node1 : BaseSchedulerNode , node2 : BaseSchedulerNode ) -> bool :
38+ if not extension_config .CONFIG_FUSION :
39+ return False
40+
41+ # Extract base template node
42+ base_template_node1 = [node for node in node1 .get_nodes () if node .is_template ()]
43+ base_template_node2 = [node for node in node2 .get_nodes () if node .is_template ()]
44+
45+ # Case 3: Prologue(Pointwise) + Tempalte
46+ if len (base_template_node1 ) == 0 and len (node1 .get_nodes ())== 1 and not node1 .is_reduction () and len (base_template_node2 ) == 1 and extension_config .CONFIG_FUSION_PROLOGUE :
47+ from PyTorchSimFrontend .mlir .mlir_gemm_template import MLIRGemmTemplate
48+ from PyTorchSimFrontend .mlir .mlir_bmm_template import MLIRBMMTemplate
49+
50+ target_node = base_template_node2 [0 ].node
51+ # Currently only BMM, MM support prologue fusion
52+ if not isinstance (target_node .template , (MLIRBMMTemplate , MLIRGemmTemplate )):
53+ return False
54+
55+ if len (node1 .read_writes .writes ) != 1 :
56+ return False
57+ if node1 .node not in target_node .inputs or any (["view" in str (ori ) for ori in node1 .node .origins ]): #FIXME
58+ return False
59+
60+ # We don't fuse this edge case...
61+ if base_template_node2 [0 ].group [1 ][0 ][0 ] == 1 :
62+ return False
63+
64+ if list (node1 .read_writes .writes )[0 ].name in [dep .name for dep in node2 .read_writes .reads ]:
65+ node1 = self .revert_group (node1 )
66+ return True
67+ return self .scheduler .can_fuse_origin (node1 , node2 )
68+
69+
3570 def _set_flush_status (self , status : bool ):
3671 self ._ready_to_flush = status
3772
@@ -45,6 +80,9 @@ def get_backend_features(self, device):
4580 def can_fuse_vertical (self , node1 , node2 ):
4681 return self .can_fuse_horizontal (node1 , node2 )
4782
83+ def can_fuse_multi_outputs_template (self , node1 , node2 ):
84+ return self .can_fuse_horizontal (node1 , node2 )
85+
4886 def can_fuse_horizontal (self , node1 , node2 ):
4987 if not extension_config .CONFIG_FUSION :
5088 return False
@@ -88,7 +126,7 @@ def can_fuse_horizontal(self, node1, node2):
88126 return same_iter and no_dependency
89127
90128 # Case 1: Template + Pointwise fusion
91- if len (base_template_node1 ) == 1 and len (base_template_node2 ) == 0 and not node2 .is_reduction ():
129+ if len (base_template_node1 ) == 1 and len (node1 . get_nodes ()) == 1 and len ( base_template_node2 ) == 0 and not node2 .is_reduction ():
92130 # Don't fuse maxpool template code
93131 from PyTorchSimFrontend .mlir .mlir_maxpool_template import MLIRMaxPoolTemplate
94132 from PyTorchSimFrontend .mlir .mlir_bmm_template import MLIRBMMTemplate
@@ -132,9 +170,10 @@ def can_fuse_horizontal(self, node1, node2):
132170 return True
133171
134172 # Case 2: Tempalte + Reduction fusion
135- if len (base_template_node1 ) == 1 and len (base_template_node2 ) == 0 and node2 .is_reduction () and extension_config .CONFIG_FUSION_REDUCTION_EPILOGUE :
173+ if len (base_template_node1 ) == 1 and len (node1 . get_nodes ()) == 1 and len ( base_template_node2 ) == 0 and node2 .is_reduction () and extension_config .CONFIG_FUSION_REDUCTION_EPILOGUE :
136174 from PyTorchSimFrontend .mlir .mlir_gemm_template import MLIRGemmTemplate
137175 from PyTorchSimFrontend .mlir .mlir_bmm_template import MLIRBMMTemplate
176+ target_node = base_template_node1 [0 ].node
138177 if not isinstance (target_node .template , (MLIRBMMTemplate , MLIRGemmTemplate )):
139178 return False
140179
@@ -149,7 +188,7 @@ def can_fuse_horizontal(self, node1, node2):
149188 # We can't fuse dim=-1
150189 layout_possible = stride != 1
151190 # Directed linked?
152- dependency_check = node2 . get_nodes ()[ 0 ] in [ node . node for node in base_template_node1 [ 0 ]. users ] # and len(node2.read_writes.reads)==1
191+ dependency_check = writes1 & reads2
153192 dependency_size = all ([i .get_numel () == node1 .get_nodes ()[0 ].node .get_numel () for i in node2 .read_writes .reads ])
154193 return size_match and layout_possible and dependency_check and dependency_size
155194
@@ -177,8 +216,8 @@ def can_fuse_horizontal(self, node1, node2):
177216 return True
178217
179218 # Check elementwise fusion
180- if vars1 == vars2 and reduce1 == reduce2 :
181- return True
219+ if vars1 == vars2 and reduce1 == reduce2 and not node1 . is_reduction () and not node2 . is_reduction () :
220+ return writes1 & reads2
182221 return False
183222
184223 def revert_group (self , act_nodes , args = None , var_ranges = None ):
0 commit comments