Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

382 Add multi-tenancy support for Postgres #383

Merged
merged 1 commit into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/protean/adapters/repository/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from protean.core.repository import BaseRepository, repository_factory
from protean.exceptions import ConfigurationError
from protean.globals import g
from protean.utils import fully_qualified_name

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -119,7 +120,9 @@ def _initialize(self):

def get_connection(self, provider_name="default"):
"""Fetch connection from Provider"""
if self._providers is None:
if (
hasattr(g, "MULTITENANCY") and g.MULTITENANCY is True
) or self._providers is None:
self._initialize()

try:
Expand All @@ -129,7 +132,9 @@ def get_connection(self, provider_name="default"):

def repository_for(self, aggregate_cls):
"""Retrieve a Repository registered for the Aggregate"""
if self._providers is None:
if (
hasattr(g, "MULTITENANCY") and g.MULTITENANCY is True
) or self._providers is None:
self._initialize()

provider_name = aggregate_cls.meta_.provider
Expand Down
6 changes: 3 additions & 3 deletions src/protean/domain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def make_config(self, instance_relative=False):
defaults["DEBUG"] = get_debug_flag()
return self.config_class(root_path, defaults)

def domain_context(self):
def domain_context(self, **kwargs):
"""Create an :class:`~protean.context.DomainContext`. Use as a ``with``
block to push the context, which will make :data:`current_domain`
point at this domain.
Expand All @@ -271,7 +271,7 @@ def domain_context(self):
with domain.domain_context():
init_db()
"""
return DomainContext(self)
return DomainContext(self, **kwargs)

def teardown_domain_context(self, f):
"""Registers a function to be called when the domain context
Expand Down Expand Up @@ -754,7 +754,7 @@ def handlers_for(self, event: BaseEvent) -> List[BaseEventHandler]:
# Repository Functionality #
############################

@lru_cache(maxsize=None)
# FIXME Optimize calls to this method with cache, but also with support for Multitenancy
def repository_for(self, aggregate_cls):
if aggregate_cls.element_type == DomainObjects.EVENT_SOURCED_AGGREGATE:
return self.event_store.repository_for(aggregate_cls)
Expand Down
6 changes: 5 additions & 1 deletion src/protean/domain/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,14 @@ class DomainContext(object):
to the current thread or greenlet.
"""

def __init__(self, domain):
def __init__(self, domain, **kwargs):
self.domain = domain
self.g = domain.domain_context_globals_class()

# Set any additional kwargs as attributes in globals
for kw in kwargs.items():
setattr(self.g, *kw)

# Use a basic "refcount" to track number of domain contexts
self._ref_count = 0

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest

from protean.globals import current_domain

from .elements import Person


# @pytest.mark.postgresql
class TestSchemaSwitch:
@pytest.fixture(autouse=True)
def register_elements(self, test_domain):
test_domain.register(Person)

def test_schema_switch(self, test_domain):
repo = test_domain.repository_for(Person)
assert repo._provider._metadata.schema == "public"

with current_domain.domain_context(MULTITENANCY=True):
current_domain.config["DATABASES"]["default"]["SCHEMA"] = "private"

repo1 = current_domain.repository_for(Person)
assert repo1._provider._metadata.schema == "private"

# FIXME Reset the database info to default outside the context
# repo2 = test_domain.repository_for(Person)
# assert repo2._provider._metadata.schema == "public"
14 changes: 14 additions & 0 deletions tests/context/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,17 @@ def __init__(self):
test_domain.domain_context_globals_class = CustomRequestGlobals
with test_domain.domain_context():
assert g.spam == "eggs"

# Test passing kwargs to domain context during activation
def test_domain_context_kwargs(self, test_domain):
with test_domain.domain_context(foo="bar"):
assert g.foo == "bar"

assert "foo" not in g

# Test global attributes are not shared between domain contexts
def test_domain_context_globals_not_shared(self, test_domain):
with test_domain.domain_context(foo="bar"):
assert g.foo == "bar"
with test_domain.domain_context(foo="baz"):
assert g.foo == "baz"