@@ -1007,6 +1007,8 @@ def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> com
10071007
10081008 # Extract index var
10091009 indirect_args = [f"%{ i } " for i in indirect_dims ]
1010+ if len (indirect_args ):
1011+ comments = "{indirect_access} " + comments # Add indirect access attribute
10101012 expr_str = str (expr )
10111013 if "//" in expr_str :
10121014 expr_str = expr_str .replace ("//" , " floordiv " )
@@ -1057,17 +1059,27 @@ def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0))
10571059
10581060 def load (self , name : str , index : sympy .Expr ):
10591061 index = self .rename_indexing (index )
1060- index = self .convert_indirect_indexing (index )
1062+ index , comptute_depedency = self .convert_indirect_indexing (index )
10611063 padding = self .get_padding_type ()
10621064
1065+ # In case of special form of indirect access, we need to put load in dma_store buffer
1066+ if comptute_depedency :
1067+ apply_buffer = self .dma_stores
1068+ dma_buffer = self .dma_stores
1069+ load_buffer = self .dma_stores
1070+ else :
1071+ apply_buffer = None
1072+ dma_buffer = self .dma_loads
1073+ load_buffer = self .loads
1074+
10631075 # Extract dram info
10641076 dram_var = self .kernel_group .args .input (name )
10651077 dram_shape = mlir_common .MLIRKernelArgs .get_mlir_shape (self .buffer_types [name ])
10661078 dtype = V .graph .get_dtype (name )
10671079 mlir_dtype = mlir_common .DTYPE_TO_MLIR [dtype ]
10681080
10691081 # Extract sram info
1070- local_tile_desc , index_var , dram_stride = self .get_dma_info (name , index )
1082+ local_tile_desc , index_var , dram_stride = self .get_dma_info (name , index , buffer = apply_buffer )
10711083 vlane_split_axis = local_tile_desc .vlane_split_axis
10721084 vlane_stride = local_tile_desc .vlane_stride
10731085 tile_numel_per_lane = local_tile_desc .get_numel_per_lane ()
@@ -1085,19 +1097,27 @@ def load(self, name: str, index: sympy.Expr):
10851097 attribute = f"{{dram_stride={ dram_stride } , sram_stride={ tile_stride } , padding={ padding } }}"
10861098 code = self .get_dma_code ("MVIN" , vlane_split_axis , vlane_stride , mlir_dtype , dram_var , index_var , sram_var , sram_index_var ,
10871099 dram_shape , tile_shape , attribute )
1088- self .cse .generate (self .dma_loads , code , assignment = False ) # FIXME: assignment = False does not support caching
1089- compute_index_var = "," .join (sram_index_var .split ("," )[:- 1 ] + [f"%{ self .compute_idx } " ])
1090- # Generate vector load instruction
1091- if compute_vec_size > 1 :
1092- operation = "affine.vector_load"
1093- line = f"{ operation } %{ sram_var } [{ compute_index_var } ] : { tile_shape } , { vshape } "
1100+ self .cse .generate (dma_buffer , code , assignment = False ) # FIXME: assignment = False does not support caching
1101+
1102+ if not comptute_depedency :
1103+ compute_index_var = "," .join (sram_index_var .split ("," )[:- 1 ] + [f"%{ self .compute_idx } " ])
1104+ # Generate vector load instruction
1105+ if compute_vec_size > 1 :
1106+ operation = "affine.vector_load"
1107+ line = f"{ operation } %{ sram_var } [{ compute_index_var } ] : { tile_shape } , { vshape } "
1108+ else :
1109+ operation = "affine.load"
1110+ line = f"{ operation } %{ sram_var } [{ compute_index_var } ] : { tile_shape } "
1111+
1112+ out = self .cse .generate (load_buffer , line )
1113+ self .register_var_info (out , [compute_vec_size , mlir_dtype ])
1114+ self .spad_buffer_dict [str (out )] = [sram_var , local_tile_desc .get_tile_size (), tile_numel_per_lane , sram_index_var , tile_shape , vshape ]
1115+ return out
10941116 else :
1095- operation = "affine.load"
1096- line = f"{ operation } %{ sram_var } [{ compute_index_var } ] : { tile_shape } "
1097- out = self .cse .generate (self .loads , line )
1098- self .register_var_info (out , [compute_vec_size , mlir_dtype ])
1099- self .spad_buffer_dict [str (out )] = [sram_var , local_tile_desc .get_tile_size (), tile_numel_per_lane , sram_index_var , tile_shape , vshape ]
1100- return out
1117+ out = sram_var
1118+ self .register_var_info (out , [compute_vec_size , mlir_dtype ])
1119+ self .spad_buffer_dict [str (out )] = [sram_var , local_tile_desc .get_tile_size (), tile_numel_per_lane , sram_index_var , tile_shape , vshape ]
1120+ return out
11011121
11021122 def store (self , name : str , index : sympy .Expr , value , * args , ** kwargs ):
11031123 index = self .rename_indexing (index )
@@ -1312,6 +1332,13 @@ def indirect_indexing(self, index_var, size, check=True):
13121332 return str (index_var )
13131333
13141334 def _index_expr (self , tile_desc , renamed_expression , index , base_vector_index ):
1335+ # In case of index expr, dimension size should be divisible by tile size
1336+ if not self .kernel_group .tile_desc .is_dim_dividable (self .ranges ):
1337+ new_tile_size = self .kernel_group .tile_desc .adjust_tile_to_divisible (self .ranges )
1338+ self .kernel_group .tile_desc .set_tile_size (new_tile_size )
1339+ self .reset ("recompile" )
1340+ raise mlir_common .RecompileSignal (f"Index access (tile size { self .kernel_group .tile_desc .get_tile_size ()} is not divisible by { self .ranges } )" )
1341+
13151342 tile_size = tile_desc .get_tile_size_per_lane ()
13161343 compute_vec_size = tile_desc .get_compute_vec_size ()
13171344 strides = tile_desc .get_tile_stride_per_lane ()
@@ -1892,22 +1919,50 @@ def get_mask(self):
18921919
18931920 def convert_indirect_indexing (self , index :sympy .Expr ):
18941921 if "tmp" not in str (index ):
1895- return index
1922+ return index , None
1923+
1924+ # Note: In case of indirect indexing, dimensions should be divisible by tile size
1925+ if not self .kernel_group .tile_desc .is_dim_dividable (self .ranges ):
1926+ new_tile_size = self .kernel_group .tile_desc .adjust_tile_to_divisible (self .ranges )
1927+ self .kernel_group .tile_desc .set_tile_size (new_tile_size )
1928+ self .reset ("recompile" )
1929+ raise mlir_common .RecompileSignal (f"Indirect access (tile size { self .kernel_group .tile_desc .get_tile_size ()} is not divisible by { self .ranges } )" )
18961930
18971931 # Process start
18981932 indirect_dims = [str (dim ) for dim in index .free_symbols if "tmp" in str (dim )]
18991933 indirect_dims .sort ()
19001934 first_dim = indirect_dims [0 ]
19011935 spad_vars = dict ()
1902- tmp_comp , self .compute = self .compute , self .dma_loads
1936+ old_compute , old_dma_lods , old_dma_stores = self .compute , self .dma_loads , self .dma_stores
1937+ compute_dependecy = any ([target_dim not in self .spad_buffer_dict for target_dim in indirect_dims ])
1938+ if compute_dependecy :
1939+ self .compute = old_dma_stores
1940+ target_dma_buffers = self .dma_stores
1941+ else :
1942+ self .compute = old_dma_lods
1943+ target_dma_buffers = self .dma_loads
19031944
19041945 # Load indirect operands
19051946 for target_dim in indirect_dims :
19061947 if target_dim in self .spad_buffer_dict :
19071948 sram_var , _ , tile_numel_per_lane , sram_index_var , tile_shape , vshape = self .spad_buffer_dict [target_dim ]
19081949 else :
1909- raise NotImplementedError ("TODO." )
1910-
1950+ # FIXME.
1951+ var_info = [v for k , v in self .var_info .items () if str (k ) == target_dim ][0 ]
1952+ dtype = mlir_common .MLIR_TO_DTYPE [var_info [1 ]]
1953+
1954+ local_tile_desc = self .kernel_group .tile_desc
1955+ tile_numel_per_lane = local_tile_desc .get_numel_per_lane ()
1956+ tile_shape = local_tile_desc .get_mlir_shape (var_info [1 ])
1957+ vshape = f"vector<{ var_info [0 ]} x{ var_info [1 ]} >"
1958+ sram_var , sram_index_var = self .get_scratchpad_buffer (dtype , target_dim , local_tile_desc , target_dim )
1959+ self .spad_buffer_dict [target_dim ] = [sram_var , local_tile_desc .get_tile_size (), tile_numel_per_lane , sram_index_var , tile_shape , vshape ]
1960+
1961+ # Store the indirect index variable
1962+ opeartion = "affine.vector_store"
1963+ compute_index_var = "," .join (sram_index_var .split ("," )[:- 1 ] + [f"%{ self .compute_idx } " ])
1964+ line = f"{ opeartion } %{ target_dim } , %{ sram_var } [{ compute_index_var } ] : { tile_shape } , { vshape } "
1965+ self .stores .writeline (line )
19111966 mlir_dtype = vshape .split ("x" )[1 ][:- 1 ]
19121967 vshape = f"vector<{ tile_numel_per_lane } x{ mlir_dtype } >" # FIXME. Maybe require fine grain compute...
19131968 if tile_numel_per_lane > 1 :
@@ -1916,7 +1971,7 @@ def convert_indirect_indexing(self, index :sympy.Expr):
19161971 else :
19171972 operation = "affine.load"
19181973 line = f"{ operation } %{ sram_var } [{ sram_index_var } ] : { tile_shape } // For indirect access"
1919- out = self .cse .generate (self . dma_loads , line )
1974+ out = self .cse .generate (target_dma_buffers , line )
19201975 self .register_var_info (out , [tile_numel_per_lane , mlir_dtype ])
19211976 spad_vars [target_dim ] = out
19221977
@@ -1946,15 +2001,15 @@ def convert_indirect_indexing(self, index :sympy.Expr):
19462001 else :
19472002 operation = "affine.store"
19482003 line = f"{ operation } %{ spad_vars [first_dim ]} , %{ sram_var } [{ sram_index_var } ] : { tile_shape } "
1949- out = self .cse .generate (self . dma_loads , line , assignment = False )
2004+ out = self .cse .generate (target_dma_buffers , line , assignment = False )
19502005
19512006 # Conversion
19522007 mlir_dtype = self .var_info [spad_vars [first_dim ]][1 ]
19532008 line = f"affine.load %{ sram_var } [{ sram_index_var } ] : { tile_shape } "
1954- out = self .cse .generate (self . dma_loads , line )
2009+ out = self .cse .generate (target_dma_buffers , line )
19552010 if mlir_dtype != "index" :
19562011 line = f"arith.index_cast %{ out } : { mlir_dtype } to { 'index' } "
1957- out = self .cse .generate (self . dma_loads , line )
2012+ out = self .cse .generate (target_dma_buffers , line )
19582013 self .register_var_info (out , [1 , "index" , [1 ]])
1959- self .compute = tmp_comp
1960- return index + sympy .Symbol (str (out ))
2014+ self .compute , self . dma_loads , self . dma_stores = old_compute , old_dma_lods , old_dma_stores
2015+ return index + sympy .Symbol (str (out )), compute_dependecy
0 commit comments