diff --git a/numba_dpex/dpnp_iface/array_sequence_ops.py b/numba_dpex/dpnp_iface/array_sequence_ops.py index 0a28f45cb4..815cb06f4b 100644 --- a/numba_dpex/dpnp_iface/array_sequence_ops.py +++ b/numba_dpex/dpnp_iface/array_sequence_ops.py @@ -13,6 +13,7 @@ from numba.core.types.misc import NoneType from numba.core.types.scalars import Boolean, Complex, Float, Integer from numba.extending import intrinsic, overload +from numba.np.numpy_support import is_nonelike import numba_dpex.utils as utils from numba_dpex.core.types import DpnpNdArray @@ -168,7 +169,7 @@ def _normalize(builder, src, src_type, dest_type): else: return builder.uitofp(src, dest_llvm_type) elif isinstance(src_type, Float) and isinstance(dest_type, Integer): - if src_type.signed: + if dest_type.signed: return builder.fptosi(src, dest_llvm_type) else: return builder.fptoui(src, dest_llvm_type) @@ -201,14 +202,14 @@ def _compute_array_length_ir( u64 = llvmir.IntType(64) one = context.get_constant(types.float64, 1) - ll = _normalize(builder, start_ir, start_arg_type, types.float64) - ul = _normalize(builder, stop_ir, stop_arg_type, types.float64) + lb = _normalize(builder, start_ir, start_arg_type, types.float64) + ub = _normalize(builder, stop_ir, stop_arg_type, types.float64) d = _normalize(builder, step_ir, step_arg_type, types.float64) # Doing ceil(a,b) = (a-1)/b + 1 to avoid overflow array_length = builder.fptosi( builder.fadd( - builder.fdiv(builder.fsub(builder.fsub(ul, ll), one), d), + builder.fdiv(builder.fsub(builder.fsub(ub, lb), one), d), one, ), u64, @@ -379,27 +380,27 @@ def ol_dpnp_arange( sycl_queue=None, ): if isinstance(start, Complex) or ( - dtype is not None and isinstance(dtype.dtype, Complex) + (not isinstance(dtype, NoneType)) and isinstance(dtype.dtype, Complex) ): raise errors.NumbaNotImplementedError( "Complex type is not supported yet." ) if isinstance(start, Boolean) or ( - dtype is not None and isinstance(dtype.dtype, Boolean) + (not isinstance(dtype, NoneType)) and isinstance(dtype.dtype, Boolean) ): raise errors.NumbaTypeError( "Boolean is not supported by dpnp.arange()." ) - if stop is None: + if is_nonelike(stop): start = 0 stop = 1 - if step is None: + if is_nonelike(step): step = 1 _dtype = ( _parse_dtype(dtype) - if dtype is not None + if not is_nonelike(dtype) else _parse_dtype_from_range(start, stop, step) ) diff --git a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_arange.py b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_arange.py new file mode 100644 index 0000000000..3dbb40ee2c --- /dev/null +++ b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_arange.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for dpnp.arange() constructor.""" + +import dpctl +import dpnp +import numpy as np +import pytest +from numba import errors + +from numba_dpex import dpjit +from numba_dpex.tests._helper import get_all_dtypes + +dtypes = get_all_dtypes( + no_bool=True, no_float16=True, no_none=False, no_complex=True +) +usm_types = ["device", "shared", "host"] +ranges = [ + [1, None, None], + [1, 10, None], + [1, 10, 1], + pytest.param( + [-1.0, None, None], + marks=pytest.mark.xfail( + reason="numba-dpex can't allocate an empty array" + ), + ), + pytest.param( + [-1.0, 10, -2], marks=pytest.mark.xfail(reason="infinite loop") + ), +] + + +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("range", ranges) +@pytest.mark.parametrize("usm_type", usm_types) +def test_dpnp_arange_start(range, dtype, usm_type): + @dpjit + def func(lb, ub, dx, dt, ut, dv): + x = dpnp.arange(lb, ub, dx, dtype=dt, usm_type=ut, device=dv) + return x + + device = dpctl.SyclDevice().filter_string + try: + c = func(*range, dtype, usm_type, device) + except Exception: + pytest.fail("Calling dpnp.arange() inside dpjit failed.") + + a = dpnp.arange(*range, dtype=dtype, usm_type=usm_type, device=device) + + print(a, c) + + assert a.dtype == c.dtype + assert np.array_equal(a.asnumpy(), c.asnumpy()) + assert c.usm_type == usm_type + assert c.sycl_device.filter_string == device + if c.sycl_queue != dpctl._sycl_queue_manager.get_device_cached_queue( + device + ): + pytest.xfail( + "Returned queue does not have the same queue as cached against the device." + )