@@ -45,6 +45,8 @@ using namespace mlir;
4545#define BLOCK_SIZE_Z_ATTR " BlockSize.z"
4646#define ARG_RANKS_ATTR " arg_ranks"
4747#define CALL_CONVENTION_ATTR " call_convention"
48+ #define DYNAMIC_CONFIG " __byteir_dynamic_config__"
49+ #define KERNEL_LAUNCH_CONFIG_NUM 6
4850
4951namespace brt {
5052namespace cuda {
@@ -123,42 +125,50 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info)
123125 impl_->call_convention = " all" ;
124126 // static assignment for config
125127 // TODO extend to support dynamic
126- if (!info.GetOperation ()->hasAttrOfType <IntegerAttr>(GRID_SIZE_X_ATTR)) {
127- BRT_THROW_EX (std::runtime_error, " no GridSize.x attr" );
128+ bool dynamic_config_flag = false ;
129+ if (info.GetOperation ()->hasAttr (DYNAMIC_CONFIG)) {
130+ dynamic_config_flag = true ;
128131 }
132+ int gx, gy, gz, bx, by, bz;
133+ gx = gy = gz = bx = by = bz = 1 ;
134+ if (!dynamic_config_flag) {
135+ if (!info.GetOperation ()->hasAttrOfType <IntegerAttr>(GRID_SIZE_X_ATTR)) {
136+ BRT_THROW_EX (std::runtime_error, " no GridSize.x attr" );
137+ }
129138
130- if (!info.GetOperation ()->hasAttrOfType <IntegerAttr>(BLOCK_SIZE_X_ATTR)) {
131- BRT_THROW_EX (std::runtime_error, " no BlockSize.x attr" );
132- }
139+ if (!info.GetOperation ()->hasAttrOfType <IntegerAttr>(BLOCK_SIZE_X_ATTR)) {
140+ BRT_THROW_EX (std::runtime_error, " no BlockSize.x attr" );
141+ }
133142
134- int gx = static_cast <int >(info.GetOperation ()
135- ->getAttrOfType <IntegerAttr>(GRID_SIZE_X_ATTR)
136- .getInt ()),
137- gy = 1 , gz = 1 ;
138- if (info.GetOperation ()->hasAttrOfType <IntegerAttr>(GRID_SIZE_Y_ATTR)) {
139- gy = static_cast <int >(info.GetOperation ()
140- ->getAttrOfType <IntegerAttr>(GRID_SIZE_Y_ATTR)
141- .getInt ());
142- }
143- if (info.GetOperation ()->hasAttrOfType <IntegerAttr>(GRID_SIZE_Z_ATTR)) {
144- gz = static_cast <int >(info.GetOperation ()
145- ->getAttrOfType <IntegerAttr>(GRID_SIZE_Z_ATTR)
146- .getInt ());
147- }
143+ gx = static_cast <int >(info.GetOperation ()
144+ ->getAttrOfType <IntegerAttr>(GRID_SIZE_X_ATTR)
145+ .getInt ()),
146+ gy = 1 , gz = 1 ;
147+ if (info.GetOperation ()->hasAttrOfType <IntegerAttr>(GRID_SIZE_Y_ATTR)) {
148+ gy = static_cast <int >(info.GetOperation ()
149+ ->getAttrOfType <IntegerAttr>(GRID_SIZE_Y_ATTR)
150+ .getInt ());
151+ }
152+ if (info.GetOperation ()->hasAttrOfType <IntegerAttr>(GRID_SIZE_Z_ATTR)) {
153+ gz = static_cast <int >(info.GetOperation ()
154+ ->getAttrOfType <IntegerAttr>(GRID_SIZE_Z_ATTR)
155+ .getInt ());
156+ }
148157
149- int bx = static_cast <int >(info.GetOperation ()
150- ->getAttrOfType <IntegerAttr>(BLOCK_SIZE_X_ATTR)
151- .getInt ()),
152- by = 1 , bz = 1 ;
153- if (info.GetOperation ()->hasAttrOfType <IntegerAttr>(BLOCK_SIZE_Y_ATTR)) {
154- by = static_cast <int >(info.GetOperation ()
155- ->getAttrOfType <IntegerAttr>(BLOCK_SIZE_Y_ATTR)
156- .getInt ());
157- }
158- if (info.GetOperation ()->hasAttrOfType <IntegerAttr>(BLOCK_SIZE_Z_ATTR)) {
159- bz = static_cast <int >(info.GetOperation ()
160- ->getAttrOfType <IntegerAttr>(BLOCK_SIZE_Z_ATTR)
161- .getInt ());
158+ bx = static_cast <int >(info.GetOperation ()
159+ ->getAttrOfType <IntegerAttr>(BLOCK_SIZE_X_ATTR)
160+ .getInt ()),
161+ by = 1 , bz = 1 ;
162+ if (info.GetOperation ()->hasAttrOfType <IntegerAttr>(BLOCK_SIZE_Y_ATTR)) {
163+ by = static_cast <int >(info.GetOperation ()
164+ ->getAttrOfType <IntegerAttr>(BLOCK_SIZE_Y_ATTR)
165+ .getInt ());
166+ }
167+ if (info.GetOperation ()->hasAttrOfType <IntegerAttr>(BLOCK_SIZE_Z_ATTR)) {
168+ bz = static_cast <int >(info.GetOperation ()
169+ ->getAttrOfType <IntegerAttr>(BLOCK_SIZE_Z_ATTR)
170+ .getInt ());
171+ }
162172 }
163173
164174 std::vector<int > ranks;
@@ -172,6 +182,10 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info)
172182 }
173183
174184 auto num_arg = GetOpArgNum (info_);
185+ // filter launch config in inputs
186+ // TODO: make `shared_size` be a input operand in compiler.
187+ if (dynamic_config_flag)
188+ num_arg -= KERNEL_LAUNCH_CONFIG_NUM;
175189 impl_->grid = dim3 (gx, gy, gz);
176190 impl_->block = dim3 (bx, by, bz);
177191 impl_->shared_size = 0 ;
@@ -198,20 +212,34 @@ common::Status PTXOpKernel::RunImpl(const ExecutionContext &ctx) {
198212 std::vector<void *> args;
199213 std::vector<MLIREngineMemRefDescriptor> descs;
200214 args.reserve (impl_->arg_reserve_size );
215+ bool dynamic_config_flag = false ;
216+ if (info_.GetOperation ()->hasAttr (DYNAMIC_CONFIG)) {
217+ dynamic_config_flag = true ;
218+ auto num_arg = GetOpArgNum (info_);
219+ std::vector<int64_t > launch_config;
220+ launch_config.reserve (KERNEL_LAUNCH_CONFIG_NUM);
221+ for (size_t i = num_arg - KERNEL_LAUNCH_CONFIG_NUM; i < num_arg; ++i) {
222+ size_t idx = GetScalarIndexFromOpArgIndex (info_, i);
223+ launch_config.emplace_back (ctx.exec_frame ->GetScalar <int64_t >(idx));
224+ }
225+ impl_->grid = dim3 (launch_config[0 ], launch_config[1 ], launch_config[2 ]);
226+ impl_->block = dim3 (launch_config[3 ], launch_config[4 ], launch_config[5 ]);
227+ }
228+
201229 args.push_back (&(impl_->grid ));
202230 args.push_back (&(impl_->block ));
203231 args.push_back (&(impl_->shared_size ));
204232
205233 descs.reserve (impl_->tensor_ids .size ());
206234 for (size_t i = 0 ; i < impl_->tensor_ids .size (); ++i) {
207235 descs.emplace_back (ctx.exec_frame ->GetAsyncValueRef (impl_->tensor_ids [i]),
208- impl_->tensor_ranks [i]);
236+ ctx. exec_frame -> GetShapeRef ( impl_->tensor_ids [i]) );
209237 if (impl_->call_convention == " bare_ptr" )
210238 args.push_back (&descs.back ().data );
211- else
239+ else {
212240 InsertMemDescToArgs (descs.back (), args);
241+ }
213242 }
214-
215243 auto work_queue = static_cast <CUDAWorkQueue *>(ctx.work_queue );
216244 auto cuda_env = work_queue->GetCudaEnv ();
217245 BRT_ENFORCE (cuda_env.IsPrimaryContext (),
0 commit comments