Skip to content

Commit

Permalink
Make it possible to call any kernel using call_kernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Oct 24, 2023
1 parent 9ac4dfb commit cb8292b
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 33 deletions.
2 changes: 1 addition & 1 deletion numba_dpex/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ def dpex_dispatcher_const(context, builder, ty, pyval):
return context.get_dummy_value()


__all__ = ["kernel", "KernelDispatcher", "dpex_dispatcher_const"]
__all__ = ["kernel", "KernelDispatcher", "call_kernel"]
84 changes: 52 additions & 32 deletions numba_dpex/experimental/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,22 @@

from llvmlite import ir as llvmir
from numba.core import cgutils, cpu, types
from numba.extending import intrinsic
from numba.extending import intrinsic, overload

from numba_dpex import config, dpjit
from numba_dpex.core.exceptions import UnreachableError
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
from numba_dpex.core.types import DpnpNdArray, NdRangeType, RangeType
from numba_dpex.core.utils import kernel_launcher as kl
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
from numba_dpex.experimental.kernel_dispatcher import (
_KernelCompileResult,
_KernelModule,
)
from numba_dpex.experimental.kernel_dispatcher import _KernelModule
from numba_dpex.utils import create_null_ptr


def _get_queue_ref_val(
targetctx: cpu.CPUContext,
kernel_targetctx,
kernel_targetctx: DpexKernelTargetContext,
builder: llvmir.IRBuilder,
kernel_argtys,
kernel_argtys: [types.Type, ...],
kernel_args,
):
"""
Expand Down Expand Up @@ -93,7 +89,8 @@ def _create_kernel_launcher_body(
indexer_argty: RangeType | NdRangeType,
kernel_argtys: tuple[types.Type, ...],
kernel_module: _KernelModule,
args: [llvmir.Instruction, ...],
index_space_arg: llvmir.BaseStructType,
kernel_args: [llvmir.Instruction, ...],
):
klbuilder = kl.KernelLaunchIRBuilder(kernel_targetctx, builder)

Expand All @@ -116,11 +113,8 @@ def _create_kernel_launcher_body(
args_ty_list = klbuilder.allocate_kernel_arg_ty_array(
num_flattened_kernel_args
)
# args[0] is the kernel fn
# args[1] is the index_space
# args[2:] are the kernel args
kernel_args_ptrs = []
for arg in args[2:]:
for arg in kernel_args:
ptr = builder.alloca(arg.type)
builder.store(arg, ptr)
kernel_args_ptrs.append(ptr)
Expand All @@ -140,11 +134,10 @@ def _create_kernel_launcher_body(
)

qref = _get_queue_ref_val(
targetctx=codegen_targetctx,
kernel_targetctx=kernel_targetctx,
builder=builder,
kernel_argtys=kernel_argtys,
kernel_args=args[2:],
kernel_args=kernel_args,
)

if config.DEBUG_KERNEL_LAUNCHER:
Expand Down Expand Up @@ -177,12 +170,13 @@ def _create_kernel_launcher_body(
# caller.
if isinstance(indexer_argty, RangeType):
range_ndim = indexer_argty.ndim
range_arg = args[1]
range_extents = []
datamodel = kernel_targetctx.data_model_manager.lookup(indexer_argty)
for dim_num in range(range_ndim):
dim_pos = datamodel.get_field_position("dim" + str(dim_num))
range_extents.append(builder.extract_value(range_arg, dim_pos))
range_extents.append(
builder.extract_value(index_space_arg, dim_pos)
)

if config.DEBUG_KERNEL_LAUNCHER:
cgutils.printf(builder, "DPEX-DEBUG: Submit sync range kernel.\n")
Expand All @@ -206,15 +200,18 @@ def _create_kernel_launcher_body(

elif isinstance(indexer_argty, NdRangeType):
ndrange_ndim = indexer_argty.ndim
ndrange_arg = args[1]
grange_extents = []
lrange_extents = []
datamodel = kernel_targetctx.data_model_manager.lookup(indexer_argty)
for dim_num in range(ndrange_ndim):
gdim_pos = datamodel.get_field_position("gdim" + str(dim_num))
grange_extents.append(builder.extract_value(ndrange_arg, gdim_pos))
grange_extents.append(
builder.extract_value(index_space_arg, gdim_pos)
)
ldim_pos = datamodel.get_field_position("ldim" + str(dim_num))
lrange_extents.append(builder.extract_value(ndrange_arg, ldim_pos))
lrange_extents.append(
builder.extract_value(index_space_arg, ldim_pos)
)

eref = klbuilder.submit_sycl_kernel(
sycl_kernel_ref=kref,
Expand All @@ -241,35 +238,58 @@ def _create_kernel_launcher_body(


@intrinsic
def launch_trampoline(typingctx, kernel_fn, index_space, a, b, c):
def intrin_launch_trampoline(typingctx, kernel_fn, index_space, kernel_args):
kernel_args_list = [arg for arg in kernel_args]
# signature of this intrinsic
sig = types.void(kernel_fn, index_space, a, b, c)
sig = types.void(kernel_fn, index_space, kernel_args)
# signature of the kernel_fn
kernel_sig = types.void(a, b, c)
kernel_sig = types.void(*kernel_args_list)
kmodule: _KernelModule = kernel_fn.dispatcher.compile(kernel_sig)
kernel_targetctx = kernel_fn.dispatcher.targetctx

def codegen(cgctx, builder, sig, llargs):
# # do something here to fetch the pointer to the kernel, or bitcode, or
# # llvm IR that you need for doing the actual kernel launch (should be
# # available from the kernel_inst compiled above).
# kernel_work_details = <get that thing above>
# # now generate the call to the driver to launch the kernel
# builder.call(<driver function>, kernel_work_details, llargs)
kernel_argtys = kernel_sig.args
kernel_args_unpacked = []
for pos in range(len(kernel_args)):
kernel_args_unpacked.append(builder.extract_value(llargs[2], pos))
_create_kernel_launcher_body(
codegen_targetctx=cgctx,
kernel_targetctx=kernel_targetctx,
builder=builder,
indexer_argty=sig.args[1],
kernel_argtys=kernel_argtys,
kernel_module=kmodule,
args=llargs,
index_space_arg=llargs[1],
kernel_args=kernel_args_unpacked,
)

return sig, codegen


def _launch_trampoline():
pass


@overload(_launch_trampoline)
def _ol_launch_trampoline(kernel_fn, index_space, *kernel_args):
def impl(kernel_fn, index_space, *kernel_args):
intrin_launch_trampoline(kernel_fn, index_space, kernel_args)

return impl


@dpjit
def call_kernel(kernel_fn, index_space, a, b, c):
launch_trampoline(kernel_fn, index_space, a, b, c)
def call_kernel(kernel_fn, index_space, *kernel_args):
"""Calls a numba_dpex.kernel decorated function from CPython or from another
dpjit function.
Args:
kernel_fn (numba_dpex.experimental.KernelDispatcher): A
numba_dpex.kernel decorated function that is compiled to a
KernelDispatcher by numba_dpex.
index_space (Range | NdRange): A numba_dpex.Range or numba_dpex.NdRange
type object that specifies the index space for the kernel.
kernel_args : List of objects that are passed to the numba_dpex.kernel
decorated function.
"""
_launch_trampoline(kernel_fn, index_space, *kernel_args)

0 comments on commit cb8292b

Please sign in to comment.