diff --git a/src/protean/__init__.py b/src/protean/__init__.py index c5ba12db..e31a1165 100644 --- a/src/protean/__init__.py +++ b/src/protean/__init__.py @@ -6,7 +6,7 @@ from .core.command_handler import BaseCommandHandler from .core.domain_service import BaseDomainService from .core.email import BaseEmail -from .core.entity import BaseEntity +from .core.entity import BaseEntity, invariant from .core.event import BaseEvent from .core.event_handler import BaseEventHandler from .core.event_sourced_aggregate import BaseEventSourcedAggregate, apply @@ -53,4 +53,5 @@ "current_uow", "get_version", "handle", + "invariant", ] diff --git a/src/protean/core/aggregate.py b/src/protean/core/aggregate.py index da6dc2cb..470f99fe 100644 --- a/src/protean/core/aggregate.py +++ b/src/protean/core/aggregate.py @@ -1,5 +1,6 @@ """Aggregate Functionality and Classes""" +import inspect import logging from protean.container import EventedMixin @@ -65,4 +66,14 @@ def _default_options(cls): def aggregate_factory(element_cls, **kwargs): - return derive_element_class(element_cls, BaseAggregate, **kwargs) + element_cls = derive_element_class(element_cls, BaseAggregate, **kwargs) + + # Iterate through methods marked as `@invariant` and record them for later use + methods = inspect.getmembers(element_cls, predicate=inspect.isroutine) + for method_name, method in methods: + if not ( + method_name.startswith("__") and method_name.endswith("__") + ) and hasattr(method, "_invariant"): + element_cls._invariants.append(method) + + return element_cls diff --git a/src/protean/core/entity.py b/src/protean/core/entity.py index 324cd94d..c9da569a 100644 --- a/src/protean/core/entity.py +++ b/src/protean/core/entity.py @@ -1,6 +1,8 @@ """Entity Functionality and Classes""" import copy +import functools +import inspect import logging from collections import defaultdict @@ -446,6 +448,12 @@ def _set_root_and_owner(self, root, owner): if not item._root: item._set_root_and_owner(self._root, self) + def __init_subclass__(subclass) -> None: + super().__init_subclass__() + + # Record invariant methods + setattr(subclass, "_invariants", []) + def entity_factory(element_cls, **kwargs): element_cls = derive_element_class(element_cls, BaseEntity, **kwargs) @@ -505,4 +513,24 @@ def entity_factory(element_cls, **kwargs): shadow_field_name, shadow_field = field.get_shadow_field() shadow_field.__set_name__(element_cls, shadow_field_name) + # Iterate through methods marked as `@invariant` and record them for later use + methods = inspect.getmembers(element_cls, predicate=inspect.isroutine) + for method_name, method in methods: + if not ( + method_name.startswith("__") and method_name.endswith("__") + ) and hasattr(method, "_invariant"): + element_cls._invariants.append(method) + return element_cls + + +def invariant(fn): + """Decorator to mark invariant methods in an Entity""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + setattr(wrapper, "_invariant", True) + + return wrapper diff --git a/tests/entity/test_invariants.py b/tests/entity/test_invariants.py new file mode 100644 index 00000000..972543aa --- /dev/null +++ b/tests/entity/test_invariants.py @@ -0,0 +1,55 @@ +from protean import BaseAggregate, BaseEntity, invariant +from protean.exceptions import ValidationError +from protean.fields import Date, Float, Integer, String, HasMany + + +class Order(BaseAggregate): + ordered_on = Date() + total = Float() + items = HasMany("OrderItem") + + @invariant + def total_should_be_sum_of_item_prices(self): + if self.items: + if self.total != sum([item.price for item in self.items]): + raise ValidationError("Total should be sum of item prices") + + @invariant + def must_have_at_least_one_item(self): + if not self.items or len(self.items) == 0: + raise ValidationError("Order must contain at least one item") + + @invariant + def item_quantities_should_be_positive(self): + for item in self.items: + if item.quantity <= 0: + raise ValidationError("Item quantities should be positive") + + +class OrderItem(BaseEntity): + product_id = String(max_length=50) + quantity = Integer() + price = Float() + + class Meta: + part_of = Order + + @invariant + def price_should_be_non_negative(self): + if self.price < 0: + raise ValidationError("Item price should be non-negative") + + +def test_that_entity_has_recorded_invariants(test_domain): + test_domain.register(OrderItem) + test_domain.register(Order) + test_domain.init(traverse=False) + + assert len(Order._invariants) == 3 + # Methods are presented in ascending order (alphabetical order) of member names. + assert Order._invariants[0].__name__ == "item_quantities_should_be_positive" + assert Order._invariants[1].__name__ == "must_have_at_least_one_item" + assert Order._invariants[2].__name__ == "total_should_be_sum_of_item_prices" + + assert len(OrderItem._invariants) == 1 + assert OrderItem._invariants[0].__name__ == "price_should_be_non_negative"