From 0a7798cfddefd166d6dd85c48d9a22106dc0fd1c Mon Sep 17 00:00:00 2001 From: Subhash Bhushan Date: Wed, 5 Jun 2024 11:00:49 -0700 Subject: [PATCH] Allow filtering on `HasMany` entities This commit introduces `get_one_from_<>` and `filter_<>` helper methods to HasMany relationships. --- .../fields/association-fields.md | 27 +++++++ .../fields/association-fields/002.py | 5 +- src/protean/container.py | 2 +- src/protean/core/entity.py | 14 +++- src/protean/fields/association.py | 41 ++++++++++ .../associations/test_has_many_filtering.py | 74 +++++++++++++++++++ 6 files changed, 156 insertions(+), 7 deletions(-) create mode 100644 tests/entity/associations/test_has_many_filtering.py diff --git a/docs/guides/domain-definition/fields/association-fields.md b/docs/guides/domain-definition/fields/association-fields.md index b8afb8d8..eeb88bb9 100644 --- a/docs/guides/domain-definition/fields/association-fields.md +++ b/docs/guides/domain-definition/fields/association-fields.md @@ -109,3 +109,30 @@ Out[4]: {'content': 'Qux', 'id': 'b1a7aeda-81ca-4d0b-9d7e-6fe0c000b8af'}], 'id': '29943ac9-a9eb-497b-b6d2-466b30ecd5f5'} ``` + +You can also use helper methods that begin with `get_one_from_` and `filter_` to filter +for specific entities within the instances. + +`get_one_from_` returns a single entity. It raises `ObjectNotFoundError` if no matching +entity for the criteria is found and `TooManyObjectsError` if more than +one entity is found. + +`filter` returns a `list` of zero or more matching entities. + +```shell hl_lines="9 12" +In [1]: post = Post( + ...: title="Foo", + ...: comments=[ + ...: Comment(content="Bar", rating=2.5), + ...: Comment(content="Baz", rating=5) + ...: ] + ...: ) + +In [2]: post.filter_comments(content="Bar", rating=2.5) +Out[2]: [] + +In [3]: comments = post.filter_comments(content="Bar", rating=2.5) + +In [4]: comments[0].to_dict() +Out[4]: {'content': 'Bar', 'rating': 2.5, 'id': '3b7fd92e-be11-4b3b-96e9-1caf02779f14'} +``` \ No newline at end of file diff --git a/docs_src/guides/domain-definition/fields/association-fields/002.py b/docs_src/guides/domain-definition/fields/association-fields/002.py index 769da60d..67fc8f51 100644 --- a/docs_src/guides/domain-definition/fields/association-fields/002.py +++ b/docs_src/guides/domain-definition/fields/association-fields/002.py @@ -1,7 +1,7 @@ from protean import Domain -from protean.fields import HasMany, String, Text +from protean.fields import HasMany, Float, String, Text -domain = Domain(__file__) +domain = Domain(__file__, load_toml=False) @domain.aggregate @@ -14,3 +14,4 @@ class Post: @domain.entity(part_of=Post) class Comment: content = String(required=True, max_length=50) + rating = Float(max_value=5) diff --git a/src/protean/container.py b/src/protean/container.py index 35d106d3..920d0cb3 100644 --- a/src/protean/container.py +++ b/src/protean/container.py @@ -323,7 +323,7 @@ def __setattr__(self, name, value): "_owner", # Owner entity in the hierarchy "_disable_invariant_checks", # Flag to disable invariant checks ] - or name.startswith(("add_", "remove_")) + or name.startswith(("add_", "remove_", "get_one_from_", "filter_")) ): super().__setattr__(name, value) else: diff --git a/src/protean/core/entity.py b/src/protean/core/entity.py index e2b9d2dd..923cfc3f 100644 --- a/src/protean/core/entity.py +++ b/src/protean/core/entity.py @@ -290,15 +290,21 @@ def __init__(self, *template, **kwargs): # noqa: C901 if isinstance(field_obj, Association): getattr(self, field_name) # This refreshes the values in associations - # Set up add and remove methods. These are pseudo methods, `add_*` and - # `remove_*` that point to the HasMany field's `add` and `remove` - # methods. They are wrapped to ensure we pass the object that holds - # the values and temp_cache. + # Set up add and remove methods. These are pseudo methods: `add_*`, + # `remove_*` and `filter_*` that point to the HasMany field's `add`, + # `remove` and `filter` methods. They are wrapped to ensure we pass + # the object that holds the values and temp_cache. if isinstance(field_obj, HasMany): setattr(self, f"add_{field_name}", partial(field_obj.add, self)) setattr( self, f"remove_{field_name}", partial(field_obj.remove, self) ) + setattr( + self, f"get_one_from_{field_name}", partial(field_obj.get, self) + ) + setattr( + self, f"filter_{field_name}", partial(field_obj.filter, self) + ) # Now load the remaining fields with a None value, which will fail # for required fields diff --git a/src/protean/fields/association.py b/src/protean/fields/association.py index dcd23aa2..174dcfa1 100644 --- a/src/protean/fields/association.py +++ b/src/protean/fields/association.py @@ -673,3 +673,44 @@ def as_dict(self, value) -> list: list: A list of dictionaries representing the linked entities. """ return [item.to_dict() for item in value] + + def get(self, instance, **kwargs): + """Fetch a single linked entity based on the provided criteria. + + Available as `get_one_from_` method on the entity instance. + + Args: + **kwargs: The filtering criteria. + """ + data = self.filter(instance, **kwargs) + + if len(data) == 0: + raise exceptions.ObjectNotFoundError( + {"self.field_name": ["No linked entities matching criteria found"]} + ) + + if len(data) > 1: + raise exceptions.TooManyObjectsError( + { + "self.field_name": [ + "Multiple linked entities matching criteria found" + ] + } + ) + + return data[0] + + def filter(self, instance, **kwargs): + """Filter the linked entities based on the provided criteria. + + Available as `filter_` method on the entity instance. + + Args: + **kwargs: The filtering criteria. + """ + data = getattr(instance, self.field_name) + return [ + item + for item in data + if all(getattr(item, key) == value for key, value in kwargs.items()) + ] diff --git a/tests/entity/associations/test_has_many_filtering.py b/tests/entity/associations/test_has_many_filtering.py new file mode 100644 index 00000000..cc572b23 --- /dev/null +++ b/tests/entity/associations/test_has_many_filtering.py @@ -0,0 +1,74 @@ +import pytest + +from datetime import datetime + +from protean import BaseEntity, BaseAggregate +from protean.exceptions import ObjectNotFoundError, TooManyObjectsError +from protean.fields import Date, Float, Integer, String, HasMany + + +class Order(BaseAggregate): + ordered_on = Date() + items = HasMany("OrderItem") + + +class OrderItem(BaseEntity): + product_id = String(max_length=50) + quantity = Integer() + price = Float() + + class Meta: + part_of = Order + + +@pytest.fixture(autouse=True) +def register_elements(test_domain): + test_domain.register(Order) + test_domain.register(OrderItem) + test_domain.init(traverse=False) + + +def test_get(): + order = Order( + ordered_on=datetime.today().date(), + items=[ + OrderItem(product_id="1", quantity=2, price=10.0), + OrderItem(product_id="2", quantity=3, price=15.0), + OrderItem(product_id="3", quantity=2, price=20.0), + ], + ) + + assert order.get_one_from_items(product_id="1").id == order.items[0].id + + with pytest.raises(ObjectNotFoundError): + order.get_one_from_items(product_id="4") + + with pytest.raises(TooManyObjectsError): + order.get_one_from_items(quantity=2) + + +def test_filtering(): + order = Order( + ordered_on=datetime.today().date(), + items=[ + OrderItem(product_id="1", quantity=2, price=10.0), + OrderItem(product_id="2", quantity=3, price=15.0), + OrderItem(product_id="3", quantity=2, price=20.0), + ], + ) + + filtered_item = order.filter_items(product_id="1") + assert len(filtered_item) == 1 + assert filtered_item[0].id == order.items[0].id + + filtered_items = order.filter_items(quantity=2) + assert len(filtered_items) == 2 + assert filtered_items[0].id == order.items[0].id + assert filtered_items[1].id == order.items[2].id + + filtered_items = order.filter_items(quantity=2, price=20.0) + assert len(filtered_items) == 1 + assert filtered_items[0].id == order.items[2].id + + filtered_items = order.filter_items(quantity=3, price=40.0) + assert len(filtered_items) == 0