|
15 | 15 | is_welford_reduction, |
16 | 16 | sympy_product |
17 | 17 | ) |
18 | | -from torch.utils._sympy.functions import ModularIndexing |
| 18 | +from torch.utils._sympy.functions import ModularIndexing, FloorDiv |
19 | 19 | import PyTorchSimFrontend.extension_codecache as extension_codecache |
20 | 20 |
|
21 | 21 | from PyTorchSimFrontend import extension_config |
@@ -260,10 +260,10 @@ def binary_elementwise_common(operand1, operand2, var_info): |
260 | 260 | operand2 = ops.to_dtype(operand2, op_type1[1], var_info) |
261 | 261 | op_type2 = var_info[operand2] |
262 | 262 | elif op_type1[1][0] == op_type2[1][0]: |
263 | | - if int(op_type1[1][1:]) > int(op_type2[1][1:]): |
| 263 | + if mlir_common.MLIR_TO_BIT[op_type1[1]] > mlir_common.MLIR_TO_BIT[op_type2[1]]: |
264 | 264 | operand2 = ops.ext(operand2, op_type1[1]) |
265 | 265 | op_type2 = var_info[operand2] |
266 | | - elif int(op_type1[1][1:]) < int(op_type2[1][1:]): |
| 266 | + elif mlir_common.MLIR_TO_BIT[op_type1[1]] < mlir_common.MLIR_TO_BIT[op_type2[1]]: |
267 | 267 | operand1 = ops.ext(operand1, op_type2[1]) |
268 | 268 | op_type1 = var_info[operand1] |
269 | 269 | else: |
@@ -348,17 +348,21 @@ def maximum(operand1, operand2, *args, var_info=None, **kwargs): |
348 | 348 | @staticmethod |
349 | 349 | def to_dtype(operand, dst_mlir_dtype, *args, var_info=None, **kwargs): |
350 | 350 | src_mlir_dtype = var_info[operand][1] |
| 351 | + if src_mlir_dtype == "index": |
| 352 | + operand = ops.index_cast(operand, "i64", var_info=var_info) |
| 353 | + src_mlir_dtype = var_info[operand][1] |
| 354 | + |
351 | 355 | tile_size = var_info[operand][0] |
352 | 356 | if isinstance(dst_mlir_dtype, torch.dtype): |
353 | 357 | dst_mlir_dtype = mlir_common.DTYPE_TO_MLIR[dst_mlir_dtype] |
354 | | - dst_bits = int(dst_mlir_dtype[1:]) |
355 | | - src_bits = int(src_mlir_dtype[1:]) |
| 358 | + dst_bits = mlir_common.MLIR_TO_BIT[dst_mlir_dtype] |
| 359 | + src_bits = mlir_common.MLIR_TO_BIT[src_mlir_dtype] |
356 | 360 | shape = f"vector<{tile_size}x{dst_mlir_dtype}>" if tile_size > 1 else dst_mlir_dtype |
357 | 361 | src_shape = f"vector<{tile_size}x{src_mlir_dtype}>" if tile_size > 1 else src_mlir_dtype |
358 | 362 | if dst_mlir_dtype[0] == "i" and src_mlir_dtype[0] == "f": |
359 | | - return f"arith.fptoui%{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] |
| 363 | + return f"arith.fptoui %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] |
360 | 364 | if dst_mlir_dtype[0] == "f" and src_mlir_dtype[0] == "i": |
361 | | - return f"arith.uitofp%{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] |
| 365 | + return f"arith.uitofp %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] |
362 | 366 | if dst_mlir_dtype[0] == "i": |
363 | 367 | if dst_bits > src_bits: |
364 | 368 | return f"arith.extui %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] |
@@ -955,6 +959,8 @@ def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> com |
955 | 959 | # Extract index var |
956 | 960 | indirect_args = [f"%{i}" for i in indirect_dims] |
957 | 961 | expr_str = str(expr) |
| 962 | + if "//" in expr_str: |
| 963 | + expr_str = expr_str.replace("//", " floordiv ") |
958 | 964 | args = ", ".join(map(str, indices)) |
959 | 965 | map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[{','.join(indirect_dims)}] -> ({expr_str})>") |
960 | 966 | args = ", ".join([f"%{i}" for i in indices]) |
@@ -1063,6 +1069,9 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): |
1063 | 1069 | vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) |
1064 | 1070 | compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() |
1065 | 1071 | require_store = True |
| 1072 | + if compute_vec_size < self.var_info[value][0]: |
| 1073 | + value = self.cse.generate(self.stores, f"vector.extract_strided_slice %{value} {{offsets = [0], sizes = [{compute_vec_size}], strides = [1]}}: vector<{self.var_info[value][0]}x{self.var_info[value][1]}> to {vshape}") |
| 1074 | + self.register_var_info(value, [compute_vec_size, mlir_dtype]) |
1066 | 1075 |
|
1067 | 1076 | if str(value) in self.spad_buffer_dict: |
1068 | 1077 | # Todo. If tile_size is not same (i.e., view operation), we can't apply peephole optimization easily |
@@ -1680,6 +1689,40 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe |
1680 | 1689 | sorted_keys = sorted(dram_dict.keys()) |
1681 | 1690 | dram_stride = sum((dram_dict[key] for key in sorted_keys), []) |
1682 | 1691 |
|
| 1692 | + # Support floordiv pattern |
| 1693 | + # FIXME. How to integrate implicit dims and floordiv? |
| 1694 | + # This was introduced to support GroupNorm |
| 1695 | + if index.has(FloorDiv) and not index.has(ModularIndexing): |
| 1696 | + dim_divisor = [1] * len(local_dims) |
| 1697 | + for sub in sympy.preorder_traversal(index): |
| 1698 | + if isinstance(sub, FloorDiv): |
| 1699 | + if not str(sub.args[0]).startswith("index"): |
| 1700 | + continue |
| 1701 | + dim_idx = int((str(sub.args[0])[5:])) |
| 1702 | + if int(self.kernel_group.tile_desc.get_tile_size()[dim_idx] % sub.args[1]) != 0: |
| 1703 | + # In this case, need to recompile |
| 1704 | + original_size = self.kernel_group.tile_desc.get_tile_size()[dim_idx] |
| 1705 | + divisor = sub.args[1] |
| 1706 | + new_size = ((original_size + divisor - 1) // divisor) * divisor |
| 1707 | + new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) |
| 1708 | + new_tile_sizes[dim_idx] = new_size |
| 1709 | + self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) |
| 1710 | + |
| 1711 | + # Send recompile signal |
| 1712 | + self.reset("recompile") |
| 1713 | + raise mlir_common.RecompileSignal(f"Tile size {self.kernel_group.tile_desc.get_tile_size()[dim_idx]} is not divisible by {sub.args[1]}") |
| 1714 | + dim_divisor[dim_idx] = sub.args[1] |
| 1715 | + |
| 1716 | + # Update dram_stride, just insert 0 next to target dim |
| 1717 | + offset = 0 |
| 1718 | + for dim_idx, divisor in enumerate(dim_divisor): |
| 1719 | + if divisor == 1: |
| 1720 | + continue |
| 1721 | + dram_stride.insert(dim_idx+offset+1, 0) |
| 1722 | + local_tile_desc.apply_divisor(dim_idx+offset, divisor, "pad") |
| 1723 | + local_tile_desc.apply_divisor(dim_idx+offset, divisor, "split") |
| 1724 | + offset = offset+1 |
| 1725 | + |
1683 | 1726 | # FIXME. It will be nice to modify node instead of this exception handling... |
1684 | 1727 | if len(self.itervars) == 1 and self.reduction_depth == 0: |
1685 | 1728 | # In case of reduction loop only case, we will add dummy loop so shift it once |
@@ -1810,7 +1853,11 @@ def convert_indirect_indexing(self, index :sympy.Expr): |
1810 | 1853 |
|
1811 | 1854 | # Load indirect operands |
1812 | 1855 | for target_dim in indirect_dims: |
1813 | | - sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim] |
| 1856 | + if target_dim in self.spad_buffer_dict: |
| 1857 | + sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim] |
| 1858 | + else: |
| 1859 | + raise NotImplementedError("TODO.") |
| 1860 | + |
1814 | 1861 | mlir_dtype = vshape.split("x")[1][:-1] |
1815 | 1862 | vshape = f"vector<{tile_numel_per_lane}x{mlir_dtype}>" # FIXME. Maybe require fine grain compute... |
1816 | 1863 | if tile_numel_per_lane > 1: |
|
0 commit comments