From 6832cbbcf4044cd58f188c08700ab502b327ad94 Mon Sep 17 00:00:00 2001 From: Subhash Bhushan Date: Mon, 20 May 2024 11:24:42 -0700 Subject: [PATCH 1/3] Preserve owner and root linkages in child entities This commit introduces two fields in every entity to track their lineage: root and owner. The owner is the immediate parent of an entity. This is the entity (or aggregate) that has the association field. The root is the aggregate root. This is the first step to being able to satisfy invariants across an aggregate's clustes of entities. --- src/protean/container.py | 11 +- src/protean/core/aggregate.py | 4 + src/protean/core/entity.py | 25 +++ src/protean/fields/association.py | 20 +- .../associations/test_owner_and_root.py | 187 ++++++++++++++++++ .../value_object/test_vo_field_properties.py | 22 ++- 6 files changed, 261 insertions(+), 8 deletions(-) create mode 100644 tests/entity/associations/test_owner_and_root.py diff --git a/src/protean/container.py b/src/protean/container.py index c0534c5f..e0736928 100644 --- a/src/protean/container.py +++ b/src/protean/container.py @@ -320,7 +320,16 @@ def __setattr__(self, name, value): if ( name in attributes(self) or name in fields(self) - or name in ["errors", "state_", "_temp_cache", "_events", "_initialized"] + or name + in [ + "errors", # Errors in state transition + "state_", # Tracking dirty state of the entity + "_temp_cache", # Temporary cache (Assocations) for storing data befor persisting + "_events", # Temp placeholder for events raised by the entity + "_initialized", # Flag to indicate if the entity has been initialized + "_root", # Root entity in the hierarchy + "_owner", # Owner entity in the hierarchy + ] or name.startswith(("add_", "remove_")) ): super().__setattr__(name, value) diff --git a/src/protean/core/aggregate.py b/src/protean/core/aggregate.py index 03ada138..da6dc2cb 100644 --- a/src/protean/core/aggregate.py +++ b/src/protean/core/aggregate.py @@ -50,6 +50,10 @@ class Meta: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # Set root in all child elements + # This is where we kick-off the process of setting the owner and root + self._set_root_and_owner(self, self) + @classmethod def _default_options(cls): return [ diff --git a/src/protean/core/entity.py b/src/protean/core/entity.py index 3397e032..324cd94d 100644 --- a/src/protean/core/entity.py +++ b/src/protean/core/entity.py @@ -145,6 +145,10 @@ def __init__(self, *template, **kwargs): # noqa: C901 # Placeholder for HasMany change tracking self._temp_cache = defaultdict(lambda: defaultdict(dict)) + # Attributes to preserve heirarchy of element instances + self._owner = None + self._root = None + # Collect Reference field attribute names to prevent accidental overwriting # of shadow fields. reference_attributes = { @@ -421,6 +425,27 @@ def _extract_options(cls, **opts): ) setattr(cls.meta_, key, value) + def _set_root_and_owner(self, root, owner): + """Set the root and owner entities on all child entities + + This is a recursive process set in motion by the aggregate's `__init__` method. + """ + self._root = root + self._owner = owner + + # Set `_root` on all child entities + for field_name, field_obj in declared_fields(self).items(): + # We care only about enclosed fields (associations) + if isinstance(field_obj, Association): + # Get current assigned value + value = getattr(self, field_name) + if value is not None: + # Link child entities to own root + items = value if isinstance(value, list) else [value] + for item in items: + if not item._root: + item._set_root_and_owner(self._root, self) + def entity_factory(element_cls, **kwargs): element_cls = derive_element_class(element_cls, BaseEntity, **kwargs) diff --git a/src/protean/fields/association.py b/src/protean/fields/association.py index 5a3f91c5..070eb16a 100644 --- a/src/protean/fields/association.py +++ b/src/protean/fields/association.py @@ -326,11 +326,15 @@ def as_dict(self): raise NotImplementedError def __set__(self, instance, value): - """Cannot set values through an association""" - raise exceptions.NotSupportedError( - "Object does not support the operation being performed", - self.field_name, - ) + """Set the value of the association field""" + # Preserve heirarchy of entities. + # + # Owner: is the entity that owns the association field + # Root: is the entity that is at the top of the hierarchy, an Aggregate Root + if value is not None: + items = value if isinstance(value, list) else [value] + for item in items: + item._set_root_and_owner(instance._root, instance) def __delete__(self, instance): """Cannot pop values for an association""" @@ -374,6 +378,8 @@ def __set__(self, instance, value): the changes to be persisted. """ + super().__set__(instance, value) + if value is not None and not isinstance(value, self.to_cls): raise ValidationError( { @@ -466,6 +472,8 @@ def __set__(self, instance, value): """This supports direct assignment of values to HasMany fields, like: `order.items = [item1, item2, item3]` """ + super().__set__(instance, value) + if value is not None: self.add(instance, value) @@ -498,6 +506,8 @@ def add(self, instance, items) -> None: instance: The source entity instance. items: The linked entity or entities to be added. """ + super().__set__(instance, items) + data = getattr(instance, self.field_name) # Convert a single item into a list of items, if necessary diff --git a/tests/entity/associations/test_owner_and_root.py b/tests/entity/associations/test_owner_and_root.py new file mode 100644 index 00000000..35dd38f5 --- /dev/null +++ b/tests/entity/associations/test_owner_and_root.py @@ -0,0 +1,187 @@ +import pytest + +from protean import BaseAggregate, BaseEntity +from protean.fields import Integer, String, HasOne, HasMany + + +class University(BaseAggregate): + name = String(max_length=50) + departments = HasMany("Department") + + +class Department(BaseEntity): + name = String(max_length=50) + dean = HasOne("Dean") + + class Meta: + part_of = University + + +class Dean(BaseEntity): + name = String(max_length=50) + age = Integer(min_value=21) + office = HasOne("Office") + + class Meta: + part_of = Department + + +class Office(BaseEntity): + building = String(max_length=25) + room = Integer(min_value=1) + + class Meta: + part_of = Dean + + +@pytest.fixture(autouse=True) +def register_elements(test_domain): + test_domain.register(University) + test_domain.register(Department) + test_domain.register(Dean) + test_domain.register(Office) + test_domain.init(traverse=False) + + +def test_owner_linkage(): + office = Office(building="Main", room=101) + dean = Dean(name="John Doe", age=45, office=office) + department = Department(name="Computer Science", dean=dean) + university = University(name="MIT", departments=[department]) + + assert university._owner == university + assert department._owner == university + assert dean._owner == department + assert office._owner == dean + + +def test_root_linkage_when_entities_are_constructed_in_advance(): + office = Office(building="Main", room=101) + dean = Dean(name="John Doe", age=45, office=office) + department = Department(name="Computer Science", dean=dean) + university = University(name="MIT", departments=[department]) + + assert university._root == university + assert department._root == university + assert dean._root == university + assert office._root == university + + +def test_root_linkage_when_aggregate_and_entities_are_constructed_together(): + university = University( + name="MIT", + departments=[ + Department( + name="Computer Science", + dean=Dean( + name="John Doe", age=45, office=Office(building="Main", room=101) + ), + ) + ], + ) + + # Test owner linkages + assert university._owner == university + assert university.departments[0]._owner == university + assert university.departments[0].dean._owner == university.departments[0] + assert ( + university.departments[0].dean.office._owner == university.departments[0].dean + ) + + # Test root linkages + assert university._root == university + assert university.departments[0]._root == university + assert university.departments[0].dean._root == university + assert university.departments[0].dean.office._root == university + + +def test_root_linkage_is_preserved_after_persistence_and_retrieval(test_domain): + university = University( + name="MIT", + departments=[ + Department( + name="Computer Science", + dean=Dean( + name="John Doe", age=45, office=Office(building="Main", room=101) + ), + ) + ], + ) + + test_domain.repository_for(University).add(university) + + refreshed_university = test_domain.repository_for(University).get(university.id) + + # Test owner linkages + assert refreshed_university._owner == university + assert refreshed_university.departments[0]._owner == refreshed_university + assert ( + refreshed_university.departments[0].dean._owner + == refreshed_university.departments[0] + ) + assert ( + refreshed_university.departments[0].dean.office._owner + == refreshed_university.departments[0].dean + ) + + # Test root linkages + assert refreshed_university._root == refreshed_university + assert refreshed_university.departments[0]._root == refreshed_university + assert refreshed_university.departments[0].dean._root == refreshed_university + assert refreshed_university.departments[0].dean.office._root == refreshed_university + + +def test_root_linkage_on_newly_added_entity(test_domain): + university = University( + name="MIT", + departments=[ + Department( + name="Computer Science", + dean=Dean( + name="John Doe", age=45, office=Office(building="Main", room=101) + ), + ) + ], + ) + + new_department = Department( + name="Electrical Engineering", + dean=Dean(name="Jane Doe", age=42, office=Office(building="Main", room=102)), + ) + + assert new_department._root is None + assert new_department.dean._root is None + assert new_department.dean.office._root is None + + university.add_departments(new_department) + + # Test owner linkages + assert new_department._owner == university + assert new_department.dean._owner == new_department + assert new_department.dean.office._owner == new_department.dean + + # Test root linkages + assert new_department._root == university + assert new_department.dean._root == university + assert new_department.dean.office._root == university + + assert university.departments[1]._root == university + assert university.departments[1].dean._root == university + assert university.departments[1].dean.office._root == university + + test_domain.repository_for(University).add(university) + + refreshed_university = test_domain.repository_for(University).get(university.id) + + # Test owner linkages + assert refreshed_university._owner == refreshed_university + assert refreshed_university.departments[0]._owner == refreshed_university + assert ( + refreshed_university.departments[0].dean._owner + == refreshed_university.departments[0] + ) + + # Test root linkages + assert refreshed_university.departments[1]._root == university + assert refreshed_university.departments[1].dean._root == university + assert refreshed_university.departments[1].dean.office._root == university diff --git a/tests/value_object/test_vo_field_properties.py b/tests/value_object/test_vo_field_properties.py index 91e133c9..3c76f5ef 100644 --- a/tests/value_object/test_vo_field_properties.py +++ b/tests/value_object/test_vo_field_properties.py @@ -1,8 +1,8 @@ import pytest -from protean import BaseValueObject +from protean import BaseEntity, BaseValueObject from protean.exceptions import IncorrectUsageError -from protean.fields import Float, String +from protean.fields import Float, HasMany, String def test_vo_cannot_contain_fields_marked_unique(): @@ -35,3 +35,21 @@ class Balance(BaseValueObject): ] } ) + + +def test_vo_cannot_have_association_fields(): + with pytest.raises(IncorrectUsageError) as exception: + + class Address(BaseEntity): + street_address = String() + + class Office(BaseValueObject): + addresses = HasMany(Address) + + assert str(exception.value) == str( + { + "_value_object": [ + "Value Objects can only contain basic field types. Remove addresses (HasMany) from class Office" + ] + } + ) From d57833bbe9967d895d9d5748e366b7bc399f74e3 Mon Sep 17 00:00:00 2001 From: Subhash Bhushan Date: Mon, 20 May 2024 16:20:36 -0700 Subject: [PATCH 2/3] Introduce @invariant decorator Add @invariant decorator, and configure aggregate and element classes to parse through methods and record them for later use. --- src/protean/__init__.py | 3 +- src/protean/core/aggregate.py | 13 +++++++- src/protean/core/entity.py | 28 +++++++++++++++++ tests/entity/test_invariants.py | 55 +++++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 tests/entity/test_invariants.py diff --git a/src/protean/__init__.py b/src/protean/__init__.py index c5ba12db..e31a1165 100644 --- a/src/protean/__init__.py +++ b/src/protean/__init__.py @@ -6,7 +6,7 @@ from .core.command_handler import BaseCommandHandler from .core.domain_service import BaseDomainService from .core.email import BaseEmail -from .core.entity import BaseEntity +from .core.entity import BaseEntity, invariant from .core.event import BaseEvent from .core.event_handler import BaseEventHandler from .core.event_sourced_aggregate import BaseEventSourcedAggregate, apply @@ -53,4 +53,5 @@ "current_uow", "get_version", "handle", + "invariant", ] diff --git a/src/protean/core/aggregate.py b/src/protean/core/aggregate.py index da6dc2cb..470f99fe 100644 --- a/src/protean/core/aggregate.py +++ b/src/protean/core/aggregate.py @@ -1,5 +1,6 @@ """Aggregate Functionality and Classes""" +import inspect import logging from protean.container import EventedMixin @@ -65,4 +66,14 @@ def _default_options(cls): def aggregate_factory(element_cls, **kwargs): - return derive_element_class(element_cls, BaseAggregate, **kwargs) + element_cls = derive_element_class(element_cls, BaseAggregate, **kwargs) + + # Iterate through methods marked as `@invariant` and record them for later use + methods = inspect.getmembers(element_cls, predicate=inspect.isroutine) + for method_name, method in methods: + if not ( + method_name.startswith("__") and method_name.endswith("__") + ) and hasattr(method, "_invariant"): + element_cls._invariants.append(method) + + return element_cls diff --git a/src/protean/core/entity.py b/src/protean/core/entity.py index 324cd94d..c9da569a 100644 --- a/src/protean/core/entity.py +++ b/src/protean/core/entity.py @@ -1,6 +1,8 @@ """Entity Functionality and Classes""" import copy +import functools +import inspect import logging from collections import defaultdict @@ -446,6 +448,12 @@ def _set_root_and_owner(self, root, owner): if not item._root: item._set_root_and_owner(self._root, self) + def __init_subclass__(subclass) -> None: + super().__init_subclass__() + + # Record invariant methods + setattr(subclass, "_invariants", []) + def entity_factory(element_cls, **kwargs): element_cls = derive_element_class(element_cls, BaseEntity, **kwargs) @@ -505,4 +513,24 @@ def entity_factory(element_cls, **kwargs): shadow_field_name, shadow_field = field.get_shadow_field() shadow_field.__set_name__(element_cls, shadow_field_name) + # Iterate through methods marked as `@invariant` and record them for later use + methods = inspect.getmembers(element_cls, predicate=inspect.isroutine) + for method_name, method in methods: + if not ( + method_name.startswith("__") and method_name.endswith("__") + ) and hasattr(method, "_invariant"): + element_cls._invariants.append(method) + return element_cls + + +def invariant(fn): + """Decorator to mark invariant methods in an Entity""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + setattr(wrapper, "_invariant", True) + + return wrapper diff --git a/tests/entity/test_invariants.py b/tests/entity/test_invariants.py new file mode 100644 index 00000000..972543aa --- /dev/null +++ b/tests/entity/test_invariants.py @@ -0,0 +1,55 @@ +from protean import BaseAggregate, BaseEntity, invariant +from protean.exceptions import ValidationError +from protean.fields import Date, Float, Integer, String, HasMany + + +class Order(BaseAggregate): + ordered_on = Date() + total = Float() + items = HasMany("OrderItem") + + @invariant + def total_should_be_sum_of_item_prices(self): + if self.items: + if self.total != sum([item.price for item in self.items]): + raise ValidationError("Total should be sum of item prices") + + @invariant + def must_have_at_least_one_item(self): + if not self.items or len(self.items) == 0: + raise ValidationError("Order must contain at least one item") + + @invariant + def item_quantities_should_be_positive(self): + for item in self.items: + if item.quantity <= 0: + raise ValidationError("Item quantities should be positive") + + +class OrderItem(BaseEntity): + product_id = String(max_length=50) + quantity = Integer() + price = Float() + + class Meta: + part_of = Order + + @invariant + def price_should_be_non_negative(self): + if self.price < 0: + raise ValidationError("Item price should be non-negative") + + +def test_that_entity_has_recorded_invariants(test_domain): + test_domain.register(OrderItem) + test_domain.register(Order) + test_domain.init(traverse=False) + + assert len(Order._invariants) == 3 + # Methods are presented in ascending order (alphabetical order) of member names. + assert Order._invariants[0].__name__ == "item_quantities_should_be_positive" + assert Order._invariants[1].__name__ == "must_have_at_least_one_item" + assert Order._invariants[2].__name__ == "total_should_be_sum_of_item_prices" + + assert len(OrderItem._invariants) == 1 + assert OrderItem._invariants[0].__name__ == "price_should_be_non_negative" From 331cc74f42a2ad54bd543760d3006db23829f302 Mon Sep 17 00:00:00 2001 From: Subhash Bhushan Date: Tue, 21 May 2024 13:47:57 -0700 Subject: [PATCH 3/3] Run all invariants on init and attribute changes When an aggregate or entity object is initialized, or its attribute value is changed, this commit ensures that all `clean()` methods (which in turn run all invariant validations) are run across the aggregate. This commit also adds `atomic_change` context manager for scenarios where multiple aggregates in an aggregate need to be changed before it becomes valid again. --- docs/guides/compose-a-domain/object-model.md | 4 + .../domain-definition/fields/simple-fields.md | 20 ++ src/protean/__init__.py | 3 +- src/protean/container.py | 4 + src/protean/core/aggregate.py | 18 +- src/protean/core/command.py | 7 - src/protean/core/entity.py | 48 ++- src/protean/core/event.py | 10 - src/protean/fields/association.py | 13 + src/protean/fields/base.py | 9 + src/protean/fields/basic.py | 7 +- tests/aggregate/test_atomic_change.py | 45 +++ tests/entity/elements.py | 15 +- .../test_invariant_decorator.py} | 8 +- .../invariants/test_invariant_triggerring.py | 287 ++++++++++++++++++ tests/entity/test_lifecycle_methods.py | 8 +- 16 files changed, 459 insertions(+), 47 deletions(-) create mode 100644 tests/aggregate/test_atomic_change.py rename tests/entity/{test_invariants.py => invariants/test_invariant_decorator.py} (83%) create mode 100644 tests/entity/invariants/test_invariant_triggerring.py diff --git a/docs/guides/compose-a-domain/object-model.md b/docs/guides/compose-a-domain/object-model.md index 6a2673e6..5b7d24c4 100644 --- a/docs/guides/compose-a-domain/object-model.md +++ b/docs/guides/compose-a-domain/object-model.md @@ -9,6 +9,10 @@ document outlines generic aspects that apply to every domain element. `Element` is a base class inherited by all domain elements. Currently, it does not have any data structures or behavior associated with it. +## Element Type + +.element_type + ## Data Containers Protean provides data container elements, aligned with DDD principles to model diff --git a/docs/guides/domain-definition/fields/simple-fields.md b/docs/guides/domain-definition/fields/simple-fields.md index 7f24ff0e..8adef2b9 100644 --- a/docs/guides/domain-definition/fields/simple-fields.md +++ b/docs/guides/domain-definition/fields/simple-fields.md @@ -74,6 +74,26 @@ Out[2]: 'id': '88a21815-7d9b-4138-9cac-5a06889d4318'} ``` +Protean will intelligently convert a valid date string into a date object, with +the help of the venerable +[`dateutil`](https://dateutil.readthedocs.io/en/stable/) module. + +```shell +In [1]: post = Post(title='Foo', published_on="2020-01-01") + +In [2]: post.to_dict() +Out[2]: +{'title': 'Foo', + 'published_on': '2020-01-01', + 'id': 'ffcb3b26-71f0-45d0-8ca0-b71a9603f792'} + +In [3]: Post(title='Foo', published_on="2019-02-29") +ERROR: Error during initialization: {'published_on': ['"2019-02-29" has an invalid date format.']} +... +ValidationError: {'published_on': ['"2019-02-29" has an invalid date format.']} +``` + + ## DateTime A date and time, represented in Python by a `datetime.datetime` instance. diff --git a/src/protean/__init__.py b/src/protean/__init__.py index e31a1165..70cf8ab3 100644 --- a/src/protean/__init__.py +++ b/src/protean/__init__.py @@ -1,6 +1,6 @@ __version__ = "0.11.0" -from .core.aggregate import BaseAggregate +from .core.aggregate import BaseAggregate, atomic_change from .core.application_service import BaseApplicationService from .core.command import BaseCommand from .core.command_handler import BaseCommandHandler @@ -54,4 +54,5 @@ "get_version", "handle", "invariant", + "atomic_change", ] diff --git a/src/protean/container.py b/src/protean/container.py index e0736928..4c596305 100644 --- a/src/protean/container.py +++ b/src/protean/container.py @@ -195,6 +195,7 @@ def __init__(self, *template, **kwargs): # noqa: C901 This initialization technique supports keyword arguments as well as dictionaries. You can even use a template for initial data. """ + self._initialized = False if self.meta_.abstract is True: raise NotSupportedError( @@ -265,6 +266,8 @@ def __init__(self, *template, **kwargs): # noqa: C901 self.defaults() + self._initialized = True + # `clean()` will return a `defaultdict(list)` if errors are to be raised custom_errors = self.clean() or {} for field in custom_errors: @@ -329,6 +332,7 @@ def __setattr__(self, name, value): "_initialized", # Flag to indicate if the entity has been initialized "_root", # Root entity in the hierarchy "_owner", # Owner entity in the hierarchy + "_disable_invariant_checks", # Flag to disable invariant checks ] or name.startswith(("add_", "remove_")) ): diff --git a/src/protean/core/aggregate.py b/src/protean/core/aggregate.py index 470f99fe..2d35ea40 100644 --- a/src/protean/core/aggregate.py +++ b/src/protean/core/aggregate.py @@ -69,11 +69,27 @@ def aggregate_factory(element_cls, **kwargs): element_cls = derive_element_class(element_cls, BaseAggregate, **kwargs) # Iterate through methods marked as `@invariant` and record them for later use + # `_invariants` is a dictionary initialized in BaseEntity.__init_subclass__ methods = inspect.getmembers(element_cls, predicate=inspect.isroutine) for method_name, method in methods: if not ( method_name.startswith("__") and method_name.endswith("__") ) and hasattr(method, "_invariant"): - element_cls._invariants.append(method) + element_cls._invariants[method_name] = method return element_cls + + +# Context manager to temporarily disable invariant checks on aggregate +class atomic_change: + def __init__(self, aggregate): + self.aggregate = aggregate + + def __enter__(self): + # Temporary disable invariant checks + self.aggregate._disable_invariant_checks = True + + def __exit__(self, *args): + # Run clean() on exit to trigger invariant checks + self.aggregate._disable_invariant_checks = False + self.aggregate.clean() diff --git a/src/protean/core/command.py b/src/protean/core/command.py index d3870799..f95440c3 100644 --- a/src/protean/core/command.py +++ b/src/protean/core/command.py @@ -24,18 +24,11 @@ def __init_subclass__(subclass) -> None: subclass.__track_id_field() def __init__(self, *args, **kwargs): - # Set the flag to prevent any further modifications - self._initialized = False - try: super().__init__(*args, **kwargs) except ValidationError as exception: raise InvalidDataError(exception.messages) - # If we made it this far, the Value Object is initialized - # and should be marked as such - self._initialized = True - def __setattr__(self, name, value): if not hasattr(self, "_initialized") or not self._initialized: return super().__setattr__(name, value) diff --git a/src/protean/core/entity.py b/src/protean/core/entity.py index c9da569a..78513aa0 100644 --- a/src/protean/core/entity.py +++ b/src/protean/core/entity.py @@ -133,6 +133,8 @@ def __init__(self, *template, **kwargs): # noqa: C901 user = User(base_user.to_dict(), first_name='John', last_name='Doe') """ + self._initialized = False + if self.meta_.abstract is True: raise NotSupportedError( f"{self.__class__.__name__} class has been marked abstract" @@ -151,6 +153,9 @@ def __init__(self, *template, **kwargs): # noqa: C901 self._owner = None self._root = None + # To control invariant checks + self._disable_invariant_checks = False + # Collect Reference field attribute names to prevent accidental overwriting # of shadow fields. reference_attributes = { @@ -287,8 +292,10 @@ def __init__(self, *template, **kwargs): # noqa: C901 self.defaults() + self._initialized = True + # `clean()` will return a `defaultdict(list)` if errors are to be raised - custom_errors = self.clean() or {} + custom_errors = self.clean(return_errors=True) or {} for field in custom_errors: self.errors[field].extend(custom_errors[field]) @@ -302,11 +309,36 @@ def defaults(self): To be overridden in concrete Containers, when an attribute's default depends on other attribute values. """ - def clean(self): - """Placeholder method for validations. - To be overridden in concrete Containers, when complex validations spanning multiple fields are required. - """ - return defaultdict(list) + def clean(self, return_errors=False): + """Invoked after initialization to perform additional validations.""" + # Call all methods marked as invariants + if self._initialized and not self._disable_invariant_checks: + errors = defaultdict(list) + + for invariant_method in self._invariants.values(): + try: + invariant_method(self) + except ValidationError as err: + for field_name in err.messages: + errors[field_name].extend(err.messages[field_name]) + + # Run through all associations and trigger their clean method + for field_name, field_obj in declared_fields(self).items(): + if isinstance(field_obj, Association): + value = getattr(self, field_name) + if value is not None: + items = value if isinstance(value, list) else [value] + for item in items: + item_errors = item.clean(return_errors=True) + if item_errors: + for sub_field_name, error_list in item_errors.items(): + errors[sub_field_name].extend(error_list) + + if return_errors: + return errors + + if errors: + raise ValidationError(errors) def __eq__(self, other): """Equivalence check to be based only on Identity""" @@ -452,7 +484,7 @@ def __init_subclass__(subclass) -> None: super().__init_subclass__() # Record invariant methods - setattr(subclass, "_invariants", []) + setattr(subclass, "_invariants", {}) def entity_factory(element_cls, **kwargs): @@ -519,7 +551,7 @@ def entity_factory(element_cls, **kwargs): if not ( method_name.startswith("__") and method_name.endswith("__") ) and hasattr(method, "_invariant"): - element_cls._invariants.append(method) + element_cls._invariants[method_name] = method return element_cls diff --git a/src/protean/core/event.py b/src/protean/core/event.py index 79ca0c7e..b1116b26 100644 --- a/src/protean/core/event.py +++ b/src/protean/core/event.py @@ -27,16 +27,6 @@ def __init_subclass__(subclass) -> None: if not subclass.meta_.abstract: subclass.__track_id_field() - def __init__(self, *args, **kwargs): - # Set the flag to prevent any further modifications - self._initialized = False - - super().__init__(*args, **kwargs) - - # If we made it this far, the Value Object is initialized - # and should be marked as such - self._initialized = True - def __setattr__(self, name, value): if not hasattr(self, "_initialized") or not self._initialized: return super().__setattr__(name, value) diff --git a/src/protean/fields/association.py b/src/protean/fields/association.py index 070eb16a..8c5f1eef 100644 --- a/src/protean/fields/association.py +++ b/src/protean/fields/association.py @@ -437,6 +437,9 @@ def __set__(self, instance, value): elif isinstance(field_obj, HasOne): setattr(old_value, field_name, None) + if instance._initialized and instance._root is not None: + instance._root.clean() # Trigger validations from the top + def _fetch_objects(self, instance, key, identifier): """Fetch single linked object""" try: @@ -526,6 +529,10 @@ def add(self, instance, items) -> None: current_value_ids = [value.id for value in data] + # Remove items when set to empty + if len(items) == 0 and len(current_value_ids) > 0: + self.remove(instance, data) + for item in items: # Items to add if item.id not in current_value_ids: @@ -563,6 +570,9 @@ def add(self, instance, items) -> None: # Reset Cache self.delete_cached_value(instance) + if instance._initialized and instance._root is not None: + instance._root.clean() # Trigger validations from the top + def remove(self, instance, items) -> None: """ Available as `add_` method on the entity instance. @@ -609,6 +619,9 @@ def remove(self, instance, items) -> None: elif isinstance(field_obj, HasOne): setattr(item, field_name, None) + if instance._initialized and instance._root is not None: + instance._root.clean() # Trigger validations from the top + def _fetch_objects(self, instance, key, value) -> list: """ Fetch linked entities. diff --git a/src/protean/fields/base.py b/src/protean/fields/base.py index 70b56383..74309dad 100644 --- a/src/protean/fields/base.py +++ b/src/protean/fields/base.py @@ -140,8 +140,17 @@ def __get__(self, instance, owner): def __set__(self, instance, value): value = self._load(value) + instance.__dict__[self.field_name] = value + # The hasattr check is necessary to avoid running clean on unrelated elements + if ( + instance._initialized + and hasattr(instance, "_root") + and instance._root is not None + ): + instance._root.clean() # Trigger validations from the top + # Mark Entity as Dirty if hasattr(instance, "state_"): instance.state_.mark_changed() diff --git a/src/protean/fields/basic.py b/src/protean/fields/basic.py index dca0805c..d8b18c0a 100644 --- a/src/protean/fields/basic.py +++ b/src/protean/fields/basic.py @@ -421,12 +421,7 @@ def __set__(self, instance, value): if existing_value is not None and value != existing_value: raise InvalidOperationError("Identifiers cannot be changed once set") - value = self._load(value) - instance.__dict__[self.field_name] = value - - if hasattr(instance, "state_"): - # Mark Entity as Dirty - instance.state_.mark_changed() + super().__set__(instance, value) def as_dict(self, value): """Return JSON-compatible value of self""" diff --git a/tests/aggregate/test_atomic_change.py b/tests/aggregate/test_atomic_change.py new file mode 100644 index 00000000..b6c88a3f --- /dev/null +++ b/tests/aggregate/test_atomic_change.py @@ -0,0 +1,45 @@ +"""Test `atomic_change` context manager""" + +import pytest + +from protean import BaseAggregate, atomic_change, invariant +from protean.fields import Integer +from protean.exceptions import ValidationError + + +class TestAtomicChange: + def test_atomic_change_context_manager(self): + class TestAggregate(BaseAggregate): + pass + + aggregate = TestAggregate() + + with atomic_change(aggregate): + assert aggregate._disable_invariant_checks is True + + assert aggregate._disable_invariant_checks is False + + def test_clean_is_not_triggered_within_context_manager(self, test_domain): + class TestAggregate(BaseAggregate): + value1 = Integer() + value2 = Integer() + + @invariant + def raise_error(self): + if self.value2 != self.value1 + 1: + raise ValidationError({"_entity": ["Invariant error"]}) + + test_domain.register(TestAggregate) + test_domain.init(traverse=False) + + aggregate = TestAggregate(value1=1, value2=2) + + # This raises an error because of the invariant + with pytest.raises(ValidationError): + aggregate.value1 = 2 + aggregate.value2 = 3 + + # This should not raise an error because of the context manager + with atomic_change(aggregate): + aggregate.value1 = 2 + aggregate.value2 = 3 diff --git a/tests/entity/elements.py b/tests/entity/elements.py index 0446939e..df1a458f 100644 --- a/tests/entity/elements.py +++ b/tests/entity/elements.py @@ -1,7 +1,7 @@ -from collections import defaultdict from enum import Enum -from protean import BaseAggregate, BaseEntity +from protean import BaseAggregate, BaseEntity, invariant +from protean.exceptions import ValidationError from protean.fields import Auto, HasOne, Integer, String @@ -149,10 +149,9 @@ def defaults(self): else: self.status = BuildingStatus.WIP.value - def clean(self): - errors = defaultdict(list) - + @invariant + def test_building_status_to_be_done_if_floors_above_4(self): if self.floors >= 4 and self.status != BuildingStatus.DONE.value: - errors["status"].append("should be DONE") - - return errors + raise ValidationError( + {"_entity": ["Building status should be DONE if floors are above 4"]} + ) diff --git a/tests/entity/test_invariants.py b/tests/entity/invariants/test_invariant_decorator.py similarity index 83% rename from tests/entity/test_invariants.py rename to tests/entity/invariants/test_invariant_decorator.py index 972543aa..847324be 100644 --- a/tests/entity/test_invariants.py +++ b/tests/entity/invariants/test_invariant_decorator.py @@ -47,9 +47,9 @@ def test_that_entity_has_recorded_invariants(test_domain): assert len(Order._invariants) == 3 # Methods are presented in ascending order (alphabetical order) of member names. - assert Order._invariants[0].__name__ == "item_quantities_should_be_positive" - assert Order._invariants[1].__name__ == "must_have_at_least_one_item" - assert Order._invariants[2].__name__ == "total_should_be_sum_of_item_prices" + assert "item_quantities_should_be_positive" in Order._invariants + assert "must_have_at_least_one_item" in Order._invariants + assert "total_should_be_sum_of_item_prices" in Order._invariants assert len(OrderItem._invariants) == 1 - assert OrderItem._invariants[0].__name__ == "price_should_be_non_negative" + assert "price_should_be_non_negative" in OrderItem._invariants diff --git a/tests/entity/invariants/test_invariant_triggerring.py b/tests/entity/invariants/test_invariant_triggerring.py new file mode 100644 index 00000000..b3970db3 --- /dev/null +++ b/tests/entity/invariants/test_invariant_triggerring.py @@ -0,0 +1,287 @@ +import pytest + +from datetime import date +from enum import Enum + +from protean import BaseAggregate, BaseEntity, invariant, atomic_change +from protean.exceptions import ValidationError +from protean.fields import Date, Float, Identifier, Integer, String, HasMany + + +class OrderStatus(Enum): + PENDING = "PENDING" + SHIPPED = "SHIPPED" + DELIVERED = "DELIVERED" + + +class Order(BaseAggregate): + customer_id = Identifier() + order_date = Date() + total_amount = Float() + status = String(max_length=50, choices=OrderStatus) + items = HasMany("OrderItem") + + @invariant + def total_amount_of_order_must_equal_sum_of_subtotal_of_all_items(self): + if self.total_amount != sum(item.subtotal for item in self.items): + raise ValidationError({"_entity": ["Total should be sum of item prices"]}) + + @invariant + def order_date_must_be_within_the_last_30_days_if_status_is_pending(self): + if self.status == OrderStatus.PENDING.value and self.order_date < date( + 2020, 1, 1 + ): + raise ValidationError( + { + "_entity": [ + "Order date must be within the last 30 days if status is PENDING" + ] + } + ) + + @invariant + def customer_id_must_be_non_null_and_the_order_must_contain_at_least_one_item(self): + if not self.customer_id or not self.items: + raise ValidationError( + { + "_entity": [ + "Customer ID must be non-null and the order must contain at least one item" + ] + } + ) + + +class OrderItem(BaseEntity): + product_id = Identifier() + quantity = Integer() + price = Float() + subtotal = Float() + + class Meta: + part_of = Order + + @invariant + def the_quantity_must_be_a_positive_integer_and_the_subtotal_must_be_correctly_calculated( + self, + ): + if self.quantity <= 0 or self.subtotal != self.quantity * self.price: + raise ValidationError( + { + "_entity": [ + "Quantity must be a positive integer and the subtotal must be correctly calculated" + ] + } + ) + + +@pytest.fixture(autouse=True) +def register_elements(test_domain): + test_domain.register(OrderItem) + test_domain.register(Order) + test_domain.init(traverse=False) + + +class TestEntityInvariantsOnInitialization: + def test_with_valid_data(self): + order = Order( + customer_id="1", + order_date="2020-01-01", + total_amount=100.0, + status="PENDING", + items=[ + OrderItem(product_id="1", quantity=4, price=10.0, subtotal=40.0), + OrderItem(product_id="2", quantity=3, price=20.0, subtotal=60.0), + ], + ) + assert order is not None + + def test_when_total_amount_is_not_sum_of_item_subtotals(self): + with pytest.raises(ValidationError) as exc: + Order( + customer_id="1", + order_date="2020-01-01", + total_amount=100.0, + status="PENDING", + items=[ + OrderItem(product_id="1", quantity=2, price=10.0, subtotal=20.0), + OrderItem(product_id="2", quantity=3, price=20.0, subtotal=60.0), + ], + ) + + assert exc.value.messages["_entity"] == ["Total should be sum of item prices"] + + def test_when_order_date_is_not_within_the_last_30_days(self): + with pytest.raises(ValidationError) as exc: + Order( + customer_id="1", + order_date="2019-12-01", + total_amount=100.0, + status="PENDING", + items=[ + OrderItem(product_id="1", quantity=4, price=10.0, subtotal=40.0), + OrderItem(product_id="2", quantity=3, price=20.0, subtotal=60.0), + ], + ) + + assert exc.value.messages["_entity"] == [ + "Order date must be within the last 30 days if status is PENDING" + ] + + def test_when_customer_ID_is_null(self): + with pytest.raises(ValidationError) as exc: + Order( + customer_id=None, + order_date="2020-01-01", + total_amount=100.0, + status="PENDING", + items=[ + OrderItem(product_id="1", quantity=4, price=10.0, subtotal=40.0), + OrderItem(product_id="2", quantity=3, price=20.0, subtotal=60.0), + ], + ) + + assert exc.value.messages["_entity"] == [ + "Customer ID must be non-null and the order must contain at least one item" + ] + + def test_when_items_are_empty(self): + with pytest.raises(ValidationError) as exc: + Order( + customer_id="1", + order_date="2020-01-01", + total_amount=100.0, + status="PENDING", + items=[], + ) + + assert exc.value.messages["_entity"] == [ + "Customer ID must be non-null and the order must contain at least one item", + "Total should be sum of item prices", + ] + + def test_when_quantity_is_negative(self): + with pytest.raises(ValidationError) as exc: + Order( + customer_id="1", + order_date="2020-01-01", + total_amount=100.0, + status="PENDING", + items=[ + OrderItem(product_id="1", quantity=-1, price=10.0, subtotal=10.0), + OrderItem(product_id="2", quantity=3, price=20.0, subtotal=60.0), + ], + ) + + assert exc.value.messages["_entity"] == [ + "Quantity must be a positive integer and the subtotal must be correctly calculated" + ] + + def test_when_subtotal_is_incorrect(self): + with pytest.raises(ValidationError) as exc: + Order( + customer_id="1", + order_date="2020-01-01", + total_amount=100.0, + status="PENDING", + items=[ + OrderItem(product_id="1", quantity=1, price=10.0, subtotal=20.0), + OrderItem(product_id="2", quantity=3, price=20.0, subtotal=60.0), + ], + ) + + assert exc.value.messages["_entity"] == [ + "Quantity must be a positive integer and the subtotal must be correctly calculated" + ] + + +@pytest.fixture +def order(): + return Order( + customer_id="1", + order_date="2020-01-01", + total_amount=100.0, + status="PENDING", + items=[ + OrderItem(product_id="1", quantity=4, price=10.0, subtotal=40.0), + OrderItem(product_id="2", quantity=3, price=20.0, subtotal=60.0), + ], + ) + + +class TestEntityInvariantsOnAttributeChanges: + def test_when_total_amount_is_not_sum_of_item_subtotals(self, order): + with pytest.raises(ValidationError) as exc: + order.total_amount = 50.0 + + assert exc.value.messages["_entity"] == ["Total should be sum of item prices"] + + def test_when_order_date_is_not_within_the_last_30_days(self, order): + with pytest.raises(ValidationError) as exc: + order.order_date = "2019-12-01" + + assert exc.value.messages["_entity"] == [ + "Order date must be within the last 30 days if status is PENDING" + ] + + def test_when_customer_ID_is_null(self, order): + with pytest.raises(ValidationError) as exc: + order.customer_id = None + + assert exc.value.messages["_entity"] == [ + "Customer ID must be non-null and the order must contain at least one item" + ] + + def test_when_items_are_empty(self, order): + with pytest.raises(ValidationError) as exc: + order.items = [] + + assert exc.value.messages["_entity"] == [ + "Customer ID must be non-null and the order must contain at least one item", + "Total should be sum of item prices", + ] + + def test_when_invalid_item_is_added(self, order): + with pytest.raises(ValidationError) as exc: + order.add_items( + OrderItem(product_id="3", quantity=2, price=10.0, subtotal=40.0) + ) + + assert exc.value.messages["_entity"] == [ + "Quantity must be a positive integer and the subtotal must be correctly calculated" + ] + + def test_when_item_is_added_along_with_total_amount(self, order): + try: + with atomic_change(order): + order.total_amount = 120.0 + order.add_items( + OrderItem(product_id="3", quantity=2, price=10.0, subtotal=20.0) + ) + except ValidationError: + pytest.fail("Failed to batch update attributes") + + def test_when_quantity_is_negative(self, order): + with pytest.raises(ValidationError) as exc: + order.items[0].quantity = -1 + + assert exc.value.messages["_entity"] == [ + "Quantity must be a positive integer and the subtotal must be correctly calculated" + ] + + def test_when_invalid_item_is_added_after_initialization(self, order): + with pytest.raises(ValidationError) as exc: + order.add_items( + OrderItem(product_id="3", quantity=2, price=10.0, subtotal=40.0) + ) + + assert exc.value.messages["_entity"] == [ + "Quantity must be a positive integer and the subtotal must be correctly calculated" + ] + + def test_when_item_price_is_changed_to_negative(self, order): + with pytest.raises(ValidationError) as exc: + order.items[0].price = -10.0 + + assert exc.value.messages["_entity"] == [ + "Quantity must be a positive integer and the subtotal must be correctly calculated" + ] diff --git a/tests/entity/test_lifecycle_methods.py b/tests/entity/test_lifecycle_methods.py index d6e3436e..cf6810bd 100644 --- a/tests/entity/test_lifecycle_methods.py +++ b/tests/entity/test_lifecycle_methods.py @@ -2,7 +2,7 @@ from protean.exceptions import ValidationError -from .elements import Building, BuildingStatus +from .elements import Area, Building, BuildingStatus class TestDefaults: @@ -18,6 +18,10 @@ def test_that_building_is_marked_as_done_if_below_4_floors(self): class TestClean: - def test_that_building_cannot_be_WIP_if_above_4_floors(self): + def test_that_building_cannot_be_WIP_if_above_4_floors(self, test_domain): + test_domain.register(Building) + test_domain.register(Area) + test_domain.init(traverse=False) + with pytest.raises(ValidationError): Building(name="Foo", floors=4, status=BuildingStatus.WIP.value)