-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Diptorup Deb
committed
Oct 20, 2023
1 parent
2faa610
commit 3967c31
Showing
3 changed files
with
123 additions
and
233 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# SPDX-FileCopyrightText: 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from llvmlite import ir as llvmir | ||
from numba.core import cgutils, cpu, types | ||
from numba.extending import intrinsic | ||
|
||
from numba_dpex import config, dpjit | ||
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext | ||
from numba_dpex.core.types import DpnpNdArray | ||
from numba_dpex.core.utils import kernel_launcher as kl | ||
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl | ||
from numba_dpex.utils import create_null_ptr | ||
|
||
|
||
def _create_kernel_bundle_from_spirv( | ||
builder: llvmir.IRBuilder, | ||
targetctx: cpu.CPUContext, | ||
queue_ref: llvmir.PointerType, | ||
kernel_bc: llvmir.Constant, | ||
kernel_bc_size_in_bytes: int, | ||
): | ||
dref = sycl.dpctl_queue_get_device(builder, queue_ref) | ||
cref = sycl.dpctl_queue_get_context(builder, queue_ref) | ||
args = [ | ||
cref, | ||
dref, | ||
kernel_bc, | ||
llvmir.Constant(llvmir.IntType(64), kernel_bc_size_in_bytes), | ||
builder.load(create_null_ptr(builder, targetctx)), | ||
] | ||
kbref = sycl.dpctl_kernel_bundle_create_from_spirv(builder, *args) | ||
sycl.dpctl_context_delete(builder, cref) | ||
sycl.dpctl_device_delete(builder, cref) | ||
|
||
return kbref | ||
|
||
|
||
def _get_num_flattened_kernel_args( | ||
kernel_targetctx: DpexKernelTargetContext, | ||
kernel_argtys: tuple[types.Type, ...], | ||
): | ||
num_flattened_kernel_args = 0 | ||
for arg_type in kernel_argtys: | ||
if isinstance(arg_type, DpnpNdArray): | ||
datamodel = kernel_targetctx.data_model_manager.lookup(arg_type) | ||
num_flattened_kernel_args += datamodel.flattened_field_count | ||
elif arg_type == types.complex64 or arg_type == types.complex128: | ||
num_flattened_kernel_args += 2 | ||
else: | ||
num_flattened_kernel_args += 1 | ||
|
||
return num_flattened_kernel_args | ||
|
||
|
||
def _create_kernel_launcher_body( | ||
codegen_targetctx: cpu.CPUContext, | ||
kernel_targetctx: DpexKernelTargetContext, | ||
builder: llvmir.IRBuilder, | ||
kernel_argtys: tuple[types.Type, ...], | ||
kernel_bc: llvmir.Constant, | ||
): | ||
klbuilder = kl.KernelLaunchIRBuilder(kernel_targetctx, builder) | ||
|
||
if config.DEBUG_KERNEL_LAUNCHER: | ||
cgutils.printf( | ||
builder, "DPEX-DEBUG: Inside the kernel launcher function" | ||
) | ||
|
||
num_flattened_kernel_args = _get_num_flattened_kernel_args( | ||
kernel_targetctx=kernel_targetctx, kernel_argtys=kernel_argtys | ||
) | ||
|
||
# Create LLVM values for the kernel args list and kernel arg types list | ||
args_list = klbuilder.allocate_kernel_arg_array(num_flattened_kernel_args) | ||
args_ty_list = klbuilder.allocate_kernel_arg_ty_array( | ||
num_flattened_kernel_args | ||
) | ||
|
||
cgutils.create_struct_proxy(kernel_argtys[0])(codegen_targetctx, builder) | ||
|
||
breakpoint() | ||
|
||
# Populate the args_list and the args_ty_list LLVM arrays | ||
klbuilder.populate_kernel_args_and_args_ty_arrays( | ||
callargs_ptrs=callargs_ptrs, | ||
kernel_argtys=kernel_argtys, | ||
args_list=args_list, | ||
args_ty_list=args_ty_list, | ||
datamodel_mgr=kernel_targetctx.data_model_manager, | ||
) | ||
|
||
|
||
@intrinsic | ||
def launch_trampoline(typingctx, kernel_fn, index_space, a, b, c): | ||
# signature of this intrinsic | ||
sig = types.void(kernel_fn, index_space, a, b, c) | ||
# signature of the kernel_fn | ||
kernel_sig = types.void(a, b, c) | ||
kernel_bitcode = kernel_fn.dispatcher.compile(kernel_sig) | ||
kernel_targetctx = kernel_fn.dispatcher.targetctx | ||
|
||
def codegen(cgctx, builder, sig, llargs): | ||
# # do something here to fetch the pointer to the kernel, or bitcode, or | ||
# # llvm IR that you need for doing the actual kernel launch (should be | ||
# # available from the kernel_inst compiled above). | ||
# kernel_work_details = <get that thing above> | ||
# # now generate the call to the driver to launch the kernel | ||
# builder.call(<driver function>, kernel_work_details, llargs) | ||
kernel_args = kernel_sig.args | ||
_create_kernel_launcher_body( | ||
cgctx, kernel_targetctx, builder, kernel_args, kernel_bitcode | ||
) | ||
|
||
return sig, codegen | ||
|
||
|
||
@dpjit | ||
def call_kernel(a, b, c, kernel_fn, index_space): | ||
launch_trampoline(a, b, c, kernel_fn, index_space) |