diff --git a/src/protean/fields/association.py b/src/protean/fields/association.py index c4e3ee64..5a3f91c5 100644 --- a/src/protean/fields/association.py +++ b/src/protean/fields/association.py @@ -288,6 +288,9 @@ def _linked_attribute(self, owner): + id_field(owner).attribute_name ) + def _linked_reference(self, owner): + return utils.inflection.underscore(owner.__name__) + def __get__(self, instance, owner): """Retrieve associated objects""" @@ -391,6 +394,10 @@ def __set__(self, instance, value): value, linked_attribute, id_value ) # This overwrites any existing linkage, which is correct + # Add the parent to the child entity cache + # Temporarily set linkage to parent in child entity + setattr(value, self._linked_reference(type(instance)), instance) + # 2. Determine and store the change in the relationship current_value = getattr(instance, self.field_name) if current_value is None: @@ -521,6 +528,9 @@ def add(self, instance, items) -> None: getattr(instance, id_field(instance).field_name), ) + # Temporarily set linkage to parent in child entity + setattr(item, self._linked_reference(type(instance)), instance) + # Reset Cache self.delete_cached_value(instance) # Items to update @@ -535,6 +545,9 @@ def add(self, instance, items) -> None: getattr(instance, id_field(instance).field_name), ) + # Temporarily set linkage to parent in child entity + setattr(item, self._linked_reference(type(instance)), instance) + instance._temp_cache[self.field_name]["updated"][item.id] = item # Reset Cache diff --git a/tests/entity/associations/test_has_many_recursive.py b/tests/entity/associations/test_has_many_recursive.py index a409ab96..53ddd2e4 100644 --- a/tests/entity/associations/test_has_many_recursive.py +++ b/tests/entity/associations/test_has_many_recursive.py @@ -57,11 +57,11 @@ def test_customers_basic_structure(): len(customer.orders) == 1 customer.orders[0] == order customer.orders[0].customer_id == customer.id - # customer.orders[0].customer == customer # FIXME + customer.orders[0].customer == customer len(customer.orders[0].items) == 2 customer.orders[0].items[0] == items[0] customer.orders[0].items[0].order_id == order.id - # customer.orders[0].items[0].order == order # FIXME + customer.orders[0].items[0].order == order @pytest.fixture @@ -89,12 +89,12 @@ def test_all_associations_are_persisted_on_direct_initialization(self, customer) assert len(customer.orders) == 1 assert customer.orders[0].ordered_on == datetime.today().date() customer.orders[0].customer_id == customer.id - # customer.orders[0].customer == customer # FIXME + customer.orders[0].customer == customer assert len(customer.orders[0].items) == 2 assert customer.orders[0].items[0].product_id == "1" customer.orders[0].items[0].order_id == customer.orders[0].id - # customer.orders[0].items[0].order == order # FIXME + customer.orders[0].items[0].order == customer.orders[0] def test_all_associations_are_persisted_on_1st_level_nested_entity_addition( self, test_domain, customer diff --git a/tests/entity/associations/test_has_one_recursive.py b/tests/entity/associations/test_has_one_recursive.py index 857a04e8..c08701d7 100644 --- a/tests/entity/associations/test_has_one_recursive.py +++ b/tests/entity/associations/test_has_one_recursive.py @@ -50,6 +50,7 @@ def test_university_basic_structure(): assert dean.university_id == university.id assert university.dean.office == office assert university.dean.office.dean_id == dean.id + assert university.dean.office.dean == dean @pytest.fixture diff --git a/tests/entity/associations/test_multiple_has_many.py b/tests/entity/associations/test_multiple_has_many.py new file mode 100644 index 00000000..03b5e535 --- /dev/null +++ b/tests/entity/associations/test_multiple_has_many.py @@ -0,0 +1,110 @@ +import pytest + +from datetime import datetime, timedelta + +from protean import BaseAggregate, BaseEntity +from protean.fields import Date, String, HasMany +from protean.reflection import declared_fields + + +class Customer(BaseAggregate): + name = String(max_length=50) + orders = HasMany("Order") + addresses = HasMany("Address") + + +class Order(BaseEntity): + ordered_on = Date() + + class Meta: + part_of = Customer + + +class Address(BaseEntity): + street = String(max_length=50) + city = String(max_length=50) + state = String(max_length=50) + zip_code = String(max_length=10) + + class Meta: + part_of = Customer + + +@pytest.fixture(autouse=True) +def register_elements(test_domain): + test_domain.register(Customer) + test_domain.register(Order) + test_domain.register(Address) + test_domain.init(traverse=False) + + +def test_multiple_has_many_associations(): + assert declared_fields(Customer)["orders"].__class__.__name__ == "HasMany" + assert declared_fields(Customer)["orders"].field_name == "orders" + assert declared_fields(Customer)["orders"].to_cls == Order + + assert declared_fields(Customer)["addresses"].__class__.__name__ == "HasMany" + assert declared_fields(Customer)["addresses"].field_name == "addresses" + assert declared_fields(Customer)["addresses"].to_cls == Address + + +def test_customer_basic_structure_with_multiple_items_in_associations(): + order1 = Order(ordered_on=datetime.today().date()) + order2 = Order(ordered_on=datetime.today().date() - timedelta(days=1)) + address1 = Address( + street="123 Main St", city="Anytown", state="NY", zip_code="12345" + ) + address2 = Address( + street="456 Elm St", city="Anytown", state="NY", zip_code="12345" + ) + customer = Customer( + name="John Doe", orders=[order1, order2], addresses=[address1, address2] + ) + + assert len(customer.orders) == 2 + assert customer.orders[0] == order1 + assert customer.orders[0].customer_id == customer.id + assert customer.orders[0].customer == customer + + assert len(customer.addresses) == 2 + assert customer.addresses[0] == address1 + assert customer.addresses[0].customer_id == customer.id + assert customer.addresses[0].customer == customer + + +def test_basic_persistence(test_domain): + order1 = Order(ordered_on=datetime.today().date()) + order2 = Order(ordered_on=datetime.today().date() - timedelta(days=1)) + address1 = Address( + street="123 Main St", city="Anytown", state="NY", zip_code="12345" + ) + address2 = Address( + street="456 Elm St", city="Anytown", state="NY", zip_code="12345" + ) + customer = Customer( + name="John Doe", orders=[order1, order2], addresses=[address1, address2] + ) + + assert customer.id is not None + assert customer.orders[0].id is not None + assert customer.orders[1].id is not None + assert customer.addresses[0].id is not None + assert customer.addresses[1].id is not None + assert customer.orders[0].customer_id == customer.id + assert customer.orders[1].customer_id == customer.id + assert customer.addresses[0].customer_id == customer.id + assert customer.addresses[1].customer_id == customer.id + + test_domain.repository_for(Customer).add(customer) + + fetched_customer = test_domain.repository_for(Customer).get(customer.id) + + assert fetched_customer.name == "John Doe" + assert len(fetched_customer.orders) == 2 + assert fetched_customer.orders[0].ordered_on == datetime.today().date() + assert fetched_customer.orders[1].ordered_on == datetime.today().date() - timedelta( + days=1 + ) + assert len(fetched_customer.addresses) == 2 + assert fetched_customer.addresses[0].street == "123 Main St" + assert fetched_customer.addresses[1].street == "456 Elm St" diff --git a/tests/entity/associations/test_multiple_has_one.py b/tests/entity/associations/test_multiple_has_one.py new file mode 100644 index 00000000..584d4c85 --- /dev/null +++ b/tests/entity/associations/test_multiple_has_one.py @@ -0,0 +1,74 @@ +import pytest + +from protean import BaseEntity, BaseAggregate +from protean.fields import HasOne, Integer, String +from protean.reflection import declared_fields + + +class Department(BaseAggregate): + name = String(max_length=50) + dean = HasOne("Dean") + location = HasOne("Location") + + +class Dean(BaseEntity): + name = String(max_length=50) + age = Integer(min_value=21) + + class Meta: + part_of = Department + + +class Location(BaseEntity): + building = String(max_length=50) + + class Meta: + part_of = Department + + +@pytest.fixture(autouse=True) +def register_elements(test_domain): + test_domain.register(Department) + test_domain.register(Dean) + test_domain.register(Location) + test_domain.init(traverse=False) + + +def test_multiple_has_one_associations(): + assert declared_fields(Department)["dean"].__class__.__name__ == "HasOne" + assert declared_fields(Department)["dean"].field_name == "dean" + assert declared_fields(Department)["dean"].to_cls == Dean + + assert declared_fields(Department)["location"].__class__.__name__ == "HasOne" + assert declared_fields(Department)["location"].field_name == "location" + assert declared_fields(Department)["location"].to_cls == Location + + +def test_department_basic_structure(): + location = Location(building="Main Building") + dean = Dean(name="John Doe", age=45) + department = Department(name="Computer Science", dean=dean, location=location) + + assert department.dean == dean + assert dean.department_id == department.id + assert department.location == location + assert location.department_id == department.id + assert department.dean.department == department + assert department.location.department == department + + +def test_basic_persistence(test_domain): + location = Location(building="Main Building") + dean = Dean(name="John Doe", age=45) + department = Department(name="Computer Science", dean=dean, location=location) + + test_domain.repository_for(Department).add(department) + + persisted_department = test_domain.repository_for(Department).get(department.id) + + assert persisted_department.dean == dean + assert persisted_department.location == location + assert persisted_department.dean.department == persisted_department + assert persisted_department.location.department == persisted_department + assert persisted_department.dean.department_id == persisted_department.id + assert persisted_department.location.department_id == persisted_department.id