Skip to content

Commit

Permalink
Move atomic_ref overloads to experimental target.
Browse files Browse the repository at this point in the history
   - All overloads are now added to the experimental target.
   - The inline keyword is no longer set to "always" for the
     atomic_ref overloads.
  • Loading branch information
Diptorup Deb committed Nov 15, 2023
1 parent 5d5626e commit a316b23
Showing 1 changed file with 18 additions and 35 deletions.
53 changes: 18 additions & 35 deletions numba_dpex/experimental/dpcpp_iface/atomic_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
from numba.extending import intrinsic, overload, overload_method

from numba_dpex.core import itanium_mangler as ext_itanium_mangler
from numba_dpex.core.targets.kernel_target import (
CC_SPIR_FUNC,
DPEX_KERNEL_TARGET_NAME,
)
from numba_dpex.core.targets.kernel_target import CC_SPIR_FUNC
from numba_dpex.core.types import USMNdArray

from ..dpcpp_types import AtomicRefType
from ..target import DPEX_KERNEL_EXP_TARGET_NAME
from ._spv_atomic_helper import (
get_atomic_inst_name,
get_memory_semantics_mask,
Expand Down Expand Up @@ -148,42 +146,42 @@ def gen(context, builder, sig, args):
return sig, gen


@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_add(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_add")


@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_sub(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_sub")


@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_min(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_min")


@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_max(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_max")


@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_and(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_and")


@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_or(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_or")


@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_xor(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_xor")


@intrinsic(target=DPEX_KERNEL_TARGET_NAME)
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_atomic_ref_ctor(ty_context, ref, ty_retty_ref):
from ..target import dpex_exp_kernel_target

Expand Down Expand Up @@ -250,8 +248,7 @@ def _check_if_supported_ref(ref):
@overload(
AtomicRef,
prefer_literal=True,
inline="always",
target=DPEX_KERNEL_TARGET_NAME,
target=DPEX_KERNEL_EXP_TARGET_NAME,
)
def ol_atomic_ref(ref, memory_order, memory_scope, address_space):
_check_if_supported_ref(ref)
Expand All @@ -274,9 +271,7 @@ def ol_atomic_ref_ctor_impl(ref, memory_order, memory_scope, address_space):
return ol_atomic_ref_ctor_impl


@overload_method(
AtomicRefType, "fetch_add", inline="always", target=DPEX_KERNEL_TARGET_NAME
)
@overload_method(AtomicRefType, "fetch_add", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_add(atomic_ref, val):
if atomic_ref.dtype != val:
raise errors.TypingError(
Expand All @@ -290,9 +285,7 @@ def ol_fetch_add_impl(atomic_ref, val):
return ol_fetch_add_impl


@overload_method(
AtomicRefType, "fetch_sub", inline="always", target=DPEX_KERNEL_TARGET_NAME
)
@overload_method(AtomicRefType, "fetch_sub", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_sub(atomic_ref, val):
if atomic_ref.dtype != val:
raise errors.TypingError(
Expand All @@ -306,9 +299,7 @@ def ol_fetch_sub_impl(atomic_ref, val):
return ol_fetch_sub_impl


@overload_method(
AtomicRefType, "fetch_min", inline="always", target=DPEX_KERNEL_TARGET_NAME
)
@overload_method(AtomicRefType, "fetch_min", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_min(atomic_ref, val):
if atomic_ref.dtype != val:
raise errors.TypingError(
Expand All @@ -322,9 +313,7 @@ def ol_fetch_min_impl(atomic_ref, val):
return ol_fetch_min_impl


@overload_method(
AtomicRefType, "fetch_max", inline="always", target=DPEX_KERNEL_TARGET_NAME
)
@overload_method(AtomicRefType, "fetch_max", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_max(atomic_ref, val):
if atomic_ref.dtype != val:
raise errors.TypingError(
Expand All @@ -338,9 +327,7 @@ def ol_fetch_max_impl(atomic_ref, val):
return ol_fetch_max_impl


@overload_method(
AtomicRefType, "fetch_and", inline="always", target=DPEX_KERNEL_TARGET_NAME
)
@overload_method(AtomicRefType, "fetch_and", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_and(atomic_ref, val):
if atomic_ref.dtype != val:
raise errors.TypingError(
Expand All @@ -359,9 +346,7 @@ def ol_fetch_and_impl(atomic_ref, val):
return ol_fetch_and_impl


@overload_method(
AtomicRefType, "fetch_or", inline="always", target=DPEX_KERNEL_TARGET_NAME
)
@overload_method(AtomicRefType, "fetch_or", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_or(atomic_ref, val):
if atomic_ref.dtype != val:
raise errors.TypingError(
Expand All @@ -380,9 +365,7 @@ def ol_fetch_or_impl(atomic_ref, val):
return ol_fetch_or_impl


@overload_method(
AtomicRefType, "fetch_xor", inline="always", target=DPEX_KERNEL_TARGET_NAME
)
@overload_method(AtomicRefType, "fetch_xor", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_xor(atomic_ref, val):
if atomic_ref.dtype != val:
raise errors.TypingError(
Expand Down

0 comments on commit a316b23

Please sign in to comment.