Skip to content

Commit

Permalink
Temp...
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Oct 21, 2023
1 parent 48066fc commit 93c556e
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 24 deletions.
2 changes: 1 addition & 1 deletion numba_dpex/core/runtime/_dbg_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

/* Debugging facilities - enabled at compile-time */
/* #undef NDEBUG */
#if 0
#if 1
#include <stdio.h>
#define DPEXRT_DEBUG(X) \
{ \
Expand Down
6 changes: 3 additions & 3 deletions numba_dpex/experimental/kernel_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,10 +621,10 @@ def typeof_pyval(self, val):

def add_overload(self, cres, kernel_bitcode):
args = tuple(cres.signature.args)
sig = [a._code for a in args]
# sig = [a._code for a in args]
self.overloads[args] = kernel_bitcode

def compile(self, sig):
def compile(self, sig) -> _KernelCompileResult:
disp = self._get_dispatcher_for_current_target()
if disp is not self:
return disp.compile(sig)
Expand Down Expand Up @@ -689,7 +689,7 @@ def folded(args, kws):
self.add_overload(kcres.cres_or_error, kcres.kernel_bitcode)
# FIXME: enable caching
# self._cache.save_overload(sig, cres)
return kcres.kernel_bitcode
return kcres

def __getitem__(self, args):
"""Square-bracket notation for configuring the global_range and
Expand Down
175 changes: 155 additions & 20 deletions numba_dpex/experimental/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,41 @@

from numba_dpex import config, dpjit
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
from numba_dpex.core.types import DpnpNdArray
from numba_dpex.core.types import DpnpNdArray, NdRangeType, RangeType
from numba_dpex.core.utils import kernel_launcher as kl
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
from numba_dpex.utils import create_null_ptr
from numba_dpex.experimental.kernel_dispatcher import _KernelCompileResult
from numba_dpex.core.exceptions import (
UnreachableError,
)


def _get_queue_ref_val(
targetctx: cpu.CPUContext,
kernel_targetctx,
builder: llvmir.IRBuilder,
kernel_argtys,
kernel_args,
):
"""
Get the sycl queue from the first DpnpNdArray argument. Prior passes
before lowering make sure that compute-follows-data is enforceable
for a specific call to a kernel. As such, at the stage of lowering
the queue from the first DpnpNdArray argument can be extracted.
"""

for arg_num, argty in enumerate(kernel_argtys):
if isinstance(argty, DpnpNdArray):
llvm_val = kernel_args[arg_num]
datamodel = kernel_targetctx.data_model_manager.lookup(argty)
sycl_queue_attr_pos = datamodel.get_field_position("sycl_queue")
ptr_to_queue_ref = builder.extract_value(
llvm_val, sycl_queue_attr_pos
)
break

return ptr_to_queue_ref


def _create_kernel_bundle_from_spirv(
Expand Down Expand Up @@ -58,8 +89,10 @@ def _create_kernel_launcher_body(
codegen_targetctx: cpu.CPUContext,
kernel_targetctx: DpexKernelTargetContext,
builder: llvmir.IRBuilder,
indexer_argty: RangeType | NdRangeType,
kernel_argtys: tuple[types.Type, ...],
kernel_bc: llvmir.Constant,
kernel_func_name: str,
args: [llvmir.Instruction, ...],
):
klbuilder = kl.KernelLaunchIRBuilder(kernel_targetctx, builder)
Expand All @@ -69,6 +102,11 @@ def _create_kernel_launcher_body(
builder, "DPEX-DEBUG: Inside the kernel launcher function\n"
)

kernel_bc_byte_str: llvmir.Constant = codegen_targetctx.insert_const_bytes(
builder.module,
bytes=kernel_bc,
)

num_flattened_kernel_args = _get_num_flattened_kernel_args(
kernel_targetctx=kernel_targetctx, kernel_argtys=kernel_argtys
)
Expand All @@ -78,19 +116,114 @@ def _create_kernel_launcher_body(
args_ty_list = klbuilder.allocate_kernel_arg_ty_array(
num_flattened_kernel_args
)
# args[0] is the kernel fn
# args[1] is the index_space
# args[2:] are the kernel args
kernel_args_ptrs = []
for arg in args[2:]:
ptr = builder.alloca(arg.type)
builder.store(arg, ptr)
kernel_args_ptrs.append(ptr)

# breakpoint()
# Populate the args_list and the args_ty_list LLVM arrays
klbuilder.populate_kernel_args_and_args_ty_arrays(
callargs_ptrs=kernel_args_ptrs,
kernel_argtys=kernel_argtys,
args_list=args_list,
args_ty_list=args_ty_list,
datamodel_mgr=kernel_targetctx.data_model_manager,
)

# callargs = []
if config.DEBUG_KERNEL_LAUNCHER:
cgutils.printf(
builder, "DPEX-DEBUG: Populated kernel args and arg type arrays.\n"
)

# # Populate the args_list and the args_ty_list LLVM arrays
# klbuilder.populate_kernel_args_and_args_ty_arrays(
# callargs_ptrs=callargs_ptrs,
# kernel_argtys=kernel_argtys,
# args_list=args_list,
# args_ty_list=args_ty_list,
# datamodel_mgr=kernel_targetctx.data_model_manager,
# )
qref = _get_queue_ref_val(
targetctx=codegen_targetctx,
kernel_targetctx=kernel_targetctx,
builder=builder,
kernel_argtys=kernel_argtys,
kernel_args=args[2:],
)

if config.DEBUG_KERNEL_LAUNCHER:
cgutils.printf(
builder,
"DPEX-DEBUG: Extracted queue pointer from first dpnp array.\n",
)

kbref = _create_kernel_bundle_from_spirv(
builder=builder,
targetctx=codegen_targetctx,
queue_ref=qref,
kernel_bc=kernel_bc_byte_str,
kernel_bc_size_in_bytes=len(kernel_bc),
)

if config.DEBUG_KERNEL_LAUNCHER:
cgutils.printf(
builder, "DPEX-DEBUG: Generated kernel_bundle from SPIR-V.\n"
)

# Get the pointer to the sycl::kernel object in the sycl::kernel_bundle
kernel_name = codegen_targetctx.insert_const_string(
builder.module, kernel_func_name
)
kref = sycl.dpctl_kernel_bundle_get_kernel(builder, kbref, kernel_name)

# Submit synchronous kernel
# FIXME: Needs to change once we support returning a SyclEvent back to
# caller.
if isinstance(indexer_argty, RangeType):
range_ndim = indexer_argty.ndim
range_arg = args[1]
range_extents = []
datamodel = kernel_targetctx.data_model_manager.lookup(indexer_argty)
for dim_num in range(range_ndim):
dim_pos = datamodel.get_field_position("dim" + str(dim_num))
range_extents.append(builder.extract_value(range_arg, dim_pos))

if config.DEBUG_KERNEL_LAUNCHER:
cgutils.printf(builder, "DPEX-DEBUG: Submit sync range kernel.\n")

klbuilder.submit_sync_kernel(
sycl_kernel_ref=kref,
sycl_queue_ref=qref,
total_kernel_args=num_flattened_kernel_args,
arg_list=args_list,
arg_ty_list=args_ty_list,
global_range=range_extents,
local_range=[],
)
elif isinstance(indexer_argty, NdRangeType):
ndrange_ndim = indexer_argty.ndim
ndrange_arg = args[1]
grange_extents = []
lrange_extents = []
datamodel = kernel_targetctx.data_model_manager.lookup(indexer_argty)
for dim_num in range(ndrange_ndim):
gdim_pos = datamodel.get_field_position("gdim" + str(dim_num))
grange_extents.append(builder.extract_value(ndrange_arg, gdim_pos))
ldim_pos = datamodel.get_field_position("ldim" + str(dim_num))
lrange_extents.append(builder.extract_value(ndrange_arg, ldim_pos))

klbuilder.submit_sync_kernel(
sycl_kernel_ref=kref,
sycl_queue_ref=qref,
total_kernel_args=num_flattened_kernel_args,
arg_list=args_list,
arg_ty_list=args_ty_list,
global_range=grange_extents,
local_range=lrange_extents,
)
else:
raise UnreachableError

# Delete the kernel ref
sycl.dpctl_kernel_delete(builder, kref)
# Delete the kernel bundle pointer
sycl.dpctl_kernel_bundle_delete(builder, kbref)


@intrinsic
Expand All @@ -99,7 +232,7 @@ def launch_trampoline(typingctx, kernel_fn, index_space, a, b, c):
sig = types.void(kernel_fn, index_space, a, b, c)
# signature of the kernel_fn
kernel_sig = types.void(a, b, c)
kernel_bitcode = kernel_fn.dispatcher.compile(kernel_sig)
kcres: _KernelCompileResult = kernel_fn.dispatcher.compile(kernel_sig)
kernel_targetctx = kernel_fn.dispatcher.targetctx

def codegen(cgctx, builder, sig, llargs):
Expand All @@ -111,17 +244,19 @@ def codegen(cgctx, builder, sig, llargs):
# builder.call(<driver function>, kernel_work_details, llargs)
kernel_argtys = kernel_sig.args
_create_kernel_launcher_body(
cgctx,
kernel_targetctx,
builder,
kernel_argtys,
kernel_bitcode,
llargs,
codegen_targetctx=cgctx,
kernel_targetctx=kernel_targetctx,
builder=builder,
indexer_argty=sig.args[1],
kernel_argtys=kernel_argtys,
kernel_bc=kcres.kernel_bitcode,
kernel_func_name=kcres.cres_or_error.fndesc.llvm_func_name,
args=llargs,
)

return sig, codegen


@dpjit
def call_kernel(a, b, c, kernel_fn, index_space):
launch_trampoline(a, b, c, kernel_fn, index_space)
def call_kernel(kernel_fn, index_space, a, b, c):
launch_trampoline(kernel_fn, index_space, a, b, c)

0 comments on commit 93c556e

Please sign in to comment.