-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds an experimental KernelDispacther to numba_dpex.
- The numba_dpex.experimental module adds a new dispatcher class for numba_dpex kernels. The new dispatcher is a numba.dispatcher.Dispathcer subclass. - Introduce a new compiler class that is used to compile a numba_dpex.kernel decorated function to spirv and then store the spirv module as the compiled "overload". - Adds an experimental `call_kernel` dpjit function that will be used to submit or launch kernels. The `call_kernel` function generates LLVM IR code for all the functionality currenty done in pure Python in JitKernel.__call__.
- Loading branch information
Diptorup Deb
committed
Oct 24, 2023
1 parent
6466a57
commit 9ac4dfb
Showing
7 changed files
with
724 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# SPDX-FileCopyrightText: 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from numba.core.imputils import Registry | ||
|
||
from .decorators import kernel | ||
from .kernel_dispatcher import KernelDispatcher | ||
from .launcher import call_kernel | ||
from .models import * | ||
from .types import KernelDispatcherType | ||
|
||
registry = Registry() | ||
lower_constant = registry.lower_constant | ||
|
||
|
||
@lower_constant(KernelDispatcherType) | ||
def dpex_dispatcher_const(context, builder, ty, pyval): | ||
return context.get_dummy_value() | ||
|
||
|
||
__all__ = ["kernel", "KernelDispatcher", "dpex_dispatcher_const"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# SPDX-FileCopyrightText: 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import inspect | ||
|
||
from numba.core import sigutils | ||
|
||
from .kernel_dispatcher import KernelDispatcher | ||
|
||
|
||
def kernel(func_or_sig=None, debug=False, cache=False, **options): | ||
"""A decorator to define a kernel function. | ||
A kernel function is conceptually equivalent to a SYCL kernel function, and | ||
gets compiled into either an OpenCL or a LevelZero SPIR-V binary kernel. | ||
A kernel decorated Python function has the following restrictions: | ||
* The function can not return any value. | ||
* All array arguments passed to a kernel should adhere to compute | ||
follows data programming model. | ||
""" | ||
# FIXME: The options need to be evaluated and checked here like it is | ||
# done in numba.core.decorators.jit | ||
|
||
def _kernel_dispatcher(pyfunc, sigs=None): | ||
return KernelDispatcher( | ||
pyfunc=pyfunc, | ||
debug_flags=debug, | ||
enable_cache=cache, | ||
specialization_sigs=sigs, | ||
targetoptions=options, | ||
) | ||
|
||
if func_or_sig is None: | ||
return _kernel_dispatcher | ||
elif isinstance(func_or_sig, str): | ||
raise NotImplementedError( | ||
"Specifying signatures as string is not yet supported by numba-dpex" | ||
) | ||
elif isinstance(func_or_sig, list) or sigutils.is_signature(func_or_sig): | ||
# String signatures are not supported as passing usm_ndarray type as | ||
# a string is not possible. Numba's sigutils relies on the type being | ||
# available in Numba's `types.__dict__` and dpex types are not | ||
# registered there yet. | ||
if isinstance(func_or_sig, list): | ||
for sig in func_or_sig: | ||
if isinstance(sig, str): | ||
raise NotImplementedError( | ||
"Specifying signatures as string is not yet supported " | ||
"by numba-dpex" | ||
) | ||
# Specialized signatures can either be a single signature or a list. | ||
# In case only one signature is provided convert it to a list | ||
if not isinstance(func_or_sig, list): | ||
func_or_sig = [func_or_sig] | ||
|
||
def _specialized_kernel_dispatcher(pyfunc): | ||
return KernelDispatcher( | ||
pyfunc=pyfunc, | ||
debug_flags=debug, | ||
enable_cache=cache, | ||
specialization_sigs=func_or_sig, | ||
) | ||
|
||
return _specialized_kernel_dispatcher | ||
else: | ||
func = func_or_sig | ||
if not inspect.isfunction(func): | ||
raise ValueError( | ||
"Argument passed to the kernel decorator is neither a " | ||
"function object, nor a signature. If you are trying to " | ||
"specialize the kernel that takes a single argument, specify " | ||
"the return type as void explicitly." | ||
) | ||
return _kernel_dispatcher(func) |
Oops, something went wrong.