Skip to content

Commit

Permalink
Added unit test code
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Nov 1, 2023
1 parent 842ef18 commit 0b1f6bd
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 9 deletions.
19 changes: 10 additions & 9 deletions numba_dpex/dpnp_iface/array_sequence_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from numba.core.types.misc import NoneType
from numba.core.types.scalars import Boolean, Complex, Float, Integer
from numba.extending import intrinsic, overload
from numba.np.numpy_support import is_nonelike

import numba_dpex.utils as utils
from numba_dpex.core.types import DpnpNdArray
Expand Down Expand Up @@ -168,7 +169,7 @@ def _normalize(builder, src, src_type, dest_type):
else:
return builder.uitofp(src, dest_llvm_type)
elif isinstance(src_type, Float) and isinstance(dest_type, Integer):
if src_type.signed:
if dest_type.signed:
return builder.fptosi(src, dest_llvm_type)
else:
return builder.fptoui(src, dest_llvm_type)
Expand Down Expand Up @@ -201,14 +202,14 @@ def _compute_array_length_ir(
u64 = llvmir.IntType(64)
one = context.get_constant(types.float64, 1)

ll = _normalize(builder, start_ir, start_arg_type, types.float64)
ul = _normalize(builder, stop_ir, stop_arg_type, types.float64)
lb = _normalize(builder, start_ir, start_arg_type, types.float64)
ub = _normalize(builder, stop_ir, stop_arg_type, types.float64)
d = _normalize(builder, step_ir, step_arg_type, types.float64)

# Doing ceil(a,b) = (a-1)/b + 1 to avoid overflow
array_length = builder.fptosi(
builder.fadd(
builder.fdiv(builder.fsub(builder.fsub(ul, ll), one), d),
builder.fdiv(builder.fsub(builder.fsub(ub, lb), one), d),
one,
),
u64,
Expand Down Expand Up @@ -379,27 +380,27 @@ def ol_dpnp_arange(
sycl_queue=None,
):
if isinstance(start, Complex) or (
dtype is not None and isinstance(dtype.dtype, Complex)
(not isinstance(dtype, NoneType)) and isinstance(dtype.dtype, Complex)
):
raise errors.NumbaNotImplementedError(
"Complex type is not supported yet."
)
if isinstance(start, Boolean) or (
dtype is not None and isinstance(dtype.dtype, Boolean)
(not isinstance(dtype, NoneType)) and isinstance(dtype.dtype, Boolean)
):
raise errors.NumbaTypeError(
"Boolean is not supported by dpnp.arange()."
)

if stop is None:
if is_nonelike(stop):
start = 0
stop = 1
if step is None:
if is_nonelike(step):
step = 1

_dtype = (
_parse_dtype(dtype)
if dtype is not None
if not is_nonelike(dtype)
else _parse_dtype_from_range(start, stop, step)
)

Expand Down
64 changes: 64 additions & 0 deletions numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_arange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""Tests for dpnp.arange() constructor."""

import dpctl
import dpnp
import numpy as np
import pytest
from numba import errors

from numba_dpex import dpjit
from numba_dpex.tests._helper import get_all_dtypes

dtypes = get_all_dtypes(
no_bool=True, no_float16=True, no_none=False, no_complex=True
)
usm_types = ["device", "shared", "host"]
ranges = [
[1, None, None],
[1, 10, None],
[1, 10, 1],
pytest.param(
[-1.0, None, None],
marks=pytest.mark.xfail(
reason="numba-dpex can't allocate an empty array"
),
),
pytest.param(
[-1.0, 10, -2], marks=pytest.mark.xfail(reason="infinite loop")
),
]


@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("range", ranges)
@pytest.mark.parametrize("usm_type", usm_types)
def test_dpnp_arange_start(range, dtype, usm_type):
@dpjit
def func(lb, ub, dx, dt, ut, dv):
x = dpnp.arange(lb, ub, dx, dtype=dt, usm_type=ut, device=dv)
return x

device = dpctl.SyclDevice().filter_string
try:
c = func(*range, dtype, usm_type, device)
except Exception:
pytest.fail("Calling dpnp.arange() inside dpjit failed.")

a = dpnp.arange(*range, dtype=dtype, usm_type=usm_type, device=device)

print(a, c)

assert a.dtype == c.dtype
assert np.array_equal(a.asnumpy(), c.asnumpy())
assert c.usm_type == usm_type
assert c.sycl_device.filter_string == device
if c.sycl_queue != dpctl._sycl_queue_manager.get_device_cached_queue(
device
):
pytest.xfail(
"Returned queue does not have the same queue as cached against the device."
)

0 comments on commit 0b1f6bd

Please sign in to comment.