diff --git a/src/protean/fields/association.py b/src/protean/fields/association.py index c0b2bd8d..52c3bcf9 100644 --- a/src/protean/fields/association.py +++ b/src/protean/fields/association.py @@ -420,6 +420,10 @@ class HasMany(Association): def __init__(self, to_cls, via=None, **kwargs): super().__init__(to_cls, via=via, **kwargs) + def __set__(self, instance, value): + if value is not None: + self.add(instance, value) + def add(self, instance, items): data = getattr(instance, self.field_name) diff --git a/tests/aggregate/test_aggregate_association.py b/tests/aggregate/test_aggregate_association.py index 5c550d48..abe6cefb 100644 --- a/tests/aggregate/test_aggregate_association.py +++ b/tests/aggregate/test_aggregate_association.py @@ -35,47 +35,41 @@ def register_elements(self, test_domain): def test_successful_initialization_of_entity_with_has_one_association( self, test_domain ): - account = Account(email="john.doe@gmail.com", password="a1b2c3") - test_domain.repository_for(Account)._dao.save(account) - author = Author(first_name="John", last_name="Doe", account=account) - test_domain.repository_for(Author)._dao.save(author) - - assert all(key in author.__dict__ for key in ["account", "account_email"]) - assert author.account.email == account.email - assert author.account_email == account.email + account = Account( + email="john.doe@gmail.com", + password="a1b2c3", + author=Author(first_name="John", last_name="Doe"), + ) + test_domain.repository_for(Account).add(account) - refreshed_account = test_domain.repository_for(Account)._dao.get(account.email) - assert refreshed_account.author.id == author.id - assert refreshed_account.author == author + updated_account = test_domain.repository_for(Account).get(account.email) + updated_author = updated_account.author - def test_successful_has_one_initialization_with_a_class_containing_via_and_no_reference( - self, test_domain - ): - account = AccountVia(email="john.doe@gmail.com", password="a1b2c3") - test_domain.repository_for(AccountVia)._dao.save(account) - profile = ProfileVia( - profile_id="12345", about_me="Lorem Ipsum", account_email=account.email + updated_author.account # To refresh and load the account # FIXME Auto-load child entities + assert all( + key in updated_author.__dict__ for key in ["account", "account_email"] ) - test_domain.repository_for(ProfileVia)._dao.save(profile) + assert updated_author.account.email == account.email + assert updated_author.account_email == account.email - refreshed_account = test_domain.repository_for(AccountVia)._dao.get( - account.email - ) - assert refreshed_account.profile == profile + assert updated_account.author.id == updated_author.id + assert updated_account.author == updated_author def test_successful_has_one_initialization_with_a_class_containing_via_and_reference( self, test_domain ): account = AccountViaWithReference( - email="john.doe@gmail.com", password="a1b2c3", username="johndoe" + email="john.doe@gmail.com", + password="a1b2c3", + username="johndoe", ) - test_domain.repository_for(AccountViaWithReference)._dao.save(account) profile = ProfileViaWithReference(about_me="Lorem Ipsum", ac=account) - test_domain.repository_for(ProfileViaWithReference)._dao.save(profile) + account.profile = profile + test_domain.repository_for(AccountViaWithReference).add(account) - refreshed_account = test_domain.repository_for( - AccountViaWithReference - )._dao.get(account.email) + refreshed_account = test_domain.repository_for(AccountViaWithReference).get( + account.email + ) assert refreshed_account.profile == profile @mock.patch("protean.fields.association.Association._fetch_objects") @@ -85,15 +79,15 @@ def test_that_subsequent_access_after_first_retrieval_do_not_fetch_record_again( account = AccountViaWithReference( email="john.doe@gmail.com", password="a1b2c3", username="johndoe" ) - test_domain.repository_for(AccountViaWithReference)._dao.save(account) profile = ProfileViaWithReference(about_me="Lorem Ipsum", ac=account) - test_domain.repository_for(ProfileViaWithReference)._dao.save(profile) + account.profile = profile + test_domain.repository_for(AccountViaWithReference).add(account) mock.return_value = profile - refreshed_account = test_domain.repository_for( - AccountViaWithReference - )._dao.get(account.email) + refreshed_account = test_domain.repository_for(AccountViaWithReference).get( + account.email + ) for _ in range(3): getattr(refreshed_account, "profile") assert ( @@ -113,20 +107,42 @@ def register_elements(self, test_domain): @pytest.fixture def persisted_post(self, test_domain): - post = test_domain.repository_for(Post)._dao.create(content="Do Re Mi Fa") + post = test_domain.repository_for(Post).add(Post(content="Do Re Mi Fa")) return post def test_successful_initialization_of_entity_with_has_many_association( self, test_domain ): - post = Post(content="Lorem Ipsum") + post = Post( + content="Lorem Ipsum", + comments=[ + Comment(id=101, content="First Comment"), + Comment(id=102, content="Second Comment"), + ], + ) test_domain.repository_for(Post).add(post) - comment1 = Comment(id=101, content="First Comment") - comment2 = Comment(id=102, content="Second Comment") + refreshed_post = test_domain.repository_for(Post).get(post.id) + assert len(refreshed_post.comments) == 2 + assert "comments" in refreshed_post.__dict__ # Available after access + assert refreshed_post.comments[0].post_id == post.id + assert refreshed_post.comments[1].post_id == post.id + + assert isinstance(refreshed_post.comments, list) + assert all( + comment.id in [101, 102] for comment in refreshed_post.comments + ) # `__iter__` magic here - post.add_comments(comment1) - post.add_comments(comment2) + def test_adding_multiple_associations_at_the_same_time_before_aggregate_save( + self, test_domain + ): + post = Post(content="Lorem Ipsum") + post.add_comments( + [ + Comment(id=101, content="First Comment"), + Comment(id=102, content="Second Comment"), + ], + ) test_domain.repository_for(Post).add(post) refreshed_post = test_domain.repository_for(Post).get(post.id) @@ -142,11 +158,13 @@ def test_successful_initialization_of_entity_with_has_many_association( def test_adding_multiple_associations_at_the_same_time(self, test_domain): post = Post(content="Lorem Ipsum") + # Save the aggregate first, which is what happens in reality test_domain.repository_for(Post).add(post) comment1 = Comment(id=101, content="First Comment") comment2 = Comment(id=102, content="Second Comment") + # Comments follow later post.add_comments([comment1, comment2]) test_domain.repository_for(Post).add(post) @@ -164,15 +182,14 @@ def test_adding_multiple_associations_at_the_same_time(self, test_domain): def test_successful_has_one_initialization_with_a_class_containing_via_and_no_reference( self, test_domain ): - post = PostVia(content="Lorem Ipsum") - test_domain.repository_for(PostVia)._dao.save(post) - comment1 = CommentVia(id=101, content="First Comment", posting_id=post.id) - comment2 = CommentVia(id=102, content="First Comment", posting_id=post.id) - test_domain.repository_for(CommentVia)._dao.save(comment1) - test_domain.repository_for(CommentVia)._dao.save(comment2) - - assert comment1.posting_id == post.id - assert comment2.posting_id == post.id + post = PostVia( + content="Lorem Ipsum", + comments=[ + CommentVia(id=101, content="First Comment"), + CommentVia(id=102, content="Second Comment"), + ], + ) + test_domain.repository_for(PostVia).add(post) refreshed_post = test_domain.repository_for(PostVia)._dao.get(post.id) assert len(refreshed_post.comments) == 2 @@ -182,23 +199,20 @@ def test_successful_has_one_initialization_with_a_class_containing_via_and_no_re assert all( comment.id in [101, 102] for comment in refreshed_post.comments ) # `__iter__` magic here + for comment in refreshed_post.comments: + assert comment.posting_id == post.id def test_successful_has_one_initialization_with_a_class_containing_via_and_reference( self, test_domain ): - post = PostViaWithReference(content="Lorem Ipsum") - test_domain.repository_for(PostViaWithReference)._dao.save(post) - comment1 = CommentViaWithReference( - id=101, content="First Comment", posting=post - ) - comment2 = CommentViaWithReference( - id=102, content="First Comment", posting=post + post = PostViaWithReference( + content="Lorem Ipsum", + comments=[ + CommentViaWithReference(id=101, content="First Comment"), + CommentViaWithReference(id=102, content="First Comment"), + ], ) - test_domain.repository_for(CommentViaWithReference)._dao.save(comment1) - test_domain.repository_for(CommentViaWithReference)._dao.save(comment2) - - assert comment1.posting_id == post.id - assert comment2.posting_id == post.id + test_domain.repository_for(PostViaWithReference).add(post) refreshed_post = test_domain.repository_for(PostViaWithReference)._dao.get( post.id @@ -214,16 +228,14 @@ def test_successful_has_one_initialization_with_a_class_containing_via_and_refer def test_that_subsequent_access_after_first_retrieval_do_not_fetch_record_again( self, test_domain ): - post = PostViaWithReference(content="Lorem Ipsum") - test_domain.repository_for(PostViaWithReference)._dao.save(post) - comment1 = CommentViaWithReference( - id=101, content="First Comment", posting=post - ) - comment2 = CommentViaWithReference( - id=102, content="First Comment", posting=post + post = PostViaWithReference( + content="Lorem Ipsum", + comments=[ + CommentViaWithReference(id=101, content="First Comment"), + CommentViaWithReference(id=102, content="First Comment"), + ], ) - test_domain.repository_for(CommentViaWithReference)._dao.save(comment1) - test_domain.repository_for(CommentViaWithReference)._dao.save(comment2) + test_domain.repository_for(PostViaWithReference).add(post) refreshed_post = test_domain.repository_for(PostViaWithReference)._dao.get( post.id diff --git a/tests/aggregate/test_aggregate_association_dao.py b/tests/aggregate/test_aggregate_association_dao.py index 5c550d48..8f391ba2 100644 --- a/tests/aggregate/test_aggregate_association_dao.py +++ b/tests/aggregate/test_aggregate_association_dao.py @@ -1,3 +1,8 @@ +"""This test file is a mirror image of `test_aggregate_association.py` but testing with DAOs. + +Accessing DAOs and persisting via them is not ideal. This test file is here only to highlight +breakages at the DAO level.""" + import mock import pytest diff --git a/tests/aggregate/test_aggregate_association_via.py b/tests/aggregate/test_aggregate_association_via.py new file mode 100644 index 00000000..06292dff --- /dev/null +++ b/tests/aggregate/test_aggregate_association_via.py @@ -0,0 +1,32 @@ +import pytest + +from protean import BaseAggregate, BaseEntity +from protean.fields import HasOne, Identifier, String + + +class Account(BaseAggregate): + email = Identifier(identifier=True) + profile = HasOne("Profile", via="parent_email") + + +class Profile(BaseEntity): + name = String() + parent_email = Identifier() + + class Meta: + aggregate_cls = Account + + +@pytest.fixture(autouse=True) +def register_elements(test_domain): + test_domain.register(Account) + test_domain.register(Profile) + + +def test_successful_has_one_initialization_with_a_class_containing_via(test_domain): + profile = Profile(name="John Doe") + account = Account(email="john.doe@gmail.com", profile=profile) + test_domain.repository_for(Account).add(account) + + refreshed_account = test_domain.repository_for(Account)._dao.get(account.email) + assert refreshed_account.profile == profile diff --git a/tests/field/test_has_many.py b/tests/field/test_has_many.py new file mode 100644 index 00000000..eb338e9c --- /dev/null +++ b/tests/field/test_has_many.py @@ -0,0 +1,119 @@ +import pytest + +from protean import BaseAggregate, BaseEntity +from protean.fields import HasMany, String +from protean.reflection import attributes, declared_fields + + +class Post(BaseAggregate): + content = String() + comments = HasMany("Comment") + + +class Comment(BaseEntity): + content = String() + + class Meta: + aggregate_cls = Post + + +@pytest.fixture(autouse=True) +def register_elements(test_domain): + test_domain.register(Post) + test_domain.register(Comment) + + +class TestHasManyFields: + def test_that_has_many_field_appears_in_fields(self): + assert "comments" in declared_fields(Post) + + def test_that_has_many_field_does_not_appear_in_attributes(self): + assert "comments" not in attributes(Post) + + def test_that_reference_field_appears_in_fields(self): + assert "post" in declared_fields(Comment) + + def test_that_reference_field_does_not_appear_in_attributes(self): + assert "post" not in attributes(Comment) + + +class TestHasManyPersistence: + def test_that_has_many_field_is_persisted_along_with_aggregate(self, test_domain): + comment = Comment(content="First Comment") + post = Post(content="My Post", comments=[comment]) + + test_domain.repository_for(Post).add(post) + + assert post.id is not None + assert post.comments[0].id is not None + + persisted_post = test_domain.repository_for(Post).get(post.id) + assert persisted_post.comments[0] == comment + assert persisted_post.comments[0].id == comment.id + assert persisted_post.comments[0].content == comment.content + + def test_that_has_many_field_is_persisted_on_aggregate_update(self, test_domain): + post = Post(content="My Post") + test_domain.repository_for(Post).add(post) + + assert post.id is not None + assert len(post.comments) == 0 + + comment = Comment(content="First Comment") + + # Fetch the persisted book and update its author + persisted_post = test_domain.repository_for(Post).get(post.id) + persisted_post.add_comments(comment) + test_domain.repository_for(Post).add(persisted_post) + + # Fetch it again to ensure the author is persisted + persisted_post = test_domain.repository_for(Post).get(post.id) + + # Ensure that the author is persisted along with the book + assert persisted_post.comments[0] == comment + assert persisted_post.comments[0].id == comment.id + assert persisted_post.comments[0].content == comment.content + + def test_that_has_many_field_is_updated_with_new_entity_on_aggregate_update( + self, test_domain + ): + comment = Comment(content="First Comment") + post = Post(content="My Post", comments=[comment]) + + test_domain.repository_for(Post).add(post) + + persisted_post = test_domain.repository_for(Post).get(post.id) + + new_comment = Comment(content="Second Comment") + persisted_post.add_comments(new_comment) + + test_domain.repository_for(Post).add(persisted_post) + + # Fetch the post again to ensure comments are updated + updated_book = test_domain.repository_for(Post).get(persisted_post.id) + assert len(updated_book.comments) == 2 + assert updated_book.comments[0] == comment + assert updated_book.comments[0].id == comment.id + assert updated_book.comments[0].content == comment.content + assert updated_book.comments[1] == new_comment + assert updated_book.comments[1].id == new_comment.id + assert updated_book.comments[1].content == new_comment.content + + def test_that_has_many_field_content_can_be_removed_on_aggregate_update( + self, test_domain + ): + comment = Comment(content="First Comment") + post = Post(content="My Post", comments=[comment]) + + test_domain.repository_for(Post).add(post) + + persisted_post = test_domain.repository_for(Post).get(post.id) + + # Remove the author from the book + persisted_post.remove_comments(comment) + + test_domain.repository_for(Post).add(persisted_post) + + # Fetch the book again to ensure the author is removed + updated_post = test_domain.repository_for(Post).get(persisted_post.id) + assert len(updated_post.comments) == 0