diff --git a/docs/guides/domain-behavior/domain-services.md b/docs/guides/domain-behavior/domain-services.md index 75852504..020f684f 100644 --- a/docs/guides/domain-behavior/domain-services.md +++ b/docs/guides/domain-behavior/domain-services.md @@ -65,11 +65,20 @@ and invokes the service method to place order. The service method executes the business logic, mutates the aggregates, and returns them to the application service, which then persists them again with the help of repositories. +## Invariants + +Just like Aggregates and Entities, Domain Services can also have invariants. +These invariants are used to validate the state of the aggregates passed to +the service method. Unlike in Aggregates though, invariants in Domain Services +typically deal with validations that span across multiple aggregates. + +`pre` invariants check the state of the aggregates before they are mutated, +while `post` invariants check the state after the mutation. ## A full-blown example -```python hl_lines="67-82" ---8<-- "guides/domain-behavior/006.py:16:98" +```python hl_lines="142-149" +{! docs_src/guides/domain-behavior/006.py !} ``` When an order is placed, `Order` status has to be `CONFIRMED` _and_ the stock @@ -83,3 +92,8 @@ orders are placed at the same time. So a Domain Service works best here because it updates the states of both the `Order` aggregate as well as the `Inventory` aggregate in a single transaction. + +**IMPORTANT**: Even though the inventory aggregate is mutated here to ensure +all invariants are satisified, the Command Handler method invoking the Domain +Service should only persist the `Order` aggregate. The `Inventory` aggregate +will eventually be updated through the domain event `OrderConfirmed`. diff --git a/docs/guides/domain-behavior/invariants.md b/docs/guides/domain-behavior/invariants.md index ca0e51fd..c8a8b773 100644 --- a/docs/guides/domain-behavior/invariants.md +++ b/docs/guides/domain-behavior/invariants.md @@ -6,7 +6,8 @@ the concept, ensuring it remains unchanged even as other aspects evolve play a crucial role in ensuring business validations within a domain. Protean treats invariants as first-class citizens, to make them explicit and -visible, making it easier to maintain the integrity of the domain model. +visible, making it easier to maintain the integrity of the domain model. You +can define invariants on Aggregates, Entities, and Value Objects. ## Key Facts @@ -19,12 +20,12 @@ aggregate cluster. - **Domain-Driven:** Invariants stem from the business rules and policies specific to a domain. - **Enforced by the Domain Model:** Protean takes on the responsibility of -enforcing invariants. +enforcing invariants. ## `@invariant` decorator -Invariants are defined using the `@invariant` decorator in Aggregates and -Entities: +Invariants are defined using the `@invariant` decorator in Aggregates, +Entities, and Value Objects (plus in Domain Services, as we will soon see): ```python hl_lines="9-10 14-15" --8<-- "guides/domain-behavior/001.py:17:41" @@ -38,6 +39,34 @@ of individual item subtotals, and the other that the order date must be within All methods marked `@invariant` are associated with the domain element when the element is registered with the domain. +## `pre` and `post` Invariants + +The `@invariant` decorator has two flavors - **`pre`** and **`post`**. + +`pre` invariants are triggered before elements are updated, while `post` +invariants are triggered after the update. `pre` invariants are used to prevent +invalid state from being introduced, while `post` invariants ensure that the +aggregate remains in a valid state after the update. + +In Protean, we will mostly be using `post` invariants because the domain model +is expected to remain valid after any operation. You would typically start +with the domain in a good state, mutate the elements, and check if all +invariants are satisfied. + +`pre` invariants are useful in certain situations where you want to check state +before the elements are mutated. For instance, you might want to check if a +user has enough balance before deducting it. Also, some invariant checks may +be easier to add *before* changing an element. + +!!!note + `pre` invariants are not applicable when aggregates and entities are being + initialized. Their validations only kick in when an element is being + changed or updated from an existing state. + +!!!note + `pre` invariant checks are not applicable to `ValueObject` elements because + they are immutable - they cannot be changed once initialized. + ## Validation Invariant validations are triggered throughout the lifecycle of domain objects, @@ -129,4 +158,8 @@ in a `ValidationError` exception: In [4]: order.total_amount = 120.0 ... ValidationError: {'_entity': ['Total should be sum of item prices']} -``` \ No newline at end of file +``` + +!!!note + Atomic Changes context manager can only be applied when updating or + changing an already initialized element. \ No newline at end of file diff --git a/docs/guides/domain-definition/fields/index.md b/docs/guides/domain-definition/fields/index.md index 2214a28f..21231c68 100644 --- a/docs/guides/domain-definition/fields/index.md +++ b/docs/guides/domain-definition/fields/index.md @@ -51,5 +51,5 @@ manage related objects efficiently, preserving data integrity across the domain. \ No newline at end of file diff --git a/docs/guides/domain-definition/value-objects.md b/docs/guides/domain-definition/value-objects.md index a1720ffe..a7c189c0 100644 --- a/docs/guides/domain-definition/value-objects.md +++ b/docs/guides/domain-definition/value-objects.md @@ -165,6 +165,25 @@ satisfied at all times. It is recommended that you always deal with Value Objects by their class. Attributes are generally used by Protean during persistence and retrieval. +## Invariants + +When a validation spans across multiple fields, you can specify it in an +`invariant` method. These methods are executed every time the value object is +initialized. + +```python hl_lines="13-16" +{! docs_src/guides/domain-definition/012.py !} +``` + +```shell hl_lines="3" +In [1]: Balance(currency="USD", amount=-100) +... +ValidationError: {'balance': ['Balance cannot be negative for USD']} +``` + +Refer to [`invariants`](../domain-behavior/invariants.md) section for a +deeper explanation of invariants. + ## Equality Two value objects are considered to be equal if their values are equal. diff --git a/docs_src/guides/domain-behavior/001.py b/docs_src/guides/domain-behavior/001.py index 23a4e35c..3031aa81 100644 --- a/docs_src/guides/domain-behavior/001.py +++ b/docs_src/guides/domain-behavior/001.py @@ -22,12 +22,12 @@ class Order: status = String(max_length=50, choices=OrderStatus) items = HasMany("OrderItem") - @invariant + @invariant.post def total_amount_of_order_must_equal_sum_of_subtotal_of_all_items(self): if self.total_amount != sum(item.subtotal for item in self.items): raise ValidationError({"_entity": ["Total should be sum of item prices"]}) - @invariant + @invariant.post def order_date_must_be_within_the_last_30_days_if_status_is_pending(self): if self.status == OrderStatus.PENDING.value and self.order_date < date( 2020, 1, 1 @@ -40,7 +40,7 @@ def order_date_must_be_within_the_last_30_days_if_status_is_pending(self): } ) - @invariant + @invariant.post def customer_id_must_be_non_null_and_the_order_must_contain_at_least_one_item(self): if not self.customer_id or not self.items: raise ValidationError( @@ -62,7 +62,7 @@ class OrderItem: class Meta: part_of = Order - @invariant + @invariant.post def the_quantity_must_be_a_positive_integer_and_the_subtotal_must_be_correctly_calculated( self, ): diff --git a/docs_src/guides/domain-behavior/002.py b/docs_src/guides/domain-behavior/002.py index b453dde1..110f70c6 100644 --- a/docs_src/guides/domain-behavior/002.py +++ b/docs_src/guides/domain-behavior/002.py @@ -20,7 +20,7 @@ class Account: balance = Float() overdraft_limit = Float(default=0.0) - @invariant + @invariant.post def balance_must_be_greater_than_or_equal_to_overdraft_limit(self): if self.balance < -self.overdraft_limit: raise InsufficientFundsException("Balance cannot be below overdraft limit") diff --git a/docs_src/guides/domain-behavior/006.py b/docs_src/guides/domain-behavior/006.py index bd807cbe..9f1ccd13 100644 --- a/docs_src/guides/domain-behavior/006.py +++ b/docs_src/guides/domain-behavior/006.py @@ -1,7 +1,8 @@ from datetime import datetime, timezone from enum import Enum -from protean import Domain +from protean import Domain, invariant +from protean.exceptions import ValidationError from protean.fields import ( DateTime, Float, @@ -9,6 +10,7 @@ Integer, HasMany, String, + ValueObject, ) domain = Domain(__file__) @@ -21,14 +23,11 @@ class OrderStatus(Enum): DELIVERED = "DELIVERED" -@domain.event +@domain.event(part_of="Order") class OrderConfirmed: order_id = Identifier(required=True) confirmed_at = DateTime(required=True) - class Meta: - part_of = "Order" - @domain.aggregate class Order: @@ -37,6 +36,11 @@ class Order: status = String(choices=OrderStatus, default=OrderStatus.PENDING.value) payment_id = Identifier() + @invariant.post + def order_should_contain_items(self): + if not self.items or len(self.items) == 0: + raise ValidationError({"_entity": ["Order must contain at least one item"]}) + def confirm(self): self.status = OrderStatus.CONFIRMED.value self.raise_( @@ -44,30 +48,31 @@ def confirm(self): ) -@domain.entity +@domain.entity(part_of=Order) class OrderItem: product_id = Identifier(required=True) quantity = Integer() price = Float() - class Meta: - part_of = Order + +@domain.value_object(part_of="Inventory") +class Warehouse: + location = String() + contact = String() -@domain.event +@domain.event(part_of="Inventory") class StockReserved: product_id = Identifier(required=True) quantity = Integer(required=True) reserved_at = DateTime(required=True) - class Meta: - part_of = "Inventory" - @domain.aggregate class Inventory: product_id = Identifier(required=True) quantity = Integer() + warehouse = ValueObject(Warehouse) def reserve_stock(self, quantity: int): self.quantity -= quantity @@ -81,20 +86,64 @@ def reserve_stock(self, quantity: int): @domain.domain_service(part_of=[Order, Inventory]) -class OrderPlacementService: - @classmethod - def place_order( - cls, order: Order, inventories: list[Inventory] - ) -> tuple[Order, list[Inventory]]: - for item in order.items: +class place_order: + def __init__(self, order, inventories): + super().__init__(*(order, inventories)) + + self.order = order + self.inventories = inventories + + @invariant.pre + def inventory_should_have_sufficient_stock(self): + for item in self.order.items: inventory = next( - (i for i in inventories if i.product_id == item.product_id), None + (i for i in self.inventories if i.product_id == item.product_id), None ) if inventory is None or inventory.quantity < item.quantity: - raise Exception("Product is out of stock") + raise ValidationError({"_service": ["Product is out of stock"]}) - inventory.reserve_stock(item.quantity) + @invariant.pre + def order_payment_method_should_be_valid(self): + if not self.order.payment_id: + raise ValidationError( + {"_service": ["Order must have a valid payment method"]} + ) + + @invariant.post + def total_reserved_value_should_match_order_value(self): + order_total = sum(item.quantity * item.price for item in self.order.items) + reserved_total = 0 + for item in self.order.items: + inventory = next( + (i for i in self.inventories if i.product_id == item.product_id), None + ) + if inventory: + reserved_total += inventory._events[0].quantity * item.price + + if order_total != reserved_total: + raise ValidationError( + {"_service": ["Total reserved value does not match order value"]} + ) + + @invariant.post + def total_quantity_reserved_should_match_order_quantity(self): + order_quantity = sum(item.quantity for item in self.order.items) + reserved_quantity = sum( + inventory._events[0].quantity + for inventory in self.inventories + if inventory._events + ) - order.confirm() + if order_quantity != reserved_quantity: + raise ValidationError( + {"_service": ["Total reserved quantity does not match order quantity"]} + ) + + def __call__(self): + for item in self.order.items: + inventory = next( + (i for i in self.inventories if i.product_id == item.product_id), None + ) + inventory.reserve_stock(item.quantity) - return order, inventories + self.order.confirm() diff --git a/docs_src/guides/domain-definition/012.py b/docs_src/guides/domain-definition/012.py new file mode 100644 index 00000000..ebd094dc --- /dev/null +++ b/docs_src/guides/domain-definition/012.py @@ -0,0 +1,16 @@ +from protean import Domain, invariant +from protean.exceptions import ValidationError +from protean.fields import Float, String + +domain = Domain(__name__, load_toml=False) + + +@domain.value_object +class Balance: + currency = String(max_length=3, required=True) + amount = Float(required=True) + + @invariant.post + def check_balance_is_positive_if_currency_is_USD(self): + if self.amount < 0 and self.currency == "USD": + raise ValidationError({"balance": ["Balance cannot be negative for USD"]}) diff --git a/src/protean/container.py b/src/protean/container.py index 4c596305..35d106d3 100644 --- a/src/protean/container.py +++ b/src/protean/container.py @@ -268,11 +268,6 @@ def __init__(self, *template, **kwargs): # noqa: C901 self._initialized = True - # `clean()` will return a `defaultdict(list)` if errors are to be raised - custom_errors = self.clean() or {} - for field in custom_errors: - self.errors[field].extend(custom_errors[field]) - # Raise any errors found during load if self.errors: logger.error(self.errors) @@ -283,12 +278,6 @@ def defaults(self): To be overridden in concrete Containers, when an attribute's default depends on other attribute values. """ - def clean(self): - """Placeholder method for validations. - To be overridden in concrete Containers, when complex validations spanning multiple fields are required. - """ - return defaultdict(list) - def __eq__(self, other): """Equivalence check for containers is based only on data. diff --git a/src/protean/core/aggregate.py b/src/protean/core/aggregate.py index 2d35ea40..c1f8a2c6 100644 --- a/src/protean/core/aggregate.py +++ b/src/protean/core/aggregate.py @@ -75,7 +75,7 @@ def aggregate_factory(element_cls, **kwargs): if not ( method_name.startswith("__") and method_name.endswith("__") ) and hasattr(method, "_invariant"): - element_cls._invariants[method_name] = method + element_cls._invariants[method._invariant][method_name] = method return element_cls @@ -87,9 +87,10 @@ def __init__(self, aggregate): def __enter__(self): # Temporary disable invariant checks + self.aggregate._precheck() self.aggregate._disable_invariant_checks = True def __exit__(self, *args): - # Run clean() on exit to trigger invariant checks + # Validate on exit to trigger invariant checks self.aggregate._disable_invariant_checks = False - self.aggregate.clean() + self.aggregate._postcheck() diff --git a/src/protean/core/domain_service.py b/src/protean/core/domain_service.py index 39c1232e..4b722405 100644 --- a/src/protean/core/domain_service.py +++ b/src/protean/core/domain_service.py @@ -1,7 +1,13 @@ +import inspect import logging +from collections import defaultdict +from functools import wraps +from typing import List, Union + +from protean import BaseAggregate from protean.container import Element, OptionsMixin -from protean.exceptions import IncorrectUsageError +from protean.exceptions import IncorrectUsageError, ValidationError from protean.utils import DomainObjects, derive_element_class logger = logging.getLogger(__name__) @@ -22,7 +28,7 @@ class Meta: def __new__(cls, *args, **kwargs): if cls is BaseDomainService: raise TypeError("BaseDomainService cannot be instantiated") - return object.__new__(cls, *args, **kwargs) + return super().__new__(cls) @classmethod def _default_options(cls): @@ -30,6 +36,75 @@ def _default_options(cls): ("part_of", None), ] + def __init_subclass__(subclass) -> None: + super().__init_subclass__() + + # Record invariant methods + setattr(subclass, "_invariants", defaultdict(dict)) + + def __init__(self, *aggregates: Union[BaseAggregate, List[BaseAggregate]]): + """ + Initializes a DomainService with one or more aggregates. + + Args: + *aggregates (Union[BaseAggregate, List[BaseAggregate]]): One or more aggregates to be associated with this + DomainService. + """ + self._aggregates = aggregates + + +def wrap_call_method_with_invariants(cls): + """ + Wraps the __call__ method of a class with a function that executes the original __call__ method and then runs + any defined invariant methods on the object. If any of the invariant methods raise a `ValidationError`, + the wrapped `__call__` method will raise a ValidationError with the collected error messages. + """ + + # Protect against re-wrapping + # by checking whether __call__ has `__wrapped__` attribute + # which it would if it has been wrapped already + # + # FIXME Is there a better way to prevent re-wrapping the same class? + if not hasattr(cls.__call__, "__wrapped__"): + original_call = cls.__call__ + + @wraps(original_call) + def wrapped_call(self, *args, **kwargs): + # Run the invariant methods marked `pre` before the original __call__ method + errors = {} + for invariant_method in self._invariants["pre"].values(): + try: + invariant_method(self) + except ValidationError as err: + for field_name in err.messages: + if field_name not in errors: + errors[field_name] = [] + errors[field_name].extend(err.messages[field_name]) + + if errors: + raise ValidationError(errors) + + # Execute the original __call__ method + result = original_call(self, *args, **kwargs) + + # Run the invariant methods marked `post` after the original __call__ method + for invariant_method in self._invariants["post"].values(): + try: + invariant_method(self) + except ValidationError as err: + for field_name in err.messages: + if field_name not in errors: + errors[field_name] = [] + errors[field_name].extend(err.messages[field_name]) + + if errors: + raise ValidationError(errors) + + return result + + cls.__call__ = wrapped_call + return cls + def domain_service_factory(element_cls, **kwargs): element_cls = derive_element_class(element_cls, BaseDomainService, **kwargs) @@ -43,4 +118,15 @@ def domain_service_factory(element_cls, **kwargs): } ) + # Iterate through methods marked as `@invariant` and record them for later use + methods = inspect.getmembers(element_cls, predicate=inspect.isroutine) + for method_name, method in methods: + if not ( + method_name.startswith("__") and method_name.endswith("__") + ) and hasattr(method, "_invariant"): + element_cls._invariants[method._invariant][method_name] = method + + # Wrap the __call__ method with invariant checks + element_cls = wrap_call_method_with_invariants(element_cls) + return element_cls diff --git a/src/protean/core/entity.py b/src/protean/core/entity.py index 78513aa0..e2b9d2dd 100644 --- a/src/protean/core/entity.py +++ b/src/protean/core/entity.py @@ -113,6 +113,36 @@ class User(BaseEntity): class Meta: abstract = True + def __init_subclass__(subclass) -> None: + super().__init_subclass__() + + # Record invariant methods + setattr(subclass, "_invariants", defaultdict(dict)) + + @classmethod + def _default_options(cls): + return [ + ("provider", "default"), + ("model", None), + ("part_of", None), + ("schema_name", inflection.underscore(cls.__name__)), + ] + + @classmethod + def _extract_options(cls, **opts): + """A stand-in method for setting customized options on the Domain Element + + Empty by default. To be overridden in each Element that expects or needs + specific options. + """ + for key, default in cls._default_options(): + value = ( + opts.pop(key, None) + or (hasattr(cls.meta_, key) and getattr(cls.meta_, key)) + or default + ) + setattr(cls.meta_, key, value) + def __init__(self, *template, **kwargs): # noqa: C901 """ Initialise the entity object. @@ -294,8 +324,8 @@ def __init__(self, *template, **kwargs): # noqa: C901 self._initialized = True - # `clean()` will return a `defaultdict(list)` if errors are to be raised - custom_errors = self.clean(return_errors=True) or {} + # `_postcheck()` will return a `defaultdict(list)` if errors are to be raised + custom_errors = self._postcheck(return_errors=True) or {} for field in custom_errors: self.errors[field].extend(custom_errors[field]) @@ -309,27 +339,33 @@ def defaults(self): To be overridden in concrete Containers, when an attribute's default depends on other attribute values. """ - def clean(self, return_errors=False): - """Invoked after initialization to perform additional validations.""" - # Call all methods marked as invariants + def _run_invariants(self, stage, return_errors=False): + """Run invariants for a given stage.""" if self._initialized and not self._disable_invariant_checks: errors = defaultdict(list) - for invariant_method in self._invariants.values(): + for invariant_method in self._invariants[stage].values(): try: invariant_method(self) except ValidationError as err: for field_name in err.messages: errors[field_name].extend(err.messages[field_name]) - # Run through all associations and trigger their clean method + # Run through all associations and trigger their invariants for field_name, field_obj in declared_fields(self).items(): - if isinstance(field_obj, Association): + if isinstance(field_obj, (Association, ValueObject)): value = getattr(self, field_name) if value is not None: items = value if isinstance(value, list) else [value] for item in items: - item_errors = item.clean(return_errors=True) + # Pre-checks don't apply to ValueObjects, because VOs are immutable + # and therefore cannot be changed once initialized. + if stage == "pre" and not isinstance( + field_obj, ValueObject + ): + item_errors = item._precheck(return_errors=True) + else: + item_errors = item._postcheck(return_errors=True) if item_errors: for sub_field_name, error_list in item_errors.items(): errors[sub_field_name].extend(error_list) @@ -340,6 +376,14 @@ def clean(self, return_errors=False): if errors: raise ValidationError(errors) + def _precheck(self, return_errors=False): + """Invariant checks performed before entity changes""" + return self._run_invariants("pre", return_errors=return_errors) + + def _postcheck(self, return_errors=False): + """Invariant checks performed after initialization and attribute changes""" + return self._run_invariants("post", return_errors=return_errors) + def __eq__(self, other): """Equivalence check to be based only on Identity""" @@ -435,30 +479,6 @@ def clone(self): return clone_copy - @classmethod - def _default_options(cls): - return [ - ("provider", "default"), - ("model", None), - ("part_of", None), - ("schema_name", inflection.underscore(cls.__name__)), - ] - - @classmethod - def _extract_options(cls, **opts): - """A stand-in method for setting customized options on the Domain Element - - Empty by default. To be overridden in each Element that expects or needs - specific options. - """ - for key, default in cls._default_options(): - value = ( - opts.pop(key, None) - or (hasattr(cls.meta_, key) and getattr(cls.meta_, key)) - or default - ) - setattr(cls.meta_, key, value) - def _set_root_and_owner(self, root, owner): """Set the root and owner entities on all child entities @@ -480,12 +500,6 @@ def _set_root_and_owner(self, root, owner): if not item._root: item._set_root_and_owner(self._root, self) - def __init_subclass__(subclass) -> None: - super().__init_subclass__() - - # Record invariant methods - setattr(subclass, "_invariants", {}) - def entity_factory(element_cls, **kwargs): element_cls = derive_element_class(element_cls, BaseEntity, **kwargs) @@ -551,18 +565,26 @@ def entity_factory(element_cls, **kwargs): if not ( method_name.startswith("__") and method_name.endswith("__") ) and hasattr(method, "_invariant"): - element_cls._invariants[method_name] = method + element_cls._invariants[method._invariant][method_name] = method return element_cls -def invariant(fn): - """Decorator to mark invariant methods in an Entity""" +class invariant: + @staticmethod + def pre(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) - @functools.wraps(fn) - def wrapper(*args, **kwargs): - return fn(*args, **kwargs) + setattr(wrapper, "_invariant", "pre") + return wrapper - setattr(wrapper, "_invariant", True) + @staticmethod + def post(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) - return wrapper + setattr(wrapper, "_invariant", "post") + return wrapper diff --git a/src/protean/core/repository.py b/src/protean/core/repository.py index cc92751d..2fde49c7 100644 --- a/src/protean/core/repository.py +++ b/src/protean/core/repository.py @@ -3,7 +3,7 @@ from functools import lru_cache from protean.container import Element, OptionsMixin -from protean.exceptions import IncorrectUsageError, ValidationError +from protean.exceptions import IncorrectUsageError from protean.fields import HasMany, HasOne from protean.globals import current_domain from protean.reflection import association_fields, has_association_fields @@ -107,14 +107,6 @@ def add(self, aggregate): # noqa: C901 is part of the DAO's design, and is automatically used wherever one tries to persist data. """ - # Ensure that aggregate is clean and good to save - # FIXME Let `clean()` raise validation errors - errors = aggregate.clean() or {} - # Raise any errors found during load - if errors: - logger.error(errors) - raise ValidationError(errors) - # If there are HasMany/HasOne fields in the aggregate, sync child objects added/removed, if has_association_fields(aggregate): self._sync_children(aggregate) diff --git a/src/protean/core/value_object.py b/src/protean/core/value_object.py index f58cb333..54e37de5 100644 --- a/src/protean/core/value_object.py +++ b/src/protean/core/value_object.py @@ -1,5 +1,6 @@ """Value Object Functionality and Classes""" +import inspect import logging from collections import defaultdict @@ -22,6 +23,9 @@ class Meta: def __init_subclass__(subclass) -> None: super().__init_subclass__() + # Record invariant methods + setattr(subclass, "_invariants", defaultdict(dict)) + subclass.__validate_for_basic_field_types() subclass.__validate_for_non_identifier_fields() subclass.__validate_for_non_unique_fields() @@ -91,7 +95,7 @@ def __init__(self, *template, **kwargs): # noqa: C901 for dictionary in template: if not isinstance(dictionary, dict): raise AssertionError( - f'Positional argument "{dictionary}" passed must be a dict.' + f"Positional argument {dictionary} passed must be a dict. " f"This argument serves as a template for loading common " f"values.", ) @@ -121,8 +125,8 @@ def __init__(self, *template, **kwargs): # noqa: C901 self.defaults() - # `clean()` will return a `defaultdict(list)` if errors are to be raised - custom_errors = self.clean() or {} + # `_postcheck()` will return a `defaultdict(list)` if errors are to be raised + custom_errors = self._postcheck() or {} for field in custom_errors: self.errors[field].extend(custom_errors[field]) @@ -147,18 +151,19 @@ def __setattr__(self, name, value): } ) - def _run_validators(self, value): - """Collect validators from enclosed fields and run them. - - This method is called during initialization of the Value Object - at the Entity level. - """ + def _postcheck(self, return_errors=False): + """Invariant checks performed after initialization""" errors = defaultdict(list) - for field_name, field_obj in fields(self).items(): + + for invariant_method in self._invariants["post"].values(): try: - field_obj._run_validators(getattr(self, field_name), value) + invariant_method(self) except ValidationError as err: - errors[field_name].extend(err.messages) + for field_name in err.messages: + errors[field_name].extend(err.messages[field_name]) + + if return_errors: + return errors if errors: raise ValidationError(errors) @@ -167,4 +172,12 @@ def _run_validators(self, value): def value_object_factory(element_cls, **kwargs): element_cls = derive_element_class(element_cls, BaseValueObject, **kwargs) + # Iterate through methods marked as `@invariant` and record them for later use + methods = inspect.getmembers(element_cls, predicate=inspect.isroutine) + for method_name, method in methods: + if not ( + method_name.startswith("__") and method_name.endswith("__") + ) and hasattr(method, "_invariant"): + element_cls._invariants[method._invariant][method_name] = method + return element_cls diff --git a/src/protean/fields/association.py b/src/protean/fields/association.py index 8c5f1eef..dcd23aa2 100644 --- a/src/protean/fields/association.py +++ b/src/protean/fields/association.py @@ -438,7 +438,7 @@ def __set__(self, instance, value): setattr(old_value, field_name, None) if instance._initialized and instance._root is not None: - instance._root.clean() # Trigger validations from the top + instance._root._postcheck() # Trigger validations from the top def _fetch_objects(self, instance, key, identifier): """Fetch single linked object""" @@ -571,7 +571,7 @@ def add(self, instance, items) -> None: self.delete_cached_value(instance) if instance._initialized and instance._root is not None: - instance._root.clean() # Trigger validations from the top + instance._root._postcheck() # Trigger validations from the top def remove(self, instance, items) -> None: """ @@ -620,7 +620,7 @@ def remove(self, instance, items) -> None: setattr(item, field_name, None) if instance._initialized and instance._root is not None: - instance._root.clean() # Trigger validations from the top + instance._root._postcheck() # Trigger validations from the top def _fetch_objects(self, instance, key, value) -> list: """ diff --git a/src/protean/fields/base.py b/src/protean/fields/base.py index 74309dad..42678661 100644 --- a/src/protean/fields/base.py +++ b/src/protean/fields/base.py @@ -141,15 +141,23 @@ def __get__(self, instance, owner): def __set__(self, instance, value): value = self._load(value) + # The hasattr check is necessary to avoid running invariant checks on unrelated elements + if ( + instance._initialized + and hasattr(instance, "_root") + and instance._root is not None + ): + instance._root._precheck() # Trigger validations from the top + instance.__dict__[self.field_name] = value - # The hasattr check is necessary to avoid running clean on unrelated elements + # The hasattr check is necessary to avoid running invariant checks on unrelated elements if ( instance._initialized and hasattr(instance, "_root") and instance._root is not None ): - instance._root.clean() # Trigger validations from the top + instance._root._postcheck() # Trigger validations from the top # Mark Entity as Dirty if hasattr(instance, "state_"): diff --git a/tests/adapters/model/dict_model/elements.py b/tests/adapters/model/dict_model/elements.py index d1ae151b..eefd9954 100644 --- a/tests/adapters/model/dict_model/elements.py +++ b/tests/adapters/model/dict_model/elements.py @@ -1,9 +1,9 @@ import re -from collections import defaultdict from typing import List -from protean import BaseAggregate, BaseModel, BaseRepository, BaseValueObject +from protean import BaseAggregate, BaseModel, BaseRepository, BaseValueObject, invariant +from protean.exceptions import ValidationError from protean.fields import Integer, String, Text, ValueObject from protean.globals import current_domain @@ -28,14 +28,11 @@ class Email(BaseValueObject): # This is the external facing data attribute address = String(max_length=254, required=True) - def clean(self): + @invariant.post + def validate_email_address(self): """Business rules of Email address""" - errors = defaultdict(list) - if not bool(re.match(Email.REGEXP, self.address)): - errors["address"].append("is invalid") - - return errors + raise ValidationError({"address": ["email address"]}) class User(BaseAggregate): diff --git a/tests/adapters/model/elasticsearch_model/elements.py b/tests/adapters/model/elasticsearch_model/elements.py index a1af8ed8..7f4953c1 100644 --- a/tests/adapters/model/elasticsearch_model/elements.py +++ b/tests/adapters/model/elasticsearch_model/elements.py @@ -1,12 +1,12 @@ import re -from collections import defaultdict from datetime import datetime from elasticsearch_dsl import Keyword, Text -from protean import BaseAggregate, BaseValueObject +from protean import BaseAggregate, BaseValueObject, invariant from protean.core.model import BaseModel +from protean.exceptions import ValidationError from protean.fields import DateTime, Integer, String from protean.fields import Text as ProteanText from protean.fields import ValueObject @@ -36,14 +36,11 @@ class Email(BaseValueObject): # This is the external facing data attribute address = String(max_length=254, required=True) - def clean(self): + @invariant.post + def validate_email_address(self): """Business rules of Email address""" - errors = defaultdict(list) - if not bool(re.match(Email.REGEXP, self.address)): - errors["address"].append("is invalid") - - return errors + raise ValidationError({"address": ["email address"]}) class ComplexUser(BaseAggregate): diff --git a/tests/adapters/model/sqlalchemy_model/postgresql/elements.py b/tests/adapters/model/sqlalchemy_model/postgresql/elements.py index 69d8dba1..deebe00c 100644 --- a/tests/adapters/model/sqlalchemy_model/postgresql/elements.py +++ b/tests/adapters/model/sqlalchemy_model/postgresql/elements.py @@ -1,11 +1,11 @@ import re -from collections import defaultdict from datetime import datetime from sqlalchemy import Column, Text -from protean import BaseAggregate, BaseModel, BaseValueObject +from protean import BaseAggregate, BaseModel, BaseValueObject, invariant +from protean.exceptions import ValidationError from protean.fields import DateTime, Integer, List, String, ValueObject @@ -27,14 +27,11 @@ class Email(BaseValueObject): # This is the external facing data attribute address = String(max_length=254, required=True) - def clean(self): + @invariant.post + def validate_email_address(self): """Business rules of Email address""" - errors = defaultdict(list) - if not bool(re.match(Email.REGEXP, self.address)): - errors["address"].append("is invalid") - - return errors + raise ValidationError({"address": ["email address"]}) class ComplexUser(BaseAggregate): diff --git a/tests/adapters/model/sqlalchemy_model/sqlite/elements.py b/tests/adapters/model/sqlalchemy_model/sqlite/elements.py index 1e1c3280..689a6fd2 100644 --- a/tests/adapters/model/sqlalchemy_model/sqlite/elements.py +++ b/tests/adapters/model/sqlalchemy_model/sqlite/elements.py @@ -1,11 +1,11 @@ import re -from collections import defaultdict from datetime import datetime from sqlalchemy import Column, Text -from protean import BaseAggregate, BaseModel, BaseValueObject +from protean import BaseAggregate, BaseModel, BaseValueObject, invariant +from protean.exceptions import ValidationError from protean.fields import DateTime, Integer, String, ValueObject @@ -27,14 +27,11 @@ class Email(BaseValueObject): # This is the external facing data attribute address = String(max_length=254, required=True) - def clean(self): + @invariant.post + def validate_email_address(self): """Business rules of Email address""" - errors = defaultdict(list) - if not bool(re.match(Email.REGEXP, self.address)): - errors["address"].append("is invalid") - - return errors + raise ValidationError({"address": ["email address"]}) class ComplexUser(BaseAggregate): diff --git a/tests/adapters/repository/elasticsearch_repo/elements.py b/tests/adapters/repository/elasticsearch_repo/elements.py index 870868e1..b621959a 100644 --- a/tests/adapters/repository/elasticsearch_repo/elements.py +++ b/tests/adapters/repository/elasticsearch_repo/elements.py @@ -1,9 +1,9 @@ import re -from collections import defaultdict from datetime import datetime -from protean import BaseAggregate, BaseRepository, BaseValueObject +from protean import BaseAggregate, BaseRepository, BaseValueObject, invariant +from protean.exceptions import ValidationError from protean.fields import DateTime, Integer, String, ValueObject @@ -35,14 +35,11 @@ class Email(BaseValueObject): # This is the external facing data attribute address = String(max_length=254, required=True) - def clean(self): + @invariant.post + def validate_email_address(self): """Business rules of Email address""" - errors = defaultdict(list) - if not bool(re.match(Email.REGEXP, self.address)): - errors["address"].append("is invalid") - - return errors + raise ValidationError({"address": ["email address"]}) class ComplexUser(BaseAggregate): diff --git a/tests/adapters/repository/sqlalchemy_repo/postgresql/elements.py b/tests/adapters/repository/sqlalchemy_repo/postgresql/elements.py index 1abc4d83..f13e3751 100644 --- a/tests/adapters/repository/sqlalchemy_repo/postgresql/elements.py +++ b/tests/adapters/repository/sqlalchemy_repo/postgresql/elements.py @@ -1,9 +1,9 @@ import re -from collections import defaultdict from datetime import datetime -from protean import BaseAggregate, BaseRepository, BaseValueObject +from protean import BaseAggregate, BaseRepository, BaseValueObject, invariant +from protean.exceptions import ValidationError from protean.fields import DateTime, Integer, String, ValueObject @@ -36,14 +36,11 @@ class Email(BaseValueObject): # This is the external facing data attribute address = String(max_length=254, required=True) - def clean(self): + @invariant.post + def validate_email_address(self): """Business rules of Email address""" - errors = defaultdict(list) - if not bool(re.match(Email.REGEXP, self.address)): - errors["address"].append("is invalid") - - return errors + raise ValidationError({"address": ["email address"]}) class ComplexUser(BaseAggregate): diff --git a/tests/adapters/repository/sqlalchemy_repo/sqlite/elements.py b/tests/adapters/repository/sqlalchemy_repo/sqlite/elements.py index 870868e1..b621959a 100644 --- a/tests/adapters/repository/sqlalchemy_repo/sqlite/elements.py +++ b/tests/adapters/repository/sqlalchemy_repo/sqlite/elements.py @@ -1,9 +1,9 @@ import re -from collections import defaultdict from datetime import datetime -from protean import BaseAggregate, BaseRepository, BaseValueObject +from protean import BaseAggregate, BaseRepository, BaseValueObject, invariant +from protean.exceptions import ValidationError from protean.fields import DateTime, Integer, String, ValueObject @@ -35,14 +35,11 @@ class Email(BaseValueObject): # This is the external facing data attribute address = String(max_length=254, required=True) - def clean(self): + @invariant.post + def validate_email_address(self): """Business rules of Email address""" - errors = defaultdict(list) - if not bool(re.match(Email.REGEXP, self.address)): - errors["address"].append("is invalid") - - return errors + raise ValidationError({"address": ["email address"]}) class ComplexUser(BaseAggregate): diff --git a/tests/adapters/repository/test_generic.py b/tests/adapters/repository/test_generic.py index 03760d72..8374c69f 100644 --- a/tests/adapters/repository/test_generic.py +++ b/tests/adapters/repository/test_generic.py @@ -1,13 +1,18 @@ import re -from collections import defaultdict from typing import List from uuid import uuid4 import pytest -from protean import BaseAggregate, BaseRepository, BaseValueObject, UnitOfWork -from protean.exceptions import ExpectedVersionError +from protean import ( + BaseAggregate, + BaseRepository, + BaseValueObject, + UnitOfWork, + invariant, +) +from protean.exceptions import ExpectedVersionError, ValidationError from protean.fields import Integer, String, ValueObject from protean.globals import current_domain @@ -32,14 +37,11 @@ class Email(BaseValueObject): # This is the external facing data attribute address = String(max_length=254, required=True) - def clean(self): + @invariant.post + def validate_email_address(self): """Business rules of Email address""" - errors = defaultdict(list) - if not bool(re.match(Email.REGEXP, self.address)): - errors["address"].append("is invalid") - - return errors + raise ValidationError({"address": ["email address"]}) class User(BaseAggregate): diff --git a/tests/aggregate/test_atomic_change.py b/tests/aggregate/test_atomic_change.py index b6c88a3f..75ffe5c6 100644 --- a/tests/aggregate/test_atomic_change.py +++ b/tests/aggregate/test_atomic_change.py @@ -19,12 +19,12 @@ class TestAggregate(BaseAggregate): assert aggregate._disable_invariant_checks is False - def test_clean_is_not_triggered_within_context_manager(self, test_domain): + def test_validation_is_not_triggered_within_context_manager(self, test_domain): class TestAggregate(BaseAggregate): value1 = Integer() value2 = Integer() - @invariant + @invariant.post def raise_error(self): if self.value2 != self.value1 + 1: raise ValidationError({"_entity": ["Invariant error"]}) diff --git a/tests/domain_service/test_invariants_decorator.py b/tests/domain_service/test_invariants_decorator.py new file mode 100644 index 00000000..b89cc1b4 --- /dev/null +++ b/tests/domain_service/test_invariants_decorator.py @@ -0,0 +1,38 @@ +from protean import BaseAggregate, BaseDomainService, invariant + + +class Aggregate1(BaseAggregate): + pass + + +class Aggregate2(BaseAggregate): + pass + + +class test_handler(BaseDomainService): + class Meta: + part_of = [Aggregate1, Aggregate2] + + @invariant.pre + def some_invariant_1(self): + pass + + @invariant.post + def some_invariant_2(self): + pass + + def __call__(self): + pass + + +def test_that_domain_service_has_recorded_invariants(test_domain): + test_domain.register(Aggregate1) + test_domain.register(Aggregate2) + test_domain.register(test_handler) + test_domain.init(traverse=False) + + assert len(test_handler._invariants) == 2 + + # Methods are presented in ascending order (alphabetical order) of member names. + assert "some_invariant_1" in test_handler._invariants["pre"] + assert "some_invariant_2" in test_handler._invariants["post"] diff --git a/tests/domain_service/test_invariants_triggering.py b/tests/domain_service/test_invariants_triggering.py new file mode 100644 index 00000000..47228ab7 --- /dev/null +++ b/tests/domain_service/test_invariants_triggering.py @@ -0,0 +1,427 @@ +import pytest + +from datetime import datetime, timezone +from enum import Enum +from uuid import uuid4 + +from protean import ( + BaseAggregate, + BaseDomainService, + BaseEvent, + BaseValueObject, + BaseEntity, + invariant, +) +from protean.exceptions import ValidationError +from protean.fields import ( + DateTime, + Float, + Identifier, + Integer, + HasMany, + List, + String, + ValueObject, +) + + +class OrderStatus(Enum): + PENDING = "PENDING" + CONFIRMED = "CONFIRMED" + SHIPPED = "SHIPPED" + DELIVERED = "DELIVERED" + + +class OrderItem(BaseEntity): + product_id = Identifier(required=True) + quantity = Integer() + price = Float() + + class Meta: + part_of = "Order" + + +class OrderItemVO(BaseValueObject): + product_id = Identifier(required=True) + quantity = Integer() + price = Float() + + +class OrderConfirmed(BaseEvent): + order_id = Identifier(required=True) + customer_id = Identifier(required=True) + items = List(content_type=ValueObject(OrderItemVO), required=True) + confirmed_at = DateTime(required=True) + + class Meta: + part_of = "Order" + + +class Order(BaseAggregate): + customer_id = Identifier(required=True) + items = HasMany("OrderItem") + status = String(choices=OrderStatus, default=OrderStatus.PENDING.value) + payment_id = Identifier() + + @invariant.post + def order_should_contain_items(self): + if not self.items or len(self.items) == 0: + raise ValidationError({"_entity": ["Order must contain at least one item"]}) + + def confirm(self): + self.status = OrderStatus.CONFIRMED.value + self.raise_( + OrderConfirmed( + customer_id=self.customer_id, + order_id=self.id, + confirmed_at=datetime.now(timezone.utc), + items=[ + OrderItemVO( + product_id=item.product_id, + quantity=item.quantity, + price=item.price, + ) + for item in self.items + ], + ) + ) + + +class Warehouse(BaseValueObject): + location = String() + contact = String() + + class Meta: + part_of = "Inventory" + + +class StockReserved(BaseEvent): + product_id = Identifier(required=True) + quantity = Integer(required=True) + reserved_at = DateTime(required=True) + + class Meta: + part_of = "Inventory" + + +class Inventory(BaseAggregate): + product_id = Identifier(required=True) + quantity = Integer() + warehouse = ValueObject(Warehouse) + + def reserve_stock(self, quantity: int): + self.quantity -= quantity + self.raise_( + StockReserved( + product_id=self.product_id, + quantity=quantity, + reserved_at=datetime.now(timezone.utc), + ) + ) + + +class OrderPlacementService(BaseDomainService): + class Meta: + part_of = [Order, Inventory] + + def __init__(self, order, inventories): + super().__init__(*(order, inventories)) + + self.order = order + self.inventories = inventories + + @invariant.pre + def inventory_should_have_sufficient_stock(self): + for item in self.order.items: + inventory = next( + (i for i in self.inventories if i.product_id == item.product_id), None + ) + if inventory is None or inventory.quantity < item.quantity: + raise ValidationError({"_service": ["Product is out of stock"]}) + + @invariant.pre + def order_payment_method_should_be_valid(self): + if not self.order.payment_id: + raise ValidationError( + {"_service": ["Order must have a valid payment method"]} + ) + + @invariant.post + def total_reserved_value_should_match_order_value(self): + order_total = sum(item.quantity * item.price for item in self.order.items) + reserved_total = 0 + for item in self.order.items: + inventory = next( + (i for i in self.inventories if i.product_id == item.product_id), None + ) + if inventory: + reserved_total += inventory._events[0].quantity * item.price + + if order_total != reserved_total: + raise ValidationError( + {"_service": ["Total reserved value does not match order value"]} + ) + + @invariant.post + def total_quantity_reserved_should_match_order_quantity(self): + order_quantity = sum(item.quantity for item in self.order.items) + reserved_quantity = sum( + inventory._events[0].quantity + for inventory in self.inventories + if inventory._events + ) + + if order_quantity != reserved_quantity: + raise ValidationError( + {"_service": ["Total reserved quantity does not match order quantity"]} + ) + + def __call__(self): + for item in self.order.items: + inventory = next( + (i for i in self.inventories if i.product_id == item.product_id), None + ) + inventory.reserve_stock(item.quantity) + + self.order.confirm() + + +@pytest.fixture(autouse=True) +def register_elements(test_domain): + test_domain.register(Order) + test_domain.register(OrderItem) + test_domain.register(Warehouse) + test_domain.register(Inventory) + test_domain.register(OrderConfirmed) + test_domain.register(StockReserved) + test_domain.register(OrderPlacementService) + test_domain.init(traverse=False) + + +def test_order_placement_with_sufficient_inventory(): + order = Order( + customer_id=str(uuid4()), + payment_id=str(uuid4()), + items=[OrderItem(product_id=str(uuid4()), quantity=10, price=100)], + ) + + inventory = Inventory( + product_id=order.items[0].product_id, + quantity=100, + warehouse=Warehouse(location="NYC", contact="John Doe"), + ) + + OrderPlacementService(order, [inventory])() + + assert order.status == OrderStatus.CONFIRMED.value + assert inventory.quantity == 90 + assert len(order._events) == 1 + assert isinstance(order._events[0], OrderConfirmed) + assert len(inventory._events) == 1 + assert isinstance(inventory._events[0], StockReserved) + assert inventory._events[0].quantity == 10 + assert inventory._events[0].product_id == order.items[0].product_id + assert inventory._events[0].reserved_at is not None + + +def test_order_placement_with_insufficient_inventory(): + order = Order( + customer_id=str(uuid4()), + payment_id=str(uuid4()), + items=[OrderItem(product_id=str(uuid4()), quantity=10, price=100)], + ) + + inventory = Inventory( + product_id=order.items[0].product_id, + quantity=5, + warehouse=Warehouse(location="NYC", contact="John Doe"), + ) + + with pytest.raises(ValidationError) as exc_info: + OrderPlacementService(order, [inventory])() + + assert str(exc_info.value) == "{'_service': ['Product is out of stock']}" + + +def test_order_placement_with_exact_inventory_match(): + order = Order( + customer_id=str(uuid4()), + payment_id=str(uuid4()), + items=[OrderItem(product_id=str(uuid4()), quantity=10, price=100)], + ) + + inventory = Inventory( + product_id=order.items[0].product_id, + quantity=10, + warehouse=Warehouse(location="NYC", contact="John Doe"), + ) + + OrderPlacementService(order, [inventory])() + + assert order.status == OrderStatus.CONFIRMED.value + assert inventory.quantity == 0 + assert len(order._events) == 1 + assert isinstance(order._events[0], OrderConfirmed) + assert len(inventory._events) == 1 + assert isinstance(inventory._events[0], StockReserved) + assert inventory._events[0].quantity == 10 + assert inventory._events[0].product_id == order.items[0].product_id + assert inventory._events[0].reserved_at is not None + + +def test_order_placement_with_multiple_items(): + order = Order( + customer_id=str(uuid4()), + payment_id=str(uuid4()), + items=[ + OrderItem(product_id=str(uuid4()), quantity=5, price=100), + OrderItem(product_id=str(uuid4()), quantity=3, price=200), + ], + ) + + inventory1 = Inventory( + product_id=order.items[0].product_id, + quantity=10, + warehouse=Warehouse(location="NYC", contact="John Doe"), + ) + inventory2 = Inventory( + product_id=order.items[1].product_id, + quantity=5, + warehouse=Warehouse(location="NYC", contact="Jane Doe"), + ) + + OrderPlacementService(order, [inventory1, inventory2])() + + assert order.status == OrderStatus.CONFIRMED.value + assert inventory1.quantity == 5 + assert inventory2.quantity == 2 + assert len(order._events) == 1 + assert isinstance(order._events[0], OrderConfirmed) + assert len(inventory1._events) == 1 + assert isinstance(inventory1._events[0], StockReserved) + assert inventory1._events[0].quantity == 5 + assert inventory1._events[0].product_id == order.items[0].product_id + assert len(inventory2._events) == 1 + assert isinstance(inventory2._events[0], StockReserved) + assert inventory2._events[0].quantity == 3 + assert inventory2._events[0].product_id == order.items[1].product_id + + +def test_total_reserved_value_matches_order_value(): + order = Order( + customer_id=str(uuid4()), + payment_id=str(uuid4()), + items=[ + OrderItem(product_id=str(uuid4()), quantity=5, price=100), + OrderItem(product_id=str(uuid4()), quantity=3, price=200), + ], + ) + + inventory1 = Inventory( + product_id=order.items[0].product_id, + quantity=10, + warehouse=Warehouse(location="NYC", contact="John Doe"), + ) + inventory2 = Inventory( + product_id=order.items[1].product_id, + quantity=5, + warehouse=Warehouse(location="NYC", contact="Jane Doe"), + ) + + OrderPlacementService(order, [inventory1, inventory2])() + + assert order.status == OrderStatus.CONFIRMED.value + assert inventory1.quantity == 5 + assert inventory2.quantity == 2 + assert sum(item.quantity * item.price for item in order.items) == sum( + item.quantity * item.price for item in order.items + ) + + +def test_order_placement_with_mismatched_reserved_value(): + order = Order( + customer_id=str(uuid4()), + payment_id=str(uuid4()), + items=[ + OrderItem(product_id=str(uuid4()), quantity=5, price=100), + OrderItem(product_id=str(uuid4()), quantity=3, price=200), + ], + ) + + # Inventory quantities are sufficient, but we will manually create a mismatch + inventory1 = Inventory( + product_id=order.items[0].product_id, + quantity=10, + warehouse=Warehouse(location="NYC", contact="John Doe"), + ) + inventory2 = Inventory( + product_id=order.items[1].product_id, + quantity=5, + warehouse=Warehouse(location="NYC", contact="Jane Doe"), + ) + + # Manually tampering the inventory to create a mismatch in reserved value + inventory1.reserve_stock(5) + inventory2.reserve_stock(1) # This should be 3 to match order + + with pytest.raises(ValidationError) as exc_info: + OrderPlacementService(order, [inventory1, inventory2])() + + assert str(exc_info.value) == ( + "{'_service': ['Total reserved quantity does not match order quantity', " + "'Total reserved value does not match order value']}" + ) + + +def test_order_placement_with_multiple_pre_condition_errors(): + order = Order( + customer_id=str(uuid4()), + payment_id=None, # Invalid payment method + items=[OrderItem(product_id=str(uuid4()), quantity=10, price=100)], + ) + + inventory = Inventory( + product_id=order.items[0].product_id, + quantity=5, # Insufficient stock + warehouse=Warehouse(location="NYC", contact="John Doe"), + ) + + with pytest.raises(ValidationError) as exc_info: + OrderPlacementService(order, [inventory])() + + assert "Product is out of stock" in str(exc_info.value) + assert "Order must have a valid payment method" in str(exc_info.value) + + +def test_order_placement_with_multiple_post_condition_errors(): + order = Order( + customer_id=str(uuid4()), + payment_id=str(uuid4()), + items=[ + OrderItem(product_id=str(uuid4()), quantity=5, price=100), + OrderItem(product_id=str(uuid4()), quantity=3, price=200), + ], + ) + + inventory1 = Inventory( + product_id=order.items[0].product_id, + quantity=10, + warehouse=Warehouse(location="NYC", contact="John Doe"), + ) + inventory2 = Inventory( + product_id=order.items[1].product_id, + quantity=5, + warehouse=Warehouse(location="NYC", contact="Jane Doe"), + ) + + # Manually tampering the inventory to create mismatches in reserved value and quantity + inventory1.reserve_stock(5) + inventory2.reserve_stock(1) # This should be 3 to match order + + with pytest.raises(ValidationError) as exc_info: + OrderPlacementService(order, [inventory1, inventory2])() + + assert "Total reserved value does not match order value" in str(exc_info.value) + assert "Total reserved quantity does not match order quantity" in str( + exc_info.value + ) diff --git a/tests/domain_service/tests.py b/tests/domain_service/tests.py index 114b5382..4da813c2 100644 --- a/tests/domain_service/tests.py +++ b/tests/domain_service/tests.py @@ -1,82 +1,109 @@ +import mock import pytest -from protean import BaseDomainService + +from protean import BaseAggregate, BaseDomainService +from protean.core import domain_service +from protean.exceptions import IncorrectUsageError from protean.utils import fully_qualified_name -def Aggregate1(BaseAggregate): +class Aggregate1(BaseAggregate): pass -def Aggregate2(BaseAggregate): +class Aggregate2(BaseAggregate): pass -class DummyDomainService(BaseDomainService): +class perform_something(BaseDomainService): class Meta: part_of = [Aggregate1, Aggregate2] - def do_complex_process(self): + def __call__(self): print("Performing complex process...") -class TestDomainServiceInitialization: - def test_that_base_domain_service_class_cannot_be_instantiated(self): - with pytest.raises(TypeError): - BaseDomainService() +def test_that_base_domain_service_class_cannot_be_instantiated(): + with pytest.raises(TypeError): + BaseDomainService() - def test_that_domain_service_can_be_instantiated(self): - service = DummyDomainService() - assert service is not None +def test_that_domain_service_can_be_instantiated(): + service = perform_something(Aggregate1(), Aggregate2()) + assert service is not None -class TestDomainServiceRegistration: - def test_that_domain_service_can_be_registered_with_domain(self, test_domain): - test_domain.register(DummyDomainService) - assert ( - fully_qualified_name(DummyDomainService) - in test_domain.registry.domain_services - ) +def test_that_domain_service_needs_to_be_associated_with_at_least_2_aggregates( + test_domain, +): + with pytest.raises(IncorrectUsageError): - def test_that_domain_service_can_be_registered_via_annotations(self, test_domain): - @test_domain.domain_service(part_of=[Aggregate1, Aggregate2]) - class AnnotatedDomainService: - def special_method(self): - pass + class bad_domain_service(BaseDomainService): + class Meta: + part_of = [Aggregate1] - assert ( - fully_qualified_name(AnnotatedDomainService) - in test_domain.registry.domain_services - ) + test_domain.register(bad_domain_service) + + +def test_that_domain_service_is_a_callable_class(): + assert callable(perform_something(Aggregate1(), Aggregate2())) + + +def test_that_domain_service_can_be_registered_with_domain(test_domain): + test_domain.register(perform_something) + + assert ( + fully_qualified_name(perform_something) in test_domain.registry.domain_services + ) - def test_that_domain_service_is_associated_with_aggregates(self, test_domain): - @test_domain.aggregate - class Aggregate3: - pass - @test_domain.aggregate - class Aggregate4: +def test_that_domain_service_can_be_registered_via_annotations(test_domain): + @test_domain.domain_service(part_of=[Aggregate1, Aggregate2]) + class AnnotatedDomainService: + def special_method(self): pass - @test_domain.domain_service(part_of=[Aggregate3, Aggregate4]) - class AnnotatedDomainService: - def special_method(self): - pass - - assert ( - fully_qualified_name(AnnotatedDomainService) - in test_domain.registry.domain_services - ) - assert ( - Aggregate3 - in test_domain.registry.domain_services[ - fully_qualified_name(AnnotatedDomainService) - ].cls.meta_.part_of - ) - assert ( - Aggregate4 - in test_domain.registry.domain_services[ - fully_qualified_name(AnnotatedDomainService) - ].cls.meta_.part_of - ) + assert ( + fully_qualified_name(AnnotatedDomainService) + in test_domain.registry.domain_services + ) + + +def test_that_domain_service_is_associated_with_aggregates(test_domain): + @test_domain.aggregate + class Aggregate3: + pass + + @test_domain.aggregate + class Aggregate4: + pass + + @test_domain.domain_service(part_of=[Aggregate3, Aggregate4]) + class do_something: + pass + + assert fully_qualified_name(do_something) in test_domain.registry.domain_services + assert ( + Aggregate3 + in test_domain.registry.domain_services[ + fully_qualified_name(do_something) + ].cls.meta_.part_of + ) + assert ( + Aggregate4 + in test_domain.registry.domain_services[ + fully_qualified_name(do_something) + ].cls.meta_.part_of + ) + + +def test_wrap_call_method_with_invariants(test_domain): + # Mock the `wrap_call_method_with_invariants` method + # Ensure it returns a domain service element + mock_wrap = mock.Mock(return_value=perform_something) + domain_service.wrap_call_method_with_invariants = mock_wrap + + test_domain.register(perform_something) + + mock_wrap.assert_called_once() diff --git a/tests/entity/elements.py b/tests/entity/elements.py index df1a458f..f20db1c1 100644 --- a/tests/entity/elements.py +++ b/tests/entity/elements.py @@ -149,7 +149,7 @@ def defaults(self): else: self.status = BuildingStatus.WIP.value - @invariant + @invariant.post def test_building_status_to_be_done_if_floors_above_4(self): if self.floors >= 4 and self.status != BuildingStatus.DONE.value: raise ValidationError( diff --git a/tests/entity/invariants/test_invariant_decorator.py b/tests/entity/invariants/test_invariant_decorator.py index 847324be..49b272ec 100644 --- a/tests/entity/invariants/test_invariant_decorator.py +++ b/tests/entity/invariants/test_invariant_decorator.py @@ -1,25 +1,51 @@ +from datetime import datetime +from enum import Enum + from protean import BaseAggregate, BaseEntity, invariant from protean.exceptions import ValidationError from protean.fields import Date, Float, Integer, String, HasMany +class OrderStatus(Enum): + PENDING = "PENDING" + SHIPPED = "SHIPPED" + DELIVERED = "DELIVERED" + + class Order(BaseAggregate): ordered_on = Date() total = Float() items = HasMany("OrderItem") + status = String( + max_length=50, choices=OrderStatus, default=OrderStatus.PENDING.value + ) - @invariant + @invariant.pre + def order_date_must_be_in_the_past_and_status_pending_to_update_order(self): + if ( + self.status != OrderStatus.PENDING.value + or self.order_date >= datetime.today().date() + ): + raise ValidationError( + { + "_entity": [ + "Order date must be in the past and status PENDING to update order" + ] + } + ) + + @invariant.post def total_should_be_sum_of_item_prices(self): if self.items: if self.total != sum([item.price for item in self.items]): raise ValidationError("Total should be sum of item prices") - @invariant + @invariant.post def must_have_at_least_one_item(self): if not self.items or len(self.items) == 0: raise ValidationError("Order must contain at least one item") - @invariant + @invariant.post def item_quantities_should_be_positive(self): for item in self.items: if item.quantity <= 0: @@ -34,7 +60,7 @@ class OrderItem(BaseEntity): class Meta: part_of = Order - @invariant + @invariant.post def price_should_be_non_negative(self): if self.price < 0: raise ValidationError("Item price should be non-negative") @@ -45,11 +71,17 @@ def test_that_entity_has_recorded_invariants(test_domain): test_domain.register(Order) test_domain.init(traverse=False) - assert len(Order._invariants) == 3 + assert len(Order._invariants["pre"]) == 1 + assert len(Order._invariants["post"]) == 3 + + assert ( + "order_date_must_be_in_the_past_and_status_pending_to_update_order" + in Order._invariants["pre"] + ) # Methods are presented in ascending order (alphabetical order) of member names. - assert "item_quantities_should_be_positive" in Order._invariants - assert "must_have_at_least_one_item" in Order._invariants - assert "total_should_be_sum_of_item_prices" in Order._invariants + assert "item_quantities_should_be_positive" in Order._invariants["post"] + assert "must_have_at_least_one_item" in Order._invariants["post"] + assert "total_should_be_sum_of_item_prices" in Order._invariants["post"] - assert len(OrderItem._invariants) == 1 - assert "price_should_be_non_negative" in OrderItem._invariants + assert len(OrderItem._invariants["post"]) == 1 + assert "price_should_be_non_negative" in OrderItem._invariants["post"] diff --git a/tests/entity/invariants/test_invariant_triggerring.py b/tests/entity/invariants/test_invariant_triggerring.py index b3970db3..adfcfdf6 100644 --- a/tests/entity/invariants/test_invariant_triggerring.py +++ b/tests/entity/invariants/test_invariant_triggerring.py @@ -1,6 +1,6 @@ import pytest -from datetime import date +from datetime import date, datetime from enum import Enum from protean import BaseAggregate, BaseEntity, invariant, atomic_change @@ -21,12 +21,26 @@ class Order(BaseAggregate): status = String(max_length=50, choices=OrderStatus) items = HasMany("OrderItem") - @invariant + @invariant.pre + def order_date_must_be_in_the_past_and_status_pending_to_update_order(self): + if ( + self.status != OrderStatus.PENDING.value + or self.order_date >= datetime.today().date() + ): + raise ValidationError( + { + "_entity": [ + "Order date must be in the past and status PENDING to update order" + ] + } + ) + + @invariant.post def total_amount_of_order_must_equal_sum_of_subtotal_of_all_items(self): if self.total_amount != sum(item.subtotal for item in self.items): raise ValidationError({"_entity": ["Total should be sum of item prices"]}) - @invariant + @invariant.post def order_date_must_be_within_the_last_30_days_if_status_is_pending(self): if self.status == OrderStatus.PENDING.value and self.order_date < date( 2020, 1, 1 @@ -39,7 +53,7 @@ def order_date_must_be_within_the_last_30_days_if_status_is_pending(self): } ) - @invariant + @invariant.post def customer_id_must_be_non_null_and_the_order_must_contain_at_least_one_item(self): if not self.customer_id or not self.items: raise ValidationError( @@ -50,6 +64,9 @@ def customer_id_must_be_non_null_and_the_order_must_contain_at_least_one_item(se } ) + def mark_shipped(self): + self.status = OrderStatus.SHIPPED.value + class OrderItem(BaseEntity): product_id = Identifier() @@ -60,7 +77,7 @@ class OrderItem(BaseEntity): class Meta: part_of = Order - @invariant + @invariant.post def the_quantity_must_be_a_positive_integer_and_the_subtotal_must_be_correctly_calculated( self, ): @@ -285,3 +302,32 @@ def test_when_item_price_is_changed_to_negative(self, order): assert exc.value.messages["_entity"] == [ "Quantity must be a positive integer and the subtotal must be correctly calculated" ] + + +class TestEntityPreInvariantsChecks: + def test_order_date_must_be_in_the_past_and_status_pending_to_update_order( + self, order + ): + # This check is enclosed within atomic_change() + order.mark_shipped() + with pytest.raises(ValidationError) as exc: + with atomic_change(order): + order.add_items( + OrderItem(product_id="3", quantity=2, price=10.0, subtotal=20.0) + ) + order.total_amount = 120.0 + + assert exc.value.messages["_entity"] == [ + "Order date must be in the past and status PENDING to update order" + ] + + def test_triggering_pre_validation_with_attribute_change(self, order): + # This is the same check as above, but we're triggering the pre-validation + # with an attribute change. + order.mark_shipped() + with pytest.raises(ValidationError) as exc: + order.customer_id = "2" + + assert exc.value.messages["_entity"] == [ + "Order date must be in the past and status PENDING to update order" + ] diff --git a/tests/entity/test_lifecycle_methods.py b/tests/entity/test_lifecycle_methods.py index cf6810bd..ed4bbd3a 100644 --- a/tests/entity/test_lifecycle_methods.py +++ b/tests/entity/test_lifecycle_methods.py @@ -17,7 +17,7 @@ def test_that_building_is_marked_as_done_if_below_4_floors(self): assert building.status == BuildingStatus.WIP.value -class TestClean: +class TestInvariantValidation: def test_that_building_cannot_be_WIP_if_above_4_floors(self, test_domain): test_domain.register(Building) test_domain.register(Area) diff --git a/tests/repository/elements.py b/tests/repository/elements.py index add33682..0728a58b 100644 --- a/tests/repository/elements.py +++ b/tests/repository/elements.py @@ -1,9 +1,9 @@ import re -from collections import defaultdict from typing import List -from protean import BaseAggregate, BaseRepository, BaseValueObject +from protean import BaseAggregate, BaseRepository, BaseValueObject, invariant +from protean.exceptions import ValidationError from protean.fields import Integer, String, ValueObject from protean.globals import current_domain @@ -28,14 +28,11 @@ class Email(BaseValueObject): # This is the external facing data attribute address = String(max_length=254, required=True) - def clean(self): + @invariant.post + def validate_email_address(self): """Business rules of Email address""" - errors = defaultdict(list) - if not bool(re.match(Email.REGEXP, self.address)): - errors["address"].append("is invalid") - - return errors + raise ValidationError({"address": ["email address"]}) class User(BaseAggregate): diff --git a/tests/value_object/elements.py b/tests/value_object/elements.py index 5b0c6f1b..6d8b31e0 100644 --- a/tests/value_object/elements.py +++ b/tests/value_object/elements.py @@ -1,7 +1,8 @@ from collections import defaultdict from enum import Enum -from protean import BaseAggregate, BaseValueObject +from protean import BaseAggregate, BaseValueObject, invariant +from protean.exceptions import ValidationError from protean.fields import Float, Identifier, Integer, String, ValueObject @@ -66,11 +67,11 @@ class Balance(BaseValueObject): currency = String(max_length=3, choices=Currency) amount = Float() - def clean(self): - errors = defaultdict(list) + @invariant.post + def validate_balance_cannot_be_less_than_1_trillion(self): + """Business rules of Balance""" if self.amount and self.amount < -1000000000000.0: - errors["amount"].append("cannot be less than 1 Trillion") - return errors + raise ValidationError({"amount": ["cannot be less than 1 Trillion"]}) def replace(self, **kwargs): # FIXME Find a way to do this generically and move method to `BaseValueObject` @@ -101,7 +102,7 @@ def defaults(self): else: self.status = BuildingStatus.WIP.value - def clean(self): + def _postcheck(self): errors = defaultdict(list) if self.floors >= 4 and self.status != BuildingStatus.DONE.value: errors["status"].append("should be DONE") diff --git a/tests/value_object/test_lifecycle_methods.py b/tests/value_object/test_lifecycle_methods.py index d6e3436e..dde51d91 100644 --- a/tests/value_object/test_lifecycle_methods.py +++ b/tests/value_object/test_lifecycle_methods.py @@ -17,7 +17,7 @@ def test_that_building_is_marked_as_done_if_below_4_floors(self): assert building.status == BuildingStatus.WIP.value -class TestClean: +class TestInvariantValidation: def test_that_building_cannot_be_WIP_if_above_4_floors(self): with pytest.raises(ValidationError): Building(name="Foo", floors=4, status=BuildingStatus.WIP.value) diff --git a/tests/value_object/test_vo_invariants.py b/tests/value_object/test_vo_invariants.py new file mode 100644 index 00000000..057eafc5 --- /dev/null +++ b/tests/value_object/test_vo_invariants.py @@ -0,0 +1,24 @@ +import pytest + +from protean import BaseValueObject, invariant +from protean.exceptions import ValidationError +from protean.fields import Float, String + + +class Balance(BaseValueObject): + currency = String(max_length=3, required=True) + amount = Float(required=True) + + @invariant.post + def check_balance_is_positive_if_currency_is_USD(self): + if self.amount < 0 and self.currency == "USD": + raise ValidationError({"balance": ["Balance cannot be negative for USD"]}) + + +def test_vo_invariant_raises_error_on_initialization(test_domain): + test_domain.register(Balance) + + with pytest.raises(ValidationError) as exc: + Balance(currency="USD", amount=-100.0) + + assert str(exc.value) == "{'balance': ['Balance cannot be negative for USD']}" diff --git a/tests/value_object/tests.py b/tests/value_object/tests.py index dfc88ee1..415d8e16 100644 --- a/tests/value_object/tests.py +++ b/tests/value_object/tests.py @@ -1,6 +1,7 @@ import pytest -from protean.exceptions import IncorrectUsageError, ValidationError +from protean.exceptions import IncorrectUsageError, ValidationError, NotSupportedError +from protean.fields import Float from protean.reflection import attributes, declared_fields from .elements import ( @@ -15,6 +16,32 @@ ) +def test_vo_marked_abstract_cannot_be_instantiated(): + class AbstractBalance(Balance): + amount = Float() + + class Meta: + abstract = True + + with pytest.raises(NotSupportedError) as exc: + AbstractBalance(amount=100.0) + + assert ( + str(exc.value) + == "AbstractBalance class has been marked abstract and cannot be instantiated" + ) + + +def test_template_param_is_a_dict(): + with pytest.raises(AssertionError) as exc: + Balance([Currency.CAD.value, 0.0]) + + assert str(exc.value) == ( + "Positional argument ['CAD', 0.0] passed must be a dict. " + "This argument serves as a template for loading common values." + ) + + class TestEquivalence: def test_two_value_objects_with_equal_values_are_considered_equal(self): email1 = Email.from_address("john.doe@gmail.com") diff --git a/tests/views/tests.py b/tests/views/tests.py index 87d1c555..3071cf61 100644 --- a/tests/views/tests.py +++ b/tests/views/tests.py @@ -119,7 +119,7 @@ def defaults(self): else: self.status = BuildingStatus.WIP.value - def clean(self): + def _postcheck(self): errors = defaultdict(list) if self.floors >= 4 and self.status != BuildingStatus.DONE.value: