66from PyTorchSimFrontend .mlir .mlir_template import MLIRTemplate
77from PyTorchSimFrontend .mlir .mlir_template import MLIRTemplateKernel
88from torch ._inductor .ir import IRNode
9- from torch ._inductor .codecache import write_atomic
10- import PyTorchSimFrontend .extension_codecache as extension_codecache
119from PyTorchSimFrontend .mlir import mlir_common
1210
1311BMM_TEMPLATE = r"""
@@ -162,51 +160,31 @@ def render(self,
162160 template_buffer_node = None ,
163161 epilogue_nodes : Optional [List [IRNode ]] = None ,
164162 prologue_nodes : Optional [List [IRNode ]] = None ,
163+ tile_info = None ,
165164 ** kwargs ):
166- if template_buffer_node is not None :
167- self .output_node = template_buffer_node
168-
169- # Extract input arguments info
170- X , W = self .input_nodes [0 ], self .input_nodes [1 ]
171- Y = self .output_node
172- Bias = None if len (self .input_nodes ) == 2 else self .input_nodes [2 ]
173-
174- W_tensor = empty_strided (W .layout .size , W .layout .stride )
175- X_tensor = empty_strided (X .layout .size , X .layout .stride )
176- if len (W_tensor .size ()) > 3 or len (W_tensor .size ()) == 2 :
177- W_tensor = W_tensor .view ([- 1 , W_tensor .shape [- 2 ], W_tensor .shape [- 1 ]])
178- if len (X_tensor .size ()) > 3 or len (X_tensor .size ()) == 2 :
179- X_tensor = X_tensor .view ([- 1 , X_tensor .shape [- 2 ], X_tensor .shape [- 1 ]])
180- B , M , N , K = X_tensor .size ()[0 ], X_tensor .size ()[1 ], W_tensor .size ()[2 ], X_tensor .size ()[2 ]
181-
182- W_stride = W_tensor .stride ()
183- X_stride = X_tensor .stride ()
184-
185- # Select tile size
186- n_extra_node = len (epilogue_nodes ) if epilogue_nodes is not None else 0
187- TILE_M , TILE_N , TILE_K = kernel .gemm_combination_mapping (M , N , K , n_extra_node = n_extra_node )
188- SUB_TILE_M = TILE_M if (TILE_M < kernel .vector_lane ) or prologue_nodes else kernel .vector_lane
189- SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane
190- SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane
165+ X , W , Y , Bias , W_tensor , X_tensor , B , M , N , K , n_extra_node , n_prologue_node = self .extract_info (template_buffer_node , epilogue_nodes , prologue_nodes )
166+ if tile_info is None :
167+ TILE_M , TILE_N , TILE_K , SUB_TILE_M , SUB_TILE_N , SUB_TILE_K = self .select_tile (kernel , M , N , K , n_extra_node , 0 , n_prologue_node )[0 ]
168+ else :
169+ TILE_M , TILE_N , TILE_K , SUB_TILE_M , SUB_TILE_N , SUB_TILE_K = tile_info
191170
192171 TOG_latency = M if TILE_M > M else TILE_M
193172 kernel .loop_size = [TOG_latency , TILE_N , TILE_K ]
194- TILE_K = TILE_K // 2 if prologue_nodes else TILE_K
195173
196174 # Select template code
197175 nr_reduction_nodes = [node for node in epilogue_nodes if node .is_reduction ()] if epilogue_nodes is not None else []
198176 if nr_reduction_nodes :
199- template = BMM_REDUCTION_TEMPLATE
200- epilogue_dim_aliasing = {"index0" :"index0" , "index1" :"index2" , "index2" : "index1" }
201- nr_rdim = 1
177+ template = BMM_REDUCTION_TEMPLATE
178+ epilogue_dim_aliasing = {"index0" :"index0" , "index1" :"index2" , "index2" : "index1" }
179+ nr_rdim = 1
202180 elif prologue_nodes :
203- template = BMM_PROLOGUE_TEMPLATE
204- epilogue_dim_aliasing = {"index0" :"index0" , "index1" :"index1" , "index2" : "index2" }
205- nr_rdim = 0
181+ template = BMM_PROLOGUE_TEMPLATE
182+ epilogue_dim_aliasing = {"index0" :"index0" , "index1" :"index1" , "index2" : "index2" }
183+ nr_rdim = 0
206184 else :
207- template = BMM_TEMPLATE
208- epilogue_dim_aliasing = {"index0" :"index0" , "index1" :"index1" , "index2" : "index2" }
209- nr_rdim = 0
185+ template = BMM_TEMPLATE
186+ epilogue_dim_aliasing = {"index0" :"index0" , "index1" :"index1" , "index2" : "index2" }
187+ nr_rdim = 0
210188
211189 # Prepare tile descriptors
212190 vlane_stride = 1
@@ -323,19 +301,53 @@ def render(self,
323301 dram_idx = Y_idx ,
324302 dram_tile_desc = Y_tile_desc ,
325303 nr_rdim = nr_rdim ,
304+ r_dim_size = M ,
326305 dim_aliasing = epilogue_dim_aliasing
327306 )
328307 code = self ._template_from_string (template ).render (** kernel .render_options )
329308 kernel .add_loop_info ([kernel .render_options ["M" ], kernel .render_options ["N" ], kernel .render_options ["K" ]], [kernel .render_options ["TILE_M" ], kernel .render_options ["TILE_N" ], kernel .render_options ["TILE_K" ]])
330309 return code
331310
332- def codegen_header (self , code , extra_headers ):
333- write_path = extension_codecache .get_write_path (code )
334- if not os .path .exists (write_path ):
335- os .makedirs (write_path )
336- spike_write_path = os .path .join (write_path , "global_var.h" )
337- gem5_write_path = os .path .join (write_path , "gem5_global_var.h" )
338- if not os .path .exists (spike_write_path ):
339- write_atomic (spike_write_path , extra_headers [0 ])
340- if not os .path .exists (gem5_write_path ):
341- write_atomic (gem5_write_path , extra_headers [1 ])
311+ def extract_info (self , template_buffer_node , epilogue_nodes , prologue_nodes ):
312+ if template_buffer_node is not None :
313+ self .output_node = template_buffer_node
314+
315+ # Extract input arguments info
316+ X , W = self .input_nodes [0 ], self .input_nodes [1 ]
317+ Y = self .output_node
318+ Bias = None if len (self .input_nodes ) == 2 else self .input_nodes [2 ]
319+
320+ W_tensor = empty_strided (W .layout .size , W .layout .stride )
321+ X_tensor = empty_strided (X .layout .size , X .layout .stride )
322+ if len (W_tensor .size ()) > 3 or len (W_tensor .size ()) == 2 :
323+ W_tensor = W_tensor .view ([- 1 , W_tensor .shape [- 2 ], W_tensor .shape [- 1 ]])
324+ if len (X_tensor .size ()) > 3 or len (X_tensor .size ()) == 2 :
325+ X_tensor = X_tensor .view ([- 1 , X_tensor .shape [- 2 ], X_tensor .shape [- 1 ]])
326+ B , M , N , K = X_tensor .size ()[0 ], X_tensor .size ()[1 ], W_tensor .size ()[2 ], X_tensor .size ()[2 ]
327+
328+ W_stride = W_tensor .stride ()
329+ X_stride = X_tensor .stride ()
330+
331+ # Select tile size
332+ n_extra_node = len (epilogue_nodes ) if epilogue_nodes is not None else 0
333+ n_prologue_node = len (prologue_nodes ) if prologue_nodes is not None else 0
334+ return X ,W ,Y ,Bias ,W_tensor ,X_tensor ,B ,M ,N ,K ,n_extra_node , n_prologue_node
335+
336+ def get_tile_candidates (self ,
337+ kernel : MLIRTemplateKernel ,
338+ template_buffer_node = None ,
339+ epilogue_nodes : Optional [List [IRNode ]] = None ,
340+ prologue_nodes : Optional [List [IRNode ]] = None ,
341+ ** kwargs ):
342+ X , W , Y , Bias , W_tensor , X_tensor , B , M , N , K , n_extra_node , n_prologue_node = self .extract_info (template_buffer_node , epilogue_nodes , prologue_nodes )
343+ return self .select_tile (kernel , M , N , K , n_extra_node , 0 , n_prologue_node )
344+
345+ def select_tile (self , kernel , M , N , K , n_extra_node , n_extra_read , n_prologue_node ):
346+ tile_candidates = kernel .gemm_combination_mapping (M , N , K , n_extra_node = n_extra_node )
347+ for idx , (TILE_M , TILE_N , TILE_K ) in enumerate (tile_candidates ):
348+ SUB_TILE_M = TILE_M if (TILE_M < kernel .vector_lane ) or n_prologue_node else kernel .vector_lane
349+ SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane
350+ SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane
351+ TILE_K = TILE_K // 2 if n_prologue_node else TILE_K
352+ tile_candidates [idx ] = TILE_M ,TILE_N ,TILE_K ,SUB_TILE_M ,SUB_TILE_N ,SUB_TILE_K
353+ return tile_candidates
0 commit comments