Skip to content

Commit

Permalink
Store event and command types in domain for easy retrieval
Browse files Browse the repository at this point in the history
Use the `__type__` value present in events and commands to retrieve
event and command classes from domain, instead of fetching by
fully qualified name. This change is required to support cross-domain
event processing.

Also:
- Fix bug with collecting identity value in associations
- Remove unused delist functionality in registry
  • Loading branch information
subhashb committed Jul 16, 2024
1 parent ff2f84b commit 83f0159
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 87 deletions.
27 changes: 16 additions & 11 deletions src/protean/domain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def __init__(
self._models: Dict[str, BaseModel] = {}
self._constructed_models: Dict[str, BaseModel] = {}

# Cache for holding events and commands by their types
self._events_and_commands: Dict[str, Union[BaseCommand, BaseEvent]] = {}

#: A list of functions that are called when the domain context
#: is destroyed. This is the place to store code that cleans up and
#: disconnects from databases, for example.
Expand Down Expand Up @@ -259,7 +262,7 @@ def init(self, traverse=True): # noqa: C901
self._generate_fact_event_classes()

# Generate and set event/command `__type__` value
self._set_event_and_command_type()
self._set_and_record_event_and_command_type()

# Parse and setup handler methods in Command Handlers
self._setup_command_handlers()
Expand Down Expand Up @@ -827,20 +830,22 @@ def _set_aggregate_cluster_options(self):
element.cls.meta_.aggregate_cluster.meta_.provider,
)

def _set_event_and_command_type(self):
def _set_and_record_event_and_command_type(self):
for element_type in [DomainObjects.EVENT, DomainObjects.COMMAND]:
for _, element in self.registry._elements[element_type.value].items():
setattr(
element.cls,
"__type__",
(
f"{self.camel_case_name}."
# f"{element.cls.meta_.aggregate_cluster.__class__.__name__}."
f"{element.cls.__name__}."
f"{element.cls.__version__}"
),
# Type is <Domain Name>.<Event or Command Name>.<Version>
# E.g. `Authentication.UserRegistered.v1`, `Ecommerce.OrderPlaced.v1`
type_string = (
f"{self.camel_case_name}."
# f"{element.cls.meta_.aggregate_cluster.__class__.__name__}."
f"{element.cls.__name__}."
f"{element.cls.__version__}"
)

setattr(element.cls, "__type__", type_string)

self._events_and_commands[type_string] = element.cls

def _setup_command_handlers(self):
for element_type in [DomainObjects.COMMAND_HANDLER]:
for _, element in self.registry._elements[element_type.value].items():
Expand Down
9 changes: 1 addition & 8 deletions src/protean/domain/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def register_element(self, element_cls):
f"Element `{element_cls.__name__}` is not a valid element class"
)

# Element name is always the fully qualified name of the class
element_name = fully_qualified_name(element_cls)

element = self._elements[element_cls.element_type.value][element_name]
Expand Down Expand Up @@ -111,14 +112,6 @@ def register_element(self, element_cls):
f"Registered Element {element_name} with Domain as a {element_cls.element_type.value}"
)

def delist_element(self, element_cls):
if self._is_invalid_element_cls(element_cls):
raise NotImplementedError

element_name = fully_qualified_name(element_cls)

self._elements[element_cls.element_type.value].pop(element_name, None)

@property
def elements(self):
elems = {}
Expand Down
47 changes: 33 additions & 14 deletions src/protean/fields/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,18 +406,24 @@ def __set__(self, instance, value):

# 2. Determine and store the change in the relationship
current_value = getattr(instance, self.field_name)
current_value_id = (
getattr(current_value, id_field(current_value).field_name)
if current_value
else None
)
value_id = getattr(value, id_field(value).field_name) if value else None
if current_value is None:
# Entity was not associated earlier
instance._temp_cache[self.field_name]["change"] = "ADDED"
elif value is None:
# Entity was associated earlier, but now being removed
instance._temp_cache[self.field_name]["change"] = "DELETED"
instance._temp_cache[self.field_name]["old_value"] = current_value
elif current_value.id != value.id:
elif current_value_id != value_id:
# A New Entity is being associated replacing the old one
instance._temp_cache[self.field_name]["change"] = "UPDATED"
instance._temp_cache[self.field_name]["old_value"] = current_value
elif current_value.id == value.id and value.state_.is_changed:
elif current_value_id == value_id and value.state_.is_changed:
# Entity was associated earlier, but now being updated
instance._temp_cache[self.field_name]["change"] = "UPDATED"
else:
Expand Down Expand Up @@ -527,17 +533,20 @@ def add(self, instance, items) -> None:
}
)

current_value_ids = [value.id for value in data]
current_value_ids = [
getattr(value, id_field(value).field_name) for value in data
]

# Remove items when set to empty
if len(items) == 0 and len(current_value_ids) > 0:
self.remove(instance, data)

for item in items:
# Items to add
if item.id not in current_value_ids:
identity = getattr(item, id_field(item).field_name)
if identity not in current_value_ids:
# If the same item is added multiple times, the last item added will win
instance._temp_cache[self.field_name]["added"][item.id] = item
instance._temp_cache[self.field_name]["added"][identity] = item

setattr(
item,
Expand All @@ -552,7 +561,7 @@ def add(self, instance, items) -> None:
self.delete_cached_value(instance)
# Items to update
elif (
item.id in current_value_ids
identity in current_value_ids
and item.state_.is_persisted
and item.state_.is_changed
):
Expand All @@ -565,7 +574,7 @@ def add(self, instance, items) -> None:
# Temporarily set linkage to parent in child entity
setattr(item, self._linked_reference(type(instance)), instance)

instance._temp_cache[self.field_name]["updated"][item.id] = item
instance._temp_cache[self.field_name]["updated"][identity] = item

# Reset Cache
self.delete_cached_value(instance)
Expand Down Expand Up @@ -601,12 +610,15 @@ def remove(self, instance, items) -> None:
}
)

current_value_ids = [value.id for value in data]
current_value_ids = [
getattr(value, id_field(value).field_name) for value in data
]

for item in items:
if item.id in current_value_ids:
if item.id not in instance._temp_cache[self.field_name]["removed"]:
instance._temp_cache[self.field_name]["removed"][item.id] = item
identity = getattr(item, id_field(item).field_name)
if identity in current_value_ids:
if identity not in instance._temp_cache[self.field_name]["removed"]:
instance._temp_cache[self.field_name]["removed"][identity] = item

# Reset Cache
self.delete_cached_value(instance)
Expand Down Expand Up @@ -648,17 +660,24 @@ def _fetch_objects(self, instance, key, value) -> list:
# Update objects from temporary cache if present
updated_objects = []
for value in data:
if value.id in instance._temp_cache[self.field_name]["updated"]:
identity = getattr(value, id_field(value).field_name)
if identity in instance._temp_cache[self.field_name]["updated"]:
updated_objects.append(
instance._temp_cache[self.field_name]["updated"][value.id]
instance._temp_cache[self.field_name]["updated"][identity]
)
else:
updated_objects.append(value)
data = updated_objects

# Remove objects marked as removed in temporary cache
for _, item in instance._temp_cache[self.field_name]["removed"].items():
data[:] = [value for value in data if value.id != item.id]
# Retain data that is not among deleted items
data[:] = [
value
for value in data
if getattr(value, id_field(value).field_name)
!= getattr(item, id_field(item).field_name)
]

return data

Expand Down
3 changes: 2 additions & 1 deletion src/protean/server/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ async def handle_message(
f"Error handling message {message.stream_name}-{message.id} "
f"in {handler_cls.__name__}"
)
logger.error(f"{str(exc)}")
# Print the stack trace
logger.error(traceback.format_exc())
handler_cls.handle_error(exc, message)

await self.shutdown(exit_code=1)
Expand Down
7 changes: 6 additions & 1 deletion src/protean/server/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ async def poll(self) -> None:
await self.tick()

if self.keep_going and not self.engine.shutting_down:
await asyncio.sleep(self.tick_interval)
# Keep control of the loop if in test mode
# Otherwise `asyncio.sleep` will give away control and
# the loop will be able to be stopped with `shutdown()`
if not self.engine.test_mode:
await asyncio.sleep(self.tick_interval)

self.loop.create_task(self.poll())

async def tick(self):
Expand Down
18 changes: 12 additions & 6 deletions src/protean/utils/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,23 @@ def from_dict(cls, message: Dict) -> Message:

def to_object(self) -> Union[BaseEvent, BaseCommand]:
"""Reconstruct the event/command object from the message data."""
if self.metadata.kind == MessageType.EVENT.value:
element_record = current_domain.registry.events[self.metadata.fqn]
elif self.metadata.kind == MessageType.COMMAND.value:
element_record = current_domain.registry.commands[self.metadata.fqn]
else:
if self.metadata.kind not in [
MessageType.COMMAND.value,
MessageType.EVENT.value,
]:
# We are dealing with a malformed or unknown message
raise InvalidDataError(
{"_message": ["Message type is not supported for deserialization"]}
)

return element_record.cls(_metadata=self.metadata, **self.data)
element_cls = current_domain._events_and_commands.get(self.metadata.type, None)

if element_cls is None:
raise ConfigurationError(
f"Message type {self.metadata.type} is not registered with the domain."
)

return element_cls(_metadata=self.metadata, **self.data)

@classmethod
def to_message(cls, message_object: Union[BaseEvent, BaseCommand]) -> Message:
Expand Down
21 changes: 21 additions & 0 deletions tests/command/test_command_basics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from protean import BaseCommand, BaseEventSourcedAggregate
from protean.fields import Identifier, String


class User(BaseEventSourcedAggregate):
id = Identifier(identifier=True)
email = String()
name = String()


class Register(BaseCommand):
user_id = Identifier(identifier=True)
email = String()
name = String()


def test_domain_stores_command_type_for_easy_retrieval(test_domain):
test_domain.register(Register, part_of=User)
test_domain.init(traverse=False)

assert Register.__type__ in test_domain._events_and_commands
8 changes: 5 additions & 3 deletions tests/domain/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ def test_domain_init_constructs_fact_events(self, test_domain):
mock_generate_fact_event_classes.assert_called_once()

def test_domain_init_sets_event_command_types(self, test_domain):
mock_set_event_and_command_type = Mock()
test_domain._set_event_and_command_type = mock_set_event_and_command_type
mock_set_and_record_event_and_command_type = Mock()
test_domain._set_and_record_event_and_command_type = (
mock_set_and_record_event_and_command_type
)
test_domain.init(traverse=False)
mock_set_event_and_command_type.assert_called_once()
mock_set_and_record_event_and_command_type.assert_called_once()

def test_domain_init_sets_up_command_handlers(self, test_domain):
mock_setup_command_handlers = Mock()
Expand Down
6 changes: 6 additions & 0 deletions tests/event/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ def special_method(self):

assert fully_qualified_name(AnnotatedDomainEvent) in test_domain.registry.events

def test_domain_stores_event_type_for_easy_retrieval(self, test_domain):
test_domain.register(PersonAdded, part_of=Person)
test_domain.init(traverse=False)

assert PersonAdded.__type__ in test_domain._events_and_commands


class TestDomainEventEquivalence:
@pytest.fixture(autouse=True)
Expand Down
6 changes: 4 additions & 2 deletions tests/message/test_message_to_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ def test_construct_event_from_message():
assert reconstructed_event.id == identifier


def test_construct_command_from_message():
def test_construct_command_from_message(test_domain):
identifier = str(uuid4())
command = Register(id=identifier, email="[email protected]", name="John Doe")
command = test_domain._enrich_command(
Register(id=identifier, email="[email protected]", name="John Doe")
)
message = Message.to_message(command)

reconstructed_command = message.to_object()
Expand Down
41 changes: 0 additions & 41 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,6 @@ def test_element_registration():
)


def test_delisting_element():
register = _DomainRegistry()
register.register_element(User)

assert (
"tests.test_registry.User" in register._elements[DomainObjects.AGGREGATE.value]
)

register.delist_element(User)

assert (
"tests.test_registry.User"
not in register._elements[DomainObjects.AGGREGATE.value]
)


def test_fetching_elements_from_registry():
register = _DomainRegistry()
register.register_element(User)
Expand Down Expand Up @@ -89,31 +73,6 @@ class FooBar3:
register.register_element(FooBar3)


def test_that_delisting_an_unknown_element_type_triggers_an_error():
class DummyEnum(Enum):
UNKNOWN = "UNKNOWN"

class FooBar1:
element_type = "FOOBAR"

class FooBar2:
element_type = DummyEnum.UNKNOWN

class FooBar3:
pass

register = _DomainRegistry()

with pytest.raises(NotImplementedError):
register.delist_element(FooBar1)

with pytest.raises(NotImplementedError):
register.delist_element(FooBar2)

with pytest.raises(NotImplementedError):
register.delist_element(FooBar3)


def test_that_re_registering_an_element_has_no_effect():
register = _DomainRegistry()
register.register_element(User)
Expand Down

0 comments on commit 83f0159

Please sign in to comment.