Skip to content

Commit

Permalink
Merge branch 'main' of github.com:proteanhq/protean
Browse files Browse the repository at this point in the history
  • Loading branch information
subhashb committed Jun 7, 2024
2 parents ad69a2e + c0b793a commit efa01dd
Show file tree
Hide file tree
Showing 10 changed files with 1,141 additions and 123 deletions.
1,015 changes: 1,015 additions & 0 deletions mega_change.patch

Large diffs are not rendered by default.

130 changes: 69 additions & 61 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ werkzeug = ">=2.0.0"
elasticsearch = {version = "~7.17.9", optional = true}
elasticsearch-dsl = {version = "~7.4.1", optional = true}
redis = {version = "~3.5.2", optional = true}
sqlalchemy = {version = "~1.4.50", optional = true}
sqlalchemy = {version = "~2.0.30", optional = true}
psycopg2 = {version = ">=2.9.9", optional = true}
celery = { version = "~5.2.7", extras = ["redis"], optional = true}
flask = {version = ">=1.1.1", optional = true}
Expand Down
86 changes: 36 additions & 50 deletions src/protean/adapters/repository/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@
import logging
import uuid

from abc import ABCMeta
from enum import Enum
from typing import Any

import sqlalchemy.dialects.postgresql as psql

from sqlalchemy import Column, MetaData, and_, create_engine, or_, orm
from sqlalchemy import Column, MetaData, and_, create_engine, or_, orm, text
from sqlalchemy import types as sa_types
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import DatabaseError
from sqlalchemy.ext import declarative as sa_dec
from sqlalchemy.ext.declarative import as_declarative, declared_attr
from sqlalchemy.types import CHAR, TypeDecorator

from protean.core.value_object import BaseValueObject
Expand Down Expand Up @@ -134,10 +131,21 @@ def _custom_json_dumps(value):
return json.dumps(value, default=_default)


class DeclarativeMeta(sa_dec.DeclarativeMeta, ABCMeta):
"""Metaclass for the Sqlalchemy declarative schema"""
def derive_schema_name(model_cls):
# Retain schema name if already present, otherwise derive from entity class
if (
hasattr(model_cls.meta_, "schema_name")
and model_cls.meta_.schema_name is not None
):
return model_cls.meta_.schema_name
else:
return model_cls.meta_.entity_cls.meta_.schema_name

def __init__(cls, classname, bases, dict_): # noqa: C901

class SqlalchemyModel(orm.DeclarativeBase, BaseModel):
"""Model representation for the Sqlalchemy Database"""

def __init_subclass__(subclass, **kwargs): # noqa: C901
field_mapping = {
Boolean: sa_types.Boolean,
Date: sa_types.Date,
Expand Down Expand Up @@ -166,13 +174,13 @@ def field_mapping_for(field_obj: Field):
return field_mapping.get(field_cls)

# Update the class attrs with the entity attributes
if "meta_" in dict_:
entity_cls = dict_["meta_"].entity_cls
if "meta_" in subclass.__dict__:
entity_cls = subclass.__dict__["meta_"].entity_cls
for _, field_obj in attributes(entity_cls).items():
attribute_name = field_obj.attribute_name

# Map the field if not in attributes
if attribute_name not in cls.__dict__:
if attribute_name not in subclass.__dict__:
# Derive field based on field enclosed within ShadowField
if isinstance(field_obj, _ShadowField):
field_obj = field_obj.field_obj
Expand All @@ -185,7 +193,7 @@ def field_mapping_for(field_obj: Field):
sa_type_cls = field_mapping_for(field_obj)

# Upgrade to Postgresql specific Data Types
if cls.metadata.bind.dialect.name == "postgresql":
if subclass.__dict__["engine"].dialect.name == "postgresql":
if field_cls == Dict and not field_obj.pickled:
sa_type_cls = psql.JSON

Expand Down Expand Up @@ -228,37 +236,12 @@ def field_mapping_for(field_obj: Field):
type_kwargs["length"] = field_obj.max_length

# Update the attributes of the class
# SQLAlchemy changed where `DeclarativeMeta.__init__` looks for class attributes
# between versions 1.3 and 1.4, and then fixated on `cls.__dict__` in 1.4.36.
# While the `setattr` below works with the latest SQLAlchemy version of 1.4.36,
# it's better to populate both `dict_` as well as `cls.__dict__` to be compatible
# with all 1.4.10+ versions of SQLAlchemy.
# More info:
# * https://github.com/sqlalchemy/sqlalchemy/issues/6791
# * https://github.com/sqlalchemy/sqlalchemy/issues/7900
column = Column(sa_type_cls(*type_args, **type_kwargs), **col_args)
setattr(cls, attribute_name, column) # Set class attribute
dict_[attribute_name] = column # Set in dict_

super().__init__(classname, bases, dict_)
setattr(subclass, attribute_name, column) # Set class attribute

super().__init_subclass__(**kwargs)

def derive_schema_name(model_cls):
# Retain schema name if already present, otherwise derive from entity class
if (
hasattr(model_cls.meta_, "schema_name")
and model_cls.meta_.schema_name is not None
):
return model_cls.meta_.schema_name
else:
return model_cls.meta_.entity_cls.meta_.schema_name


@as_declarative(metaclass=DeclarativeMeta)
class SqlalchemyModel(BaseModel):
"""Model representation for the Sqlalchemy Database"""

@declared_attr
@orm.declared_attr
def __tablename__(cls):
return derive_schema_name(cls)

Expand Down Expand Up @@ -516,7 +499,7 @@ def _raw(self, query: Any, data: Any = None):

conn = self._get_session()
try:
results = conn.execute(query)
results = conn.execute(text(query))

entity_items = []
for item in results:
Expand Down Expand Up @@ -578,9 +561,9 @@ def __init__(self, *args, **kwargs):
self.conn_info["SCHEMA"] if "SCHEMA" in self.conn_info else "public"
)

self._metadata = MetaData(bind=self._engine, schema=schema)
self._metadata = MetaData(schema=schema)
else:
self._metadata = MetaData(bind=self._engine)
self._metadata = MetaData()

# A temporary cache of already constructed model classes
self._model_classes = {}
Expand All @@ -604,7 +587,7 @@ def _get_database_specific_session_args(self):
Return: a dictionary with additional arguments and values.
"""
if self.conn_info["database"] == self.databases.postgresql.value:
return {"autocommit": True, "autoflush": False}
return {"autoflush": False}

return {}

Expand Down Expand Up @@ -632,7 +615,7 @@ def _execute_database_specific_connection_statements(self, conn):
Return: None
"""
if self.conn_info["database"] == self.databases.sqlite.value:
conn.execute("PRAGMA case_sensitive_like = ON;")
conn.execute(text("PRAGMA case_sensitive_like = ON;"))

return conn

Expand All @@ -655,13 +638,13 @@ def _data_reset(self):
transaction = conn.begin()

if self.conn_info["database"] == self.databases.sqlite.value:
conn.execute("PRAGMA foreign_keys = OFF;")
conn.execute(text("PRAGMA foreign_keys = OFF;"))

for table in self._metadata.sorted_tables:
conn.execute(table.delete())

if self.conn_info["database"] == self.databases.sqlite.value:
conn.execute("PRAGMA foreign_keys = ON;")
conn.execute(text("PRAGMA foreign_keys = ON;"))

transaction.commit()

Expand All @@ -673,10 +656,10 @@ def _create_database_artifacts(self):
for _, aggregate_record in self.domain.registry.aggregates.items():
self.domain.repository_for(aggregate_record.cls)._dao

self._metadata.create_all()
self._metadata.create_all(self._engine)

def _drop_database_artifacts(self):
self._metadata.drop_all()
self._metadata.drop_all(self._engine)
self._metadata.clear()

def decorate_model_class(self, entity_cls, model_cls):
Expand Down Expand Up @@ -720,7 +703,9 @@ def decorate_model_class(self, entity_cls, model_cls):
schema_name if meta_.schema_name is None else meta_.schema_name
)

custom_attrs.update({"meta_": meta_, "metadata": self._metadata})
custom_attrs.update(
{"meta_": meta_, "engine": self._engine, "metadata": self._metadata}
)
# FIXME Ensure the custom model attributes are constructed properly
decorated_model_cls = type(
model_cls.__name__, (SqlalchemyModel, model_cls), custom_attrs
Expand Down Expand Up @@ -752,6 +737,7 @@ def construct_model_class(self, entity_cls):

attrs = {
"meta_": meta_,
"engine": self._engine,
"metadata": self._metadata,
}
# FIXME Ensure the custom model attributes are constructed properly
Expand All @@ -774,7 +760,7 @@ def raw(self, query: Any, data: Any = None):
assert isinstance(query, str)
assert isinstance(data, (dict, None))

return self.get_connection().execute(query, data)
return self.get_connection().execute(text(query), data)


operators = {
Expand Down
6 changes: 4 additions & 2 deletions tests/adapters/model/sqlalchemy_model/postgresql/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def setup_db():
domain.repository_for(ListUser)._dao
domain.repository_for(IntegerListUser)._dao

domain.providers["default"]._metadata.create_all()
default_provider = domain.providers["default"]
default_provider._metadata.create_all(default_provider._engine)

yield

# Drop all tables at the end of test suite
domain.providers["default"]._metadata.drop_all()
default_provider = domain.providers["default"]
default_provider._metadata.drop_all(default_provider._engine)
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ class ReceiverInlineModel:

test_domain.repository_for(Receiver)._dao # Registers and refreshes DB objects

provider = test_domain.providers["default"]
provider._metadata.create_all()
default_provider = test_domain.providers["default"]
default_provider._metadata.create_all(default_provider._engine)

model_cls = test_domain.repository_for(Receiver)._model
assert model_cls.__name__ == "ReceiverInlineModel"
Expand Down
6 changes: 4 additions & 2 deletions tests/adapters/model/sqlalchemy_model/sqlite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ def setup_db():
domain.repository_for(Provider)._dao
domain.repository_for(User)._dao

domain.providers["default"]._metadata.create_all()
default_provider = domain.providers["default"]
default_provider._metadata.create_all(default_provider._engine)

yield

# Drop all tables at the end of test suite
domain.providers["default"]._metadata.drop_all()
default_provider = domain.providers["default"]
default_provider._metadata.drop_all(default_provider._engine)
3 changes: 2 additions & 1 deletion tests/adapters/model/sqlalchemy_model/sqlite/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ class ReceiverInlineModel:

test_domain.repository_for(Receiver)._dao

test_domain.providers["default"]._metadata.create_all()
default_provider = test_domain.providers["default"]
default_provider._metadata.create_all(default_provider._engine)

model_cls = test_domain.repository_for(Receiver)._model
assert model_cls.__name__ == "ReceiverInlineModel"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ def setup_db():
domain.repository_for(Customer)._dao
domain.repository_for(Order)._dao

domain.providers["default"]._metadata.create_all()
default_provider = domain.providers["default"]
default_provider._metadata.create_all(default_provider._engine)

yield

# Drop all tables at the end of test suite
domain.providers["default"]._metadata.drop_all()
default_provider = domain.providers["default"]
default_provider._metadata.drop_all(default_provider._engine)
6 changes: 4 additions & 2 deletions tests/adapters/repository/sqlalchemy_repo/sqlite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ def setup_db():
domain.repository_for(User)._dao
domain.repository_for(ComplexUser)._dao

domain.providers["default"]._metadata.create_all()
default_provider = domain.providers["default"]
default_provider._metadata.create_all(default_provider._engine)

yield

# Drop all tables at the end of test suite
domain.providers["default"]._metadata.drop_all()
default_provider = domain.providers["default"]
default_provider._metadata.drop_all(default_provider._engine)

0 comments on commit efa01dd

Please sign in to comment.