Skip to content

Commit

Permalink
use -1 for filtering all results
Browse files Browse the repository at this point in the history
repositories must expose the close_connection method
  • Loading branch information
abhishekram committed Dec 4, 2018
1 parent b933040 commit 7b64cd5
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 10 deletions.
5 changes: 5 additions & 0 deletions src/protean/core/repository/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ def to_entity(cls, *args, **kwargs):

class BaseConnectionHandler(metaclass=ABCMeta):
""" Interface to manage connections to the database """

@abstractmethod
def get_connection(self):
""" Get the connection object for the repository"""

@abstractmethod
def close_connection(self, conn):
""" Close the connection object for the repository"""
11 changes: 10 additions & 1 deletion src/protean/core/repository/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def register(self, schema_cls, repo_cls=None):

# If no connection exists then build it
if schema_cls.opts_.bind not in self.connections:
conn_handler = provider.ConnectionHandler(conn_info)
conn_handler = provider.ConnectionHandler(
schema_cls.opts_.bind, conn_info)
self._connections.connections[schema_cls.opts_.bind] = \
conn_handler.get_connection()

Expand All @@ -95,5 +96,13 @@ def __getattr__(self, schema):
except KeyError:
raise AssertionError('Unregistered Schema')

def close_connections(self):
""" Close all connections registered with the repository """
for conn_name, conn_obj in self.connections.items():
conn_info = self.repositories[conn_name]
provider = importlib.import_module(conn_info['PROVIDER'])
conn_handler = provider.ConnectionHandler(conn_name, conn_info)
conn_handler.close_connection(conn_obj)


repo = RepositoryFactory()
35 changes: 28 additions & 7 deletions src/protean/impl/repository/dict_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
from collections import defaultdict
from itertools import count

from threading import Lock

from operator import itemgetter

from protean.core.field import Auto
from protean.core.repository import BaseRepository, BaseSchema, \
Pagination, BaseConnectionHandler


# Global in-memory store of dict data. Keyed by name, to provide
# multiple named local memory caches.
_databases = {}
_locks = {}


class Repository(BaseRepository):
""" A repository for storing data in a dictionary """

Expand All @@ -35,7 +43,8 @@ def _create_object(self, schema_obj):

# Add the entity to the repository
identifier = schema_obj[self.entity_cls.meta_.id_field.field_name]
self.conn['data'][self.schema_name][identifier] = schema_obj
with self.conn['lock']:
self.conn['data'][self.schema_name][identifier] = schema_obj
return schema_obj

def _filter_objects(self, page: int = 1, per_page: int = 10,
Expand Down Expand Up @@ -71,7 +80,7 @@ def _filter_objects(self, page: int = 1, per_page: int = 10,

# Build the pagination results for the filtered items
cur_offset, cur_limit = None, None
if per_page is not None:
if per_page > 0:
cur_offset = (page - 1) * per_page
cur_limit = page * per_page

Expand All @@ -85,7 +94,8 @@ def _filter_objects(self, page: int = 1, per_page: int = 10,
def _update_object(self, schema_obj):
""" Update the entity record in the dictionary """
identifier = schema_obj[self.entity_cls.meta_.id_field.field_name]
self.conn['data'][self.schema_name][identifier] = schema_obj
with self.conn['lock']:
self.conn['data'][self.schema_name][identifier] = schema_obj
return schema_obj

def _delete_objects(self, **filters):
Expand All @@ -111,7 +121,8 @@ def _delete_objects(self, **filters):

def delete_all(self):
""" Delete all objects in this schema """
del self.conn['data'][self.schema_name]
with self.conn['lock']:
del self.conn['data'][self.schema_name]


class DictSchema(BaseSchema):
Expand All @@ -134,10 +145,20 @@ def to_entity(cls, item):
class ConnectionHandler(BaseConnectionHandler):
""" Handle connections to the dict repository """

def __init__(self, conn_info):
def __init__(self, conn_name, conn_info):
self.conn_info = conn_info
self.conn_name = conn_name

def get_connection(self):
""" Return the dictionary database object """
return {'data': defaultdict(dict),
'counters': defaultdict(count)}
database = {
'data': _databases.setdefault(self.conn_name, defaultdict(dict)),
'lock': _locks.setdefault(self.conn_name, Lock()),
'counters': defaultdict(count)
}
return database

def close_connection(self, conn):
""" Remove the dictionary database object """
del _databases[self.conn_name]
del _locks[self.conn_name]
8 changes: 7 additions & 1 deletion tests/core/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from protean.core.repository import repo
from protean.core.exceptions import ValidationError, ObjectNotFoundError
from protean.core import field
from protean.impl.repository.dict_repo import DictSchema
from protean.impl.repository.dict_repo import DictSchema, _databases


class Dog(Entity):
Expand Down Expand Up @@ -128,3 +128,9 @@ def test_delete(self):

with pytest.raises(ObjectNotFoundError):
repo.DogSchema.get(1)

def test_close_connections(self):
""" Test closing all connections to the repository"""
assert 'default' in _databases
repo.close_connections()
assert _databases == {}
2 changes: 1 addition & 1 deletion tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def run_create_task(thread_name, name, sleep=0):
t.join()

# Get the list of dogs and validate the results
dogs = repo.ThreadedDogSchema.filter(per_page=None)
dogs = repo.ThreadedDogSchema.filter(per_page=-1)
assert dogs.total == 5
for dog in dogs.items:
if dog.name == 'Johnny':
Expand Down

0 comments on commit 7b64cd5

Please sign in to comment.