-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Diptorup Deb
committed
Oct 18, 2023
1 parent
466dbf5
commit 8c60c75
Showing
3 changed files
with
802 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# SPDX-FileCopyrightText: 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .decorators import kernel | ||
from .kernel_dispatcher import KernelDispatcher | ||
|
||
__all__ = ["kernel", "KernelDispatcher"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# SPDX-FileCopyrightText: 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import inspect | ||
|
||
from numba.core import sigutils | ||
|
||
from .kernel_dispatcher import KernelDispatcher | ||
|
||
|
||
def kernel(func_or_sig=None, debug=False, cache=False, **options): | ||
"""A decorator to define a kernel function. | ||
A kernel function is conceptually equivalent to a SYCL kernel function, and | ||
gets compiled into either an OpenCL or a LevelZero SPIR-V binary kernel. | ||
A kernel decorated Python function has the following restrictions: | ||
* The function can not return any value. | ||
* All array arguments passed to a kernel should adhere to compute | ||
follows data programming model. | ||
""" | ||
# FIXME: The options need to be evaluated and checked here like it is | ||
# done in numba.core.decorators.jit | ||
|
||
def _kernel_dispatcher(pyfunc, sigs=None): | ||
return KernelDispatcher( | ||
pyfunc=pyfunc, | ||
debug_flags=debug, | ||
enable_cache=cache, | ||
specialization_sigs=sigs, | ||
targetoptions=options, | ||
) | ||
|
||
if func_or_sig is None: | ||
return _kernel_dispatcher | ||
elif isinstance(func_or_sig, str): | ||
raise NotImplementedError( | ||
"Specifying signatures as string is not yet supported by numba-dpex" | ||
) | ||
elif isinstance(func_or_sig, list) or sigutils.is_signature(func_or_sig): | ||
# String signatures are not supported as passing usm_ndarray type as | ||
# a string is not possible. Numba's sigutils relies on the type being | ||
# available in Numba's `types.__dict__` and dpex types are not | ||
# registered there yet. | ||
if isinstance(func_or_sig, list): | ||
for sig in func_or_sig: | ||
if isinstance(sig, str): | ||
raise NotImplementedError( | ||
"Specifying signatures as string is not yet supported " | ||
"by numba-dpex" | ||
) | ||
# Specialized signatures can either be a single signature or a list. | ||
# In case only one signature is provided convert it to a list | ||
if not isinstance(func_or_sig, list): | ||
func_or_sig = [func_or_sig] | ||
|
||
def _specialized_kernel_dispatcher(pyfunc): | ||
return KernelDispatcher( | ||
pyfunc=pyfunc, | ||
debug_flags=debug, | ||
enable_cache=cache, | ||
specialization_sigs=func_or_sig, | ||
) | ||
|
||
return _specialized_kernel_dispatcher | ||
else: | ||
func = func_or_sig | ||
if not inspect.isfunction(func): | ||
raise ValueError( | ||
"Argument passed to the kernel decorator is neither a " | ||
"function object, nor a signature. If you are trying to " | ||
"specialize the kernel that takes a single argument, specify " | ||
"the return type as void explicitly." | ||
) | ||
return _kernel_dispatcher(func) |
Oops, something went wrong.