diff --git a/.github/workflows/pytorchsim_test.yml b/.github/workflows/pytorchsim_test.yml index 9589384b..eaaa7e50 100644 --- a/.github/workflows/pytorchsim_test.yml +++ b/.github/workflows/pytorchsim_test.yml @@ -163,6 +163,27 @@ jobs: -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_conv2d.py + test_cat: + name: Run test_cat.py + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_cat.py + run: | + echo "Running test_cat.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + -e vpu_num_lanes="${{ inputs.vector_lane }}" \ + -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/test_cat.py + test_matmul: name: Run test_matmul.py runs-on: self-hosted diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 178ea987..9398f90c 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -154,6 +154,9 @@ class MLIRBMMTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = True + self.support_reduction_fusion = True def render(self, kernel: MLIRTemplateKernel, diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py new file mode 100644 index 00000000..7bee54ac --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -0,0 +1,357 @@ +from typing import List, Optional, Set +import math +import itertools + +import sympy +from torch._inductor.ir import IRNode + +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel + + +TEMPLATE = r""" +{{kernel.def_global_vars()}} +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=INPUT_NAMES, outputs=[Y], names_str=NAMES_STR, input_reorder=input_reorder)}} { +{%- for buffer_name, tile_desc in UNIQUE_BUFFER_TILE_DESCS.items() %} + {{ kernel.def_sram_buffer(buffer_name, tile_desc, indent_size=2) }} +{%- endfor %} + {{ kernel.def_local_vars(indent_size=2) }} + + affine.for %cat_block = 0 to 1 step 1 { +{%- for d in range(RANK-1) %} + affine.for %index{{ OUTPUT_DIM[d] }} = 0 to {{ OUTPUT_SIZES[d] }} step {{ TILE_SIZES[d] }} { +{%- endfor %} +{%- for i in range(NUM_INPUTS) %} + // Input tensor{{ i }} + affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUTS[i].sizes[DIM] }} step {{ INPUTS[i].tile_size_dim }} { + %index{{ DIM }}_{{ i }} = affine.apply affine_map<(d0) -> (d0 + {{ INPUTS[i].cum_offset }})> (%index_local{{ DIM }}_{{ i }}) + %input_dram_offset_{{ i }} = affine.apply {{ INPUTS[i].offset_map }}({{ INPUTS[i].offset_vars }}) + %output_dram_offset_{{ i }} = affine.apply {{ OUTPUTS[i].offset_map }}({{ OUTPUTS[i].offset_vars }}) + {{ kernel.def_dma_op("MVIN", INPUTS[i].dram_name, [], INPUTS[i].tile_desc, indent_size=INDENT_SIZE, dram_stride=INPUTS[i].dram_strides, dram_offset="input_dram_offset_" ~ i) }} + {{ kernel.def_dma_op("MVOUT", "Y", [], OUTPUTS[i].tile_desc, indent_size=INDENT_SIZE, dram_stride=OUTPUTS[i].dram_strides, dram_offset="output_dram_offset_" ~ i) }} + } { inner_loop=true } +{%- endfor %} + +{%- for d in range(RANK-1) %} + } { outer_loop=true } +{%- endfor %} + } { outer_loop=true } + return +} +""" + + +class MLIRCatTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, dim): + super().__init__("kernel", input_nodes, layout) + self.dim = dim + + def render( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + tile_info=None, + **kwargs, + ): + input_nodes = self.input_nodes + y = self.output_node + num_inputs = len(input_nodes) + rank = len(y.get_size()) + + input_sizes = [x.get_size() for x in input_nodes] + output_sizes = [sz for d, sz in enumerate(y.get_size()) if d != self.dim] + output_dim = [d for d, _ in enumerate(y.get_size()) if d != self.dim] + output_strides = y.get_layout().stride + + tile_sizes = list(tile_info) if tile_info is not None else [1] * len(output_sizes) + excluded_dims = self._compute_excluded_dims(tile_sizes) + + input_tile_sizes_dim = self._calculate_input_tile_sizes( + kernel, input_sizes, tile_sizes, num_inputs, rank + ) + buffer_name_to_template_name, input_dram_names = self._build_buffer_mapping(input_nodes) + input_tile_descs, output_tile_descs, unique_tile_descs = self._build_tile_descriptors( + kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, + input_dram_names, y, excluded_dims=excluded_dims + ) + (input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets) = self._build_dma_info( + input_nodes, input_sizes, output_strides, input_tile_descs, output_tile_descs, + rank, num_inputs, excluded_dims=excluded_dims + ) + + unique_buffer_tile_descs = { + buffer_name_to_template_name[name]: desc + for name, desc in unique_tile_descs.items() + } + names_str = ", ".join(input_dram_names + ["Y"]) + indent_size = 2 + (rank - 1) * 2 + 4 + + inputs_info = [ + dict( + dram_name = input_dram_names[i], + sizes = input_sizes[i], + tile_size_dim= input_tile_sizes_dim[i], + tile_desc = input_tile_descs[i], + offset_map = input_offset_maps[i], + offset_vars = input_offset_var_strs[i], + dram_strides = input_dram_strides[i], + cum_offset = cumulative_offsets[i], + ) + for i in range(num_inputs) + ] + outputs_info = [ + dict( + tile_desc = output_tile_descs[i], + offset_map = output_offset_maps[i], + offset_vars = output_offset_var_strs[i], + dram_strides = output_dram_strides[i], + ) + for i in range(num_inputs) + ] + + kernel.render_options = dict( + KERNEL_NAME = self.name, + kernel = kernel, + NUM_INPUTS = num_inputs, + NAMES_STR = names_str, + Y = y, + INPUT_NAMES = input_nodes, + RANK = rank, + DIM = self.dim, + OUTPUT_SIZES = output_sizes, + OUTPUT_DIM = output_dim, + TILE_SIZES = tile_sizes, + UNIQUE_BUFFER_TILE_DESCS = unique_buffer_tile_descs, + INPUTS = inputs_info, + OUTPUTS = outputs_info, + INDENT_SIZE = indent_size, + input_reorder = self.input_reorder, + ) + + return self._template_from_string(TEMPLATE).render(**kernel.render_options) + + def get_tile_candidates( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs, + ): + """Generate tile candidates for cat operation. Concat dimension always has tile size 1.""" + if template_buffer_node is not None: + self.output_node = template_buffer_node + + y = self.output_node + num_inputs = len(self.input_nodes) + output_sizes = [sz for d, sz in enumerate(y.get_size()) if d != self.dim] + + if not output_sizes: + return [[1]] + + max_tile_total = kernel.spad_info["spad_size"] // ( + kernel.vector_lane * kernel.precision * 2 * num_inputs + ) + + dim_tile_candidates = [] + for dim_size in output_sizes: + max_tile = min(dim_size, max_tile_total) + candidates = set() + for mult in range(1, max_tile // kernel.vector_lane + 1): + t = mult * kernel.vector_lane + if t <= dim_size and dim_size % t == 0: + candidates.add(t) + if max_tile > 0: + for exp in range(int(math.log2(max_tile)) + 1): + t = 2 ** exp + if t <= dim_size and dim_size % t == 0: + candidates.add(t) + candidates.add(dim_size) # dim_size always divides itself + dim_tile_candidates.append(sorted(candidates)[:5]) + + tile_candidates = [ + list(combo) + for combo in itertools.product(*dim_tile_candidates) + if math.prod(combo) * (num_inputs + 1) * kernel.precision + <= kernel.spad_info["spad_size"] * kernel.vector_lane + ] + + if not tile_candidates: + tile_candidates = [[1] * len(output_sizes)] + + tile_candidates.sort(key=lambda x: -math.prod(x)) + return tile_candidates[:4] + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _compute_excluded_dims(self, tile_sizes: list) -> list: + """Return non-tiled dimension indices when rank exceeds the 4-dim limit.""" + max_tiled = 3 + if len(tile_sizes) <= max_tiled: + return [] + sorted_dims = sorted(enumerate(tile_sizes), key=lambda x: x[1], reverse=True) + excluded = [idx for idx, _ in sorted_dims[max_tiled:]] + for idx in excluded: + tile_sizes[idx] = 1 + return excluded + + def _calculate_input_tile_sizes(self, kernel, input_sizes, tile_sizes, num_inputs, rank): + """Calculate tile sizes along the concat dimension for each input.""" + non_dim_tile_elements = math.prod(tile_sizes) if tile_sizes else 1 + max_spad_per_input = kernel.spad_info["spad_size"] * kernel.vector_lane // 2 + extra_concat = math.ceil(max_spad_per_input / (non_dim_tile_elements * kernel.precision)) - num_inputs + + input_tile_sizes_dim = [] + for i in range(num_inputs): + if extra_concat > 0 and non_dim_tile_elements > 0: + tile_dim = min(input_sizes[i][self.dim], extra_concat) + extra_concat -= tile_dim + else: + tile_dim = 1 + input_tile_sizes_dim.append(tile_dim) + return input_tile_sizes_dim + + def _build_buffer_mapping(self, input_nodes): + """Map actual buffer names to short template names (X0, X1, ...).""" + name_map = {} + template_names = [] + for x in input_nodes: + actual = x.get_name() + template = name_map.setdefault(actual, f"X{len(name_map)}") + template_names.append(template) + return name_map, template_names + + def _build_tile_descriptors( + self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, + input_buffer_names, output_node, excluded_dims=None + ): + """Build tile descriptors for every input (and its paired output).""" + if excluded_dims is None: + excluded_dims = set() + + def make_tile_desc(tile_sz, vector_lane, name, offset): + desc = mlir_common.MLIRMultiDimTile( + tile_sz, vector_lane, + vlane_split_axis=len(tile_sz) - 1, + vlane_stride=1 + ) + desc.set_tile_size(tile_sz) + desc.set_name(name) + desc.offset = offset + return desc + + output_offset = output_node.get_layout().offset + input_tile_descs, output_tile_descs, unique_tile_descs = [], [], {} + + for i, x in enumerate(input_nodes): + # Collect tile sizes for tiled dimensions only (skip excluded non-concat dims) + tile_sz = [] + tile_idx = 0 + for d in range(rank): + if d != self.dim: + if tile_idx not in excluded_dims: + tile_sz.append(tile_sizes[tile_idx]) + tile_idx += 1 + else: + tile_sz.append(input_tile_sizes_dim[i]) + + sram_name = f"{input_buffer_names[i].lower()}_cat_tile" + input_tile_descs.append(make_tile_desc(tile_sz, kernel.vector_lane, sram_name, x.get_layout().offset)) + output_tile_descs.append(make_tile_desc(tile_sz, kernel.vector_lane, sram_name, output_offset)) + + actual_name = x.get_name() + if actual_name not in unique_tile_descs: + unique_tile_descs[actual_name] = input_tile_descs[-1] + + return input_tile_descs, output_tile_descs, unique_tile_descs + + def _build_dma_info( + self, input_nodes, input_sizes, output_strides, + input_tile_descs, output_tile_descs, + rank, num_inputs, excluded_dims=None + ): + """Build per-input DRAM offset affine maps and tile strides. + + Three stride concepts are maintained: + + * layout_strides (internal) - raw DRAM buffer strides for every rank + dimension, used to compute the flat base-address affine map. + These reflect how the tensor is physically laid out in DRAM. + * dram_strides (returned, ``def_dma_op dram_stride=``) - stride in + DRAM per *tiled* dimension (excluded dims removed). The DMA engine + uses these to walk DRAM when loading/storing a tile. + * sram_strides (inside ``def_dma_op``, from tile_desc) - stride in + SRAM per tiled dimension. The DMA engine uses these to place data + into the SRAM tile buffer. + + Returns: + input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets + """ + if excluded_dims is None: + excluded_dims = set() + + def make_affine_map(idx_syms, strides, layout_offset): + terms = [] + for j, s in enumerate(strides): + s = int(s) + if s == 1: + terms.append(f"d{j}") + elif s != 0: + terms.append(f"d{j} * {s}") + try: + off = int(layout_offset) + except (TypeError, ValueError): + off = 0 + if off: + terms.append(str(off)) + dim_str = ", ".join(f"d{j}" for j in range(len(idx_syms))) + return f"affine_map<({dim_str}) -> ({' + '.join(terms) if terms else '0'})>" + + cumulative_offsets = [0] + for i in range(num_inputs - 1): + cumulative_offsets.append(cumulative_offsets[-1] + input_sizes[i][self.dim]) + + input_offset_maps, input_offset_var_strs, input_dram_strides = [], [], [] + output_offset_maps, output_offset_var_strs, output_dram_strides = [], [], [] + + for i, x in enumerate(input_nodes): + x_stride = x.get_layout().stride + in_syms, in_layout_strides, in_dram_strides = [], [], [] + out_syms, out_layout_strides, out_dram_strides = [], [], [] + tile_idx = 0 + + for d in range(rank): + if d != self.dim: + in_syms.append(sympy.Symbol(f"index{d}")) + in_layout_strides.append(int(x_stride[d])) + out_syms.append(sympy.Symbol(f"index{d}")) + out_layout_strides.append(int(output_strides[d])) + if tile_idx not in excluded_dims: + in_dram_strides.append(int(x_stride[d])) + out_dram_strides.append(int(output_strides[d])) + tile_idx += 1 + else: + in_syms.append(sympy.Symbol(f"index_local{self.dim}_{i}")) + in_layout_strides.append(int(x_stride[d])) + out_syms.append(sympy.Symbol(f"index{self.dim}_{i}")) + out_layout_strides.append(int(output_strides[d])) + in_dram_strides.append(int(x_stride[d])) + out_dram_strides.append(int(output_strides[d])) + + input_offset_maps.append(make_affine_map(in_syms, in_layout_strides, input_tile_descs[i].offset)) + input_offset_var_strs.append(", ".join(f"%{s}" for s in in_syms)) + input_dram_strides.append(in_dram_strides) + + output_offset_maps.append(make_affine_map(out_syms, out_layout_strides, output_tile_descs[i].offset)) + output_offset_var_strs.append(", ".join(f"%{s}" for s in out_syms)) + output_dram_strides.append(out_dram_strides) + + return (input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets) diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 34b185b8..256d7101 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -173,7 +173,11 @@ def get_mlir_shape(info): def mlir_argdefs(self, extra_node=dict()): buffer_types = {} for x in V.graph.buffers: - if not isinstance(x.layout, MultiOutputLayout): # FIXME: MultiOutputLayout should be handled + if isinstance(x.layout, MultiOutputLayout): + # MultiOutput kernel containers own concrete output nodes in `outputs`. + for out in getattr(x, "outputs", []): + buffer_types[out.get_name()] = [out.get_dtype(), out.get_numel(), out.get_size(), out.get_stride()] + else: buffer_types[x.get_name()] = [x.get_dtype(), x.get_numel(), x.get_size(), x.get_stride()] for name, val in V.graph.graph_inputs.items(): if isinstance(val, sympy.Expr): diff --git a/PyTorchSimFrontend/mlir/mlir_conv_common.py b/PyTorchSimFrontend/mlir/mlir_conv_common.py index f8566b6d..f72a7663 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_common.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_common.py @@ -12,6 +12,9 @@ class MLIRConvCommonTemplate(MLIRTemplate): WRAPPER_TEMPLATE = None def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = False + self.support_reduction_fusion = False self.stride = kwargs["stride"] self.padding = kwargs["padding"] self.dilation = kwargs["dilation"] diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 0158caa6..5b116807 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -105,6 +105,9 @@ class MLIRGemmTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = True + self.support_reduction_fusion = True def render(self, kernel: MLIRTemplateKernel, diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index ebf0c80e..e5df4b78 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -1,3 +1,4 @@ +import math from typing import List, Optional, Sequence import torch @@ -15,10 +16,15 @@ from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate +from PyTorchSimFrontend.mlir.mlir_cat_template import MLIRCatTemplate +from PyTorchSimFrontend.mlir.mlir_sort_template import MLIRSortTemplate from PyTorchSimFrontend import extension_config aten = torch.ops.aten aten_spmm = MLIRExternKernelChoice(torch.sparse.mm, "custom_op::sparse_addmm") +_orig_cat_default_lowering = lowerings.get(aten.cat.default) +_orig_cat_out_lowering = lowerings.get(aten.cat.out) +_orig_sort_values_stable_lowering = lowerings.get(aten.sort.values_stable) def tuned_mm(mat1, mat2, * ,layout=None): m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) @@ -181,11 +187,191 @@ def custom_unsafe_index(x, indices): x.realize() return index_impl(x, indices, check=False) + +def _cat_layout(tensors: Sequence[TensorBox], dim: int) -> ir.Layout: + with V.graph.fake_mode: + output = torch.ops.aten.cat( + [ir.ir_node_to_tensor(t, guard_shape=True) for t in tensors], + dim, + ) + sizes = ir.convert_shape_to_inductor(output.size()) + stride = ir.convert_shape_to_inductor(output.stride()) + return ir.FixedLayout( + tensors[0].get_device(), + tensors[0].get_dtype(), + sizes, + stride, + ) + +def custom_cat_default(tensors: Sequence[TensorBox], dim: int = 0): + if tensors and dim < 0: + dim += len(tensors[0].get_size()) + copy_default_lowering = lowerings.get(aten.copy_.default) + empty_strided_lowering = lowerings.get(aten.empty_strided.default) + new_tensors = [] + for t in tensors: + t.realize() + # If the tensor is backed by a view (ReinterpretView, PermuteView, etc.), + # materialise it into a fresh contiguous FixedLayout buffer so the cat + # kernel always receives plain, dense strides. + if isinstance(t.data, ir.BaseView): + sizes = list(t.get_size()) + strides = [math.prod(sizes[i + 1:]) for i in range(len(sizes))] + new_buf = empty_strided_lowering( + sizes, strides, dtype=t.get_dtype(), device=t.get_device() + ) + tt = copy_default_lowering(new_buf, t) + else: + tt = t + new_tensors.append(tt) + + layout = _cat_layout(new_tensors, dim) + mlir_template = MLIRCatTemplate(list(new_tensors), layout, dim=dim) + return mlir_template.generate().output_node() + +def _custom_sort_values_impl( + self: TensorBox, + dim: int = -1, + descending: bool = False, + values: Optional[TensorBox] = None, + indices: Optional[TensorBox] = None, + stable: Optional[bool] = None, +): + if values is None or indices is None: + raise RuntimeError("sort.values* lowering requires both out tensors: values, indices") + + def _normalize_dim(rank: int, d: int) -> int: + return d + rank if d < 0 else d + + if not hasattr(self, "get_size"): + raise RuntimeError("sort.values* lowering requires TensorBox input") + + rank = len(self.get_size()) + norm_dim = _normalize_dim(rank, dim) + if norm_dim < 0 or norm_dim >= rank: + raise RuntimeError(f"sort.values* dim out of range: dim={dim}, rank={rank}") + if rank != 2: + raise RuntimeError(f"sort.values* lowering currently supports rank-2 only, got rank={rank}") + if norm_dim not in (0, 1): + raise RuntimeError(f"sort.values* lowering currently supports dim in {{0,1}} only, got dim={norm_dim}") + + self.realize() + if isinstance(values, TensorBox): + values.realize() + if isinstance(indices, TensorBox): + indices.realize() + + value_layout, _ = _sort_layouts(self, norm_dim, descending) + mlir_template = MLIRSortTemplate( + [self], + value_layout, + dim=norm_dim, + descending=descending, + stable=True if stable is None else stable, + indices_node=indices, + ) + sorted_values = mlir_template.generate(template_buffer_node=values, epilogue_nodes=[indices]).output_node() + return sorted_values, indices + + +def _sort_layouts(x: TensorBox, dim: int, descending: bool): + with V.graph.fake_mode: + v, i = torch.ops.aten.sort( + ir.ir_node_to_tensor(x, guard_shape=True), + dim, + descending, + ) + v_sizes = ir.convert_shape_to_inductor(v.size()) + v_stride = ir.convert_shape_to_inductor(v.stride()) + i_sizes = ir.convert_shape_to_inductor(i.size()) + i_stride = ir.convert_shape_to_inductor(i.stride()) + + value_layout = ir.FixedLayout(x.get_device(), x.get_dtype(), v_sizes, v_stride) + index_layout = ir.FixedLayout(x.get_device(), torch.int64, i_sizes, i_stride) + return value_layout, index_layout + + +def custom_sort_stable( + self: TensorBox, + *, + stable: Optional[bool] = None, + dim: int = -1, + descending: bool = False, +): + empty_strided_lowering = lowerings.get(aten.empty_strided.default) + if empty_strided_lowering is None: + if _orig_sort_values_stable_lowering is None: + raise RuntimeError("sort.stable lowering requires aten.empty_strided.default") + return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=True) + + rank = len(self.get_size()) if hasattr(self, "get_size") else 0 + norm_dim = dim + rank if dim < 0 else dim + if rank > 0 and (norm_dim < 0 or norm_dim >= rank): + raise RuntimeError(f"sort.stable dim out of range: dim={dim}, rank={rank}") + + # Template specialization supports rank-2 and dim in {0,1}. + if rank == 2 and norm_dim not in (0, 1): + if _orig_sort_values_stable_lowering is None: + raise RuntimeError("Original aten.sort.values_stable lowering is missing") + return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=True) + + try: + value_layout, index_layout = _sort_layouts(self, norm_dim, descending) + values = empty_strided_lowering( + list(value_layout.size), + list(value_layout.stride), + dtype=value_layout.dtype, + device=self.get_device(), + ) + indices = empty_strided_lowering( + list(index_layout.size), + list(index_layout.stride), + dtype=index_layout.dtype, + device=self.get_device(), + ) + return _custom_sort_values_impl( + self=self, + dim=dim, + descending=descending, + values=values, + indices=indices, + stable=True if stable is None else stable, + ) + except Exception: + if _orig_sort_values_stable_lowering is None: + raise + return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=stable) + + +def custom_sort_values_stable( + self: TensorBox, + *, + stable: Optional[bool] = None, + dim: int = -1, + descending: bool = False, + values: Optional[TensorBox] = None, + indices: Optional[TensorBox] = None, +): + return _custom_sort_values_impl( + self=self, + dim=dim, + descending=descending, + values=values, + indices=indices, + stable=stable, + ) + + lowerings.update({getattr(aten.mm, overload): tuned_mm for overload in aten.mm.overloads()}) lowerings.update({getattr(aten.addmm, overload): tuned_addmm for overload in aten.addmm.overloads()}) lowerings.update({getattr(aten.convolution, overload): convolution for overload in aten.convolution.overloads()}) lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()}) lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()}) +lowerings.update({getattr(aten.cat, overload): custom_cat_default for overload in aten.cat.overloads()}) + +lowerings.update({aten.sort.stable: custom_sort_stable}) +lowerings.update({aten.sort.values_stable: custom_sort_values_stable}) + if extension_config.CONFIG_USE_TIMING_POOLING: - lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template \ No newline at end of file + lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index af960533..2f9c9704 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -44,12 +44,10 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule # Case 3: Prologue(Pointwise) + Tempalte if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - target_node = base_template_node2[0].node - # Currently only BMM, MM support prologue fusion - if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + + # Check if template supports prologue fusion + if not getattr(target_node.template, 'support_prologue_fusion', False): return False if len(node1.read_writes.writes) != 1: @@ -129,12 +127,14 @@ def can_fuse_horizontal(self, node1, node2): if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and not node2.is_reduction(): # Don't fuse maxpool template code from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate template_node = base_template_node1[0] epilogue_node = node2 + # Check if template supports epilogue fusion + if not getattr(template_node.node.template, 'support_epilogue_fusion', False): + return False + if isinstance(template_node.node.template, MLIRMaxPoolTemplate): return False @@ -161,7 +161,7 @@ def can_fuse_horizontal(self, node1, node2): # Revert act_node.group : simplify_and_reorder() modified _body, _size, group if template_node.group != epilogue_node.group: # We don't fuse this case... - if (isinstance(template_node.node.template, MLIRBMMTemplate) or isinstance(template_node.node.template, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: + if getattr(template_node.node.template, 'support_prologue_fusion', False) and template_node.group[1][0][0] == 1: return False if list(template_node.group[1][0]) != list(epilogue_node.get_nodes()[0].node.data.get_size()): @@ -171,10 +171,10 @@ def can_fuse_horizontal(self, node1, node2): # Case 2: Tempalte + Reduction fusion if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and node2.is_reduction() and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate target_node = base_template_node1[0].node - if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + + # Check if template supports reduction fusion + if not getattr(target_node.template, 'support_reduction_fusion', False): return False size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) diff --git a/PyTorchSimFrontend/mlir/mlir_sort_template.py b/PyTorchSimFrontend/mlir/mlir_sort_template.py new file mode 100644 index 00000000..d12c7570 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_sort_template.py @@ -0,0 +1,253 @@ +from typing import List, Optional + +import sympy +from torch._inductor.ir import IRNode +from torch._inductor.virtualized import V + +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel + + +TEMPLATE = r""" +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X, YI], outputs=[YV], names_str=NAMES_STR, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("YI", YI_TILE_DESC, id=1, indent_size=2) }} + {{ kernel.def_sram_buffer(OUT_DVAR, YV_TILE_DESC, id=2, indent_size=2) }} + {{ kernel.def_local_vars(indent_size=2) }} + + %c0 = arith.constant 0 : index + %c_cols = arith.constant {{ COLS }} : index + + affine.for %sort_block = 0 to 1 step 1 { + // Initialize output value/index buffers. + affine.for %row = 0 to {{ ROWS }} step 1 { + affine.for %col = 0 to {{ COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X", INIT_X_IDX, X_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, INIT_YV_IDX, X_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} +{% if DIM == 1 %} + %idx_i64 = arith.index_cast %col : index to {{ YI_ELEM_TYPE }} +{% else %} + %idx_i64 = arith.index_cast %row : index to {{ YI_ELEM_TYPE }} +{% endif %} + memref.store %idx_i64, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", INIT_YI_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} + } + } + +{% if DIM == 1 %} + // Stable bubble sort on each row (dim=1). + affine.for %row = 0 to {{ ROWS }} step 1 { + affine.for %pass = 0 to {{ COLS }} step 1 { + affine.for %j = 0 to {{ COLS_MINUS1 }} step 1 { + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D1_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %lhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D1_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %rhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + +{% if DESCENDING %} + %need_swap = arith.cmpf olt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% else %} + %need_swap = arith.cmpf ogt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% endif %} + scf.if %need_swap { + memref.store %rhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D1_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %lhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D1_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + {{ kernel.def_dma_op("MVIN", "YI", D1_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %li = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", "YI", D1_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %ri = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + memref.store %ri, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D1_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %li, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D1_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + } + } + } + } +{% else %} + // Stable bubble sort on each column (dim=0). + affine.for %col = 0 to {{ COLS }} step 1 { + affine.for %pass = 0 to {{ ROWS }} step 1 { + affine.for %i = 0 to {{ ROWS_MINUS1 }} step 1 { + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D0_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %lhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D0_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %rhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + +{% if DESCENDING %} + %need_swap = arith.cmpf olt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% else %} + %need_swap = arith.cmpf ogt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% endif %} + scf.if %need_swap { + memref.store %rhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D0_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %lhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D0_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + {{ kernel.def_dma_op("MVIN", "YI", D0_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %li = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", "YI", D0_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %ri = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + memref.store %ri, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D0_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %li, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D0_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + } + } + } + } +{% endif %} + } { outer_loop=true } + return +} +""" + + +class MLIRSortTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, dim, descending=False, stable=False, indices_node=None, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.dim = dim + self.descending = descending + self.stable = stable + self.indices_node = indices_node + + def render( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + tile_info=None, + **kwargs, + ): + if template_buffer_node is not None: + self.output_node = template_buffer_node + if self.indices_node is None: + raise RuntimeError("MLIRSortTemplate requires indices output node") + + x = self.input_nodes[0] + yv = self.output_node + yi = self.indices_node + + def _as_int(v): + try: + return int(v) + except Exception: + return int(V.graph.sizevars.size_hint(v)) + + x_size = x.get_size() + if len(x_size) != 2: + raise RuntimeError("MLIRSortTemplate currently supports rank-2 input only") + if self.dim not in (0, 1): + raise RuntimeError(f"MLIRSortTemplate currently supports dim in {{0,1}} only, got dim={self.dim}") + + rows = _as_int(x_size[0]) + cols = _as_int(x_size[1]) + cols_minus1 = max(0, cols - 1) + rows_minus1 = max(0, rows - 1) + + x_dtype = x.get_dtype() + yv_dtype = yv.get_dtype() + yi_dtype = yi.get_dtype() + if x_dtype != yv_dtype: + raise RuntimeError("sort template requires input/value dtype match") + + yi_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yi_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yi_tile_desc.set_name("yi_sort_tile") + yv_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yv_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yv_tile_desc.set_name("yv_sort_tile") + # Neighbor element descriptors use DRAM offset to preserve affine stride metadata. + yv_s1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yv_s1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yv_s1_tile_desc.set_name("yv_sort_tile") + yi_s1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yi_s1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yi_s1_tile_desc.set_name("yi_sort_tile") + if int(self.dim) == 1: + yv_s1_tile_desc.offset = sympy.Integer(1) + yi_s1_tile_desc.offset = sympy.Integer(1) + else: + yv_s1_tile_desc.offset = sympy.Integer(cols) + yi_s1_tile_desc.offset = sympy.Integer(cols) + + row = sympy.Symbol("row") + col = sympy.Symbol("col") + i = sympy.Symbol("i") + j = sympy.Symbol("j") + + init_x_idx = [row * cols, col] + init_yv_idx = [row * cols, col] + init_yi_idx = [row * cols, col] + + d1_s0_idx = [row * cols, j] + d1_s1_idx = [row * cols, j] + + d0_s0_idx = [i * cols, col] + d0_s1_idx = [i * cols, col] + + kernel.loop_size = None + numel = rows * cols + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X=x, + YV=yv, + YI=yi, + OUT_DVAR="YV", + NAMES_STR="X, YI, YV", + ROWS=rows, + COLS=cols, + COLS_MINUS1=cols_minus1, + ROWS_MINUS1=rows_minus1, + DIM=int(self.dim), + DESCENDING=bool(self.descending), + YI_TILE_DESC=yi_tile_desc, + YV_TILE_DESC=yv_tile_desc, + YI_S1_TILE_DESC=yi_s1_tile_desc, + YV_S1_TILE_DESC=yv_s1_tile_desc, + INIT_X_IDX=init_x_idx, + INIT_YV_IDX=init_yv_idx, + INIT_YI_IDX=init_yi_idx, + D1_S0_IDX=d1_s0_idx, + D1_S1_IDX=d1_s1_idx, + D0_S0_IDX=d0_s0_idx, + D0_S1_IDX=d0_s1_idx, + YV_ELEM_TYPE=mlir_common.DTYPE_TO_MLIR[yv_dtype], + YI_ELEM_TYPE=mlir_common.DTYPE_TO_MLIR[yi_dtype], + X_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[x_dtype]}>", + YV_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[yv_dtype]}>", + YI_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[yi_dtype]}>", + YV_TILE_MEMREF_TYPE=yv_tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[yv_dtype]), + YI_TILE_MEMREF_TYPE=yi_tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[yi_dtype]), + X_TILE_DESC=yv_tile_desc, + input_reorder=self.input_reorder, + ) + + output_node_name = yv.get_name() if hasattr(yv, "get_name") else yv.name + kernel.epilogue_info = dict( + output_node=output_node_name, + sram_var="yv_sort_tile", + dram_var=kernel.render_options["OUT_DVAR"], + dram_tile_desc=yv_tile_desc, + ) + kernel.exception_nodes[kernel.render_options["OUT_DVAR"]] = {"numel": yv.get_numel()} + kernel.exception_nodes["YI"] = {"numel": yi.get_numel()} + + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + return code diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index b1c756ba..9cc79e0a 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -14,7 +14,7 @@ from unittest.mock import patch from torch._inductor.codegen.common import KernelTemplate, CSE, DeferredLine -from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, ChoiceCaller +from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, ChoiceCaller, ir_node_to_tensor from torch._inductor.select_algorithm import PartialRender from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller from torch._inductor.autotune_process import TensorMeta @@ -112,7 +112,8 @@ def __init__(self, self.outer_func_name = outer_func_name self.outer_func_render = outer_func_render self.kernel_arg_attributes = kernel_arg_attributes - self.render_hooks = OrderedDict() + self.render_hooks = OrderedDict() # Stores {key: (priority, hook)} + self.dma_op_counter = itertools.count() # Add counter for unique DMA op keys self.buffer_names = dict() self.render_options = dict() self.tile_size = [] @@ -124,6 +125,7 @@ def __init__(self, self.epilogue_buffer_group = IndentedBufferGroup(self, prefix="epilogue_") self.global_vars = IndentedBuffer() self.exception_nodes = {} + self.epilogue_info = {} # Reduction data structure self.reduction_epilogue_suffix = IndentedBuffer() self.reduction_fusion = False @@ -460,11 +462,11 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ } node.codegen((vars, reduction_vars)) - # Codegen epilogue nodes - tile_desc = kernel.set_tile_size(kernel.epilogue_info) - kernel.kernel_group.set_tile_info(tile_desc) - kernel.call_ranges = None if epilogue_nodes: + # Codegen epilogue nodes + tile_desc = kernel.set_tile_size(kernel.epilogue_info) + kernel.kernel_group.set_tile_info(tile_desc) + kernel.call_ranges = None with kernel.epilogue_buffer_group.as_local(): _, (group, reduction_group) = max( epilogue_nodes, key=lambda x: int(x.is_reduction()) @@ -554,7 +556,7 @@ def template_store(): dram_var = self.epilogue_info["dram_var"] index_list = self.epilogue_info["dram_idx"] tile_desc = self.epilogue_info["dram_tile_desc"] - code = self.def_dma_op("MVOUT", dram_var, index_list, tile_desc) + code = self.def_dma_op("MVOUT", dram_var, index_list, tile_desc, lazy_mode=False) self.cse.generate(self.dma_stores, code, assignment = False) body = IndentedBuffer() @@ -625,14 +627,34 @@ def def_kernel( extra_node[node.get_name()] = node.node else: extra_node[node.get_name()] = node - self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] + + if 'sram_var' in self.epilogue_info: + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] def hook(): - arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) - return f"({', '.join(arg_defs)})" + arg_defs, call_args, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) + output_names = names[len(inputs) : len(inputs) + len(outputs)] + out_ptr_idx = 0 + renamed_arg_defs = [] + for outer, arg_def in zip(call_args, arg_defs): + raw_symbol = arg_def.split(":", 1)[0].strip().lstrip("%") + if outer in self.kernel_group.args.input_buffers: + symbol = self.kernel_group.args.input_buffers[outer] + elif outer in self.kernel_group.args.output_buffers: + symbol = self.kernel_group.args.output_buffers[outer] + elif raw_symbol.startswith("out_ptr") and out_ptr_idx < len(output_names): + symbol = output_names[out_ptr_idx] + out_ptr_idx += 1 + elif outer in self.kernel_group.args.sizevars: + symbol = self.kernel_group.args.sizevars[outer] + else: + symbol = raw_symbol + _, arg_type = arg_def.split(":", 1) + renamed_arg_defs.append(f"%{symbol}:{arg_type}") + return f"({', '.join(renamed_arg_defs)})" assert "" not in self.render_hooks - self.render_hooks[""] = hook + self.render_hooks[""] = (5, hook) # Default priority 5 return "" # This function is a temporal function for convolution because currently convolution kernel is not considering padding. @@ -670,7 +692,8 @@ def def_conv_kernel( self.kernel_group.args.output_buffers[node.get_name()] = name self.store_buffer_names.add(node.get_name()) #TODO: Is this enough not calling store() in mlir_common.py? self.extra_node[node.get_name()] = node - self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] #TODO: Buffer name fixed + if 'sram_var' in self.epilogue_info: + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] #TODO: Buffer name fixed def kernel_hook(): arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=self.extra_node) @@ -678,7 +701,7 @@ def kernel_hook(): return f"({', '.join(arg_defs)})" assert "" not in self.render_hooks - self.render_hooks[""] = kernel_hook + self.render_hooks[""] = (5, kernel_hook) # Default priority 5 return "" # This function is for convolution wrapper function finalizing. @@ -689,7 +712,7 @@ def wrapper_hook(): return f"({', '.join(wrapper_arg_defs)})" if "" not in self.render_hooks: - self.render_hooks[""] = wrapper_hook + self.render_hooks[""] = (5, wrapper_hook) # Default priority 5 return "" def get_conv_inputs(self): @@ -698,15 +721,15 @@ def get_conv_inputs(self): def get_conv_outputs(self): return {k: v for k, v in self.kernel_group.args.output_buffers.items() if v != 'REMOVED'} - def load_input(self, indent_size: int = 0): + def load_input(self, indent_size: int = 0, priority: int = 1): def hook(): code = IndentedBuffer() prologue_code = self.codegen_prologue_body() if prologue_code.getvalue(): input_dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], - self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False, lazy_mode=False) weight_dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], - self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False, lazy_mode=False) if (self.prologue_info["is_input_fused"]): code.splice(input_dma_code) code.splice(prologue_code) @@ -717,58 +740,63 @@ def hook(): code.splice(input_dma_code) else: dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], - self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False, lazy_mode=False) code.splice(dma_code) dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], - self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False, lazy_mode=False) code.splice(dma_code) code = textwrap.indent(code.getvalue(), " "*indent_size).strip() return code assert "" not in self.render_hooks - self.render_hooks[""] = hook - self.render_hooks.move_to_end("", last=False) # Force order to be triggered first + self.render_hooks[""] = (priority, hook) return "" - def store_output(self, indent_size: int = 0): + def store_output(self, indent_size: int = 0, priority: int = 1): def hook(): epilogue_code = self.codegen_epilogue_body() return textwrap.indent(epilogue_code.getvalue(), " "*indent_size).strip() assert "" not in self.render_hooks - self.render_hooks[""] = hook - self.render_hooks.move_to_end("", last=False) # Force order to be triggered first + self.render_hooks[""] = (priority, hook) return "" - def reduction_output(self, indent_size: int = 0): + def reduction_output(self, indent_size: int = 0, priority: int = 5): def hook(): return textwrap.indent(self.reductions_suffix.getvalue(), " "*indent_size).strip() assert "" not in self.render_hooks - self.render_hooks[""] = hook + self.render_hooks[""] = (priority, hook) return "" + def _sort_hooks_by_priority(self): + """Sort hooks by priority (lower priority executes first).""" + sorted_hooks = OrderedDict() + for key, (priority, hook) in sorted(self.render_hooks.items(), key=lambda x: x[1][0]): + sorted_hooks[key] = hook + return sorted_hooks + def def_function(self): _, call_args, _, _ = self.kernel_group.args.python_argdefs() if self.outer_func_render is not None: partial_code, function_name = self.outer_func_render(input_args=call_args) + return PartialRender( partial_code, - self.render_hooks, + self._sort_hooks_by_priority(), ), function_name else: return None, None - def def_global_vars(self): + def def_global_vars(self, priority: int = 10): key = "" def hook(): return textwrap.indent(self.global_vars.getvalue(), "").strip() - assert key not in self.render_hooks - self.render_hooks[key] = hook + self.render_hooks[key] = (priority, hook) return key - def def_local_vars(self, indent_size=0): + def def_local_vars(self, indent_size=0, priority: int = 10): key = "" def hook(): code = IndentedBuffer() @@ -777,52 +805,74 @@ def hook(): code.splice(self.alloc_buffer) return textwrap.indent(code.getvalue(), " "*indent_size).strip() - assert key not in self.render_hooks - self.render_hooks[key] = hook + self.render_hooks[key] = (priority, hook) return key def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_common.MLIRMultiDimTile, - subtile_size:list=[], async_type=None, indent_size=0): - # Prepare code block - local_code = IndentedBuffer() - with self, self.override_buffer_cse(buffer=local_code, cse=self.apply_cse): - index_var = self.parse_index_list(index_list, offset=tile_desc.offset) - node_layout = self.named_nodes[dram_var].get_layout() - if dram_var in self.exception_nodes: - numel = self.exception_nodes[dram_var]["numel"] - else: - numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() - mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] - dram_shape = f"memref<{numel}x{mlir_dtype}>" - dram_stride = [] - for idx in index_list: - if idx.is_Mul: - dram_stride.append(int(idx.args[0])) - elif idx == sympy.Symbol("c0"): - dram_stride.append(0) - elif not idx.is_Number: - dram_stride.append(1) + subtile_size:list=[], async_type=None, indent_size=0, priority: int = 5, lazy_mode: bool = True, + dram_stride:list=None, dram_offset=None): + # Todo. Remove legacy behavior (i.e., index_list parsing) + def generate_dma_code(): + """Internal method to generate DMA code directly.""" + local_code = IndentedBuffer() + with self, self.override_buffer_cse(buffer=local_code, cse=self.apply_cse): + if dram_offset is not None: + # Use explicitly provided offset (pre-computed MLIR SSA variable name) + index_var = dram_offset + else: + index_var = self.parse_index_list(index_list, offset=tile_desc.offset) + node_layout = self.named_nodes[dram_var].get_layout() + if dram_var in self.exception_nodes: + numel = self.exception_nodes[dram_var]["numel"] else: - dram_stride.append(0) - - sram_var = tile_desc.get_name() - tile_shape = tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = tile_desc.get_tile_stride() - vlane_split_axis = tile_desc.vmap.vlane_split_axis - vlane_stride = tile_desc.vmap.vlane_stride - - zero_cse = self.get_const_cse(0, "index") - sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) - - attribute_parts = [f"dram_stride={dram_stride}", f"sram_stride={tile_stride}", "padding=0"] - if subtile_size: - attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") - attribute = " {" + ", ".join(attribute_parts) + "}" - code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, "") - local_code.writeline(code) - local_code.writeline(attribute) - return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() + numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() + mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] + dram_shape = f"memref<{numel}x{mlir_dtype}>" + + if dram_stride is not None: + # Use explicitly provided dram_stride + _dram_stride = dram_stride + else: + # Extract dram_stride from index_list (legacy behavior) + _dram_stride = [] + for idx in index_list: + if idx.is_Mul: + _dram_stride.append(int(idx.args[0])) + elif idx == sympy.Symbol("c0"): + _dram_stride.append(0) + elif not idx.is_Number: + _dram_stride.append(1) + else: + _dram_stride.append(0) + + sram_var = tile_desc.get_name() + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) + sram_strides = tile_desc.get_tile_stride() + vlane_split_axis = tile_desc.vmap.vlane_split_axis + vlane_stride = tile_desc.vmap.vlane_stride + + zero_cse = self.get_const_cse(0, "index") + sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) + + attribute_parts = [f"dram_stride={_dram_stride}", f"sram_stride={sram_strides}", "padding=0"] + if subtile_size: + attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") + attribute = " {" + ", ".join(attribute_parts) + "}" + code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, "") + local_code.writeline(code) + local_code.writeline(attribute) + return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() + + if not lazy_mode: + # Immediate mode: generate code directly and return it + return generate_dma_code() + + # Lazy mode: register hook and return key + dma_op_id = next(self.dma_op_counter) + key = f"" + self.render_hooks[key] = (priority, generate_dma_code) + return key def def_sram_buffer(self, dram_name, tile_desc, id=0, indent_size=0): # Prepare code block @@ -840,7 +890,7 @@ def render(self, template, kwargs, define_function=None): return PartialRender( code, - self.render_hooks, + self._sort_hooks_by_priority(), ) def get_spad_size_per_lane(self, tile_m, tile_n): @@ -1128,6 +1178,15 @@ def set_tile_size(self, template_fusion_info, prologue=False): return tile_desc class MLIRTemplateCaller(CUDATemplateCaller): + def __init__(self, name, category, input_nodes, layout, make_kernel_render, supports_epilogue_fusion, template, info_kwargs, description): + bmreq = MLIRBenchmarkRequest( + kernel_name=name, + input_tensor_meta=list(), + output_tensor_meta=list(), + extra_args=[], + source_code="", + ) + super().__init__(name, category, input_nodes, layout, make_kernel_render, bmreq, supports_epilogue_fusion, template, info_kwargs, description) def __str__(self): return f"MLIRTemplateCaller(source_file={self.bmreq.source_file})" @@ -1151,8 +1210,14 @@ def __init__(self, name, input_nodes, layout, input_reorder = None): super().__init__(name) self.input_nodes = [node for node in input_nodes if node is not None] self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + # Multi-output templates can override this with explicit output buffers. + self.output_nodes = [self.output_node] self.input_reorder = input_reorder self.layout = layout + # Fusion support flags (default to False) + self.support_epilogue_fusion = False + self.support_prologue_fusion = False + self.support_reduction_fusion = False def generate(self, **kwargs) -> ChoiceCaller: kernel_name = f"mlir_{self.name}" @@ -1164,15 +1229,8 @@ def generate(self, **kwargs) -> ChoiceCaller: code = self.render(kernel=kernel, **kwargs) kernel_hash_name = f"mlir_{self.name}_{next(self.index_counter)}" - extra_args = [] # create the BenchmarkRequest - bmreq = MLIRBenchmarkRequest( - kernel_name=kernel_name, - input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), - output_tensor_meta=TensorMeta.from_irnodes(self.output_node), - extra_args=extra_args, - source_code=code, - ) + output_nodes = getattr(self, "output_nodes", None) or [self.output_node] def make_kernel_render( template_node: TemplateBuffer, @@ -1214,7 +1272,6 @@ def make_kernel_render( self.input_nodes, self.output_node.get_layout(), make_kernel_render, - bmreq, False, # supports_epilogue_fusion self, kwargs, diff --git a/TOGSim/src/DMA.cc b/TOGSim/src/DMA.cc index f8f21025..fefee6d2 100644 --- a/TOGSim/src/DMA.cc +++ b/TOGSim/src/DMA.cc @@ -12,7 +12,7 @@ void DMA::issue_tile(std::shared_ptr inst) { _current_inst = std::move(inst); std::vector& tile_size = _current_inst->get_tile_size(); if (tile_size.size() <= 0 || tile_size.size() > get_max_dim()) { - spdlog::error("[DMA {}] issued tile is not supported format..", _id); + spdlog::error("[DMA {}] issued tile is not supported format.. tile.size: {}, tile_size: [{}]", _id, tile_size.size(), fmt::join(tile_size, ", ")); exit(EXIT_FAILURE); } _finished = false; diff --git a/TOGSim/src/helper/CommandLineParser.cc b/TOGSim/src/helper/CommandLineParser.cc index 66aebbe1..9cd177ac 100644 --- a/TOGSim/src/helper/CommandLineParser.cc +++ b/TOGSim/src/helper/CommandLineParser.cc @@ -12,9 +12,13 @@ void CommandLineParser::parse(int argc, char **argv) noexcept(false) { po::notify(variables_map); } +void CommandLineParser::print_help_message() const noexcept { + std::cout << options_description << std::endl; +} + void CommandLineParser::print_help_message_if_required() const noexcept { if (variables_map.count("help") > 0) { - std::cout << options_description << std::endl; + print_help_message(); exit(0); } } diff --git a/TOGSim/src/helper/CommandLineParser.h b/TOGSim/src/helper/CommandLineParser.h index 39174d5d..b41eabf3 100644 --- a/TOGSim/src/helper/CommandLineParser.h +++ b/TOGSim/src/helper/CommandLineParser.h @@ -19,7 +19,7 @@ class CommandLineParser { * Command Line Parser constructor */ CommandLineParser() noexcept { - options_description.add_options()("help", "Prints help message"); + options_description.add_options()("help,h", "Prints help message"); } /** @@ -38,6 +38,12 @@ class CommandLineParser { */ void print_help_message_if_required() const noexcept; + /** + * Prints the help message. + * (Can be called to show help for invalid options) + */ + void print_help_message() const noexcept; + /** * Add a new command line argument option. * (Should be called before `parse` method is called) diff --git a/TOGSim/src/main.cc b/TOGSim/src/main.cc index 7c596af5..cda8f986 100644 --- a/TOGSim/src/main.cc +++ b/TOGSim/src/main.cc @@ -96,19 +96,24 @@ int main(int argc, char** argv) { // parse command line argumnet CommandLineParser cmd_parser = CommandLineParser(); cmd_parser.add_command_line_option( - "config", "Path for hardware configuration file"); + "config", "Path for hardware configuration file (.yml)"); cmd_parser.add_command_line_option( - "models_list", "Path for the models list file (can be FIFO or regular file)"); + "models_list", "Path for the trace file (.trace)"); cmd_parser.add_command_line_option( "log_level", "Set for log level [trace, debug, info], default = info"); try { cmd_parser.parse(argc, argv); } catch (const CommandLineParser::ParsingError& e) { spdlog::error( - "Command line argument parrsing error captured. Error message: {}", + "Command line argument parsing error captured. Error message: {}", e.what()); - throw(e); + std::cerr << std::endl; + cmd_parser.print_help_message(); + exit(1); } + + // Check if help was requested + cmd_parser.print_help_message_if_required(); std::string level = "info"; cmd_parser.set_if_defined("log_level", &level); diff --git a/tests/DeepSeek/test_deepseek_v3_base.py b/tests/DeepSeek/test_deepseek_v3_base.py index b8402c8b..ade787c5 100644 --- a/tests/DeepSeek/test_deepseek_v3_base.py +++ b/tests/DeepSeek/test_deepseek_v3_base.py @@ -1,8 +1,55 @@ import os import sys import argparse +import copy +from pathlib import Path import torch +# recursive compile for some ops that are caused by graph break +torch.npu.register_eager_to_compile([ + "aten::zero_", + "aten::sum.IntList_out", + "aten::mul.out", + "aten::floor_divide", + "aten::floor_divide.Tensor", + "aten::floor_divide.Scalar", + "aten::cat.out", + "aten::sort.values_stable", +]) + + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + out_cpu = out.cpu() + max_diff = (out_cpu - cpu_out).abs().max().item() + mean_diff = (out_cpu - cpu_out).abs().mean().item() + if torch.allclose(out_cpu, cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print(f"Max absolute difference: {max_diff:.6f}") + print(f"Mean absolute difference: {mean_diff:.6f}") + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("NPU out: ", out_cpu) + print("CPU out: ", cpu_out) + print(f"Max absolute difference: {max_diff:.6f}") + print(f"Mean absolute difference: {mean_diff:.6f}") + exit(1) + + +def _extract_logits(output): + if isinstance(output, torch.Tensor): + return output + if hasattr(output, "logits"): + return output.logits + if isinstance(output, (list, tuple)) and len(output) > 0 and isinstance(output[0], torch.Tensor): + return output[0] + raise TypeError(f"Unsupported output type for comparison: {type(output)}") + def _dtype_from_str(name: str) -> torch.dtype: return { @@ -81,7 +128,7 @@ def _maybe_scale_config(config, scale=1.0, max_layers=None): def _apply_preset(scale, max_layers, batch, seq_len, preset): if preset == "tiny": - return 0.03, 4, 1, min(seq_len, 16) + return 0.03, 1, 1, min(seq_len, 16) if preset == "small": return 0.07, 8, 1, min(seq_len, 32) if preset == "medium": @@ -89,8 +136,58 @@ def _apply_preset(scale, max_layers, batch, seq_len, preset): return scale, max_layers, batch, seq_len +def _togsim_log_count() -> int: + log_dir = Path("togsim_results") + if not log_dir.exists(): + return 0 + return len(list(log_dir.glob("*.log"))) + + +def _assert_simulation_happened(before_count: int, case_name: str): + after_count = _togsim_log_count() + if after_count <= before_count: + raise RuntimeError( + f"{case_name}: TOGSim log count did not increase " + f"(before={before_count}, after={after_count})" + ) + print(f"{case_name}: TOGSim logs increased ({before_count} -> {after_count})") + + +def test_cat_default(device): + def cat_default_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_default_fn) + + before = _togsim_log_count() + out = opt_fn(x, y) + _assert_simulation_happened(before, "cat.default") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + test_result("cat.default", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_out(device): + def cat_out_fn(a, b, out): + return torch.ops.aten.cat.out([a, b], 0, out=out) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + out_buf = torch.empty(14, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_out_fn) + + before = _togsim_log_count() + out = opt_fn(x, y, out_buf) + _assert_simulation_happened(before, "cat.out") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + test_result("cat.out", out, cpu_out, rtol=1e-4, atol=1e-4) + + @torch.no_grad() -def run_deep_seek_v3_base_test( +def run_deepseek_v3_base( model_id, device, init_mode="config-random", @@ -120,7 +217,6 @@ def run_deep_seek_v3_base_test( # (call .to_dict()), so only disable it for pretrained loading path. if init_mode == "pretrained" and getattr(config, "quantization_config", None) is not None: config.quantization_config = None - config = _maybe_scale_config(config, scale=scale, max_layers=max_layers) if init_mode == "config-random": @@ -141,7 +237,6 @@ def run_deep_seek_v3_base_test( else: raise ValueError(f"Unsupported init mode: {init_mode}") - model = model.to(device) model_params = sum(p.numel() for p in model.parameters()) print("init mode:", init_mode) print("scaled hidden_size:", getattr(config, "hidden_size", "n/a")) @@ -157,23 +252,33 @@ def run_deep_seek_v3_base_test( revision=revision, ) encoded = tokenizer(prompt, return_tensors="pt") - input_ids = encoded["input_ids"].to(device) + cpu_input_ids = encoded["input_ids"].cpu() else: vocab_size = getattr(config, "vocab_size", None) if vocab_size is None: raise ValueError("Config has no vocab_size; use --use-tokenizer or pass a model with vocab_size.") - input_ids = _build_random_inputs(batch, seq_len, vocab_size, device) + cpu_input_ids = _build_random_inputs(batch, seq_len, vocab_size, torch.device("cpu")) + input_ids = cpu_input_ids.to(device) - if compile_model: - model = torch.compile(model, dynamic=False) + # CPU version + model_cpu = copy.deepcopy(model).cpu().eval() + cpu_out = _extract_logits(model_cpu(cpu_input_ids)) - out = model(input_ids) - logits = out.logits + # NPU version + model_npu = copy.deepcopy(model_cpu).to(device).eval() + if compile_model: + model_npu = torch.compile(model_npu, dynamic=False) + npu_out = _extract_logits(model_npu(input_ids)) + + # Campare results + test_result( + "DeepSeek V3 Base", + npu_out, + cpu_out, + rtol=3e-1, + atol=2e-1, + ) - print("logits shape:", tuple(logits.shape)) - print("logits dtype:", logits.dtype) - print("logits max:", logits.max().item()) - if __name__ == "__main__": parser = argparse.ArgumentParser(description="DeepSeek V3 download-based test") @@ -181,7 +286,7 @@ def run_deep_seek_v3_base_test( parser.add_argument("--revision", type=str, default=None) parser.add_argument("--trust-remote-code", action="store_true", default=True) parser.add_argument("--init-mode", type=str, default="config-random", choices=["config-random", "pretrained"]) - parser.add_argument("--preset", type=str, default="tiny", choices=["none", "tiny", "small", "medium"]) + parser.add_argument("--preset", type=str, default="small", choices=["none", "tiny", "small", "medium"]) parser.add_argument("--scale", type=float, default=1.0) parser.add_argument("--max-layers", type=int, default=None) parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) @@ -190,6 +295,7 @@ def run_deep_seek_v3_base_test( parser.add_argument("--use-tokenizer", action="store_true") parser.add_argument("--prompt", type=str, default="Hello, DeepSeek V3") parser.add_argument("--compile", action="store_true", default=True) + parser.add_argument("--test", type=str, default="e2e", choices=["all", "e2e", "cat"]) args = parser.parse_args() @@ -203,18 +309,22 @@ def run_deep_seek_v3_base_test( device = torch.device("npu:0") - run_deep_seek_v3_base_test( - model_id=args.model_id, - device=device, - init_mode=args.init_mode, - scale=args.scale, - max_layers=args.max_layers, - dtype=args.dtype, - batch=args.batch, - seq_len=args.seq_len, - use_tokenizer=args.use_tokenizer, - prompt=args.prompt, - trust_remote_code=args.trust_remote_code, - revision=args.revision, - compile_model=args.compile, - ) + if args.test in ("all", "cat"): + test_cat_default(device) + test_cat_out(device) + if args.test in ("all", "e2e"): + run_deepseek_v3_base( + model_id=args.model_id, + device=device, + init_mode=args.init_mode, + scale=args.scale, + max_layers=args.max_layers, + dtype=args.dtype, + batch=args.batch, + seq_len=args.seq_len, + use_tokenizer=args.use_tokenizer, + prompt=args.prompt, + trust_remote_code=args.trust_remote_code, + revision=args.revision, + compile_model=args.compile, + ) diff --git a/tests/test_cat.py b/tests/test_cat.py new file mode 100644 index 00000000..97fcc754 --- /dev/null +++ b/tests/test_cat.py @@ -0,0 +1,200 @@ +import argparse +from pathlib import Path + +import torch + + +def _test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + return + + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + raise RuntimeError(f"{name} mismatch") + +def test_cat_default(device): + def cat_default_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_default_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.default", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_out(device): + def cat_out_fn(a, b, out): + return torch.ops.aten.cat.out([a, b], 0, out=out) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + out_buf = torch.empty(14, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_out_fn) + + out = opt_fn(x, y, out_buf) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.out", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim0(device): + def cat_4d_dim0_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(3, 3, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim0_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.4d.dim0", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim1(device): + def cat_4d_dim1_fn(a, b): + return torch.cat([a, b], dim=1) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 5, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim1_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=1) + _test_result("cat.4d.dim1", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim2(device): + def cat_4d_dim2_fn(a, b): + return torch.cat([a, b], dim=2) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 3, 6, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim2_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=2) + _test_result("cat.4d.dim2", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim3(device): + def cat_4d_dim3_fn(a, b): + return torch.cat([a, b], dim=3) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 3, 4, 7, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim3_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=3) + _test_result("cat.4d.dim3", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_three_inputs(device): + def cat_three_inputs_fn(a, b, c): + return torch.cat([a, b, c], dim=0) + + x = torch.randn(4, 16, device=device) + y = torch.randn(5, 16, device=device) + z = torch.randn(3, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_three_inputs_fn) + + out = opt_fn(x, y, z) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu()], dim=0) + _test_result("cat.three_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_four_inputs(device): + def cat_four_inputs_fn(a, b, c, d): + return torch.cat([a, b, c, d], dim=0) + + x = torch.randn(3, 16, device=device) + y = torch.randn(4, 16, device=device) + z = torch.randn(5, 16, device=device) + w = torch.randn(2, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_four_inputs_fn) + + out = opt_fn(x, y, z, w) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu(), w.cpu()], dim=0) + _test_result("cat.four_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_three_inputs(device): + def cat_4d_three_inputs_fn(a, b, c): + return torch.cat([a, b, c], dim=1) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 4, 4, 5, device=device) + z = torch.randn(2, 5, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_three_inputs_fn) + + out = opt_fn(x, y, z) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu()], dim=1) + _test_result("cat.4d.three_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + +def test_cat_5d(device, dim=0): + def cat_5d_fn(a, b): + return torch.cat([a, b], dim=dim) + + x = torch.randn(2, 3, 4, 5, 6, device=device) + y = torch.randn(3, 3, 4, 5, 6, device=device) + opt_fn = torch.compile(dynamic=False)(cat_5d_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=dim) + _test_result("cat.5d.dim0", out, cpu_out, rtol=1e-4, atol=1e-4) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run cat simulation tests") + parser.add_argument( + "--case", + choices=[ + "default", "out", "4d_dim0", "4d_dim1", "4d_dim2", "4d_dim3", "5d" + "three_inputs", "four_inputs", "4d_three_inputs", "all" + ], + default="all", + help="Which cat case to run", + ) + args = parser.parse_args() + + device = torch.device("npu:0") + + if args.case in ("default", "all"): + test_cat_default(device) + if args.case in ("out", "all"): + test_cat_out(device) + if args.case in ("4d_dim0", "all"): + test_cat_4d_dim0(device) + if args.case in ("4d_dim1", "all"): + test_cat_4d_dim1(device) + if args.case in ("4d_dim2", "all"): + test_cat_4d_dim2(device) + if args.case in ("4d_dim3", "all"): + test_cat_4d_dim3(device) + if args.case in ("three_inputs", "all"): + test_cat_three_inputs(device) + if args.case in ("four_inputs", "all"): + test_cat_four_inputs(device) + if args.case in ("4d_three_inputs", "all"): + test_cat_4d_three_inputs(device) + if args.case in ("5d", "all"): + test_cat_5d(device) diff --git a/tests/test_sort.py b/tests/test_sort.py new file mode 100644 index 00000000..2b070223 --- /dev/null +++ b/tests/test_sort.py @@ -0,0 +1,112 @@ +import argparse +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out:", out.cpu()) + print("cpu out:", cpu_out) + raise SystemExit(1) + + +def test_equal(name, out, cpu_out): + if torch.equal(out.cpu(), cpu_out): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out:", out.cpu()) + print("cpu out:", cpu_out) + raise SystemExit(1) + + +def _normalize_dim(dim: int, rank: int) -> int: + d = dim if dim >= 0 else rank + dim + if d < 0 or d >= rank: + raise ValueError(f"dim out of range: dim={dim}, rank={rank}") + return d + + +def test_sort_stable(device, size=(128, 128), dim=-1, descending=False): + _normalize_dim(dim, len(size)) + + def sort_stable_fn(x): + return torch.sort(x, stable=True, dim=dim, descending=descending) + + x = torch.randn(size, dtype=torch.float32) + x_npu = x.to(device=device) + + opt_sort = torch.compile(dynamic=False)(sort_stable_fn) + out_values, out_indices = opt_sort(x_npu) + + ref_values, ref_indices = torch.sort(x, stable=True, dim=dim, descending=descending) + + test_result("Sort.stable/values", out_values, ref_values) + test_equal("Sort.stable/indices", out_indices, ref_indices) + + +def test_sort_values_stable(device, size=(128, 128), dim=-1, descending=False): + _normalize_dim(dim, len(size)) + + def sort_out_fn(x): + out_values = torch.empty_like(x, device=x.device) + out_indices = torch.empty_like(x, dtype=torch.int64, device=x.device) + return torch.sort(x, stable=True, dim=dim, descending=descending, out=(out_values, out_indices)) + + x = torch.randn(size, dtype=torch.float32) + x_npu = x.to(device=device) + + opt_sort = sort_out_fn# torch.compile(dynamic=False)(sort_out_fn) + out_values, out_indices = opt_sort(x_npu) + + ref_values, ref_indices = torch.sort(x, stable=True, dim=dim, descending=descending) + + test_result("Sort.values_stable/values", out_values, ref_values) + test_equal("Sort.values_stable/indices", out_indices, ref_indices) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run sort tests") + parser.add_argument("--shape", type=str, default="(128,128)") + parser.add_argument("--dim", type=int, default=0) + parser.add_argument("--descending", action="store_true") + parser.add_argument( + "--mode", + type=str, + default="all", + choices=["all", "default", "values"], + ) + args = parser.parse_args() + + shape = tuple(map(int, args.shape.strip("()").split(","))) + + from Scheduler.scheduler import PyTorchSimRunner + + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + # Register recursive-compile bridge only when values_stable path is explicitly tested. + if args.mode in ("all", "values"): + torch.npu.register_eager_to_compile([ + "aten::sort.values_stable", + ]) + + if args.mode in ("all", "default"): + test_sort_stable(device, size=shape, dim=args.dim, descending=args.descending) + if args.mode in ("all", "values"): + test_sort_values_stable(device, size=shape, dim=args.dim, descending=args.descending)