diff --git a/numba_dpex/dpnp_iface/_intrinsic.py b/numba_dpex/dpnp_iface/_intrinsic.py index dc43703466..fad1b4e124 100644 --- a/numba_dpex/dpnp_iface/_intrinsic.py +++ b/numba_dpex/dpnp_iface/_intrinsic.py @@ -361,13 +361,25 @@ def alloc_empty_arrayobj(context, builder, sig, queue_ref, args, is_like=False): Returns: The LLVM IR value that stores the empty array """ + print("alloc_empty_arrayobj: sig =", sig) + print("alloc_empty_arrayobj: args =", args) + arrtype, shape = ( _parse_empty_like_args(context, builder, sig, args) if is_like else _parse_empty_args(context, builder, sig, args) ) + print( + "alloc_empty_arrayobj(): arrtype =", + arrtype, + "type(arrtype) =", + type(arrtype), + ) + print( + "alloc_empty_arrayobj(): shape =", shape, ", type(shape) =", type(shape) + ) ary = _empty_nd_impl(context, builder, arrtype, shape, queue_ref) - + print("alloc_empty_arrayobj(): ary =", ary, ", type(ary) =", type(ary)) return ary @@ -473,6 +485,8 @@ def impl_dpnp_empty( ty_retty_ref, ) + print("--- impl_dpnp_empty()") + sycl_queue_arg_pos = -2 def codegen(context, builder, sig, args): @@ -486,6 +500,9 @@ def codegen(context, builder, sig, args): sycl_queue_arg=sycl_queue_arg, ) + print("impl_dpnp_empty(): sig =", sig, type(sig)) + print("impl_dpnp_empty(): args =", args, type(args)) + ary = alloc_empty_arrayobj( context, builder, sig, qref_payload.queue_ref, args ) diff --git a/numba_dpex/dpnp_iface/array_sequence_ops.py b/numba_dpex/dpnp_iface/array_sequence_ops.py index 5b52a1ee14..a389556040 100644 --- a/numba_dpex/dpnp_iface/array_sequence_ops.py +++ b/numba_dpex/dpnp_iface/array_sequence_ops.py @@ -1,21 +1,32 @@ +import math from collections import namedtuple import dpctl.tensor as dpt import dpnp +import numba import numpy as np from dpctl.tensor._ctors import _coerce_and_infer_dt from llvmlite import ir as llvmir from numba import errors, types from numba.core import cgutils -from numba.core.types.scalars import Complex, Float, Integer +from numba.core.types.misc import UnicodeType +from numba.core.types.scalars import Complex, Float, Integer, IntegerLiteral +from numba.core.typing.templates import Signature from numba.extending import intrinsic, overload import numba_dpex.utils as utils from numba_dpex.core.runtime import context as dpexrt from numba_dpex.core.types import DpnpNdArray -from numba_dpex.dpnp_iface._intrinsic import _get_queue_ref +from numba_dpex.dpnp_iface._intrinsic import ( + _ArgTyAndValue, + _empty_nd_impl, + _get_queue_ref, + alloc_empty_arrayobj, +) from numba_dpex.dpnp_iface.arrayobj import ( _parse_device_filter_string, + _parse_dim, + _parse_dtype, _parse_usm_type, ) @@ -24,21 +35,153 @@ ) -def _parse_dtype(a): - if isinstance(a.dtype, Complex): - v_type = a.dtype - w_type = dpnp.float64 if a.dtype.bitwidth == 128 else dpnp.float32 - elif isinstance(a.dtype, Float): - v_type = w_type = a.dtype - elif isinstance(a.dtype, Integer): - v_type = w_type = ( - dpnp.float32 if a.dtype.bitwidth == 32 else dpnp.float64 - ) - # elif a.queue.sycl_device.has_aspect_fp64: - # v_type = w_type = dpnp.float64 +def _parse_dtype_from_range(start, stop, step): + max_bw = max(start.bitwidth, stop.bitwidth, step.bitwidth) + if ( + isinstance(start, Complex) + or isinstance(stop, Complex) + or isinstance(step, Complex) + ): + if max_bw == 128: + return numba.from_dtype(dpnp.complex128) + else: + return numba.from_dtype(dpnp.complex64) + elif ( + isinstance(start, Float) + or isinstance(stop, Float) + or isinstance(step, Float) + ): + if max_bw == 64: + return numba.from_dtype(dpnp.float64) + elif max_bw == 32: + return numba.from_dtype(dpnp.float32) + elif max_bw == 16: + return numba.from_dtype(dpnp.float16) + else: + return numba.from_dtype(dpnp.float) + elif ( + isinstance(start, Integer) + or isinstance(stop, Integer) + or isinstance(step, Integer) + ): + if max_bw == 64: + return numba.from_dtype(dpnp.int64) + elif max_bw == 32: + return numba.from_dtype(dpnp.int32) + else: + return numba.from_dtype(dpnp.int) else: - v_type = w_type = dpnp.float64 - return (v_type, w_type) + msg = "Type couldn't be inferred from (start, stop, step)." + raise errors.NumbaValueError(msg) + + +@intrinsic +def impl_dpnp_arange( + ty_context, + ty_start, + ty_stop, + ty_step, + ty_dtype, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_ret_ty, +): + ty_retty_ = ty_ret_ty.instance_type + signature = ty_retty_( + ty_start, + ty_stop, + ty_step, + ty_dtype, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_ret_ty, + ) + + sycl_queue_arg_pos = -2 + + def codegen(context, builder, sig, args): + start_ir, stop_ir, step_ir, queue_ir = ( + args[0], + args[1], + args[2], + args[sycl_queue_arg_pos], + ) + queue_arg_type = sig.args[sycl_queue_arg_pos] + + u64 = llvmir.IntType(64) + b = llvmir.IntType(1) + # f64 = llvmir.DoubleType() # noqa: E800 + mod = builder.module + + sycl_queue_arg = _ArgTyAndValue(queue_arg_type, queue_ir) + qref_payload: _QueueRefPayload = _get_queue_ref( + context=context, + builder=builder, + returned_sycl_queue_ty=sig.return_type.queue, + sycl_queue_arg=sycl_queue_arg, + ) + + from numba.core.cpu import CPUContext + from numba.np.arrayobj import make_array + + from numba_dpex.core.datamodel.models import DpnpNdArrayModel + + # dt = builder.bitcast(builder.sdiv(t, builder.bitcast(step_ir, u64)), f64) # noqa: E800 + # dt = builder.sdiv(t, builder.bitcast(step_ir, u64)) # noqa: E800 + + with builder.goto_entry_block(): + start_ptr = cgutils.alloca_once(builder, start_ir.type) + step_ptr = cgutils.alloca_once(builder, step_ir.type) + # dt_ptr = cgutils.alloca_once(builder, dt.type) # noqa: E800 + + builder.store(start_ir, start_ptr) + builder.store(step_ir, step_ptr) + # builder.store(dt, dt_ptr) # noqa: E800 + + start_vptr = builder.bitcast(start_ptr, cgutils.voidptr_t) + step_vptr = builder.bitcast(step_ptr, cgutils.voidptr_t) + # dt_vptr = builder.bitcast(dt_ptr, cgutils.voidptr_t) # noqa: E800 + + t = builder.sub(stop_ir, start_ir) + ary = _empty_nd_impl( + context, builder, sig.return_type, [t], qref_payload.queue_ref + ) + arrystruct_vptr = builder.bitcast(ary._getpointer(), cgutils.voidptr_t) + + ndim = context.get_constant(types.intp, 1) + is_c_contguous = context.get_constant(types.boolean, 1) + + fnty = llvmir.FunctionType( + utils.LLVMTypes.int64_ptr_t, + [ + cgutils.voidptr_t, + cgutils.voidptr_t, + cgutils.voidptr_t, + u64, + b, + cgutils.voidptr_t, + ], + ) + fn = cgutils.get_or_insert_function( + mod, fnty, "NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence" + ) + builder.call( + fn, + [ + start_vptr, + step_vptr, + arrystruct_vptr, + ndim, + is_c_contguous, + qref_payload.queue_ref, + ], + ) + + return ary._getvalue() + + return signature, codegen @overload(dpnp.arange, prefer_literal=True) @@ -65,7 +208,15 @@ def ol_dpnp_arange( start = 0 if step is None: step = 1 - _dtype = _parse_dtype(dtype) if dtype is not None else type(start) + print("start =", start, ", type(start) =", type(start)) + print("stop =", stop, ", type(stop) =", type(stop)) + print("step =", step, ", type(step) =", type(step)) + print("-*-") + _dtype = ( + _parse_dtype(dtype) + if dtype is not None + else _parse_dtype_from_range(start, stop, step) + ) _device = _parse_device_filter_string(device) if device else None _usm_type = _parse_usm_type(usm_type) if usm_type else "device" @@ -98,25 +249,16 @@ def impl( usm_type="device", sycl_queue=None, ): - print("start =", start, ", type(start) =", type(start)) - print("stop =", stop, ", type(stop) =", type(stop)) - print("step =", step, ", type(step) =", type(step)) - print( - "dtype =", dtype - ) # , ", type(dtype) =", type(dtype) if dtype is not None else "Null") - print( - "device =", device - ) # , ", type(device) =", type(device) if device is not None else "Null") - print( - "usm_type =", usm_type - ) # , ", type(usm_type) =", type(usm_type) if usm_type is not None else "Null") - print( - "sycl_queue =", sycl_queue - ) # , ", type(sycl_queue) =", type(sycl_queue) if sycl_queue is not None else "Null") - print("###") - - v = dpnp.empty(10) - return v + return impl_dpnp_arange( + start, + stop, + step, + _dtype, + _device, + _usm_type, + sycl_queue, + ret_ty, + ) return impl else: