Skip to content

Commit ccf12db

Browse files
committed
[Frontend] Support indexed expr + indirect access case
1 parent 8cbfd98 commit ccf12db

9 files changed

Lines changed: 389 additions & 54 deletions

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 79 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,8 @@ def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> com
10071007

10081008
# Extract index var
10091009
indirect_args = [f"%{i}" for i in indirect_dims]
1010+
if len(indirect_args):
1011+
comments = "{indirect_access} " + comments # Add indirect access attribute
10101012
expr_str = str(expr)
10111013
if "//" in expr_str:
10121014
expr_str = expr_str.replace("//", " floordiv ")
@@ -1057,17 +1059,27 @@ def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0))
10571059

10581060
def load(self, name: str, index: sympy.Expr):
10591061
index = self.rename_indexing(index)
1060-
index = self.convert_indirect_indexing(index)
1062+
index, comptute_depedency = self.convert_indirect_indexing(index)
10611063
padding = self.get_padding_type()
10621064

1065+
# In case of special form of indirect access, we need to put load in dma_store buffer
1066+
if comptute_depedency:
1067+
apply_buffer = self.dma_stores
1068+
dma_buffer = self.dma_stores
1069+
load_buffer = self.dma_stores
1070+
else:
1071+
apply_buffer = None
1072+
dma_buffer = self.dma_loads
1073+
load_buffer = self.loads
1074+
10631075
# Extract dram info
10641076
dram_var = self.kernel_group.args.input(name)
10651077
dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name])
10661078
dtype = V.graph.get_dtype(name)
10671079
mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype]
10681080

10691081
# Extract sram info
1070-
local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index)
1082+
local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index, buffer=apply_buffer)
10711083
vlane_split_axis = local_tile_desc.vlane_split_axis
10721084
vlane_stride = local_tile_desc.vlane_stride
10731085
tile_numel_per_lane = local_tile_desc.get_numel_per_lane()
@@ -1085,19 +1097,27 @@ def load(self, name: str, index: sympy.Expr):
10851097
attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding={padding}}}"
10861098
code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var,
10871099
dram_shape, tile_shape, attribute)
1088-
self.cse.generate(self.dma_loads, code, assignment = False) # FIXME: assignment = False does not support caching
1089-
compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"])
1090-
# Generate vector load instruction
1091-
if compute_vec_size > 1:
1092-
operation = "affine.vector_load"
1093-
line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}"
1100+
self.cse.generate(dma_buffer, code, assignment = False) # FIXME: assignment = False does not support caching
1101+
1102+
if not comptute_depedency:
1103+
compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"])
1104+
# Generate vector load instruction
1105+
if compute_vec_size > 1:
1106+
operation = "affine.vector_load"
1107+
line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}"
1108+
else:
1109+
operation = "affine.load"
1110+
line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}"
1111+
1112+
out = self.cse.generate(load_buffer, line)
1113+
self.register_var_info(out, [compute_vec_size, mlir_dtype])
1114+
self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape]
1115+
return out
10941116
else:
1095-
operation = "affine.load"
1096-
line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}"
1097-
out = self.cse.generate(self.loads, line)
1098-
self.register_var_info(out, [compute_vec_size, mlir_dtype])
1099-
self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape]
1100-
return out
1117+
out = sram_var
1118+
self.register_var_info(out, [compute_vec_size, mlir_dtype])
1119+
self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape]
1120+
return out
11011121

11021122
def store(self, name: str, index: sympy.Expr, value, *args, **kwargs):
11031123
index = self.rename_indexing(index)
@@ -1312,6 +1332,13 @@ def indirect_indexing(self, index_var, size, check=True):
13121332
return str(index_var)
13131333

13141334
def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index):
1335+
# In case of index expr, dimension size should be divisible by tile size
1336+
if not self.kernel_group.tile_desc.is_dim_dividable(self.ranges):
1337+
new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges)
1338+
self.kernel_group.tile_desc.set_tile_size(new_tile_size)
1339+
self.reset("recompile")
1340+
raise mlir_common.RecompileSignal(f"Index access (tile size {self.kernel_group.tile_desc.get_tile_size()} is not divisible by {self.ranges})")
1341+
13151342
tile_size = tile_desc.get_tile_size_per_lane()
13161343
compute_vec_size = tile_desc.get_compute_vec_size()
13171344
strides = tile_desc.get_tile_stride_per_lane()
@@ -1892,22 +1919,50 @@ def get_mask(self):
18921919

18931920
def convert_indirect_indexing(self, index :sympy.Expr):
18941921
if "tmp" not in str(index):
1895-
return index
1922+
return index, None
1923+
1924+
# Note: In case of indirect indexing, dimensions should be divisible by tile size
1925+
if not self.kernel_group.tile_desc.is_dim_dividable(self.ranges):
1926+
new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges)
1927+
self.kernel_group.tile_desc.set_tile_size(new_tile_size)
1928+
self.reset("recompile")
1929+
raise mlir_common.RecompileSignal(f"Indirect access (tile size {self.kernel_group.tile_desc.get_tile_size()} is not divisible by {self.ranges})")
18961930

18971931
# Process start
18981932
indirect_dims = [str(dim) for dim in index.free_symbols if "tmp" in str(dim)]
18991933
indirect_dims.sort()
19001934
first_dim = indirect_dims[0]
19011935
spad_vars = dict()
1902-
tmp_comp, self.compute = self.compute, self.dma_loads
1936+
old_compute, old_dma_lods, old_dma_stores = self.compute, self.dma_loads, self.dma_stores
1937+
compute_dependecy = any([target_dim not in self.spad_buffer_dict for target_dim in indirect_dims])
1938+
if compute_dependecy:
1939+
self.compute = old_dma_stores
1940+
target_dma_buffers = self.dma_stores
1941+
else:
1942+
self.compute = old_dma_lods
1943+
target_dma_buffers = self.dma_loads
19031944

19041945
# Load indirect operands
19051946
for target_dim in indirect_dims:
19061947
if target_dim in self.spad_buffer_dict:
19071948
sram_var, _, tile_numel_per_lane, sram_index_var, tile_shape, vshape = self.spad_buffer_dict[target_dim]
19081949
else:
1909-
raise NotImplementedError("TODO.")
1910-
1950+
# FIXME.
1951+
var_info = [v for k, v in self.var_info.items() if str(k) == target_dim][0]
1952+
dtype = mlir_common.MLIR_TO_DTYPE[var_info[1]]
1953+
1954+
local_tile_desc = self.kernel_group.tile_desc
1955+
tile_numel_per_lane = local_tile_desc.get_numel_per_lane()
1956+
tile_shape = local_tile_desc.get_mlir_shape(var_info[1])
1957+
vshape = f"vector<{var_info[0]}x{var_info[1]}>"
1958+
sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, target_dim, local_tile_desc, target_dim)
1959+
self.spad_buffer_dict[target_dim] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape]
1960+
1961+
# Store the indirect index variable
1962+
opeartion = "affine.vector_store"
1963+
compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"])
1964+
line = f"{opeartion} %{target_dim}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}"
1965+
self.stores.writeline(line)
19111966
mlir_dtype = vshape.split("x")[1][:-1]
19121967
vshape = f"vector<{tile_numel_per_lane}x{mlir_dtype}>" # FIXME. Maybe require fine grain compute...
19131968
if tile_numel_per_lane > 1:
@@ -1916,7 +1971,7 @@ def convert_indirect_indexing(self, index :sympy.Expr):
19161971
else:
19171972
operation = "affine.load"
19181973
line = f"{operation} %{sram_var}[{sram_index_var}] : {tile_shape} // For indirect access"
1919-
out = self.cse.generate(self.dma_loads, line)
1974+
out = self.cse.generate(target_dma_buffers, line)
19201975
self.register_var_info(out, [tile_numel_per_lane, mlir_dtype])
19211976
spad_vars[target_dim] = out
19221977

@@ -1946,15 +2001,15 @@ def convert_indirect_indexing(self, index :sympy.Expr):
19462001
else:
19472002
operation = "affine.store"
19482003
line = f"{operation} %{spad_vars[first_dim]}, %{sram_var}[{sram_index_var}] : {tile_shape}"
1949-
out = self.cse.generate(self.dma_loads, line, assignment=False)
2004+
out = self.cse.generate(target_dma_buffers, line, assignment=False)
19502005

19512006
# Conversion
19522007
mlir_dtype = self.var_info[spad_vars[first_dim]][1]
19532008
line = f"affine.load %{sram_var}[{sram_index_var}] : {tile_shape}"
1954-
out = self.cse.generate(self.dma_loads, line)
2009+
out = self.cse.generate(target_dma_buffers, line)
19552010
if mlir_dtype != "index":
19562011
line = f"arith.index_cast %{out} : {mlir_dtype} to {'index'}"
1957-
out = self.cse.generate(self.dma_loads, line)
2012+
out = self.cse.generate(target_dma_buffers, line)
19582013
self.register_var_info(out, [1, "index", [1]])
1959-
self.compute = tmp_comp
1960-
return index + sympy.Symbol(str(out))
2014+
self.compute, self.dma_loads, self.dma_stores = old_compute, old_dma_lods, old_dma_stores
2015+
return index + sympy.Symbol(str(out)), compute_dependecy

PyTorchSimFrontend/mlir/mlir_common.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@
4848
torch.bfloat16: "bf16",
4949
}
5050

51+
MLIR_TO_DTYPE = {
52+
"f32": torch.float32,
53+
"f64": torch.float64,
54+
"f16": torch.float16,
55+
"i64": torch.int64,
56+
"i32": torch.int32,
57+
"i16": torch.int16,
58+
"i8": torch.int8,
59+
"bf16": torch.bfloat16,
60+
}
61+
5162
DTYPE_TO_C = {
5263
torch.float32: "float",
5364
torch.float64: "double",
@@ -393,6 +404,22 @@ def apply_divisor(self, axis: int, divisor: int, mode: str = "split"):
393404
def get_reduction_numel(self):
394405
return reduce(mul, self.get_tile_size()[-1*self.nr_rdim:], 1)
395406

407+
def is_dim_dividable(self, dim_sizes):
408+
if len(dim_sizes) != len(self._tile_size):
409+
raise ValueError("dim_sizes must match the tile size dimensions.")
410+
return all(d % t == 0 for d, t in zip(dim_sizes, self._tile_size))
411+
412+
def adjust_tile_to_divisible(self, dim_sizes):
413+
def _adjust_one(dim_size, tile_size):
414+
for candidate in range(tile_size, 0, -1):
415+
if dim_size % candidate == 0:
416+
return candidate
417+
return 1
418+
419+
if len(dim_sizes) != len(self._tile_size):
420+
raise ValueError("dim_sizes must match the tile size dimensions.")
421+
return [_adjust_one(d, t) for d, t in zip(dim_sizes, self._tile_size)]
422+
396423
class MLIRWrapperKenrelGroup(cpp.KernelGroup):
397424
def __init__(self):
398425
super().__init__()

PyTorchSimFrontend/mlir/mlir_conv_mt_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ def __init__(self, input_nodes, layout, input_reorder=None, **kwargs):
138138
self.padding = kwargs["padding"]
139139
self.dilation = kwargs["dilation"]
140140
self.weight_shape = [str(i) for i in input_nodes[1].layout.size]
141-
self.input_shape = [i for i in input_nodes[0].layout.size]
142-
self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \
141+
self.input_shape = [str(i) for i in input_nodes[0].layout.size]
142+
self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \
143143
+ "_".join([str(i) for i in self.stride]) \
144144
+ "_" + "_".join([str(i) for i in self.padding]) \
145145
+ "_" + "_".join([str(i) for i in self.dilation])

PyTorchSimFrontend/mlir/mlir_conv_sb_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def __init__(self, input_nodes, layout, input_reorder=None, **kwargs):
139139
self.padding = kwargs["padding"]
140140
self.dilation = kwargs["dilation"]
141141
self.weight_shape = [str(i) for i in input_nodes[1].layout.size]
142-
self.input_shape = [i for i in input_nodes[0].layout.size]
143-
self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \
142+
self.input_shape = [str(i) for i in input_nodes[0].layout.size]
143+
self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \
144144
+ "_".join([str(i) for i in self.stride]) \
145145
+ "_" + "_".join([str(i) for i in self.padding]) \
146146
+ "_" + "_".join([str(i) for i in self.dilation])

PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def __init__(self, input_nodes, layout, input_reorder=None, **kwargs):
139139
self.padding = kwargs["padding"]
140140
self.dilation = kwargs["dilation"]
141141
self.weight_shape = [str(i) for i in input_nodes[1].layout.size]
142-
self.input_shape = [i for i in input_nodes[0].layout.size]
143-
self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \
142+
self.input_shape = [str(i) for i in input_nodes[0].layout.size]
143+
self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \
144144
+ "_".join([str(i) for i in self.stride]) \
145145
+ "_" + "_".join([str(i) for i in self.padding]) \
146146
+ "_" + "_".join([str(i) for i in self.dilation])

PyTorchSimFrontend/mlir/mlir_conv_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ def __init__(self, input_nodes, layout, input_reorder=None, **kwargs):
143143
self.padding = kwargs["padding"]
144144
self.dilation = kwargs["dilation"]
145145
self.weight_shape = [str(i) for i in input_nodes[1].layout.size]
146-
self.input_shape = [i for i in input_nodes[0].layout.size]
147-
self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \
146+
self.input_shape = [str(i) for i in input_nodes[0].layout.size]
147+
self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \
148148
+ "_".join([str(i) for i in self.stride]) \
149149
+ "_" + "_".join([str(i) for i in self.padding]) \
150150
+ "_" + "_".join([str(i) for i in self.dilation])

PyTorchSimFrontend/mlir/mlir_lowering.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import List, Optional, Sequence
22

33
import torch
4-
from torch._inductor.lowering import lowerings
4+
from torch._inductor.lowering import lowerings, index_impl
55
from torch._inductor.kernel.mm_common import mm_args
66
# from torch._inductor.select_algorithm import ExternKernelChoice
77
from torch._inductor import ir
@@ -175,10 +175,17 @@ def sparse_addmm(*args, **kwargs):
175175
)
176176
return aten_spmm.bind((sp_mat1, sp_mat2), layout).output_node()
177177

178+
def custom_unsafe_index(x, indices):
179+
# We can't fuse indirect access + indexed_expression + computation
180+
if isinstance(x, TensorBox):
181+
x.realize()
182+
return index_impl(x, indices, check=False)
183+
178184
lowerings.update({getattr(aten.mm, overload): tuned_mm for overload in aten.mm.overloads()})
179185
lowerings.update({getattr(aten.addmm, overload): tuned_addmm for overload in aten.addmm.overloads()})
180186
lowerings.update({getattr(aten.convolution, overload): convolution for overload in aten.convolution.overloads()})
181187
lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()})
182188
lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()})
189+
lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()})
183190
if CONFIG_USE_TIMING_POOLING:
184191
lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template

PyTorchSimFrontend/mlir/mlir_scheduling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule
5050
try:
5151
stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1]
5252
stride = int(sympify(stride).coeff(target_symbol))
53-
except sympy.core.SympifyError:
53+
except:
5454
return False
5555

5656
# We can't fuse dim=-1
@@ -109,6 +109,9 @@ def can_fuse_horizontal(self, node1, node2):
109109
if node1.is_template() and node2.is_template():
110110
return False
111111

112+
if '_unsafe_index' in node1.get_nodes()[0].node.origins or "_unsafe_index" in node2.get_nodes()[0].node.origins:
113+
return False
114+
112115
# Check template node fusion
113116
if node1.is_template() or node2.is_template():
114117
# Don't fuse maxpool template code

0 commit comments

Comments
 (0)