diff --git a/numba_dpex/tests/core/types/DpnpNdArray/test_models.py b/numba_dpex/tests/core/types/DpnpNdArray/test_models.py index 79a1645950..93fd52cc5f 100644 --- a/numba_dpex/tests/core/types/DpnpNdArray/test_models.py +++ b/numba_dpex/tests/core/types/DpnpNdArray/test_models.py @@ -4,12 +4,14 @@ from numba import types from numba.core.datamodel import default_manager, models +from numba.core.registry import cpu_target from numba_dpex.core.datamodel.models import ( DpnpNdArrayModel, USMArrayModel, dpex_data_model_manager, ) +from numba_dpex.core.descriptor import dpex_kernel_target from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray @@ -31,3 +33,22 @@ def test_dpnp_ndarray_Model(): """ assert issubclass(DpnpNdArrayModel, models.StructModel) + + +def test_flattened_member_count(): + """Test that the number of flattened member count matches the number of + flattened args generated by the CpuTarget's ArgPacker. + """ + + cputargetctx = cpu_target.target_context + kerneltargetctx = dpex_kernel_target.target_context + dpex_dmm = kerneltargetctx.data_model_manager + + for ndim in range(4): + dty = DpnpNdArray(ndim) + argty_tuple = tuple([dty]) + datamodel = dpex_dmm.lookup(dty) + num_flattened_args = datamodel.flattened_field_count + ap = cputargetctx.get_arg_packer(argty_tuple) + + assert num_flattened_args == len(ap._be_args) diff --git a/numba_dpex/tests/core/types/USMNdArray/test_models.py b/numba_dpex/tests/core/types/USMNdArray/test_models.py new file mode 100644 index 0000000000..0586ed13ca --- /dev/null +++ b/numba_dpex/tests/core/types/USMNdArray/test_models.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +from numba.core.registry import cpu_target + +from numba_dpex.core.descriptor import dpex_kernel_target +from numba_dpex.core.types.usm_ndarray_type import USMNdArray + + +def test_flattened_member_count(): + """Test that the number of flattened member count matches the number of + flattened args generated by the CpuTarget's ArgPacker. + """ + + cputargetctx = cpu_target.target_context + kerneltargetctx = dpex_kernel_target.target_context + dpex_dmm = kerneltargetctx.data_model_manager + + for ndim in range(4): + dty = USMNdArray(ndim) + argty_tuple = tuple([dty]) + datamodel = dpex_dmm.lookup(dty) + num_flattened_args = datamodel.flattened_field_count + ap = cputargetctx.get_arg_packer(argty_tuple) + + assert num_flattened_args == len(ap._be_args) diff --git a/numba_dpex/tests/core/types/range_types/test_data_model.py b/numba_dpex/tests/core/types/range_types/test_data_model.py index f4e585206b..bbdeba7f71 100644 --- a/numba_dpex/tests/core/types/range_types/test_data_model.py +++ b/numba_dpex/tests/core/types/range_types/test_data_model.py @@ -4,16 +4,19 @@ import pytest from numba.core.datamodel import default_manager +from numba.core.registry import cpu_target from numba_dpex.core.datamodel.models import ( NdRangeModel, RangeModel, dpex_data_model_manager, ) +from numba_dpex.core.descriptor import dpex_kernel_target from numba_dpex.core.types.range_types import NdRangeType, RangeType rfields = ["ndim", "dim0", "dim1", "dim2"] ndrfields = ["ndim", "gdim0", "gdim1", "gdim2", "ldim0", "ldim1", "ldim2"] +range_tys = [RangeType, NdRangeType] def test_datamodel_registration(): @@ -58,3 +61,23 @@ def test_ndrange_model_fields(field): dm.get_field_position(field) except: pytest.fail(f"Expected field {field} not present in NdRangeModel") + + +@pytest.mark.parametrize("range_type", range_tys) +def test_flattened_member_count(range_type): + """Test that the number of flattened member count matches the number of + flattened args generated by the CpuTarget's ArgPacker. + """ + + cputargetctx = cpu_target.target_context + kerneltargetctx = dpex_kernel_target.target_context + dpex_dmm = kerneltargetctx.data_model_manager + + for ndim in range(1, 3): + dty = range_type(ndim) + argty_tuple = tuple([dty]) + datamodel = dpex_dmm.lookup(dty) + num_flattened_args = datamodel.flattened_field_count + ap = cputargetctx.get_arg_packer(argty_tuple) + + assert num_flattened_args == len(ap._be_args)