From 9ac4dfb5c9c53937d297f65bbf8dad376614770d Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 22 Oct 2023 21:48:28 -0500 Subject: [PATCH] Adds an experimental KernelDispacther to numba_dpex. - The numba_dpex.experimental module adds a new dispatcher class for numba_dpex kernels. The new dispatcher is a numba.dispatcher.Dispathcer subclass. - Introduce a new compiler class that is used to compile a numba_dpex.kernel decorated function to spirv and then store the spirv module as the compiled "overload". - Adds an experimental `call_kernel` dpjit function that will be used to submit or launch kernels. The `call_kernel` function generates LLVM IR code for all the functionality currenty done in pure Python in JitKernel.__call__. --- numba_dpex/core/descriptor.py | 2 + numba_dpex/experimental/__init__.py | 22 ++ numba_dpex/experimental/decorators.py | 76 +++++ numba_dpex/experimental/kernel_dispatcher.py | 324 +++++++++++++++++++ numba_dpex/experimental/launcher.py | 275 ++++++++++++++++ numba_dpex/experimental/models.py | 16 + numba_dpex/experimental/types.py | 9 + 7 files changed, 724 insertions(+) create mode 100644 numba_dpex/experimental/__init__.py create mode 100644 numba_dpex/experimental/decorators.py create mode 100644 numba_dpex/experimental/kernel_dispatcher.py create mode 100644 numba_dpex/experimental/launcher.py create mode 100644 numba_dpex/experimental/models.py create mode 100644 numba_dpex/experimental/types.py diff --git a/numba_dpex/core/descriptor.py b/numba_dpex/core/descriptor.py index 6d40686289..406f7115f8 100644 --- a/numba_dpex/core/descriptor.py +++ b/numba_dpex/core/descriptor.py @@ -38,11 +38,13 @@ def _inherit_if_not_set(flags, options, name, default=targetconfig._NotSet): class DpexTargetOptions(CPUTargetOptions): experimental = _option_mapping("experimental") release_gil = _option_mapping("release_gil") + no_compile = _option_mapping("no_compile") def finalize(self, flags, options): super().finalize(flags, options) _inherit_if_not_set(flags, options, "experimental", False) _inherit_if_not_set(flags, options, "release_gil", False) + _inherit_if_not_set(flags, options, "no_compile", True) class DpexKernelTarget(TargetDescriptor): diff --git a/numba_dpex/experimental/__init__.py b/numba_dpex/experimental/__init__.py new file mode 100644 index 0000000000..31b9b4c90d --- /dev/null +++ b/numba_dpex/experimental/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +from numba.core.imputils import Registry + +from .decorators import kernel +from .kernel_dispatcher import KernelDispatcher +from .launcher import call_kernel +from .models import * +from .types import KernelDispatcherType + +registry = Registry() +lower_constant = registry.lower_constant + + +@lower_constant(KernelDispatcherType) +def dpex_dispatcher_const(context, builder, ty, pyval): + return context.get_dummy_value() + + +__all__ = ["kernel", "KernelDispatcher", "dpex_dispatcher_const"] diff --git a/numba_dpex/experimental/decorators.py b/numba_dpex/experimental/decorators.py new file mode 100644 index 0000000000..914641e62e --- /dev/null +++ b/numba_dpex/experimental/decorators.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import inspect + +from numba.core import sigutils + +from .kernel_dispatcher import KernelDispatcher + + +def kernel(func_or_sig=None, debug=False, cache=False, **options): + """A decorator to define a kernel function. + + A kernel function is conceptually equivalent to a SYCL kernel function, and + gets compiled into either an OpenCL or a LevelZero SPIR-V binary kernel. + A kernel decorated Python function has the following restrictions: + + * The function can not return any value. + * All array arguments passed to a kernel should adhere to compute + follows data programming model. + """ + # FIXME: The options need to be evaluated and checked here like it is + # done in numba.core.decorators.jit + + def _kernel_dispatcher(pyfunc, sigs=None): + return KernelDispatcher( + pyfunc=pyfunc, + debug_flags=debug, + enable_cache=cache, + specialization_sigs=sigs, + targetoptions=options, + ) + + if func_or_sig is None: + return _kernel_dispatcher + elif isinstance(func_or_sig, str): + raise NotImplementedError( + "Specifying signatures as string is not yet supported by numba-dpex" + ) + elif isinstance(func_or_sig, list) or sigutils.is_signature(func_or_sig): + # String signatures are not supported as passing usm_ndarray type as + # a string is not possible. Numba's sigutils relies on the type being + # available in Numba's `types.__dict__` and dpex types are not + # registered there yet. + if isinstance(func_or_sig, list): + for sig in func_or_sig: + if isinstance(sig, str): + raise NotImplementedError( + "Specifying signatures as string is not yet supported " + "by numba-dpex" + ) + # Specialized signatures can either be a single signature or a list. + # In case only one signature is provided convert it to a list + if not isinstance(func_or_sig, list): + func_or_sig = [func_or_sig] + + def _specialized_kernel_dispatcher(pyfunc): + return KernelDispatcher( + pyfunc=pyfunc, + debug_flags=debug, + enable_cache=cache, + specialization_sigs=func_or_sig, + ) + + return _specialized_kernel_dispatcher + else: + func = func_or_sig + if not inspect.isfunction(func): + raise ValueError( + "Argument passed to the kernel decorator is neither a " + "function object, nor a signature. If you are trying to " + "specialize the kernel that takes a single argument, specify " + "the return type as void explicitly." + ) + return _kernel_dispatcher(func) diff --git a/numba_dpex/experimental/kernel_dispatcher.py b/numba_dpex/experimental/kernel_dispatcher.py new file mode 100644 index 0000000000..821965d4fe --- /dev/null +++ b/numba_dpex/experimental/kernel_dispatcher.py @@ -0,0 +1,324 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import functools +from collections import Counter, OrderedDict, namedtuple +from contextlib import ExitStack + +import numba.core.event as ev +from numba.core import errors, sigutils, types, utils +from numba.core.caching import NullCache +from numba.core.compiler import CompileResult +from numba.core.compiler_lock import global_compiler_lock +from numba.core.dispatcher import Dispatcher, _DispatcherBase, _FunctionCompiler +from numba.core.typing.typeof import Purpose, typeof + +from numba_dpex import config, spirv_generator +from numba_dpex.core.descriptor import dpex_kernel_target +from numba_dpex.core.exceptions import ( + InvalidKernelLaunchArgsError, + UnsupportedKernelArgumentError, +) +from numba_dpex.core.kernel_interface.indexers import NdRange, Range +from numba_dpex.core.pipelines import kernel_compiler +from numba_dpex.core.types import DpnpNdArray + +_KernelLauncherLowerResult = namedtuple( + "_KernelLauncherLowerResult", + ["sig", "fndesc", "library", "call_helper"], +) + +_KernelModule = namedtuple("_KernelModule", ["kernel_name", "kernel_bitcode"]) + +_KernelCompileResult = namedtuple( + "_KernelCompileResult", + ["status", "cres_or_error", "kernel_module"], +) + + +class _KernelCompiler(_FunctionCompiler): + def _compile_to_spirv( + self, kernel_library, kernel_fndesc, kernel_targetctx + ): + kernel_func = kernel_library.get_function(kernel_fndesc.llvm_func_name) + + # Create a spir_kernel wrapper function + kernel_fn = kernel_targetctx.prepare_spir_kernel( + kernel_func, kernel_fndesc.argtypes + ) + + # makes sure that the spir_func is completely inlined into the + # spir_kernel wrapper + kernel_library._optimize_final_module() + # Compiled the LLVM IR to SPIR-V + kernel_spirv_module = spirv_generator.llvm_to_spirv( + kernel_targetctx, + kernel_library._final_module, + kernel_library._final_module.as_bitcode(), + ) + return _KernelModule( + kernel_name=kernel_fn.name, kernel_bitcode=kernel_spirv_module + ) + + def compile(self, args, return_type): + kcres = self._compile_cached(args, return_type) + if kcres.status: + return kcres + else: + raise kcres.cres_or_error + + def _compile_cached( + self, kernel_args, return_type: types.Type + ) -> _KernelCompileResult: + """Compiles the kernel function to bitcode and generates a host-callable + wrapper to submit the kernel to a SYCL queue. + + The LLVM IR generated for the kernel function is available in the + CompileResult objected returned by + numba_dpex.core.pipeline.kernel_compiler.KernelCompiler. + + Once the kernel decorated function is compiled down to LLVM IR, the + following steps are performed: + + a) compile the IR into SPIR-V kernel + b) generate a host callable wrapper function that will create a + sycl::kernel_bundle from the SPIR-V and then submits the + kernel_bundle to a sycl::queue + c) create a cpython_wrapper_function for the host callable wrapper + function. + d) create a cfunc_wrapper_function to make the host callable wrapper + function callable inside another JIT-compiled function. + + Args: + args (tuple(types.Type)): A tuple of numba.core.Type instances each + representing the numba-inferred type of a kernel argument. + + return_type (types.Type): The numba-inferred type of the returned + value from the kernel. Should always be types.NoneType. + + Returns: + CompileResult: A CompileResult object storing the LLVM library for + the host-callable wrapper function. + """ + key = tuple(kernel_args), return_type + try: + return _KernelCompileResult(False, self._failed_cache[key], None) + except KeyError: + pass + + try: + kernel_cres: CompileResult = self._compile_core( + kernel_args, return_type + ) + + kernel_library = kernel_cres.library + kernel_fndesc = kernel_cres.fndesc + kernel_targetctx = kernel_cres.target_context + + kernel_module = self._compile_to_spirv( + kernel_library, kernel_fndesc, kernel_targetctx + ) + + if config.DUMP_KERNEL_LLVM: + with open( + kernel_cres.fndesc.llvm_func_name + ".ll", + "w", + ) as f: + f.write(kernel_cres.library._final_module.__str__()) + + except errors.TypingError as e: + self._failed_cache[key] = e + return _KernelCompileResult(False, e, None) + else: + return _KernelCompileResult(True, kernel_cres, kernel_module) + + +class KernelDispatcher(Dispatcher): + targetdescr = dpex_kernel_target + _fold_args = False + + Dispatcher._impl_kinds["kernel"] = _KernelCompiler + + def __init__( + self, + pyfunc, + debug_flags=None, + compile_flags=None, + specialization_sigs=None, + enable_cache=True, + locals={}, + targetoptions={}, + impl_kind="kernel", + pipeline_class=kernel_compiler.KernelCompiler, + ): + targetoptions["nopython"] = True + targetoptions["experimental"] = True + + self._kernel_name = pyfunc.__name__ + self._range = None + self._ndrange = None + + self.typingctx = self.targetdescr.typing_context + self.targetctx = self.targetdescr.target_context + + pysig = utils.pysignature(pyfunc) + arg_count = len(pysig.parameters) + + self.overloads = OrderedDict() + + can_fallback = not targetoptions.get("nopython", False) + + _DispatcherBase.__init__( + self, + arg_count, + pyfunc, + pysig, + can_fallback, + exact_match_required=False, + ) + # XXX: What does this function do exactly? + functools.update_wrapper(self, pyfunc) + + self.targetoptions = targetoptions + self.locals = locals + self._cache = NullCache() + compiler_class = self._impl_kinds[impl_kind] + self._impl_kind = impl_kind + self._compiler = compiler_class( + pyfunc, self.targetdescr, targetoptions, locals, pipeline_class + ) + self._cache_hits = Counter() + self._cache_misses = Counter() + + self._type = types.Dispatcher(self) + self.typingctx.insert_global(self, self._type) + + # Remember target restriction + self._required_target_backend = targetoptions.get("target_backend") + + def typeof_pyval(self, val): + """ + Resolve the Numba type of Python value *val*. + This is called from numba._dispatcher as a fallback if the native code + cannot decide the type. + """ + # Not going through the resolve_argument_type() indirection + # can save a couple µs. + try: + tp = typeof(val, Purpose.argument) + if isinstance(tp, types.Array) and not isinstance(tp, DpnpNdArray): + raise UnsupportedKernelArgumentError( + type=str(type(val)), value=val + ) + except ValueError: + tp = types.pyobject + else: + if tp is None: + tp = types.pyobject + self._types_active_call.append(tp) + return tp + + def add_overload(self, cres, kernel_module): + args = tuple(cres.signature.args) + self.overloads[args] = kernel_module + + def compile(self, sig) -> _KernelCompileResult: + disp = self._get_dispatcher_for_current_target() + if disp is not self: + return disp.compile(sig) + + with ExitStack() as scope: + cres = None + + def cb_compiler(dur): + if cres is not None: + self._callback_add_compiler_timer(dur, cres) + + def cb_llvm(dur): + if cres is not None: + self._callback_add_llvm_timer(dur, cres) + + scope.enter_context( + ev.install_timer("numba:compiler_lock", cb_compiler) + ) + scope.enter_context(ev.install_timer("numba:llvm_lock", cb_llvm)) + scope.enter_context(global_compiler_lock) + + if not self._can_compile: + raise RuntimeError("compilation disabled") + # Use counter to track recursion compilation depth + with self._compiling_counter: + args, return_type = sigutils.normalize_signature(sig) + # Don't recompile if signature already exists + existing = self.overloads.get(tuple(args)) + if existing is not None: + return existing + + # FIXME: Enable caching + # Add code to enable on disk caching of a binary spirv kernel + self._cache_misses[sig] += 1 + ev_details = dict( + dispatcher=self, + args=args, + return_type=return_type, + ) + with ev.trigger_event("numba_dpex:compile", data=ev_details): + try: + kcres: _KernelCompileResult = self._compiler.compile( + args, return_type + ) + except errors.ForceLiteralArg as e: + + def folded(args, kws): + return self._compiler.fold_argument_types( + args, kws + )[1] + + raise e.bind_fold_arguments(folded) + self.add_overload(kcres.cres_or_error, kcres.kernel_module) + + # FIXME: enable caching + + return kcres.kernel_module + + def __getitem__(self, args): + """Square-bracket notation for configuring the global_range and + local_range settings when launching a kernel on a SYCL queue. + + When a Python function decorated with the @kernel decorator, + is invoked it creates a KernelLauncher object. Calling the + KernelLauncher objects ``__getitem__`` function inturn clones the object + and sets the ``global_range`` and optionally the ``local_range`` + attributes with the arguments passed to ``__getitem__``. + + Args: + args (tuple): A tuple of tuples that specify the global and + optionally the local range for the kernel execution. If the + argument is a two-tuple of tuple, then it is assumed that both + global and local range options are specified. The first entry is + considered to be the global range and the second the local range. + + If only a single tuple value is provided, then the kernel is + launched with only a global range and the local range configuration + is decided by the SYCL runtime. + + Returns: + KernelLauncher: A clone of the KernelLauncher object, but with the + global_range and local_range attributes initialized. + """ + + if isinstance(args, Range): + self._range = args + elif isinstance(args, NdRange): + self._ndrange = args + else: + # FIXME: Improve error message + raise InvalidKernelLaunchArgsError(kernel_name=self._kernel_name) + + return self + + def __call__(self, *args, **kw_args): + """Functor to launch a kernel.""" + + raise NotImplementedError diff --git a/numba_dpex/experimental/launcher.py b/numba_dpex/experimental/launcher.py new file mode 100644 index 0000000000..f887f8a1b1 --- /dev/null +++ b/numba_dpex/experimental/launcher.py @@ -0,0 +1,275 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +from llvmlite import ir as llvmir +from numba.core import cgutils, cpu, types +from numba.extending import intrinsic + +from numba_dpex import config, dpjit +from numba_dpex.core.exceptions import UnreachableError +from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext +from numba_dpex.core.types import DpnpNdArray, NdRangeType, RangeType +from numba_dpex.core.utils import kernel_launcher as kl +from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl +from numba_dpex.experimental.kernel_dispatcher import ( + _KernelCompileResult, + _KernelModule, +) +from numba_dpex.utils import create_null_ptr + + +def _get_queue_ref_val( + targetctx: cpu.CPUContext, + kernel_targetctx, + builder: llvmir.IRBuilder, + kernel_argtys, + kernel_args, +): + """ + Get the sycl queue from the first DpnpNdArray argument. Prior passes + before lowering make sure that compute-follows-data is enforceable + for a specific call to a kernel. As such, at the stage of lowering + the queue from the first DpnpNdArray argument can be extracted. + """ + + for arg_num, argty in enumerate(kernel_argtys): + if isinstance(argty, DpnpNdArray): + llvm_val = kernel_args[arg_num] + datamodel = kernel_targetctx.data_model_manager.lookup(argty) + sycl_queue_attr_pos = datamodel.get_field_position("sycl_queue") + ptr_to_queue_ref = builder.extract_value( + llvm_val, sycl_queue_attr_pos + ) + break + + return ptr_to_queue_ref + + +def _create_kernel_bundle_from_spirv( + builder: llvmir.IRBuilder, + targetctx: cpu.CPUContext, + queue_ref: llvmir.PointerType, + kernel_bc: llvmir.Constant, + kernel_bc_size_in_bytes: int, +): + dref = sycl.dpctl_queue_get_device(builder, queue_ref) + cref = sycl.dpctl_queue_get_context(builder, queue_ref) + args = [ + cref, + dref, + kernel_bc, + llvmir.Constant(llvmir.IntType(64), kernel_bc_size_in_bytes), + builder.load(create_null_ptr(builder, targetctx)), + ] + kbref = sycl.dpctl_kernel_bundle_create_from_spirv(builder, *args) + sycl.dpctl_context_delete(builder, cref) + sycl.dpctl_device_delete(builder, dref) + + return kbref + + +def _get_num_flattened_kernel_args( + kernel_targetctx: DpexKernelTargetContext, + kernel_argtys: tuple[types.Type, ...], +): + num_flattened_kernel_args = 0 + for arg_type in kernel_argtys: + if isinstance(arg_type, DpnpNdArray): + datamodel = kernel_targetctx.data_model_manager.lookup(arg_type) + num_flattened_kernel_args += datamodel.flattened_field_count + elif arg_type == types.complex64 or arg_type == types.complex128: + num_flattened_kernel_args += 2 + else: + num_flattened_kernel_args += 1 + + return num_flattened_kernel_args + + +def _create_kernel_launcher_body( + codegen_targetctx: cpu.CPUContext, + kernel_targetctx: DpexKernelTargetContext, + builder: llvmir.IRBuilder, + indexer_argty: RangeType | NdRangeType, + kernel_argtys: tuple[types.Type, ...], + kernel_module: _KernelModule, + args: [llvmir.Instruction, ...], +): + klbuilder = kl.KernelLaunchIRBuilder(kernel_targetctx, builder) + + if config.DEBUG_KERNEL_LAUNCHER: + cgutils.printf( + builder, "DPEX-DEBUG: Inside the kernel launcher function\n" + ) + + kernel_bc_byte_str: llvmir.Constant = codegen_targetctx.insert_const_bytes( + builder.module, + bytes=kernel_module.kernel_bitcode, + ) + + num_flattened_kernel_args = _get_num_flattened_kernel_args( + kernel_targetctx=kernel_targetctx, kernel_argtys=kernel_argtys + ) + + # Create LLVM values for the kernel args list and kernel arg types list + args_list = klbuilder.allocate_kernel_arg_array(num_flattened_kernel_args) + args_ty_list = klbuilder.allocate_kernel_arg_ty_array( + num_flattened_kernel_args + ) + # args[0] is the kernel fn + # args[1] is the index_space + # args[2:] are the kernel args + kernel_args_ptrs = [] + for arg in args[2:]: + ptr = builder.alloca(arg.type) + builder.store(arg, ptr) + kernel_args_ptrs.append(ptr) + + # Populate the args_list and the args_ty_list LLVM arrays + klbuilder.populate_kernel_args_and_args_ty_arrays( + callargs_ptrs=kernel_args_ptrs, + kernel_argtys=kernel_argtys, + args_list=args_list, + args_ty_list=args_ty_list, + datamodel_mgr=kernel_targetctx.data_model_manager, + ) + + if config.DEBUG_KERNEL_LAUNCHER: + cgutils.printf( + builder, "DPEX-DEBUG: Populated kernel args and arg type arrays.\n" + ) + + qref = _get_queue_ref_val( + targetctx=codegen_targetctx, + kernel_targetctx=kernel_targetctx, + builder=builder, + kernel_argtys=kernel_argtys, + kernel_args=args[2:], + ) + + if config.DEBUG_KERNEL_LAUNCHER: + cgutils.printf( + builder, + "DPEX-DEBUG: Extracted queue pointer from first dpnp array.\n", + ) + + kbref = _create_kernel_bundle_from_spirv( + builder=builder, + targetctx=codegen_targetctx, + queue_ref=qref, + kernel_bc=kernel_bc_byte_str, + kernel_bc_size_in_bytes=len(kernel_module.kernel_bitcode), + ) + + if config.DEBUG_KERNEL_LAUNCHER: + cgutils.printf( + builder, "DPEX-DEBUG: Generated kernel_bundle from SPIR-V.\n" + ) + + # Get the pointer to the sycl::kernel object in the sycl::kernel_bundle + kernel_name = codegen_targetctx.insert_const_string( + builder.module, kernel_module.kernel_name + ) + kref = sycl.dpctl_kernel_bundle_get_kernel(builder, kbref, kernel_name) + + # Submit synchronous kernel + # FIXME: Needs to change once we support returning a SyclEvent back to + # caller. + if isinstance(indexer_argty, RangeType): + range_ndim = indexer_argty.ndim + range_arg = args[1] + range_extents = [] + datamodel = kernel_targetctx.data_model_manager.lookup(indexer_argty) + for dim_num in range(range_ndim): + dim_pos = datamodel.get_field_position("dim" + str(dim_num)) + range_extents.append(builder.extract_value(range_arg, dim_pos)) + + if config.DEBUG_KERNEL_LAUNCHER: + cgutils.printf(builder, "DPEX-DEBUG: Submit sync range kernel.\n") + + eref = klbuilder.submit_sycl_kernel( + sycl_kernel_ref=kref, + sycl_queue_ref=qref, + total_kernel_args=num_flattened_kernel_args, + arg_list=args_list, + arg_ty_list=args_ty_list, + global_range=range_extents, + local_range=[], + wait_before_return=False, + ) + + if config.DEBUG_KERNEL_LAUNCHER: + cgutils.printf(builder, "DPEX-DEBUG: Wait on event.\n") + + sycl.dpctl_event_wait(builder, eref) + sycl.dpctl_event_delete(builder, eref) + + elif isinstance(indexer_argty, NdRangeType): + ndrange_ndim = indexer_argty.ndim + ndrange_arg = args[1] + grange_extents = [] + lrange_extents = [] + datamodel = kernel_targetctx.data_model_manager.lookup(indexer_argty) + for dim_num in range(ndrange_ndim): + gdim_pos = datamodel.get_field_position("gdim" + str(dim_num)) + grange_extents.append(builder.extract_value(ndrange_arg, gdim_pos)) + ldim_pos = datamodel.get_field_position("ldim" + str(dim_num)) + lrange_extents.append(builder.extract_value(ndrange_arg, ldim_pos)) + + eref = klbuilder.submit_sycl_kernel( + sycl_kernel_ref=kref, + sycl_queue_ref=qref, + total_kernel_args=num_flattened_kernel_args, + arg_list=args_list, + arg_ty_list=args_ty_list, + global_range=grange_extents, + local_range=lrange_extents, + wait_before_return=False, + ) + if config.DEBUG_KERNEL_LAUNCHER: + cgutils.printf(builder, "DPEX-DEBUG: Wait on event.\n") + + sycl.dpctl_event_wait(builder, eref) + sycl.dpctl_event_delete(builder, eref) + else: + raise UnreachableError + + # Delete the kernel ref + sycl.dpctl_kernel_delete(builder, kref) + # Delete the kernel bundle pointer + sycl.dpctl_kernel_bundle_delete(builder, kbref) + + +@intrinsic +def launch_trampoline(typingctx, kernel_fn, index_space, a, b, c): + # signature of this intrinsic + sig = types.void(kernel_fn, index_space, a, b, c) + # signature of the kernel_fn + kernel_sig = types.void(a, b, c) + kmodule: _KernelModule = kernel_fn.dispatcher.compile(kernel_sig) + kernel_targetctx = kernel_fn.dispatcher.targetctx + + def codegen(cgctx, builder, sig, llargs): + # # do something here to fetch the pointer to the kernel, or bitcode, or + # # llvm IR that you need for doing the actual kernel launch (should be + # # available from the kernel_inst compiled above). + # kernel_work_details = + # # now generate the call to the driver to launch the kernel + # builder.call(, kernel_work_details, llargs) + kernel_argtys = kernel_sig.args + _create_kernel_launcher_body( + codegen_targetctx=cgctx, + kernel_targetctx=kernel_targetctx, + builder=builder, + indexer_argty=sig.args[1], + kernel_argtys=kernel_argtys, + kernel_module=kmodule, + args=llargs, + ) + + return sig, codegen + + +@dpjit +def call_kernel(kernel_fn, index_space, a, b, c): + launch_trampoline(kernel_fn, index_space, a, b, c) diff --git a/numba_dpex/experimental/models.py b/numba_dpex/experimental/models.py new file mode 100644 index 0000000000..8a541cde5d --- /dev/null +++ b/numba_dpex/experimental/models.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +from numba.core.datamodel import models +from numba.core.extending import register_model + +from numba_dpex.core.datamodel.models import dpex_data_model_manager as dmm + +from .types import KernelDispatcherType + +# Register the types and datamodel in the DpexKernelTargetContext +dmm.register(KernelDispatcherType, models.OpaqueModel) + +# Register the types and datamodel in the DpexTargetContext +register_model(KernelDispatcherType)(models.OpaqueModel) diff --git a/numba_dpex/experimental/types.py b/numba_dpex/experimental/types.py new file mode 100644 index 0000000000..04185a0381 --- /dev/null +++ b/numba_dpex/experimental/types.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +from numba.core import types + + +class KernelDispatcherType(types.Dispatcher): + """The type of KernelDispatcher dispatchers"""