22import sympy
33import re
44import os
5- import math
65from functools import reduce
76from operator import mul
87import torch
2928from .mlir_ops import ExtensionOverrides
3029from PyTorchSimFrontend .mlir .mlir_autotune import MLIRBenchmarkRequest
3130
31+ # Configure logger for mlir_codegen_backend module
32+ logger = extension_config .setup_logger ()
33+
3234def reduction_init (reduction_type , dtype ):
3335 if dtype in cpp .DTYPE_LOWP_FP :
3436 # Since load promotes all half-precision inputs to float, the initial
@@ -95,11 +97,14 @@ def write_header(self):
9597
9698 from torch import device, empty, empty_strided
9799 from { extension_codecache .__name__ } import CustomAsyncCompile
98- from PyTorchSimFrontend.extension_config import CONFIG_SRAM_BUFFER_PLAN, CONFIG_TOGSIM_EAGER_MODE
100+ from PyTorchSimFrontend.extension_config import CONFIG_SRAM_BUFFER_PLAN, CONFIG_TOGSIM_EAGER_MODE, setup_logger
99101 from Simulator.simulator import TOGSimulator
100102 from PyTorchSimFrontend.extension_op import sparse_mm_dummy_stonne_outer
101103 from torch._inductor.select_algorithm import extern_kernels
102104
105+ # Configure logger for generated wrapper code
106+ _logger = setup_logger("PyTorchSimFrontend.mlir.generated_wrapper")
107+
103108 aten = torch.ops.aten
104109 inductor_ops = torch.ops.inductor
105110 assert_size_stride = torch._C._dynamo.guards.assert_size_stride
@@ -108,7 +113,7 @@ def write_header(self):
108113 custom_async_compile = CustomAsyncCompile()
109114 async_compile = AsyncCompile()
110115 os.environ["TORCHSIM_LAST_COMPILED_MODULE"] = __file__
111- print(f \ ' Wrapper Codegen Path = {{__file__}}\ ' )
116+ _logger.info(f 'Wrapper Codegen Path = {{__file__}}')
112117 """
113118 )
114119 self .header .splice (
@@ -909,15 +914,14 @@ def make_choices(self, nodes, kernel_name):
909914
910915 # Try initial tile size
911916 self .reset (None )
912- src_code = super ().codegen_nodes (nodes , kernel_name )
917+ src_code , meta_code = super ().codegen_nodes (nodes , kernel_name )
913918 current_tile_sz = tuple (self .kernel_group .tile_desc .get_tile_size ())
914919 search_space .add (current_tile_sz )
915920
916- if extension_config .CONFIG_DEBUG_MODE :
917- print (f"[Auto-tune] Trying tile size: { list (current_tile_sz )} , vlane_stride: { self .kernel_group .tile_desc .vmap .vlane_stride } , split_axis: { self .kernel_group .tile_desc .vmap .vlane_split_axis } " )
921+ logger .debug (f"Auto-tune: Trying tile size: { list (current_tile_sz )} , vlane_stride: { self .kernel_group .tile_desc .vmap .vlane_stride } , split_axis: { self .kernel_group .tile_desc .vmap .vlane_split_axis } " )
918922 self ._prepare_simulator_headers (src_code )
919923 bench_runner = self .run_bench (nodes , kernel_name , src_code )
920- choices .append ((bench_runner , src_code , current_tile_sz , self .kernel_group .tile_desc .vmap .vlane_stride ))
924+ choices .append ((bench_runner , src_code , meta_code , current_tile_sz , self .kernel_group .tile_desc .vmap .vlane_stride ))
921925
922926 while prevent_infinite_loop < 10 and candidate_axes :
923927 for axis in list (candidate_axes ):
@@ -939,7 +943,7 @@ def make_choices(self, nodes, kernel_name):
939943 continue
940944
941945 self .reset (None )
942- src_code = super ().codegen_nodes (nodes , kernel_name )
946+ src_code , meta_code = super ().codegen_nodes (nodes , kernel_name )
943947 current_tile_sz = tuple (self .kernel_group .tile_desc .get_tile_size ())
944948
945949 # FIXME. How to intergrate this constraint to tile system?
@@ -956,11 +960,10 @@ def make_choices(self, nodes, kernel_name):
956960
957961 # Add this choice
958962 search_space .add (current_tile_sz )
959- if extension_config .CONFIG_DEBUG_MODE :
960- print (f"[Auto-tune] Trying tile size: { list (current_tile_sz )} , vlane_stride: { self .kernel_group .tile_desc .vmap .vlane_stride } , split_axis: { self .kernel_group .tile_desc .vmap .vlane_split_axis } " )
963+ logger .debug (f"Auto-tune: Trying tile size: { list (current_tile_sz )} , vlane_stride: { self .kernel_group .tile_desc .vmap .vlane_stride } , split_axis: { self .kernel_group .tile_desc .vmap .vlane_split_axis } " )
961964 self ._prepare_simulator_headers (src_code )
962965 bench_runner = self .run_bench (nodes , kernel_name , src_code )
963- choices .append ((bench_runner , src_code , self .kernel_group .tile_desc .get_tile_size (), self .kernel_group .tile_desc .vmap .vlane_stride ))
966+ choices .append ((bench_runner , src_code , meta_code , self .kernel_group .tile_desc .get_tile_size (), self .kernel_group .tile_desc .vmap .vlane_stride ))
964967 prevent_infinite_loop += 1
965968 self .kernel_group .tile_desc .prev_tail_threshold = prev_tail_threshold
966969 return choices
@@ -976,18 +979,20 @@ def get_cycle(choice):
976979 return float ("inf" )
977980 return float ("inf" ) # Exceeded maximum number of autotuning attempts
978981 choices = self .make_choices (* args )
979-
980982 if len (choices ) == 0 : # Can't autotune
981- return [None , None ]
983+ return [None , None , None ]
984+
985+ # Get cycle time for each choice
982986 with ThreadPoolExecutor (max_workers = 8 ) as executor :
983987 results = list (executor .map (get_cycle , choices ))
984- max_idx = results .index (min (results ))
988+ min_idx = results .index (min (results ))
985989 if min (results ) == float ("inf" ):
986990 raise RuntimeError ("Failed to find optimal tile size..." )
987- if extension_config .CONFIG_DEBUG_MODE :
988- self ._log_autotune_result (choices [max_idx ], results [max_idx ])
989- optimal_src_code , loop_size = choices [max_idx ][1 ], choices [max_idx ][- 1 ]
990- return optimal_src_code , loop_size
991+
992+ self ._log_autotune_result (choices [min_idx ], results [min_idx ])
993+
994+ optimal_src_code , meta_code , loop_size = choices [min_idx ][1 ], choices [min_idx ][2 ], choices [min_idx ][- 1 ]
995+ return optimal_src_code , meta_code , loop_size
991996
992997 def run_bench (self , nodes , kernel_name , src_code ):
993998 _ , _ , arg_attributes , _ = self .kernel_group .args .mlir_argdefs ()
@@ -1015,19 +1020,19 @@ def run_bench(self, nodes, kernel_name, src_code):
10151020 return bmreq .make_run_fn (dummy_inputs , dummy_outputs )
10161021
10171022 def _log_autotune_result (self , best_choice , best_cycle ):
1018- print (
1019- f"[ Auto-tune] Optimal tile size: { list (best_choice [2 ])} , "
1020- f"vlane_stride: { best_choice [3 ]} , "
1023+ logger . debug (
1024+ f"Auto-tune: Optimal tile size: { list (best_choice [3 ])} , "
1025+ f"vlane_stride: { best_choice [4 ]} , "
10211026 f"cycles: { best_cycle } "
10221027 )
10231028
10241029 def codegen_nodes (self , nodes , kernel_name ):
10251030 src_code , meta_code = super ().codegen_nodes (nodes , kernel_name )
10261031 self ._prepare_simulator_headers (src_code )
10271032 if "autotune" in extension_config .codegen_mapping_strategy and extension_config .pytorchsim_timing_mode :
1028- optimal_src_code = self .autotune (nodes , kernel_name )[0 ]
1033+ optimal_src_code , meta_code = self .autotune (nodes , kernel_name )[: 2 ]
10291034 if optimal_src_code is not None :
1030- return optimal_src_code
1035+ return optimal_src_code , meta_code
10311036 return src_code , meta_code
10321037
10331038 def _prepare_simulator_headers (self , src_code ):
0 commit comments