diff --git a/src/protean/core/repository/base.py b/src/protean/core/repository/base.py index 5df4c115..a4db272e 100644 --- a/src/protean/core/repository/base.py +++ b/src/protean/core/repository/base.py @@ -260,6 +260,11 @@ def to_entity(cls, *args, **kwargs): class BaseConnectionHandler(metaclass=ABCMeta): """ Interface to manage connections to the database """ + @abstractmethod def get_connection(self): """ Get the connection object for the repository""" + + @abstractmethod + def close_connection(self, conn): + """ Close the connection object for the repository""" diff --git a/src/protean/core/repository/factory.py b/src/protean/core/repository/factory.py index a2e0d0b6..8542f5a9 100644 --- a/src/protean/core/repository/factory.py +++ b/src/protean/core/repository/factory.py @@ -77,7 +77,8 @@ def register(self, schema_cls, repo_cls=None): # If no connection exists then build it if schema_cls.opts_.bind not in self.connections: - conn_handler = provider.ConnectionHandler(conn_info) + conn_handler = provider.ConnectionHandler( + schema_cls.opts_.bind, conn_info) self._connections.connections[schema_cls.opts_.bind] = \ conn_handler.get_connection() @@ -95,5 +96,13 @@ def __getattr__(self, schema): except KeyError: raise AssertionError('Unregistered Schema') + def close_connections(self): + """ Close all connections registered with the repository """ + for conn_name, conn_obj in self.connections.items(): + conn_info = self.repositories[conn_name] + provider = importlib.import_module(conn_info['PROVIDER']) + conn_handler = provider.ConnectionHandler(conn_name, conn_info) + conn_handler.close_connection(conn_obj) + repo = RepositoryFactory() diff --git a/src/protean/impl/repository/dict_repo.py b/src/protean/impl/repository/dict_repo.py index 118c2456..060d480f 100644 --- a/src/protean/impl/repository/dict_repo.py +++ b/src/protean/impl/repository/dict_repo.py @@ -3,6 +3,8 @@ from collections import defaultdict from itertools import count +from threading import Lock + from operator import itemgetter from protean.core.field import Auto @@ -10,6 +12,12 @@ Pagination, BaseConnectionHandler +# Global in-memory store of dict data. Keyed by name, to provide +# multiple named local memory caches. +_databases = {} +_locks = {} + + class Repository(BaseRepository): """ A repository for storing data in a dictionary """ @@ -35,7 +43,8 @@ def _create_object(self, schema_obj): # Add the entity to the repository identifier = schema_obj[self.entity_cls.meta_.id_field.field_name] - self.conn['data'][self.schema_name][identifier] = schema_obj + with self.conn['lock']: + self.conn['data'][self.schema_name][identifier] = schema_obj return schema_obj def _filter_objects(self, page: int = 1, per_page: int = 10, @@ -71,7 +80,7 @@ def _filter_objects(self, page: int = 1, per_page: int = 10, # Build the pagination results for the filtered items cur_offset, cur_limit = None, None - if per_page is not None: + if per_page > 0: cur_offset = (page - 1) * per_page cur_limit = page * per_page @@ -85,7 +94,8 @@ def _filter_objects(self, page: int = 1, per_page: int = 10, def _update_object(self, schema_obj): """ Update the entity record in the dictionary """ identifier = schema_obj[self.entity_cls.meta_.id_field.field_name] - self.conn['data'][self.schema_name][identifier] = schema_obj + with self.conn['lock']: + self.conn['data'][self.schema_name][identifier] = schema_obj return schema_obj def _delete_objects(self, **filters): @@ -111,7 +121,8 @@ def _delete_objects(self, **filters): def delete_all(self): """ Delete all objects in this schema """ - del self.conn['data'][self.schema_name] + with self.conn['lock']: + del self.conn['data'][self.schema_name] class DictSchema(BaseSchema): @@ -134,10 +145,20 @@ def to_entity(cls, item): class ConnectionHandler(BaseConnectionHandler): """ Handle connections to the dict repository """ - def __init__(self, conn_info): + def __init__(self, conn_name, conn_info): self.conn_info = conn_info + self.conn_name = conn_name def get_connection(self): """ Return the dictionary database object """ - return {'data': defaultdict(dict), - 'counters': defaultdict(count)} + database = { + 'data': _databases.setdefault(self.conn_name, defaultdict(dict)), + 'lock': _locks.setdefault(self.conn_name, Lock()), + 'counters': defaultdict(count) + } + return database + + def close_connection(self, conn): + """ Remove the dictionary database object """ + del _databases[self.conn_name] + del _locks[self.conn_name] diff --git a/tests/core/test_repository.py b/tests/core/test_repository.py index 09248744..dcb16221 100644 --- a/tests/core/test_repository.py +++ b/tests/core/test_repository.py @@ -6,7 +6,7 @@ from protean.core.repository import repo from protean.core.exceptions import ValidationError, ObjectNotFoundError from protean.core import field -from protean.impl.repository.dict_repo import DictSchema +from protean.impl.repository.dict_repo import DictSchema, _databases class Dog(Entity): @@ -128,3 +128,9 @@ def test_delete(self): with pytest.raises(ObjectNotFoundError): repo.DogSchema.get(1) + + def test_close_connections(self): + """ Test closing all connections to the repository""" + assert 'default' in _databases + repo.close_connections() + assert _databases == {} diff --git a/tests/test_context.py b/tests/test_context.py index c081e3c2..f11bbc74 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -76,7 +76,7 @@ def run_create_task(thread_name, name, sleep=0): t.join() # Get the list of dogs and validate the results - dogs = repo.ThreadedDogSchema.filter(per_page=None) + dogs = repo.ThreadedDogSchema.filter(per_page=-1) assert dogs.total == 5 for dog in dogs.items: if dog.name == 'Johnny':