Skip to content

Commit

Permalink
Introduce @invariant decorator
Browse files Browse the repository at this point in the history
Add @invariant decorator, and configure aggregate and element classes
to parse through methods and record them for later use.
  • Loading branch information
subhashb committed May 21, 2024
1 parent 7171e33 commit 0e89328
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/protean/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .core.command_handler import BaseCommandHandler
from .core.domain_service import BaseDomainService
from .core.email import BaseEmail
from .core.entity import BaseEntity
from .core.entity import BaseEntity, invariant
from .core.event import BaseEvent
from .core.event_handler import BaseEventHandler
from .core.event_sourced_aggregate import BaseEventSourcedAggregate, apply
Expand Down Expand Up @@ -53,4 +53,5 @@
"current_uow",
"get_version",
"handle",
"invariant",
]
13 changes: 12 additions & 1 deletion src/protean/core/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Aggregate Functionality and Classes"""

import inspect
import logging

from protean.container import EventedMixin
Expand Down Expand Up @@ -65,4 +66,14 @@ def _default_options(cls):


def aggregate_factory(element_cls, **kwargs):
return derive_element_class(element_cls, BaseAggregate, **kwargs)
element_cls = derive_element_class(element_cls, BaseAggregate, **kwargs)

# Iterate through methods marked as `@invariant` and record them for later use
methods = inspect.getmembers(element_cls, predicate=inspect.isroutine)
for method_name, method in methods:
if not (
method_name.startswith("__") and method_name.endswith("__")
) and hasattr(method, "_invariant"):
element_cls._invariants.append(method)

return element_cls
28 changes: 28 additions & 0 deletions src/protean/core/entity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Entity Functionality and Classes"""

import copy
import functools
import inspect
import logging

from collections import defaultdict
Expand Down Expand Up @@ -446,6 +448,12 @@ def _set_root_and_owner(self, root, owner):
if not item._root:
item._set_root_and_owner(self._root, self)

def __init_subclass__(subclass) -> None:
super().__init_subclass__()

# Record invariant methods
setattr(subclass, "_invariants", [])


def entity_factory(element_cls, **kwargs):
element_cls = derive_element_class(element_cls, BaseEntity, **kwargs)
Expand Down Expand Up @@ -505,4 +513,24 @@ def entity_factory(element_cls, **kwargs):
shadow_field_name, shadow_field = field.get_shadow_field()
shadow_field.__set_name__(element_cls, shadow_field_name)

# Iterate through methods marked as `@invariant` and record them for later use
methods = inspect.getmembers(element_cls, predicate=inspect.isroutine)
for method_name, method in methods:
if not (
method_name.startswith("__") and method_name.endswith("__")
) and hasattr(method, "_invariant"):
element_cls._invariants.append(method)

return element_cls


def invariant(fn):
"""Decorator to mark invariant methods in an Entity"""

@functools.wraps(fn)
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)

setattr(wrapper, "_invariant", True)

return wrapper
55 changes: 55 additions & 0 deletions tests/entity/test_invariants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from protean import BaseAggregate, BaseEntity, invariant
from protean.exceptions import ValidationError
from protean.fields import Date, Float, Integer, String, HasMany


class Order(BaseAggregate):
ordered_on = Date()
total = Float()
items = HasMany("OrderItem")

@invariant
def total_should_be_sum_of_item_prices(self):
if self.items:
if self.total != sum([item.price for item in self.items]):
raise ValidationError("Total should be sum of item prices")

@invariant
def must_have_at_least_one_item(self):
if not self.items or len(self.items) == 0:
raise ValidationError("Order must contain at least one item")

@invariant
def item_quantities_should_be_positive(self):
for item in self.items:
if item.quantity <= 0:
raise ValidationError("Item quantities should be positive")


class OrderItem(BaseEntity):
product_id = String(max_length=50)
quantity = Integer()
price = Float()

class Meta:
part_of = Order

@invariant
def price_should_be_non_negative(self):
if self.price < 0:
raise ValidationError("Item price should be non-negative")


def test_that_entity_has_recorded_invariants(test_domain):
test_domain.register(OrderItem)
test_domain.register(Order)
test_domain.init(traverse=False)

assert len(Order._invariants) == 3
# Methods are presented in ascending order (alphabetical order) of member names.
assert Order._invariants[0].__name__ == "item_quantities_should_be_positive"
assert Order._invariants[1].__name__ == "must_have_at_least_one_item"
assert Order._invariants[2].__name__ == "total_should_be_sum_of_item_prices"

assert len(OrderItem._invariants) == 1
assert OrderItem._invariants[0].__name__ == "price_should_be_non_negative"

0 comments on commit 0e89328

Please sign in to comment.