77"""
88from collections import namedtuple
99from contextlib import ExitStack
10+ from typing import Tuple
1011
1112import numba .core .event as ev
1213from numba .core import errors , sigutils , types
2425from numba_dpex .core .pipelines import kernel_compiler
2526from numba_dpex .core .types import DpnpNdArray
2627
27- from .target import dpex_exp_kernel_target
28+ from .target import DPEX_KERNEL_EXP_TARGET_NAME , dpex_exp_kernel_target
2829
2930_KernelModule = namedtuple ("_KernelModule" , ["kernel_name" , "kernel_bitcode" ])
3031
3132_KernelCompileResult = namedtuple (
32- "_KernelCompileResult" ,
33- ["status" , "cres_or_error" , "entry_point" ],
33+ "_KernelCompileResult" , CompileResult ._fields + ("kernel_device_ir_module" ,)
3434)
3535
3636
@@ -96,15 +96,15 @@ def _compile_to_spirv(
9696 )
9797
9898 def compile (self , args , return_type ):
99- kcres = self ._compile_cached (args , return_type )
100- if kcres . status :
99+ status , kcres = self ._compile_cached (args , return_type )
100+ if status :
101101 return kcres
102102
103- raise kcres . cres_or_error
103+ raise kcres
104104
105105 def _compile_cached (
106106 self , args , return_type : types .Type
107- ) -> _KernelCompileResult :
107+ ) -> Tuple [ bool , _KernelCompileResult ] :
108108 """Compiles the kernel function to bitcode and generates a host-callable
109109 wrapper to submit the kernel to a SYCL queue.
110110
@@ -137,34 +137,45 @@ def _compile_cached(
137137 """
138138 key = tuple (args ), return_type
139139 try :
140- return _KernelCompileResult ( False , self ._failed_cache [key ], None )
140+ return False , self ._failed_cache [key ]
141141 except KeyError :
142142 pass
143143
144144 try :
145- kernel_cres : CompileResult = self ._compile_core (args , return_type )
145+ cres : CompileResult = self ._compile_core (args , return_type )
146146
147- kernel_library = kernel_cres .library
148- kernel_fndesc = kernel_cres .fndesc
149- kernel_targetctx = kernel_cres .target_context
150-
151- kernel_module = self ._compile_to_spirv (
152- kernel_library , kernel_fndesc , kernel_targetctx
147+ kernel_device_ir_module = self ._compile_to_spirv (
148+ cres .library , cres .fndesc , cres .target_context
153149 )
154150
151+ kcres_attrs = []
152+
153+ for cres_field in cres ._fields :
154+ cres_attr = getattr (cres , cres_field )
155+ if cres_field == "entry_point" :
156+ if cres_attr is not None :
157+ raise AssertionError (
158+ "Compiled kernel and device_func should be "
159+ "compiled with compile_cfunc option turned off"
160+ )
161+ cres_attr = cres .fndesc .qualname
162+ kcres_attrs .append (cres_attr )
163+
164+ kcres_attrs .append (kernel_device_ir_module )
165+
155166 if config .DUMP_KERNEL_LLVM :
156167 with open (
157- kernel_cres .fndesc .llvm_func_name + ".ll" ,
168+ cres .fndesc .llvm_func_name + ".ll" ,
158169 "w" ,
159170 encoding = "UTF-8" ,
160171 ) as f :
161- f .write (kernel_cres .library .final_module )
172+ f .write (cres .library .final_module )
162173
163174 except errors .TypingError as e :
164175 self ._failed_cache [key ] = e
165- return _KernelCompileResult ( False , e , None )
176+ return False , e
166177
167- return _KernelCompileResult ( True , kernel_cres , kernel_module )
178+ return True , _KernelCompileResult ( * kcres_attrs )
168179
169180
170181class KernelDispatcher (Dispatcher ):
@@ -234,7 +245,14 @@ def typeof_pyval(self, val):
234245
235246 def add_overload (self , cres ):
236247 args = tuple (cres .signature .args )
237- self .overloads [args ] = cres .entry_point
248+ self .overloads [args ] = cres
249+
250+ def get_overload_device_ir (self , sig ):
251+ """
252+ Return the compiled device bitcode for the given signature.
253+ """
254+ args , _ = sigutils .normalize_signature (sig )
255+ return self .overloads [tuple (args )].kernel_device_ir_module
238256
239257 def compile (self , sig ) -> _KernelCompileResult :
240258 disp = self ._get_dispatcher_for_current_target ()
@@ -274,7 +292,7 @@ def cb_llvm(dur):
274292 # Don't recompile if signature already exists
275293 existing = self .overloads .get (tuple (args ))
276294 if existing is not None :
277- return existing
295+ return existing . entry_point
278296
279297 # TODO: Enable caching
280298 # Add code to enable on disk caching of a binary spirv kernel.
@@ -298,7 +316,11 @@ def folded(args, kws):
298316 )[1 ]
299317
300318 raise e .bind_fold_arguments (folded )
301- self .add_overload (kcres .cres_or_error )
319+ self .add_overload (kcres )
320+
321+ kcres .target_context .insert_user_function (
322+ kcres .entry_point , kcres .fndesc , [kcres .library ]
323+ )
302324
303325 # TODO: enable caching of kernel_module
304326 # https://github.com/IntelPython/numba-dpex/issues/1197
@@ -318,5 +340,5 @@ def __call__(self, *args, **kw_args):
318340 raise NotImplementedError
319341
320342
321- _dpex_target = target_registry ["dpex_kernel" ]
343+ _dpex_target = target_registry [DPEX_KERNEL_EXP_TARGET_NAME ]
322344dispatcher_registry [_dpex_target ] = KernelDispatcher
0 commit comments