From 50ce1d4eef80ee07827e5b5831ad090ef47fb532 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EC=9E=AC=EA=B7=A0?= Date: Mon, 2 Mar 2026 00:28:31 +0900 Subject: [PATCH] [Template] Add cat & sort template + Multi-output (WIP) --- .../torch_openreg/openreg/__init__.py | 49 +++ PyTorchSimFrontend/mlir/mlir_cat_template.py | 167 +++++++++++ PyTorchSimFrontend/mlir/mlir_common.py | 6 +- PyTorchSimFrontend/mlir/mlir_lowering.py | 281 +++++++++++++++++- PyTorchSimFrontend/mlir/mlir_sort_template.py | 253 ++++++++++++++++ PyTorchSimFrontend/mlir/mlir_template.py | 30 +- tests/DeepSeek/test_deepseek_v3_base.py | 170 +++++++++-- tests/test_cat.py | 89 ++++++ tests/test_sort.py | 112 +++++++ 9 files changed, 1121 insertions(+), 36 deletions(-) create mode 100644 PyTorchSimFrontend/mlir/mlir_cat_template.py create mode 100644 PyTorchSimFrontend/mlir/mlir_sort_template.py create mode 100644 tests/test_cat.py create mode 100644 tests/test_sort.py diff --git a/PyTorchSimDevice/torch_openreg/openreg/__init__.py b/PyTorchSimDevice/torch_openreg/openreg/__init__.py index f5aabc18..5603a4f7 100644 --- a/PyTorchSimDevice/torch_openreg/openreg/__init__.py +++ b/PyTorchSimDevice/torch_openreg/openreg/__init__.py @@ -256,6 +256,52 @@ def launch_model(model, *args, stream_index=0, timestamp=0, **kwargs): from .random import * # noqa: F403 from .amp import * +def _precheck_cat_out_args(args, kwargs): + tensors = args[0] if len(args) > 0 else kwargs.get("tensors") + dim = args[1] if len(args) > 1 else kwargs.get("dim", 0) + out = kwargs.get("out", args[2] if len(args) > 2 else None) + + if out is None: + return + if not isinstance(tensors, (list, tuple)) or len(tensors) == 0: + raise RuntimeError("aten::cat.out requires non-empty tensor list") + if not all(isinstance(t, torch.Tensor) for t in tensors): + raise RuntimeError("aten::cat.out tensors must be Tensor values") + if not isinstance(out, torch.Tensor): + raise RuntimeError("aten::cat.out out must be a Tensor") + + rank = tensors[0].dim() + if rank == 0: + raise RuntimeError("aten::cat.out does not support scalar inputs") + if dim < 0: + dim += rank + if dim < 0 or dim >= rank: + raise RuntimeError(f"aten::cat.out dim out of range: dim={dim}, rank={rank}") + if any(t.dim() != rank for t in tensors): + raise RuntimeError("aten::cat.out inputs must have the same rank") + if any(t.dtype != tensors[0].dtype for t in tensors): + raise RuntimeError("aten::cat.out inputs must have the same dtype") + if out.dim() != rank: + raise RuntimeError("aten::cat.out out rank mismatch") + + for d in range(rank): + if d == dim: + continue + base = tensors[0].shape[d] + if any(t.shape[d] != base for t in tensors[1:]): + raise RuntimeError( + f"aten::cat.out non-concatenated dimension mismatch at dim={d}" + ) + if out.shape[d] != base: + raise RuntimeError(f"aten::cat.out out shape mismatch at dim={d}") + + expected = sum(t.shape[dim] for t in tensors) + if out.shape[dim] != expected: + raise RuntimeError( + f"aten::cat.out out concatenated dimension mismatch at dim={dim}: " + f"expected {expected}, got {out.shape[dim]}" + ) + def eager_to_compile(op_name): """ Register an eager mode operation as a graph-based implementation using torch.compile(). @@ -267,6 +313,9 @@ def eager_to_compile(op_name): torch.npu.eager_to_compile("aten::mul.Tensor") """ def wrapper(*args, **kwargs): + if op_name == "aten::cat.out": + _precheck_cat_out_args(args, kwargs) + @torch.compile(dynamic=False) def dummy_graph(*args, **kwargs): # Convert "aten::mul.Tensor" -> torch.ops.aten.mul.Tensor diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py new file mode 100644 index 00000000..996af1de --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -0,0 +1,167 @@ +from typing import List, Optional, cast + +import sympy +from torch._inductor.ir import Buffer, IRNode +from torch._inductor.virtualized import V + +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel + + +TEMPLATE = r""" +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X0, X1], outputs=[Y], names_str=NAMES_STR, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X0", X0_TILE_DESC, id=0, indent_size=2) }} + {{ kernel.def_sram_buffer("X1", X1_TILE_DESC, id=1, indent_size=2) }} + {{ kernel.def_sram_buffer(OUT_DVAR, Y_TILE_DESC, id=2, indent_size=2) }} + {{ kernel.def_local_vars(indent_size=2) }} + + affine.for %cat_block = 0 to 1 step 1 { +{% if DIM == 0 %} + affine.for %index0 = 0 to {{ X0_ROWS }} step 1 { + affine.for %index1 = 0 to {{ COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X0", X0_IDX, X0_TILE_DESC, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y0_IDX, X0_TILE_DESC, indent_size=8) }} + } + } + + affine.for %index2 = 0 to {{ X1_ROWS }} step 1 { + affine.for %index3 = 0 to {{ COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X1", X1_IDX, X1_TILE_DESC, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y1_IDX, X1_TILE_DESC, indent_size=8) }} + } + } +{% else %} + affine.for %index0 = 0 to {{ ROWS }} step 1 { + affine.for %index1 = 0 to {{ X0_COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X0", X0_IDX, X0_TILE_DESC, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y0_IDX, X0_TILE_DESC, indent_size=8) }} + } + affine.for %index3 = 0 to {{ X1_COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X1", X1_IDX, X1_TILE_DESC, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y1_IDX, X1_TILE_DESC, indent_size=8) }} + } + } +{% endif %} + } { outer_loop=true } + return +} +""" + + +class MLIRCatTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, dim, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.dim = dim + + def render( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + tile_info=None, + **kwargs, + ): + is_out_variant = template_buffer_node is not None + if is_out_variant: + self.output_node = template_buffer_node + # cat template currently emits a single output buffer and does not + # support epilogue output remapping. + + def _unwrap_node(n): + return n.node if hasattr(n, "node") else n + + x0 = _unwrap_node(self.input_nodes[0]) + x1 = _unwrap_node(self.input_nodes[1]) + y = _unwrap_node(self.output_node) + + def _as_int(v): + try: + return int(v) + except Exception: + return int(V.graph.sizevars.size_hint(v)) + + x0_rows = _as_int(x0.get_size()[0]) + x1_rows = _as_int(x1.get_size()[0]) + x0_cols = _as_int(x0.get_size()[1]) + x1_cols = _as_int(x1.get_size()[1]) + y_cols = _as_int(y.get_size()[1]) + kernel.loop_size = None + + # 2D cat template with contiguous layout. + x0_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + x0_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + x0_tile_desc.set_name("x0_cat_tile") + x1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + x1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + x1_tile_desc.set_name("x1_cat_tile") + y_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + y_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + y_tile_desc.set_name("y_cat_tile") + + if self.dim == 0: + # Flattened offsets for dim=0 cat. + x0_idx = [sympy.Symbol("index0") * x0_cols, sympy.Symbol("index1")] + x1_idx = [sympy.Symbol("index2") * x1_cols, sympy.Symbol("index3")] + y0_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index1")] + y1_idx = [(sympy.Symbol("index2") + x0_rows) * y_cols, sympy.Symbol("index3")] + else: + # Flattened offsets for dim=1 cat. + x0_idx = [sympy.Symbol("index0") * x0_cols, sympy.Symbol("index1")] + x1_idx = [sympy.Symbol("index0") * x1_cols, sympy.Symbol("index3")] + y0_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index1")] + y1_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index3") + x0_cols] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X0=x0, + X1=x1, + Y=y, + OUT_DVAR="out_ptr1" if is_out_variant else "Y", + NAMES_STR="X0, X1, out_ptr1" if is_out_variant else "X0, X1, Y", + DIM=self.dim, + X0_ROWS=x0_rows, + X1_ROWS=x1_rows, + ROWS=x0_rows, + X0_COLS=x0_cols, + X1_COLS=x1_cols, + COLS=x0_cols, + X0_TILE_DESC=x0_tile_desc, + X1_TILE_DESC=x1_tile_desc, + Y_TILE_DESC=y_tile_desc, + X0_IDX=x0_idx, + X1_IDX=x1_idx, + Y0_IDX=y0_idx, + Y1_IDX=y1_idx, + input_reorder=self.input_reorder, + ) + # Needed when epilogue fusion requests set_ranges(). + kernel.dim_aliasing = {"index0": "index0", "index1": "index1"} + + if hasattr(self.output_node, "node") and hasattr(self.output_node.node, "get_name"): + output_node_name = self.output_node.node.get_name() + elif hasattr(self.output_node, "get_name"): + output_node_name = self.output_node.get_name() + else: + output_node_name = self.output_node.name + + if hasattr(y, "get_numel"): + y_numel = y.get_numel() + elif hasattr(y, "node") and hasattr(y.node, "get_numel"): + y_numel = y.node.get_numel() + else: + y_numel = None + + kernel.epilogue_info = dict( + output_node=output_node_name, + sram_var="y_cat_tile", + dram_var=kernel.render_options["OUT_DVAR"], + dram_tile_desc=y_tile_desc, + ) + if y_numel is not None: + kernel.exception_nodes[kernel.render_options["OUT_DVAR"]] = {"numel": y_numel} + + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + return code diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 34b185b8..256d7101 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -173,7 +173,11 @@ def get_mlir_shape(info): def mlir_argdefs(self, extra_node=dict()): buffer_types = {} for x in V.graph.buffers: - if not isinstance(x.layout, MultiOutputLayout): # FIXME: MultiOutputLayout should be handled + if isinstance(x.layout, MultiOutputLayout): + # MultiOutput kernel containers own concrete output nodes in `outputs`. + for out in getattr(x, "outputs", []): + buffer_types[out.get_name()] = [out.get_dtype(), out.get_numel(), out.get_size(), out.get_stride()] + else: buffer_types[x.get_name()] = [x.get_dtype(), x.get_numel(), x.get_size(), x.get_stride()] for name, val in V.graph.graph_inputs.items(): if isinstance(val, sympy.Expr): diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index ebf0c80e..0f28f03b 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -15,10 +15,15 @@ from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate +from PyTorchSimFrontend.mlir.mlir_cat_template import MLIRCatTemplate +from PyTorchSimFrontend.mlir.mlir_sort_template import MLIRSortTemplate from PyTorchSimFrontend import extension_config aten = torch.ops.aten aten_spmm = MLIRExternKernelChoice(torch.sparse.mm, "custom_op::sparse_addmm") +_orig_cat_default_lowering = lowerings.get(aten.cat.default) +_orig_cat_out_lowering = lowerings.get(aten.cat.out) +_orig_sort_values_stable_lowering = lowerings.get(aten.sort.values_stable) def tuned_mm(mat1, mat2, * ,layout=None): m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) @@ -181,11 +186,285 @@ def custom_unsafe_index(x, indices): x.realize() return index_impl(x, indices, check=False) + +def _cat_layout(tensors: Sequence[TensorBox], dim: int) -> ir.Layout: + with V.graph.fake_mode: + output = torch.ops.aten.cat( + [ir.ir_node_to_tensor(t, guard_shape=True) for t in tensors], + dim, + ) + sizes = ir.convert_shape_to_inductor(output.size()) + stride = ir.convert_shape_to_inductor(output.stride()) + return ir.FixedLayout( + tensors[0].get_device(), + tensors[0].get_dtype(), + sizes, + stride, + ) + + +def _can_use_cat_template(tensors: Sequence[TensorBox], dim: int) -> bool: + # Current template specialization: 2 inputs, rank-2, dim in {0, 1}. + if len(tensors) != 2: + return False + if not all(hasattr(t, "get_size") and hasattr(t, "get_dtype") and hasattr(t, "realize") for t in tensors): + return False + if tensors[0].get_dtype() != tensors[1].get_dtype(): + return False + rank0 = len(tensors[0].get_size()) + rank1 = len(tensors[1].get_size()) + if rank0 != 2 or rank1 != 2: + return False + if dim < 0: + dim += rank0 + if dim not in (0, 1): + return False + + if dim == 0: + cols0 = tensors[0].get_size()[1] + cols1 = tensors[1].get_size()[1] + return V.graph.sizevars.statically_known_equals(cols0, cols1) + + rows0 = tensors[0].get_size()[0] + rows1 = tensors[1].get_size()[0] + return V.graph.sizevars.statically_known_equals(rows0, rows1) + + +def _cat_fallback(reason: str, tensors: Sequence[TensorBox], dim: int): + # Non-template cases delegate to the original lowering path. + return _orig_cat_default_lowering(tensors, dim) + + +def _custom_cat_impl(tensors: Sequence[TensorBox], dim: int = 0): + if _orig_cat_default_lowering is None: + raise RuntimeError("Original aten.cat.default lowering is missing") + if len(tensors) > 0: + rank = len(tensors[0].get_size()) + if dim < 0: + dim += rank + if not _can_use_cat_template(tensors, dim): + return _cat_fallback("default-path", tensors, dim) + + for t in tensors: + t.realize() + layout = _cat_layout(tensors, dim) + mlir_template = MLIRCatTemplate(list(tensors), layout, dim=dim) + return mlir_template.generate().output_node() + + +def custom_cat_default(tensors: Sequence[TensorBox], dim: int = 0): + return _custom_cat_impl(tensors, dim) + + +def custom_cat_out(tensors: Sequence[TensorBox], dim: int = 0, out: Optional[TensorBox] = None): + if _orig_cat_out_lowering is None: + raise RuntimeError("Original aten.cat.out lowering is missing") + if out is None: + return _orig_cat_out_lowering(tensors, dim, out) + + copy_default_lowering = lowerings.get(aten.copy_.default) + slice_tensor_lowering = lowerings.get(aten.slice.Tensor) + if copy_default_lowering is None or slice_tensor_lowering is None: + raise RuntimeError("cat.out lowering requires aten.copy_.default and aten.slice.Tensor lowerings") + + # Lower cat.out as a sequence of slice+copy ops so each piece still runs + # through the existing compiled/simulated kernel path. + if len(tensors) == 0: + raise RuntimeError("cat.out requires at least one input tensor") + if not all(hasattr(t, "get_size") and hasattr(t, "get_dtype") and hasattr(t, "realize") for t in tensors): + raise RuntimeError("cat.out inputs must be tensor-like values") + rank = len(tensors[0].get_size()) + if rank == 0: + raise RuntimeError("cat.out does not support scalar inputs") + if dim < 0: + dim = dim + rank + if dim < 0 or dim >= rank: + raise RuntimeError(f"cat.out dim out of range: dim={dim}, rank={rank}") + if any(len(t.get_size()) != rank for t in tensors): + raise RuntimeError("cat.out inputs must have the same rank") + if any(t.get_dtype() != tensors[0].get_dtype() for t in tensors): + raise RuntimeError("cat.out inputs must have the same dtype") + # cat semantics: all non-cat dimensions must be equal. + for i in range(rank): + if i == dim: + continue + base = tensors[0].get_size()[i] + if any(not V.graph.sizevars.statically_known_equals(base, t.get_size()[i]) for t in tensors[1:]): + raise RuntimeError(f"cat.out non-concatenated dimension mismatch at dim={i}") + + # Output shape must match concatenated shape. + if not hasattr(out, "get_size"): + raise RuntimeError("cat.out output must be tensor-like") + out_sizes = list(out.get_size()) + if len(out_sizes) != rank: + raise RuntimeError("cat.out output rank mismatch") + for i in range(rank): + if i == dim: + continue + if not V.graph.sizevars.statically_known_equals(out_sizes[i], tensors[0].get_size()[i]): + raise RuntimeError(f"cat.out output shape mismatch at dim={i}") + expected_cat = sum(t.get_size()[dim] for t in tensors) + if not V.graph.sizevars.statically_known_equals(out_sizes[dim], expected_cat): + raise RuntimeError(f"cat.out output concatenated dimension mismatch at dim={dim}") + + if isinstance(out, TensorBox): + out.realize() + + offset = 0 + for src in tensors: + src.realize() + end = offset + src.get_size()[dim] + dst_view = slice_tensor_lowering(out, dim, offset, end, 1) + copy_default_lowering(dst_view, src) + offset = end + return out + + +def _custom_sort_values_impl( + self: TensorBox, + dim: int = -1, + descending: bool = False, + values: Optional[TensorBox] = None, + indices: Optional[TensorBox] = None, + stable: Optional[bool] = None, +): + if values is None or indices is None: + raise RuntimeError("sort.values* lowering requires both out tensors: values, indices") + + def _normalize_dim(rank: int, d: int) -> int: + return d + rank if d < 0 else d + + if not hasattr(self, "get_size"): + raise RuntimeError("sort.values* lowering requires TensorBox input") + + rank = len(self.get_size()) + norm_dim = _normalize_dim(rank, dim) + if norm_dim < 0 or norm_dim >= rank: + raise RuntimeError(f"sort.values* dim out of range: dim={dim}, rank={rank}") + if rank != 2: + raise RuntimeError(f"sort.values* lowering currently supports rank-2 only, got rank={rank}") + if norm_dim not in (0, 1): + raise RuntimeError(f"sort.values* lowering currently supports dim in {{0,1}} only, got dim={norm_dim}") + + self.realize() + if isinstance(values, TensorBox): + values.realize() + if isinstance(indices, TensorBox): + indices.realize() + + value_layout, _ = _sort_layouts(self, norm_dim, descending) + mlir_template = MLIRSortTemplate( + [self], + value_layout, + dim=norm_dim, + descending=descending, + stable=True if stable is None else stable, + indices_node=indices, + ) + sorted_values = mlir_template.generate(template_buffer_node=values, epilogue_nodes=[indices]).output_node() + return sorted_values, indices + + +def _sort_layouts(x: TensorBox, dim: int, descending: bool): + with V.graph.fake_mode: + v, i = torch.ops.aten.sort( + ir.ir_node_to_tensor(x, guard_shape=True), + dim, + descending, + ) + v_sizes = ir.convert_shape_to_inductor(v.size()) + v_stride = ir.convert_shape_to_inductor(v.stride()) + i_sizes = ir.convert_shape_to_inductor(i.size()) + i_stride = ir.convert_shape_to_inductor(i.stride()) + + value_layout = ir.FixedLayout(x.get_device(), x.get_dtype(), v_sizes, v_stride) + index_layout = ir.FixedLayout(x.get_device(), torch.int64, i_sizes, i_stride) + return value_layout, index_layout + + +def custom_sort_stable( + self: TensorBox, + *, + stable: Optional[bool] = None, + dim: int = -1, + descending: bool = False, +): + empty_strided_lowering = lowerings.get(aten.empty_strided.default) + if empty_strided_lowering is None: + if _orig_sort_values_stable_lowering is None: + raise RuntimeError("sort.stable lowering requires aten.empty_strided.default") + return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=True) + + rank = len(self.get_size()) if hasattr(self, "get_size") else 0 + norm_dim = dim + rank if dim < 0 else dim + if rank > 0 and (norm_dim < 0 or norm_dim >= rank): + raise RuntimeError(f"sort.stable dim out of range: dim={dim}, rank={rank}") + + # Template specialization supports rank-2 and dim in {0,1}. + if rank == 2 and norm_dim not in (0, 1): + if _orig_sort_values_stable_lowering is None: + raise RuntimeError("Original aten.sort.values_stable lowering is missing") + return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=True) + + try: + value_layout, index_layout = _sort_layouts(self, norm_dim, descending) + values = empty_strided_lowering( + list(value_layout.size), + list(value_layout.stride), + dtype=value_layout.dtype, + device=self.get_device(), + ) + indices = empty_strided_lowering( + list(index_layout.size), + list(index_layout.stride), + dtype=index_layout.dtype, + device=self.get_device(), + ) + return _custom_sort_values_impl( + self=self, + dim=dim, + descending=descending, + values=values, + indices=indices, + stable=True if stable is None else stable, + ) + except Exception: + if _orig_sort_values_stable_lowering is None: + raise + return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=stable) + + +def custom_sort_values_stable( + self: TensorBox, + *, + stable: Optional[bool] = None, + dim: int = -1, + descending: bool = False, + values: Optional[TensorBox] = None, + indices: Optional[TensorBox] = None, +): + return _custom_sort_values_impl( + self=self, + dim=dim, + descending=descending, + values=values, + indices=indices, + stable=stable, + ) + + lowerings.update({getattr(aten.mm, overload): tuned_mm for overload in aten.mm.overloads()}) lowerings.update({getattr(aten.addmm, overload): tuned_addmm for overload in aten.addmm.overloads()}) lowerings.update({getattr(aten.convolution, overload): convolution for overload in aten.convolution.overloads()}) lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()}) lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()}) + +lowerings.update({aten.cat.default: custom_cat_default}) +lowerings.update({aten.cat.out: custom_cat_out}) + +lowerings.update({aten.sort.stable: custom_sort_stable}) +lowerings.update({aten.sort.values_stable: custom_sort_values_stable}) + if extension_config.CONFIG_USE_TIMING_POOLING: - lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template \ No newline at end of file + lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template diff --git a/PyTorchSimFrontend/mlir/mlir_sort_template.py b/PyTorchSimFrontend/mlir/mlir_sort_template.py new file mode 100644 index 00000000..d12c7570 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_sort_template.py @@ -0,0 +1,253 @@ +from typing import List, Optional + +import sympy +from torch._inductor.ir import IRNode +from torch._inductor.virtualized import V + +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel + + +TEMPLATE = r""" +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X, YI], outputs=[YV], names_str=NAMES_STR, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("YI", YI_TILE_DESC, id=1, indent_size=2) }} + {{ kernel.def_sram_buffer(OUT_DVAR, YV_TILE_DESC, id=2, indent_size=2) }} + {{ kernel.def_local_vars(indent_size=2) }} + + %c0 = arith.constant 0 : index + %c_cols = arith.constant {{ COLS }} : index + + affine.for %sort_block = 0 to 1 step 1 { + // Initialize output value/index buffers. + affine.for %row = 0 to {{ ROWS }} step 1 { + affine.for %col = 0 to {{ COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X", INIT_X_IDX, X_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, INIT_YV_IDX, X_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} +{% if DIM == 1 %} + %idx_i64 = arith.index_cast %col : index to {{ YI_ELEM_TYPE }} +{% else %} + %idx_i64 = arith.index_cast %row : index to {{ YI_ELEM_TYPE }} +{% endif %} + memref.store %idx_i64, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", INIT_YI_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} + } + } + +{% if DIM == 1 %} + // Stable bubble sort on each row (dim=1). + affine.for %row = 0 to {{ ROWS }} step 1 { + affine.for %pass = 0 to {{ COLS }} step 1 { + affine.for %j = 0 to {{ COLS_MINUS1 }} step 1 { + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D1_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %lhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D1_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %rhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + +{% if DESCENDING %} + %need_swap = arith.cmpf olt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% else %} + %need_swap = arith.cmpf ogt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% endif %} + scf.if %need_swap { + memref.store %rhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D1_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %lhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D1_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + {{ kernel.def_dma_op("MVIN", "YI", D1_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %li = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", "YI", D1_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %ri = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + memref.store %ri, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D1_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %li, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D1_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + } + } + } + } +{% else %} + // Stable bubble sort on each column (dim=0). + affine.for %col = 0 to {{ COLS }} step 1 { + affine.for %pass = 0 to {{ ROWS }} step 1 { + affine.for %i = 0 to {{ ROWS_MINUS1 }} step 1 { + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D0_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %lhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D0_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %rhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + +{% if DESCENDING %} + %need_swap = arith.cmpf olt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% else %} + %need_swap = arith.cmpf ogt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% endif %} + scf.if %need_swap { + memref.store %rhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D0_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %lhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D0_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + {{ kernel.def_dma_op("MVIN", "YI", D0_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %li = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", "YI", D0_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %ri = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + memref.store %ri, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D0_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %li, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D0_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + } + } + } + } +{% endif %} + } { outer_loop=true } + return +} +""" + + +class MLIRSortTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, dim, descending=False, stable=False, indices_node=None, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.dim = dim + self.descending = descending + self.stable = stable + self.indices_node = indices_node + + def render( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + tile_info=None, + **kwargs, + ): + if template_buffer_node is not None: + self.output_node = template_buffer_node + if self.indices_node is None: + raise RuntimeError("MLIRSortTemplate requires indices output node") + + x = self.input_nodes[0] + yv = self.output_node + yi = self.indices_node + + def _as_int(v): + try: + return int(v) + except Exception: + return int(V.graph.sizevars.size_hint(v)) + + x_size = x.get_size() + if len(x_size) != 2: + raise RuntimeError("MLIRSortTemplate currently supports rank-2 input only") + if self.dim not in (0, 1): + raise RuntimeError(f"MLIRSortTemplate currently supports dim in {{0,1}} only, got dim={self.dim}") + + rows = _as_int(x_size[0]) + cols = _as_int(x_size[1]) + cols_minus1 = max(0, cols - 1) + rows_minus1 = max(0, rows - 1) + + x_dtype = x.get_dtype() + yv_dtype = yv.get_dtype() + yi_dtype = yi.get_dtype() + if x_dtype != yv_dtype: + raise RuntimeError("sort template requires input/value dtype match") + + yi_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yi_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yi_tile_desc.set_name("yi_sort_tile") + yv_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yv_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yv_tile_desc.set_name("yv_sort_tile") + # Neighbor element descriptors use DRAM offset to preserve affine stride metadata. + yv_s1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yv_s1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yv_s1_tile_desc.set_name("yv_sort_tile") + yi_s1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yi_s1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yi_s1_tile_desc.set_name("yi_sort_tile") + if int(self.dim) == 1: + yv_s1_tile_desc.offset = sympy.Integer(1) + yi_s1_tile_desc.offset = sympy.Integer(1) + else: + yv_s1_tile_desc.offset = sympy.Integer(cols) + yi_s1_tile_desc.offset = sympy.Integer(cols) + + row = sympy.Symbol("row") + col = sympy.Symbol("col") + i = sympy.Symbol("i") + j = sympy.Symbol("j") + + init_x_idx = [row * cols, col] + init_yv_idx = [row * cols, col] + init_yi_idx = [row * cols, col] + + d1_s0_idx = [row * cols, j] + d1_s1_idx = [row * cols, j] + + d0_s0_idx = [i * cols, col] + d0_s1_idx = [i * cols, col] + + kernel.loop_size = None + numel = rows * cols + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X=x, + YV=yv, + YI=yi, + OUT_DVAR="YV", + NAMES_STR="X, YI, YV", + ROWS=rows, + COLS=cols, + COLS_MINUS1=cols_minus1, + ROWS_MINUS1=rows_minus1, + DIM=int(self.dim), + DESCENDING=bool(self.descending), + YI_TILE_DESC=yi_tile_desc, + YV_TILE_DESC=yv_tile_desc, + YI_S1_TILE_DESC=yi_s1_tile_desc, + YV_S1_TILE_DESC=yv_s1_tile_desc, + INIT_X_IDX=init_x_idx, + INIT_YV_IDX=init_yv_idx, + INIT_YI_IDX=init_yi_idx, + D1_S0_IDX=d1_s0_idx, + D1_S1_IDX=d1_s1_idx, + D0_S0_IDX=d0_s0_idx, + D0_S1_IDX=d0_s1_idx, + YV_ELEM_TYPE=mlir_common.DTYPE_TO_MLIR[yv_dtype], + YI_ELEM_TYPE=mlir_common.DTYPE_TO_MLIR[yi_dtype], + X_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[x_dtype]}>", + YV_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[yv_dtype]}>", + YI_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[yi_dtype]}>", + YV_TILE_MEMREF_TYPE=yv_tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[yv_dtype]), + YI_TILE_MEMREF_TYPE=yi_tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[yi_dtype]), + X_TILE_DESC=yv_tile_desc, + input_reorder=self.input_reorder, + ) + + output_node_name = yv.get_name() if hasattr(yv, "get_name") else yv.name + kernel.epilogue_info = dict( + output_node=output_node_name, + sram_var="yv_sort_tile", + dram_var=kernel.render_options["OUT_DVAR"], + dram_tile_desc=yv_tile_desc, + ) + kernel.exception_nodes[kernel.render_options["OUT_DVAR"]] = {"numel": yv.get_numel()} + kernel.exception_nodes["YI"] = {"numel": yi.get_numel()} + + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + return code diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index b1c756ba..76b0ef71 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -403,7 +403,7 @@ def call_kernel(self, kernel_name): _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this wrapper.generate_kernel_call( - kernel_name if self.outer_func_name is None else "wrapper_" + kernel_name, call_args) + kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args) def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_nodes, tile_info): with self as kernel: @@ -628,8 +628,26 @@ def def_kernel( self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] def hook(): - arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) - return f"({', '.join(arg_defs)})" + arg_defs, call_args, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) + output_names = names[len(inputs) : len(inputs) + len(outputs)] + out_ptr_idx = 0 + renamed_arg_defs = [] + for outer, arg_def in zip(call_args, arg_defs): + raw_symbol = arg_def.split(":", 1)[0].strip().lstrip("%") + if outer in self.kernel_group.args.input_buffers: + symbol = self.kernel_group.args.input_buffers[outer] + elif outer in self.kernel_group.args.output_buffers: + symbol = self.kernel_group.args.output_buffers[outer] + elif raw_symbol.startswith("out_ptr") and out_ptr_idx < len(output_names): + symbol = output_names[out_ptr_idx] + out_ptr_idx += 1 + elif outer in self.kernel_group.args.sizevars: + symbol = self.kernel_group.args.sizevars[outer] + else: + symbol = raw_symbol + _, arg_type = arg_def.split(":", 1) + renamed_arg_defs.append(f"%{symbol}:{arg_type}") + return f"({', '.join(renamed_arg_defs)})" assert "" not in self.render_hooks self.render_hooks[""] = hook @@ -1151,6 +1169,8 @@ def __init__(self, name, input_nodes, layout, input_reorder = None): super().__init__(name) self.input_nodes = [node for node in input_nodes if node is not None] self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + # Multi-output templates can override this with explicit output buffers. + self.output_nodes = [self.output_node] self.input_reorder = input_reorder self.layout = layout @@ -1166,10 +1186,12 @@ def generate(self, **kwargs) -> ChoiceCaller: kernel_hash_name = f"mlir_{self.name}_{next(self.index_counter)}" extra_args = [] # create the BenchmarkRequest + output_nodes = getattr(self, "output_nodes", None) or [self.output_node] + bmreq = MLIRBenchmarkRequest( kernel_name=kernel_name, input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), - output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + output_tensor_meta=TensorMeta.from_irnodes(output_nodes), extra_args=extra_args, source_code=code, ) diff --git a/tests/DeepSeek/test_deepseek_v3_base.py b/tests/DeepSeek/test_deepseek_v3_base.py index b8402c8b..ade787c5 100644 --- a/tests/DeepSeek/test_deepseek_v3_base.py +++ b/tests/DeepSeek/test_deepseek_v3_base.py @@ -1,8 +1,55 @@ import os import sys import argparse +import copy +from pathlib import Path import torch +# recursive compile for some ops that are caused by graph break +torch.npu.register_eager_to_compile([ + "aten::zero_", + "aten::sum.IntList_out", + "aten::mul.out", + "aten::floor_divide", + "aten::floor_divide.Tensor", + "aten::floor_divide.Scalar", + "aten::cat.out", + "aten::sort.values_stable", +]) + + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + out_cpu = out.cpu() + max_diff = (out_cpu - cpu_out).abs().max().item() + mean_diff = (out_cpu - cpu_out).abs().mean().item() + if torch.allclose(out_cpu, cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print(f"Max absolute difference: {max_diff:.6f}") + print(f"Mean absolute difference: {mean_diff:.6f}") + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("NPU out: ", out_cpu) + print("CPU out: ", cpu_out) + print(f"Max absolute difference: {max_diff:.6f}") + print(f"Mean absolute difference: {mean_diff:.6f}") + exit(1) + + +def _extract_logits(output): + if isinstance(output, torch.Tensor): + return output + if hasattr(output, "logits"): + return output.logits + if isinstance(output, (list, tuple)) and len(output) > 0 and isinstance(output[0], torch.Tensor): + return output[0] + raise TypeError(f"Unsupported output type for comparison: {type(output)}") + def _dtype_from_str(name: str) -> torch.dtype: return { @@ -81,7 +128,7 @@ def _maybe_scale_config(config, scale=1.0, max_layers=None): def _apply_preset(scale, max_layers, batch, seq_len, preset): if preset == "tiny": - return 0.03, 4, 1, min(seq_len, 16) + return 0.03, 1, 1, min(seq_len, 16) if preset == "small": return 0.07, 8, 1, min(seq_len, 32) if preset == "medium": @@ -89,8 +136,58 @@ def _apply_preset(scale, max_layers, batch, seq_len, preset): return scale, max_layers, batch, seq_len +def _togsim_log_count() -> int: + log_dir = Path("togsim_results") + if not log_dir.exists(): + return 0 + return len(list(log_dir.glob("*.log"))) + + +def _assert_simulation_happened(before_count: int, case_name: str): + after_count = _togsim_log_count() + if after_count <= before_count: + raise RuntimeError( + f"{case_name}: TOGSim log count did not increase " + f"(before={before_count}, after={after_count})" + ) + print(f"{case_name}: TOGSim logs increased ({before_count} -> {after_count})") + + +def test_cat_default(device): + def cat_default_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_default_fn) + + before = _togsim_log_count() + out = opt_fn(x, y) + _assert_simulation_happened(before, "cat.default") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + test_result("cat.default", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_out(device): + def cat_out_fn(a, b, out): + return torch.ops.aten.cat.out([a, b], 0, out=out) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + out_buf = torch.empty(14, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_out_fn) + + before = _togsim_log_count() + out = opt_fn(x, y, out_buf) + _assert_simulation_happened(before, "cat.out") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + test_result("cat.out", out, cpu_out, rtol=1e-4, atol=1e-4) + + @torch.no_grad() -def run_deep_seek_v3_base_test( +def run_deepseek_v3_base( model_id, device, init_mode="config-random", @@ -120,7 +217,6 @@ def run_deep_seek_v3_base_test( # (call .to_dict()), so only disable it for pretrained loading path. if init_mode == "pretrained" and getattr(config, "quantization_config", None) is not None: config.quantization_config = None - config = _maybe_scale_config(config, scale=scale, max_layers=max_layers) if init_mode == "config-random": @@ -141,7 +237,6 @@ def run_deep_seek_v3_base_test( else: raise ValueError(f"Unsupported init mode: {init_mode}") - model = model.to(device) model_params = sum(p.numel() for p in model.parameters()) print("init mode:", init_mode) print("scaled hidden_size:", getattr(config, "hidden_size", "n/a")) @@ -157,23 +252,33 @@ def run_deep_seek_v3_base_test( revision=revision, ) encoded = tokenizer(prompt, return_tensors="pt") - input_ids = encoded["input_ids"].to(device) + cpu_input_ids = encoded["input_ids"].cpu() else: vocab_size = getattr(config, "vocab_size", None) if vocab_size is None: raise ValueError("Config has no vocab_size; use --use-tokenizer or pass a model with vocab_size.") - input_ids = _build_random_inputs(batch, seq_len, vocab_size, device) + cpu_input_ids = _build_random_inputs(batch, seq_len, vocab_size, torch.device("cpu")) + input_ids = cpu_input_ids.to(device) - if compile_model: - model = torch.compile(model, dynamic=False) + # CPU version + model_cpu = copy.deepcopy(model).cpu().eval() + cpu_out = _extract_logits(model_cpu(cpu_input_ids)) - out = model(input_ids) - logits = out.logits + # NPU version + model_npu = copy.deepcopy(model_cpu).to(device).eval() + if compile_model: + model_npu = torch.compile(model_npu, dynamic=False) + npu_out = _extract_logits(model_npu(input_ids)) + + # Campare results + test_result( + "DeepSeek V3 Base", + npu_out, + cpu_out, + rtol=3e-1, + atol=2e-1, + ) - print("logits shape:", tuple(logits.shape)) - print("logits dtype:", logits.dtype) - print("logits max:", logits.max().item()) - if __name__ == "__main__": parser = argparse.ArgumentParser(description="DeepSeek V3 download-based test") @@ -181,7 +286,7 @@ def run_deep_seek_v3_base_test( parser.add_argument("--revision", type=str, default=None) parser.add_argument("--trust-remote-code", action="store_true", default=True) parser.add_argument("--init-mode", type=str, default="config-random", choices=["config-random", "pretrained"]) - parser.add_argument("--preset", type=str, default="tiny", choices=["none", "tiny", "small", "medium"]) + parser.add_argument("--preset", type=str, default="small", choices=["none", "tiny", "small", "medium"]) parser.add_argument("--scale", type=float, default=1.0) parser.add_argument("--max-layers", type=int, default=None) parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) @@ -190,6 +295,7 @@ def run_deep_seek_v3_base_test( parser.add_argument("--use-tokenizer", action="store_true") parser.add_argument("--prompt", type=str, default="Hello, DeepSeek V3") parser.add_argument("--compile", action="store_true", default=True) + parser.add_argument("--test", type=str, default="e2e", choices=["all", "e2e", "cat"]) args = parser.parse_args() @@ -203,18 +309,22 @@ def run_deep_seek_v3_base_test( device = torch.device("npu:0") - run_deep_seek_v3_base_test( - model_id=args.model_id, - device=device, - init_mode=args.init_mode, - scale=args.scale, - max_layers=args.max_layers, - dtype=args.dtype, - batch=args.batch, - seq_len=args.seq_len, - use_tokenizer=args.use_tokenizer, - prompt=args.prompt, - trust_remote_code=args.trust_remote_code, - revision=args.revision, - compile_model=args.compile, - ) + if args.test in ("all", "cat"): + test_cat_default(device) + test_cat_out(device) + if args.test in ("all", "e2e"): + run_deepseek_v3_base( + model_id=args.model_id, + device=device, + init_mode=args.init_mode, + scale=args.scale, + max_layers=args.max_layers, + dtype=args.dtype, + batch=args.batch, + seq_len=args.seq_len, + use_tokenizer=args.use_tokenizer, + prompt=args.prompt, + trust_remote_code=args.trust_remote_code, + revision=args.revision, + compile_model=args.compile, + ) diff --git a/tests/test_cat.py b/tests/test_cat.py new file mode 100644 index 00000000..32573a05 --- /dev/null +++ b/tests/test_cat.py @@ -0,0 +1,89 @@ +import argparse +from pathlib import Path + +import torch + + +def _test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + return + + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + raise RuntimeError(f"{name} mismatch") + + +def _togsim_log_count() -> int: + log_dir = Path("togsim_results") + if not log_dir.exists(): + return 0 + return len(list(log_dir.glob("*.log"))) + + +def _assert_simulation_happened(before_count: int, case_name: str): + after_count = _togsim_log_count() + if after_count <= before_count: + raise RuntimeError( + f"{case_name}: TOGSim log count did not increase " + f"(before={before_count}, after={after_count})" + ) + print(f"{case_name}: TOGSim logs increased ({before_count} -> {after_count})") + + +def test_cat_default(device): + def cat_default_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_default_fn) + + before = _togsim_log_count() + out = opt_fn(x, y) + _assert_simulation_happened(before, "cat.default") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.default", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_out(device): + def cat_out_fn(a, b, out): + return torch.ops.aten.cat.out([a, b], 0, out=out) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + out_buf = torch.empty(14, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_out_fn) + + before = _togsim_log_count() + out = opt_fn(x, y, out_buf) + _assert_simulation_happened(before, "cat.out") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.out", out, cpu_out, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run cat simulation tests") + parser.add_argument( + "--case", + choices=["default", "out", "all"], + default="all", + help="Which cat case to run", + ) + args = parser.parse_args() + + device = torch.device("npu:0") + + if args.case in ("default", "all"): + test_cat_default(device) + if args.case in ("out", "all"): + test_cat_out(device) diff --git a/tests/test_sort.py b/tests/test_sort.py new file mode 100644 index 00000000..2b070223 --- /dev/null +++ b/tests/test_sort.py @@ -0,0 +1,112 @@ +import argparse +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out:", out.cpu()) + print("cpu out:", cpu_out) + raise SystemExit(1) + + +def test_equal(name, out, cpu_out): + if torch.equal(out.cpu(), cpu_out): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out:", out.cpu()) + print("cpu out:", cpu_out) + raise SystemExit(1) + + +def _normalize_dim(dim: int, rank: int) -> int: + d = dim if dim >= 0 else rank + dim + if d < 0 or d >= rank: + raise ValueError(f"dim out of range: dim={dim}, rank={rank}") + return d + + +def test_sort_stable(device, size=(128, 128), dim=-1, descending=False): + _normalize_dim(dim, len(size)) + + def sort_stable_fn(x): + return torch.sort(x, stable=True, dim=dim, descending=descending) + + x = torch.randn(size, dtype=torch.float32) + x_npu = x.to(device=device) + + opt_sort = torch.compile(dynamic=False)(sort_stable_fn) + out_values, out_indices = opt_sort(x_npu) + + ref_values, ref_indices = torch.sort(x, stable=True, dim=dim, descending=descending) + + test_result("Sort.stable/values", out_values, ref_values) + test_equal("Sort.stable/indices", out_indices, ref_indices) + + +def test_sort_values_stable(device, size=(128, 128), dim=-1, descending=False): + _normalize_dim(dim, len(size)) + + def sort_out_fn(x): + out_values = torch.empty_like(x, device=x.device) + out_indices = torch.empty_like(x, dtype=torch.int64, device=x.device) + return torch.sort(x, stable=True, dim=dim, descending=descending, out=(out_values, out_indices)) + + x = torch.randn(size, dtype=torch.float32) + x_npu = x.to(device=device) + + opt_sort = sort_out_fn# torch.compile(dynamic=False)(sort_out_fn) + out_values, out_indices = opt_sort(x_npu) + + ref_values, ref_indices = torch.sort(x, stable=True, dim=dim, descending=descending) + + test_result("Sort.values_stable/values", out_values, ref_values) + test_equal("Sort.values_stable/indices", out_indices, ref_indices) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run sort tests") + parser.add_argument("--shape", type=str, default="(128,128)") + parser.add_argument("--dim", type=int, default=0) + parser.add_argument("--descending", action="store_true") + parser.add_argument( + "--mode", + type=str, + default="all", + choices=["all", "default", "values"], + ) + args = parser.parse_args() + + shape = tuple(map(int, args.shape.strip("()").split(","))) + + from Scheduler.scheduler import PyTorchSimRunner + + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + # Register recursive-compile bridge only when values_stable path is explicitly tested. + if args.mode in ("all", "values"): + torch.npu.register_eager_to_compile([ + "aten::sort.values_stable", + ]) + + if args.mode in ("all", "default"): + test_sort_stable(device, size=shape, dim=args.dim, descending=args.descending) + if args.mode in ("all", "values"): + test_sort_values_stable(device, size=shape, dim=args.dim, descending=args.descending)