Skip to content

Commit

Permalink
wip...
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Oct 20, 2023
1 parent 2faa610 commit 3967c31
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 233 deletions.
234 changes: 1 addition & 233 deletions driver.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import dpctl
import dpnp
from llvmlite import ir as llvmir
from numba.core import cgutils, cpu, types
from numba.extending import intrinsic

import numba_dpex.dpctl_iface.libsyclinterface_bindings as sycl
import numba_dpex.experimental as exp_dpex
from numba_dpex import NdRange, Range, config, dpjit
from numba_dpex.core import DpnpNdArray
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
from numba_dpex.utils import create_null_ptr


@exp_dpex.kernel(
Expand All @@ -32,232 +25,7 @@ def add(a, b, c):

r = Range(1)


def _create_kernel_bundle_from_spirv(
builder: llvmir.IRBuilder,
targetctx: cpu.CPUContext,
queue_ref: llvmir.PointerType,
kernel_bc: llvmir.Constant,
kernel_bc_size_in_bytes: int,
):
dref = sycl.dpctl_queue_get_device(builder, queue_ref)
cref = sycl.dpctl_queue_get_context(builder, queue_ref)
args = [
cref,
dref,
kernel_bc,
llvmir.Constant(llvmir.IntType(64), kernel_bc_size_in_bytes),
builder.load(create_null_ptr(builder, targetctx)),
]
kbref = sycl.dpctl_kernel_bundle_create_from_spirv(builder, *args)
sycl.dpctl_context_delete(builder, cref)
sycl.dpctl_device_delete(builder, cref)

return kbref


def _get_num_flattened_kernel_args(
kernel_targetctx: DpexKernelTargetContext,
kernel_argtys: tuple[types.Type, ...],
):
num_flattened_kernel_args = 0
for arg_type in kernel_argtys:
if isinstance(arg_type, DpnpNdArray):
datamodel = kernel_targetctx.data_model_manager.lookup(arg_type)
num_flattened_kernel_args += datamodel.flattened_field_count
elif arg_type == types.complex64 or arg_type == types.complex128:
num_flattened_kernel_args += 2
else:
num_flattened_kernel_args += 1

return num_flattened_kernel_args


def _create_kernel_launcher_body(
kernel_targetctx: DpexKernelTargetContext,
targetctx: cpu.CPUContext,
builder: llvmir.IRBuilder,
klbuilder: kl.KernelLaunchIRBuilder,
kernel_argtys: tuple[Type, ...],
fndesc: funcdesc.FunctionDescriptor,
kernel_bc: llvmir.Constant,
kernel_func_name: str,
kernel_bc_size_in_bytes: int,
kernel_launcher_fn: llvmir.Function,
):
num_flattened_kernel_args = 0
zero = llvmir.Constant(llvmir.IntType(32), 0)

if config.DEBUG_KERNEL_LAUNCHER:
cgutils.printf(
builder, "DPEX-DEBUG: Inside the kernel launcher function"
)

# Compute number of args to be passed to the kernel. Note that the
# actual number of kernel arguments is greater than the count of
# kernel_fn.kernel_args as arrays get flattened.
for arg_type in kernel_argtys:
if isinstance(arg_type, DpnpNdArray):
datamodel = kernel_targetctx.data_model_manager.lookup(arg_type)
num_flattened_kernel_args += datamodel.flattened_field_count
elif arg_type == types.complex64 or arg_type == types.complex128:
num_flattened_kernel_args += 2
else:
num_flattened_kernel_args += 1

# Create LLVM values for the kernel args list and kernel arg types list
args_list = klbuilder.allocate_kernel_arg_array(num_flattened_kernel_args)
args_ty_list = klbuilder.allocate_kernel_arg_ty_array(
num_flattened_kernel_args
)

kernel_arginfo = targetctx.get_arg_packer(kernel_argtys)

# The cpu_target_context.call_conv.get_function_type sets the argument
# list of the wrapper function as:
# [resptr, ir.PointerType(excinfo_ptr_t)] + argtypes. That is, a void*
# for the return value pointer, a pointer to Numba's exec_info struct,
# and then the actual function arguments. For this reason, to get the
# actual arguments we take the sub-list kernel_launcher_fn.args[2:].
wrapper_args = list(kernel_launcher_fn.args)[2:]

# The first argument of the kernel launcher wrapper is the indexer
# object that is either a RangeType or an NdRangeType. Thus we remove
# all the flattened args that are for the indexer to get the actual
# list of args to the underlying kernel.
indexer_argty = fndesc.argtypes[0]
indexer_argmodel = kernel_targetctx.data_model_manager.lookup(indexer_argty)
num_indexer_flattened_members = indexer_argmodel.flattened_field_count

kernel_args = wrapper_args[num_indexer_flattened_members:]

callargs = kernel_arginfo.from_arguments(builder, kernel_args)
callargs_ptrs = []
for arg in callargs:
ptr = builder.alloca(arg.type)
builder.store(arg, ptr)
callargs_ptrs.append(ptr)

# Populate the args_list and the args_ty_list LLVM arrays
klbuilder.populate_kernel_args_and_args_ty_arrays(
callargs_ptrs=callargs_ptrs,
kernel_argtys=kernel_argtys,
args_list=args_list,
args_ty_list=args_ty_list,
datamodel_mgr=kernel_targetctx.data_model_manager,
)

# Get the sycl queue from the first DpnpNdArray argument. Prior passes
# before lowering make sure that compute-follows-data is enforceable
# for a specific call to a kernel. As such, at the stage of lowering
# the queue from the first DpnpNdArray argument can be extracted.
ptr_to_queue_ref = None
for arg_num, argty in enumerate(kernel_argtys):
if isinstance(argty, DpnpNdArray):
llvm_val = callargs_ptrs[arg_num]
datamodel = kernel_targetctx.data_model_manager.lookup(argty)
sycl_queue_attr_pos = datamodel.get_field_position("sycl_queue")
ptr_to_queue_ref = builder.gep(
llvm_val,
[
targetctx.get_constant(types.int32, 0),
targetctx.get_constant(types.int32, sycl_queue_attr_pos),
],
)
break

qref = builder.load(ptr_to_queue_ref)

# Generate kernel bundle by calling libsyclinterface
kbref = self._create_kernel_bundle_from_spirv(
builder=builder,
targetctx=targetctx,
queue_ref=qref,
kernel_bc=kernel_bc,
kernel_bc_size_in_bytes=kernel_bc_size_in_bytes,
)
# Get the pointer to the sycl::kernel object in the sycl::kernel_bundle
kernel_name = targetctx.insert_const_string(
builder.module, kernel_func_name
)
kref = sycl.dpctl_kernel_bundle_get_kernel(builder, kbref, kernel_name)

# Submit synchronous kernel
# FIXME: Needs to change once we support returning a SyclEvent back to
# caller.
if isinstance(indexer_argty, RangeType):
range_dim = indexer_argty.ndim
# Refer: RangeModel definition. The first attribute is `ndim`
# followed by the actual range extents
range_extents = wrapper_args[1 : range_dim + 1] # noqa E203

klbuilder.submit_sync_kernel(
sycl_kernel_ref=kref,
sycl_queue_ref=qref,
total_kernel_args=num_flattened_kernel_args,
arg_list=args_list,
arg_ty_list=args_ty_list,
global_range=range_extents,
local_range=[],
)
elif isinstance(indexer_argty, NdRangeType):
range_dim = indexer_argty.ndim
# Refer: NdRangeModel definition. The first attribute is `ndim`
# followed by the actual range extents. The [1:4] are the global
# range extents and the attributes [4:7] are the local range extents
grange_extents = wrapper_args[1 : range_dim + 1] # noqa E203
lrange_extents = wrapper_args[4 : 4 + range_dim] # noqa E203

klbuilder.submit_sync_kernel(
sycl_kernel_ref=kref,
sycl_queue_ref=qref,
total_kernel_args=num_flattened_kernel_args,
arg_list=args_list,
arg_ty_list=args_ty_list,
global_range=grange_extents,
local_range=lrange_extents,
)
else:
raise UnreachableError

# Delete the kernel ref
sycl.dpctl_kernel_delete(builder, kref)
# Delete the kernel bundle pointer
sycl.dpctl_kernel_bundle_delete(builder, kbref)

# FIXME: Needs to change once we support returning a SyclEvent back to
# caller.
builder.ret(zero)


@intrinsic
def launch_trampoline(typingctx, kernel_fn, index_space, a, b, c):
# signature of this intrinsic
sig = types.void(kernel_fn, index_space, a, b, c)
# signature of the kernel_fn
kernel_sig = types.void(a, b, c)
kernel_bitcode = kernel_fn.dispatcher.compile(kernel_sig)
kernel_targetctx = kernel_fn.dispatcher.targetctx

def codegen(cgctx, builder, sig, llargs):
# # do something here to fetch the pointer to the kernel, or bitcode, or
# # llvm IR that you need for doing the actual kernel launch (should be
# # available from the kernel_inst compiled above).
# kernel_work_details = <get that thing above>
# # now generate the call to the driver to launch the kernel
# builder.call(<driver function>, kernel_work_details, llargs)
breakpoint()
pass

return sig, codegen


@dpjit
def call_kernel(a, b, c, kernel_fn, index_space):
launch_trampoline(a, b, c, kernel_fn, index_space)


call_kernel(add, r, a, b, c)
exp_dpex.call_kernel(add, r, a, b, c)

# ndr = NdRange(r, r)
# print(r)
Expand Down
1 change: 1 addition & 0 deletions numba_dpex/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .decorators import kernel
from .kernel_dispatcher import KernelDispatcher
from .launcher import call_kernel
from .models import *
from .types import KernelDispatcherType

Expand Down
121 changes: 121 additions & 0 deletions numba_dpex/experimental/launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from llvmlite import ir as llvmir
from numba.core import cgutils, cpu, types
from numba.extending import intrinsic

from numba_dpex import config, dpjit
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
from numba_dpex.core.types import DpnpNdArray
from numba_dpex.core.utils import kernel_launcher as kl
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
from numba_dpex.utils import create_null_ptr


def _create_kernel_bundle_from_spirv(
builder: llvmir.IRBuilder,
targetctx: cpu.CPUContext,
queue_ref: llvmir.PointerType,
kernel_bc: llvmir.Constant,
kernel_bc_size_in_bytes: int,
):
dref = sycl.dpctl_queue_get_device(builder, queue_ref)
cref = sycl.dpctl_queue_get_context(builder, queue_ref)
args = [
cref,
dref,
kernel_bc,
llvmir.Constant(llvmir.IntType(64), kernel_bc_size_in_bytes),
builder.load(create_null_ptr(builder, targetctx)),
]
kbref = sycl.dpctl_kernel_bundle_create_from_spirv(builder, *args)
sycl.dpctl_context_delete(builder, cref)
sycl.dpctl_device_delete(builder, cref)

return kbref


def _get_num_flattened_kernel_args(
kernel_targetctx: DpexKernelTargetContext,
kernel_argtys: tuple[types.Type, ...],
):
num_flattened_kernel_args = 0
for arg_type in kernel_argtys:
if isinstance(arg_type, DpnpNdArray):
datamodel = kernel_targetctx.data_model_manager.lookup(arg_type)
num_flattened_kernel_args += datamodel.flattened_field_count
elif arg_type == types.complex64 or arg_type == types.complex128:
num_flattened_kernel_args += 2
else:
num_flattened_kernel_args += 1

return num_flattened_kernel_args


def _create_kernel_launcher_body(
codegen_targetctx: cpu.CPUContext,
kernel_targetctx: DpexKernelTargetContext,
builder: llvmir.IRBuilder,
kernel_argtys: tuple[types.Type, ...],
kernel_bc: llvmir.Constant,
):
klbuilder = kl.KernelLaunchIRBuilder(kernel_targetctx, builder)

if config.DEBUG_KERNEL_LAUNCHER:
cgutils.printf(
builder, "DPEX-DEBUG: Inside the kernel launcher function"
)

num_flattened_kernel_args = _get_num_flattened_kernel_args(
kernel_targetctx=kernel_targetctx, kernel_argtys=kernel_argtys
)

# Create LLVM values for the kernel args list and kernel arg types list
args_list = klbuilder.allocate_kernel_arg_array(num_flattened_kernel_args)
args_ty_list = klbuilder.allocate_kernel_arg_ty_array(
num_flattened_kernel_args
)

cgutils.create_struct_proxy(kernel_argtys[0])(codegen_targetctx, builder)

breakpoint()

# Populate the args_list and the args_ty_list LLVM arrays
klbuilder.populate_kernel_args_and_args_ty_arrays(
callargs_ptrs=callargs_ptrs,
kernel_argtys=kernel_argtys,
args_list=args_list,
args_ty_list=args_ty_list,
datamodel_mgr=kernel_targetctx.data_model_manager,
)


@intrinsic
def launch_trampoline(typingctx, kernel_fn, index_space, a, b, c):
# signature of this intrinsic
sig = types.void(kernel_fn, index_space, a, b, c)
# signature of the kernel_fn
kernel_sig = types.void(a, b, c)
kernel_bitcode = kernel_fn.dispatcher.compile(kernel_sig)
kernel_targetctx = kernel_fn.dispatcher.targetctx

def codegen(cgctx, builder, sig, llargs):
# # do something here to fetch the pointer to the kernel, or bitcode, or
# # llvm IR that you need for doing the actual kernel launch (should be
# # available from the kernel_inst compiled above).
# kernel_work_details = <get that thing above>
# # now generate the call to the driver to launch the kernel
# builder.call(<driver function>, kernel_work_details, llargs)
kernel_args = kernel_sig.args
_create_kernel_launcher_body(
cgctx, kernel_targetctx, builder, kernel_args, kernel_bitcode
)

return sig, codegen


@dpjit
def call_kernel(a, b, c, kernel_fn, index_space):
launch_trampoline(a, b, c, kernel_fn, index_space)

0 comments on commit 3967c31

Please sign in to comment.