Skip to content

Commit

Permalink
fix(NumpyModel): Equality operator when the fields are heterogeneous.
Browse files Browse the repository at this point in the history
patch(NumpyModel): Migrated to new __eq__ method in Pydantic
  • Loading branch information
caniko committed May 9, 2024
1 parent 155d56a commit 6fab8b2
Show file tree
Hide file tree
Showing 6 changed files with 317 additions and 236 deletions.
446 changes: 253 additions & 193 deletions poetry.lock

Large diffs are not rendered by default.

42 changes: 22 additions & 20 deletions pydantic_numpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,29 @@ class NumpyModel(BaseModel):
_directory_suffix: ClassVar[str] = ".pdnp"

def __eq__(self, other: Any) -> bool:
if isinstance(other, NumpyModel):
self_type = self.__pydantic_generic_metadata__["origin"] or self.__class__
other_type = other.__pydantic_generic_metadata__["origin"] or other.__class__
if not isinstance(other, BaseModel):
return NotImplemented # delegate to the other item in the comparison

self_type = self.__pydantic_generic_metadata__["origin"] or self.__class__
other_type = other.__pydantic_generic_metadata__["origin"] or other.__class__

if not (
self_type == other_type
and getattr(self, "__pydantic_private__", None) == getattr(other, "__pydantic_private__", None)
and self.__pydantic_extra__ == other.__pydantic_extra__
):
return False

if isinstance(other, NumpyModel):
self_ndarray_field_to_array, self_other_field_to_value = self._dump_numpy_split_dict()
other_ndarray_field_to_array, other_other_field_to_value = other._dump_numpy_split_dict()

return (
self_type == other_type
and self_other_field_to_value == other_other_field_to_value
and self.__pydantic_private__ == other.__pydantic_private__
and self.__pydantic_extra__ == other.__pydantic_extra__
and _compare_np_array_dicts(self_ndarray_field_to_array, other_ndarray_field_to_array)
return self_other_field_to_value == other_other_field_to_value and _compare_np_array_dicts(
self_ndarray_field_to_array, other_ndarray_field_to_array
)
elif isinstance(other, BaseModel):
return super().__eq__(other)
else:
return NotImplemented # delegate to the other item in the comparison

# Self is NumpyModel, other is not; likely unequal; checking anyway.
return super().__eq__(other)

@classmethod
@validate_call
Expand Down Expand Up @@ -156,10 +161,10 @@ def _dump_numpy_split_dict(self) -> tuple[dict, dict]:
ndarray_field_to_array = {}
other_field_to_value = {}

for k, v in self.model_dump(exclude_unset=True).items():
for k, v in self.model_dump().items():
if isinstance(v, np.ndarray):
ndarray_field_to_array[k] = v
else:
elif v:
other_field_to_value[k] = v

return ndarray_field_to_array, other_field_to_value
Expand Down Expand Up @@ -259,16 +264,13 @@ def _compare_np_array_dicts(
keys2 = frozenset(dict_b.keys())

if keys1 != keys2:
raise ValueError("Dictionaries have different keys")
return False

for key in keys1:
arr_a = dict_a[key]
arr_b = dict_b[key]

if arr_a.shape != arr_b.shape:
raise ValueError(f"Arrays for key '{key}' have different shapes")

if not np_general_all_close(arr_a, arr_b, rtol, atol):
if arr_a.shape != arr_b.shape or not np_general_all_close(arr_a, arr_b, rtol, atol):
return False

return True
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pydantic_numpy"
version = "5.0.1"
version = "5.0.2"
description = "Pydantic Model integration of the NumPy array"
authors = ["Can H. Tartanoglu", "Christoph Heindl"]
maintainers = ["Can H. Tartanoglu <[email protected]>"]
Expand Down Expand Up @@ -30,6 +30,7 @@ semver = "^3.0.1"
pytest = "^7.4.0"
parameterized = "^0.9.0"
orjson = "*"
coverage = "^7.5.1"

[tool.poetry.group.format.dependencies]
black = "^23.7.0"
Expand Down
4 changes: 2 additions & 2 deletions tests/helper/testing_groups.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import platform

import numpy as np

Expand Down Expand Up @@ -185,7 +185,7 @@
(np.array([[[0]]]), np.int64, Np3DArrayInt64, 3),
]

if os.name != "nt":
if platform.system() != "Windows":

def get_strict_data_type_nd_array_typing_dimensions_128_bit():
return [
Expand Down
54 changes: 36 additions & 18 deletions tests/test_np_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import platform
import tempfile
from pathlib import Path

import numpy as np
import pytest

from pydantic_numpy.model import model_agnostic_load
from pydantic_numpy.model import NumpyModel, model_agnostic_load
from pydantic_numpy.typing import NpNDArray
from tests.model import (
NpNDArrayModelWithNonArray,
Expand All @@ -17,28 +17,19 @@
NON_ARRAY_VALUE = 5


def _numpy_model():
@pytest.fixture
def numpy_model() -> NpNDArrayModelWithNonArray:
return NpNDArrayModelWithNonArray(array=np.array([0.0]), non_array=NON_ARRAY_VALUE)


@pytest.fixture
def numpy_model():
return _numpy_model()


@pytest.fixture(
params=[
_numpy_model(),
NpNDArrayModelWithNonArrayWithArbitrary(
array=np.array([0.0]), non_array=NON_ARRAY_VALUE, my_arbitrary_slice=slice(0, 10)
),
]
)
def numpy_model_with_arbitrary(request):
return request.param
def numpy_model_with_arbitrary() -> NpNDArrayModelWithNonArrayWithArbitrary:
return NpNDArrayModelWithNonArrayWithArbitrary(
array=np.array([0.0]), non_array=NON_ARRAY_VALUE, my_arbitrary_slice=slice(0, 1)
)


if os.name != "nt":
if platform.system() != "Windows":

def test_io_yaml(numpy_model: NpNDArrayModelWithNonArray) -> None:
with tempfile.TemporaryDirectory() as tmp_dirname:
Expand Down Expand Up @@ -80,3 +71,30 @@ class NumpyModelBForTest(NpNDArrayModelWithNonArray):
models = [NumpyModelAForTest, NumpyModelBForTest]
assert model_a == model_agnostic_load(tmp_dir_path, TEST_MODEL_OBJECT_ID, models=models)
assert model_b == model_agnostic_load(tmp_dir_path, OTHER_TEST_MODEL_OBJECT_ID, models=models)

def test_simple_eq(numpy_model: NpNDArrayModelWithNonArray) -> None:
assert numpy_model == numpy_model

def test_not_eq_different_fields(numpy_model, numpy_model_with_arbitrary) -> None:
assert numpy_model != numpy_model_with_arbitrary

class AnotherModel(NumpyModel):
yarra: NpNDArray

assert numpy_model != AnotherModel(yarra=np.array([0.0]))

def test_not_eq_different_inner(numpy_model: NpNDArrayModelWithNonArray) -> None:
assert numpy_model != NpNDArrayModelWithNonArray(array=np.array([1.0]), non_array=NON_ARRAY_VALUE)

def test_not_eq_different_shape(numpy_model: NpNDArrayModelWithNonArray) -> None:
assert numpy_model != NpNDArrayModelWithNonArray(array=np.array([0.0, 1.0]), non_array=NON_ARRAY_VALUE)

def test_random_not_eq(numpy_model: NpNDArrayModelWithNonArray) -> None:
for r in (0, 5, 1.0, "1"):
assert numpy_model != r

def test_serde_eq(numpy_model: NpNDArrayModelWithNonArray) -> None:
ser = numpy_model.model_dump_json()
reread_data = numpy_model.model_validate_json(ser)

assert numpy_model == reread_data
4 changes: 2 additions & 2 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import platform
import tempfile
from pathlib import Path
from typing import Optional
Expand Down Expand Up @@ -54,7 +54,7 @@ def test_wrong_dimension():
get_numpy_type_model(Np1DArrayInt64)(array_field=np.array([[0]]))


if os.name != "nt":
if platform.system() != "Windows":
from tests.helper.testing_groups import (
get_strict_data_type_nd_array_typing_dimensions_128_bit,
)
Expand Down

0 comments on commit 6fab8b2

Please sign in to comment.