Skip to content

Commit

Permalink
Allow filtering on HasMany entities
Browse files Browse the repository at this point in the history
This commit introduces `get_one_from_<>` and `filter_<>` helper methods
to HasMany relationships.
  • Loading branch information
subhashb committed Jun 5, 2024
1 parent 29830a4 commit 0a7798c
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 7 deletions.
27 changes: 27 additions & 0 deletions docs/guides/domain-definition/fields/association-fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,30 @@ Out[4]:
{'content': 'Qux', 'id': 'b1a7aeda-81ca-4d0b-9d7e-6fe0c000b8af'}],
'id': '29943ac9-a9eb-497b-b6d2-466b30ecd5f5'}
```

You can also use helper methods that begin with `get_one_from_` and `filter_` to filter
for specific entities within the instances.

`get_one_from_` returns a single entity. It raises `ObjectNotFoundError` if no matching
entity for the criteria is found and `TooManyObjectsError` if more than
one entity is found.

`filter` returns a `list` of zero or more matching entities.

```shell hl_lines="9 12"
In [1]: post = Post(
...: title="Foo",
...: comments=[
...: Comment(content="Bar", rating=2.5),
...: Comment(content="Baz", rating=5)
...: ]
...: )

In [2]: post.filter_comments(content="Bar", rating=2.5)
Out[2]: [<Comment: Comment object (id: 3b7fd92e-be11-4b3b-96e9-1caf02779f14)>]

In [3]: comments = post.filter_comments(content="Bar", rating=2.5)

In [4]: comments[0].to_dict()
Out[4]: {'content': 'Bar', 'rating': 2.5, 'id': '3b7fd92e-be11-4b3b-96e9-1caf02779f14'}
```
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from protean import Domain
from protean.fields import HasMany, String, Text
from protean.fields import HasMany, Float, String, Text

domain = Domain(__file__)
domain = Domain(__file__, load_toml=False)


@domain.aggregate
Expand All @@ -14,3 +14,4 @@ class Post:
@domain.entity(part_of=Post)
class Comment:
content = String(required=True, max_length=50)
rating = Float(max_value=5)
2 changes: 1 addition & 1 deletion src/protean/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def __setattr__(self, name, value):
"_owner", # Owner entity in the hierarchy
"_disable_invariant_checks", # Flag to disable invariant checks
]
or name.startswith(("add_", "remove_"))
or name.startswith(("add_", "remove_", "get_one_from_", "filter_"))
):
super().__setattr__(name, value)
else:
Expand Down
14 changes: 10 additions & 4 deletions src/protean/core/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,15 +290,21 @@ def __init__(self, *template, **kwargs): # noqa: C901
if isinstance(field_obj, Association):
getattr(self, field_name) # This refreshes the values in associations

# Set up add and remove methods. These are pseudo methods, `add_*` and
# `remove_*` that point to the HasMany field's `add` and `remove`
# methods. They are wrapped to ensure we pass the object that holds
# the values and temp_cache.
# Set up add and remove methods. These are pseudo methods: `add_*`,
# `remove_*` and `filter_*` that point to the HasMany field's `add`,
# `remove` and `filter` methods. They are wrapped to ensure we pass
# the object that holds the values and temp_cache.
if isinstance(field_obj, HasMany):
setattr(self, f"add_{field_name}", partial(field_obj.add, self))
setattr(
self, f"remove_{field_name}", partial(field_obj.remove, self)
)
setattr(
self, f"get_one_from_{field_name}", partial(field_obj.get, self)
)
setattr(
self, f"filter_{field_name}", partial(field_obj.filter, self)
)

# Now load the remaining fields with a None value, which will fail
# for required fields
Expand Down
41 changes: 41 additions & 0 deletions src/protean/fields/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,3 +673,44 @@ def as_dict(self, value) -> list:
list: A list of dictionaries representing the linked entities.
"""
return [item.to_dict() for item in value]

def get(self, instance, **kwargs):
"""Fetch a single linked entity based on the provided criteria.
Available as `get_one_from_<HasMany Field Name>` method on the entity instance.
Args:
**kwargs: The filtering criteria.
"""
data = self.filter(instance, **kwargs)

if len(data) == 0:
raise exceptions.ObjectNotFoundError(
{"self.field_name": ["No linked entities matching criteria found"]}
)

if len(data) > 1:
raise exceptions.TooManyObjectsError(
{
"self.field_name": [
"Multiple linked entities matching criteria found"
]
}
)

return data[0]

def filter(self, instance, **kwargs):
"""Filter the linked entities based on the provided criteria.
Available as `filter_<HasMany Field Name>` method on the entity instance.
Args:
**kwargs: The filtering criteria.
"""
data = getattr(instance, self.field_name)
return [
item
for item in data
if all(getattr(item, key) == value for key, value in kwargs.items())
]
74 changes: 74 additions & 0 deletions tests/entity/associations/test_has_many_filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest

from datetime import datetime

from protean import BaseEntity, BaseAggregate
from protean.exceptions import ObjectNotFoundError, TooManyObjectsError
from protean.fields import Date, Float, Integer, String, HasMany


class Order(BaseAggregate):
ordered_on = Date()
items = HasMany("OrderItem")


class OrderItem(BaseEntity):
product_id = String(max_length=50)
quantity = Integer()
price = Float()

class Meta:
part_of = Order


@pytest.fixture(autouse=True)
def register_elements(test_domain):
test_domain.register(Order)
test_domain.register(OrderItem)
test_domain.init(traverse=False)


def test_get():
order = Order(
ordered_on=datetime.today().date(),
items=[
OrderItem(product_id="1", quantity=2, price=10.0),
OrderItem(product_id="2", quantity=3, price=15.0),
OrderItem(product_id="3", quantity=2, price=20.0),
],
)

assert order.get_one_from_items(product_id="1").id == order.items[0].id

with pytest.raises(ObjectNotFoundError):
order.get_one_from_items(product_id="4")

with pytest.raises(TooManyObjectsError):
order.get_one_from_items(quantity=2)


def test_filtering():
order = Order(
ordered_on=datetime.today().date(),
items=[
OrderItem(product_id="1", quantity=2, price=10.0),
OrderItem(product_id="2", quantity=3, price=15.0),
OrderItem(product_id="3", quantity=2, price=20.0),
],
)

filtered_item = order.filter_items(product_id="1")
assert len(filtered_item) == 1
assert filtered_item[0].id == order.items[0].id

filtered_items = order.filter_items(quantity=2)
assert len(filtered_items) == 2
assert filtered_items[0].id == order.items[0].id
assert filtered_items[1].id == order.items[2].id

filtered_items = order.filter_items(quantity=2, price=20.0)
assert len(filtered_items) == 1
assert filtered_items[0].id == order.items[2].id

filtered_items = order.filter_items(quantity=3, price=40.0)
assert len(filtered_items) == 0

0 comments on commit 0a7798c

Please sign in to comment.