From 331cc74f42a2ad54bd543760d3006db23829f302 Mon Sep 17 00:00:00 2001 From: Subhash Bhushan Date: Tue, 21 May 2024 13:47:57 -0700 Subject: [PATCH] Run all invariants on init and attribute changes When an aggregate or entity object is initialized, or its attribute value is changed, this commit ensures that all `clean()` methods (which in turn run all invariant validations) are run across the aggregate. This commit also adds `atomic_change` context manager for scenarios where multiple aggregates in an aggregate need to be changed before it becomes valid again. --- docs/guides/compose-a-domain/object-model.md | 4 + .../domain-definition/fields/simple-fields.md | 20 ++ src/protean/__init__.py | 3 +- src/protean/container.py | 4 + src/protean/core/aggregate.py | 18 +- src/protean/core/command.py | 7 - src/protean/core/entity.py | 48 ++- src/protean/core/event.py | 10 - src/protean/fields/association.py | 13 + src/protean/fields/base.py | 9 + src/protean/fields/basic.py | 7 +- tests/aggregate/test_atomic_change.py | 45 +++ tests/entity/elements.py | 15 +- .../test_invariant_decorator.py} | 8 +- .../invariants/test_invariant_triggerring.py | 287 ++++++++++++++++++ tests/entity/test_lifecycle_methods.py | 8 +- 16 files changed, 459 insertions(+), 47 deletions(-) create mode 100644 tests/aggregate/test_atomic_change.py rename tests/entity/{test_invariants.py => invariants/test_invariant_decorator.py} (83%) create mode 100644 tests/entity/invariants/test_invariant_triggerring.py diff --git a/docs/guides/compose-a-domain/object-model.md b/docs/guides/compose-a-domain/object-model.md index 6a2673e6..5b7d24c4 100644 --- a/docs/guides/compose-a-domain/object-model.md +++ b/docs/guides/compose-a-domain/object-model.md @@ -9,6 +9,10 @@ document outlines generic aspects that apply to every domain element. `Element` is a base class inherited by all domain elements. Currently, it does not have any data structures or behavior associated with it. +## Element Type + +.element_type + ## Data Containers Protean provides data container elements, aligned with DDD principles to model diff --git a/docs/guides/domain-definition/fields/simple-fields.md b/docs/guides/domain-definition/fields/simple-fields.md index 7f24ff0e..8adef2b9 100644 --- a/docs/guides/domain-definition/fields/simple-fields.md +++ b/docs/guides/domain-definition/fields/simple-fields.md @@ -74,6 +74,26 @@ Out[2]: 'id': '88a21815-7d9b-4138-9cac-5a06889d4318'} ``` +Protean will intelligently convert a valid date string into a date object, with +the help of the venerable +[`dateutil`](https://dateutil.readthedocs.io/en/stable/) module. + +```shell +In [1]: post = Post(title='Foo', published_on="2020-01-01") + +In [2]: post.to_dict() +Out[2]: +{'title': 'Foo', + 'published_on': '2020-01-01', + 'id': 'ffcb3b26-71f0-45d0-8ca0-b71a9603f792'} + +In [3]: Post(title='Foo', published_on="2019-02-29") +ERROR: Error during initialization: {'published_on': ['"2019-02-29" has an invalid date format.']} +... +ValidationError: {'published_on': ['"2019-02-29" has an invalid date format.']} +``` + + ## DateTime A date and time, represented in Python by a `datetime.datetime` instance. diff --git a/src/protean/__init__.py b/src/protean/__init__.py index e31a1165..70cf8ab3 100644 --- a/src/protean/__init__.py +++ b/src/protean/__init__.py @@ -1,6 +1,6 @@ __version__ = "0.11.0" -from .core.aggregate import BaseAggregate +from .core.aggregate import BaseAggregate, atomic_change from .core.application_service import BaseApplicationService from .core.command import BaseCommand from .core.command_handler import BaseCommandHandler @@ -54,4 +54,5 @@ "get_version", "handle", "invariant", + "atomic_change", ] diff --git a/src/protean/container.py b/src/protean/container.py index e0736928..4c596305 100644 --- a/src/protean/container.py +++ b/src/protean/container.py @@ -195,6 +195,7 @@ def __init__(self, *template, **kwargs): # noqa: C901 This initialization technique supports keyword arguments as well as dictionaries. You can even use a template for initial data. """ + self._initialized = False if self.meta_.abstract is True: raise NotSupportedError( @@ -265,6 +266,8 @@ def __init__(self, *template, **kwargs): # noqa: C901 self.defaults() + self._initialized = True + # `clean()` will return a `defaultdict(list)` if errors are to be raised custom_errors = self.clean() or {} for field in custom_errors: @@ -329,6 +332,7 @@ def __setattr__(self, name, value): "_initialized", # Flag to indicate if the entity has been initialized "_root", # Root entity in the hierarchy "_owner", # Owner entity in the hierarchy + "_disable_invariant_checks", # Flag to disable invariant checks ] or name.startswith(("add_", "remove_")) ): diff --git a/src/protean/core/aggregate.py b/src/protean/core/aggregate.py index 470f99fe..2d35ea40 100644 --- a/src/protean/core/aggregate.py +++ b/src/protean/core/aggregate.py @@ -69,11 +69,27 @@ def aggregate_factory(element_cls, **kwargs): element_cls = derive_element_class(element_cls, BaseAggregate, **kwargs) # Iterate through methods marked as `@invariant` and record them for later use + # `_invariants` is a dictionary initialized in BaseEntity.__init_subclass__ methods = inspect.getmembers(element_cls, predicate=inspect.isroutine) for method_name, method in methods: if not ( method_name.startswith("__") and method_name.endswith("__") ) and hasattr(method, "_invariant"): - element_cls._invariants.append(method) + element_cls._invariants[method_name] = method return element_cls + + +# Context manager to temporarily disable invariant checks on aggregate +class atomic_change: + def __init__(self, aggregate): + self.aggregate = aggregate + + def __enter__(self): + # Temporary disable invariant checks + self.aggregate._disable_invariant_checks = True + + def __exit__(self, *args): + # Run clean() on exit to trigger invariant checks + self.aggregate._disable_invariant_checks = False + self.aggregate.clean() diff --git a/src/protean/core/command.py b/src/protean/core/command.py index d3870799..f95440c3 100644 --- a/src/protean/core/command.py +++ b/src/protean/core/command.py @@ -24,18 +24,11 @@ def __init_subclass__(subclass) -> None: subclass.__track_id_field() def __init__(self, *args, **kwargs): - # Set the flag to prevent any further modifications - self._initialized = False - try: super().__init__(*args, **kwargs) except ValidationError as exception: raise InvalidDataError(exception.messages) - # If we made it this far, the Value Object is initialized - # and should be marked as such - self._initialized = True - def __setattr__(self, name, value): if not hasattr(self, "_initialized") or not self._initialized: return super().__setattr__(name, value) diff --git a/src/protean/core/entity.py b/src/protean/core/entity.py index c9da569a..78513aa0 100644 --- a/src/protean/core/entity.py +++ b/src/protean/core/entity.py @@ -133,6 +133,8 @@ def __init__(self, *template, **kwargs): # noqa: C901 user = User(base_user.to_dict(), first_name='John', last_name='Doe') """ + self._initialized = False + if self.meta_.abstract is True: raise NotSupportedError( f"{self.__class__.__name__} class has been marked abstract" @@ -151,6 +153,9 @@ def __init__(self, *template, **kwargs): # noqa: C901 self._owner = None self._root = None + # To control invariant checks + self._disable_invariant_checks = False + # Collect Reference field attribute names to prevent accidental overwriting # of shadow fields. reference_attributes = { @@ -287,8 +292,10 @@ def __init__(self, *template, **kwargs): # noqa: C901 self.defaults() + self._initialized = True + # `clean()` will return a `defaultdict(list)` if errors are to be raised - custom_errors = self.clean() or {} + custom_errors = self.clean(return_errors=True) or {} for field in custom_errors: self.errors[field].extend(custom_errors[field]) @@ -302,11 +309,36 @@ def defaults(self): To be overridden in concrete Containers, when an attribute's default depends on other attribute values. """ - def clean(self): - """Placeholder method for validations. - To be overridden in concrete Containers, when complex validations spanning multiple fields are required. - """ - return defaultdict(list) + def clean(self, return_errors=False): + """Invoked after initialization to perform additional validations.""" + # Call all methods marked as invariants + if self._initialized and not self._disable_invariant_checks: + errors = defaultdict(list) + + for invariant_method in self._invariants.values(): + try: + invariant_method(self) + except ValidationError as err: + for field_name in err.messages: + errors[field_name].extend(err.messages[field_name]) + + # Run through all associations and trigger their clean method + for field_name, field_obj in declared_fields(self).items(): + if isinstance(field_obj, Association): + value = getattr(self, field_name) + if value is not None: + items = value if isinstance(value, list) else [value] + for item in items: + item_errors = item.clean(return_errors=True) + if item_errors: + for sub_field_name, error_list in item_errors.items(): + errors[sub_field_name].extend(error_list) + + if return_errors: + return errors + + if errors: + raise ValidationError(errors) def __eq__(self, other): """Equivalence check to be based only on Identity""" @@ -452,7 +484,7 @@ def __init_subclass__(subclass) -> None: super().__init_subclass__() # Record invariant methods - setattr(subclass, "_invariants", []) + setattr(subclass, "_invariants", {}) def entity_factory(element_cls, **kwargs): @@ -519,7 +551,7 @@ def entity_factory(element_cls, **kwargs): if not ( method_name.startswith("__") and method_name.endswith("__") ) and hasattr(method, "_invariant"): - element_cls._invariants.append(method) + element_cls._invariants[method_name] = method return element_cls diff --git a/src/protean/core/event.py b/src/protean/core/event.py index 79ca0c7e..b1116b26 100644 --- a/src/protean/core/event.py +++ b/src/protean/core/event.py @@ -27,16 +27,6 @@ def __init_subclass__(subclass) -> None: if not subclass.meta_.abstract: subclass.__track_id_field() - def __init__(self, *args, **kwargs): - # Set the flag to prevent any further modifications - self._initialized = False - - super().__init__(*args, **kwargs) - - # If we made it this far, the Value Object is initialized - # and should be marked as such - self._initialized = True - def __setattr__(self, name, value): if not hasattr(self, "_initialized") or not self._initialized: return super().__setattr__(name, value) diff --git a/src/protean/fields/association.py b/src/protean/fields/association.py index 070eb16a..8c5f1eef 100644 --- a/src/protean/fields/association.py +++ b/src/protean/fields/association.py @@ -437,6 +437,9 @@ def __set__(self, instance, value): elif isinstance(field_obj, HasOne): setattr(old_value, field_name, None) + if instance._initialized and instance._root is not None: + instance._root.clean() # Trigger validations from the top + def _fetch_objects(self, instance, key, identifier): """Fetch single linked object""" try: @@ -526,6 +529,10 @@ def add(self, instance, items) -> None: current_value_ids = [value.id for value in data] + # Remove items when set to empty + if len(items) == 0 and len(current_value_ids) > 0: + self.remove(instance, data) + for item in items: # Items to add if item.id not in current_value_ids: @@ -563,6 +570,9 @@ def add(self, instance, items) -> None: # Reset Cache self.delete_cached_value(instance) + if instance._initialized and instance._root is not None: + instance._root.clean() # Trigger validations from the top + def remove(self, instance, items) -> None: """ Available as `add_` method on the entity instance. @@ -609,6 +619,9 @@ def remove(self, instance, items) -> None: elif isinstance(field_obj, HasOne): setattr(item, field_name, None) + if instance._initialized and instance._root is not None: + instance._root.clean() # Trigger validations from the top + def _fetch_objects(self, instance, key, value) -> list: """ Fetch linked entities. diff --git a/src/protean/fields/base.py b/src/protean/fields/base.py index 70b56383..74309dad 100644 --- a/src/protean/fields/base.py +++ b/src/protean/fields/base.py @@ -140,8 +140,17 @@ def __get__(self, instance, owner): def __set__(self, instance, value): value = self._load(value) + instance.__dict__[self.field_name] = value + # The hasattr check is necessary to avoid running clean on unrelated elements + if ( + instance._initialized + and hasattr(instance, "_root") + and instance._root is not None + ): + instance._root.clean() # Trigger validations from the top + # Mark Entity as Dirty if hasattr(instance, "state_"): instance.state_.mark_changed() diff --git a/src/protean/fields/basic.py b/src/protean/fields/basic.py index dca0805c..d8b18c0a 100644 --- a/src/protean/fields/basic.py +++ b/src/protean/fields/basic.py @@ -421,12 +421,7 @@ def __set__(self, instance, value): if existing_value is not None and value != existing_value: raise InvalidOperationError("Identifiers cannot be changed once set") - value = self._load(value) - instance.__dict__[self.field_name] = value - - if hasattr(instance, "state_"): - # Mark Entity as Dirty - instance.state_.mark_changed() + super().__set__(instance, value) def as_dict(self, value): """Return JSON-compatible value of self""" diff --git a/tests/aggregate/test_atomic_change.py b/tests/aggregate/test_atomic_change.py new file mode 100644 index 00000000..b6c88a3f --- /dev/null +++ b/tests/aggregate/test_atomic_change.py @@ -0,0 +1,45 @@ +"""Test `atomic_change` context manager""" + +import pytest + +from protean import BaseAggregate, atomic_change, invariant +from protean.fields import Integer +from protean.exceptions import ValidationError + + +class TestAtomicChange: + def test_atomic_change_context_manager(self): + class TestAggregate(BaseAggregate): + pass + + aggregate = TestAggregate() + + with atomic_change(aggregate): + assert aggregate._disable_invariant_checks is True + + assert aggregate._disable_invariant_checks is False + + def test_clean_is_not_triggered_within_context_manager(self, test_domain): + class TestAggregate(BaseAggregate): + value1 = Integer() + value2 = Integer() + + @invariant + def raise_error(self): + if self.value2 != self.value1 + 1: + raise ValidationError({"_entity": ["Invariant error"]}) + + test_domain.register(TestAggregate) + test_domain.init(traverse=False) + + aggregate = TestAggregate(value1=1, value2=2) + + # This raises an error because of the invariant + with pytest.raises(ValidationError): + aggregate.value1 = 2 + aggregate.value2 = 3 + + # This should not raise an error because of the context manager + with atomic_change(aggregate): + aggregate.value1 = 2 + aggregate.value2 = 3 diff --git a/tests/entity/elements.py b/tests/entity/elements.py index 0446939e..df1a458f 100644 --- a/tests/entity/elements.py +++ b/tests/entity/elements.py @@ -1,7 +1,7 @@ -from collections import defaultdict from enum import Enum -from protean import BaseAggregate, BaseEntity +from protean import BaseAggregate, BaseEntity, invariant +from protean.exceptions import ValidationError from protean.fields import Auto, HasOne, Integer, String @@ -149,10 +149,9 @@ def defaults(self): else: self.status = BuildingStatus.WIP.value - def clean(self): - errors = defaultdict(list) - + @invariant + def test_building_status_to_be_done_if_floors_above_4(self): if self.floors >= 4 and self.status != BuildingStatus.DONE.value: - errors["status"].append("should be DONE") - - return errors + raise ValidationError( + {"_entity": ["Building status should be DONE if floors are above 4"]} + ) diff --git a/tests/entity/test_invariants.py b/tests/entity/invariants/test_invariant_decorator.py similarity index 83% rename from tests/entity/test_invariants.py rename to tests/entity/invariants/test_invariant_decorator.py index 972543aa..847324be 100644 --- a/tests/entity/test_invariants.py +++ b/tests/entity/invariants/test_invariant_decorator.py @@ -47,9 +47,9 @@ def test_that_entity_has_recorded_invariants(test_domain): assert len(Order._invariants) == 3 # Methods are presented in ascending order (alphabetical order) of member names. - assert Order._invariants[0].__name__ == "item_quantities_should_be_positive" - assert Order._invariants[1].__name__ == "must_have_at_least_one_item" - assert Order._invariants[2].__name__ == "total_should_be_sum_of_item_prices" + assert "item_quantities_should_be_positive" in Order._invariants + assert "must_have_at_least_one_item" in Order._invariants + assert "total_should_be_sum_of_item_prices" in Order._invariants assert len(OrderItem._invariants) == 1 - assert OrderItem._invariants[0].__name__ == "price_should_be_non_negative" + assert "price_should_be_non_negative" in OrderItem._invariants diff --git a/tests/entity/invariants/test_invariant_triggerring.py b/tests/entity/invariants/test_invariant_triggerring.py new file mode 100644 index 00000000..b3970db3 --- /dev/null +++ b/tests/entity/invariants/test_invariant_triggerring.py @@ -0,0 +1,287 @@ +import pytest + +from datetime import date +from enum import Enum + +from protean import BaseAggregate, BaseEntity, invariant, atomic_change +from protean.exceptions import ValidationError +from protean.fields import Date, Float, Identifier, Integer, String, HasMany + + +class OrderStatus(Enum): + PENDING = "PENDING" + SHIPPED = "SHIPPED" + DELIVERED = "DELIVERED" + + +class Order(BaseAggregate): + customer_id = Identifier() + order_date = Date() + total_amount = Float() + status = String(max_length=50, choices=OrderStatus) + items = HasMany("OrderItem") + + @invariant + def total_amount_of_order_must_equal_sum_of_subtotal_of_all_items(self): + if self.total_amount != sum(item.subtotal for item in self.items): + raise ValidationError({"_entity": ["Total should be sum of item prices"]}) + + @invariant + def order_date_must_be_within_the_last_30_days_if_status_is_pending(self): + if self.status == OrderStatus.PENDING.value and self.order_date < date( + 2020, 1, 1 + ): + raise ValidationError( + { + "_entity": [ + "Order date must be within the last 30 days if status is PENDING" + ] + } + ) + + @invariant + def customer_id_must_be_non_null_and_the_order_must_contain_at_least_one_item(self): + if not self.customer_id or not self.items: + raise ValidationError( + { + "_entity": [ + "Customer ID must be non-null and the order must contain at least one item" + ] + } + ) + + +class OrderItem(BaseEntity): + product_id = Identifier() + quantity = Integer() + price = Float() + subtotal = Float() + + class Meta: + part_of = Order + + @invariant + def the_quantity_must_be_a_positive_integer_and_the_subtotal_must_be_correctly_calculated( + self, + ): + if self.quantity <= 0 or self.subtotal != self.quantity * self.price: + raise ValidationError( + { + "_entity": [ + "Quantity must be a positive integer and the subtotal must be correctly calculated" + ] + } + ) + + +@pytest.fixture(autouse=True) +def register_elements(test_domain): + test_domain.register(OrderItem) + test_domain.register(Order) + test_domain.init(traverse=False) + + +class TestEntityInvariantsOnInitialization: + def test_with_valid_data(self): + order = Order( + customer_id="1", + order_date="2020-01-01", + total_amount=100.0, + status="PENDING", + items=[ + OrderItem(product_id="1", quantity=4, price=10.0, subtotal=40.0), + OrderItem(product_id="2", quantity=3, price=20.0, subtotal=60.0), + ], + ) + assert order is not None + + def test_when_total_amount_is_not_sum_of_item_subtotals(self): + with pytest.raises(ValidationError) as exc: + Order( + customer_id="1", + order_date="2020-01-01", + total_amount=100.0, + status="PENDING", + items=[ + OrderItem(product_id="1", quantity=2, price=10.0, subtotal=20.0), + OrderItem(product_id="2", quantity=3, price=20.0, subtotal=60.0), + ], + ) + + assert exc.value.messages["_entity"] == ["Total should be sum of item prices"] + + def test_when_order_date_is_not_within_the_last_30_days(self): + with pytest.raises(ValidationError) as exc: + Order( + customer_id="1", + order_date="2019-12-01", + total_amount=100.0, + status="PENDING", + items=[ + OrderItem(product_id="1", quantity=4, price=10.0, subtotal=40.0), + OrderItem(product_id="2", quantity=3, price=20.0, subtotal=60.0), + ], + ) + + assert exc.value.messages["_entity"] == [ + "Order date must be within the last 30 days if status is PENDING" + ] + + def test_when_customer_ID_is_null(self): + with pytest.raises(ValidationError) as exc: + Order( + customer_id=None, + order_date="2020-01-01", + total_amount=100.0, + status="PENDING", + items=[ + OrderItem(product_id="1", quantity=4, price=10.0, subtotal=40.0), + OrderItem(product_id="2", quantity=3, price=20.0, subtotal=60.0), + ], + ) + + assert exc.value.messages["_entity"] == [ + "Customer ID must be non-null and the order must contain at least one item" + ] + + def test_when_items_are_empty(self): + with pytest.raises(ValidationError) as exc: + Order( + customer_id="1", + order_date="2020-01-01", + total_amount=100.0, + status="PENDING", + items=[], + ) + + assert exc.value.messages["_entity"] == [ + "Customer ID must be non-null and the order must contain at least one item", + "Total should be sum of item prices", + ] + + def test_when_quantity_is_negative(self): + with pytest.raises(ValidationError) as exc: + Order( + customer_id="1", + order_date="2020-01-01", + total_amount=100.0, + status="PENDING", + items=[ + OrderItem(product_id="1", quantity=-1, price=10.0, subtotal=10.0), + OrderItem(product_id="2", quantity=3, price=20.0, subtotal=60.0), + ], + ) + + assert exc.value.messages["_entity"] == [ + "Quantity must be a positive integer and the subtotal must be correctly calculated" + ] + + def test_when_subtotal_is_incorrect(self): + with pytest.raises(ValidationError) as exc: + Order( + customer_id="1", + order_date="2020-01-01", + total_amount=100.0, + status="PENDING", + items=[ + OrderItem(product_id="1", quantity=1, price=10.0, subtotal=20.0), + OrderItem(product_id="2", quantity=3, price=20.0, subtotal=60.0), + ], + ) + + assert exc.value.messages["_entity"] == [ + "Quantity must be a positive integer and the subtotal must be correctly calculated" + ] + + +@pytest.fixture +def order(): + return Order( + customer_id="1", + order_date="2020-01-01", + total_amount=100.0, + status="PENDING", + items=[ + OrderItem(product_id="1", quantity=4, price=10.0, subtotal=40.0), + OrderItem(product_id="2", quantity=3, price=20.0, subtotal=60.0), + ], + ) + + +class TestEntityInvariantsOnAttributeChanges: + def test_when_total_amount_is_not_sum_of_item_subtotals(self, order): + with pytest.raises(ValidationError) as exc: + order.total_amount = 50.0 + + assert exc.value.messages["_entity"] == ["Total should be sum of item prices"] + + def test_when_order_date_is_not_within_the_last_30_days(self, order): + with pytest.raises(ValidationError) as exc: + order.order_date = "2019-12-01" + + assert exc.value.messages["_entity"] == [ + "Order date must be within the last 30 days if status is PENDING" + ] + + def test_when_customer_ID_is_null(self, order): + with pytest.raises(ValidationError) as exc: + order.customer_id = None + + assert exc.value.messages["_entity"] == [ + "Customer ID must be non-null and the order must contain at least one item" + ] + + def test_when_items_are_empty(self, order): + with pytest.raises(ValidationError) as exc: + order.items = [] + + assert exc.value.messages["_entity"] == [ + "Customer ID must be non-null and the order must contain at least one item", + "Total should be sum of item prices", + ] + + def test_when_invalid_item_is_added(self, order): + with pytest.raises(ValidationError) as exc: + order.add_items( + OrderItem(product_id="3", quantity=2, price=10.0, subtotal=40.0) + ) + + assert exc.value.messages["_entity"] == [ + "Quantity must be a positive integer and the subtotal must be correctly calculated" + ] + + def test_when_item_is_added_along_with_total_amount(self, order): + try: + with atomic_change(order): + order.total_amount = 120.0 + order.add_items( + OrderItem(product_id="3", quantity=2, price=10.0, subtotal=20.0) + ) + except ValidationError: + pytest.fail("Failed to batch update attributes") + + def test_when_quantity_is_negative(self, order): + with pytest.raises(ValidationError) as exc: + order.items[0].quantity = -1 + + assert exc.value.messages["_entity"] == [ + "Quantity must be a positive integer and the subtotal must be correctly calculated" + ] + + def test_when_invalid_item_is_added_after_initialization(self, order): + with pytest.raises(ValidationError) as exc: + order.add_items( + OrderItem(product_id="3", quantity=2, price=10.0, subtotal=40.0) + ) + + assert exc.value.messages["_entity"] == [ + "Quantity must be a positive integer and the subtotal must be correctly calculated" + ] + + def test_when_item_price_is_changed_to_negative(self, order): + with pytest.raises(ValidationError) as exc: + order.items[0].price = -10.0 + + assert exc.value.messages["_entity"] == [ + "Quantity must be a positive integer and the subtotal must be correctly calculated" + ] diff --git a/tests/entity/test_lifecycle_methods.py b/tests/entity/test_lifecycle_methods.py index d6e3436e..cf6810bd 100644 --- a/tests/entity/test_lifecycle_methods.py +++ b/tests/entity/test_lifecycle_methods.py @@ -2,7 +2,7 @@ from protean.exceptions import ValidationError -from .elements import Building, BuildingStatus +from .elements import Area, Building, BuildingStatus class TestDefaults: @@ -18,6 +18,10 @@ def test_that_building_is_marked_as_done_if_below_4_floors(self): class TestClean: - def test_that_building_cannot_be_WIP_if_above_4_floors(self): + def test_that_building_cannot_be_WIP_if_above_4_floors(self, test_domain): + test_domain.register(Building) + test_domain.register(Area) + test_domain.init(traverse=False) + with pytest.raises(ValidationError): Building(name="Foo", floors=4, status=BuildingStatus.WIP.value)