From 81d6c8748a9781a4cdf6d61a0977fcd985f9fba9 Mon Sep 17 00:00:00 2001 From: Subhash Bhushan Date: Wed, 19 Apr 2023 12:22:08 -0700 Subject: [PATCH] Add multi-tenancy support for Postgres (#383) This PR addresses 382 and adds naive support for multi-tenancy with a "MULTITENANCY" flag. Domain context can now be initialized witn arbitrary attributes. In our case, we will be setting MULTITENANCY to True. Improvements to be done: - a better way to supply database info in a multi-tenancy use case - reset database connection to default after execution --- src/protean/adapters/repository/__init__.py | 9 +++++-- src/protean/domain/__init__.py | 6 ++--- src/protean/domain/context.py | 6 ++++- .../postgresql/test_schema_switch.py | 26 +++++++++++++++++++ tests/context/tests.py | 14 ++++++++++ 5 files changed, 55 insertions(+), 6 deletions(-) create mode 100644 tests/adapters/repository/sqlalchemy_repo/postgresql/test_schema_switch.py diff --git a/src/protean/adapters/repository/__init__.py b/src/protean/adapters/repository/__init__.py index 0ffe47f7..5ee8377d 100644 --- a/src/protean/adapters/repository/__init__.py +++ b/src/protean/adapters/repository/__init__.py @@ -7,6 +7,7 @@ from protean.core.repository import BaseRepository, repository_factory from protean.exceptions import ConfigurationError +from protean.globals import g from protean.utils import fully_qualified_name logger = logging.getLogger(__name__) @@ -119,7 +120,9 @@ def _initialize(self): def get_connection(self, provider_name="default"): """Fetch connection from Provider""" - if self._providers is None: + if ( + hasattr(g, "MULTITENANCY") and g.MULTITENANCY is True + ) or self._providers is None: self._initialize() try: @@ -129,7 +132,9 @@ def get_connection(self, provider_name="default"): def repository_for(self, aggregate_cls): """Retrieve a Repository registered for the Aggregate""" - if self._providers is None: + if ( + hasattr(g, "MULTITENANCY") and g.MULTITENANCY is True + ) or self._providers is None: self._initialize() provider_name = aggregate_cls.meta_.provider diff --git a/src/protean/domain/__init__.py b/src/protean/domain/__init__.py index 6aa5c3c6..24082933 100644 --- a/src/protean/domain/__init__.py +++ b/src/protean/domain/__init__.py @@ -261,7 +261,7 @@ def make_config(self, instance_relative=False): defaults["DEBUG"] = get_debug_flag() return self.config_class(root_path, defaults) - def domain_context(self): + def domain_context(self, **kwargs): """Create an :class:`~protean.context.DomainContext`. Use as a ``with`` block to push the context, which will make :data:`current_domain` point at this domain. @@ -271,7 +271,7 @@ def domain_context(self): with domain.domain_context(): init_db() """ - return DomainContext(self) + return DomainContext(self, **kwargs) def teardown_domain_context(self, f): """Registers a function to be called when the domain context @@ -754,7 +754,7 @@ def handlers_for(self, event: BaseEvent) -> List[BaseEventHandler]: # Repository Functionality # ############################ - @lru_cache(maxsize=None) + # FIXME Optimize calls to this method with cache, but also with support for Multitenancy def repository_for(self, aggregate_cls): if aggregate_cls.element_type == DomainObjects.EVENT_SOURCED_AGGREGATE: return self.event_store.repository_for(aggregate_cls) diff --git a/src/protean/domain/context.py b/src/protean/domain/context.py index 6005d478..c6c8c06b 100644 --- a/src/protean/domain/context.py +++ b/src/protean/domain/context.py @@ -75,10 +75,14 @@ class DomainContext(object): to the current thread or greenlet. """ - def __init__(self, domain): + def __init__(self, domain, **kwargs): self.domain = domain self.g = domain.domain_context_globals_class() + # Set any additional kwargs as attributes in globals + for kw in kwargs.items(): + setattr(self.g, *kw) + # Use a basic "refcount" to track number of domain contexts self._ref_count = 0 diff --git a/tests/adapters/repository/sqlalchemy_repo/postgresql/test_schema_switch.py b/tests/adapters/repository/sqlalchemy_repo/postgresql/test_schema_switch.py new file mode 100644 index 00000000..6ec72255 --- /dev/null +++ b/tests/adapters/repository/sqlalchemy_repo/postgresql/test_schema_switch.py @@ -0,0 +1,26 @@ +import pytest + +from protean.globals import current_domain + +from .elements import Person + + +# @pytest.mark.postgresql +class TestSchemaSwitch: + @pytest.fixture(autouse=True) + def register_elements(self, test_domain): + test_domain.register(Person) + + def test_schema_switch(self, test_domain): + repo = test_domain.repository_for(Person) + assert repo._provider._metadata.schema == "public" + + with current_domain.domain_context(MULTITENANCY=True): + current_domain.config["DATABASES"]["default"]["SCHEMA"] = "private" + + repo1 = current_domain.repository_for(Person) + assert repo1._provider._metadata.schema == "private" + + # FIXME Reset the database info to default outside the context + # repo2 = test_domain.repository_for(Person) + # assert repo2._provider._metadata.schema == "public" diff --git a/tests/context/tests.py b/tests/context/tests.py index 2445e00f..0e0889f4 100644 --- a/tests/context/tests.py +++ b/tests/context/tests.py @@ -137,3 +137,17 @@ def __init__(self): test_domain.domain_context_globals_class = CustomRequestGlobals with test_domain.domain_context(): assert g.spam == "eggs" + + # Test passing kwargs to domain context during activation + def test_domain_context_kwargs(self, test_domain): + with test_domain.domain_context(foo="bar"): + assert g.foo == "bar" + + assert "foo" not in g + + # Test global attributes are not shared between domain contexts + def test_domain_context_globals_not_shared(self, test_domain): + with test_domain.domain_context(foo="bar"): + assert g.foo == "bar" + with test_domain.domain_context(foo="baz"): + assert g.foo == "baz"