Skip to content

Commit

Permalink
Integrate with experimental kernel launcher
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Nov 2, 2023
1 parent cecec87 commit 0d421dc
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 17 deletions.
52 changes: 36 additions & 16 deletions numba_dpex/core/kernel_interface/dpcpp_iface/atomic_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from numba_dpex.core import itanium_mangler as ext_itanium_mangler
from numba_dpex.core.datamodel.models import dpex_data_model_manager
from numba_dpex.core.exceptions import UnreachableError
from numba_dpex.core.targets.kernel_target import DPEX_KERNEL_TARGET_NAME
from numba_dpex.core.types import AtomicRefType, USMNdArray

from ._spv_atomic_helper import (
Expand Down Expand Up @@ -124,42 +125,42 @@ def gen(context, builder, sig, args):
return sig, gen


@intrinsic
@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
def _intrinsic_fetch_add(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_add")


@intrinsic
@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
def _intrinsic_fetch_sub(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_sub")


@intrinsic
@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
def _intrinsic_fetch_min(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_min")


@intrinsic
@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
def _intrinsic_fetch_max(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_max")


@intrinsic
@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
def _intrinsic_fetch_and(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_and")


@intrinsic
@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
def _intrinsic_fetch_or(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_or")


@intrinsic
@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
def _intrinsic_fetch_xor(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_xor")


@intrinsic
@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
def _intrinsic_atomic_ref_ctor(ty_context, ref, ty_retty_ref):
ty_retty = ty_retty_ref.instance_type
sig = ty_retty(ref, ty_retty_ref)
Expand Down Expand Up @@ -223,7 +224,12 @@ def _check_if_supported_ref(ref):
return supported


@overload(AtomicRef, prefer_literal=True, inline="always")
@overload(
AtomicRef,
prefer_literal=True,
inline="always",
target=DPEX_KERNEL_TARGET_NAME,
)
def ol_atomic_ref(ref, memory_order, memory_scope, address_space):
_check_if_supported_ref(ref)

Expand All @@ -241,7 +247,9 @@ def impl(ref, memory_order, memory_scope, address_space):
return impl


@overload_method(AtomicRefType, "fetch_add", inline="always")
@overload_method(
AtomicRefType, "fetch_add", inline="always", target=DPEX_KERNEL_TARGET_NAME
)
def ol_fetch_add(atomic_ref, val):
if atomic_ref.dtype != val:
raise errors.TypingError(
Expand All @@ -255,7 +263,9 @@ def impl(atomic_ref, val):
return impl


@overload_method(AtomicRefType, "fetch_sub", inline="always")
@overload_method(
AtomicRefType, "fetch_sub", inline="always", target=DPEX_KERNEL_TARGET_NAME
)
def ol_fetch_sub(atomic_ref, val):
if atomic_ref.dtype != val:
raise errors.TypingError(
Expand All @@ -269,7 +279,9 @@ def impl(atomic_ref, val):
return impl


@overload_method(AtomicRefType, "fetch_min", inline="always")
@overload_method(
AtomicRefType, "fetch_min", inline="always", target=DPEX_KERNEL_TARGET_NAME
)
def ol_fetch_min(atomic_ref, val):
if atomic_ref.dtype != val:
raise errors.TypingError(
Expand All @@ -283,7 +295,9 @@ def impl(atomic_ref, val):
return impl


@overload_method(AtomicRefType, "fetch_max", inline="always")
@overload_method(
AtomicRefType, "fetch_max", inline="always", target=DPEX_KERNEL_TARGET_NAME
)
def ol_fetch_max(atomic_ref, val):
if atomic_ref.dtype != val:
raise errors.TypingError(
Expand All @@ -297,7 +311,9 @@ def impl(atomic_ref, val):
return impl


@overload_method(AtomicRefType, "fetch_and", inline="always")
@overload_method(
AtomicRefType, "fetch_and", inline="always", target=DPEX_KERNEL_TARGET_NAME
)
def ol_fetch_and(atomic_ref, val):
if atomic_ref.dtype != val:
raise errors.TypingError(
Expand All @@ -316,7 +332,9 @@ def impl(atomic_ref, val):
return impl


@overload_method(AtomicRefType, "fetch_or", inline="always")
@overload_method(
AtomicRefType, "fetch_or", inline="always", target=DPEX_KERNEL_TARGET_NAME
)
def ol_fetch_or(atomic_ref, val):
if atomic_ref.dtype != val:
raise errors.TypingError(
Expand All @@ -335,7 +353,9 @@ def impl(atomic_ref, val):
return impl


@overload_method(AtomicRefType, "fetch_xor", inline="always")
@overload_method(
AtomicRefType, "fetch_xor", inline="always", target=DPEX_KERNEL_TARGET_NAME
)
def ol_fetch_xor(atomic_ref, val):
if atomic_ref.dtype != val:
raise errors.TypingError(
Expand Down
2 changes: 1 addition & 1 deletion numba_dpex/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,4 @@ def dpjit(*args, **kws):
# add it to the decorator registry, this is so e.g. @overload can look up a
# JIT function to do the compilation work.
jit_registry[target_registry["dpex"]] = dpjit
jit_registry[target_registry["dpex_kernel"]] = kernel
# jit_registry[target_registry["dpex_kernel"]] = kernel
1 change: 1 addition & 0 deletions numba_dpex/experimental/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import inspect

from numba.core import sigutils
from numba.core.target_extension import jit_registry, target_registry

from .kernel_dispatcher import KernelDispatcher

Expand Down

0 comments on commit 0d421dc

Please sign in to comment.