diff --git a/numba_dpex/core/runtime/kernels/tensor/include/sequences.hpp b/numba_dpex/core/runtime/kernels/tensor/include/sequences.hpp index dc87eefdc0..6c603f498c 100644 --- a/numba_dpex/core/runtime/kernels/tensor/include/sequences.hpp +++ b/numba_dpex/core/runtime/kernels/tensor/include/sequences.hpp @@ -19,6 +19,8 @@ #include "typeutils.hpp" +namespace dpexrt_tensor = dpex::rt::kernel::tensor; + namespace dpex { namespace rt @@ -47,8 +49,7 @@ template class SequenceStepFunctor void operator()(sycl::id<1> wiid) const { auto i = wiid.get(0); - if constexpr (dpex::rt::kernel::tensor::typeutils::is_complex::value) - { + if constexpr (dpexrt_tensor::typeutils::is_complex::value) { p[i] = T{start_v.real() + i * step_v.real(), start_v.imag() + i * step_v.imag()}; } @@ -78,8 +79,7 @@ template class AffineSequenceFunctor auto i = wiid.get(0); wT wc = wT(i) / n; wT w = wT(n - i) / n; - if constexpr (dpex::rt::kernel::tensor::typeutils::is_complex::value) - { + if constexpr (dpexrt_tensor::typeutils::is_complex::value) { using reT = typename T::value_type; auto _w = static_cast(w); auto _wc = static_cast(wc); @@ -104,7 +104,7 @@ template class AffineSequenceFunctor } else { auto affine_comb = start_v * w + end_v * wc; - p[i] = dpex::rt::kernel::tensor::typeutils::convert_impl< + p[i] = dpexrt_tensor::typeutils::convert_impl< T, decltype(affine_comb)>(affine_comb); } } @@ -119,14 +119,14 @@ sycl::event sequence_step_kernel(sycl::queue exec_q, const std::vector &depends) { std::cout << "sequqnce_step_kernel<" - << dpex::rt::kernel::tensor::typeutils::demangle() + << dpexrt_tensor::typeutils::demangle() << ">(): nelems = " << nelems << ", start_v = " << start_v << ", step_v = " << step_v << std::endl; - dpex::rt::kernel::tensor::typeutils::validate_type_for_device(exec_q); + dpexrt_tensor::typeutils::validate_type_for_device(exec_q); std::cout << "sequqnce_step_kernel<" - << dpex::rt::kernel::tensor::typeutils::demangle() + << dpexrt_tensor::typeutils::demangle() << ">(): validate_type_for_device(exec_q) = done" << std::endl; sycl::event seq_step_event = exec_q.submit([&](sycl::handler &cgh) { @@ -148,7 +148,7 @@ sycl::event affine_sequence_kernel(sycl::queue &exec_q, char *array_data, const std::vector &depends) { - dpex::rt::kernel::tensor::typeutils::validate_type_for_device(exec_q); + dpexrt_tensor::typeutils::validate_type_for_device(exec_q); bool device_supports_doubles = exec_q.get_device().has(sycl::aspect::fp64); sycl::event affine_seq_step_event = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -187,8 +187,7 @@ sycl::event sequence_step(sycl::queue &exec_q, std::cerr << e.what() << std::endl; } - std::cout << "sequqnce_step()<" - << dpex::rt::kernel::tensor::typeutils::demangle() + std::cout << "sequqnce_step()<" << dpexrt_tensor::typeutils::demangle() << ">: nelems = " << nelems << ", *start_v = " << (*start_v) << ", *step_v = " << (*step_v) << std::endl; diff --git a/numba_dpex/core/runtime/kernels/tensor/include/typeutils.hpp b/numba_dpex/core/runtime/kernels/tensor/include/typeutils.hpp index b25462a334..8ce3864f02 100644 --- a/numba_dpex/core/runtime/kernels/tensor/include/typeutils.hpp +++ b/numba_dpex/core/runtime/kernels/tensor/include/typeutils.hpp @@ -172,118 +172,6 @@ template void validate_type_for_device(const sycl::queue &q) validate_type_for_device(q.get_device()); } -// template -// auto vec_cast_impl(const Vec &v, std::index_sequence) -// { -// return Op{v[I]...}; -// } - -// template > -// auto vec_cast(const sycl::vec &s) -// { -// if constexpr (std::is_same_v) { -// return s; -// } -// else { -// return vec_cast_impl, sycl::vec>(s, -// Indices{}); -// } -// } - -// struct usm_ndarray_types -// { -// int typenum_to_lookup_id(int typenum) const -// { -// // using typenum_t = ::dpctl::tensor::type_dispatch::typenum_t; -// auto const &api = ::dpctl::detail::dpctl_capi::get(); - -// if (typenum == api.UAR_DOUBLE_) { -// return static_cast(typenum_t::DOUBLE); -// } -// else if (typenum == api.UAR_INT64_) { -// return static_cast(typenum_t::INT64); -// } -// else if (typenum == api.UAR_INT32_) { -// return static_cast(typenum_t::INT32); -// } -// else if (typenum == api.UAR_BOOL_) { -// return static_cast(typenum_t::BOOL); -// } -// else if (typenum == api.UAR_CDOUBLE_) { -// return static_cast(typenum_t::CDOUBLE); -// } -// else if (typenum == api.UAR_FLOAT_) { -// return static_cast(typenum_t::FLOAT); -// } -// else if (typenum == api.UAR_INT16_) { -// return static_cast(typenum_t::INT16); -// } -// else if (typenum == api.UAR_INT8_) { -// return static_cast(typenum_t::INT8); -// } -// else if (typenum == api.UAR_UINT64_) { -// return static_cast(typenum_t::UINT64); -// } -// else if (typenum == api.UAR_UINT32_) { -// return static_cast(typenum_t::UINT32); -// } -// else if (typenum == api.UAR_UINT16_) { -// return static_cast(typenum_t::UINT16); -// } -// else if (typenum == api.UAR_UINT8_) { -// return static_cast(typenum_t::UINT8); -// } -// else if (typenum == api.UAR_CFLOAT_) { -// return static_cast(typenum_t::CFLOAT); -// } -// else if (typenum == api.UAR_HALF_) { -// return static_cast(typenum_t::HALF); -// } -// else if (typenum == api.UAR_INT_ || typenum == api.UAR_UINT_) { -// switch (sizeof(int)) { -// case sizeof(int32_t): -// return ((typenum == api.UAR_INT_) -// ? static_cast(typenum_t::INT32) -// : static_cast(typenum_t::UINT32)); -// case sizeof(int64_t): -// return ((typenum == api.UAR_INT_) -// ? static_cast(typenum_t::INT64) -// : static_cast(typenum_t::UINT64)); -// default: -// throw_unrecognized_typenum_error(typenum); -// } -// } -// else if (typenum == api.UAR_LONGLONG_ || typenum == -// api.UAR_ULONGLONG_) -// { -// switch (sizeof(long long)) { -// case sizeof(int64_t): -// return ((typenum == api.UAR_LONGLONG_) -// ? static_cast(typenum_t::INT64) -// : static_cast(typenum_t::UINT64)); -// default: -// throw_unrecognized_typenum_error(typenum); -// } -// } -// else { -// throw_unrecognized_typenum_error(typenum); -// } -// // return code signalling error, should never be reached -// assert(false); -// return -1; -// } - -// private: -// void throw_unrecognized_typenum_error(int typenum) const -// { -// throw std::runtime_error("Unrecognized typenum " + -// std::to_string(typenum) + " encountered."); -// } -// }; - } // namespace typeutils } // namespace tensor } // namespace kernel diff --git a/numba_dpex/core/runtime/kernels/tensor/src/sequences.cpp b/numba_dpex/core/runtime/kernels/tensor/src/sequences.cpp index 361b87994a..739eb718ee 100644 --- a/numba_dpex/core/runtime/kernels/tensor/src/sequences.cpp +++ b/numba_dpex/core/runtime/kernels/tensor/src/sequences.cpp @@ -12,34 +12,31 @@ #include "../include/typeutils.hpp" #include "../include/api.h" -static dpex::rt::kernel::tensor::sequence_step_ptr_t - sequence_step_dispatch_vector - [dpex::rt::kernel::tensor::typeutils::num_types]; +namespace dpexrt_tensor = dpex::rt::kernel::tensor; -static dpex::rt::kernel::tensor::affine_sequence_ptr_t - affine_sequence_dispatch_vector - [dpex::rt::kernel::tensor::typeutils::num_types]; +static dpexrt_tensor::sequence_step_ptr_t + sequence_step_dispatch_vector[dpexrt_tensor::typeutils::num_types]; + +static dpexrt_tensor::affine_sequence_ptr_t + affine_sequence_dispatch_vector[dpexrt_tensor::typeutils::num_types]; extern "C" void NUMBA_DPEX_SYCL_KERNEL_init_sequence_step_dispatch_vectors() { - dpex::rt::kernel::tensor::dispatch::DispatchVectorBuilder< - dpex::rt::kernel::tensor::sequence_step_ptr_t, - dpex::rt::kernel::tensor::SequenceStepFactory, - dpex::rt::kernel::tensor::typeutils::num_types> + dpexrt_tensor::dispatch::DispatchVectorBuilder< + dpexrt_tensor::sequence_step_ptr_t, dpexrt_tensor::SequenceStepFactory, + dpexrt_tensor::typeutils::num_types> dvb; dvb.populate_dispatch_vector(sequence_step_dispatch_vector); - std::cout << "-----> init_sequence_dispatch_vectors()" << std::endl; } extern "C" void NUMBA_DPEX_SYCL_KERNEL_init_affine_sequence_dispatch_vectors() { - dpex::rt::kernel::tensor::dispatch::DispatchVectorBuilder< - dpex::rt::kernel::tensor::affine_sequence_ptr_t, - dpex::rt::kernel::tensor::AffineSequenceFactory, - dpex::rt::kernel::tensor::typeutils::num_types> + dpexrt_tensor::dispatch::DispatchVectorBuilder< + dpexrt_tensor::affine_sequence_ptr_t, + dpexrt_tensor::AffineSequenceFactory, + dpexrt_tensor::typeutils::num_types> dvb; dvb.populate_dispatch_vector(affine_sequence_dispatch_vector); - std::cout << "-----> init_affine_sequence_dispatch_vectors()" << std::endl; } extern "C" uint NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence( @@ -53,14 +50,12 @@ extern "C" uint NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence( { std::cout << "NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence:" << " start = " - << dpex::rt::kernel::tensor::typeutils::caste_using_typeid( - start, dst_typeid) + << dpexrt_tensor::typeutils::caste_using_typeid(start, dst_typeid) << std::endl; std::cout << "NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence:" << " dt = " - << dpex::rt::kernel::tensor::typeutils::caste_using_typeid( - dt, dst_typeid) + << dpexrt_tensor::typeutils::caste_using_typeid(dt, dst_typeid) << std::endl; if (ndim != 1) { @@ -98,7 +93,7 @@ extern "C" uint NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence( return 1; } -// uint dpex::rt::kernel::tensor::tensor::populate_arystruct_affine_sequence( +// uint dpexrt_tensor::tensor::populate_arystruct_affine_sequence( // void *start, // void *end, // arystruct_t *dst, diff --git a/numba_dpex/dpnp_iface/_intrinsic.py b/numba_dpex/dpnp_iface/_intrinsic.py index fad1b4e124..dc43703466 100644 --- a/numba_dpex/dpnp_iface/_intrinsic.py +++ b/numba_dpex/dpnp_iface/_intrinsic.py @@ -361,25 +361,13 @@ def alloc_empty_arrayobj(context, builder, sig, queue_ref, args, is_like=False): Returns: The LLVM IR value that stores the empty array """ - print("alloc_empty_arrayobj: sig =", sig) - print("alloc_empty_arrayobj: args =", args) - arrtype, shape = ( _parse_empty_like_args(context, builder, sig, args) if is_like else _parse_empty_args(context, builder, sig, args) ) - print( - "alloc_empty_arrayobj(): arrtype =", - arrtype, - "type(arrtype) =", - type(arrtype), - ) - print( - "alloc_empty_arrayobj(): shape =", shape, ", type(shape) =", type(shape) - ) ary = _empty_nd_impl(context, builder, arrtype, shape, queue_ref) - print("alloc_empty_arrayobj(): ary =", ary, ", type(ary) =", type(ary)) + return ary @@ -485,8 +473,6 @@ def impl_dpnp_empty( ty_retty_ref, ) - print("--- impl_dpnp_empty()") - sycl_queue_arg_pos = -2 def codegen(context, builder, sig, args): @@ -500,9 +486,6 @@ def codegen(context, builder, sig, args): sycl_queue_arg=sycl_queue_arg, ) - print("impl_dpnp_empty(): sig =", sig, type(sig)) - print("impl_dpnp_empty(): args =", args, type(args)) - ary = alloc_empty_arrayobj( context, builder, sig, qref_payload.queue_ref, args ) diff --git a/numba_dpex/dpnp_iface/array_sequence_ops.py b/numba_dpex/dpnp_iface/array_sequence_ops.py index 664b75fec5..93d2d278a1 100644 --- a/numba_dpex/dpnp_iface/array_sequence_ops.py +++ b/numba_dpex/dpnp_iface/array_sequence_ops.py @@ -2,40 +2,27 @@ # # SPDX-License-Identifier: Apache-2.0 -import math from collections import namedtuple -import dpctl.tensor as dpt import dpnp import numba import numpy as np -from dpctl.tensor._ctors import _coerce_and_infer_dt from llvmlite import ir as llvmir from numba import errors, types from numba.core import cgutils -from numba.core.types.misc import NoneType, UnicodeType -from numba.core.types.scalars import ( - Boolean, - Complex, - Float, - Integer, - IntegerLiteral, -) -from numba.core.typing.templates import Signature +from numba.core.types.misc import NoneType +from numba.core.types.scalars import Boolean, Complex, Float, Integer from numba.extending import intrinsic, overload import numba_dpex.utils as utils -from numba_dpex.core.runtime import context as dpexrt from numba_dpex.core.types import DpnpNdArray from numba_dpex.dpnp_iface._intrinsic import ( _ArgTyAndValue, _empty_nd_impl, _get_queue_ref, - alloc_empty_arrayobj, ) from numba_dpex.dpnp_iface.arrayobj import ( _parse_device_filter_string, - _parse_dim, _parse_dtype, _parse_usm_type, ) @@ -65,69 +52,25 @@ def _is_any_complex_type(value): return np.iscomplex(value) or isinstance(value, Complex) -def _compute_bitwidth(value): - print("_compute_bitwidth(): type(value) =", type(value)) - if ( - isinstance(value, Float) - or isinstance(value, Integer) - or isinstance(value, Complex) - ): - return value.bitwidth - elif ( - isinstance(value, np.floating) - or isinstance(value, np.integer) - or np.iscomplex(value) - ): - return value.itemsize * 8 - elif type(value) == float or type(value) == int: - return 64 - elif type(value) == complex: - return 128 - else: - msg = "dpnp_iface.array_sequence_ops._compute_bitwidth(): Unknwon type." - raise errors.NumbaValueError(msg) - - def _parse_dtype_from_range(start, stop, step): - max_bw = max( - _compute_bitwidth(start), - _compute_bitwidth(stop), - _compute_bitwidth(step), - ) if ( _is_any_complex_type(start) or _is_any_complex_type(stop) or _is_any_complex_type(step) ): - return ( - numba.from_dtype(dpnp.complex128) - if max_bw == 128 - else numba.from_dtype(dpnp.complex64) - ) + numba.from_dtype(dpnp.complex128) elif ( _is_any_float_type(start) or _is_any_float_type(stop) or _is_any_float_type(step) ): - if max_bw == 64: - return numba.from_dtype(dpnp.float64) - elif max_bw == 32: - return numba.from_dtype(dpnp.float32) - elif max_bw == 16: - return numba.from_dtype(dpnp.float16) - else: - return numba.from_dtype(dpnp.float) + return numba.from_dtype(dpnp.float64) elif ( _is_any_int_type(start) or _is_any_int_type(stop) or _is_any_int_type(step) ): - if max_bw == 64: - return numba.from_dtype(dpnp.int64) - elif max_bw == 32: - return numba.from_dtype(dpnp.int32) - else: - return numba.from_dtype(dpnp.int) + return numba.from_dtype(dpnp.int64) else: msg = ( "dpnp_iface.array_sequence_ops._parse_dtype_from_range(): " @@ -160,36 +103,6 @@ def _get_llvm_type(numba_type): raise errors.NumbaTypeError(msg) -def _get_constant(context, dtype, bitwidth, value): - if isinstance(dtype, Integer): - if bitwidth == 64: - return context.get_constant(types.int64, value) - elif bitwidth == 32: - return context.get_constant(types.int32, value) - elif bitwidth == 16: - return context.get_constant(types.int16, value) - elif bitwidth == 8: - return context.get_constant(types.int8, value) - elif isinstance(dtype, Float): - if bitwidth == 64: - return context.get_constant(types.float64, value) - elif bitwidth == 32: - return context.get_constant(types.float32, value) - elif bitwidth == 16: - return context.get_constant(types.float16, value) - elif isinstance(dtype, Complex): - if bitwidth == 128: - return context.get_constant(types.complex128, value) - elif bitwidth == 64: - return context.get_constant(types.complex64, value) - else: - msg = ( - "dpnp_iface.array_sequence_ops._get_constant():" - + " Couldn't infer type for the requested constant." - ) - raise errors.NumbaTypeError(msg) - - def _get_dst_typeid(dtype): if isinstance(dtype, Boolean): return 0 @@ -240,6 +153,65 @@ def _get_dst_typeid(dtype): raise errors.NumbaTypeError(msg) +def _normalize(builder, src, src_type, dest_type): + dest_llvm_type = _get_llvm_type(dest_type) + if isinstance(src_type, Integer) and isinstance(dest_type, Integer): + if src_type.bitwidth < dest_type.bitwidth: + return builder.zext(src, dest_llvm_type) + elif src_type.bitwidth > dest_type.bitwidth: + return builder.trunc(src, dest_llvm_type) + else: + return src + elif isinstance(src_type, Integer) and isinstance(dest_type, Float): + if src_type.signed: + return builder.sitofp(src, dest_llvm_type) + else: + return builder.uitofp(src, dest_llvm_type) + elif isinstance(src_type, Float) and isinstance(dest_type, Integer): + if src_type.signed: + return builder.fptosi(src, dest_llvm_type) + else: + return builder.fptoui(src, dest_llvm_type) + elif isinstance(src_type, Float) and isinstance(dest_type, Float): + if src_type.bitwidth < dest_type.bitwidth: + return builder.fpext(src, dest_llvm_type) + elif src_type.bitwidth > dest_type.bitwidth: + return builder.fptrunc(src, dest_llvm_type) + else: + return src + else: + msg = f"{src}[{src_type}] is neither numba type 'Float' nor 'Integer'." + raise errors.NumbaTypeError(msg) + + +def _compute_array_length_ir( + context, + builder, + start_ir, + stop_ir, + step_ir, + start_arg_type, + stop_arg_type, + step_arg_type, +): + u64 = llvmir.IntType(64) + one = context.get_constant(types.float64, 1) + + ll = _normalize(builder, start_ir, start_arg_type, types.float64) + ul = _normalize(builder, stop_ir, stop_arg_type, types.float64) + d = _normalize(builder, step_ir, step_arg_type, types.float64) + + # Doing ceil(a,b) = (a-1)/b + 1 to avoid overflow + array_length = builder.fptosi( + builder.fadd( + builder.fdiv(builder.fsub(builder.fsub(ul, ll), one), d), + one, + ), + u64, + ) + return array_length + + @intrinsic def impl_dpnp_arange( ty_context, @@ -268,12 +240,11 @@ def impl_dpnp_arange( def codegen(context, builder, sig, args): mod = builder.module - - start_ir, stop_ir, step_ir, dtype_ir, queue_ir = ( + # Rename variables for easy coding + start_ir, stop_ir, step_ir, queue_ir = ( args[0], args[1], args[2], - args[3], args[sycl_queue_arg_pos], ) ( @@ -290,43 +261,13 @@ def codegen(context, builder, sig, args): sig.args[sycl_queue_arg_pos], ) - # b = llvmir.IntType(1) # noqa: E800 - # u32 = llvmir.IntType(32) # noqa: E800 - u64 = llvmir.IntType(64) - # f32 = llvmir.FloatType() # noqa: E800 - f64 = llvmir.DoubleType() # noqa: E800 - # zero_u32 = context.get_constant(types.int32, 0) # noqa: E800 - # zero_u64 = context.get_constant(types.int64, 0) # noqa: E800 - # zero_f32 = context.get_constant(types.float32, 0) # noqa: E800 - zero_f64 = context.get_constant(types.float64, 0) - # one_u32 = context.get_constant(types.int32, 1) # noqa: E800 - # one_u64 = context.get_constant(types.int64, 1) # noqa: E800 - # one_f32 = context.get_constant(types.float32, 1) # noqa: E800 - one_f64 = context.get_constant(types.float64, 1) - - # ftype = _get_llvm_type(dtype_arg_type.dtype) # noqa: E800 - # utype = _get_llvm_type(dtype_arg_type.dtype) # noqa: E800 - # one = _get_constant( # noqa: E800 - # context, dtype_arg_type.dtype, dtype_arg_type.dtype.bitwidth, 1 # noqa: E800 - # ) # noqa: E800 - # zero = _get_constant( # noqa: E800 - # context, dtype_arg_type.dtype, dtype_arg_type.dtype.bitwidth, 0 # noqa: E800 - # ) # noqa: E800 - - print( - f"start_ir = {start_ir}, " - + f"start_ir.type = {start_ir.type}, " - + f"type(start_ir.type) = {type(start_ir.type)}" - ) - print( - f"step_ir = {step_ir}, " - + f"step_ir.type = {step_ir.type}, " - + f"type(step_ir.type) = {type(step_ir.type)}" - ) - print( - f"stop_ir = {stop_ir}, " - + f"stop_ir.type = {stop_ir.type}, " - + f"type(stop_ir.type) = {type(stop_ir.type)}" + # Get SYCL Queue ref + sycl_queue_arg = _ArgTyAndValue(queue_arg_type, queue_ir) + qref_payload: _QueueRefPayload = _get_queue_ref( + context=context, + builder=builder, + returned_sycl_queue_ty=sig.return_type.queue, + sycl_queue_arg=sycl_queue_arg, ) # Sanity check: @@ -335,89 +276,52 @@ def codegen(context, builder, sig, args): # stop <- 1 # if step is pointing to a null # step <- 1 - # TODO: do this either in LLVMIR or outside of intrinsic - print("type(stop_arg_type) =", type(stop_arg_type)) - print("type(step_arg_type) =", type(step_arg_type)) if isinstance(stop_arg_type, NoneType): - start_ir = zero_f64 - stop_ir = one_f64 + start_ir = context.get_constant(start_arg_type, 0) + stop_ir = context.get_constant(start_arg_type, 1) + stop_arg_type = start_arg_type if isinstance(step_arg_type, NoneType): - step_ir = one_f64 + step_ir = context.get_constant(start_arg_type, 1) + step_arg_type = start_arg_type - if isinstance(start_arg_type, Integer) and isinstance( - dtype_arg_type.dtype, Float - ): - if start_arg_type.signed: - start_ir = builder.sitofp(start_ir, f64) - step_ir = builder.sitofp(step_ir, f64) - else: - start_ir = builder.uitofp(start_ir, f64) - step_ir = builder.uitofp(step_ir, f64) - - print( - f"start_ir = {start_ir}, " - + f"start_ir.type = {start_ir.type}, " - + f"type(start_ir.type) = {type(start_ir.type)}" - ) - print( - f"step_ir = {step_ir}, " - + f"step_ir.type = {step_ir.type}, " - + f"type(step_ir.type) = {type(step_ir.type)}" + start_ir = _normalize( + builder, start_ir, start_arg_type, dtype_arg_type.dtype ) - print( - f"stop_ir = {stop_ir}, " - + f"stop_ir.type = {stop_ir.type}, " - + f"type(stop_ir.type) = {type(stop_ir.type)}" + start_arg_type = dtype_arg_type.dtype + stop_ir = _normalize( + builder, stop_ir, stop_arg_type, dtype_arg_type.dtype ) - print( - f"dtype_ir = {dtype_ir}, " - + f"dtype_ir.type = {dtype_ir.type}, " - + f"dtype_arg_type = {dtype_arg_type}, " - + f"dtype_arg_type.dtype = {dtype_arg_type.dtype}, " - + f"dtype_arg_type.dtype.bitwidth = {dtype_arg_type.dtype.bitwidth}" + stop_arg_type = dtype_arg_type.dtype + step_ir = _normalize( + builder, step_ir, step_arg_type, dtype_arg_type.dtype ) + step_arg_type = dtype_arg_type.dtype - # Get SYCL Queue ref - sycl_queue_arg = _ArgTyAndValue(queue_arg_type, queue_ir) - qref_payload: _QueueRefPayload = _get_queue_ref( - context=context, - builder=builder, - returned_sycl_queue_ty=sig.return_type.queue, - sycl_queue_arg=sycl_queue_arg, + # Allocate an empty array + t = _compute_array_length_ir( + context, + builder, + start_ir, + stop_ir, + step_ir, + start_arg_type, + stop_arg_type, + step_arg_type, ) + ary = _empty_nd_impl( + context, builder, sig.return_type, [t], qref_payload.queue_ref + ) + # Convert into void* + arrystruct_vptr = builder.bitcast(ary._getpointer(), cgutils.voidptr_t) + # Construct function parameters with builder.goto_entry_block(): start_ptr = cgutils.alloca_once(builder, start_ir.type) step_ptr = cgutils.alloca_once(builder, step_ir.type) - builder.store(start_ir, start_ptr) builder.store(step_ir, step_ptr) - start_vptr = builder.bitcast(start_ptr, cgutils.voidptr_t) step_vptr = builder.bitcast(step_ptr, cgutils.voidptr_t) - - ll = builder.sitofp(start_ir, f64) - ul = builder.sitofp(stop_ir, f64) - d = builder.sitofp(step_ir, f64) - - # Doing ceil(a,b) = (a-1)/b + 1 to avoid overflow - t = builder.fptosi( - builder.fadd( - builder.fdiv(builder.fsub(builder.fsub(ul, ll), one_f64), d), - one_f64, - ), - u64, - ) - - # Allocate an empty array - ary = _empty_nd_impl( - context, builder, sig.return_type, [t], qref_payload.queue_ref - ) - - # Convert into void* - arrystruct_vptr = builder.bitcast(ary._getpointer(), cgutils.voidptr_t) - - # Function parameters ndim = context.get_constant(types.intp, 1) is_c_contguous = context.get_constant(types.int8, 1) typeid_index = _get_dst_typeid(dtype_arg_type.dtype) @@ -469,29 +373,17 @@ def ol_dpnp_arange( usm_type="device", sycl_queue=None, ): - print("start =", start, ", type(start) =", type(start)) - print("stop =", stop, ", type(stop) =", type(stop)) - print("step =", step, ", type(step) =", type(step)) - print("dtype =", dtype, ", type(dtype) =", type(dtype)) - print("---") - if stop is None: start = 0 stop = 1 if step is None: step = 1 - print("start =", start, ", type(start) =", type(start)) - print("stop =", stop, ", type(stop) =", type(stop)) - print("step =", step, ", type(step) =", type(step)) - print("***") - _dtype = ( _parse_dtype(dtype) if dtype is not None else _parse_dtype_from_range(start, stop, step) ) - print("_dtype =", _dtype, ", type(_dtype) =", type(_dtype)) _device = _parse_device_filter_string(device) if device else None _usm_type = _parse_usm_type(usm_type) if usm_type else "device"