Skip to content

Commit 61caebd

Browse files
committed
[Template/Cat] Fix apply offset setting
1 parent 5295dfb commit 61caebd

1 file changed

Lines changed: 37 additions & 43 deletions

File tree

PyTorchSimFrontend/mlir/mlir_cat_template.py

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUT_SIZES[i][DIM] }} step {{ INPUT_TILE_SIZES_DIM[i] }} {
2727
%index{{ DIM }}_{{i}} = affine.apply affine_map<(d0) -> (d0 + {{ CUMULATIVE_OFFSETS[i] }})> (%index_local{{ DIM }}_{{ i }})
2828
{{ kernel.def_dma_op("MVIN", INPUT_BUFFER_NAMES[i], INPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }}
29-
{{ kernel.def_dma_op("MVOUT", OUT_DVAR, OUTPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }}
29+
{{ kernel.def_dma_op("MVOUT", OUT_DVAR, OUTPUT_IDXS[i], OUTPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }}
3030
} { inner_loop=true }
3131
{%- endfor %}
3232
@@ -52,10 +52,6 @@ def render(
5252
tile_info=None,
5353
**kwargs,
5454
):
55-
is_out_variant = template_buffer_node is not None
56-
if is_out_variant:
57-
self.output_node = template_buffer_node
58-
5955
# Extract info
6056
input_nodes = self.input_nodes
6157
y = self.output_node
@@ -73,11 +69,8 @@ def render(
7369
kernel, input_sizes, tile_sizes, num_inputs, rank
7470
)
7571
buffer_name_to_template_name, input_buffer_names = self._build_buffer_mapping(input_nodes)
76-
input_tile_descs, unique_tile_descs = self._build_tile_descriptors(
77-
kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names
78-
)
79-
y_tile_desc = self._build_output_tile_desc(
80-
kernel, input_tile_sizes_dim, tile_sizes, rank
72+
input_tile_descs, output_tile_descs, unique_tile_descs = self._build_tile_descriptors(
73+
kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, y
8174
)
8275

8376
input_idxs, output_idxs, cumulative_offsets = self._build_index_expressions(
@@ -90,14 +83,14 @@ def render(
9083
if actual_name in unique_tile_descs:
9184
unique_buffer_tile_descs[template_name] = unique_tile_descs[actual_name]
9285

93-
names_str = ", ".join(input_buffer_names + ["out_ptr1" if is_out_variant else "Y"])
86+
names_str = ", ".join(input_buffer_names + ["Y"])
9487
indent_size = 2 + (rank - 1) * 2 + 4
9588

9689
kernel.render_options = dict(
9790
KERNEL_NAME=self.name,
9891
kernel=kernel,
9992
Y=y,
100-
OUT_DVAR="out_ptr1" if is_out_variant else "Y",
93+
OUT_DVAR="Y",
10194
NAMES_STR=names_str,
10295
INPUT_NAMES=input_nodes,
10396
INPUT_BUFFER_NAMES=input_buffer_names,
@@ -110,6 +103,7 @@ def render(
110103
TILE_SIZES=tile_sizes,
111104
INPUT_TILE_SIZES_DIM=input_tile_sizes_dim,
112105
INPUT_TILE_DESCS=input_tile_descs,
106+
OUTPUT_TILE_DESCS=output_tile_descs,
113107
UNIQUE_BUFFER_TILE_DESCS=unique_buffer_tile_descs,
114108
INPUT_IDXS=input_idxs,
115109
OUTPUT_IDXS=output_idxs,
@@ -209,14 +203,16 @@ def _build_buffer_mapping(self, input_nodes):
209203
return buffer_name_to_template_name, input_buffer_names
210204

211205
def _build_tile_descriptors(
212-
self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names
206+
self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, output_node
213207
):
214-
"""Build tile descriptors for each input."""
208+
"""Build tile descriptors for each input and output."""
215209
input_tile_descs = []
210+
output_tile_descs = []
216211
unique_tile_descs = {}
212+
output_offset = output_node.get_layout().offset
217213

218214
for i, x in enumerate(input_nodes):
219-
# Build full tile size list for this input
215+
x_offset = x.get_layout().offset
220216
full_tile_sizes = []
221217
tile_size_idx = 0
222218
for d in range(rank):
@@ -226,23 +222,37 @@ def _build_tile_descriptors(
226222
else:
227223
full_tile_sizes.append(input_tile_sizes_dim[i])
228224

229-
tile_desc = mlir_common.MLIRMultiDimTile(
225+
# Input tile descriptor
226+
input_tile_desc = mlir_common.MLIRMultiDimTile(
230227
full_tile_sizes,
231228
kernel.vector_lane,
232229
vlane_split_axis=rank - 1,
233230
vlane_stride=1
234231
)
235-
tile_desc.set_tile_size(full_tile_sizes)
232+
input_tile_desc.set_tile_size(full_tile_sizes)
236233
template_buffer_name = input_buffer_names[i]
237-
tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile")
238-
input_tile_descs.append(tile_desc)
234+
input_tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile")
235+
input_tile_desc.offset = x_offset
236+
input_tile_descs.append(input_tile_desc)
237+
238+
# Output tile descriptor (same as input but with output offset)
239+
output_tile_desc = mlir_common.MLIRMultiDimTile(
240+
full_tile_sizes,
241+
kernel.vector_lane,
242+
vlane_split_axis=rank - 1,
243+
vlane_stride=1
244+
)
245+
output_tile_desc.set_tile_size(full_tile_sizes)
246+
output_tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile")
247+
output_tile_desc.offset = output_offset
248+
output_tile_descs.append(output_tile_desc)
239249

240250
# Store unique tile desc by actual buffer name
241251
actual_name = x.get_name()
242252
if actual_name not in unique_tile_descs:
243-
unique_tile_descs[actual_name] = tile_desc
253+
unique_tile_descs[actual_name] = input_tile_desc
244254

245-
return input_tile_descs, unique_tile_descs
255+
return input_tile_descs, output_tile_descs, unique_tile_descs
246256

247257
def _build_index_expressions(
248258
self, input_nodes, input_sizes, output_strides, rank, num_inputs
@@ -256,6 +266,12 @@ def _build_index_expressions(
256266

257267
for i, x in enumerate(input_nodes):
258268
x_stride = x.get_layout().stride
269+
x_offset = x.get_layout().offset
270+
if hasattr(x, 'data') and hasattr(x.data, 'dims'):
271+
# In case of PermuteView, the stride is permuted
272+
perm_dims = x.data.dims
273+
x_stride = [x_stride[perm_dims[d]] for d in range(rank)]
274+
259275
input_idx = []
260276
output_idx = []
261277
for d in range(rank):
@@ -271,25 +287,3 @@ def _build_index_expressions(
271287
output_idxs.append(output_idx)
272288

273289
return input_idxs, output_idxs, cumulative_offsets
274-
275-
def _build_output_tile_desc(self, kernel, input_tile_sizes_dim, tile_sizes, rank):
276-
"""Build output tile descriptor."""
277-
max_output_tile_dim = max(input_tile_sizes_dim) if input_tile_sizes_dim else 1
278-
output_full_tile_sizes = []
279-
tile_size_idx = 0
280-
for d in range(rank):
281-
if d != self.dim:
282-
output_full_tile_sizes.append(tile_sizes[tile_size_idx])
283-
tile_size_idx += 1
284-
else:
285-
output_full_tile_sizes.append(max_output_tile_dim)
286-
287-
y_tile_desc = mlir_common.MLIRMultiDimTile(
288-
output_full_tile_sizes,
289-
kernel.vector_lane,
290-
vlane_split_axis=rank - 1,
291-
vlane_stride=1
292-
)
293-
y_tile_desc.set_tile_size(output_full_tile_sizes)
294-
y_tile_desc.set_name("y_cat_tile")
295-
return y_tile_desc

0 commit comments

Comments
 (0)