Skip to content

Commit

Permalink
Streamline sqlalchemy providers config
Browse files Browse the repository at this point in the history
This commit changes the provider for Postgresql and SQLite databases from
`sqlalchemy` to `postgresql` and `sqlite` respectively. This allows for
greater freedom in changing underlying structures and API in the future,
and also streamlines the database config in sync with other database providers.
  • Loading branch information
subhashb committed Jul 8, 2024
1 parent 4bd2bd1 commit d48fa2d
Show file tree
Hide file tree
Showing 22 changed files with 55 additions and 65 deletions.
3 changes: 1 addition & 2 deletions docs/guides/compose-a-domain/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ provider = "memory"
provider = "memory"

[databases.sqlite]
provider = "sqlalchemy"
database = "sqlite"
provider = "sqlite"
database_uri = "sqlite:///test.db"
```

Expand Down
8 changes: 4 additions & 4 deletions src/protean/adapters/repository/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

DATABASE_PROVIDERS = {
"memory": "protean.adapters.MemoryProvider",
"sqlalchemy": "protean.adapters.repository.sqlalchemy.SAProvider",
"postgresql": "protean.adapters.repository.sqlalchemy.PostgresqlProvider",
"sqlite": "protean.adapters.repository.sqlalchemy.SqliteProvider",
"elasticsearch": "protean.adapters.repository.elasticsearch.ESProvider",
}

Expand Down Expand Up @@ -67,8 +68,7 @@ def _register_repository(self, part_of, repository_cls):
# For example, with the following PostgreSQL configuration:
# databases = {
# "default": {
# "provider": "sqlalchemy",
# "database": "postgresql",
# "provider": "postgresql",
# "database_uri": "postgresql://postgres:postgres@localhost:5432/postgres",
# },
# }
Expand Down Expand Up @@ -130,7 +130,7 @@ def repository_for(self, part_of) -> BaseRepository:

provider_name = part_of.meta_.provider
provider = self._providers[provider_name]
database = provider.conn_info["database"]
database = provider.__class__.__database__

aggregate_name = fully_qualified_name(part_of)

Expand Down
24 changes: 14 additions & 10 deletions src/protean/adapters/repository/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,12 @@ def _raw(self, query: Any, data: Any = None):


class ESProvider(BaseProvider):
__database__ = "elasticsearch"

def __init__(self, name, domain, conn_info: dict):
"""Initialize Provider with Connection/Adapter details"""

# In case of `ESProvider`, the `database` value will always be `ELASTICSEARCH`.
conn_info["database"] = "elasticsearch"
conn_info["database_uri"] = json.loads(conn_info["database_uri"])
super().__init__(name, domain, conn_info)

Expand Down Expand Up @@ -493,9 +494,10 @@ def _data_reset(self):
repo = self.domain.repository_for(element_record.cls)

model_cls = repo._model
if provider.conn_info[
"database"
] == "elasticsearch" and conn.indices.exists(index=model_cls._index._name):
if (
provider.__class__.__database__ == "elasticsearch"
and conn.indices.exists(index=model_cls._index._name)
):
conn.delete_by_query(
refresh=True,
index=model_cls._index._name,
Expand All @@ -513,9 +515,10 @@ def _create_database_artifacts(self):
for _, element_record in elements.items():
provider = current_domain.providers[element_record.cls.meta_.provider]
model_cls = current_domain.repository_for(element_record.cls)._model
if provider.conn_info[
"database"
] == "elasticsearch" and not model_cls._index.exists(using=conn):
if (
provider.__class__.__database__ == "elasticsearch"
and not model_cls._index.exists(using=conn)
):
# We use model_cls here to ensure the index is created along with mappings
model_cls.init(using=conn)

Expand All @@ -530,9 +533,10 @@ def _drop_database_artifacts(self):
for _, element_record in elements.items():
model_cls = self.domain.repository_for(element_record.cls)._model
provider = self.domain.providers[element_record.cls.meta_.provider]
if provider.conn_info[
"database"
] == "elasticsearch" and model_cls._index.exists(using=conn):
if (
provider.__class__.__database__ == "elasticsearch"
and model_cls._index.exists(using=conn)
):
conn.indices.delete(index=model_cls._index._name)


Expand Down
3 changes: 2 additions & 1 deletion src/protean/adapters/repository/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,12 @@ def close(self):
class MemoryProvider(BaseProvider):
"""Provider class for Dict Repositories"""

__database__ = "memory"

def __init__(self, name, domain, conn_info: dict):
"""Initialize Provider with Connection/Adapter details"""

# In case of `MemoryProvider`, the `database` value will always be `memory`.
conn_info["database"] = "memory"
super().__init__(name, domain, conn_info)

# Global in-memory store of dict data.
Expand Down
27 changes: 14 additions & 13 deletions src/protean/adapters/repository/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,13 +527,6 @@ class databases(Enum):

def __init__(self, *args, **kwargs):
"""Initialize and maintain Engine"""
# Since SQLAlchemyProvider can cater to multiple databases, it is important
# that we know which database we are dealing with, to run database-specific
# statements like `PRAGMA` for SQLite.
if "database" not in args[2]:
logger.error(f"Missing `database` information in conn_info: {args[2]}")
raise ConfigurationError("Missing `database` attribute in Connection info")

super().__init__(*args, **kwargs)

kwargs = self._get_database_specific_engine_args()
Expand All @@ -544,7 +537,7 @@ def __init__(self, *args, **kwargs):
**kwargs,
)

if self.conn_info["database"] == self.databases.postgresql.value:
if self.__database__ == self.databases.postgresql.value:
# Nest database tables under a schema, so that we have complete control
# on creating/dropping db structures. We cannot control structures in the
# the default `public` schema.
Expand All @@ -567,7 +560,7 @@ def _get_database_specific_engine_args(self):
Return: a dictionary with database-specific SQLAlchemy Engine arguments.
"""
if self.conn_info["database"] == self.databases.postgresql.value:
if self.__database__ == self.databases.postgresql.value:
return {"isolation_level": "AUTOCOMMIT"}

return {}
Expand All @@ -580,7 +573,7 @@ def _get_database_specific_session_args(self):
Return: a dictionary with additional arguments and values.
"""
if self.conn_info["database"] == self.databases.postgresql.value:
if self.__database__ == self.databases.postgresql.value:
return {"autoflush": False}

return {}
Expand Down Expand Up @@ -608,7 +601,7 @@ def _execute_database_specific_connection_statements(self, conn):
Return: None
"""
if self.conn_info["database"] == self.databases.sqlite.value:
if self.__database__ == self.databases.sqlite.value:
conn.execute(text("PRAGMA case_sensitive_like = ON;"))

return conn
Expand All @@ -631,13 +624,13 @@ def _data_reset(self):

transaction = conn.begin()

if self.conn_info["database"] == self.databases.sqlite.value:
if self.__database__ == self.databases.sqlite.value:
conn.execute(text("PRAGMA foreign_keys = OFF;"))

for table in self._metadata.sorted_tables:
conn.execute(table.delete())

if self.conn_info["database"] == self.databases.sqlite.value:
if self.__database__ == self.databases.sqlite.value:
conn.execute(text("PRAGMA foreign_keys = ON;"))

transaction.commit()
Expand Down Expand Up @@ -753,6 +746,14 @@ def raw(self, query: Any, data: Any = None):
return self.get_connection().execute(text(query), data)


class PostgresqlProvider(SAProvider):
__database__ = SAProvider.databases.postgresql.value


class SqliteProvider(SAProvider):
__database__ = SAProvider.databases.sqlite.value


operators = {
"exact": "__eq__",
"iexact": "ilike",
Expand Down
2 changes: 1 addition & 1 deletion src/protean/core/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _model(self):
# If no database is specified, model can be used for all databases
if custom_model_cls and (
custom_model_cls.meta_.database is None
or custom_model_cls.meta_.database == self._provider.conn_info["database"]
or custom_model_cls.meta_.database == self._provider.__class__.__database__
):
# Get the decorated model class.
# This is a no-op if the provider decides that the model is fully-baked
Expand Down
3 changes: 1 addition & 2 deletions tests/adapters/model/sqlalchemy_model/postgresql/domain.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
[databases.default]
provider = "sqlalchemy"
database = "postgresql"
provider = "postgresql"
database_uri = "postgresql://postgres:postgres@localhost:5432/postgres"
3 changes: 1 addition & 2 deletions tests/adapters/model/sqlalchemy_model/sqlite/domain.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
[databases.default]
provider = "sqlalchemy"
database = "sqlite"
provider = "sqlite"
database_uri = "sqlite:///test.db"
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
identity_type = "uuid"

[databases.default]
provider = "sqlalchemy"
database = "postgresql"
provider = "postgresql"
database_uri = "postgresql://postgres:postgres@localhost:5432/postgres"
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy.orm.session import Session

from protean.adapters.repository import Providers
from protean.adapters.repository.sqlalchemy import SAProvider
from protean.adapters.repository.sqlalchemy import PostgresqlProvider

from .elements import Alien, Person

Expand All @@ -30,7 +30,7 @@ def test_provider_detail(self, test_domain):
"""Test provider info loaded for tests"""

provider1 = test_domain.providers["default"]
assert isinstance(provider1, SAProvider)
assert isinstance(provider1, PostgresqlProvider)

def test_provider_get_connection(self, test_domain):
"""Test ``get_connection`` method and check for connection details"""
Expand Down
3 changes: 1 addition & 2 deletions tests/adapters/repository/sqlalchemy_repo/sqlite/domain.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
[databases.default]
provider = "sqlalchemy"
database = "sqlite"
provider = "sqlite"
database_uri = "sqlite:///test.db"
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy.orm.session import Session

from protean.adapters.repository import Providers
from protean.adapters.repository.sqlalchemy import SAProvider
from protean.adapters.repository.sqlalchemy import SqliteProvider

from .elements import Alien, Person

Expand All @@ -30,7 +30,7 @@ def test_provider_detail(self, test_domain):
"""Test provider info loaded for tests"""

provider1 = test_domain.providers["default"]
assert isinstance(provider1, SAProvider)
assert isinstance(provider1, SqliteProvider)

def test_provider_get_connection(self, test_domain):
"""Test ``get_connection`` method and check for connection details"""
Expand Down
3 changes: 1 addition & 2 deletions tests/cli/test_generate_docker_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def test_correct_config_is_loaded(self):
domain.init()
assert domain is not None
assert domain.name == "SQLite-Domain"
assert domain.providers["default"].conn_info["provider"] == "sqlalchemy"
assert domain.providers["default"].conn_info["database"] == "sqlite"
assert domain.providers["default"].conn_info["provider"] == "sqlite"
assert domain.providers["default"]._engine.url.database == ":memory:"
assert domain.providers["default"]._engine.url.drivername == "sqlite"

Expand Down
3 changes: 1 addition & 2 deletions tests/config/domain.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ provider = "memory"
provider = "memory"

[databases.sqlite]
provider = "sqlalchemy"
database = "sqlite"
provider = "sqlite"
database_uri = "sqlite:///test.db"

[brokers.default]
Expand Down
6 changes: 2 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ def db_config(request):
return {
"MEMORY": {"provider": "memory"},
"POSTGRESQL": {
"provider": "sqlalchemy",
"database": "postgresql",
"provider": "postgresql",
"database_uri": "postgresql://postgres:postgres@localhost:5432/postgres",
},
"ELASTICSEARCH": {
Expand All @@ -163,8 +162,7 @@ def db_config(request):
"database_uri": {"hosts": ["localhost"]},
},
"SQLITE": {
"provider": "sqlalchemy",
"database": "sqlite",
"provider": "sqlite",
"database_uri": "sqlite:///test.db",
},
}[request.config.getoption("--db", "MEMORY")]
Expand Down
3 changes: 1 addition & 2 deletions tests/domain.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ provider = "memory"
provider = "memory"

[databases.sqlite]
provider = "sqlalchemy"
database = "sqlite"
provider = "sqlite"
database_uri = "sqlite:///test.db"

[brokers.default]
Expand Down
3 changes: 1 addition & 2 deletions tests/repository/test_custom_repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ class TestRepositoryConstructionAndRegistration:
@pytest.fixture
def custom_test_domain(self, test_domain):
test_domain.config["databases"]["sqlite"] = {
"provider": "sqlalchemy",
"database": "sqlite",
"provider": "sqlite",
"database_uri": "sqlite:///test.db",
}
test_domain._initialize()
Expand Down
3 changes: 1 addition & 2 deletions tests/support/domains/test18/domain.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
[databases.secondary]
provider = "sqlalchemy"
database = "sqlite"
provider = "sqlite"
database_uri = "${SQLITE_DB_LOCATION}"

[custom]
Expand Down
3 changes: 1 addition & 2 deletions tests/support/domains/test20/domain.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ provider = "memory"
provider = "memory"

[databases.sqlite]
provider = "sqlalchemy"
database = "sqlite"
provider = "sqlite"
database_uri = "sqlite:///test.db"

[brokers.default]
Expand Down
3 changes: 1 addition & 2 deletions tests/support/domains/test21/domain.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ provider = "memory"
provider = "memory"

[databases.sqlite]
provider = "sqlalchemy"
database = "sqlite"
provider = "sqlite"
database_uri = "sqlite:///test.db"

[brokers.default]
Expand Down
3 changes: 1 addition & 2 deletions tests/support/domains/test24/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ provider = "memory"
provider = "memory"

[tool.protean.databases.sqlite]
provider = "sqlalchemy"
database = "sqlite"
provider = "sqlite"
database_uri = "sqlite:///test.db"

[tool.protean.brokers.default]
Expand Down
3 changes: 1 addition & 2 deletions tests/support/domains/test8/sqlite_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

domain.config["databases"] = {
"default": {
"provider": "sqlalchemy",
"database": "sqlite",
"provider": "sqlite",
"database_uri": "sqlite:///:memory:",
}
}

0 comments on commit d48fa2d

Please sign in to comment.