Skip to content

Commit

Permalink
Add tests and increase coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
subhashb committed Jul 7, 2024
1 parent 828a4b1 commit ac9e8b2
Show file tree
Hide file tree
Showing 22 changed files with 513 additions and 190 deletions.
4 changes: 3 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ omit =
show_missing = true
precision = 2
omit = *migrations*

exclude_lines =
pragma: no cover
if TYPE_CHECKING:
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
"module": "pytest",
"justMyCode": false,
"args": [
"tests/adapters/model/elasticsearch_model/tests.py::TestDefaultModel::test_dynamically_constructed_model_attributes",
"tests/adapters/model/elasticsearch_model/tests.py::TestModelWithVO::test_conversion_from_model_to_entity",
"--elasticsearch"
]
},
Expand Down
62 changes: 34 additions & 28 deletions src/protean/adapters/event_store/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import importlib
import logging
from collections import defaultdict
from typing import List, Optional, Type
from typing import TYPE_CHECKING, DefaultDict, List, Optional, Set, Type

from protean import BaseEvent, BaseEventHandler
from protean.core.command import BaseCommand
Expand All @@ -14,6 +16,10 @@
from protean.utils import fqn
from protean.utils.mixins import Message

if TYPE_CHECKING:
from protean.domain import Domain
from protean.port.event_store import BaseEventStore

logger = logging.getLogger(__name__)

EVENT_STORE_PROVIDERS = {
Expand All @@ -24,45 +30,47 @@

class EventStore:
def __init__(self, domain):
self.domain = domain
self._event_store = None
self._event_streams = None
self._command_streams = None
self.domain: Domain = domain
self._event_store: BaseEventStore = None
self._event_streams: DefaultDict[str, Set[BaseEventHandler]] = defaultdict(set)
self._command_streams: DefaultDict[str, Set[BaseCommandHandler]] = defaultdict(
set
)

@property
def store(self):
return self._event_store

def _initialize(self):
logger.debug("Initializing Event Store...")

def _initialize_event_store(self) -> BaseEventStore:
configured_event_store = self.domain.config["event_store"]
if configured_event_store and isinstance(configured_event_store, dict):
event_store_full_path = EVENT_STORE_PROVIDERS[
configured_event_store["provider"]
]
event_store_module, event_store_class = event_store_full_path.rsplit(
".", maxsplit=1
)
event_store_full_path = EVENT_STORE_PROVIDERS[
configured_event_store["provider"]
]
event_store_module, event_store_class = event_store_full_path.rsplit(
".", maxsplit=1
)

event_store_cls = getattr(
importlib.import_module(event_store_module), event_store_class
)
event_store_cls = getattr(
importlib.import_module(event_store_module), event_store_class
)

store = event_store_cls(self.domain, configured_event_store)

store = event_store_cls(self.domain, configured_event_store)
else:
raise ConfigurationError("Configure at least one event store in the domain")
return store

def _initialize(self) -> None:
logger.debug("Initializing Event Store...")

self._event_store = store
# Initialize the Event Store
#
# An event store is always present by default. If not configured explicitly,
# a memory-based event store is used.
self._event_store = self._initialize_event_store()

self._initialize_event_streams()
self._initialize_command_streams()

return self._event_store

def _initialize_event_streams(self):
self._event_streams = defaultdict(set)

for _, record in self.domain.registry.event_handlers.items():
stream_name = (
record.cls.meta_.stream_name
Expand All @@ -71,8 +79,6 @@ def _initialize_event_streams(self):
self._event_streams[stream_name].add(record.cls)

def _initialize_command_streams(self):
self._command_streams = defaultdict(set)

for _, record in self.domain.registry.command_handlers.items():
self._command_streams[record.cls.meta_.part_of.meta_.stream_name].add(
record.cls
Expand Down
17 changes: 17 additions & 0 deletions src/protean/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Category(str, Enum):
CORE = "CORE"
EVENTSTORE = "EVENTSTORE"
DATABASE = "DATABASE"
COVERAGE = "COVERAGE"
FULL = "FULL"


Expand Down Expand Up @@ -119,6 +120,22 @@ def test(
for store in ["MESSAGE_DB"]:
print(f"Running tests for EVENTSTORE: {store}...")
subprocess.call(commands + ["-m", "eventstore", f"--store={store}"])
case "COVERAGE":
subprocess.call(
commands
+ [
"--slow",
"--sqlite",
"--postgresql",
"--elasticsearch",
"--redis",
"--message_db",
"--cov=protean",
"--cov-config",
".coveragerc",
"tests",
]
)
case _:
print("Running core tests...")
subprocess.call(commands)
Expand Down
10 changes: 10 additions & 0 deletions src/protean/core/command_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,16 @@ def command_handler_factory(element_cls, **kwargs):
}
)

if method._target_cls.meta_.part_of != element_cls.meta_.part_of:
raise IncorrectUsageError(
{
"_command_handler": [
f"Command `{method._target_cls.__name__}` in Command Handler `{element_cls.__name__}` "
"is not associated with the same aggregate as the Command Handler"
]
}
)

# Associate Command with the handler's stream
# Order of preference:
# 1. Stream name defined in command
Expand Down
105 changes: 53 additions & 52 deletions src/protean/core/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,21 +138,6 @@ def _default_options(cls):
("schema_name", inflection.underscore(cls.__name__)),
]

@classmethod
def _extract_options(cls, **opts):
"""A stand-in method for setting customized options on the Domain Element
Empty by default. To be overridden in each Element that expects or needs
specific options.
"""
for key, default in cls._default_options():
value = (
opts.pop(key, None)
or (hasattr(cls.meta_, key) and getattr(cls.meta_, key))
or default
)
setattr(cls.meta_, key, value)

def __init__(self, *template, **kwargs): # noqa: C901
"""
Initialise the entity object.
Expand Down Expand Up @@ -217,6 +202,9 @@ def __init__(self, *template, **kwargs): # noqa: C901
id_field_obj = id_field(self)
id_field_name = id_field_obj.field_name

############
# ID Value #
############
# Look for id field in kwargs and load value if present
if kwargs and id_field_name in kwargs:
setattr(self, id_field_name, kwargs.pop(id_field_name))
Expand All @@ -225,7 +213,7 @@ def __init__(self, *template, **kwargs): # noqa: C901
# Look for id field in template dictionary and load value if present
for dictionary in template:
if id_field_name in dictionary:
setattr(self, id_field_name, dictionary[id_field_name])
setattr(self, id_field_name, dictionary.pop(id_field_name))
loaded_fields.append(id_field_name)
break
else:
Expand All @@ -242,33 +230,40 @@ def __init__(self, *template, **kwargs): # noqa: C901
)
loaded_fields.append(id_field_name)

# Load the attributes based on the template
########################
# Load supplied values #
########################
# Gather values from template
template_values = {}
for dictionary in template:
if not isinstance(dictionary, dict):
raise AssertionError(
f'Positional argument "{dictionary}" passed must be a dict.'
f"Positional argument {dictionary} passed must be a dict. "
f"This argument serves as a template for loading common "
f"values.",
)
for field_name, val in dictionary.items():
if field_name not in kwargs and field_name not in loaded_fields:
kwargs[field_name] = val

# Now load against the keyword arguments
for field_name, val in kwargs.items():
if field_name not in loaded_fields:
try:
setattr(self, field_name, val)
except ValidationError as err:
for field_name in err.messages:
self.errors[field_name].extend(err.messages[field_name])
finally:
loaded_fields.append(field_name)
template_values[field_name] = val

# Also note reference field name if its attribute was loaded
if field_name in reference_attributes:
loaded_fields.append(reference_attributes[field_name])
supplied_values = {**template_values, **kwargs}

# Now load the attributes from template and kwargs
for field_name, val in supplied_values.items():
try:
setattr(self, field_name, val)
except ValidationError as err:
for field_name in err.messages:
self.errors[field_name].extend(err.messages[field_name])
finally:
loaded_fields.append(field_name)

# Also note reference field name if its attribute was loaded
if field_name in reference_attributes:
loaded_fields.append(reference_attributes[field_name])

######################
# Load value objects #
######################
# Load Value Objects from associated fields
# This block will dynamically construct value objects from field values
# and associated the vo with the entity
Expand All @@ -279,7 +274,9 @@ def __init__(self, *template, **kwargs): # noqa: C901
(embedded_field.field_name, embedded_field.attribute_name)
for embedded_field in field_obj.embedded_fields.values()
]
kwargs_values = {name: kwargs.get(attr) for name, attr in attrs}
kwargs_values = {
name: supplied_values.get(attr) for name, attr in attrs
}

# Check if any of the values in `values` are not None
# If all values are None, it means that the value object is not being set
Expand All @@ -290,37 +287,38 @@ def __init__(self, *template, **kwargs): # noqa: C901
if any(kwargs_values.values()):
try:
value_object = field_obj.value_object_cls(**kwargs_values)

# Set VO value only if the value object is not None/Empty
if value_object:
setattr(self, field_name, value_object)
loaded_fields.append(field_name)
setattr(self, field_name, value_object)
loaded_fields.append(field_name)
except ValidationError as err:
for sub_field_name in err.messages:
self.errors[
"{}_{}".format(field_name, sub_field_name)
].extend(err.messages[sub_field_name])

#############################
# Generate other identities #
#############################
# Load other identities
for field_name, field_obj in declared_fields(self).items():
if (
field_name not in loaded_fields
and type(field_obj) is Auto
and not field_obj.increment
):
if not getattr(self, field_obj.field_name, None):
setattr(
self,
field_obj.field_name,
generate_identity(
field_obj.identity_strategy,
field_obj.identity_function,
field_obj.identity_type,
),
)
setattr(
self,
field_obj.field_name,
generate_identity(
field_obj.identity_strategy,
field_obj.identity_function,
field_obj.identity_type,
),
)
loaded_fields.append(field_obj.field_name)

# Load Associations
#####################
# Load Associations #
#####################
for field_name, field_obj in declared_fields(self).items():
if isinstance(field_obj, Association):
getattr(self, field_name) # This refreshes the values in associations
Expand All @@ -343,6 +341,9 @@ def __init__(self, *template, **kwargs): # noqa: C901

self.defaults()

#################################
# Mark remaining fields as None #
#################################
# Now load the remaining fields with a None value, which will fail
# for required fields
for field_name, field_obj in fields(self).items():
Expand Down Expand Up @@ -529,7 +530,7 @@ def _update_data(self, *data_dict, **kwargs):
for data in data_dict:
if not isinstance(data, dict):
raise AssertionError(
f'Positional argument "{data}" passed must be a dict.'
f"Positional argument {data} passed must be a dict. "
f"This argument serves as a template for loading common "
f"values.",
)
Expand Down
Loading

0 comments on commit ac9e8b2

Please sign in to comment.