55from llvmlite import ir as llvmir
66from numba .core import cgutils , types
77
8- from numba_dpex import utils
8+ from numba_dpex import config , utils
99from numba_dpex .core .runtime .context import DpexRTContext
1010from numba_dpex .core .types import DpnpNdArray
1111from numba_dpex .dpctl_iface import libsyclinterface_bindings as sycl
@@ -361,6 +361,60 @@ def _create_sycl_range(self, idx_range):
361361
362362 return self .builder .bitcast (range_list , intp_ptr_t )
363363
364+ def submit_kernel (
365+ self ,
366+ kernel_ref : llvmir .CallInstr ,
367+ queue_ref : llvmir .PointerType ,
368+ kernel_args : list ,
369+ ty_kernel_args : list ,
370+ global_range_extents : list ,
371+ local_range_extents : list ,
372+ ):
373+ if config .DEBUG_KERNEL_LAUNCHER :
374+ cgutils .printf (
375+ self .builder ,
376+ "DPEX-DEBUG: Populating kernel args and arg type arrays.\n " ,
377+ )
378+
379+ num_flattened_kernel_args = self .get_num_flattened_kernel_args (
380+ kernel_argtys = ty_kernel_args ,
381+ )
382+
383+ # Create LLVM values for the kernel args list and kernel arg types list
384+ args_list = self .allocate_kernel_arg_array (num_flattened_kernel_args )
385+
386+ args_ty_list = self .allocate_kernel_arg_ty_array (
387+ num_flattened_kernel_args
388+ )
389+
390+ kernel_args_ptrs = []
391+ for arg in kernel_args :
392+ ptr = self .builder .alloca (arg .type )
393+ self .builder .store (arg , ptr )
394+ kernel_args_ptrs .append (ptr )
395+
396+ # Populate the args_list and the args_ty_list LLVM arrays
397+ self .populate_kernel_args_and_args_ty_arrays (
398+ callargs_ptrs = kernel_args_ptrs ,
399+ kernel_argtys = ty_kernel_args ,
400+ args_list = args_list ,
401+ args_ty_list = args_ty_list ,
402+ )
403+
404+ if config .DEBUG_KERNEL_LAUNCHER :
405+ cgutils .printf (self ._builder , "DPEX-DEBUG: Submit kernel.\n " )
406+
407+ return self .submit_sycl_kernel (
408+ sycl_kernel_ref = kernel_ref ,
409+ sycl_queue_ref = queue_ref ,
410+ total_kernel_args = num_flattened_kernel_args ,
411+ arg_list = args_list ,
412+ arg_ty_list = args_ty_list ,
413+ global_range = global_range_extents ,
414+ local_range = local_range_extents ,
415+ wait_before_return = False ,
416+ )
417+
364418 def submit_sycl_kernel (
365419 self ,
366420 sycl_kernel_ref ,
@@ -373,7 +427,7 @@ def submit_sycl_kernel(
373427 wait_before_return = True ,
374428 ) -> llvmir .PointerType (llvmir .IntType (8 )):
375429 """
376- Submits the kernel to the specified queue, waits.
430+ Submits the kernel to the specified queue, waits by default .
377431 """
378432 eref = None
379433 gr = self ._create_sycl_range (global_range )
@@ -411,19 +465,34 @@ def submit_sycl_kernel(
411465 else :
412466 return eref
413467
468+ def get_num_flattened_kernel_args (
469+ self ,
470+ kernel_argtys : tuple [types .Type , ...],
471+ ):
472+ num_flattened_kernel_args = 0
473+ for arg_type in kernel_argtys :
474+ if isinstance (arg_type , DpnpNdArray ):
475+ datamodel = self .context .data_model_manager .lookup (arg_type )
476+ num_flattened_kernel_args += datamodel .flattened_field_count
477+ elif arg_type in [types .complex64 , types .complex128 ]:
478+ num_flattened_kernel_args += 2
479+ else :
480+ num_flattened_kernel_args += 1
481+
482+ return num_flattened_kernel_args
483+
414484 def populate_kernel_args_and_args_ty_arrays (
415485 self ,
416486 kernel_argtys ,
417487 callargs_ptrs ,
418488 args_list ,
419489 args_ty_list ,
420- datamodel_mgr ,
421490 ):
422491 kernel_arg_num = 0
423492 for arg_num , argtype in enumerate (kernel_argtys ):
424493 llvm_val = callargs_ptrs [arg_num ]
425494 if isinstance (argtype , DpnpNdArray ):
426- datamodel = datamodel_mgr .lookup (argtype )
495+ datamodel = self . context . data_model_manager .lookup (argtype )
427496 self .build_array_arg (
428497 array_val = llvm_val ,
429498 array_data_model = datamodel ,
0 commit comments