Skip to content

Commit 8db4e0b

Browse files
committed
[Fix+Log] Change logging system + Fix meta_code interface
1 parent 1d1508a commit 8db4e0b

8 files changed

Lines changed: 144 additions & 81 deletions

File tree

PyTorchSimFrontend/extension_codecache.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from PyTorchSimFrontend import extension_config
1111
from Simulator.simulator import FunctionalSimulator, CycleSimulator, TOGSimulator
1212

13+
# Configure logger for extension_codecache module (WARNING level by default)
14+
logger = extension_config.setup_logger()
15+
1316
LOCK_TIMEOUT = 600
1417

1518
def hash_prefix(hash_value):
@@ -166,8 +169,8 @@ def load(cls, source_code,
166169
subprocess.check_call(translate_cmd)
167170
subprocess.check_call(llc_cmd)
168171
except subprocess.CalledProcessError as e:
169-
print("Command failed with exit code", e.returncode)
170-
print("Error output:", e.output)
172+
logger.error(f"Command failed with exit code {e.returncode}")
173+
logger.error(f"Error output: {e.output.decode() if isinstance(e.output, bytes) else e.output}")
171174
assert(0)
172175

173176
val_llvm_caller = MLIRKernelCallerCodeGen(extension_config.pytorchsim_functional_mode, arg_attributes)
@@ -179,8 +182,10 @@ def load(cls, source_code,
179182
spad_size = val_llvm_caller.get_spad_size(target)
180183
spad_usage = stack_size + spad_size # Spad usage per lane
181184
if extension_config.CONFIG_SPAD_INFO["spad_size"] < spad_usage:
182-
print(f"[Warning] Scratchpad size exceeded: required {spad_usage} bytes, "
183-
f"but only {extension_config.CONFIG_SPAD_INFO['spad_size']} bytes available.")
185+
logger.debug(
186+
f"Scratchpad size exceeded: required {spad_usage} bytes, "
187+
f"but only {extension_config.CONFIG_SPAD_INFO['spad_size']} bytes available."
188+
)
184189
raise SpadOverflowError()
185190

186191
# Launch tile graph generator
@@ -197,8 +202,8 @@ def load(cls, source_code,
197202
subprocess.check_call(gem5_translate_cmd)
198203
subprocess.check_call(gem5_llc_cmd)
199204
except subprocess.CalledProcessError as e:
200-
print("Command failed with exit code", e.returncode)
201-
print("Error output:", e.output)
205+
logger.error(f"Command failed with exit code {e.returncode}")
206+
logger.error(f"Error output: {e.output.decode() if isinstance(e.output, bytes) else e.output}")
202207
assert(0)
203208

204209
if not extension_config.pytorchsim_timing_mode:

PyTorchSimFrontend/extension_config.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import importlib
44
import yaml
5+
import logging
56

67
CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')
78
CONFIG_GEM5_PATH = os.environ.get('GEM5_PATH', default="/workspace/gem5/build/RISCV/gem5.opt")
@@ -134,4 +135,43 @@ def load_plan_from_module(module_path):
134135

135136
CONFIG_USE_TIMING_POOLING = int(os.environ.get('TORCHSIM_USE_TIMING_POOLING', default=0))
136137

137-
CONFIG_DEBUG_MODE = int(os.environ.get('TORCHSIM_DEBUG_MODE', default=0))
138+
CONFIG_DEBUG_MODE = int(os.environ.get('TORCHSIM_DEBUG_MODE', default=0))
139+
140+
141+
def setup_logger(name=None, level=None):
142+
"""
143+
Setup a logger with consistent formatting across all modules.
144+
145+
Args:
146+
name: Logger name (default: __name__ of calling module)
147+
level: Logging level (default: DEBUG if CONFIG_DEBUG_MODE else INFO)
148+
149+
Returns:
150+
Logger instance
151+
"""
152+
if name is None:
153+
import inspect
154+
# Get the calling module's name
155+
frame = inspect.currentframe().f_back
156+
name = frame.f_globals.get('__name__', 'PyTorchSim')
157+
158+
# Convert logger name to lowercase
159+
name = name.lower()
160+
logger = logging.getLogger(name)
161+
162+
# Only configure if not already configured (avoid duplicate handlers)
163+
if not logger.handlers:
164+
handler = logging.StreamHandler()
165+
formatter = logging.Formatter(
166+
fmt='[%(asctime)s.%(msecs)03d] [%(levelname)s] [%(name)s] %(message)s',
167+
datefmt='%Y-%m-%d %H:%M:%S'
168+
)
169+
handler.setFormatter(formatter)
170+
logger.addHandler(handler)
171+
172+
# Set log level
173+
if level is None:
174+
level = logging.DEBUG if CONFIG_DEBUG_MODE else logging.INFO
175+
logger.setLevel(level)
176+
177+
return logger

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import sympy
33
import re
44
import os
5-
import math
65
from functools import reduce
76
from operator import mul
87
import torch
@@ -29,6 +28,9 @@
2928
from .mlir_ops import ExtensionOverrides
3029
from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest
3130

31+
# Configure logger for mlir_codegen_backend module
32+
logger = extension_config.setup_logger()
33+
3234
def reduction_init(reduction_type, dtype):
3335
if dtype in cpp.DTYPE_LOWP_FP:
3436
# Since load promotes all half-precision inputs to float, the initial
@@ -95,11 +97,14 @@ def write_header(self):
9597
9698
from torch import device, empty, empty_strided
9799
from {extension_codecache.__name__} import CustomAsyncCompile
98-
from PyTorchSimFrontend.extension_config import CONFIG_SRAM_BUFFER_PLAN, CONFIG_TOGSIM_EAGER_MODE
100+
from PyTorchSimFrontend.extension_config import CONFIG_SRAM_BUFFER_PLAN, CONFIG_TOGSIM_EAGER_MODE, setup_logger
99101
from Simulator.simulator import TOGSimulator
100102
from PyTorchSimFrontend.extension_op import sparse_mm_dummy_stonne_outer
101103
from torch._inductor.select_algorithm import extern_kernels
102104
105+
# Configure logger for generated wrapper code
106+
_logger = setup_logger("PyTorchSimFrontend.mlir.generated_wrapper")
107+
103108
aten = torch.ops.aten
104109
inductor_ops = torch.ops.inductor
105110
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
@@ -108,7 +113,7 @@ def write_header(self):
108113
custom_async_compile = CustomAsyncCompile()
109114
async_compile = AsyncCompile()
110115
os.environ["TORCHSIM_LAST_COMPILED_MODULE"] = __file__
111-
print(f\'Wrapper Codegen Path = {{__file__}}\')
116+
_logger.info(f'Wrapper Codegen Path = {{__file__}}')
112117
"""
113118
)
114119
self.header.splice(
@@ -909,15 +914,14 @@ def make_choices(self, nodes, kernel_name):
909914

910915
# Try initial tile size
911916
self.reset(None)
912-
src_code = super().codegen_nodes(nodes, kernel_name)
917+
src_code, meta_code = super().codegen_nodes(nodes, kernel_name)
913918
current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size())
914919
search_space.add(current_tile_sz)
915920

916-
if extension_config.CONFIG_DEBUG_MODE:
917-
print(f"[Auto-tune] Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}")
921+
logger.debug(f"Auto-tune: Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}")
918922
self._prepare_simulator_headers(src_code)
919923
bench_runner = self.run_bench(nodes, kernel_name, src_code)
920-
choices.append((bench_runner, src_code, current_tile_sz, self.kernel_group.tile_desc.vmap.vlane_stride))
924+
choices.append((bench_runner, src_code, meta_code, current_tile_sz, self.kernel_group.tile_desc.vmap.vlane_stride))
921925

922926
while prevent_infinite_loop < 10 and candidate_axes:
923927
for axis in list(candidate_axes):
@@ -939,7 +943,7 @@ def make_choices(self, nodes, kernel_name):
939943
continue
940944

941945
self.reset(None)
942-
src_code = super().codegen_nodes(nodes, kernel_name)
946+
src_code, meta_code = super().codegen_nodes(nodes, kernel_name)
943947
current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size())
944948

945949
# FIXME. How to intergrate this constraint to tile system?
@@ -956,11 +960,10 @@ def make_choices(self, nodes, kernel_name):
956960

957961
# Add this choice
958962
search_space.add(current_tile_sz)
959-
if extension_config.CONFIG_DEBUG_MODE:
960-
print(f"[Auto-tune] Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}")
963+
logger.debug(f"Auto-tune: Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}")
961964
self._prepare_simulator_headers(src_code)
962965
bench_runner = self.run_bench(nodes, kernel_name, src_code)
963-
choices.append((bench_runner, src_code, self.kernel_group.tile_desc.get_tile_size(), self.kernel_group.tile_desc.vmap.vlane_stride))
966+
choices.append((bench_runner, src_code, meta_code, self.kernel_group.tile_desc.get_tile_size(), self.kernel_group.tile_desc.vmap.vlane_stride))
964967
prevent_infinite_loop += 1
965968
self.kernel_group.tile_desc.prev_tail_threshold = prev_tail_threshold
966969
return choices
@@ -976,18 +979,20 @@ def get_cycle(choice):
976979
return float("inf")
977980
return float("inf") # Exceeded maximum number of autotuning attempts
978981
choices = self.make_choices(*args)
979-
980982
if len(choices) == 0: # Can't autotune
981-
return [None, None]
983+
return [None, None, None]
984+
985+
# Get cycle time for each choice
982986
with ThreadPoolExecutor(max_workers=8) as executor:
983987
results = list(executor.map(get_cycle, choices))
984-
max_idx = results.index(min(results))
988+
min_idx = results.index(min(results))
985989
if min(results) == float("inf"):
986990
raise RuntimeError("Failed to find optimal tile size...")
987-
if extension_config.CONFIG_DEBUG_MODE:
988-
self._log_autotune_result(choices[max_idx], results[max_idx])
989-
optimal_src_code, loop_size = choices[max_idx][1], choices[max_idx][-1]
990-
return optimal_src_code, loop_size
991+
992+
self._log_autotune_result(choices[min_idx], results[min_idx])
993+
994+
optimal_src_code, meta_code, loop_size = choices[min_idx][1], choices[min_idx][2], choices[min_idx][-1]
995+
return optimal_src_code, meta_code, loop_size
991996

992997
def run_bench(self, nodes, kernel_name, src_code):
993998
_, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs()
@@ -1015,19 +1020,19 @@ def run_bench(self, nodes, kernel_name, src_code):
10151020
return bmreq.make_run_fn(dummy_inputs, dummy_outputs)
10161021

10171022
def _log_autotune_result(self, best_choice, best_cycle):
1018-
print(
1019-
f"[Auto-tune] Optimal tile size: {list(best_choice[2])}, "
1020-
f"vlane_stride: {best_choice[3]}, "
1023+
logger.debug(
1024+
f"Auto-tune: Optimal tile size: {list(best_choice[3])}, "
1025+
f"vlane_stride: {best_choice[4]}, "
10211026
f"cycles: {best_cycle}"
10221027
)
10231028

10241029
def codegen_nodes(self, nodes, kernel_name):
10251030
src_code, meta_code = super().codegen_nodes(nodes, kernel_name)
10261031
self._prepare_simulator_headers(src_code)
10271032
if "autotune" in extension_config.codegen_mapping_strategy and extension_config.pytorchsim_timing_mode:
1028-
optimal_src_code = self.autotune(nodes, kernel_name)[0]
1033+
optimal_src_code, meta_code = self.autotune(nodes, kernel_name)[:2]
10291034
if optimal_src_code is not None:
1030-
return optimal_src_code
1035+
return optimal_src_code, meta_code
10311036
return src_code, meta_code
10321037

10331038
def _prepare_simulator_headers(self, src_code):

PyTorchSimFrontend/mlir/mlir_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import math
22
import torch
3+
import warnings
34

45
from torch._inductor.codegen import common
56
from torch._inductor.virtualized import V, _ops as ops
67
from . import mlir_common
78

9+
warnings.filterwarnings('ignore', message='undefined OpHandler\\..*, please add missing op schema')
10+
811
def reduction_combine_vec(reduction_type, vector_value, init_value, axis, shape, reduced_shape):
912
if reduction_type == "sum":
1013
return f"vector.multi_reduction <add>, %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}"

PyTorchSimFrontend/mlir/mlir_scheduling.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,7 @@ def codegen_template(self, template_node, epilogue_nodes, prologue_nodes):
299299
template_buffer = template_node.node
300300
kernel, tile_candidates, render = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group)
301301
_, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs()
302-
src_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes)
303-
meta_code = kernel.meta_kernel()
302+
src_code, meta_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes)
304303

305304
with V.set_kernel_handler(kernel):
306305
kernel_name = self.define_kernel(src_code, meta_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info,

PyTorchSimFrontend/mlir/mlir_template.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
from PyTorchSimFrontend import extension_config
3333
from . import mlir_common
3434

35+
# Configure logger for mlir_template module
36+
logger = extension_config.setup_logger()
37+
3538
class IndentedBufferGroup:
3639
def __init__(self, kernel: 'MLIRTemplateKernel', prefix=""):
3740
self.kernel = kernel
@@ -386,7 +389,6 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio
386389
return tile_candidates
387390

388391
def meta_kernel(self):
389-
wrapper = V.graph.wrapper_code
390392
kernel_arg_attributes = self.kernel_arg_attributes
391393
_, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs()
392394
if kernel_arg_attributes is not None:
@@ -483,38 +485,36 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_
483485
buffer.splice(src_code)
484486
src_code = buffer.getvalue()
485487
self._prepare_simulator_headers(src_code)
486-
return src_code
488+
meta_code = self.meta_kernel()
489+
return src_code, meta_code
487490

488491
def make_choices(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes):
489492
choices = []
490493
for tile_info in tile_candidates:
491-
if extension_config.CONFIG_DEBUG_MODE:
492-
# Compute Tile M, N, K DMA Tile M, N, K
493-
print(f"[Auto-tune] Trying tile size: {list(tile_info)}")
494-
src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info)
494+
# Compute Tile M, N, K DMA Tile M, N, K
495+
logger.debug(f"Auto-tune: Trying tile size: {list(tile_info)}")
496+
src_code, meta_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info)
495497
bench_runner = self.run_bench([template_node], self.kernel_name, src_code)
496-
choices.append((bench_runner, src_code, tile_info, self.loop_size))
498+
choices.append((bench_runner, src_code, meta_code, tile_info, self.loop_size))
497499
self.reset(reason=None)
498500
return choices
499501

500502
def _log_autotune_result(self, best_choice, best_cycle):
501-
tile_size = best_choice[2]
502-
print(
503-
f"[Auto-tune] Optimal tile size: {list(tile_size)}, "
503+
tile_size = best_choice[3]
504+
logger.debug(
505+
f"Auto-tune: Optimal tile size: {list(tile_size)}, "
504506
f"cycles: {best_cycle}"
505507
)
506508

507509
def codegen_nodes(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes):
508510
if "autotune" in extension_config.codegen_mapping_strategy and len(tile_candidates):
509-
src_code, loop_size = self.autotune(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes)
511+
src_code, meta_code, loop_size = self.autotune(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes)
510512
self.loop_size = loop_size
511513
else:
512514
tile_info = tile_candidates[0] if tile_candidates else None
513-
src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info)
515+
src_code, meta_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info)
514516

515-
with V.set_kernel_handler(self):
516-
self.meta_kernel()
517-
return src_code
517+
return src_code, meta_code
518518

519519
def _prepare_simulator_headers(self, src_code):
520520
spad_end_symbol = f"int spad_end[0] __attribute__ ((section(\".spad\")));\n"

Scheduler/scheduler.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
from torch._dynamo.device_interface import register_interface_for_device
1414

15+
# Configure logger for Scheduler module
16+
logger = extension_config.setup_logger()
17+
1518

1619
def import_module_from_path(module_name, path):
1720
module_path = Path(path) # Convert to Path object for safety
@@ -380,7 +383,7 @@ def __init__(self, num_request_queue=1, max_batch=1, engine_select=FIFO_ENGINE,
380383
elif engine_select == Scheduler.RR_ENGINE:
381384
self.execution_engine = RoundRobinRunner(self.tog_simulator, self.num_request_queue)
382385
else:
383-
print(f"Not supporetd engine type {engine_select}")
386+
logger.error(f"Not supported engine type {engine_select}")
384387
exit(1)
385388

386389
def add_request(self, request: Request, request_time=-1):
@@ -441,9 +444,11 @@ def finish_request(self, req : Request):
441444
self.finish_queue.append(req)
442445
self.request_queue[req.request_queue_idx].remove(req)
443446
turnaround_time, response_time, tbt_time = req.get_latency()
444-
print(f"[Request-{req.id} finished] partition: {req.request_queue_idx} arrival_time: "
445-
f"{req.arrival_time} start_time: {req.start_time[0]} turnaround latency: {turnaround_time}, "
446-
f"response time: {response_time} tbt_time: {tbt_time}")
447+
logger.info(
448+
f"[Request-{req.id} finished] partition: {req.request_queue_idx} arrival_time: "
449+
f"{req.arrival_time} start_time: {req.start_time[0]} turnaround latency: {turnaround_time}, "
450+
f"response time: {response_time} tbt_time: {tbt_time}"
451+
)
447452

448453
def per_schedule(self, request_queue_idx):
449454
# Wait partition is idle
@@ -454,11 +459,13 @@ def per_schedule(self, request_queue_idx):
454459
if not request_list:
455460
return False
456461

457-
print(f"[Request issue] partition: {request_queue_idx} batch size: {len(request_list)}", flush=True)
462+
logger.info(f"[Request issue] partition: {request_queue_idx} batch size: {len(request_list)}")
458463
for req in request_list:
459464
req.set_start(self.current_time())
460-
print(f"[Request-{req.id} issue] partition: {req.request_queue_idx} "
461-
f"arrival_time: {req.arrival_time} start_time: {req.start_time[0]}", flush=True)
465+
logger.info(
466+
f"[Request-{req.id} issue] partition: {req.request_queue_idx} "
467+
f"arrival_time: {req.arrival_time} start_time: {req.start_time[0]}"
468+
)
462469
# Submit batched request
463470
self.execution_engine.submit(request_list, request_queue_idx)
464471

0 commit comments

Comments
 (0)