From cb8292b55a66635508d83e4e12a5a6a16e3903ad Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 22 Oct 2023 23:18:23 -0500 Subject: [PATCH] Make it possible to call any kernel using call_kernel. --- numba_dpex/experimental/__init__.py | 2 +- numba_dpex/experimental/launcher.py | 84 ++++++++++++++++++----------- 2 files changed, 53 insertions(+), 33 deletions(-) diff --git a/numba_dpex/experimental/__init__.py b/numba_dpex/experimental/__init__.py index 31b9b4c90d..ad96124166 100644 --- a/numba_dpex/experimental/__init__.py +++ b/numba_dpex/experimental/__init__.py @@ -19,4 +19,4 @@ def dpex_dispatcher_const(context, builder, ty, pyval): return context.get_dummy_value() -__all__ = ["kernel", "KernelDispatcher", "dpex_dispatcher_const"] +__all__ = ["kernel", "KernelDispatcher", "call_kernel"] diff --git a/numba_dpex/experimental/launcher.py b/numba_dpex/experimental/launcher.py index f887f8a1b1..25cb1305ae 100644 --- a/numba_dpex/experimental/launcher.py +++ b/numba_dpex/experimental/launcher.py @@ -4,7 +4,7 @@ from llvmlite import ir as llvmir from numba.core import cgutils, cpu, types -from numba.extending import intrinsic +from numba.extending import intrinsic, overload from numba_dpex import config, dpjit from numba_dpex.core.exceptions import UnreachableError @@ -12,18 +12,14 @@ from numba_dpex.core.types import DpnpNdArray, NdRangeType, RangeType from numba_dpex.core.utils import kernel_launcher as kl from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl -from numba_dpex.experimental.kernel_dispatcher import ( - _KernelCompileResult, - _KernelModule, -) +from numba_dpex.experimental.kernel_dispatcher import _KernelModule from numba_dpex.utils import create_null_ptr def _get_queue_ref_val( - targetctx: cpu.CPUContext, - kernel_targetctx, + kernel_targetctx: DpexKernelTargetContext, builder: llvmir.IRBuilder, - kernel_argtys, + kernel_argtys: [types.Type, ...], kernel_args, ): """ @@ -93,7 +89,8 @@ def _create_kernel_launcher_body( indexer_argty: RangeType | NdRangeType, kernel_argtys: tuple[types.Type, ...], kernel_module: _KernelModule, - args: [llvmir.Instruction, ...], + index_space_arg: llvmir.BaseStructType, + kernel_args: [llvmir.Instruction, ...], ): klbuilder = kl.KernelLaunchIRBuilder(kernel_targetctx, builder) @@ -116,11 +113,8 @@ def _create_kernel_launcher_body( args_ty_list = klbuilder.allocate_kernel_arg_ty_array( num_flattened_kernel_args ) - # args[0] is the kernel fn - # args[1] is the index_space - # args[2:] are the kernel args kernel_args_ptrs = [] - for arg in args[2:]: + for arg in kernel_args: ptr = builder.alloca(arg.type) builder.store(arg, ptr) kernel_args_ptrs.append(ptr) @@ -140,11 +134,10 @@ def _create_kernel_launcher_body( ) qref = _get_queue_ref_val( - targetctx=codegen_targetctx, kernel_targetctx=kernel_targetctx, builder=builder, kernel_argtys=kernel_argtys, - kernel_args=args[2:], + kernel_args=kernel_args, ) if config.DEBUG_KERNEL_LAUNCHER: @@ -177,12 +170,13 @@ def _create_kernel_launcher_body( # caller. if isinstance(indexer_argty, RangeType): range_ndim = indexer_argty.ndim - range_arg = args[1] range_extents = [] datamodel = kernel_targetctx.data_model_manager.lookup(indexer_argty) for dim_num in range(range_ndim): dim_pos = datamodel.get_field_position("dim" + str(dim_num)) - range_extents.append(builder.extract_value(range_arg, dim_pos)) + range_extents.append( + builder.extract_value(index_space_arg, dim_pos) + ) if config.DEBUG_KERNEL_LAUNCHER: cgutils.printf(builder, "DPEX-DEBUG: Submit sync range kernel.\n") @@ -206,15 +200,18 @@ def _create_kernel_launcher_body( elif isinstance(indexer_argty, NdRangeType): ndrange_ndim = indexer_argty.ndim - ndrange_arg = args[1] grange_extents = [] lrange_extents = [] datamodel = kernel_targetctx.data_model_manager.lookup(indexer_argty) for dim_num in range(ndrange_ndim): gdim_pos = datamodel.get_field_position("gdim" + str(dim_num)) - grange_extents.append(builder.extract_value(ndrange_arg, gdim_pos)) + grange_extents.append( + builder.extract_value(index_space_arg, gdim_pos) + ) ldim_pos = datamodel.get_field_position("ldim" + str(dim_num)) - lrange_extents.append(builder.extract_value(ndrange_arg, ldim_pos)) + lrange_extents.append( + builder.extract_value(index_space_arg, ldim_pos) + ) eref = klbuilder.submit_sycl_kernel( sycl_kernel_ref=kref, @@ -241,22 +238,20 @@ def _create_kernel_launcher_body( @intrinsic -def launch_trampoline(typingctx, kernel_fn, index_space, a, b, c): +def intrin_launch_trampoline(typingctx, kernel_fn, index_space, kernel_args): + kernel_args_list = [arg for arg in kernel_args] # signature of this intrinsic - sig = types.void(kernel_fn, index_space, a, b, c) + sig = types.void(kernel_fn, index_space, kernel_args) # signature of the kernel_fn - kernel_sig = types.void(a, b, c) + kernel_sig = types.void(*kernel_args_list) kmodule: _KernelModule = kernel_fn.dispatcher.compile(kernel_sig) kernel_targetctx = kernel_fn.dispatcher.targetctx def codegen(cgctx, builder, sig, llargs): - # # do something here to fetch the pointer to the kernel, or bitcode, or - # # llvm IR that you need for doing the actual kernel launch (should be - # # available from the kernel_inst compiled above). - # kernel_work_details = - # # now generate the call to the driver to launch the kernel - # builder.call(, kernel_work_details, llargs) kernel_argtys = kernel_sig.args + kernel_args_unpacked = [] + for pos in range(len(kernel_args)): + kernel_args_unpacked.append(builder.extract_value(llargs[2], pos)) _create_kernel_launcher_body( codegen_targetctx=cgctx, kernel_targetctx=kernel_targetctx, @@ -264,12 +259,37 @@ def codegen(cgctx, builder, sig, llargs): indexer_argty=sig.args[1], kernel_argtys=kernel_argtys, kernel_module=kmodule, - args=llargs, + index_space_arg=llargs[1], + kernel_args=kernel_args_unpacked, ) return sig, codegen +def _launch_trampoline(): + pass + + +@overload(_launch_trampoline) +def _ol_launch_trampoline(kernel_fn, index_space, *kernel_args): + def impl(kernel_fn, index_space, *kernel_args): + intrin_launch_trampoline(kernel_fn, index_space, kernel_args) + + return impl + + @dpjit -def call_kernel(kernel_fn, index_space, a, b, c): - launch_trampoline(kernel_fn, index_space, a, b, c) +def call_kernel(kernel_fn, index_space, *kernel_args): + """Calls a numba_dpex.kernel decorated function from CPython or from another + dpjit function. + + Args: + kernel_fn (numba_dpex.experimental.KernelDispatcher): A + numba_dpex.kernel decorated function that is compiled to a + KernelDispatcher by numba_dpex. + index_space (Range | NdRange): A numba_dpex.Range or numba_dpex.NdRange + type object that specifies the index space for the kernel. + kernel_args : List of objects that are passed to the numba_dpex.kernel + decorated function. + """ + _launch_trampoline(kernel_fn, index_space, *kernel_args)