From e32251e41872fe8da3144bf6cb82b535af4759ab Mon Sep 17 00:00:00 2001 From: Subhash Bhushan Date: Mon, 4 Oct 2021 16:50:25 -0700 Subject: [PATCH] Fix issue with deriving embedded field types Fields inside `ValueObject` fields were always considered Strings when deriving Postgres mapping. This commit considers the field object embedded within `_ShadowField` to derive the appropriate field type when generating model mappings. --- src/protean/adapters/repository/sqlalchemy.py | 8 +++++- .../sqlalchemy_repo/postgresql/conftest.py | 4 ++- .../postgresql/test_associations.py | 25 +++++++++++++++++-- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/protean/adapters/repository/sqlalchemy.py b/src/protean/adapters/repository/sqlalchemy.py index 32d45da1..7cab5b72 100644 --- a/src/protean/adapters/repository/sqlalchemy.py +++ b/src/protean/adapters/repository/sqlalchemy.py @@ -31,6 +31,7 @@ Text, ) from protean.fields.association import Reference, _ReferenceField +from protean.fields.embedded import _ShadowField from protean.globals import current_domain, current_uow from protean.port.dao import BaseDAO, BaseLookup, ResultSet from protean.port.provider import BaseProvider @@ -51,6 +52,7 @@ class GUID(TypeDecorator): """ impl = CHAR + cache_ok = True def load_dialect_impl(self, dialect): if dialect.name == "postgresql": @@ -104,7 +106,7 @@ def _get_identity_type(): class DeclarativeMeta(sa_dec.DeclarativeMeta, ABCMeta): """ Metaclass for the Sqlalchemy declarative schema """ - def __init__(cls, classname, bases, dict_): + def __init__(cls, classname, bases, dict_): # noqa: C901 # Update the class attrs with the entity attributes field_mapping = { @@ -129,6 +131,10 @@ def __init__(cls, classname, bases, dict_): # Map the field if not in attributes if attribute_name not in cls.__dict__: + # Derive field based on field enclosed within ShadowField + if isinstance(field_obj, _ShadowField): + field_obj = field_obj.field_obj + field_cls = type(field_obj) type_args = [] type_kwargs = {} diff --git a/tests/adapters/repository/sqlalchemy_repo/postgresql/conftest.py b/tests/adapters/repository/sqlalchemy_repo/postgresql/conftest.py index d032690b..816aefa9 100644 --- a/tests/adapters/repository/sqlalchemy_repo/postgresql/conftest.py +++ b/tests/adapters/repository/sqlalchemy_repo/postgresql/conftest.py @@ -31,7 +31,7 @@ def setup_db(): with domain.domain_context(): # Create all associated tables from .elements import Alien, ComplexUser, Person, User - from .test_associations import Comment, Post + from .test_associations import Comment, Post, Audit from .test_persistence import Event domain.register(Alien) @@ -41,6 +41,7 @@ def setup_db(): domain.register(User) domain.register(Post) domain.register(Comment) + domain.register(Audit) domain.get_dao(Alien) domain.get_dao(ComplexUser) @@ -49,6 +50,7 @@ def setup_db(): domain.get_dao(User) domain.get_dao(Post) domain.get_dao(Comment) + domain.get_dao(Audit) for provider in domain.providers_list(): provider._metadata.create_all() diff --git a/tests/adapters/repository/sqlalchemy_repo/postgresql/test_associations.py b/tests/adapters/repository/sqlalchemy_repo/postgresql/test_associations.py index 9f4649c9..8069119a 100644 --- a/tests/adapters/repository/sqlalchemy_repo/postgresql/test_associations.py +++ b/tests/adapters/repository/sqlalchemy_repo/postgresql/test_associations.py @@ -2,8 +2,8 @@ import pytest -from protean import BaseAggregate, BaseEntity, UnitOfWork -from protean.fields import DateTime, HasMany, Reference, Text +from protean import BaseAggregate, BaseEntity, BaseValueObject, UnitOfWork +from protean.fields import DateTime, Dict, HasMany, Reference, Text, ValueObject class Comment(BaseEntity): @@ -16,11 +16,19 @@ class Meta: aggregate_cls = "Post" +class Permission(BaseValueObject): + dict_object = Dict() + + class Post(BaseAggregate): content = Text(required=True) comments = HasMany(Comment) +class Audit(BaseAggregate): + permission = ValueObject(Permission) + + @pytest.mark.postgresql def test_updating_a_has_many_association(test_domain): test_domain.register(Post) @@ -41,3 +49,16 @@ def test_updating_a_has_many_association(test_domain): refreshed_comment.content = "baz" refreshed_post.add_comments(refreshed_comment) post_repo.add(refreshed_post) + + +@pytest.mark.postgresql +def test_embedded_dict_field_in_value_object(test_domain): + test_domain.register(Audit) + + audit_repo = test_domain.repository_for(Audit) + audit = Audit(permission=Permission(dict_object={"foo": "bar"})) + audit_repo.add(audit) + + assert test_domain.get_dao(Audit).get(audit.id).permission_dict_object == { + "foo": "bar" + }