From fc247be17221f2b6aa8c52228a2e86b7315ef78d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EC=9E=AC=EA=B7=A0?= Date: Mon, 2 Mar 2026 00:28:31 +0900 Subject: [PATCH 1/9] [Template] Add cat & sort template + Multi-output (WIP) --- .../torch_openreg/openreg/__init__.py | 49 +++ PyTorchSimFrontend/mlir/mlir_cat_template.py | 167 +++++++++++ PyTorchSimFrontend/mlir/mlir_common.py | 6 +- PyTorchSimFrontend/mlir/mlir_lowering.py | 281 +++++++++++++++++- PyTorchSimFrontend/mlir/mlir_sort_template.py | 253 ++++++++++++++++ PyTorchSimFrontend/mlir/mlir_template.py | 30 +- tests/DeepSeek/test_deepseek_v3_base.py | 170 +++++++++-- tests/test_cat.py | 89 ++++++ tests/test_sort.py | 112 +++++++ 9 files changed, 1121 insertions(+), 36 deletions(-) create mode 100644 PyTorchSimFrontend/mlir/mlir_cat_template.py create mode 100644 PyTorchSimFrontend/mlir/mlir_sort_template.py create mode 100644 tests/test_cat.py create mode 100644 tests/test_sort.py diff --git a/PyTorchSimDevice/torch_openreg/openreg/__init__.py b/PyTorchSimDevice/torch_openreg/openreg/__init__.py index f5aabc18..5603a4f7 100644 --- a/PyTorchSimDevice/torch_openreg/openreg/__init__.py +++ b/PyTorchSimDevice/torch_openreg/openreg/__init__.py @@ -256,6 +256,52 @@ def launch_model(model, *args, stream_index=0, timestamp=0, **kwargs): from .random import * # noqa: F403 from .amp import * +def _precheck_cat_out_args(args, kwargs): + tensors = args[0] if len(args) > 0 else kwargs.get("tensors") + dim = args[1] if len(args) > 1 else kwargs.get("dim", 0) + out = kwargs.get("out", args[2] if len(args) > 2 else None) + + if out is None: + return + if not isinstance(tensors, (list, tuple)) or len(tensors) == 0: + raise RuntimeError("aten::cat.out requires non-empty tensor list") + if not all(isinstance(t, torch.Tensor) for t in tensors): + raise RuntimeError("aten::cat.out tensors must be Tensor values") + if not isinstance(out, torch.Tensor): + raise RuntimeError("aten::cat.out out must be a Tensor") + + rank = tensors[0].dim() + if rank == 0: + raise RuntimeError("aten::cat.out does not support scalar inputs") + if dim < 0: + dim += rank + if dim < 0 or dim >= rank: + raise RuntimeError(f"aten::cat.out dim out of range: dim={dim}, rank={rank}") + if any(t.dim() != rank for t in tensors): + raise RuntimeError("aten::cat.out inputs must have the same rank") + if any(t.dtype != tensors[0].dtype for t in tensors): + raise RuntimeError("aten::cat.out inputs must have the same dtype") + if out.dim() != rank: + raise RuntimeError("aten::cat.out out rank mismatch") + + for d in range(rank): + if d == dim: + continue + base = tensors[0].shape[d] + if any(t.shape[d] != base for t in tensors[1:]): + raise RuntimeError( + f"aten::cat.out non-concatenated dimension mismatch at dim={d}" + ) + if out.shape[d] != base: + raise RuntimeError(f"aten::cat.out out shape mismatch at dim={d}") + + expected = sum(t.shape[dim] for t in tensors) + if out.shape[dim] != expected: + raise RuntimeError( + f"aten::cat.out out concatenated dimension mismatch at dim={dim}: " + f"expected {expected}, got {out.shape[dim]}" + ) + def eager_to_compile(op_name): """ Register an eager mode operation as a graph-based implementation using torch.compile(). @@ -267,6 +313,9 @@ def eager_to_compile(op_name): torch.npu.eager_to_compile("aten::mul.Tensor") """ def wrapper(*args, **kwargs): + if op_name == "aten::cat.out": + _precheck_cat_out_args(args, kwargs) + @torch.compile(dynamic=False) def dummy_graph(*args, **kwargs): # Convert "aten::mul.Tensor" -> torch.ops.aten.mul.Tensor diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py new file mode 100644 index 00000000..996af1de --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -0,0 +1,167 @@ +from typing import List, Optional, cast + +import sympy +from torch._inductor.ir import Buffer, IRNode +from torch._inductor.virtualized import V + +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel + + +TEMPLATE = r""" +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X0, X1], outputs=[Y], names_str=NAMES_STR, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X0", X0_TILE_DESC, id=0, indent_size=2) }} + {{ kernel.def_sram_buffer("X1", X1_TILE_DESC, id=1, indent_size=2) }} + {{ kernel.def_sram_buffer(OUT_DVAR, Y_TILE_DESC, id=2, indent_size=2) }} + {{ kernel.def_local_vars(indent_size=2) }} + + affine.for %cat_block = 0 to 1 step 1 { +{% if DIM == 0 %} + affine.for %index0 = 0 to {{ X0_ROWS }} step 1 { + affine.for %index1 = 0 to {{ COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X0", X0_IDX, X0_TILE_DESC, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y0_IDX, X0_TILE_DESC, indent_size=8) }} + } + } + + affine.for %index2 = 0 to {{ X1_ROWS }} step 1 { + affine.for %index3 = 0 to {{ COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X1", X1_IDX, X1_TILE_DESC, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y1_IDX, X1_TILE_DESC, indent_size=8) }} + } + } +{% else %} + affine.for %index0 = 0 to {{ ROWS }} step 1 { + affine.for %index1 = 0 to {{ X0_COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X0", X0_IDX, X0_TILE_DESC, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y0_IDX, X0_TILE_DESC, indent_size=8) }} + } + affine.for %index3 = 0 to {{ X1_COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X1", X1_IDX, X1_TILE_DESC, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y1_IDX, X1_TILE_DESC, indent_size=8) }} + } + } +{% endif %} + } { outer_loop=true } + return +} +""" + + +class MLIRCatTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, dim, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.dim = dim + + def render( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + tile_info=None, + **kwargs, + ): + is_out_variant = template_buffer_node is not None + if is_out_variant: + self.output_node = template_buffer_node + # cat template currently emits a single output buffer and does not + # support epilogue output remapping. + + def _unwrap_node(n): + return n.node if hasattr(n, "node") else n + + x0 = _unwrap_node(self.input_nodes[0]) + x1 = _unwrap_node(self.input_nodes[1]) + y = _unwrap_node(self.output_node) + + def _as_int(v): + try: + return int(v) + except Exception: + return int(V.graph.sizevars.size_hint(v)) + + x0_rows = _as_int(x0.get_size()[0]) + x1_rows = _as_int(x1.get_size()[0]) + x0_cols = _as_int(x0.get_size()[1]) + x1_cols = _as_int(x1.get_size()[1]) + y_cols = _as_int(y.get_size()[1]) + kernel.loop_size = None + + # 2D cat template with contiguous layout. + x0_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + x0_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + x0_tile_desc.set_name("x0_cat_tile") + x1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + x1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + x1_tile_desc.set_name("x1_cat_tile") + y_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + y_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + y_tile_desc.set_name("y_cat_tile") + + if self.dim == 0: + # Flattened offsets for dim=0 cat. + x0_idx = [sympy.Symbol("index0") * x0_cols, sympy.Symbol("index1")] + x1_idx = [sympy.Symbol("index2") * x1_cols, sympy.Symbol("index3")] + y0_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index1")] + y1_idx = [(sympy.Symbol("index2") + x0_rows) * y_cols, sympy.Symbol("index3")] + else: + # Flattened offsets for dim=1 cat. + x0_idx = [sympy.Symbol("index0") * x0_cols, sympy.Symbol("index1")] + x1_idx = [sympy.Symbol("index0") * x1_cols, sympy.Symbol("index3")] + y0_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index1")] + y1_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index3") + x0_cols] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X0=x0, + X1=x1, + Y=y, + OUT_DVAR="out_ptr1" if is_out_variant else "Y", + NAMES_STR="X0, X1, out_ptr1" if is_out_variant else "X0, X1, Y", + DIM=self.dim, + X0_ROWS=x0_rows, + X1_ROWS=x1_rows, + ROWS=x0_rows, + X0_COLS=x0_cols, + X1_COLS=x1_cols, + COLS=x0_cols, + X0_TILE_DESC=x0_tile_desc, + X1_TILE_DESC=x1_tile_desc, + Y_TILE_DESC=y_tile_desc, + X0_IDX=x0_idx, + X1_IDX=x1_idx, + Y0_IDX=y0_idx, + Y1_IDX=y1_idx, + input_reorder=self.input_reorder, + ) + # Needed when epilogue fusion requests set_ranges(). + kernel.dim_aliasing = {"index0": "index0", "index1": "index1"} + + if hasattr(self.output_node, "node") and hasattr(self.output_node.node, "get_name"): + output_node_name = self.output_node.node.get_name() + elif hasattr(self.output_node, "get_name"): + output_node_name = self.output_node.get_name() + else: + output_node_name = self.output_node.name + + if hasattr(y, "get_numel"): + y_numel = y.get_numel() + elif hasattr(y, "node") and hasattr(y.node, "get_numel"): + y_numel = y.node.get_numel() + else: + y_numel = None + + kernel.epilogue_info = dict( + output_node=output_node_name, + sram_var="y_cat_tile", + dram_var=kernel.render_options["OUT_DVAR"], + dram_tile_desc=y_tile_desc, + ) + if y_numel is not None: + kernel.exception_nodes[kernel.render_options["OUT_DVAR"]] = {"numel": y_numel} + + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + return code diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 34b185b8..256d7101 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -173,7 +173,11 @@ def get_mlir_shape(info): def mlir_argdefs(self, extra_node=dict()): buffer_types = {} for x in V.graph.buffers: - if not isinstance(x.layout, MultiOutputLayout): # FIXME: MultiOutputLayout should be handled + if isinstance(x.layout, MultiOutputLayout): + # MultiOutput kernel containers own concrete output nodes in `outputs`. + for out in getattr(x, "outputs", []): + buffer_types[out.get_name()] = [out.get_dtype(), out.get_numel(), out.get_size(), out.get_stride()] + else: buffer_types[x.get_name()] = [x.get_dtype(), x.get_numel(), x.get_size(), x.get_stride()] for name, val in V.graph.graph_inputs.items(): if isinstance(val, sympy.Expr): diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index ebf0c80e..0f28f03b 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -15,10 +15,15 @@ from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate +from PyTorchSimFrontend.mlir.mlir_cat_template import MLIRCatTemplate +from PyTorchSimFrontend.mlir.mlir_sort_template import MLIRSortTemplate from PyTorchSimFrontend import extension_config aten = torch.ops.aten aten_spmm = MLIRExternKernelChoice(torch.sparse.mm, "custom_op::sparse_addmm") +_orig_cat_default_lowering = lowerings.get(aten.cat.default) +_orig_cat_out_lowering = lowerings.get(aten.cat.out) +_orig_sort_values_stable_lowering = lowerings.get(aten.sort.values_stable) def tuned_mm(mat1, mat2, * ,layout=None): m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) @@ -181,11 +186,285 @@ def custom_unsafe_index(x, indices): x.realize() return index_impl(x, indices, check=False) + +def _cat_layout(tensors: Sequence[TensorBox], dim: int) -> ir.Layout: + with V.graph.fake_mode: + output = torch.ops.aten.cat( + [ir.ir_node_to_tensor(t, guard_shape=True) for t in tensors], + dim, + ) + sizes = ir.convert_shape_to_inductor(output.size()) + stride = ir.convert_shape_to_inductor(output.stride()) + return ir.FixedLayout( + tensors[0].get_device(), + tensors[0].get_dtype(), + sizes, + stride, + ) + + +def _can_use_cat_template(tensors: Sequence[TensorBox], dim: int) -> bool: + # Current template specialization: 2 inputs, rank-2, dim in {0, 1}. + if len(tensors) != 2: + return False + if not all(hasattr(t, "get_size") and hasattr(t, "get_dtype") and hasattr(t, "realize") for t in tensors): + return False + if tensors[0].get_dtype() != tensors[1].get_dtype(): + return False + rank0 = len(tensors[0].get_size()) + rank1 = len(tensors[1].get_size()) + if rank0 != 2 or rank1 != 2: + return False + if dim < 0: + dim += rank0 + if dim not in (0, 1): + return False + + if dim == 0: + cols0 = tensors[0].get_size()[1] + cols1 = tensors[1].get_size()[1] + return V.graph.sizevars.statically_known_equals(cols0, cols1) + + rows0 = tensors[0].get_size()[0] + rows1 = tensors[1].get_size()[0] + return V.graph.sizevars.statically_known_equals(rows0, rows1) + + +def _cat_fallback(reason: str, tensors: Sequence[TensorBox], dim: int): + # Non-template cases delegate to the original lowering path. + return _orig_cat_default_lowering(tensors, dim) + + +def _custom_cat_impl(tensors: Sequence[TensorBox], dim: int = 0): + if _orig_cat_default_lowering is None: + raise RuntimeError("Original aten.cat.default lowering is missing") + if len(tensors) > 0: + rank = len(tensors[0].get_size()) + if dim < 0: + dim += rank + if not _can_use_cat_template(tensors, dim): + return _cat_fallback("default-path", tensors, dim) + + for t in tensors: + t.realize() + layout = _cat_layout(tensors, dim) + mlir_template = MLIRCatTemplate(list(tensors), layout, dim=dim) + return mlir_template.generate().output_node() + + +def custom_cat_default(tensors: Sequence[TensorBox], dim: int = 0): + return _custom_cat_impl(tensors, dim) + + +def custom_cat_out(tensors: Sequence[TensorBox], dim: int = 0, out: Optional[TensorBox] = None): + if _orig_cat_out_lowering is None: + raise RuntimeError("Original aten.cat.out lowering is missing") + if out is None: + return _orig_cat_out_lowering(tensors, dim, out) + + copy_default_lowering = lowerings.get(aten.copy_.default) + slice_tensor_lowering = lowerings.get(aten.slice.Tensor) + if copy_default_lowering is None or slice_tensor_lowering is None: + raise RuntimeError("cat.out lowering requires aten.copy_.default and aten.slice.Tensor lowerings") + + # Lower cat.out as a sequence of slice+copy ops so each piece still runs + # through the existing compiled/simulated kernel path. + if len(tensors) == 0: + raise RuntimeError("cat.out requires at least one input tensor") + if not all(hasattr(t, "get_size") and hasattr(t, "get_dtype") and hasattr(t, "realize") for t in tensors): + raise RuntimeError("cat.out inputs must be tensor-like values") + rank = len(tensors[0].get_size()) + if rank == 0: + raise RuntimeError("cat.out does not support scalar inputs") + if dim < 0: + dim = dim + rank + if dim < 0 or dim >= rank: + raise RuntimeError(f"cat.out dim out of range: dim={dim}, rank={rank}") + if any(len(t.get_size()) != rank for t in tensors): + raise RuntimeError("cat.out inputs must have the same rank") + if any(t.get_dtype() != tensors[0].get_dtype() for t in tensors): + raise RuntimeError("cat.out inputs must have the same dtype") + # cat semantics: all non-cat dimensions must be equal. + for i in range(rank): + if i == dim: + continue + base = tensors[0].get_size()[i] + if any(not V.graph.sizevars.statically_known_equals(base, t.get_size()[i]) for t in tensors[1:]): + raise RuntimeError(f"cat.out non-concatenated dimension mismatch at dim={i}") + + # Output shape must match concatenated shape. + if not hasattr(out, "get_size"): + raise RuntimeError("cat.out output must be tensor-like") + out_sizes = list(out.get_size()) + if len(out_sizes) != rank: + raise RuntimeError("cat.out output rank mismatch") + for i in range(rank): + if i == dim: + continue + if not V.graph.sizevars.statically_known_equals(out_sizes[i], tensors[0].get_size()[i]): + raise RuntimeError(f"cat.out output shape mismatch at dim={i}") + expected_cat = sum(t.get_size()[dim] for t in tensors) + if not V.graph.sizevars.statically_known_equals(out_sizes[dim], expected_cat): + raise RuntimeError(f"cat.out output concatenated dimension mismatch at dim={dim}") + + if isinstance(out, TensorBox): + out.realize() + + offset = 0 + for src in tensors: + src.realize() + end = offset + src.get_size()[dim] + dst_view = slice_tensor_lowering(out, dim, offset, end, 1) + copy_default_lowering(dst_view, src) + offset = end + return out + + +def _custom_sort_values_impl( + self: TensorBox, + dim: int = -1, + descending: bool = False, + values: Optional[TensorBox] = None, + indices: Optional[TensorBox] = None, + stable: Optional[bool] = None, +): + if values is None or indices is None: + raise RuntimeError("sort.values* lowering requires both out tensors: values, indices") + + def _normalize_dim(rank: int, d: int) -> int: + return d + rank if d < 0 else d + + if not hasattr(self, "get_size"): + raise RuntimeError("sort.values* lowering requires TensorBox input") + + rank = len(self.get_size()) + norm_dim = _normalize_dim(rank, dim) + if norm_dim < 0 or norm_dim >= rank: + raise RuntimeError(f"sort.values* dim out of range: dim={dim}, rank={rank}") + if rank != 2: + raise RuntimeError(f"sort.values* lowering currently supports rank-2 only, got rank={rank}") + if norm_dim not in (0, 1): + raise RuntimeError(f"sort.values* lowering currently supports dim in {{0,1}} only, got dim={norm_dim}") + + self.realize() + if isinstance(values, TensorBox): + values.realize() + if isinstance(indices, TensorBox): + indices.realize() + + value_layout, _ = _sort_layouts(self, norm_dim, descending) + mlir_template = MLIRSortTemplate( + [self], + value_layout, + dim=norm_dim, + descending=descending, + stable=True if stable is None else stable, + indices_node=indices, + ) + sorted_values = mlir_template.generate(template_buffer_node=values, epilogue_nodes=[indices]).output_node() + return sorted_values, indices + + +def _sort_layouts(x: TensorBox, dim: int, descending: bool): + with V.graph.fake_mode: + v, i = torch.ops.aten.sort( + ir.ir_node_to_tensor(x, guard_shape=True), + dim, + descending, + ) + v_sizes = ir.convert_shape_to_inductor(v.size()) + v_stride = ir.convert_shape_to_inductor(v.stride()) + i_sizes = ir.convert_shape_to_inductor(i.size()) + i_stride = ir.convert_shape_to_inductor(i.stride()) + + value_layout = ir.FixedLayout(x.get_device(), x.get_dtype(), v_sizes, v_stride) + index_layout = ir.FixedLayout(x.get_device(), torch.int64, i_sizes, i_stride) + return value_layout, index_layout + + +def custom_sort_stable( + self: TensorBox, + *, + stable: Optional[bool] = None, + dim: int = -1, + descending: bool = False, +): + empty_strided_lowering = lowerings.get(aten.empty_strided.default) + if empty_strided_lowering is None: + if _orig_sort_values_stable_lowering is None: + raise RuntimeError("sort.stable lowering requires aten.empty_strided.default") + return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=True) + + rank = len(self.get_size()) if hasattr(self, "get_size") else 0 + norm_dim = dim + rank if dim < 0 else dim + if rank > 0 and (norm_dim < 0 or norm_dim >= rank): + raise RuntimeError(f"sort.stable dim out of range: dim={dim}, rank={rank}") + + # Template specialization supports rank-2 and dim in {0,1}. + if rank == 2 and norm_dim not in (0, 1): + if _orig_sort_values_stable_lowering is None: + raise RuntimeError("Original aten.sort.values_stable lowering is missing") + return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=True) + + try: + value_layout, index_layout = _sort_layouts(self, norm_dim, descending) + values = empty_strided_lowering( + list(value_layout.size), + list(value_layout.stride), + dtype=value_layout.dtype, + device=self.get_device(), + ) + indices = empty_strided_lowering( + list(index_layout.size), + list(index_layout.stride), + dtype=index_layout.dtype, + device=self.get_device(), + ) + return _custom_sort_values_impl( + self=self, + dim=dim, + descending=descending, + values=values, + indices=indices, + stable=True if stable is None else stable, + ) + except Exception: + if _orig_sort_values_stable_lowering is None: + raise + return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=stable) + + +def custom_sort_values_stable( + self: TensorBox, + *, + stable: Optional[bool] = None, + dim: int = -1, + descending: bool = False, + values: Optional[TensorBox] = None, + indices: Optional[TensorBox] = None, +): + return _custom_sort_values_impl( + self=self, + dim=dim, + descending=descending, + values=values, + indices=indices, + stable=stable, + ) + + lowerings.update({getattr(aten.mm, overload): tuned_mm for overload in aten.mm.overloads()}) lowerings.update({getattr(aten.addmm, overload): tuned_addmm for overload in aten.addmm.overloads()}) lowerings.update({getattr(aten.convolution, overload): convolution for overload in aten.convolution.overloads()}) lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()}) lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()}) + +lowerings.update({aten.cat.default: custom_cat_default}) +lowerings.update({aten.cat.out: custom_cat_out}) + +lowerings.update({aten.sort.stable: custom_sort_stable}) +lowerings.update({aten.sort.values_stable: custom_sort_values_stable}) + if extension_config.CONFIG_USE_TIMING_POOLING: - lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template \ No newline at end of file + lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template diff --git a/PyTorchSimFrontend/mlir/mlir_sort_template.py b/PyTorchSimFrontend/mlir/mlir_sort_template.py new file mode 100644 index 00000000..d12c7570 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_sort_template.py @@ -0,0 +1,253 @@ +from typing import List, Optional + +import sympy +from torch._inductor.ir import IRNode +from torch._inductor.virtualized import V + +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel + + +TEMPLATE = r""" +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X, YI], outputs=[YV], names_str=NAMES_STR, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("YI", YI_TILE_DESC, id=1, indent_size=2) }} + {{ kernel.def_sram_buffer(OUT_DVAR, YV_TILE_DESC, id=2, indent_size=2) }} + {{ kernel.def_local_vars(indent_size=2) }} + + %c0 = arith.constant 0 : index + %c_cols = arith.constant {{ COLS }} : index + + affine.for %sort_block = 0 to 1 step 1 { + // Initialize output value/index buffers. + affine.for %row = 0 to {{ ROWS }} step 1 { + affine.for %col = 0 to {{ COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X", INIT_X_IDX, X_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, INIT_YV_IDX, X_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} +{% if DIM == 1 %} + %idx_i64 = arith.index_cast %col : index to {{ YI_ELEM_TYPE }} +{% else %} + %idx_i64 = arith.index_cast %row : index to {{ YI_ELEM_TYPE }} +{% endif %} + memref.store %idx_i64, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", INIT_YI_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} + } + } + +{% if DIM == 1 %} + // Stable bubble sort on each row (dim=1). + affine.for %row = 0 to {{ ROWS }} step 1 { + affine.for %pass = 0 to {{ COLS }} step 1 { + affine.for %j = 0 to {{ COLS_MINUS1 }} step 1 { + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D1_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %lhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D1_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %rhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + +{% if DESCENDING %} + %need_swap = arith.cmpf olt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% else %} + %need_swap = arith.cmpf ogt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% endif %} + scf.if %need_swap { + memref.store %rhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D1_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %lhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D1_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + {{ kernel.def_dma_op("MVIN", "YI", D1_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %li = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", "YI", D1_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %ri = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + memref.store %ri, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D1_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %li, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D1_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + } + } + } + } +{% else %} + // Stable bubble sort on each column (dim=0). + affine.for %col = 0 to {{ COLS }} step 1 { + affine.for %pass = 0 to {{ ROWS }} step 1 { + affine.for %i = 0 to {{ ROWS_MINUS1 }} step 1 { + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D0_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %lhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D0_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %rhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + +{% if DESCENDING %} + %need_swap = arith.cmpf olt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% else %} + %need_swap = arith.cmpf ogt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% endif %} + scf.if %need_swap { + memref.store %rhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D0_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %lhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D0_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + {{ kernel.def_dma_op("MVIN", "YI", D0_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %li = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", "YI", D0_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %ri = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + memref.store %ri, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D0_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %li, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D0_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + } + } + } + } +{% endif %} + } { outer_loop=true } + return +} +""" + + +class MLIRSortTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, dim, descending=False, stable=False, indices_node=None, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.dim = dim + self.descending = descending + self.stable = stable + self.indices_node = indices_node + + def render( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + tile_info=None, + **kwargs, + ): + if template_buffer_node is not None: + self.output_node = template_buffer_node + if self.indices_node is None: + raise RuntimeError("MLIRSortTemplate requires indices output node") + + x = self.input_nodes[0] + yv = self.output_node + yi = self.indices_node + + def _as_int(v): + try: + return int(v) + except Exception: + return int(V.graph.sizevars.size_hint(v)) + + x_size = x.get_size() + if len(x_size) != 2: + raise RuntimeError("MLIRSortTemplate currently supports rank-2 input only") + if self.dim not in (0, 1): + raise RuntimeError(f"MLIRSortTemplate currently supports dim in {{0,1}} only, got dim={self.dim}") + + rows = _as_int(x_size[0]) + cols = _as_int(x_size[1]) + cols_minus1 = max(0, cols - 1) + rows_minus1 = max(0, rows - 1) + + x_dtype = x.get_dtype() + yv_dtype = yv.get_dtype() + yi_dtype = yi.get_dtype() + if x_dtype != yv_dtype: + raise RuntimeError("sort template requires input/value dtype match") + + yi_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yi_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yi_tile_desc.set_name("yi_sort_tile") + yv_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yv_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yv_tile_desc.set_name("yv_sort_tile") + # Neighbor element descriptors use DRAM offset to preserve affine stride metadata. + yv_s1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yv_s1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yv_s1_tile_desc.set_name("yv_sort_tile") + yi_s1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yi_s1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yi_s1_tile_desc.set_name("yi_sort_tile") + if int(self.dim) == 1: + yv_s1_tile_desc.offset = sympy.Integer(1) + yi_s1_tile_desc.offset = sympy.Integer(1) + else: + yv_s1_tile_desc.offset = sympy.Integer(cols) + yi_s1_tile_desc.offset = sympy.Integer(cols) + + row = sympy.Symbol("row") + col = sympy.Symbol("col") + i = sympy.Symbol("i") + j = sympy.Symbol("j") + + init_x_idx = [row * cols, col] + init_yv_idx = [row * cols, col] + init_yi_idx = [row * cols, col] + + d1_s0_idx = [row * cols, j] + d1_s1_idx = [row * cols, j] + + d0_s0_idx = [i * cols, col] + d0_s1_idx = [i * cols, col] + + kernel.loop_size = None + numel = rows * cols + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X=x, + YV=yv, + YI=yi, + OUT_DVAR="YV", + NAMES_STR="X, YI, YV", + ROWS=rows, + COLS=cols, + COLS_MINUS1=cols_minus1, + ROWS_MINUS1=rows_minus1, + DIM=int(self.dim), + DESCENDING=bool(self.descending), + YI_TILE_DESC=yi_tile_desc, + YV_TILE_DESC=yv_tile_desc, + YI_S1_TILE_DESC=yi_s1_tile_desc, + YV_S1_TILE_DESC=yv_s1_tile_desc, + INIT_X_IDX=init_x_idx, + INIT_YV_IDX=init_yv_idx, + INIT_YI_IDX=init_yi_idx, + D1_S0_IDX=d1_s0_idx, + D1_S1_IDX=d1_s1_idx, + D0_S0_IDX=d0_s0_idx, + D0_S1_IDX=d0_s1_idx, + YV_ELEM_TYPE=mlir_common.DTYPE_TO_MLIR[yv_dtype], + YI_ELEM_TYPE=mlir_common.DTYPE_TO_MLIR[yi_dtype], + X_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[x_dtype]}>", + YV_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[yv_dtype]}>", + YI_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[yi_dtype]}>", + YV_TILE_MEMREF_TYPE=yv_tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[yv_dtype]), + YI_TILE_MEMREF_TYPE=yi_tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[yi_dtype]), + X_TILE_DESC=yv_tile_desc, + input_reorder=self.input_reorder, + ) + + output_node_name = yv.get_name() if hasattr(yv, "get_name") else yv.name + kernel.epilogue_info = dict( + output_node=output_node_name, + sram_var="yv_sort_tile", + dram_var=kernel.render_options["OUT_DVAR"], + dram_tile_desc=yv_tile_desc, + ) + kernel.exception_nodes[kernel.render_options["OUT_DVAR"]] = {"numel": yv.get_numel()} + kernel.exception_nodes["YI"] = {"numel": yi.get_numel()} + + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + return code diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index b1c756ba..76b0ef71 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -403,7 +403,7 @@ def call_kernel(self, kernel_name): _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this wrapper.generate_kernel_call( - kernel_name if self.outer_func_name is None else "wrapper_" + kernel_name, call_args) + kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args) def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_nodes, tile_info): with self as kernel: @@ -628,8 +628,26 @@ def def_kernel( self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] def hook(): - arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) - return f"({', '.join(arg_defs)})" + arg_defs, call_args, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) + output_names = names[len(inputs) : len(inputs) + len(outputs)] + out_ptr_idx = 0 + renamed_arg_defs = [] + for outer, arg_def in zip(call_args, arg_defs): + raw_symbol = arg_def.split(":", 1)[0].strip().lstrip("%") + if outer in self.kernel_group.args.input_buffers: + symbol = self.kernel_group.args.input_buffers[outer] + elif outer in self.kernel_group.args.output_buffers: + symbol = self.kernel_group.args.output_buffers[outer] + elif raw_symbol.startswith("out_ptr") and out_ptr_idx < len(output_names): + symbol = output_names[out_ptr_idx] + out_ptr_idx += 1 + elif outer in self.kernel_group.args.sizevars: + symbol = self.kernel_group.args.sizevars[outer] + else: + symbol = raw_symbol + _, arg_type = arg_def.split(":", 1) + renamed_arg_defs.append(f"%{symbol}:{arg_type}") + return f"({', '.join(renamed_arg_defs)})" assert "" not in self.render_hooks self.render_hooks[""] = hook @@ -1151,6 +1169,8 @@ def __init__(self, name, input_nodes, layout, input_reorder = None): super().__init__(name) self.input_nodes = [node for node in input_nodes if node is not None] self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + # Multi-output templates can override this with explicit output buffers. + self.output_nodes = [self.output_node] self.input_reorder = input_reorder self.layout = layout @@ -1166,10 +1186,12 @@ def generate(self, **kwargs) -> ChoiceCaller: kernel_hash_name = f"mlir_{self.name}_{next(self.index_counter)}" extra_args = [] # create the BenchmarkRequest + output_nodes = getattr(self, "output_nodes", None) or [self.output_node] + bmreq = MLIRBenchmarkRequest( kernel_name=kernel_name, input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), - output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + output_tensor_meta=TensorMeta.from_irnodes(output_nodes), extra_args=extra_args, source_code=code, ) diff --git a/tests/DeepSeek/test_deepseek_v3_base.py b/tests/DeepSeek/test_deepseek_v3_base.py index b8402c8b..ade787c5 100644 --- a/tests/DeepSeek/test_deepseek_v3_base.py +++ b/tests/DeepSeek/test_deepseek_v3_base.py @@ -1,8 +1,55 @@ import os import sys import argparse +import copy +from pathlib import Path import torch +# recursive compile for some ops that are caused by graph break +torch.npu.register_eager_to_compile([ + "aten::zero_", + "aten::sum.IntList_out", + "aten::mul.out", + "aten::floor_divide", + "aten::floor_divide.Tensor", + "aten::floor_divide.Scalar", + "aten::cat.out", + "aten::sort.values_stable", +]) + + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + out_cpu = out.cpu() + max_diff = (out_cpu - cpu_out).abs().max().item() + mean_diff = (out_cpu - cpu_out).abs().mean().item() + if torch.allclose(out_cpu, cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print(f"Max absolute difference: {max_diff:.6f}") + print(f"Mean absolute difference: {mean_diff:.6f}") + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("NPU out: ", out_cpu) + print("CPU out: ", cpu_out) + print(f"Max absolute difference: {max_diff:.6f}") + print(f"Mean absolute difference: {mean_diff:.6f}") + exit(1) + + +def _extract_logits(output): + if isinstance(output, torch.Tensor): + return output + if hasattr(output, "logits"): + return output.logits + if isinstance(output, (list, tuple)) and len(output) > 0 and isinstance(output[0], torch.Tensor): + return output[0] + raise TypeError(f"Unsupported output type for comparison: {type(output)}") + def _dtype_from_str(name: str) -> torch.dtype: return { @@ -81,7 +128,7 @@ def _maybe_scale_config(config, scale=1.0, max_layers=None): def _apply_preset(scale, max_layers, batch, seq_len, preset): if preset == "tiny": - return 0.03, 4, 1, min(seq_len, 16) + return 0.03, 1, 1, min(seq_len, 16) if preset == "small": return 0.07, 8, 1, min(seq_len, 32) if preset == "medium": @@ -89,8 +136,58 @@ def _apply_preset(scale, max_layers, batch, seq_len, preset): return scale, max_layers, batch, seq_len +def _togsim_log_count() -> int: + log_dir = Path("togsim_results") + if not log_dir.exists(): + return 0 + return len(list(log_dir.glob("*.log"))) + + +def _assert_simulation_happened(before_count: int, case_name: str): + after_count = _togsim_log_count() + if after_count <= before_count: + raise RuntimeError( + f"{case_name}: TOGSim log count did not increase " + f"(before={before_count}, after={after_count})" + ) + print(f"{case_name}: TOGSim logs increased ({before_count} -> {after_count})") + + +def test_cat_default(device): + def cat_default_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_default_fn) + + before = _togsim_log_count() + out = opt_fn(x, y) + _assert_simulation_happened(before, "cat.default") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + test_result("cat.default", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_out(device): + def cat_out_fn(a, b, out): + return torch.ops.aten.cat.out([a, b], 0, out=out) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + out_buf = torch.empty(14, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_out_fn) + + before = _togsim_log_count() + out = opt_fn(x, y, out_buf) + _assert_simulation_happened(before, "cat.out") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + test_result("cat.out", out, cpu_out, rtol=1e-4, atol=1e-4) + + @torch.no_grad() -def run_deep_seek_v3_base_test( +def run_deepseek_v3_base( model_id, device, init_mode="config-random", @@ -120,7 +217,6 @@ def run_deep_seek_v3_base_test( # (call .to_dict()), so only disable it for pretrained loading path. if init_mode == "pretrained" and getattr(config, "quantization_config", None) is not None: config.quantization_config = None - config = _maybe_scale_config(config, scale=scale, max_layers=max_layers) if init_mode == "config-random": @@ -141,7 +237,6 @@ def run_deep_seek_v3_base_test( else: raise ValueError(f"Unsupported init mode: {init_mode}") - model = model.to(device) model_params = sum(p.numel() for p in model.parameters()) print("init mode:", init_mode) print("scaled hidden_size:", getattr(config, "hidden_size", "n/a")) @@ -157,23 +252,33 @@ def run_deep_seek_v3_base_test( revision=revision, ) encoded = tokenizer(prompt, return_tensors="pt") - input_ids = encoded["input_ids"].to(device) + cpu_input_ids = encoded["input_ids"].cpu() else: vocab_size = getattr(config, "vocab_size", None) if vocab_size is None: raise ValueError("Config has no vocab_size; use --use-tokenizer or pass a model with vocab_size.") - input_ids = _build_random_inputs(batch, seq_len, vocab_size, device) + cpu_input_ids = _build_random_inputs(batch, seq_len, vocab_size, torch.device("cpu")) + input_ids = cpu_input_ids.to(device) - if compile_model: - model = torch.compile(model, dynamic=False) + # CPU version + model_cpu = copy.deepcopy(model).cpu().eval() + cpu_out = _extract_logits(model_cpu(cpu_input_ids)) - out = model(input_ids) - logits = out.logits + # NPU version + model_npu = copy.deepcopy(model_cpu).to(device).eval() + if compile_model: + model_npu = torch.compile(model_npu, dynamic=False) + npu_out = _extract_logits(model_npu(input_ids)) + + # Campare results + test_result( + "DeepSeek V3 Base", + npu_out, + cpu_out, + rtol=3e-1, + atol=2e-1, + ) - print("logits shape:", tuple(logits.shape)) - print("logits dtype:", logits.dtype) - print("logits max:", logits.max().item()) - if __name__ == "__main__": parser = argparse.ArgumentParser(description="DeepSeek V3 download-based test") @@ -181,7 +286,7 @@ def run_deep_seek_v3_base_test( parser.add_argument("--revision", type=str, default=None) parser.add_argument("--trust-remote-code", action="store_true", default=True) parser.add_argument("--init-mode", type=str, default="config-random", choices=["config-random", "pretrained"]) - parser.add_argument("--preset", type=str, default="tiny", choices=["none", "tiny", "small", "medium"]) + parser.add_argument("--preset", type=str, default="small", choices=["none", "tiny", "small", "medium"]) parser.add_argument("--scale", type=float, default=1.0) parser.add_argument("--max-layers", type=int, default=None) parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) @@ -190,6 +295,7 @@ def run_deep_seek_v3_base_test( parser.add_argument("--use-tokenizer", action="store_true") parser.add_argument("--prompt", type=str, default="Hello, DeepSeek V3") parser.add_argument("--compile", action="store_true", default=True) + parser.add_argument("--test", type=str, default="e2e", choices=["all", "e2e", "cat"]) args = parser.parse_args() @@ -203,18 +309,22 @@ def run_deep_seek_v3_base_test( device = torch.device("npu:0") - run_deep_seek_v3_base_test( - model_id=args.model_id, - device=device, - init_mode=args.init_mode, - scale=args.scale, - max_layers=args.max_layers, - dtype=args.dtype, - batch=args.batch, - seq_len=args.seq_len, - use_tokenizer=args.use_tokenizer, - prompt=args.prompt, - trust_remote_code=args.trust_remote_code, - revision=args.revision, - compile_model=args.compile, - ) + if args.test in ("all", "cat"): + test_cat_default(device) + test_cat_out(device) + if args.test in ("all", "e2e"): + run_deepseek_v3_base( + model_id=args.model_id, + device=device, + init_mode=args.init_mode, + scale=args.scale, + max_layers=args.max_layers, + dtype=args.dtype, + batch=args.batch, + seq_len=args.seq_len, + use_tokenizer=args.use_tokenizer, + prompt=args.prompt, + trust_remote_code=args.trust_remote_code, + revision=args.revision, + compile_model=args.compile, + ) diff --git a/tests/test_cat.py b/tests/test_cat.py new file mode 100644 index 00000000..32573a05 --- /dev/null +++ b/tests/test_cat.py @@ -0,0 +1,89 @@ +import argparse +from pathlib import Path + +import torch + + +def _test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + return + + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + raise RuntimeError(f"{name} mismatch") + + +def _togsim_log_count() -> int: + log_dir = Path("togsim_results") + if not log_dir.exists(): + return 0 + return len(list(log_dir.glob("*.log"))) + + +def _assert_simulation_happened(before_count: int, case_name: str): + after_count = _togsim_log_count() + if after_count <= before_count: + raise RuntimeError( + f"{case_name}: TOGSim log count did not increase " + f"(before={before_count}, after={after_count})" + ) + print(f"{case_name}: TOGSim logs increased ({before_count} -> {after_count})") + + +def test_cat_default(device): + def cat_default_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_default_fn) + + before = _togsim_log_count() + out = opt_fn(x, y) + _assert_simulation_happened(before, "cat.default") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.default", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_out(device): + def cat_out_fn(a, b, out): + return torch.ops.aten.cat.out([a, b], 0, out=out) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + out_buf = torch.empty(14, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_out_fn) + + before = _togsim_log_count() + out = opt_fn(x, y, out_buf) + _assert_simulation_happened(before, "cat.out") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.out", out, cpu_out, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run cat simulation tests") + parser.add_argument( + "--case", + choices=["default", "out", "all"], + default="all", + help="Which cat case to run", + ) + args = parser.parse_args() + + device = torch.device("npu:0") + + if args.case in ("default", "all"): + test_cat_default(device) + if args.case in ("out", "all"): + test_cat_out(device) diff --git a/tests/test_sort.py b/tests/test_sort.py new file mode 100644 index 00000000..2b070223 --- /dev/null +++ b/tests/test_sort.py @@ -0,0 +1,112 @@ +import argparse +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out:", out.cpu()) + print("cpu out:", cpu_out) + raise SystemExit(1) + + +def test_equal(name, out, cpu_out): + if torch.equal(out.cpu(), cpu_out): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out:", out.cpu()) + print("cpu out:", cpu_out) + raise SystemExit(1) + + +def _normalize_dim(dim: int, rank: int) -> int: + d = dim if dim >= 0 else rank + dim + if d < 0 or d >= rank: + raise ValueError(f"dim out of range: dim={dim}, rank={rank}") + return d + + +def test_sort_stable(device, size=(128, 128), dim=-1, descending=False): + _normalize_dim(dim, len(size)) + + def sort_stable_fn(x): + return torch.sort(x, stable=True, dim=dim, descending=descending) + + x = torch.randn(size, dtype=torch.float32) + x_npu = x.to(device=device) + + opt_sort = torch.compile(dynamic=False)(sort_stable_fn) + out_values, out_indices = opt_sort(x_npu) + + ref_values, ref_indices = torch.sort(x, stable=True, dim=dim, descending=descending) + + test_result("Sort.stable/values", out_values, ref_values) + test_equal("Sort.stable/indices", out_indices, ref_indices) + + +def test_sort_values_stable(device, size=(128, 128), dim=-1, descending=False): + _normalize_dim(dim, len(size)) + + def sort_out_fn(x): + out_values = torch.empty_like(x, device=x.device) + out_indices = torch.empty_like(x, dtype=torch.int64, device=x.device) + return torch.sort(x, stable=True, dim=dim, descending=descending, out=(out_values, out_indices)) + + x = torch.randn(size, dtype=torch.float32) + x_npu = x.to(device=device) + + opt_sort = sort_out_fn# torch.compile(dynamic=False)(sort_out_fn) + out_values, out_indices = opt_sort(x_npu) + + ref_values, ref_indices = torch.sort(x, stable=True, dim=dim, descending=descending) + + test_result("Sort.values_stable/values", out_values, ref_values) + test_equal("Sort.values_stable/indices", out_indices, ref_indices) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run sort tests") + parser.add_argument("--shape", type=str, default="(128,128)") + parser.add_argument("--dim", type=int, default=0) + parser.add_argument("--descending", action="store_true") + parser.add_argument( + "--mode", + type=str, + default="all", + choices=["all", "default", "values"], + ) + args = parser.parse_args() + + shape = tuple(map(int, args.shape.strip("()").split(","))) + + from Scheduler.scheduler import PyTorchSimRunner + + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + # Register recursive-compile bridge only when values_stable path is explicitly tested. + if args.mode in ("all", "values"): + torch.npu.register_eager_to_compile([ + "aten::sort.values_stable", + ]) + + if args.mode in ("all", "default"): + test_sort_stable(device, size=shape, dim=args.dim, descending=args.descending) + if args.mode in ("all", "values"): + test_sort_values_stable(device, size=shape, dim=args.dim, descending=args.descending) From 41288bc2d300305d91559ae49a67f11984f789c0 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Tue, 3 Mar 2026 16:40:57 +0900 Subject: [PATCH 2/9] [Template] Polish template kernel of cat operation --- .../torch_openreg/openreg/__init__.py | 49 --- PyTorchSimFrontend/mlir/mlir_bmm_template.py | 3 + PyTorchSimFrontend/mlir/mlir_cat_template.py | 369 ++++++++++++------ PyTorchSimFrontend/mlir/mlir_conv_common.py | 3 + PyTorchSimFrontend/mlir/mlir_gemm_template.py | 3 + PyTorchSimFrontend/mlir/mlir_lowering.py | 118 +----- PyTorchSimFrontend/mlir/mlir_scheduling.py | 22 +- PyTorchSimFrontend/mlir/mlir_template.py | 43 +- tests/test_cat.py | 143 +++++-- 9 files changed, 424 insertions(+), 329 deletions(-) diff --git a/PyTorchSimDevice/torch_openreg/openreg/__init__.py b/PyTorchSimDevice/torch_openreg/openreg/__init__.py index 5603a4f7..f5aabc18 100644 --- a/PyTorchSimDevice/torch_openreg/openreg/__init__.py +++ b/PyTorchSimDevice/torch_openreg/openreg/__init__.py @@ -256,52 +256,6 @@ def launch_model(model, *args, stream_index=0, timestamp=0, **kwargs): from .random import * # noqa: F403 from .amp import * -def _precheck_cat_out_args(args, kwargs): - tensors = args[0] if len(args) > 0 else kwargs.get("tensors") - dim = args[1] if len(args) > 1 else kwargs.get("dim", 0) - out = kwargs.get("out", args[2] if len(args) > 2 else None) - - if out is None: - return - if not isinstance(tensors, (list, tuple)) or len(tensors) == 0: - raise RuntimeError("aten::cat.out requires non-empty tensor list") - if not all(isinstance(t, torch.Tensor) for t in tensors): - raise RuntimeError("aten::cat.out tensors must be Tensor values") - if not isinstance(out, torch.Tensor): - raise RuntimeError("aten::cat.out out must be a Tensor") - - rank = tensors[0].dim() - if rank == 0: - raise RuntimeError("aten::cat.out does not support scalar inputs") - if dim < 0: - dim += rank - if dim < 0 or dim >= rank: - raise RuntimeError(f"aten::cat.out dim out of range: dim={dim}, rank={rank}") - if any(t.dim() != rank for t in tensors): - raise RuntimeError("aten::cat.out inputs must have the same rank") - if any(t.dtype != tensors[0].dtype for t in tensors): - raise RuntimeError("aten::cat.out inputs must have the same dtype") - if out.dim() != rank: - raise RuntimeError("aten::cat.out out rank mismatch") - - for d in range(rank): - if d == dim: - continue - base = tensors[0].shape[d] - if any(t.shape[d] != base for t in tensors[1:]): - raise RuntimeError( - f"aten::cat.out non-concatenated dimension mismatch at dim={d}" - ) - if out.shape[d] != base: - raise RuntimeError(f"aten::cat.out out shape mismatch at dim={d}") - - expected = sum(t.shape[dim] for t in tensors) - if out.shape[dim] != expected: - raise RuntimeError( - f"aten::cat.out out concatenated dimension mismatch at dim={dim}: " - f"expected {expected}, got {out.shape[dim]}" - ) - def eager_to_compile(op_name): """ Register an eager mode operation as a graph-based implementation using torch.compile(). @@ -313,9 +267,6 @@ def eager_to_compile(op_name): torch.npu.eager_to_compile("aten::mul.Tensor") """ def wrapper(*args, **kwargs): - if op_name == "aten::cat.out": - _precheck_cat_out_args(args, kwargs) - @torch.compile(dynamic=False) def dummy_graph(*args, **kwargs): # Convert "aten::mul.Tensor" -> torch.ops.aten.mul.Tensor diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 178ea987..9398f90c 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -154,6 +154,9 @@ class MLIRBMMTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = True + self.support_reduction_fusion = True def render(self, kernel: MLIRTemplateKernel, diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py index 996af1de..d68af7d4 100644 --- a/PyTorchSimFrontend/mlir/mlir_cat_template.py +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -1,8 +1,9 @@ -from typing import List, Optional, cast +from typing import List, Optional +import math +import itertools import sympy -from torch._inductor.ir import Buffer, IRNode -from torch._inductor.virtualized import V +from torch._inductor.ir import IRNode from PyTorchSimFrontend.mlir import mlir_common from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel @@ -10,40 +11,28 @@ TEMPLATE = r""" {{kernel.def_global_vars()}} - -func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X0, X1], outputs=[Y], names_str=NAMES_STR, input_reorder=input_reorder)}} { - {{ kernel.def_sram_buffer("X0", X0_TILE_DESC, id=0, indent_size=2) }} - {{ kernel.def_sram_buffer("X1", X1_TILE_DESC, id=1, indent_size=2) }} - {{ kernel.def_sram_buffer(OUT_DVAR, Y_TILE_DESC, id=2, indent_size=2) }} +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=INPUT_NAMES, outputs=[Y], names_str=NAMES_STR, input_reorder=input_reorder)}} { +{%- for buffer_name, tile_desc in UNIQUE_BUFFER_TILE_DESCS.items() %} + {{ kernel.def_sram_buffer(buffer_name, tile_desc, indent_size=2) }} +{%- endfor %} {{ kernel.def_local_vars(indent_size=2) }} affine.for %cat_block = 0 to 1 step 1 { -{% if DIM == 0 %} - affine.for %index0 = 0 to {{ X0_ROWS }} step 1 { - affine.for %index1 = 0 to {{ COLS }} step 1 { - {{ kernel.def_dma_op("MVIN", "X0", X0_IDX, X0_TILE_DESC, indent_size=8) }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y0_IDX, X0_TILE_DESC, indent_size=8) }} - } - } - - affine.for %index2 = 0 to {{ X1_ROWS }} step 1 { - affine.for %index3 = 0 to {{ COLS }} step 1 { - {{ kernel.def_dma_op("MVIN", "X1", X1_IDX, X1_TILE_DESC, indent_size=8) }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y1_IDX, X1_TILE_DESC, indent_size=8) }} - } - } -{% else %} - affine.for %index0 = 0 to {{ ROWS }} step 1 { - affine.for %index1 = 0 to {{ X0_COLS }} step 1 { - {{ kernel.def_dma_op("MVIN", "X0", X0_IDX, X0_TILE_DESC, indent_size=8) }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y0_IDX, X0_TILE_DESC, indent_size=8) }} - } - affine.for %index3 = 0 to {{ X1_COLS }} step 1 { - {{ kernel.def_dma_op("MVIN", "X1", X1_IDX, X1_TILE_DESC, indent_size=8) }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y1_IDX, X1_TILE_DESC, indent_size=8) }} - } - } -{% endif %} +{%- for d in range(RANK-1) %} + affine.for %index{{ OUTPUT_DIM[d] }} = 0 to {{ OUTPUT_SIZES[d] }} step {{ TILE_SIZES[d] }} { +{%- endfor %} +{%- for i in range(NUM_INPUTS) %} + // Input tensor{{ i }} + affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUT_SIZES[i][DIM] }} step {{ INPUT_TILE_SIZES_DIM[i] }} { + %index{{ DIM }}_{{i}} = affine.apply affine_map<(d0) -> (d0 + {{ CUMULATIVE_OFFSETS[i] }})> (%index_local{{ DIM }}_{{ i }}) + {{ kernel.def_dma_op("MVIN", INPUT_BUFFER_NAMES[i], INPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, OUTPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }} + } { inner_loop=true } +{%- endfor %} + +{%- for d in range(RANK-1) %} + } { outer_loop=true } +{%- endfor %} } { outer_loop=true } return } @@ -51,8 +40,8 @@ class MLIRCatTemplate(MLIRTemplate): - def __init__(self, input_nodes, layout, dim, input_reorder=None): - super().__init__("kernel", input_nodes, layout, input_reorder) + def __init__(self, input_nodes, layout, dim): + super().__init__("kernel", input_nodes, layout) self.dim = dim def render( @@ -66,87 +55,248 @@ def render( is_out_variant = template_buffer_node is not None if is_out_variant: self.output_node = template_buffer_node - # cat template currently emits a single output buffer and does not - # support epilogue output remapping. - - def _unwrap_node(n): - return n.node if hasattr(n, "node") else n - - x0 = _unwrap_node(self.input_nodes[0]) - x1 = _unwrap_node(self.input_nodes[1]) - y = _unwrap_node(self.output_node) - - def _as_int(v): - try: - return int(v) - except Exception: - return int(V.graph.sizevars.size_hint(v)) - - x0_rows = _as_int(x0.get_size()[0]) - x1_rows = _as_int(x1.get_size()[0]) - x0_cols = _as_int(x0.get_size()[1]) - x1_cols = _as_int(x1.get_size()[1]) - y_cols = _as_int(y.get_size()[1]) - kernel.loop_size = None - - # 2D cat template with contiguous layout. - x0_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) - x0_tile_desc.set_tile_size_stride([1, 1], [1, 1]) - x0_tile_desc.set_name("x0_cat_tile") - x1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) - x1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) - x1_tile_desc.set_name("x1_cat_tile") - y_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) - y_tile_desc.set_tile_size_stride([1, 1], [1, 1]) - y_tile_desc.set_name("y_cat_tile") - if self.dim == 0: - # Flattened offsets for dim=0 cat. - x0_idx = [sympy.Symbol("index0") * x0_cols, sympy.Symbol("index1")] - x1_idx = [sympy.Symbol("index2") * x1_cols, sympy.Symbol("index3")] - y0_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index1")] - y1_idx = [(sympy.Symbol("index2") + x0_rows) * y_cols, sympy.Symbol("index3")] - else: - # Flattened offsets for dim=1 cat. - x0_idx = [sympy.Symbol("index0") * x0_cols, sympy.Symbol("index1")] - x1_idx = [sympy.Symbol("index0") * x1_cols, sympy.Symbol("index3")] - y0_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index1")] - y1_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index3") + x0_cols] + # Extract info + input_nodes = self.input_nodes + y = self.output_node + num_inputs = len(self.input_nodes) + rank = len(y.get_size()) + + input_sizes = [x.get_size() for x in input_nodes] + output_sizes = [sz for dim, sz in enumerate(y.get_size()) if dim != self.dim] + output_dim = [dim for dim, sz in enumerate(y.get_size()) if dim != self.dim] + tile_sizes = tile_info if tile_info is not None else [1] * len(output_sizes) + output_strides = y.get_layout().stride + + # Calculate input tile sizes + input_tile_sizes_dim = self._calculate_input_tile_sizes( + kernel, input_sizes, tile_sizes, num_inputs, rank + ) + buffer_name_to_template_name, input_buffer_names = self._build_buffer_mapping(input_nodes) + input_tile_descs, unique_tile_descs = self._build_tile_descriptors( + kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names + ) + y_tile_desc = self._build_output_tile_desc( + kernel, input_tile_sizes_dim, tile_sizes, rank + ) + + input_idxs, output_idxs, cumulative_offsets = self._build_index_expressions( + input_nodes, input_sizes, output_strides, rank, num_inputs + ) + + # Map unique buffer names to their tile descriptors for template + unique_buffer_tile_descs = {} + for actual_name, template_name in buffer_name_to_template_name.items(): + if actual_name in unique_tile_descs: + unique_buffer_tile_descs[template_name] = unique_tile_descs[actual_name] + + names_str = ", ".join(input_buffer_names + ["out_ptr1" if is_out_variant else "Y"]) + indent_size = 2 + (rank - 1) * 2 + 4 kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, - X0=x0, - X1=x1, Y=y, OUT_DVAR="out_ptr1" if is_out_variant else "Y", - NAMES_STR="X0, X1, out_ptr1" if is_out_variant else "X0, X1, Y", + NAMES_STR=names_str, + INPUT_NAMES=input_nodes, + INPUT_BUFFER_NAMES=input_buffer_names, + NUM_INPUTS=num_inputs, + RANK=rank, DIM=self.dim, - X0_ROWS=x0_rows, - X1_ROWS=x1_rows, - ROWS=x0_rows, - X0_COLS=x0_cols, - X1_COLS=x1_cols, - COLS=x0_cols, - X0_TILE_DESC=x0_tile_desc, - X1_TILE_DESC=x1_tile_desc, - Y_TILE_DESC=y_tile_desc, - X0_IDX=x0_idx, - X1_IDX=x1_idx, - Y0_IDX=y0_idx, - Y1_IDX=y1_idx, + INPUT_SIZES=input_sizes, + OUTPUT_SIZES=output_sizes, + OUTPUT_DIM=output_dim, + TILE_SIZES=tile_sizes, + INPUT_TILE_SIZES_DIM=input_tile_sizes_dim, + INPUT_TILE_DESCS=input_tile_descs, + UNIQUE_BUFFER_TILE_DESCS=unique_buffer_tile_descs, + INPUT_IDXS=input_idxs, + OUTPUT_IDXS=output_idxs, + CUMULATIVE_OFFSETS=cumulative_offsets, + INDENT_SIZE=indent_size, input_reorder=self.input_reorder, ) - # Needed when epilogue fusion requests set_ranges(). - kernel.dim_aliasing = {"index0": "index0", "index1": "index1"} - if hasattr(self.output_node, "node") and hasattr(self.output_node.node, "get_name"): - output_node_name = self.output_node.node.get_name() - elif hasattr(self.output_node, "get_name"): - output_node_name = self.output_node.get_name() - else: - output_node_name = self.output_node.name + self._setup_epilogue_info(kernel, y) + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + return code + + def get_tile_candidates( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs, + ): + """Generate tile candidates for cat operation. Concat dimension always has tile size 1.""" + if template_buffer_node is not None: + self.output_node = template_buffer_node + + y = self.output_node + num_inputs = len(self.input_nodes) + output_sizes = [sz for dim, sz in enumerate(y.get_size()) if dim != self.dim] + num_non_dim_dims = len(output_sizes) + + if num_non_dim_dims == 0: + return [[1]] + + tile_candidates = [] + dim_tile_candidates = [] + + for dim_size in output_sizes: + dim_candidates = [] + max_tile = min(dim_size, kernel.spad_info["spad_size"] // (kernel.vector_lane * kernel.precision * 2 * num_inputs)) + + for mult in range(1, max_tile // kernel.vector_lane + 1): + tile = mult * kernel.vector_lane + if tile <= dim_size: + dim_candidates.append(tile) + if max_tile > 0: + for exp in range(int(math.log2(max_tile)) + 1): + tile = 2 ** exp + if tile <= dim_size and tile not in dim_candidates: + dim_candidates.append(tile) + + if dim_size not in dim_candidates: + dim_candidates.append(dim_size) + + dim_tile_candidates.append(sorted(set(dim_candidates))[:5]) + + for tile_combo in itertools.product(*dim_tile_candidates): + total_elements = math.prod(tile_combo) + total_spad_needed = total_elements * (num_inputs + 1) * kernel.precision + + if total_spad_needed <= kernel.spad_info["spad_size"] * kernel.vector_lane: + tile_candidates.append(list(tile_combo)) + + if not tile_candidates: + tile_candidates = [[1] * num_non_dim_dims] + + tile_candidates.sort(key=lambda x: -math.prod(x)) + return tile_candidates[:4] + + def _calculate_input_tile_sizes( + self, kernel, input_sizes, tile_sizes, num_inputs, rank + ): + """Calculate tile sizes for concat dimension for each input.""" + non_dim_tile_elements = math.prod(tile_sizes) if tile_sizes else 1 + non_dim_tile_spad = non_dim_tile_elements * kernel.precision + max_spad_per_input = kernel.spad_info["spad_size"] * kernel.vector_lane // 2 + extra_concat_input = math.ceil(max_spad_per_input / non_dim_tile_spad) - num_inputs + + input_tile_sizes_dim = [] + for i in range(num_inputs): + input_dim_size = input_sizes[i][self.dim] + if extra_concat_input > 0 and non_dim_tile_elements > 0: + max_tile_dim = min(input_dim_size, extra_concat_input) + extra_concat_input -= max_tile_dim + else: + max_tile_dim = 1 + input_tile_sizes_dim.append(max_tile_dim) + return input_tile_sizes_dim + + def _build_buffer_mapping(self, input_nodes): + """Map actual buffer names to template buffer names """ + buffer_name_to_template_name = {} + input_buffer_names = [] + for x in input_nodes: + actual_name = x.get_name() + template_name = buffer_name_to_template_name.setdefault( + actual_name, f"X{len(buffer_name_to_template_name)}" + ) + input_buffer_names.append(template_name) + return buffer_name_to_template_name, input_buffer_names + + def _build_tile_descriptors( + self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names + ): + """Build tile descriptors for each input.""" + input_tile_descs = [] + unique_tile_descs = {} + + for i, x in enumerate(input_nodes): + # Build full tile size list for this input + full_tile_sizes = [] + tile_size_idx = 0 + for d in range(rank): + if d != self.dim: + full_tile_sizes.append(tile_sizes[tile_size_idx]) + tile_size_idx += 1 + else: + full_tile_sizes.append(input_tile_sizes_dim[i]) + + tile_desc = mlir_common.MLIRMultiDimTile( + full_tile_sizes, + kernel.vector_lane, + vlane_split_axis=rank - 1, + vlane_stride=1 + ) + tile_desc.set_tile_size(full_tile_sizes) + template_buffer_name = input_buffer_names[i] + tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile") + input_tile_descs.append(tile_desc) + + # Store unique tile desc by actual buffer name + actual_name = x.get_name() + if actual_name not in unique_tile_descs: + unique_tile_descs[actual_name] = tile_desc + + return input_tile_descs, unique_tile_descs + + def _build_index_expressions( + self, input_nodes, input_sizes, output_strides, rank, num_inputs + ): + """Build index expressions for input and output.""" + input_idxs = [] + output_idxs = [] + cumulative_offsets = [0] + for i in range(num_inputs - 1): + cumulative_offsets.append(cumulative_offsets[-1] + input_sizes[i][self.dim]) + + for i, x in enumerate(input_nodes): + x_stride = x.get_layout().stride + input_idx = [] + output_idx = [] + for d in range(rank): + if d != self.dim: + input_idx_symbol = sympy.Symbol(f"index{d}") + output_idx_symbol = sympy.Symbol(f"index{d}") + else: + input_idx_symbol = sympy.Symbol(f"index_local{self.dim}_{i}") + output_idx_symbol = sympy.Symbol(f"index{self.dim}_{i}") + input_idx.append(input_idx_symbol * x_stride[d]) + output_idx.append(output_idx_symbol * output_strides[d]) + input_idxs.append(input_idx) + output_idxs.append(output_idx) + + return input_idxs, output_idxs, cumulative_offsets + + def _build_output_tile_desc(self, kernel, input_tile_sizes_dim, tile_sizes, rank): + """Build output tile descriptor.""" + max_output_tile_dim = max(input_tile_sizes_dim) if input_tile_sizes_dim else 1 + output_full_tile_sizes = [] + tile_size_idx = 0 + for d in range(rank): + if d != self.dim: + output_full_tile_sizes.append(tile_sizes[tile_size_idx]) + tile_size_idx += 1 + else: + output_full_tile_sizes.append(max_output_tile_dim) + + y_tile_desc = mlir_common.MLIRMultiDimTile( + output_full_tile_sizes, + kernel.vector_lane, + vlane_split_axis=rank - 1, + vlane_stride=1 + ) + y_tile_desc.set_tile_size(output_full_tile_sizes) + y_tile_desc.set_name("y_cat_tile") + return y_tile_desc + + def _setup_epilogue_info(self, kernel, y): + """Setup epilogue information.""" if hasattr(y, "get_numel"): y_numel = y.get_numel() elif hasattr(y, "node") and hasattr(y.node, "get_numel"): @@ -154,14 +304,5 @@ def _as_int(v): else: y_numel = None - kernel.epilogue_info = dict( - output_node=output_node_name, - sram_var="y_cat_tile", - dram_var=kernel.render_options["OUT_DVAR"], - dram_tile_desc=y_tile_desc, - ) if y_numel is not None: kernel.exception_nodes[kernel.render_options["OUT_DVAR"]] = {"numel": y_numel} - - code = self._template_from_string(TEMPLATE).render(**kernel.render_options) - return code diff --git a/PyTorchSimFrontend/mlir/mlir_conv_common.py b/PyTorchSimFrontend/mlir/mlir_conv_common.py index f8566b6d..f72a7663 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_common.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_common.py @@ -12,6 +12,9 @@ class MLIRConvCommonTemplate(MLIRTemplate): WRAPPER_TEMPLATE = None def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = False + self.support_reduction_fusion = False self.stride = kwargs["stride"] self.padding = kwargs["padding"] self.dilation = kwargs["dilation"] diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 0158caa6..5b116807 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -105,6 +105,9 @@ class MLIRGemmTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = True + self.support_reduction_fusion = True def render(self, kernel: MLIRTemplateKernel, diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index 0f28f03b..d7aee715 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -202,48 +202,9 @@ def _cat_layout(tensors: Sequence[TensorBox], dim: int) -> ir.Layout: stride, ) - -def _can_use_cat_template(tensors: Sequence[TensorBox], dim: int) -> bool: - # Current template specialization: 2 inputs, rank-2, dim in {0, 1}. - if len(tensors) != 2: - return False - if not all(hasattr(t, "get_size") and hasattr(t, "get_dtype") and hasattr(t, "realize") for t in tensors): - return False - if tensors[0].get_dtype() != tensors[1].get_dtype(): - return False - rank0 = len(tensors[0].get_size()) - rank1 = len(tensors[1].get_size()) - if rank0 != 2 or rank1 != 2: - return False - if dim < 0: - dim += rank0 - if dim not in (0, 1): - return False - - if dim == 0: - cols0 = tensors[0].get_size()[1] - cols1 = tensors[1].get_size()[1] - return V.graph.sizevars.statically_known_equals(cols0, cols1) - - rows0 = tensors[0].get_size()[0] - rows1 = tensors[1].get_size()[0] - return V.graph.sizevars.statically_known_equals(rows0, rows1) - - -def _cat_fallback(reason: str, tensors: Sequence[TensorBox], dim: int): - # Non-template cases delegate to the original lowering path. - return _orig_cat_default_lowering(tensors, dim) - - -def _custom_cat_impl(tensors: Sequence[TensorBox], dim: int = 0): - if _orig_cat_default_lowering is None: - raise RuntimeError("Original aten.cat.default lowering is missing") - if len(tensors) > 0: - rank = len(tensors[0].get_size()) - if dim < 0: - dim += rank - if not _can_use_cat_template(tensors, dim): - return _cat_fallback("default-path", tensors, dim) +def custom_cat_default(tensors: Sequence[TensorBox], dim: int = 0): + if tensors and dim < 0: + dim += len(tensors[0].get_size()) for t in tensors: t.realize() @@ -251,75 +212,6 @@ def _custom_cat_impl(tensors: Sequence[TensorBox], dim: int = 0): mlir_template = MLIRCatTemplate(list(tensors), layout, dim=dim) return mlir_template.generate().output_node() - -def custom_cat_default(tensors: Sequence[TensorBox], dim: int = 0): - return _custom_cat_impl(tensors, dim) - - -def custom_cat_out(tensors: Sequence[TensorBox], dim: int = 0, out: Optional[TensorBox] = None): - if _orig_cat_out_lowering is None: - raise RuntimeError("Original aten.cat.out lowering is missing") - if out is None: - return _orig_cat_out_lowering(tensors, dim, out) - - copy_default_lowering = lowerings.get(aten.copy_.default) - slice_tensor_lowering = lowerings.get(aten.slice.Tensor) - if copy_default_lowering is None or slice_tensor_lowering is None: - raise RuntimeError("cat.out lowering requires aten.copy_.default and aten.slice.Tensor lowerings") - - # Lower cat.out as a sequence of slice+copy ops so each piece still runs - # through the existing compiled/simulated kernel path. - if len(tensors) == 0: - raise RuntimeError("cat.out requires at least one input tensor") - if not all(hasattr(t, "get_size") and hasattr(t, "get_dtype") and hasattr(t, "realize") for t in tensors): - raise RuntimeError("cat.out inputs must be tensor-like values") - rank = len(tensors[0].get_size()) - if rank == 0: - raise RuntimeError("cat.out does not support scalar inputs") - if dim < 0: - dim = dim + rank - if dim < 0 or dim >= rank: - raise RuntimeError(f"cat.out dim out of range: dim={dim}, rank={rank}") - if any(len(t.get_size()) != rank for t in tensors): - raise RuntimeError("cat.out inputs must have the same rank") - if any(t.get_dtype() != tensors[0].get_dtype() for t in tensors): - raise RuntimeError("cat.out inputs must have the same dtype") - # cat semantics: all non-cat dimensions must be equal. - for i in range(rank): - if i == dim: - continue - base = tensors[0].get_size()[i] - if any(not V.graph.sizevars.statically_known_equals(base, t.get_size()[i]) for t in tensors[1:]): - raise RuntimeError(f"cat.out non-concatenated dimension mismatch at dim={i}") - - # Output shape must match concatenated shape. - if not hasattr(out, "get_size"): - raise RuntimeError("cat.out output must be tensor-like") - out_sizes = list(out.get_size()) - if len(out_sizes) != rank: - raise RuntimeError("cat.out output rank mismatch") - for i in range(rank): - if i == dim: - continue - if not V.graph.sizevars.statically_known_equals(out_sizes[i], tensors[0].get_size()[i]): - raise RuntimeError(f"cat.out output shape mismatch at dim={i}") - expected_cat = sum(t.get_size()[dim] for t in tensors) - if not V.graph.sizevars.statically_known_equals(out_sizes[dim], expected_cat): - raise RuntimeError(f"cat.out output concatenated dimension mismatch at dim={dim}") - - if isinstance(out, TensorBox): - out.realize() - - offset = 0 - for src in tensors: - src.realize() - end = offset + src.get_size()[dim] - dst_view = slice_tensor_lowering(out, dim, offset, end, 1) - copy_default_lowering(dst_view, src) - offset = end - return out - - def _custom_sort_values_impl( self: TensorBox, dim: int = -1, @@ -459,9 +351,7 @@ def custom_sort_values_stable( lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()}) lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()}) - -lowerings.update({aten.cat.default: custom_cat_default}) -lowerings.update({aten.cat.out: custom_cat_out}) +lowerings.update({getattr(aten.cat, overload): custom_cat_default for overload in aten.cat.overloads()}) lowerings.update({aten.sort.stable: custom_sort_stable}) lowerings.update({aten.sort.values_stable: custom_sort_values_stable}) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index af960533..2f9c9704 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -44,12 +44,10 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule # Case 3: Prologue(Pointwise) + Tempalte if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - target_node = base_template_node2[0].node - # Currently only BMM, MM support prologue fusion - if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + + # Check if template supports prologue fusion + if not getattr(target_node.template, 'support_prologue_fusion', False): return False if len(node1.read_writes.writes) != 1: @@ -129,12 +127,14 @@ def can_fuse_horizontal(self, node1, node2): if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and not node2.is_reduction(): # Don't fuse maxpool template code from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate template_node = base_template_node1[0] epilogue_node = node2 + # Check if template supports epilogue fusion + if not getattr(template_node.node.template, 'support_epilogue_fusion', False): + return False + if isinstance(template_node.node.template, MLIRMaxPoolTemplate): return False @@ -161,7 +161,7 @@ def can_fuse_horizontal(self, node1, node2): # Revert act_node.group : simplify_and_reorder() modified _body, _size, group if template_node.group != epilogue_node.group: # We don't fuse this case... - if (isinstance(template_node.node.template, MLIRBMMTemplate) or isinstance(template_node.node.template, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: + if getattr(template_node.node.template, 'support_prologue_fusion', False) and template_node.group[1][0][0] == 1: return False if list(template_node.group[1][0]) != list(epilogue_node.get_nodes()[0].node.data.get_size()): @@ -171,10 +171,10 @@ def can_fuse_horizontal(self, node1, node2): # Case 2: Tempalte + Reduction fusion if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and node2.is_reduction() and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate target_node = base_template_node1[0].node - if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + + # Check if template supports reduction fusion + if not getattr(target_node.template, 'support_reduction_fusion', False): return False size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 76b0ef71..04d327f8 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -14,7 +14,7 @@ from unittest.mock import patch from torch._inductor.codegen.common import KernelTemplate, CSE, DeferredLine -from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, ChoiceCaller +from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, ChoiceCaller, ir_node_to_tensor from torch._inductor.select_algorithm import PartialRender from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller from torch._inductor.autotune_process import TensorMeta @@ -124,6 +124,7 @@ def __init__(self, self.epilogue_buffer_group = IndentedBufferGroup(self, prefix="epilogue_") self.global_vars = IndentedBuffer() self.exception_nodes = {} + self.epilogue_info = {} # Reduction data structure self.reduction_epilogue_suffix = IndentedBuffer() self.reduction_fusion = False @@ -403,7 +404,7 @@ def call_kernel(self, kernel_name): _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this wrapper.generate_kernel_call( - kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args) + kernel_name if self.outer_func_name is None else "wrapper_" + kernel_name, call_args) def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_nodes, tile_info): with self as kernel: @@ -460,11 +461,11 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ } node.codegen((vars, reduction_vars)) - # Codegen epilogue nodes - tile_desc = kernel.set_tile_size(kernel.epilogue_info) - kernel.kernel_group.set_tile_info(tile_desc) - kernel.call_ranges = None if epilogue_nodes: + # Codegen epilogue nodes + tile_desc = kernel.set_tile_size(kernel.epilogue_info) + kernel.kernel_group.set_tile_info(tile_desc) + kernel.call_ranges = None with kernel.epilogue_buffer_group.as_local(): _, (group, reduction_group) = max( epilogue_nodes, key=lambda x: int(x.is_reduction()) @@ -625,7 +626,9 @@ def def_kernel( extra_node[node.get_name()] = node.node else: extra_node[node.get_name()] = node - self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] + + if 'sram_var' in self.epilogue_info: + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] def hook(): arg_defs, call_args, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) @@ -688,7 +691,8 @@ def def_conv_kernel( self.kernel_group.args.output_buffers[node.get_name()] = name self.store_buffer_names.add(node.get_name()) #TODO: Is this enough not calling store() in mlir_common.py? self.extra_node[node.get_name()] = node - self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] #TODO: Buffer name fixed + if 'sram_var' in self.epilogue_info: + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] #TODO: Buffer name fixed def kernel_hook(): arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=self.extra_node) @@ -1146,6 +1150,15 @@ def set_tile_size(self, template_fusion_info, prologue=False): return tile_desc class MLIRTemplateCaller(CUDATemplateCaller): + def __init__(self, name, category, input_nodes, layout, make_kernel_render, supports_epilogue_fusion, template, info_kwargs, description): + bmreq = MLIRBenchmarkRequest( + kernel_name=name, + input_tensor_meta=list(), + output_tensor_meta=list(), + extra_args=[], + source_code="", + ) + super().__init__(name, category, input_nodes, layout, make_kernel_render, bmreq, supports_epilogue_fusion, template, info_kwargs, description) def __str__(self): return f"MLIRTemplateCaller(source_file={self.bmreq.source_file})" @@ -1173,6 +1186,10 @@ def __init__(self, name, input_nodes, layout, input_reorder = None): self.output_nodes = [self.output_node] self.input_reorder = input_reorder self.layout = layout + # Fusion support flags (default to False) + self.support_epilogue_fusion = False + self.support_prologue_fusion = False + self.support_reduction_fusion = False def generate(self, **kwargs) -> ChoiceCaller: kernel_name = f"mlir_{self.name}" @@ -1184,18 +1201,9 @@ def generate(self, **kwargs) -> ChoiceCaller: code = self.render(kernel=kernel, **kwargs) kernel_hash_name = f"mlir_{self.name}_{next(self.index_counter)}" - extra_args = [] # create the BenchmarkRequest output_nodes = getattr(self, "output_nodes", None) or [self.output_node] - bmreq = MLIRBenchmarkRequest( - kernel_name=kernel_name, - input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), - output_tensor_meta=TensorMeta.from_irnodes(output_nodes), - extra_args=extra_args, - source_code=code, - ) - def make_kernel_render( template_node: TemplateBuffer, prologue_nodes: Optional[List[IRNode]] = None, @@ -1236,7 +1244,6 @@ def make_kernel_render( self.input_nodes, self.output_node.get_layout(), make_kernel_render, - bmreq, False, # supports_epilogue_fusion self, kwargs, diff --git a/tests/test_cat.py b/tests/test_cat.py index 32573a05..62de6759 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -20,24 +20,6 @@ def _test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): print("cpu out: ", cpu_out) raise RuntimeError(f"{name} mismatch") - -def _togsim_log_count() -> int: - log_dir = Path("togsim_results") - if not log_dir.exists(): - return 0 - return len(list(log_dir.glob("*.log"))) - - -def _assert_simulation_happened(before_count: int, case_name: str): - after_count = _togsim_log_count() - if after_count <= before_count: - raise RuntimeError( - f"{case_name}: TOGSim log count did not increase " - f"(before={before_count}, after={after_count})" - ) - print(f"{case_name}: TOGSim logs increased ({before_count} -> {after_count})") - - def test_cat_default(device): def cat_default_fn(a, b): return torch.cat([a, b], dim=0) @@ -46,9 +28,7 @@ def cat_default_fn(a, b): y = torch.randn(6, 16, device=device) opt_fn = torch.compile(dynamic=False)(cat_default_fn) - before = _togsim_log_count() out = opt_fn(x, y) - _assert_simulation_happened(before, "cat.default") cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) _test_result("cat.default", out, cpu_out, rtol=1e-4, atol=1e-4) @@ -63,19 +43,122 @@ def cat_out_fn(a, b, out): out_buf = torch.empty(14, 16, device=device) opt_fn = torch.compile(dynamic=False)(cat_out_fn) - before = _togsim_log_count() out = opt_fn(x, y, out_buf) - _assert_simulation_happened(before, "cat.out") cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) _test_result("cat.out", out, cpu_out, rtol=1e-4, atol=1e-4) +def test_cat_4d_dim0(device): + def cat_4d_dim0_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(3, 3, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim0_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.4d.dim0", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim1(device): + def cat_4d_dim1_fn(a, b): + return torch.cat([a, b], dim=1) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 5, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim1_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=1) + _test_result("cat.4d.dim1", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim2(device): + def cat_4d_dim2_fn(a, b): + return torch.cat([a, b], dim=2) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 3, 6, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim2_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=2) + _test_result("cat.4d.dim2", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim3(device): + def cat_4d_dim3_fn(a, b): + return torch.cat([a, b], dim=3) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 3, 4, 7, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim3_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=3) + _test_result("cat.4d.dim3", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_three_inputs(device): + def cat_three_inputs_fn(a, b, c): + return torch.cat([a, b, c], dim=0) + + x = torch.randn(4, 16, device=device) + y = torch.randn(5, 16, device=device) + z = torch.randn(3, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_three_inputs_fn) + + out = opt_fn(x, y, z) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu()], dim=0) + _test_result("cat.three_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_four_inputs(device): + def cat_four_inputs_fn(a, b, c, d): + return torch.cat([a, b, c, d], dim=0) + + x = torch.randn(3, 16, device=device) + y = torch.randn(4, 16, device=device) + z = torch.randn(5, 16, device=device) + w = torch.randn(2, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_four_inputs_fn) + + out = opt_fn(x, y, z, w) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu(), w.cpu()], dim=0) + _test_result("cat.four_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_three_inputs(device): + def cat_4d_three_inputs_fn(a, b, c): + return torch.cat([a, b, c], dim=1) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 4, 4, 5, device=device) + z = torch.randn(2, 5, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_three_inputs_fn) + + out = opt_fn(x, y, z) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu()], dim=1) + _test_result("cat.4d.three_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run cat simulation tests") parser.add_argument( "--case", - choices=["default", "out", "all"], + choices=[ + "default", "out", "4d_dim0", "4d_dim1", "4d_dim2", "4d_dim3", + "three_inputs", "four_inputs", "4d_three_inputs", "all" + ], default="all", help="Which cat case to run", ) @@ -87,3 +170,17 @@ def cat_out_fn(a, b, out): test_cat_default(device) if args.case in ("out", "all"): test_cat_out(device) + if args.case in ("4d_dim0", "all"): + test_cat_4d_dim0(device) + if args.case in ("4d_dim1", "all"): + test_cat_4d_dim1(device) + if args.case in ("4d_dim2", "all"): + test_cat_4d_dim2(device) + if args.case in ("4d_dim3", "all"): + test_cat_4d_dim3(device) + if args.case in ("three_inputs", "all"): + test_cat_three_inputs(device) + if args.case in ("four_inputs", "all"): + test_cat_four_inputs(device) + if args.case in ("4d_three_inputs", "all"): + test_cat_4d_three_inputs(device) From 434bbb10793a68172e49e107bc3b639fd3b86264 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 4 Mar 2026 20:02:14 +0900 Subject: [PATCH 3/9] [WIP] --- PyTorchSimFrontend/mlir/mlir_cat_template.py | 13 ------------- PyTorchSimFrontend/mlir/mlir_template.py | 2 +- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py index d68af7d4..5062e629 100644 --- a/PyTorchSimFrontend/mlir/mlir_cat_template.py +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -118,7 +118,6 @@ def render( input_reorder=self.input_reorder, ) - self._setup_epilogue_info(kernel, y) code = self._template_from_string(TEMPLATE).render(**kernel.render_options) return code @@ -294,15 +293,3 @@ def _build_output_tile_desc(self, kernel, input_tile_sizes_dim, tile_sizes, rank y_tile_desc.set_tile_size(output_full_tile_sizes) y_tile_desc.set_name("y_cat_tile") return y_tile_desc - - def _setup_epilogue_info(self, kernel, y): - """Setup epilogue information.""" - if hasattr(y, "get_numel"): - y_numel = y.get_numel() - elif hasattr(y, "node") and hasattr(y.node, "get_numel"): - y_numel = y.node.get_numel() - else: - y_numel = None - - if y_numel is not None: - kernel.exception_nodes[kernel.render_options["OUT_DVAR"]] = {"numel": y_numel} diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 04d327f8..59610228 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -813,7 +813,7 @@ def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_com if dram_var in self.exception_nodes: numel = self.exception_nodes[dram_var]["numel"] else: - numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() + numel = self.named_nodes[dram_var].get_numel() mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] dram_shape = f"memref<{numel}x{mlir_dtype}>" dram_stride = [] From 5295dfb5a16e21fda57b12d73906c1bd290c4f94 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 4 Mar 2026 22:13:26 +0900 Subject: [PATCH 4/9] [Template] Delay def_dma_op codegen def_dma_op find data node using dram_var. But it can't locate the proper node when output buffer has not been created. --- PyTorchSimFrontend/mlir/mlir_template.py | 146 +++++++++++++---------- 1 file changed, 81 insertions(+), 65 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 59610228..7c52bfe6 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -112,7 +112,8 @@ def __init__(self, self.outer_func_name = outer_func_name self.outer_func_render = outer_func_render self.kernel_arg_attributes = kernel_arg_attributes - self.render_hooks = OrderedDict() + self.render_hooks = OrderedDict() # Stores {key: (priority, hook)} + self.dma_op_counter = itertools.count() # Add counter for unique DMA op keys self.buffer_names = dict() self.render_options = dict() self.tile_size = [] @@ -555,7 +556,7 @@ def template_store(): dram_var = self.epilogue_info["dram_var"] index_list = self.epilogue_info["dram_idx"] tile_desc = self.epilogue_info["dram_tile_desc"] - code = self.def_dma_op("MVOUT", dram_var, index_list, tile_desc) + code = self.def_dma_op("MVOUT", dram_var, index_list, tile_desc, lazy_mode=False) self.cse.generate(self.dma_stores, code, assignment = False) body = IndentedBuffer() @@ -653,7 +654,7 @@ def hook(): return f"({', '.join(renamed_arg_defs)})" assert "" not in self.render_hooks - self.render_hooks[""] = hook + self.render_hooks[""] = (5, hook) # Default priority 5 return "" # This function is a temporal function for convolution because currently convolution kernel is not considering padding. @@ -700,7 +701,7 @@ def kernel_hook(): return f"({', '.join(arg_defs)})" assert "" not in self.render_hooks - self.render_hooks[""] = kernel_hook + self.render_hooks[""] = (5, kernel_hook) # Default priority 5 return "" # This function is for convolution wrapper function finalizing. @@ -711,7 +712,7 @@ def wrapper_hook(): return f"({', '.join(wrapper_arg_defs)})" if "" not in self.render_hooks: - self.render_hooks[""] = wrapper_hook + self.render_hooks[""] = (5, wrapper_hook) # Default priority 5 return "" def get_conv_inputs(self): @@ -720,15 +721,15 @@ def get_conv_inputs(self): def get_conv_outputs(self): return {k: v for k, v in self.kernel_group.args.output_buffers.items() if v != 'REMOVED'} - def load_input(self, indent_size: int = 0): + def load_input(self, indent_size: int = 0, priority: int = 1): def hook(): code = IndentedBuffer() prologue_code = self.codegen_prologue_body() if prologue_code.getvalue(): input_dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], - self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False, lazy_mode=False) weight_dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], - self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False, lazy_mode=False) if (self.prologue_info["is_input_fused"]): code.splice(input_dma_code) code.splice(prologue_code) @@ -739,58 +740,63 @@ def hook(): code.splice(input_dma_code) else: dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], - self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False, lazy_mode=False) code.splice(dma_code) dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], - self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False, lazy_mode=False) code.splice(dma_code) code = textwrap.indent(code.getvalue(), " "*indent_size).strip() return code assert "" not in self.render_hooks - self.render_hooks[""] = hook - self.render_hooks.move_to_end("", last=False) # Force order to be triggered first + self.render_hooks[""] = (priority, hook) return "" - def store_output(self, indent_size: int = 0): + def store_output(self, indent_size: int = 0, priority: int = 1): def hook(): epilogue_code = self.codegen_epilogue_body() return textwrap.indent(epilogue_code.getvalue(), " "*indent_size).strip() assert "" not in self.render_hooks - self.render_hooks[""] = hook - self.render_hooks.move_to_end("", last=False) # Force order to be triggered first + self.render_hooks[""] = (priority, hook) return "" - def reduction_output(self, indent_size: int = 0): + def reduction_output(self, indent_size: int = 0, priority: int = 5): def hook(): return textwrap.indent(self.reductions_suffix.getvalue(), " "*indent_size).strip() assert "" not in self.render_hooks - self.render_hooks[""] = hook + self.render_hooks[""] = (priority, hook) return "" + def _sort_hooks_by_priority(self): + """Sort hooks by priority (lower priority executes first).""" + sorted_hooks = OrderedDict() + for key, (priority, hook) in sorted(self.render_hooks.items(), key=lambda x: x[1][0]): + sorted_hooks[key] = hook + return sorted_hooks + def def_function(self): _, call_args, _, _ = self.kernel_group.args.python_argdefs() if self.outer_func_render is not None: partial_code, function_name = self.outer_func_render(input_args=call_args) + return PartialRender( partial_code, - self.render_hooks, + self._sort_hooks_by_priority(), ), function_name else: return None, None - def def_global_vars(self): + def def_global_vars(self, priority: int = 10): key = "" def hook(): return textwrap.indent(self.global_vars.getvalue(), "").strip() - assert key not in self.render_hooks - self.render_hooks[key] = hook + self.render_hooks[key] = (priority, hook) return key - def def_local_vars(self, indent_size=0): + def def_local_vars(self, indent_size=0, priority: int = 10): key = "" def hook(): code = IndentedBuffer() @@ -799,52 +805,62 @@ def hook(): code.splice(self.alloc_buffer) return textwrap.indent(code.getvalue(), " "*indent_size).strip() - assert key not in self.render_hooks - self.render_hooks[key] = hook + self.render_hooks[key] = (priority, hook) return key def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_common.MLIRMultiDimTile, - subtile_size:list=[], async_type=None, indent_size=0): - # Prepare code block - local_code = IndentedBuffer() - with self, self.override_buffer_cse(buffer=local_code, cse=self.apply_cse): - index_var = self.parse_index_list(index_list, offset=tile_desc.offset) - node_layout = self.named_nodes[dram_var].get_layout() - if dram_var in self.exception_nodes: - numel = self.exception_nodes[dram_var]["numel"] - else: - numel = self.named_nodes[dram_var].get_numel() - mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] - dram_shape = f"memref<{numel}x{mlir_dtype}>" - dram_stride = [] - for idx in index_list: - if idx.is_Mul: - dram_stride.append(int(idx.args[0])) - elif idx == sympy.Symbol("c0"): - dram_stride.append(0) - elif not idx.is_Number: - dram_stride.append(1) + subtile_size:list=[], async_type=None, indent_size=0, priority: int = 5, lazy_mode: bool = True): + def generate_dma_code(): + """Internal method to generate DMA code directly.""" + local_code = IndentedBuffer() + with self, self.override_buffer_cse(buffer=local_code, cse=self.apply_cse): + index_var = self.parse_index_list(index_list, offset=tile_desc.offset) + node_layout = self.named_nodes[dram_var].get_layout() + if dram_var in self.exception_nodes: + numel = self.exception_nodes[dram_var]["numel"] else: - dram_stride.append(0) - - sram_var = tile_desc.get_name() - tile_shape = tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = tile_desc.get_tile_stride() - vlane_split_axis = tile_desc.vmap.vlane_split_axis - vlane_stride = tile_desc.vmap.vlane_stride - - zero_cse = self.get_const_cse(0, "index") - sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) - - attribute_parts = [f"dram_stride={dram_stride}", f"sram_stride={tile_stride}", "padding=0"] - if subtile_size: - attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") - attribute = " {" + ", ".join(attribute_parts) + "}" - code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, "") - local_code.writeline(code) - local_code.writeline(attribute) - return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() + numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() + mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] + dram_shape = f"memref<{numel}x{mlir_dtype}>" + dram_stride = [] + for idx in index_list: + if idx.is_Mul: + dram_stride.append(int(idx.args[0])) + elif idx == sympy.Symbol("c0"): + dram_stride.append(0) + elif not idx.is_Number: + dram_stride.append(1) + else: + dram_stride.append(0) + + sram_var = tile_desc.get_name() + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) + tile_stride = tile_desc.get_tile_stride() + vlane_split_axis = tile_desc.vmap.vlane_split_axis + vlane_stride = tile_desc.vmap.vlane_stride + + zero_cse = self.get_const_cse(0, "index") + sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) + + attribute_parts = [f"dram_stride={dram_stride}", f"sram_stride={tile_stride}", "padding=0"] + if subtile_size: + attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") + attribute = " {" + ", ".join(attribute_parts) + "}" + code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, "") + local_code.writeline(code) + local_code.writeline(attribute) + return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() + + if not lazy_mode: + # Immediate mode: generate code directly and return it + return generate_dma_code() + + # Lazy mode: register hook and return key + dma_op_id = next(self.dma_op_counter) + key = f"" + self.render_hooks[key] = (priority, generate_dma_code) + return key def def_sram_buffer(self, dram_name, tile_desc, id=0, indent_size=0): # Prepare code block @@ -862,7 +878,7 @@ def render(self, template, kwargs, define_function=None): return PartialRender( code, - self.render_hooks, + self._sort_hooks_by_priority(), ) def get_spad_size_per_lane(self, tile_m, tile_n): From 61caebd5708ca21a88950d4d5073445891ea32f1 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 5 Mar 2026 00:12:49 +0900 Subject: [PATCH 5/9] [Template/Cat] Fix apply offset setting --- PyTorchSimFrontend/mlir/mlir_cat_template.py | 80 +++++++++----------- 1 file changed, 37 insertions(+), 43 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py index 5062e629..5aaf3e71 100644 --- a/PyTorchSimFrontend/mlir/mlir_cat_template.py +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -26,7 +26,7 @@ affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUT_SIZES[i][DIM] }} step {{ INPUT_TILE_SIZES_DIM[i] }} { %index{{ DIM }}_{{i}} = affine.apply affine_map<(d0) -> (d0 + {{ CUMULATIVE_OFFSETS[i] }})> (%index_local{{ DIM }}_{{ i }}) {{ kernel.def_dma_op("MVIN", INPUT_BUFFER_NAMES[i], INPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, OUTPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, OUTPUT_IDXS[i], OUTPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }} } { inner_loop=true } {%- endfor %} @@ -52,10 +52,6 @@ def render( tile_info=None, **kwargs, ): - is_out_variant = template_buffer_node is not None - if is_out_variant: - self.output_node = template_buffer_node - # Extract info input_nodes = self.input_nodes y = self.output_node @@ -73,11 +69,8 @@ def render( kernel, input_sizes, tile_sizes, num_inputs, rank ) buffer_name_to_template_name, input_buffer_names = self._build_buffer_mapping(input_nodes) - input_tile_descs, unique_tile_descs = self._build_tile_descriptors( - kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names - ) - y_tile_desc = self._build_output_tile_desc( - kernel, input_tile_sizes_dim, tile_sizes, rank + input_tile_descs, output_tile_descs, unique_tile_descs = self._build_tile_descriptors( + kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, y ) input_idxs, output_idxs, cumulative_offsets = self._build_index_expressions( @@ -90,14 +83,14 @@ def render( if actual_name in unique_tile_descs: unique_buffer_tile_descs[template_name] = unique_tile_descs[actual_name] - names_str = ", ".join(input_buffer_names + ["out_ptr1" if is_out_variant else "Y"]) + names_str = ", ".join(input_buffer_names + ["Y"]) indent_size = 2 + (rank - 1) * 2 + 4 kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, Y=y, - OUT_DVAR="out_ptr1" if is_out_variant else "Y", + OUT_DVAR="Y", NAMES_STR=names_str, INPUT_NAMES=input_nodes, INPUT_BUFFER_NAMES=input_buffer_names, @@ -110,6 +103,7 @@ def render( TILE_SIZES=tile_sizes, INPUT_TILE_SIZES_DIM=input_tile_sizes_dim, INPUT_TILE_DESCS=input_tile_descs, + OUTPUT_TILE_DESCS=output_tile_descs, UNIQUE_BUFFER_TILE_DESCS=unique_buffer_tile_descs, INPUT_IDXS=input_idxs, OUTPUT_IDXS=output_idxs, @@ -209,14 +203,16 @@ def _build_buffer_mapping(self, input_nodes): return buffer_name_to_template_name, input_buffer_names def _build_tile_descriptors( - self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names + self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, output_node ): - """Build tile descriptors for each input.""" + """Build tile descriptors for each input and output.""" input_tile_descs = [] + output_tile_descs = [] unique_tile_descs = {} + output_offset = output_node.get_layout().offset for i, x in enumerate(input_nodes): - # Build full tile size list for this input + x_offset = x.get_layout().offset full_tile_sizes = [] tile_size_idx = 0 for d in range(rank): @@ -226,23 +222,37 @@ def _build_tile_descriptors( else: full_tile_sizes.append(input_tile_sizes_dim[i]) - tile_desc = mlir_common.MLIRMultiDimTile( + # Input tile descriptor + input_tile_desc = mlir_common.MLIRMultiDimTile( full_tile_sizes, kernel.vector_lane, vlane_split_axis=rank - 1, vlane_stride=1 ) - tile_desc.set_tile_size(full_tile_sizes) + input_tile_desc.set_tile_size(full_tile_sizes) template_buffer_name = input_buffer_names[i] - tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile") - input_tile_descs.append(tile_desc) + input_tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile") + input_tile_desc.offset = x_offset + input_tile_descs.append(input_tile_desc) + + # Output tile descriptor (same as input but with output offset) + output_tile_desc = mlir_common.MLIRMultiDimTile( + full_tile_sizes, + kernel.vector_lane, + vlane_split_axis=rank - 1, + vlane_stride=1 + ) + output_tile_desc.set_tile_size(full_tile_sizes) + output_tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile") + output_tile_desc.offset = output_offset + output_tile_descs.append(output_tile_desc) # Store unique tile desc by actual buffer name actual_name = x.get_name() if actual_name not in unique_tile_descs: - unique_tile_descs[actual_name] = tile_desc + unique_tile_descs[actual_name] = input_tile_desc - return input_tile_descs, unique_tile_descs + return input_tile_descs, output_tile_descs, unique_tile_descs def _build_index_expressions( self, input_nodes, input_sizes, output_strides, rank, num_inputs @@ -256,6 +266,12 @@ def _build_index_expressions( for i, x in enumerate(input_nodes): x_stride = x.get_layout().stride + x_offset = x.get_layout().offset + if hasattr(x, 'data') and hasattr(x.data, 'dims'): + # In case of PermuteView, the stride is permuted + perm_dims = x.data.dims + x_stride = [x_stride[perm_dims[d]] for d in range(rank)] + input_idx = [] output_idx = [] for d in range(rank): @@ -271,25 +287,3 @@ def _build_index_expressions( output_idxs.append(output_idx) return input_idxs, output_idxs, cumulative_offsets - - def _build_output_tile_desc(self, kernel, input_tile_sizes_dim, tile_sizes, rank): - """Build output tile descriptor.""" - max_output_tile_dim = max(input_tile_sizes_dim) if input_tile_sizes_dim else 1 - output_full_tile_sizes = [] - tile_size_idx = 0 - for d in range(rank): - if d != self.dim: - output_full_tile_sizes.append(tile_sizes[tile_size_idx]) - tile_size_idx += 1 - else: - output_full_tile_sizes.append(max_output_tile_dim) - - y_tile_desc = mlir_common.MLIRMultiDimTile( - output_full_tile_sizes, - kernel.vector_lane, - vlane_split_axis=rank - 1, - vlane_stride=1 - ) - y_tile_desc.set_tile_size(output_full_tile_sizes) - y_tile_desc.set_name("y_cat_tile") - return y_tile_desc From 47684a75942bf9d35e19a7a79a1862418c5649a6 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 5 Mar 2026 17:44:32 +0900 Subject: [PATCH 6/9] [TOGSim] Add help print --- TOGSim/src/DMA.cc | 2 +- TOGSim/src/helper/CommandLineParser.cc | 6 +++++- TOGSim/src/helper/CommandLineParser.h | 8 +++++++- TOGSim/src/main.cc | 13 +++++++++---- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/TOGSim/src/DMA.cc b/TOGSim/src/DMA.cc index f8f21025..fefee6d2 100644 --- a/TOGSim/src/DMA.cc +++ b/TOGSim/src/DMA.cc @@ -12,7 +12,7 @@ void DMA::issue_tile(std::shared_ptr inst) { _current_inst = std::move(inst); std::vector& tile_size = _current_inst->get_tile_size(); if (tile_size.size() <= 0 || tile_size.size() > get_max_dim()) { - spdlog::error("[DMA {}] issued tile is not supported format..", _id); + spdlog::error("[DMA {}] issued tile is not supported format.. tile.size: {}, tile_size: [{}]", _id, tile_size.size(), fmt::join(tile_size, ", ")); exit(EXIT_FAILURE); } _finished = false; diff --git a/TOGSim/src/helper/CommandLineParser.cc b/TOGSim/src/helper/CommandLineParser.cc index 66aebbe1..9cd177ac 100644 --- a/TOGSim/src/helper/CommandLineParser.cc +++ b/TOGSim/src/helper/CommandLineParser.cc @@ -12,9 +12,13 @@ void CommandLineParser::parse(int argc, char **argv) noexcept(false) { po::notify(variables_map); } +void CommandLineParser::print_help_message() const noexcept { + std::cout << options_description << std::endl; +} + void CommandLineParser::print_help_message_if_required() const noexcept { if (variables_map.count("help") > 0) { - std::cout << options_description << std::endl; + print_help_message(); exit(0); } } diff --git a/TOGSim/src/helper/CommandLineParser.h b/TOGSim/src/helper/CommandLineParser.h index 39174d5d..b41eabf3 100644 --- a/TOGSim/src/helper/CommandLineParser.h +++ b/TOGSim/src/helper/CommandLineParser.h @@ -19,7 +19,7 @@ class CommandLineParser { * Command Line Parser constructor */ CommandLineParser() noexcept { - options_description.add_options()("help", "Prints help message"); + options_description.add_options()("help,h", "Prints help message"); } /** @@ -38,6 +38,12 @@ class CommandLineParser { */ void print_help_message_if_required() const noexcept; + /** + * Prints the help message. + * (Can be called to show help for invalid options) + */ + void print_help_message() const noexcept; + /** * Add a new command line argument option. * (Should be called before `parse` method is called) diff --git a/TOGSim/src/main.cc b/TOGSim/src/main.cc index 7c596af5..cda8f986 100644 --- a/TOGSim/src/main.cc +++ b/TOGSim/src/main.cc @@ -96,19 +96,24 @@ int main(int argc, char** argv) { // parse command line argumnet CommandLineParser cmd_parser = CommandLineParser(); cmd_parser.add_command_line_option( - "config", "Path for hardware configuration file"); + "config", "Path for hardware configuration file (.yml)"); cmd_parser.add_command_line_option( - "models_list", "Path for the models list file (can be FIFO or regular file)"); + "models_list", "Path for the trace file (.trace)"); cmd_parser.add_command_line_option( "log_level", "Set for log level [trace, debug, info], default = info"); try { cmd_parser.parse(argc, argv); } catch (const CommandLineParser::ParsingError& e) { spdlog::error( - "Command line argument parrsing error captured. Error message: {}", + "Command line argument parsing error captured. Error message: {}", e.what()); - throw(e); + std::cerr << std::endl; + cmd_parser.print_help_message(); + exit(1); } + + // Check if help was requested + cmd_parser.print_help_message_if_required(); std::string level = "info"; cmd_parser.set_if_defined("log_level", &level); From a24f1f1081a4ce7e5e09a59f61763850d11d994f Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 5 Mar 2026 17:45:00 +0900 Subject: [PATCH 7/9] [Template/Cat] Limit maximum rank of tile --- PyTorchSimFrontend/mlir/mlir_cat_template.py | 52 +++++++++++++++----- 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py index 5aaf3e71..2a00ce95 100644 --- a/PyTorchSimFrontend/mlir/mlir_cat_template.py +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -64,17 +64,30 @@ def render( tile_sizes = tile_info if tile_info is not None else [1] * len(output_sizes) output_strides = y.get_layout().stride + excluded_dims = list() + max_tiled_dims = 4 - 1 + if len(tile_sizes) > max_tiled_dims: + # Create index:tile_size dictionary and sort by tile_size + dim_tile_dict = {idx: sz for idx, sz in enumerate(tile_sizes)} + sorted_dims = sorted(dim_tile_dict.items(), key=lambda x: x[1], reverse=True) + # Keep top 4 dimensions, exclude the rest + excluded_dims = [idx for idx, _ in sorted_dims[max_tiled_dims:]] + for idx in excluded_dims: + tile_sizes[idx] = 1 + # Calculate input tile sizes input_tile_sizes_dim = self._calculate_input_tile_sizes( kernel, input_sizes, tile_sizes, num_inputs, rank ) buffer_name_to_template_name, input_buffer_names = self._build_buffer_mapping(input_nodes) input_tile_descs, output_tile_descs, unique_tile_descs = self._build_tile_descriptors( - kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, y + kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, y, + excluded_dims=excluded_dims ) input_idxs, output_idxs, cumulative_offsets = self._build_index_expressions( - input_nodes, input_sizes, output_strides, rank, num_inputs + input_nodes, input_sizes, output_strides, rank, num_inputs, + excluded_dims=excluded_dims ) # Map unique buffer names to their tile descriptors for template @@ -203,9 +216,12 @@ def _build_buffer_mapping(self, input_nodes): return buffer_name_to_template_name, input_buffer_names def _build_tile_descriptors( - self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, output_node + self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, output_node, excluded_dims=None ): """Build tile descriptors for each input and output.""" + if excluded_dims is None: + excluded_dims = set() + input_tile_descs = [] output_tile_descs = [] unique_tile_descs = {} @@ -217,16 +233,21 @@ def _build_tile_descriptors( tile_size_idx = 0 for d in range(rank): if d != self.dim: - full_tile_sizes.append(tile_sizes[tile_size_idx]) + # Skip excluded dimensions + if tile_size_idx not in excluded_dims: + full_tile_sizes.append(tile_sizes[tile_size_idx]) tile_size_idx += 1 else: full_tile_sizes.append(input_tile_sizes_dim[i]) + # Calculate vlane_split_axis for reduced dimensions + vlane_split_axis = len(full_tile_sizes) - 1 + # Input tile descriptor input_tile_desc = mlir_common.MLIRMultiDimTile( full_tile_sizes, kernel.vector_lane, - vlane_split_axis=rank - 1, + vlane_split_axis=vlane_split_axis, vlane_stride=1 ) input_tile_desc.set_tile_size(full_tile_sizes) @@ -239,7 +260,7 @@ def _build_tile_descriptors( output_tile_desc = mlir_common.MLIRMultiDimTile( full_tile_sizes, kernel.vector_lane, - vlane_split_axis=rank - 1, + vlane_split_axis=vlane_split_axis, vlane_stride=1 ) output_tile_desc.set_tile_size(full_tile_sizes) @@ -255,9 +276,12 @@ def _build_tile_descriptors( return input_tile_descs, output_tile_descs, unique_tile_descs def _build_index_expressions( - self, input_nodes, input_sizes, output_strides, rank, num_inputs + self, input_nodes, input_sizes, output_strides, rank, num_inputs, excluded_dims=None ): """Build index expressions for input and output.""" + if excluded_dims is None: + excluded_dims = set() + input_idxs = [] output_idxs = [] cumulative_offsets = [0] @@ -274,15 +298,21 @@ def _build_index_expressions( input_idx = [] output_idx = [] + tile_size_idx = 0 for d in range(rank): if d != self.dim: - input_idx_symbol = sympy.Symbol(f"index{d}") - output_idx_symbol = sympy.Symbol(f"index{d}") + # Skip excluded dimensions + if tile_size_idx not in excluded_dims: + input_idx_symbol = sympy.Symbol(f"index{d}") + output_idx_symbol = sympy.Symbol(f"index{d}") + input_idx.append(input_idx_symbol * x_stride[d]) + output_idx.append(output_idx_symbol * output_strides[d]) + tile_size_idx += 1 else: input_idx_symbol = sympy.Symbol(f"index_local{self.dim}_{i}") output_idx_symbol = sympy.Symbol(f"index{self.dim}_{i}") - input_idx.append(input_idx_symbol * x_stride[d]) - output_idx.append(output_idx_symbol * output_strides[d]) + input_idx.append(input_idx_symbol * x_stride[d]) + output_idx.append(output_idx_symbol * output_strides[d]) input_idxs.append(input_idx) output_idxs.append(output_idx) From 4e4300e2cda61dcc5eeec103c91fe5ef13ff3a73 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 5 Mar 2026 20:22:10 +0900 Subject: [PATCH 8/9] [Template/Cat] Refactor cat + Support explicit dram+stride in def_dma_op --- .github/workflows/pytorchsim_test.yml | 21 + PyTorchSimFrontend/mlir/mlir_cat_template.py | 401 ++++++++++--------- PyTorchSimFrontend/mlir/mlir_template.py | 48 ++- tests/test_cat.py | 16 +- 4 files changed, 288 insertions(+), 198 deletions(-) diff --git a/.github/workflows/pytorchsim_test.yml b/.github/workflows/pytorchsim_test.yml index 9589384b..eaaa7e50 100644 --- a/.github/workflows/pytorchsim_test.yml +++ b/.github/workflows/pytorchsim_test.yml @@ -163,6 +163,27 @@ jobs: -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_conv2d.py + test_cat: + name: Run test_cat.py + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_cat.py + run: | + echo "Running test_cat.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + -e vpu_num_lanes="${{ inputs.vector_lane }}" \ + -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/test_cat.py + test_matmul: name: Run test_matmul.py runs-on: self-hosted diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py index 2a00ce95..6eb60198 100644 --- a/PyTorchSimFrontend/mlir/mlir_cat_template.py +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Set import math import itertools @@ -23,10 +23,12 @@ {%- endfor %} {%- for i in range(NUM_INPUTS) %} // Input tensor{{ i }} - affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUT_SIZES[i][DIM] }} step {{ INPUT_TILE_SIZES_DIM[i] }} { - %index{{ DIM }}_{{i}} = affine.apply affine_map<(d0) -> (d0 + {{ CUMULATIVE_OFFSETS[i] }})> (%index_local{{ DIM }}_{{ i }}) - {{ kernel.def_dma_op("MVIN", INPUT_BUFFER_NAMES[i], INPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, OUTPUT_IDXS[i], OUTPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }} + affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUTS[i].sizes[DIM] }} step {{ INPUTS[i].tile_size_dim }} { + %index{{ DIM }}_{{ i }} = affine.apply affine_map<(d0) -> (d0 + {{ INPUTS[i].cum_offset }})> (%index_local{{ DIM }}_{{ i }}) + %input_dram_offset_{{ i }} = affine.apply {{ INPUTS[i].offset_map }}({{ INPUTS[i].offset_vars }}) + %output_dram_offset_{{ i }} = affine.apply {{ OUTPUTS[i].offset_map }}({{ OUTPUTS[i].offset_vars }}) + {{ kernel.def_dma_op("MVIN", INPUTS[i].dram_name, [], INPUTS[i].tile_desc, indent_size=INDENT_SIZE, dram_stride=INPUTS[i].dram_strides, dram_offset="input_dram_offset_" ~ i) }} + {{ kernel.def_dma_op("MVOUT", "Y", [], OUTPUTS[i].tile_desc, indent_size=INDENT_SIZE, dram_stride=OUTPUTS[i].dram_strides, dram_offset="output_dram_offset_" ~ i) }} } { inner_loop=true } {%- endfor %} @@ -52,81 +54,84 @@ def render( tile_info=None, **kwargs, ): - # Extract info input_nodes = self.input_nodes y = self.output_node - num_inputs = len(self.input_nodes) + num_inputs = len(input_nodes) rank = len(y.get_size()) input_sizes = [x.get_size() for x in input_nodes] - output_sizes = [sz for dim, sz in enumerate(y.get_size()) if dim != self.dim] - output_dim = [dim for dim, sz in enumerate(y.get_size()) if dim != self.dim] - tile_sizes = tile_info if tile_info is not None else [1] * len(output_sizes) + output_sizes = [sz for d, sz in enumerate(y.get_size()) if d != self.dim] + output_dim = [d for d, _ in enumerate(y.get_size()) if d != self.dim] output_strides = y.get_layout().stride - excluded_dims = list() - max_tiled_dims = 4 - 1 - if len(tile_sizes) > max_tiled_dims: - # Create index:tile_size dictionary and sort by tile_size - dim_tile_dict = {idx: sz for idx, sz in enumerate(tile_sizes)} - sorted_dims = sorted(dim_tile_dict.items(), key=lambda x: x[1], reverse=True) - # Keep top 4 dimensions, exclude the rest - excluded_dims = [idx for idx, _ in sorted_dims[max_tiled_dims:]] - for idx in excluded_dims: - tile_sizes[idx] = 1 - - # Calculate input tile sizes + tile_sizes = list(tile_info) if tile_info is not None else [1] * len(output_sizes) + excluded_dims = self._compute_excluded_dims(tile_sizes) + input_tile_sizes_dim = self._calculate_input_tile_sizes( kernel, input_sizes, tile_sizes, num_inputs, rank ) - buffer_name_to_template_name, input_buffer_names = self._build_buffer_mapping(input_nodes) + buffer_name_to_template_name, input_dram_names = self._build_buffer_mapping(input_nodes) input_tile_descs, output_tile_descs, unique_tile_descs = self._build_tile_descriptors( - kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, y, - excluded_dims=excluded_dims + kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, + input_dram_names, y, excluded_dims=excluded_dims ) - - input_idxs, output_idxs, cumulative_offsets = self._build_index_expressions( - input_nodes, input_sizes, output_strides, rank, num_inputs, - excluded_dims=excluded_dims + (input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets) = self._build_dma_info( + input_nodes, input_sizes, output_strides, input_tile_descs, output_tile_descs, + rank, num_inputs, excluded_dims=excluded_dims ) - # Map unique buffer names to their tile descriptors for template - unique_buffer_tile_descs = {} - for actual_name, template_name in buffer_name_to_template_name.items(): - if actual_name in unique_tile_descs: - unique_buffer_tile_descs[template_name] = unique_tile_descs[actual_name] - - names_str = ", ".join(input_buffer_names + ["Y"]) + unique_buffer_tile_descs = { + buffer_name_to_template_name[name]: desc + for name, desc in unique_tile_descs.items() + } + names_str = ", ".join(input_dram_names + ["Y"]) indent_size = 2 + (rank - 1) * 2 + 4 + inputs_info = [ + dict( + dram_name = input_dram_names[i], + sizes = input_sizes[i], + tile_size_dim= input_tile_sizes_dim[i], + tile_desc = input_tile_descs[i], + offset_map = input_offset_maps[i], + offset_vars = input_offset_var_strs[i], + dram_strides = input_dram_strides[i], + cum_offset = cumulative_offsets[i], + ) + for i in range(num_inputs) + ] + outputs_info = [ + dict( + tile_desc = output_tile_descs[i], + offset_map = output_offset_maps[i], + offset_vars = output_offset_var_strs[i], + dram_strides = output_dram_strides[i], + ) + for i in range(num_inputs) + ] + kernel.render_options = dict( - KERNEL_NAME=self.name, - kernel=kernel, - Y=y, - OUT_DVAR="Y", - NAMES_STR=names_str, - INPUT_NAMES=input_nodes, - INPUT_BUFFER_NAMES=input_buffer_names, - NUM_INPUTS=num_inputs, - RANK=rank, - DIM=self.dim, - INPUT_SIZES=input_sizes, - OUTPUT_SIZES=output_sizes, - OUTPUT_DIM=output_dim, - TILE_SIZES=tile_sizes, - INPUT_TILE_SIZES_DIM=input_tile_sizes_dim, - INPUT_TILE_DESCS=input_tile_descs, - OUTPUT_TILE_DESCS=output_tile_descs, - UNIQUE_BUFFER_TILE_DESCS=unique_buffer_tile_descs, - INPUT_IDXS=input_idxs, - OUTPUT_IDXS=output_idxs, - CUMULATIVE_OFFSETS=cumulative_offsets, - INDENT_SIZE=indent_size, - input_reorder=self.input_reorder, + KERNEL_NAME = self.name, + kernel = kernel, + NUM_INPUTS = num_inputs, + NAMES_STR = names_str, + Y = y, + INPUT_NAMES = input_nodes, + RANK = rank, + DIM = self.dim, + OUTPUT_SIZES = output_sizes, + OUTPUT_DIM = output_dim, + TILE_SIZES = tile_sizes, + UNIQUE_BUFFER_TILE_DESCS = unique_buffer_tile_descs, + INPUTS = inputs_info, + OUTPUTS = outputs_info, + INDENT_SIZE = indent_size, + input_reorder = self.input_reorder, ) - code = self._template_from_string(TEMPLATE).render(**kernel.render_options) - return code + return self._template_from_string(TEMPLATE).render(**kernel.render_options) def get_tile_candidates( self, @@ -141,179 +146,217 @@ def get_tile_candidates( y = self.output_node num_inputs = len(self.input_nodes) - output_sizes = [sz for dim, sz in enumerate(y.get_size()) if dim != self.dim] - num_non_dim_dims = len(output_sizes) + output_sizes = [sz for d, sz in enumerate(y.get_size()) if d != self.dim] - if num_non_dim_dims == 0: + if not output_sizes: return [[1]] - tile_candidates = [] - dim_tile_candidates = [] + max_tile_total = kernel.spad_info["spad_size"] // ( + kernel.vector_lane * kernel.precision * 2 * num_inputs + ) + dim_tile_candidates = [] for dim_size in output_sizes: - dim_candidates = [] - max_tile = min(dim_size, kernel.spad_info["spad_size"] // (kernel.vector_lane * kernel.precision * 2 * num_inputs)) - + max_tile = min(dim_size, max_tile_total) + candidates = set() for mult in range(1, max_tile // kernel.vector_lane + 1): - tile = mult * kernel.vector_lane - if tile <= dim_size: - dim_candidates.append(tile) - + t = mult * kernel.vector_lane + if t <= dim_size: + candidates.add(t) if max_tile > 0: for exp in range(int(math.log2(max_tile)) + 1): - tile = 2 ** exp - if tile <= dim_size and tile not in dim_candidates: - dim_candidates.append(tile) - - if dim_size not in dim_candidates: - dim_candidates.append(dim_size) - - dim_tile_candidates.append(sorted(set(dim_candidates))[:5]) - - for tile_combo in itertools.product(*dim_tile_candidates): - total_elements = math.prod(tile_combo) - total_spad_needed = total_elements * (num_inputs + 1) * kernel.precision - - if total_spad_needed <= kernel.spad_info["spad_size"] * kernel.vector_lane: - tile_candidates.append(list(tile_combo)) + t = 2 ** exp + if t <= dim_size: + candidates.add(t) + candidates.add(dim_size) + dim_tile_candidates.append(sorted(candidates)[:5]) + + tile_candidates = [ + list(combo) + for combo in itertools.product(*dim_tile_candidates) + if math.prod(combo) * (num_inputs + 1) * kernel.precision + <= kernel.spad_info["spad_size"] * kernel.vector_lane + ] if not tile_candidates: - tile_candidates = [[1] * num_non_dim_dims] + tile_candidates = [[1] * len(output_sizes)] tile_candidates.sort(key=lambda x: -math.prod(x)) return tile_candidates[:4] - def _calculate_input_tile_sizes( - self, kernel, input_sizes, tile_sizes, num_inputs, rank - ): - """Calculate tile sizes for concat dimension for each input.""" + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _compute_excluded_dims(self, tile_sizes: list) -> list: + """Return non-tiled dimension indices when rank exceeds the 4-dim limit.""" + max_tiled = 3 + if len(tile_sizes) <= max_tiled: + return [] + sorted_dims = sorted(enumerate(tile_sizes), key=lambda x: x[1], reverse=True) + excluded = [idx for idx, _ in sorted_dims[max_tiled:]] + for idx in excluded: + tile_sizes[idx] = 1 + return excluded + + def _calculate_input_tile_sizes(self, kernel, input_sizes, tile_sizes, num_inputs, rank): + """Calculate tile sizes along the concat dimension for each input.""" non_dim_tile_elements = math.prod(tile_sizes) if tile_sizes else 1 - non_dim_tile_spad = non_dim_tile_elements * kernel.precision max_spad_per_input = kernel.spad_info["spad_size"] * kernel.vector_lane // 2 - extra_concat_input = math.ceil(max_spad_per_input / non_dim_tile_spad) - num_inputs + extra_concat = math.ceil(max_spad_per_input / (non_dim_tile_elements * kernel.precision)) - num_inputs input_tile_sizes_dim = [] for i in range(num_inputs): - input_dim_size = input_sizes[i][self.dim] - if extra_concat_input > 0 and non_dim_tile_elements > 0: - max_tile_dim = min(input_dim_size, extra_concat_input) - extra_concat_input -= max_tile_dim + if extra_concat > 0 and non_dim_tile_elements > 0: + tile_dim = min(input_sizes[i][self.dim], extra_concat) + extra_concat -= tile_dim else: - max_tile_dim = 1 - input_tile_sizes_dim.append(max_tile_dim) + tile_dim = 1 + input_tile_sizes_dim.append(tile_dim) return input_tile_sizes_dim def _build_buffer_mapping(self, input_nodes): - """Map actual buffer names to template buffer names """ - buffer_name_to_template_name = {} - input_buffer_names = [] + """Map actual buffer names to short template names (X0, X1, ...).""" + name_map = {} + template_names = [] for x in input_nodes: - actual_name = x.get_name() - template_name = buffer_name_to_template_name.setdefault( - actual_name, f"X{len(buffer_name_to_template_name)}" - ) - input_buffer_names.append(template_name) - return buffer_name_to_template_name, input_buffer_names + actual = x.get_name() + template = name_map.setdefault(actual, f"X{len(name_map)}") + template_names.append(template) + return name_map, template_names def _build_tile_descriptors( - self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, output_node, excluded_dims=None + self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, + input_buffer_names, output_node, excluded_dims=None ): - """Build tile descriptors for each input and output.""" + """Build tile descriptors for every input (and its paired output).""" if excluded_dims is None: excluded_dims = set() - input_tile_descs = [] - output_tile_descs = [] - unique_tile_descs = {} + def make_tile_desc(tile_sz, vector_lane, name, offset): + desc = mlir_common.MLIRMultiDimTile( + tile_sz, vector_lane, + vlane_split_axis=len(tile_sz) - 1, + vlane_stride=1 + ) + desc.set_tile_size(tile_sz) + desc.set_name(name) + desc.offset = offset + return desc + output_offset = output_node.get_layout().offset + input_tile_descs, output_tile_descs, unique_tile_descs = [], [], {} for i, x in enumerate(input_nodes): - x_offset = x.get_layout().offset - full_tile_sizes = [] - tile_size_idx = 0 + # Collect tile sizes for tiled dimensions only (skip excluded non-concat dims) + tile_sz = [] + tile_idx = 0 for d in range(rank): if d != self.dim: - # Skip excluded dimensions - if tile_size_idx not in excluded_dims: - full_tile_sizes.append(tile_sizes[tile_size_idx]) - tile_size_idx += 1 + if tile_idx not in excluded_dims: + tile_sz.append(tile_sizes[tile_idx]) + tile_idx += 1 else: - full_tile_sizes.append(input_tile_sizes_dim[i]) + tile_sz.append(input_tile_sizes_dim[i]) - # Calculate vlane_split_axis for reduced dimensions - vlane_split_axis = len(full_tile_sizes) - 1 + sram_name = f"{input_buffer_names[i].lower()}_cat_tile" + input_tile_descs.append(make_tile_desc(tile_sz, kernel.vector_lane, sram_name, x.get_layout().offset)) + output_tile_descs.append(make_tile_desc(tile_sz, kernel.vector_lane, sram_name, output_offset)) - # Input tile descriptor - input_tile_desc = mlir_common.MLIRMultiDimTile( - full_tile_sizes, - kernel.vector_lane, - vlane_split_axis=vlane_split_axis, - vlane_stride=1 - ) - input_tile_desc.set_tile_size(full_tile_sizes) - template_buffer_name = input_buffer_names[i] - input_tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile") - input_tile_desc.offset = x_offset - input_tile_descs.append(input_tile_desc) - - # Output tile descriptor (same as input but with output offset) - output_tile_desc = mlir_common.MLIRMultiDimTile( - full_tile_sizes, - kernel.vector_lane, - vlane_split_axis=vlane_split_axis, - vlane_stride=1 - ) - output_tile_desc.set_tile_size(full_tile_sizes) - output_tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile") - output_tile_desc.offset = output_offset - output_tile_descs.append(output_tile_desc) - - # Store unique tile desc by actual buffer name actual_name = x.get_name() if actual_name not in unique_tile_descs: - unique_tile_descs[actual_name] = input_tile_desc + unique_tile_descs[actual_name] = input_tile_descs[-1] return input_tile_descs, output_tile_descs, unique_tile_descs - def _build_index_expressions( - self, input_nodes, input_sizes, output_strides, rank, num_inputs, excluded_dims=None + def _build_dma_info( + self, input_nodes, input_sizes, output_strides, + input_tile_descs, output_tile_descs, + rank, num_inputs, excluded_dims=None ): - """Build index expressions for input and output.""" + """Build per-input DRAM offset affine maps and tile strides. + + Three stride concepts are maintained: + + * layout_strides (internal) - raw DRAM buffer strides for every rank + dimension, used to compute the flat base-address affine map. + These reflect how the tensor is physically laid out in DRAM. + * dram_strides (returned, ``def_dma_op dram_stride=``) - stride in + DRAM per *tiled* dimension (excluded dims removed). The DMA engine + uses these to walk DRAM when loading/storing a tile. + * sram_strides (inside ``def_dma_op``, from tile_desc) - stride in + SRAM per tiled dimension. The DMA engine uses these to place data + into the SRAM tile buffer. + + Returns: + input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets + """ if excluded_dims is None: excluded_dims = set() - input_idxs = [] - output_idxs = [] + def make_affine_map(idx_syms, strides, layout_offset): + terms = [] + for j, s in enumerate(strides): + s = int(s) + if s == 1: + terms.append(f"d{j}") + elif s != 0: + terms.append(f"d{j} * {s}") + try: + off = int(layout_offset) + except (TypeError, ValueError): + off = 0 + if off: + terms.append(str(off)) + dim_str = ", ".join(f"d{j}" for j in range(len(idx_syms))) + return f"affine_map<({dim_str}) -> ({' + '.join(terms) if terms else '0'})>" + cumulative_offsets = [0] for i in range(num_inputs - 1): cumulative_offsets.append(cumulative_offsets[-1] + input_sizes[i][self.dim]) + input_offset_maps, input_offset_var_strs, input_dram_strides = [], [], [] + output_offset_maps, output_offset_var_strs, output_dram_strides = [], [], [] + for i, x in enumerate(input_nodes): x_stride = x.get_layout().stride - x_offset = x.get_layout().offset if hasattr(x, 'data') and hasattr(x.data, 'dims'): - # In case of PermuteView, the stride is permuted - perm_dims = x.data.dims - x_stride = [x_stride[perm_dims[d]] for d in range(rank)] + # PermuteView: re-order strides according to the permutation + perm = x.data.dims + x_stride = [x_stride[perm[d]] for d in range(rank)] + + in_syms, in_layout_strides, in_dram_strides = [], [], [] + out_syms, out_layout_strides, out_dram_strides = [], [], [] + tile_idx = 0 - input_idx = [] - output_idx = [] - tile_size_idx = 0 for d in range(rank): if d != self.dim: - # Skip excluded dimensions - if tile_size_idx not in excluded_dims: - input_idx_symbol = sympy.Symbol(f"index{d}") - output_idx_symbol = sympy.Symbol(f"index{d}") - input_idx.append(input_idx_symbol * x_stride[d]) - output_idx.append(output_idx_symbol * output_strides[d]) - tile_size_idx += 1 + in_syms.append(sympy.Symbol(f"index{d}")) + in_layout_strides.append(int(x_stride[d])) + out_syms.append(sympy.Symbol(f"index{d}")) + out_layout_strides.append(int(output_strides[d])) + if tile_idx not in excluded_dims: + in_dram_strides.append(int(x_stride[d])) + out_dram_strides.append(int(output_strides[d])) + tile_idx += 1 else: - input_idx_symbol = sympy.Symbol(f"index_local{self.dim}_{i}") - output_idx_symbol = sympy.Symbol(f"index{self.dim}_{i}") - input_idx.append(input_idx_symbol * x_stride[d]) - output_idx.append(output_idx_symbol * output_strides[d]) - input_idxs.append(input_idx) - output_idxs.append(output_idx) - - return input_idxs, output_idxs, cumulative_offsets + in_syms.append(sympy.Symbol(f"index_local{self.dim}_{i}")) + in_layout_strides.append(int(x_stride[d])) + out_syms.append(sympy.Symbol(f"index{self.dim}_{i}")) + out_layout_strides.append(int(output_strides[d])) + in_dram_strides.append(int(x_stride[d])) + out_dram_strides.append(int(output_strides[d])) + + input_offset_maps.append(make_affine_map(in_syms, in_layout_strides, input_tile_descs[i].offset)) + input_offset_var_strs.append(", ".join(f"%{s}" for s in in_syms)) + input_dram_strides.append(in_dram_strides) + + output_offset_maps.append(make_affine_map(out_syms, out_layout_strides, output_tile_descs[i].offset)) + output_offset_var_strs.append(", ".join(f"%{s}" for s in out_syms)) + output_dram_strides.append(out_dram_strides) + + return (input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 7c52bfe6..9cc79e0a 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -809,12 +809,18 @@ def hook(): return key def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_common.MLIRMultiDimTile, - subtile_size:list=[], async_type=None, indent_size=0, priority: int = 5, lazy_mode: bool = True): + subtile_size:list=[], async_type=None, indent_size=0, priority: int = 5, lazy_mode: bool = True, + dram_stride:list=None, dram_offset=None): + # Todo. Remove legacy behavior (i.e., index_list parsing) def generate_dma_code(): """Internal method to generate DMA code directly.""" local_code = IndentedBuffer() with self, self.override_buffer_cse(buffer=local_code, cse=self.apply_cse): - index_var = self.parse_index_list(index_list, offset=tile_desc.offset) + if dram_offset is not None: + # Use explicitly provided offset (pre-computed MLIR SSA variable name) + index_var = dram_offset + else: + index_var = self.parse_index_list(index_list, offset=tile_desc.offset) node_layout = self.named_nodes[dram_var].get_layout() if dram_var in self.exception_nodes: numel = self.exception_nodes[dram_var]["numel"] @@ -822,27 +828,33 @@ def generate_dma_code(): numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] dram_shape = f"memref<{numel}x{mlir_dtype}>" - dram_stride = [] - for idx in index_list: - if idx.is_Mul: - dram_stride.append(int(idx.args[0])) - elif idx == sympy.Symbol("c0"): - dram_stride.append(0) - elif not idx.is_Number: - dram_stride.append(1) - else: - dram_stride.append(0) - sram_var = tile_desc.get_name() - tile_shape = tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = tile_desc.get_tile_stride() - vlane_split_axis = tile_desc.vmap.vlane_split_axis - vlane_stride = tile_desc.vmap.vlane_stride + if dram_stride is not None: + # Use explicitly provided dram_stride + _dram_stride = dram_stride + else: + # Extract dram_stride from index_list (legacy behavior) + _dram_stride = [] + for idx in index_list: + if idx.is_Mul: + _dram_stride.append(int(idx.args[0])) + elif idx == sympy.Symbol("c0"): + _dram_stride.append(0) + elif not idx.is_Number: + _dram_stride.append(1) + else: + _dram_stride.append(0) + + sram_var = tile_desc.get_name() + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) + sram_strides = tile_desc.get_tile_stride() + vlane_split_axis = tile_desc.vmap.vlane_split_axis + vlane_stride = tile_desc.vmap.vlane_stride zero_cse = self.get_const_cse(0, "index") sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) - attribute_parts = [f"dram_stride={dram_stride}", f"sram_stride={tile_stride}", "padding=0"] + attribute_parts = [f"dram_stride={_dram_stride}", f"sram_stride={sram_strides}", "padding=0"] if subtile_size: attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") attribute = " {" + ", ".join(attribute_parts) + "}" diff --git a/tests/test_cat.py b/tests/test_cat.py index 62de6759..97fcc754 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -150,13 +150,25 @@ def cat_4d_three_inputs_fn(a, b, c): cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu()], dim=1) _test_result("cat.4d.three_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) +def test_cat_5d(device, dim=0): + def cat_5d_fn(a, b): + return torch.cat([a, b], dim=dim) + + x = torch.randn(2, 3, 4, 5, 6, device=device) + y = torch.randn(3, 3, 4, 5, 6, device=device) + opt_fn = torch.compile(dynamic=False)(cat_5d_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=dim) + _test_result("cat.5d.dim0", out, cpu_out, rtol=1e-4, atol=1e-4) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run cat simulation tests") parser.add_argument( "--case", choices=[ - "default", "out", "4d_dim0", "4d_dim1", "4d_dim2", "4d_dim3", + "default", "out", "4d_dim0", "4d_dim1", "4d_dim2", "4d_dim3", "5d" "three_inputs", "four_inputs", "4d_three_inputs", "all" ], default="all", @@ -184,3 +196,5 @@ def cat_4d_three_inputs_fn(a, b, c): test_cat_four_inputs(device) if args.case in ("4d_three_inputs", "all"): test_cat_4d_three_inputs(device) + if args.case in ("5d", "all"): + test_cat_5d(device) From 591e8a98cdb7a734f58c3e2afff6b252f5b86bee Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 5 Mar 2026 23:16:40 +0900 Subject: [PATCH 9/9] [Templte/Cat] Apply copy operation when node has view --- PyTorchSimFrontend/mlir/mlir_cat_template.py | 11 +++------- PyTorchSimFrontend/mlir/mlir_lowering.py | 23 +++++++++++++++++--- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py index 6eb60198..7bee54ac 100644 --- a/PyTorchSimFrontend/mlir/mlir_cat_template.py +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -161,14 +161,14 @@ def get_tile_candidates( candidates = set() for mult in range(1, max_tile // kernel.vector_lane + 1): t = mult * kernel.vector_lane - if t <= dim_size: + if t <= dim_size and dim_size % t == 0: candidates.add(t) if max_tile > 0: for exp in range(int(math.log2(max_tile)) + 1): t = 2 ** exp - if t <= dim_size: + if t <= dim_size and dim_size % t == 0: candidates.add(t) - candidates.add(dim_size) + candidates.add(dim_size) # dim_size always divides itself dim_tile_candidates.append(sorted(candidates)[:5]) tile_candidates = [ @@ -322,11 +322,6 @@ def make_affine_map(idx_syms, strides, layout_offset): for i, x in enumerate(input_nodes): x_stride = x.get_layout().stride - if hasattr(x, 'data') and hasattr(x.data, 'dims'): - # PermuteView: re-order strides according to the permutation - perm = x.data.dims - x_stride = [x_stride[perm[d]] for d in range(rank)] - in_syms, in_layout_strides, in_dram_strides = [], [], [] out_syms, out_layout_strides, out_dram_strides = [], [], [] tile_idx = 0 diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index d7aee715..e5df4b78 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -1,3 +1,4 @@ +import math from typing import List, Optional, Sequence import torch @@ -205,11 +206,27 @@ def _cat_layout(tensors: Sequence[TensorBox], dim: int) -> ir.Layout: def custom_cat_default(tensors: Sequence[TensorBox], dim: int = 0): if tensors and dim < 0: dim += len(tensors[0].get_size()) - + copy_default_lowering = lowerings.get(aten.copy_.default) + empty_strided_lowering = lowerings.get(aten.empty_strided.default) + new_tensors = [] for t in tensors: t.realize() - layout = _cat_layout(tensors, dim) - mlir_template = MLIRCatTemplate(list(tensors), layout, dim=dim) + # If the tensor is backed by a view (ReinterpretView, PermuteView, etc.), + # materialise it into a fresh contiguous FixedLayout buffer so the cat + # kernel always receives plain, dense strides. + if isinstance(t.data, ir.BaseView): + sizes = list(t.get_size()) + strides = [math.prod(sizes[i + 1:]) for i in range(len(sizes))] + new_buf = empty_strided_lowering( + sizes, strides, dtype=t.get_dtype(), device=t.get_device() + ) + tt = copy_default_lowering(new_buf, t) + else: + tt = t + new_tensors.append(tt) + + layout = _cat_layout(new_tensors, dim) + mlir_template = MLIRCatTemplate(list(new_tensors), layout, dim=dim) return mlir_template.generate().output_node() def _custom_sort_values_impl(