2626 affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUT_SIZES[i][DIM] }} step {{ INPUT_TILE_SIZES_DIM[i] }} {
2727 %index{{ DIM }}_{{i}} = affine.apply affine_map<(d0) -> (d0 + {{ CUMULATIVE_OFFSETS[i] }})> (%index_local{{ DIM }}_{{ i }})
2828 {{ kernel.def_dma_op("MVIN", INPUT_BUFFER_NAMES[i], INPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }}
29- {{ kernel.def_dma_op("MVOUT", OUT_DVAR, OUTPUT_IDXS[i], INPUT_TILE_DESCS [i], indent_size=INDENT_SIZE) }}
29+ {{ kernel.def_dma_op("MVOUT", OUT_DVAR, OUTPUT_IDXS[i], OUTPUT_TILE_DESCS [i], indent_size=INDENT_SIZE) }}
3030 } { inner_loop=true }
3131{%- endfor %}
3232
@@ -52,10 +52,6 @@ def render(
5252 tile_info = None ,
5353 ** kwargs ,
5454 ):
55- is_out_variant = template_buffer_node is not None
56- if is_out_variant :
57- self .output_node = template_buffer_node
58-
5955 # Extract info
6056 input_nodes = self .input_nodes
6157 y = self .output_node
@@ -73,11 +69,8 @@ def render(
7369 kernel , input_sizes , tile_sizes , num_inputs , rank
7470 )
7571 buffer_name_to_template_name , input_buffer_names = self ._build_buffer_mapping (input_nodes )
76- input_tile_descs , unique_tile_descs = self ._build_tile_descriptors (
77- kernel , input_nodes , input_sizes , input_tile_sizes_dim , tile_sizes , rank , input_buffer_names
78- )
79- y_tile_desc = self ._build_output_tile_desc (
80- kernel , input_tile_sizes_dim , tile_sizes , rank
72+ input_tile_descs , output_tile_descs , unique_tile_descs = self ._build_tile_descriptors (
73+ kernel , input_nodes , input_sizes , input_tile_sizes_dim , tile_sizes , rank , input_buffer_names , y
8174 )
8275
8376 input_idxs , output_idxs , cumulative_offsets = self ._build_index_expressions (
@@ -90,14 +83,14 @@ def render(
9083 if actual_name in unique_tile_descs :
9184 unique_buffer_tile_descs [template_name ] = unique_tile_descs [actual_name ]
9285
93- names_str = ", " .join (input_buffer_names + ["out_ptr1" if is_out_variant else " Y" ])
86+ names_str = ", " .join (input_buffer_names + ["Y" ])
9487 indent_size = 2 + (rank - 1 ) * 2 + 4
9588
9689 kernel .render_options = dict (
9790 KERNEL_NAME = self .name ,
9891 kernel = kernel ,
9992 Y = y ,
100- OUT_DVAR = "out_ptr1" if is_out_variant else " Y" ,
93+ OUT_DVAR = "Y" ,
10194 NAMES_STR = names_str ,
10295 INPUT_NAMES = input_nodes ,
10396 INPUT_BUFFER_NAMES = input_buffer_names ,
@@ -110,6 +103,7 @@ def render(
110103 TILE_SIZES = tile_sizes ,
111104 INPUT_TILE_SIZES_DIM = input_tile_sizes_dim ,
112105 INPUT_TILE_DESCS = input_tile_descs ,
106+ OUTPUT_TILE_DESCS = output_tile_descs ,
113107 UNIQUE_BUFFER_TILE_DESCS = unique_buffer_tile_descs ,
114108 INPUT_IDXS = input_idxs ,
115109 OUTPUT_IDXS = output_idxs ,
@@ -209,14 +203,16 @@ def _build_buffer_mapping(self, input_nodes):
209203 return buffer_name_to_template_name , input_buffer_names
210204
211205 def _build_tile_descriptors (
212- self , kernel , input_nodes , input_sizes , input_tile_sizes_dim , tile_sizes , rank , input_buffer_names
206+ self , kernel , input_nodes , input_sizes , input_tile_sizes_dim , tile_sizes , rank , input_buffer_names , output_node
213207 ):
214- """Build tile descriptors for each input."""
208+ """Build tile descriptors for each input and output ."""
215209 input_tile_descs = []
210+ output_tile_descs = []
216211 unique_tile_descs = {}
212+ output_offset = output_node .get_layout ().offset
217213
218214 for i , x in enumerate (input_nodes ):
219- # Build full tile size list for this input
215+ x_offset = x . get_layout (). offset
220216 full_tile_sizes = []
221217 tile_size_idx = 0
222218 for d in range (rank ):
@@ -226,23 +222,37 @@ def _build_tile_descriptors(
226222 else :
227223 full_tile_sizes .append (input_tile_sizes_dim [i ])
228224
229- tile_desc = mlir_common .MLIRMultiDimTile (
225+ # Input tile descriptor
226+ input_tile_desc = mlir_common .MLIRMultiDimTile (
230227 full_tile_sizes ,
231228 kernel .vector_lane ,
232229 vlane_split_axis = rank - 1 ,
233230 vlane_stride = 1
234231 )
235- tile_desc .set_tile_size (full_tile_sizes )
232+ input_tile_desc .set_tile_size (full_tile_sizes )
236233 template_buffer_name = input_buffer_names [i ]
237- tile_desc .set_name (f"{ template_buffer_name .lower ()} _cat_tile" )
238- input_tile_descs .append (tile_desc )
234+ input_tile_desc .set_name (f"{ template_buffer_name .lower ()} _cat_tile" )
235+ input_tile_desc .offset = x_offset
236+ input_tile_descs .append (input_tile_desc )
237+
238+ # Output tile descriptor (same as input but with output offset)
239+ output_tile_desc = mlir_common .MLIRMultiDimTile (
240+ full_tile_sizes ,
241+ kernel .vector_lane ,
242+ vlane_split_axis = rank - 1 ,
243+ vlane_stride = 1
244+ )
245+ output_tile_desc .set_tile_size (full_tile_sizes )
246+ output_tile_desc .set_name (f"{ template_buffer_name .lower ()} _cat_tile" )
247+ output_tile_desc .offset = output_offset
248+ output_tile_descs .append (output_tile_desc )
239249
240250 # Store unique tile desc by actual buffer name
241251 actual_name = x .get_name ()
242252 if actual_name not in unique_tile_descs :
243- unique_tile_descs [actual_name ] = tile_desc
253+ unique_tile_descs [actual_name ] = input_tile_desc
244254
245- return input_tile_descs , unique_tile_descs
255+ return input_tile_descs , output_tile_descs , unique_tile_descs
246256
247257 def _build_index_expressions (
248258 self , input_nodes , input_sizes , output_strides , rank , num_inputs
@@ -256,6 +266,12 @@ def _build_index_expressions(
256266
257267 for i , x in enumerate (input_nodes ):
258268 x_stride = x .get_layout ().stride
269+ x_offset = x .get_layout ().offset
270+ if hasattr (x , 'data' ) and hasattr (x .data , 'dims' ):
271+ # In case of PermuteView, the stride is permuted
272+ perm_dims = x .data .dims
273+ x_stride = [x_stride [perm_dims [d ]] for d in range (rank )]
274+
259275 input_idx = []
260276 output_idx = []
261277 for d in range (rank ):
@@ -271,25 +287,3 @@ def _build_index_expressions(
271287 output_idxs .append (output_idx )
272288
273289 return input_idxs , output_idxs , cumulative_offsets
274-
275- def _build_output_tile_desc (self , kernel , input_tile_sizes_dim , tile_sizes , rank ):
276- """Build output tile descriptor."""
277- max_output_tile_dim = max (input_tile_sizes_dim ) if input_tile_sizes_dim else 1
278- output_full_tile_sizes = []
279- tile_size_idx = 0
280- for d in range (rank ):
281- if d != self .dim :
282- output_full_tile_sizes .append (tile_sizes [tile_size_idx ])
283- tile_size_idx += 1
284- else :
285- output_full_tile_sizes .append (max_output_tile_dim )
286-
287- y_tile_desc = mlir_common .MLIRMultiDimTile (
288- output_full_tile_sizes ,
289- kernel .vector_lane ,
290- vlane_split_axis = rank - 1 ,
291- vlane_stride = 1
292- )
293- y_tile_desc .set_tile_size (output_full_tile_sizes )
294- y_tile_desc .set_name ("y_cat_tile" )
295- return y_tile_desc
0 commit comments