Skip to content

Commit

Permalink
Unit tests for flattened_member_count data model property.
Browse files Browse the repository at this point in the history
  • Loading branch information
Diptorup Deb committed Oct 12, 2023
1 parent 6e74aea commit c4d2f40
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
21 changes: 21 additions & 0 deletions numba_dpex/tests/core/types/DpnpNdArray/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

from numba import types
from numba.core.datamodel import default_manager, models
from numba.core.registry import cpu_target

from numba_dpex.core.datamodel.models import (
DpnpNdArrayModel,
USMArrayModel,
dpex_data_model_manager,
)
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray


Expand All @@ -31,3 +33,22 @@ def test_dpnp_ndarray_Model():
"""

assert issubclass(DpnpNdArrayModel, models.StructModel)


def test_flattened_member_count():
"""Test that the number of flattened member count matches the number of
flattened args generated by the CpuTarget's ArgPacker.
"""

cputargetctx = cpu_target.target_context
kerneltargetctx = dpex_kernel_target.target_context
dpex_dmm = kerneltargetctx.data_model_manager

for ndim in range(4):
dty = DpnpNdArray(ndim)
argty_tuple = tuple([dty])
datamodel = dpex_dmm.lookup(dty)
num_flattened_args = datamodel.flattened_field_count
ap = cputargetctx.get_arg_packer(argty_tuple)

assert num_flattened_args == len(ap._be_args)
27 changes: 27 additions & 0 deletions numba_dpex/tests/core/types/USMNdArray/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from numba.core.registry import cpu_target

from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.types.usm_ndarray_type import USMNdArray


def test_flattened_member_count():
"""Test that the number of flattened member count matches the number of
flattened args generated by the CpuTarget's ArgPacker.
"""

cputargetctx = cpu_target.target_context
kerneltargetctx = dpex_kernel_target.target_context
dpex_dmm = kerneltargetctx.data_model_manager

for ndim in range(4):
dty = USMNdArray(ndim)
argty_tuple = tuple([dty])
datamodel = dpex_dmm.lookup(dty)
num_flattened_args = datamodel.flattened_field_count
ap = cputargetctx.get_arg_packer(argty_tuple)

assert num_flattened_args == len(ap._be_args)
23 changes: 23 additions & 0 deletions numba_dpex/tests/core/types/range_types/test_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@

import pytest
from numba.core.datamodel import default_manager
from numba.core.registry import cpu_target

from numba_dpex.core.datamodel.models import (
NdRangeModel,
RangeModel,
dpex_data_model_manager,
)
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.core.types.range_types import NdRangeType, RangeType

rfields = ["ndim", "dim0", "dim1", "dim2"]
ndrfields = ["ndim", "gdim0", "gdim1", "gdim2", "ldim0", "ldim1", "ldim2"]
range_tys = [RangeType, NdRangeType]


def test_datamodel_registration():
Expand Down Expand Up @@ -58,3 +61,23 @@ def test_ndrange_model_fields(field):
dm.get_field_position(field)
except:
pytest.fail(f"Expected field {field} not present in NdRangeModel")


@pytest.mark.parametrize("range_type", range_tys)
def test_flattened_member_count(range_type):
"""Test that the number of flattened member count matches the number of
flattened args generated by the CpuTarget's ArgPacker.
"""

cputargetctx = cpu_target.target_context
kerneltargetctx = dpex_kernel_target.target_context
dpex_dmm = kerneltargetctx.data_model_manager

for ndim in range(1, 3):
dty = range_type(ndim)
argty_tuple = tuple([dty])
datamodel = dpex_dmm.lookup(dty)
num_flattened_args = datamodel.flattened_field_count
ap = cputargetctx.get_arg_packer(argty_tuple)

assert num_flattened_args == len(ap._be_args)

0 comments on commit c4d2f40

Please sign in to comment.