diff --git a/docs/guides/cli/shell.md b/docs/guides/cli/shell.md index 0b2f2ce9..b1c51439 100644 --- a/docs/guides/cli/shell.md +++ b/docs/guides/cli/shell.md @@ -13,10 +13,11 @@ protean shell [OPTIONS] ## Options -| Option | Description | Default | -|-------------|-------------------------------------------|---------| -| `--domain` | Sets the domain context for the shell. | `.` | -| `--help` | Shows the help message and exits. | | +| Option | Description | Default | +|---------------|-------------------------------------------|---------| +| `--domain` | Sets the domain context for the shell. | `.` | +| `--traverse` | Auto-traverse domain elements | `False` | +| `--help` | Shows the help message and exits. | | ## Launching the Shell @@ -37,4 +38,14 @@ protean shell --domain auth This command will initiate the shell in the context of `auth` domain, allowing you to perform domain-specific operations more conveniently. Read [Domain -Discovery](discovery.md) for options to specify the domain. \ No newline at end of file +Discovery](discovery.md) for options to specify the domain. + +### Traversing subdirectories + +By default, only the domain and elments in the specified module will be loaded +into the shell context. If you want traverse files in the folder and its +subdirectories, you can specify the `--traverse` option. + +```shell +protean shell --domain auth --traverse +``` diff --git a/docs/guides/compose-a-domain/activate-domain.md b/docs/guides/compose-a-domain/activate-domain.md index 2faa0cc1..4d36eb77 100644 --- a/docs/guides/compose-a-domain/activate-domain.md +++ b/docs/guides/compose-a-domain/activate-domain.md @@ -4,18 +4,77 @@ A `Domain` in protean is always associated with a domain context, which can be used to bind an domain object implicitly to the current thread or greenlet. We refer to the act of binding the domain object as **activating the domain**. - -A `DomainContext` helps manage the active domain object for the duration of a -thread's execution. It also provides a namespace for storing data for the -duration a domain context is active. +## Domain Context -You activate a domain by pushing up its context to the top of the domain stack. +A Protean Domain object has attributes, such as config, that are useful to +access within domain elements. However, importing the domain instance within +the modules in your project is prone to circular import issues. -## Using a Context Manager +Protean solves this issue with the domain context. Rather than passing the +domain around to each method, or referring to a domain directly, you can use +the `current_domain` proxy instead. The `current_domain` proxy, +which points to the domain handling the current activity. + +The `DomainContext` helps manage the active domain object for the duration of a +thread's execution. The domain context keeps track of the domain-level data +during the lifetime of a domain object, and is used while processing handlers, +CLI commands, or other activities. + +## Storing Data + +The domain context also provides a `g` object for storing data. It is a simple +namespace object that has the same lifetime as an domain context. + +!!! note + The `g` name stands for "global", but that is referring to the data + being global within a context. The data on `g` is lost after the context + ends, and it is not an appropriate place to store data between domain + calls. Use a session or a database to store data across domain model calls. + +A common use for g is to manage resources during a domain call. + +1. `get_X()` creates resource X if it does not exist, caching it as g.X. + +2. `teardown_X()` closes or otherwise deallocates the resource if it exists. +It is registered as a `teardown_domain_context()` handler. + +Using this pattern, you can, for example, manage a file connection for the +lifetime of a domain call: + +```python +from protean.globals import g + +def get_log(): + if 'log' not in g: + g.log = open_log_file() + + return g.log + +@domain.teardown_appcontext +def teardown_log_file(exception): + file_obj = g.pop('log', None) + + if not file_obj.closed: + file_obj.close() +``` + +Now, every call to `get_log()` during the domain call will return the same file +object, and it will be closed automatically at the end of processing. + +## Pushing up the Domain Context + +A Protean domain is activated close to the application's entrypoint, like an +API request. In many other cases, like Protean's server processing commands and +events, or the CLI accessing the domain, Protean automatically activates a +domain context for the duration of the task. + +You activate a domain by pushing up its context to the top of the domain stack: + +### With Context Manager Protean provides a helpful context manager to nest the domain operations under. -```Python hl_lines="18-21" +```python hl_lines="18-21" {! docs_src/guides/composing-a-domain/018.py !} ``` @@ -27,7 +86,7 @@ This is a convenient pattern to use in conjunction with most API frameworks. The domain’s context is pushed up at the beginning of a request and popped out once the request is processed. -## Without the Context Manager +### Manually You can also activate the context manually by using the `push` and `pop` methods of the domain context: diff --git a/docs/guides/compose-a-domain/element-decorators.md b/docs/guides/compose-a-domain/element-decorators.md index 041be473..599b4584 100644 --- a/docs/guides/compose-a-domain/element-decorators.md +++ b/docs/guides/compose-a-domain/element-decorators.md @@ -6,7 +6,7 @@ Each element is explored in detail in its own section. ## `Domain.aggregate` -```Python hl_lines="7-11" +```python hl_lines="7-11" {! docs_src/guides/composing-a-domain/002.py !} ``` @@ -15,7 +15,7 @@ Read more at Aggregates. ## `Domain.entity` -```Python hl_lines="14-17" +```python hl_lines="14-17" {! docs_src/guides/composing-a-domain/003.py !} ``` @@ -24,7 +24,7 @@ Read more at Entities. ## `Domain.value_object` -```Python hl_lines="7-15 23" +```python hl_lines="7-15 23" {! docs_src/guides/composing-a-domain/004.py !} ``` @@ -33,7 +33,7 @@ Read more at Value Objects. ## `Domain.domain_service` -```Python hl_lines="33-37" +```python hl_lines="33-37" {! docs_src/guides/composing-a-domain/005.py !} ``` @@ -42,7 +42,7 @@ Read more at Domain Services. ## `Domain.event_sourced_aggregate` -```Python hl_lines="7-10" +```python hl_lines="7-10" {! docs_src/guides/composing-a-domain/006.py !} ``` @@ -51,7 +51,7 @@ Read more at Event Sourced Aggregates. ## `Domain.command` -```Python hl_lines="18-23" +```python hl_lines="18-23" {! docs_src/guides/composing-a-domain/007.py !} ``` @@ -60,7 +60,7 @@ Read more at Commands. ## `Domain.command_handler` -```Python hl_lines="26-34" +```python hl_lines="26-34" {! docs_src/guides/composing-a-domain/008.py !} ``` @@ -69,7 +69,7 @@ Read more at Command Handlers. ## `Domain.event` -```Python hl_lines="18-23" +```python hl_lines="18-23" {! docs_src/guides/composing-a-domain/009.py !} ``` @@ -78,7 +78,7 @@ Read more at Events. ## `Domain.event_handler` -```Python hl_lines="28-32" +```python hl_lines="28-32" {! docs_src/guides/composing-a-domain/010.py !} ``` @@ -87,7 +87,7 @@ Read more at Event Handlers. ## `Domain.model` -```Python hl_lines="18-25" +```python hl_lines="18-25" {! docs_src/guides/composing-a-domain/011.py !} ``` @@ -96,7 +96,7 @@ Read more at Models. ## `Domain.repository` -```Python hl_lines="17-22" +```python hl_lines="17-22" {! docs_src/guides/composing-a-domain/012.py !} ``` @@ -105,7 +105,7 @@ Read more at Repositories. ## `Domain.view` -```Python hl_lines="20-24" +```python hl_lines="20-24" {! docs_src/guides/composing-a-domain/013.py !} ``` diff --git a/docs/guides/compose-a-domain/initialize-domain.md b/docs/guides/compose-a-domain/initialize-domain.md index ca4bf43a..9a066862 100644 --- a/docs/guides/compose-a-domain/initialize-domain.md +++ b/docs/guides/compose-a-domain/initialize-domain.md @@ -27,7 +27,7 @@ with Protean explicitly. Protean constructs a graph of all elements registered with a domain and exposes them in a registry. -```Python hl_lines="28-35" +```python hl_lines="28-35" {! docs_src/guides/composing-a-domain/016.py !} ``` @@ -41,7 +41,7 @@ a database that actually persists data. Calling `domain.init()` establishes connectivity with the underlying infra, testing access, and making them available for use by the rest of the system. -```Python hl_lines="5-11" +```python hl_lines="5-11" {! docs_src/guides/composing-a-domain/017.py !} ``` diff --git a/docs/guides/compose-a-domain/object-model.md b/docs/guides/compose-a-domain/object-model.md index 97b219fb..0e2c208d 100644 --- a/docs/guides/compose-a-domain/object-model.md +++ b/docs/guides/compose-a-domain/object-model.md @@ -17,7 +17,7 @@ Additional options can be passed to a domain element in two ways: You can specify options within a nested inner class called `Meta`: -```Python hl_lines="13-14" +```python hl_lines="13-14" {! docs_src/guides/composing-a-domain/020.py !} ``` @@ -25,6 +25,6 @@ You can specify options within a nested inner class called `Meta`: You can also pass options as parameters to the decorator: -```Python hl_lines="7" +```python hl_lines="7" {! docs_src/guides/composing-a-domain/021.py !} ``` diff --git a/docs/guides/compose-a-domain/register-elements.md b/docs/guides/compose-a-domain/register-elements.md index 0f8b8495..d52f577d 100644 --- a/docs/guides/compose-a-domain/register-elements.md +++ b/docs/guides/compose-a-domain/register-elements.md @@ -5,7 +5,7 @@ the domain. ## With decorators -```Python hl_lines="7-11" +```python hl_lines="7-11" {! docs_src/guides/composing-a-domain/002.py !} ``` @@ -16,7 +16,7 @@ A full list of domain decorators along with examples are available in the You can also choose to register elements manually. -```Python hl_lines="7-13" +```python hl_lines="7-13" {! docs_src/guides/composing-a-domain/014.py !} ``` @@ -31,7 +31,7 @@ of element in Protean has a distinct base class of its own. There might be additional options you will pass in a `Meta` inner class, depending upon the element being registered. -```Python hl_lines="12-13" +```python hl_lines="12-13" {! docs_src/guides/composing-a-domain/015.py !} ``` diff --git a/docs/guides/compose-a-domain/when-to-compose.md b/docs/guides/compose-a-domain/when-to-compose.md index 4a806031..509eed90 100644 --- a/docs/guides/compose-a-domain/when-to-compose.md +++ b/docs/guides/compose-a-domain/when-to-compose.md @@ -16,6 +16,6 @@ You would compose the domain along with the app object, and activate it (push up the context) before processing a request. -```Python hl_lines="29 33 35 38" +```python hl_lines="29 33 35 38" {! docs_src/guides/composing-a-domain/019.py !} ``` diff --git a/src/protean/cli/__init__.py b/src/protean/cli/__init__.py index d4602b08..47792ca0 100644 --- a/src/protean/cli/__init__.py +++ b/src/protean/cli/__init__.py @@ -15,6 +15,7 @@ Also see (1) from http://click.pocoo.org/5/setuptools/#setuptools-integration """ +import logging import subprocess from enum import Enum @@ -32,6 +33,8 @@ from protean.exceptions import NoDomainException from protean.utils.domain_discovery import derive_domain +logger = logging.getLogger(__name__) + # Create the Typer app # `no_args_is_help=True` will show the help message when no arguments are passed app = typer.Typer(no_args_is_help=True) @@ -125,20 +128,26 @@ def test( @app.command() def server( - domain: Annotated[str, typer.Option("--domain")] = ".", + domain: Annotated[str, typer.Option()] = ".", test_mode: Annotated[Optional[bool], typer.Option()] = False, + debug: Annotated[Optional[bool], typer.Option()] = False, ): """Run Async Background Server""" # FIXME Accept MAX_WORKERS as command-line input as well - from protean.server import Engine - - domain = derive_domain(domain) - if not domain: - raise NoDomainException( + try: + domain = derive_domain(domain) + except NoDomainException: + logger.error( "Could not locate a Protean domain. You should provide a domain in" '"PROTEAN_DOMAIN" environment variable or pass a domain file in options ' 'and a "domain.py" module was not found in the current directory.' ) + raise typer.Abort() - engine = Engine(domain, test_mode=test_mode) + from protean.server import Engine + + engine = Engine(domain, test_mode=test_mode, debug=debug) engine.run() + + if engine.exit_code != 0: + raise typer.Exit(code=engine.exit_code) diff --git a/src/protean/cli/generate.py b/src/protean/cli/generate.py index 33881a93..bf5b2600 100644 --- a/src/protean/cli/generate.py +++ b/src/protean/cli/generate.py @@ -1,3 +1,5 @@ +import logging + import typer from typing_extensions import Annotated @@ -5,6 +7,8 @@ from protean.exceptions import NoDomainException from protean.utils.domain_discovery import derive_domain +logger = logging.getLogger(__name__) + app = typer.Typer(no_args_is_help=True) @@ -28,13 +32,16 @@ def docker_compose( domain: Annotated[str, typer.Option()] = ".", ): """Generate a `docker-compose.yml` from Domain config""" - print(f"Generating docker-compose.yml for domain at {domain}") - domain_instance = derive_domain(domain) - if not domain_instance: - raise NoDomainException( + try: + domain_instance = derive_domain(domain) + except NoDomainException: + logger.error( "Could not locate a Protean domain. You should provide a domain in" '"PROTEAN_DOMAIN" environment variable or pass a domain file in options' ) + raise typer.Abort() + + print(f"Generating docker-compose.yml for domain at {domain}") with domain_instance.domain_context(): domain_instance.init() diff --git a/src/protean/cli/shell.py b/src/protean/cli/shell.py index ca253b39..22dfd63b 100644 --- a/src/protean/cli/shell.py +++ b/src/protean/cli/shell.py @@ -10,6 +10,7 @@ https://github.com/pallets/flask/blob/b90a4f1f4a370e92054b9cc9db0efcb864f87ebe/src/flask/cli.py#L984 """ +import logging import sys import typing @@ -21,17 +22,27 @@ from protean.exceptions import NoDomainException from protean.utils.domain_discovery import derive_domain +logger = logging.getLogger(__name__) -def shell(domain: Annotated[str, typer.Option()] = "."): - domain_instance = derive_domain(domain) - if not domain_instance: - raise NoDomainException( + +def shell( + domain: Annotated[str, typer.Option()] = ".", + traverse: Annotated[bool, typer.Option()] = False, +): + try: + domain_instance = derive_domain(domain) + except NoDomainException: + logger.error( "Could not locate a Protean domain. You should provide a domain in" '"PROTEAN_DOMAIN" environment variable or pass a domain file in options' ) + raise typer.Abort() + + if traverse: + print("Traversing directory to load all modules...") with domain_instance.domain_context(): - domain_instance.init() + domain_instance.init(traverse=traverse) ctx: dict[str, typing.Any] = {} ctx.update(domain_instance.make_shell_context()) diff --git a/src/protean/container.py b/src/protean/container.py index 71ebb7e8..fe1c025e 100644 --- a/src/protean/container.py +++ b/src/protean/container.py @@ -320,7 +320,7 @@ def __setattr__(self, name, value): if ( name in attributes(self) or name in fields(self) - or name in ["errors", "state_", "_temp_cache", "_events"] + or name in ["errors", "state_", "_temp_cache", "_events", "_initialized"] or name.startswith(("add_", "remove_", "_mark_changed_")) ): super().__setattr__(name, value) diff --git a/src/protean/core/entity.py b/src/protean/core/entity.py index 5324d46a..55c3ca6c 100644 --- a/src/protean/core/entity.py +++ b/src/protean/core/entity.py @@ -1,4 +1,5 @@ """Entity Functionality and Classes""" + import copy import logging @@ -202,22 +203,27 @@ def __init__(self, *template, **kwargs): # noqa: C901 (embedded_field.field_name, embedded_field.attribute_name) for embedded_field in field_obj.embedded_fields.values() ] - values = {name: kwargs.get(attr) for name, attr in attrs} - try: - # Pass the `required` option value as defined at the parent - value_object = field_obj.value_object_cls( - **values, required=field_obj.required - ) + kwargs_values = {name: kwargs.get(attr) for name, attr in attrs} + + # Check if any of the values in `values` are not None + # If all values are None, it means that the value object is not being set + # and we should set it to None + # + # If any of the values are not None, we should set the value object and its attributes + # to the values provided and let it trigger validations + if any(kwargs_values.values()): + try: + value_object = field_obj.value_object_cls(**kwargs_values) - # Set VO value only if the value object is not None/Empty - if value_object: - setattr(self, field_name, value_object) - loaded_fields.append(field_name) - except ValidationError as err: - for sub_field_name in err.messages: - self.errors["{}_{}".format(field_name, sub_field_name)].extend( - err.messages[sub_field_name] - ) + # Set VO value only if the value object is not None/Empty + if value_object: + setattr(self, field_name, value_object) + loaded_fields.append(field_name) + except ValidationError as err: + for sub_field_name in err.messages: + self.errors[ + "{}_{}".format(field_name, sub_field_name) + ].extend(err.messages[sub_field_name]) # Load Identities for field_name, field_obj in declared_fields(self).items(): diff --git a/src/protean/core/value_object.py b/src/protean/core/value_object.py index 69ebcd98..f58cb333 100644 --- a/src/protean/core/value_object.py +++ b/src/protean/core/value_object.py @@ -1,4 +1,5 @@ """Value Object Functionality and Classes""" + import logging from collections import defaultdict @@ -22,6 +23,8 @@ def __init_subclass__(subclass) -> None: super().__init_subclass__() subclass.__validate_for_basic_field_types() + subclass.__validate_for_non_identifier_fields() + subclass.__validate_for_non_unique_fields() @classmethod def __validate_for_basic_field_types(subclass): @@ -29,13 +32,37 @@ def __validate_for_basic_field_types(subclass): if isinstance(field_obj, (Reference, Association, ValueObject)): raise IncorrectUsageError( { - "_entity": [ - f"Views can only contain basic field types. " + "_value_object": [ + f"Value Objects can only contain basic field types. " f"Remove {field_name} ({field_obj.__class__.__name__}) from class {subclass.__name__}" ] } ) + @classmethod + def __validate_for_non_identifier_fields(subclass): + for field_name, field_obj in fields(subclass).items(): + if field_obj.identifier: + raise IncorrectUsageError( + { + "_value_object": [ + f"Value Objects cannot contain fields marked 'identifier' (field '{field_name}')" + ] + } + ) + + @classmethod + def __validate_for_non_unique_fields(subclass): + for field_name, field_obj in fields(subclass).items(): + if field_obj.unique: + raise IncorrectUsageError( + { + "_value_object": [ + f"Value Objects cannot contain fields marked 'unique' (field '{field_name}')" + ] + } + ) + def __init__(self, *template, **kwargs): # noqa: C901 """ Initialise the container. @@ -56,7 +83,8 @@ def __init__(self, *template, **kwargs): # noqa: C901 self.errors = defaultdict(list) - required = kwargs.pop("required", False) + # Set the flag to prevent any further modifications + self._initialized = False # Load the attributes based on the template loaded_fields = [] @@ -73,17 +101,17 @@ def __init__(self, *template, **kwargs): # noqa: C901 # Now load against the keyword arguments for field_name, val in kwargs.items(): + # Record that a field was encountered by appending to `loaded_fields` + # When it fails validations, we want it's errors to be recorded + # + # Not remembering the field was recorded will result in it being set to `None` + # which will raise a ValidationError of its own for the wrong reasons (required field not set) + loaded_fields.append(field_name) try: setattr(self, field_name, val) except ValidationError as err: - # Ignore mandatory errors if VO is marked optional at the parent level - if "is required" in err.messages[field_name] and not required: - loaded_fields.append(field_name) - else: - for field_name in err.messages: - self.errors[field_name].extend(err.messages[field_name]) - else: - loaded_fields.append(field_name) + for field_name in err.messages: + self.errors[field_name].extend(err.messages[field_name]) # Now load the remaining fields with a None value, which will fail # for required fields @@ -103,6 +131,38 @@ def __init__(self, *template, **kwargs): # noqa: C901 logger.error(self.errors) raise ValidationError(self.errors) + # If we made it this far, the Value Object is initialized + # and should be marked as such + self._initialized = True + + def __setattr__(self, name, value): + if not hasattr(self, "_initialized") or not self._initialized: + return super().__setattr__(name, value) + else: + raise IncorrectUsageError( + { + "_value_object": [ + "Value Objects are immutable and cannot be modified once created" + ] + } + ) + + def _run_validators(self, value): + """Collect validators from enclosed fields and run them. + + This method is called during initialization of the Value Object + at the Entity level. + """ + errors = defaultdict(list) + for field_name, field_obj in fields(self).items(): + try: + field_obj._run_validators(getattr(self, field_name), value) + except ValidationError as err: + errors[field_name].extend(err.messages) + + if errors: + raise ValidationError(errors) + def value_object_factory(element_cls, **kwargs): element_cls = derive_element_class(element_cls, BaseValueObject, **kwargs) diff --git a/src/protean/domain/context.py b/src/protean/domain/context.py index a31c2bd3..70c22e2d 100644 --- a/src/protean/domain/context.py +++ b/src/protean/domain/context.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -class _DomainContextGlobals(object): +class _DomainContextGlobals: """A plain object. Used as a namespace for storing data for the duration of a domain context. @@ -78,7 +78,7 @@ class DomainContext(object): def __init__(self, domain, **kwargs): self.domain = domain - self.g = domain.domain_context_globals_class() + self.g: _DomainContextGlobals = domain.domain_context_globals_class() # Set any additional kwargs as attributes in globals for kw in kwargs.items(): diff --git a/src/protean/fields/base.py b/src/protean/fields/base.py index f09d0fc7..55fe90a8 100644 --- a/src/protean/fields/base.py +++ b/src/protean/fields/base.py @@ -7,8 +7,7 @@ from typing import Any, Callable, Iterable, List from protean import exceptions - -from .mixins import FieldDescriptorMixin +from protean.fields.mixins import FieldDescriptorMixin MISSING_ERROR_MESSAGE = ( "ValidationError raised by `{class_name}`, but error key `{key}` does " @@ -91,9 +90,6 @@ def __init__( self._validators = validators - # Hold a reference to Entity registering the field - self._entity_cls = None - # Collect default error message from self and parent classes messages = {} for cls in reversed(self.__class__.__mro__): @@ -101,6 +97,26 @@ def __init__( messages.update(error_messages or {}) self.error_messages = messages + def _generic_param_values_for_repr(self): + """Return the generic parameter values for the Field's repr""" + values = [] + if self.required: + values.append("required=True") + if self.default is not None: + # If default is a callable, use its name + if callable(self.default): + values.append(f"default={self.default.__name__}") + else: + values.append(f"default='{self.default}'") + return values + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + + ", ".join(self._generic_param_values_for_repr()) + + ")" + ) + def __get__(self, instance, owner): if hasattr(instance, "__dict__"): return instance.__dict__.get(self.field_name) diff --git a/src/protean/fields/basic.py b/src/protean/fields/basic.py index a2bbe50e..5cbd408b 100644 --- a/src/protean/fields/basic.py +++ b/src/protean/fields/basic.py @@ -48,6 +48,18 @@ def as_dict(self, value): """Return JSON-compatible value of self""" return value + def __repr__(self): + # Generate repr values specific to this field + values = self._generic_param_values_for_repr() + if self.max_length != 255: + values.append(f"max_length={self.max_length}") + if self.min_length: + values.append(f"min_length={self.min_length}") + if not self.sanitize: + values.append("sanitize=False") + + return f"{self.__class__.__name__}(" + ", ".join(values) + ")" + class Text(Field): """Concrete field implementation for the text type.""" @@ -73,6 +85,14 @@ def as_dict(self, value): """Return JSON-compatible value of self""" return value + def __repr__(self): + # Generate repr values specific to this field + values = self._generic_param_values_for_repr() + if not self.sanitize: + values.append("sanitize=False") + + return f"{self.__class__.__name__}(" + ", ".join(values) + ")" + class Integer(Field): """Concrete field implementation for the Integer type. @@ -109,6 +129,16 @@ def as_dict(self, value): """Return JSON-compatible value of self""" return value + def __repr__(self): + # Generate repr values specific to this field + values = self._generic_param_values_for_repr() + if self.max_value: + values.append(f"max_value={self.max_length}") + if self.min_value: + values.append(f"min_value={self.min_value}") + + return f"{self.__class__.__name__}(" + ", ".join(values) + ")" + class Float(Field): """Concrete field implementation for the Floating type. @@ -142,6 +172,16 @@ def as_dict(self, value): """Return JSON-compatible value of self""" return value + def __repr__(self): + # Generate repr values specific to this field + values = self._generic_param_values_for_repr() + if self.max_value: + values.append(f"max_value={self.max_length}") + if self.min_value: + values.append(f"min_value={self.min_value}") + + return f"{self.__class__.__name__}(" + ", ".join(values) + ")" + class Boolean(Field): """Concrete field implementation for the Boolean type.""" @@ -289,6 +329,14 @@ def as_dict(self, value): return value if isinstance(value, int) else str(value) + def __repr__(self): + # Generate repr values specific to this field + values = self._generic_param_values_for_repr() + if self.increment: + values.append("increment=True") + + return f"{self.__class__.__name__}(" + ", ".join(values) + ")" + class Identifier(Field): """Concrete field implementation for Identifiers. diff --git a/src/protean/port/event_store.py b/src/protean/port/event_store.py index c05c9974..0bc38392 100644 --- a/src/protean/port/event_store.py +++ b/src/protean/port/event_store.py @@ -33,6 +33,8 @@ def _write( ) -> int: """Write a message to the event store. + Returns the position of the message in the stream. + Implemented by the concrete event store adapter. """ diff --git a/src/protean/reflection.py b/src/protean/reflection.py index bc8133f2..609c68b6 100644 --- a/src/protean/reflection.py +++ b/src/protean/reflection.py @@ -1,5 +1,7 @@ from typing import Any +from protean.exceptions import IncorrectUsageError + _FIELDS = "__container_fields__" _ID_FIELD_NAME = "__container_id_field_name__" @@ -15,7 +17,9 @@ def fields(class_or_instance): try: fields_dict = getattr(class_or_instance, _FIELDS) except AttributeError: - raise TypeError("must be called with a dataclass type or instance") + raise IncorrectUsageError( + {"field": [f"{class_or_instance} does not have fields"]} + ) return fields_dict @@ -24,7 +28,9 @@ def id_field(class_or_instance): try: field_name = getattr(class_or_instance, _ID_FIELD_NAME) except AttributeError: - raise TypeError("must be called with a dataclass type or instance") + raise IncorrectUsageError( + {"identity": [f"{class_or_instance} does not have identity fields"]} + ) return fields(class_or_instance)[field_name] @@ -93,6 +99,8 @@ def declared_fields(class_or_instance): fields_dict = dict(getattr(class_or_instance, _FIELDS)) fields_dict.pop("_version", None) except AttributeError: - raise TypeError("must be called with a dataclass type or instance") + raise IncorrectUsageError( + {"field": [f"{class_or_instance} does not have fields"]} + ) return fields_dict diff --git a/src/protean/server/engine.py b/src/protean/server/engine.py index 92f91142..e36818ed 100644 --- a/src/protean/server/engine.py +++ b/src/protean/server/engine.py @@ -3,14 +3,13 @@ import asyncio import logging import signal +import traceback -from typing import Union +from typing import Type, Union from protean.core.command_handler import BaseCommandHandler from protean.core.event_handler import BaseEventHandler -from protean.exceptions import ConfigurationError from protean.globals import g -from protean.utils import import_from_full_path from protean.utils.mixins import Message from .subscription import Subscription @@ -25,15 +24,37 @@ class Engine: - def __init__(self, domain, test_mode: bool = False) -> None: + """ + The Engine class represents the Protean Engine that handles message processing and subscription management. + """ + + def __init__(self, domain, test_mode: bool = False, debug: bool = False) -> None: + """ + Initialize the Engine. + + Modes: + - Test Mode: If set to True, the engine will run in test mode and will exit after all tasks are completed. + - Debug Mode: If set to True, the engine will run in debug mode and will log additional information. + + Args: + domain (Domain): The domain object associated with the engine. + test_mode (bool, optional): Flag to indicate if the engine is running in test mode. Defaults to False. + debug (bool, optional): Flag to indicate if debug mode is enabled. Defaults to False. + """ self.domain = domain - self.test_mode = test_mode + self.test_mode = ( + test_mode # Flag to indicate if the engine is running in test mode + ) + self.debug = debug # Flag to indicate if debug mode is enabled + self.exit_code = 0 + self.shutting_down = False # Flag to indicate the engine is shutting down self.loop = asyncio.get_event_loop() # FIXME Gather all handlers self._subscriptions = {} for handler_name, record in self.domain.registry.event_handlers.items(): + # Create a subscription for each event handler self._subscriptions[handler_name] = Subscription( self, handler_name, @@ -44,6 +65,7 @@ def __init__(self, domain, test_mode: bool = False) -> None: ) for handler_name, record in self.domain.registry.command_handlers.items(): + # Create a subscription for each command handler self._subscriptions[handler_name] = Subscription( self, handler_name, @@ -51,18 +73,28 @@ def __init__(self, domain, test_mode: bool = False) -> None: record.cls, ) - @classmethod - def from_domain_file(cls, domain: str, domain_file: str, **kwargs) -> Engine: - domain = import_from_full_path(domain=domain, path=domain_file) - return cls(domain=domain, **kwargs) - - def handle_results(self, results, message): - # FIXME Implement handling of results - pass - async def handle_message( - self, handler_cls: Union[BaseCommandHandler, BaseEventHandler], message: Message + self, + handler_cls: Type[Union[BaseCommandHandler, BaseEventHandler]], + message: Message, ) -> None: + """ + Handle a message by invoking the appropriate handler class. + + Args: + handler_cls (Type[Union[BaseCommandHandler, BaseEventHandler]]): The handler class to invoke. + message (Message): The message to be handled. + + Returns: + None + + Raises: + Exception: If an error occurs while handling the message. + + """ + if self.shutting_down: + return # Skip handling if shutdown is in progress + with self.domain.domain_context(): # Set context from current message, so that further processes # carry the metadata forward. @@ -74,34 +106,62 @@ async def handle_message( logger.info( f"{handler_cls.__name__} processed {message.type}-{message.id} successfully." ) - except ConfigurationError as exc: - logger.error( - f"Error while handling message {message.stream_name}-{message.id} in {handler_cls.__name__} - {str(exc)}" - ) - raise - except Exception as exc: + except Exception as exc: # Includes handling `ConfigurationError` logger.error( - f"Error while handling message {message.stream_name}-{message.id} in {handler_cls.__name__} - {str(exc)}" + f"Error handling message {message.stream_name}-{message.id} " + f"in {handler_cls.__name__}" ) - # FIXME Implement mechanisms to track errors + logger.error(f"{str(exc)}") + handler_cls.handle_error(exc, message) + + await self.shutdown(exit_code=1) + return # Reset message context g.pop("message_in_context") - async def shutdown(self, signal=None): - """Cleanup tasks tied to the service's shutdown.""" - if signal: - logger.info(f"Received exit signal {signal.name}...") + async def shutdown(self, signal=None, exit_code=0): + """ + Cleanup tasks tied to the service's shutdown. + + Args: + signal (Optional[signal]): The exit signal received. Defaults to None. + exit_code (int): The exit code to be stored. Defaults to 0. + """ + self.shutting_down = True # Set shutdown flag + + try: + if signal: + logger.info(f"Received exit signal {signal.name}...") + + # Store the exit code + self.exit_code = exit_code - tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] - [task.cancel() for task in tasks] + # Shutdown subscriptions + subscription_shutdown_tasks = [ + subscription.shutdown() + for _, subscription in self._subscriptions.items() + ] - logger.info(f"Cancelling {len(tasks)} outstanding tasks") - await asyncio.gather(*tasks, return_exceptions=True) - self.loop.stop() + # Cancel outstanding tasks + [task.cancel() for task in tasks] + logger.info(f"Cancelling {len(tasks)} outstanding tasks") + await asyncio.gather(*tasks, return_exceptions=True) + + # Wait for subscriptions to shut down + await asyncio.gather(*subscription_shutdown_tasks, return_exceptions=True) + logger.info("All subscriptions have been shut down.") + finally: + if self.loop.is_running(): + self.loop.stop() def run(self): + """ + Start the Protean Engine and run the subscriptions. + """ + logger.info("Starting Protean Engine...") # Handle Signals signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) for s in signals: @@ -114,13 +174,13 @@ def handle_exception(loop, context): # context["message"] will always be there; but context["exception"] may not msg = context.get("exception", context["message"]) - import traceback - + # Print the stack trace traceback.print_stack(context.get("exception")) logger.error(f"Caught exception: {msg}") logger.info("Shutting down...") - asyncio.create_task(self.shutdown(loop)) + if loop.is_running(): + asyncio.create_task(self.shutdown(exit_code=1)) self.loop.set_exception_handler(handle_exception) @@ -129,10 +189,18 @@ def handle_exception(loop, context): # Start consumption, one per subscription try: - for _, subscription in self._subscriptions.items(): + tasks = [ self.loop.create_task(subscription.start()) - - self.loop.run_forever() + for _, subscription in self._subscriptions.items() + ] + + if self.test_mode: + # If in test mode, run until all tasks complete + self.loop.run_until_complete(asyncio.gather(*tasks)) + # Then immediately call and await the shutdown directly + self.loop.run_until_complete(self.shutdown()) + else: + self.loop.run_forever() finally: self.loop.close() logger.debug("Successfully shutdown Protean Engine.") diff --git a/src/protean/server/subscription.py b/src/protean/server/subscription.py index d4b28e8b..76c633a2 100644 --- a/src/protean/server/subscription.py +++ b/src/protean/server/subscription.py @@ -5,7 +5,7 @@ from protean import BaseCommandHandler, BaseEventHandler from protean.port import BaseEventStore -from protean.utils.mixins import Message +from protean.utils.mixins import Message, MessageType logging.basicConfig( level=logging.INFO, @@ -17,7 +17,12 @@ class Subscription: - """Subscriber implementation.""" + """ + Represents a subscription to a stream in the Protean event-driven architecture. + + A subscription allows a subscriber to receive and process messages from a specific stream. + It provides methods to start and stop the subscription, as well as process messages in batches. + """ def __init__( self, @@ -30,6 +35,19 @@ def __init__( origin_stream_name: str | None = None, tick_interval: int = 1, ) -> None: + """ + Initialize the Subscription object. + + Args: + engine: The Protean engine instance. + subscriber_id (str): The unique identifier for the subscriber. + stream_name (str): The name of the stream to subscribe to. + handler (Union[BaseEventHandler, BaseCommandHandler]): The event or command handler. + messages_per_tick (int, optional): The number of messages to process per tick. Defaults to 10. + position_update_interval (int, optional): The interval at which to update the current position. Defaults to 10. + origin_stream_name (str | None, optional): The name of the origin stream to filter messages. Defaults to None. + tick_interval (int, optional): The interval between ticks. Defaults to 1. + """ self.engine = engine self.store: BaseEventStore = engine.domain.event_store.store @@ -43,40 +61,178 @@ def __init__( self.origin_stream_name = origin_stream_name self.tick_interval = tick_interval - self.subscriber_stream_name = f"subscriber_position-${subscriber_id}" + self.subscriber_stream_name = f"position-${subscriber_id}" self.current_position: int = -1 self.messages_since_last_position_write: int = 0 - self.keep_going: bool = not engine.test_mode + self.keep_going = True # Initially set to keep going + + async def start(self) -> None: + """ + Start the subscription. + + This method initializes the subscription by loading the last position from the event store + and starting the polling loop. + + Returns: + None + """ + logger.debug(f"Starting {self.subscriber_id}") + + # Load own position from Event store + await self.load_position_on_start() + + # Start the polling loop + self.loop.create_task(self.poll()) + + async def poll(self) -> None: + """ + Polling loop for processing messages. + + This method continuously polls for new messages and processes them by calling the `tick` method. + It sleeps for the specified `tick_interval` between each tick. + + Returns: + None + """ + await self.tick() + + if self.keep_going and not self.engine.shutting_down: + await asyncio.sleep(self.tick_interval) + self.loop.create_task(self.poll()) + + async def tick(self): + """ + This method retrieves the next batch of messages to process and calls the `process_batch` method + to handle each message. It also updates the read position after processing each message. + + Returns: + None + """ + messages = await self.get_next_batch_of_messages() + if messages: + await self.process_batch(messages) + + async def shutdown(self): + """ + Shutdown the subscription. + + This method signals the subscription to stop polling and updates the current position to the store. + It also logs a message indicating the shutdown of the subscription. + + Returns: + None + """ + self.keep_going = False # Signal to stop polling + await self.update_current_position_to_store() + logger.info(f"Shutting down subscription {self.subscriber_id}") - async def load_position(self): + async def fetch_last_position(self): + """ + Fetch the last read position from the store. + + Returns: + int: The last read position from the store. + """ message = self.store._read_last_message(self.subscriber_stream_name) if message: - self.current_position = message["data"]["position"] + return message["data"]["position"] + + return -1 + + async def load_position_on_start(self) -> None: + """ + Load the last position from the store when starting. + + This method retrieves the last read position from the event store and updates the current position + of the subscription. If there is no previous position, it logs a message indicating that the + subscription will start at the beginning of the stream. + + Returns: + None + """ + last_position = await self.fetch_last_position() + if last_position > -1: + self.current_position = last_position logger.debug(f"Loaded position {self.current_position} from last message") else: - self.current_position = 0 - logger.debug("No previous messages - Starting at position 0") + logger.debug( + "No previous messages - Starting at the beginning of the stream" + ) + + async def update_current_position_to_store(self) -> int: + """Update the current position to the store, only if out of sync. + + This method updates the current position of the subscription to the event store, but only if the + current position is greater than the last written position. - async def update_read_position(self, position): + Returns: + int: The last written position. + """ + last_written_position = await self.fetch_last_position() + if last_written_position < self.current_position: + self.write_position(self.current_position) + + return last_written_position + + async def update_read_position(self, position) -> int: + """ + Update the current read position. + + If at or beyond the configured interval, write position to the store. + + Args: + position (int): The new read position. + + Returns: + int: The updated read position. + """ self.current_position = position self.messages_since_last_position_write += 1 - if self.messages_since_last_position_write == self.position_update_interval: - return self.write_position(position) + if self.messages_since_last_position_write >= self.position_update_interval: + self.write_position(position) + + return self.current_position + + def write_position(self, position: int) -> int: + """ + Write the position to the store. + + This method writes the current read position to the event store. It updates the read position + of the subscriber and resets the counter for messages since the last position write. - return + Args: + position (int): The read position to be written. - def write_position(self, position): + Returns: + int: The position that was written. + """ logger.debug(f"Updating Read Position of {self.subscriber_id} to {position}") - self.messages_since_last_position_write = 0 + self.messages_since_last_position_write = 0 # Reset counter + return self.store._write( - self.subscriber_stream_name, "Read", {"position": position} + self.subscriber_stream_name, + "Read", + {"position": position}, + metadata={ + "kind": MessageType.READ_POSITION.value, + "origin_stream_name": self.stream_name, + }, ) def filter_on_origin(self, messages: List[Message]) -> List[Message]: + """ + Filter messages based on the origin stream name. + + Args: + messages (List[Message]): The list of messages to filter. + + Returns: + List[Message]: The filtered list of messages. + """ if not self.origin_stream_name: return messages @@ -94,6 +250,15 @@ def filter_on_origin(self, messages: List[Message]) -> List[Message]: return filtered_messages async def get_next_batch_of_messages(self): + """ + Get the next batch of messages to process. + + This method reads messages from the event store starting from the current position + 1. + It retrieves a specified number of messages per tick and applies filtering based on the origin stream name. + + Returns: + List[Message]: The next batch of messages to process. + """ messages = self.store.read( self.stream_name, position=self.current_position + 1, @@ -103,36 +268,23 @@ async def get_next_batch_of_messages(self): return self.filter_on_origin(messages) async def process_batch(self, messages): + """ + Process a batch of messages. + + This method takes a batch of messages and processes each message by calling the `handle_message` method + of the engine. It also updates the read position after processing each message. If an exception occurs + during message processing, it logs the error using the `log_error` method. + + Args: + messages (List[Message]): The batch of messages to process. + + Returns: + int: The number of messages processed. + """ logging.debug(f"Processing {len(messages)} messages...") for message in messages: logging.info(f"{message.type}-{message.id} : {message.to_dict()}") - try: - await self.engine.handle_message(self.handler, message) - await self.update_read_position(message.global_position) - except Exception as exc: - self.log_error(message, exc) + await self.engine.handle_message(self.handler, message) + await self.update_read_position(message.global_position) return len(messages) - - def log_error(self, last_message, error): - logger.error(str(error)) - # FIXME Better Debug : print(f"{str(error) - {last_message}}") - - async def start(self): - logger.debug(f"Starting {self.subscriber_id}") - - # Load own position from Event store - await self.load_position() - self.loop.create_task(self.poll()) - - async def poll(self): - await self.tick() - - if self.keep_going: - await asyncio.sleep(self.tick_interval) - self.loop.create_task(self.poll()) - - async def tick(self): - messages = await self.get_next_batch_of_messages() - if messages: - return await self.process_batch(messages) diff --git a/src/protean/utils/mixins.py b/src/protean/utils/mixins.py index ee0e850c..662d0f32 100644 --- a/src/protean/utils/mixins.py +++ b/src/protean/utils/mixins.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +import logging from collections import defaultdict from enum import Enum @@ -19,15 +20,18 @@ from protean.reflection import has_id_field, id_field from protean.utils import fully_qualified_name +logger = logging.getLogger(__name__) + class MessageType(Enum): EVENT = "EVENT" COMMAND = "COMMAND" + READ_POSITION = "READ_POSITION" class MessageMetadata(BaseValueObject): # Marks message as a `COMMAND` or an `EVENT` - kind = fields.String(required=True, max_length=7, choices=MessageType) + kind = fields.String(required=True, max_length=15, choices=MessageType) # Name of service that owns the contract of the message owner = fields.String(max_length=50) @@ -41,6 +45,9 @@ class MessageMetadata(BaseValueObject): # Events raised subsequently by the commands also carry forward the original stream name. origin_stream_name = fields.String() + # FIXME Provide mechanism to add custom metadata fields/structure + # Can come handy in case of multi-tenancy, etc. + class MessageRecord(BaseContainer): """ @@ -252,3 +259,10 @@ def _handle(cls, message: Message) -> None: for handler_method in handlers: handler_method(cls(), message.to_object()) + + @classmethod + def handle_error(cls, exc: Exception, message: Message) -> None: + """Default error handler for messages. Can be overridden in subclasses. + + By default, this method logs the error and raises it. + """ diff --git a/tests/cli/test_server.py b/tests/cli/test_server.py new file mode 100644 index 00000000..2a9d2c74 --- /dev/null +++ b/tests/cli/test_server.py @@ -0,0 +1,67 @@ +import os +import sys + +from pathlib import Path + +import pytest + +from typer.testing import CliRunner + +from protean.cli import app +from tests.shared import change_working_directory_to + +runner = CliRunner() + + +class TestServerCommand: + @pytest.fixture(autouse=True) + def reset_path(self): + """Reset sys.path after every test run""" + original_path = sys.path[:] + cwd = Path.cwd() + + yield + + sys.path[:] = original_path + os.chdir(cwd) + + def test_server_with_invalid_domain(self): + """Test that the server command fails when domain is not provided""" + args = ["server", "--domain", "foobar"] + result = runner.invoke(app, args) + assert result.exit_code != 0 + assert isinstance(result.exception, SystemExit) + assert "Aborted" in result.output + + def test_server_start_successfully(self): + change_working_directory_to("test7") + + args = ["shell", "--domain", "publishing7.py"] + + # Run the shell command + result = runner.invoke(app, args) + + # Assertions + assert result.exit_code == 0 + + def test_server_start_failure(self): + pass + + def test_that_server_processes_messages_on_start(self): + # Start in non-test mode + # Ensure messages are processed + # Manually shutdown with `asyncio.create_task(engine.shutdown())` + pass + + def test_debug_mode(self): + # Test debug mode is saved and correct logger level is set + pass + + def test_that_server_processes_messages_in_test_mode(self): + pass + + def test_that_server_handles_exceptions_elegantly(self): + pass + + def test_that_last_read_positions_are_saved(self): + pass diff --git a/tests/cli/test_shell.py b/tests/cli/test_shell.py index 901ea921..d7a52f1a 100644 --- a/tests/cli/test_shell.py +++ b/tests/cli/test_shell.py @@ -8,7 +8,6 @@ from typer.testing import CliRunner from protean.cli import app -from protean.exceptions import NoDomainException from tests.shared import change_working_directory_to runner = CliRunner() @@ -35,7 +34,6 @@ def test_shell_command_success(self): result = runner.invoke(app, args) # Assertions - print(result.output) assert result.exit_code == 0 def test_shell_command_with_no_explicit_domain_and_domain_py_file(self): @@ -47,7 +45,6 @@ def test_shell_command_with_no_explicit_domain_and_domain_py_file(self): result = runner.invoke(app, args) # Assertions - print(result.output) assert result.exit_code == 0 def test_shell_command_with_no_explicit_domain_and_subdomain_py_file(self): @@ -59,7 +56,6 @@ def test_shell_command_with_no_explicit_domain_and_subdomain_py_file(self): result = runner.invoke(app, args) # Assertions - print(result.output) assert result.exit_code == 0 def test_shell_command_with_domain_attribute_name_as_domain(self): @@ -71,7 +67,6 @@ def test_shell_command_with_domain_attribute_name_as_domain(self): result = runner.invoke(app, args) # Assertions - print(result.output) assert result.exit_code == 0 def test_shell_command_with_domain_attribute_name_as_subdomain(self): @@ -83,14 +78,23 @@ def test_shell_command_with_domain_attribute_name_as_subdomain(self): result = runner.invoke(app, args) # Assertions - print(result.output) assert result.exit_code == 0 def test_shell_command_raises_no_domain_exception_when_no_domain_is_found(self): - change_working_directory_to("test7") - args = ["shell", "--domain", "foobar"] # Run the shell command and expect it to raise an exception - with pytest.raises(NoDomainException): - runner.invoke(app, args, catch_exceptions=False) + result = runner.invoke(app, args, catch_exceptions=False) + assert result.exit_code == 1 + assert isinstance(result.exception, SystemExit) + assert "Aborted" in result.output + + def test_shell_command_with_traverse_option(self): + change_working_directory_to("test7") + + args = ["shell", "--domain", "publishing7.py", "--traverse"] + + # Run the shell command + result = runner.invoke(app, args) + + assert "Traversing directory to load all modules..." in result.stdout diff --git a/tests/context/tests.py b/tests/context/tests.py index 0e0889f4..bb732c11 100644 --- a/tests/context/tests.py +++ b/tests/context/tests.py @@ -130,11 +130,11 @@ def test_domain_context_globals_methods(self, test_domain, test_domain_context): assert repr(g) == "" def test_custom_domain_ctx_globals_class(self, test_domain): - class CustomRequestGlobals: + class CustomGlobals: def __init__(self): self.spam = "eggs" - test_domain.domain_context_globals_class = CustomRequestGlobals + test_domain.domain_context_globals_class = CustomGlobals with test_domain.domain_context(): assert g.spam == "eggs" @@ -149,5 +149,6 @@ def test_domain_context_kwargs(self, test_domain): def test_domain_context_globals_not_shared(self, test_domain): with test_domain.domain_context(foo="bar"): assert g.foo == "bar" - with test_domain.domain_context(foo="baz"): - assert g.foo == "baz" + + with test_domain.domain_context(foo="baz"): + assert g.foo == "baz" diff --git a/tests/field/test_datetime.py b/tests/field/test_datetime.py new file mode 100644 index 00000000..9d1d290c --- /dev/null +++ b/tests/field/test_datetime.py @@ -0,0 +1,19 @@ +from datetime import datetime, timezone + +from protean.fields import DateTime + + +def utc_now(): + return datetime.now(timezone.utc) + + +def test_datetime_repr_and_str(): + dt_obj1 = DateTime() + dt_obj2 = DateTime(required=True) + dt_obj3 = DateTime(default="2020-01-01T00:00:00") + dt_obj4 = DateTime(required=True, default=utc_now) + + assert repr(dt_obj1) == str(dt_obj1) == "DateTime()" + assert repr(dt_obj2) == str(dt_obj2) == "DateTime(required=True)" + assert repr(dt_obj3) == str(dt_obj3) == "DateTime(default='2020-01-01T00:00:00')" + assert repr(dt_obj4) == str(dt_obj4) == "DateTime(required=True, default=utc_now)" diff --git a/tests/field/test_field_types.py b/tests/field/test_field_types.py index 1bc1fa92..f050f4f0 100644 --- a/tests/field/test_field_types.py +++ b/tests/field/test_field_types.py @@ -162,15 +162,6 @@ def test_max_value(self): score = Float(max_value=5.5) score._load(5.6) - @pytest.mark.xfail - def test_none_value(self): - """Test None value treatment for the float field""" - - score = Float(max_value=5.5) - score._load(None) - - assert score.value == 0.0 - class TestBooleanField: """Test the Boolean Field Implementation""" diff --git a/tests/field/test_string.py b/tests/field/test_string.py index 38b9b5b0..53e29cbb 100644 --- a/tests/field/test_string.py +++ b/tests/field/test_string.py @@ -1,6 +1,34 @@ from protean.fields import String +def test_string_repr_and_str(): + str_obj1 = String(max_length=50) + str_obj2 = String(min_length=50) + str_obj3 = String(sanitize=False) + str_obj4 = String(max_length=50, min_length=50, sanitize=False) + str_obj5 = String(required=True, default="John Doe") + str_obj6 = String( + required=True, default="John Doe", min_length=50, max_length=50, sanitize=False + ) + + assert repr(str_obj1) == str(str_obj1) == "String(max_length=50)" + assert repr(str_obj2) == str(str_obj2) == "String(min_length=50)" + assert repr(str_obj3) == str(str_obj3) == "String(sanitize=False)" + assert ( + repr(str_obj4) + == str(str_obj4) + == "String(max_length=50, min_length=50, sanitize=False)" + ) + assert ( + repr(str_obj5) == str(str_obj5) == "String(required=True, default='John Doe')" + ) + assert ( + repr(str_obj6) + == str(str_obj6) + == "String(required=True, default='John Doe', max_length=50, min_length=50, sanitize=False)" + ) + + def test_sanitization_option_for_string_fields(): str_field1 = String() assert str_field1.sanitize is True @@ -13,7 +41,7 @@ def test_that_string_values_are_automatically_cleaned(): str_field = String() value = str_field._load("an example") - assert value == u"an <script>evil()</script> example" + assert value == "an <script>evil()</script> example" def test_that_sanitization_can_be_optionally_switched_off(): diff --git a/tests/field/test_text.py b/tests/field/test_text.py index a6d7e864..7a51314e 100644 --- a/tests/field/test_text.py +++ b/tests/field/test_text.py @@ -1,6 +1,24 @@ from protean.fields import Text +def test_text_repr_and_str(): + text_obj1 = Text(sanitize=False) + text_obj2 = Text(required=True, default="John Doe") + text_obj3 = Text(required=True, sanitize=False) + text_obj4 = Text(required=True, default="John Doe", sanitize=False) + + assert repr(text_obj1) == str(text_obj1) == "Text(sanitize=False)" + assert ( + repr(text_obj2) == str(text_obj2) == "Text(required=True, default='John Doe')" + ) + assert repr(text_obj3) == str(text_obj3) == "Text(required=True, sanitize=False)" + assert ( + repr(text_obj4) + == str(text_obj4) + == "Text(required=True, default='John Doe', sanitize=False)" + ) + + def test_sanitization_option_for_text_fields(): text_field1 = Text() assert text_field1.sanitize is True @@ -13,7 +31,7 @@ def test_that_text_values_are_automatically_cleaned(): text_field = Text() value = text_field._load("an example") - assert value == u"an <script>evil()</script> example" + assert value == "an <script>evil()</script> example" def test_that_sanitization_can_be_optionally_switched_off(): diff --git a/tests/field/tests.py b/tests/field/tests.py index c8164c51..a4b35952 100644 --- a/tests/field/tests.py +++ b/tests/field/tests.py @@ -116,3 +116,18 @@ def medium_string_validator(value): with pytest.raises(ValidationError): name = DummyStringField() name._load("Dummy Dummy Dummy") + + def test_repr(self): + """Test that Field repr is generated correctly""" + + name = DummyStringField() + assert repr(name) == "DummyStringField()" + + name = DummyStringField(required=True) + assert repr(name) == "DummyStringField(required=True)" + + name = DummyStringField(default="dummy") + assert repr(name) == "DummyStringField(default='dummy')" + + name = DummyStringField(required=True, default="dummy") + assert repr(name) == "DummyStringField(required=True, default='dummy')" diff --git a/tests/message/test_origin_stream_name_in_metadata.py b/tests/message/test_origin_stream_name_in_metadata.py index d1355a00..49ad6c1d 100644 --- a/tests/message/test_origin_stream_name_in_metadata.py +++ b/tests/message/test_origin_stream_name_in_metadata.py @@ -6,7 +6,7 @@ from protean.fields import String from protean.fields.basic import Identifier from protean.globals import g -from protean.utils.mixins import Message +from protean.utils.mixins import Message, MessageMetadata class User(BaseEventSourcedAggregate): @@ -74,7 +74,10 @@ def test_origin_stream_name_in_event_from_command_without_origin_stream_name(use def test_origin_stream_name_in_event_from_command_with_origin_stream_name(user_id): command_message = register_command_message(user_id) - command_message.metadata.origin_stream_name = "foo" + + command_message.metadata = MessageMetadata( + command_message.metadata.to_dict(), origin_stream_name="foo" + ) # MessageMetadata is a VO and immutable, so creating a copy with updated value g.message_in_context = command_message event_message = Message.to_message( @@ -113,7 +116,10 @@ def test_origin_stream_name_in_aggregate_event_from_command_with_origin_stream_n user_id, ): command_message = register_command_message(user_id) - command_message.metadata.origin_stream_name = "foo" + + command_message.metadata = MessageMetadata( + command_message.metadata.to_dict(), origin_stream_name="foo" + ) # MessageMetadata is a VO and immutable, so creating a copy with updated value g.message_in_context = command_message user = User( diff --git a/tests/reflection/test_id_field.py b/tests/reflection/test_id_field.py new file mode 100644 index 00000000..ee021c51 --- /dev/null +++ b/tests/reflection/test_id_field.py @@ -0,0 +1,26 @@ +import pytest + +from protean import BaseValueObject, Domain +from protean.exceptions import IncorrectUsageError +from protean.fields import Float, String +from protean.reflection import id_field + +domain = Domain(__name__) + + +class Balance(BaseValueObject): + currency = String(max_length=3, required=True) + amount = Float(required=True) + + +def test_value_objects_do_not_have_id_fields(): + with pytest.raises(IncorrectUsageError) as exception: + id_field(Balance) + + assert str(exception.value) == str( + { + "identity": [ + " does not have identity fields" + ] + } + ) diff --git a/tests/server/test_engine_initialization.py b/tests/server/test_engine_initialization.py index 8af56fc8..9eb22677 100644 --- a/tests/server/test_engine_initialization.py +++ b/tests/server/test_engine_initialization.py @@ -3,14 +3,6 @@ from protean import Engine -def test_that_domain_is_loaded_from_domain_file(): - engine = Engine.from_domain_file( - domain="baz", domain_file="tests/server/dummy_domain.py" - ) - assert engine.domain is not None - assert engine.domain.name == "FooBar" - - def test_that_engine_can_be_initialized_from_a_domain_object(test_domain): engine = Engine(test_domain) assert engine.domain == test_domain diff --git a/tests/server/test_engine_run.py b/tests/server/test_engine_run.py index e81cd14f..6c242c77 100644 --- a/tests/server/test_engine_run.py +++ b/tests/server/test_engine_run.py @@ -1,11 +1,103 @@ +import asyncio + +from uuid import uuid4 + import pytest -from protean import Engine +from protean import BaseEvent, BaseEventHandler, Engine, handle +from protean.fields import Identifier + +counter = 0 + + +def count_up(): + global counter + counter += 1 + + +class UserLoggedIn(BaseEvent): + user_id = Identifier(identifier=True) + + class Meta: + stream_name = "authentication" + +class UserEventHandler(BaseEventHandler): + @handle(UserLoggedIn) + def count_users(self, event: UserLoggedIn) -> None: + count_up() -@pytest.mark.skip(reason="Yet to implement") -def test_running_subscriptions_on_engine_start(): - engine = Engine.from_domain_file( - domain="baz", domain_file="tests/server/dummy_domain.py", test_mode=True - ) + class Meta: + stream_name = "authentication" + + +@pytest.fixture(autouse=True) +def auto_set_and_close_loop(): + # Create and set a new loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + yield + + # Close the loop after the test + if not loop.is_closed(): + loop.close() + asyncio.set_event_loop(None) # Explicitly unset the loop + + +def test_processing_messages_on_start(test_domain): + test_domain.register(UserLoggedIn) + test_domain.register(UserEventHandler) + + identifier = str(uuid4()) + event = UserLoggedIn(user_id=identifier) + test_domain.event_store.store.append(event) + + engine = Engine(domain=test_domain, test_mode=True) engine.run() + + global counter + assert counter == 1 + + +def test_that_read_position_is_updated_after_engine_run(test_domain): + test_domain.register(UserLoggedIn) + test_domain.register(UserEventHandler) + + identifier = str(uuid4()) + event = UserLoggedIn(user_id=identifier) + test_domain.event_store.store.append(event) + + messages = test_domain.event_store.store.read("authentication") + assert len(messages) == 1 + + engine = Engine(domain=test_domain, test_mode=True) + engine.run() + + messages = test_domain.event_store.store.read("$all") + assert len(messages) == 2 + + +def test_processing_messages_from_beginning_the_first_time(test_domain): + test_domain.register(UserLoggedIn) + test_domain.register(UserEventHandler) + + identifier = str(uuid4()) + event = UserLoggedIn(user_id=identifier) + test_domain.event_store.store.append(event) + + engine = Engine(domain=test_domain, test_mode=True) + engine.run() + + messages = test_domain.event_store.store.read("$all") + assert len(messages) == 2 + + # Create and set a new loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + engine = Engine(domain=test_domain, test_mode=True) + engine.run() + + messages = test_domain.event_store.store.read("$all") + assert len(messages) == 2 diff --git a/tests/server/test_error_handling.py b/tests/server/test_error_handling.py new file mode 100644 index 00000000..bb0688be --- /dev/null +++ b/tests/server/test_error_handling.py @@ -0,0 +1,101 @@ +import asyncio + +from uuid import uuid4 + +import pytest + +from protean import BaseEvent, BaseEventHandler, BaseEventSourcedAggregate, handle +from protean.fields import Identifier, String +from protean.server import Engine +from protean.utils.mixins import Message + + +class Registered(BaseEvent): + id = Identifier() + email = String() + name = String() + password_hash = String() + + +class User(BaseEventSourcedAggregate): + email = String() + name = String() + password_hash = String() + + +def some_function(): + raise Exception("Some exception") + + +class UserEventHandler(BaseEventHandler): + @handle(Registered) + def send_notification(self, event: Registered) -> None: + some_function() + + +@pytest.fixture(autouse=True) +def auto_set_and_close_loop(): + # Create and set a new loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + yield + + # Close the loop after the test + if not loop.is_closed(): + loop.close() + asyncio.set_event_loop(None) # Explicitly unset the loop + + +@pytest.mark.asyncio +async def test_that_exception_is_raised(test_domain): + test_domain.register(User) + test_domain.register(Registered) + test_domain.register(UserEventHandler, aggregate_cls=User) + + identifier = str(uuid4()) + user = User( + id=identifier, + email="john.doe@example.com", + name="John Doe", + password_hash="hash", + ) + event = Registered( + id=identifier, + email="john.doe@example.com", + name="John Doe", + password_hash="hash", + ) + message = Message.to_aggregate_event_message(user, event) + + engine = Engine(domain=test_domain, test_mode=True) + + await engine.handle_message(UserEventHandler, message) + + assert engine.exit_code == 1 + + +def test_exceptions_stop_processing(test_domain): + test_domain.register(User) + test_domain.register(Registered) + test_domain.register(UserEventHandler, aggregate_cls=User) + + identifier = str(uuid4()) + user = User( + id=identifier, + email="john.doe@example.com", + name="John Doe", + password_hash="hash", + ) + event = Registered( + id=identifier, + email="john.doe@example.com", + name="John Doe", + password_hash="hash", + ) + test_domain.event_store.store.append_aggregate_event(user, event) + + engine = Engine(domain=test_domain) + engine.run() + + assert engine.exit_code == 1 diff --git a/tests/server/test_event_handler_subscription.py b/tests/server/test_event_handler_subscription.py index 075ba075..38f0815d 100644 --- a/tests/server/test_event_handler_subscription.py +++ b/tests/server/test_event_handler_subscription.py @@ -1,5 +1,9 @@ from __future__ import annotations +import asyncio + +import pytest + from protean import BaseEvent, BaseEventHandler, BaseEventSourcedAggregate, handle from protean.fields import DateTime, Identifier, String from protean.server import Engine @@ -53,6 +57,20 @@ def record_sent_email(self, event: Sent) -> None: pass +@pytest.fixture(autouse=True) +def setup_event_loop(): + """Ensure an Event Loop Exists in Tests. + + Otherwise tests are attempting to access the asyncio event loop from a non-async context + where no event loop is running or set as the current event loop. + """ + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + yield + loop.close() + asyncio.set_event_loop(None) + + def test_event_subscriptions(test_domain): test_domain.register(UserEventHandler, aggregate_cls=User) engine = Engine(test_domain, test_mode=True) diff --git a/tests/subscription/test_message_filtering_with_origin_stream.py b/tests/subscription/test_message_filtering_with_origin_stream.py index 6f334ad5..de2e6ec3 100644 --- a/tests/subscription/test_message_filtering_with_origin_stream.py +++ b/tests/subscription/test_message_filtering_with_origin_stream.py @@ -10,7 +10,7 @@ from protean.fields import DateTime, Identifier, String from protean.server import Engine from protean.utils import fqn -from protean.utils.mixins import Message +from protean.utils.mixins import Message, MessageMetadata class User(BaseEventSourcedAggregate): @@ -91,7 +91,10 @@ async def test_message_filtering_for_event_handlers_with_defined_origin_stream( email, Sent(email="john.doe@gmail.com", sent_at=datetime.now(UTC)) ), ] - messages[2].metadata.origin_stream_name = f"user-{identifier}" + + messages[2].metadata = MessageMetadata( + messages[2].metadata.to_dict(), origin_stream_name=f"user-{identifier}" + ) # MessageMetadata is a VO and immutable, so creating a copy with updated value # Mock `read` method and have it return the 3 messages mock_store_read = mock.Mock() diff --git a/tests/subscription/test_no_message_filtering.py b/tests/subscription/test_no_message_filtering.py index 10b59c38..add7f095 100644 --- a/tests/subscription/test_no_message_filtering.py +++ b/tests/subscription/test_no_message_filtering.py @@ -10,7 +10,7 @@ from protean.fields import DateTime, Identifier, String from protean.server import Engine from protean.utils import fqn -from protean.utils.mixins import Message +from protean.utils.mixins import Message, MessageMetadata class User(BaseEventSourcedAggregate): @@ -91,7 +91,10 @@ async def test_no_filtering_for_event_handlers_without_defined_origin_stream( email, Sent(email="john.doe@gmail.com", sent_at=datetime.now(UTC)) ), ] - messages[2].metadata.origin_stream_name = f"user-{identifier}" + + messages[2].metadata = MessageMetadata( + messages[2].metadata.to_dict(), origin_stream_name=f"user-{identifier}" + ) # MessageMetadata is a VO and immutable, so creating a copy with updated value # Mock `read` method and have it return the 3 messages mock_store_read = mock.Mock() diff --git a/tests/subscription/test_read_position_updates.py b/tests/subscription/test_read_position_updates.py new file mode 100644 index 00000000..a91c7427 --- /dev/null +++ b/tests/subscription/test_read_position_updates.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from uuid import uuid4 + +import pytest + +from protean import BaseEvent, BaseEventHandler, BaseEventSourcedAggregate, handle +from protean.fields import DateTime, Identifier, String +from protean.server import Engine +from protean.utils import fqn + + +class User(BaseEventSourcedAggregate): + email = String() + name = String() + password_hash = String() + + +class Email(BaseEventSourcedAggregate): + email = String() + sent_at = DateTime() + + +def dummy(*args): + pass + + +class Registered(BaseEvent): + id = Identifier() + email = String() + name = String() + password_hash = String() + + +class Activated(BaseEvent): + id = Identifier() + activated_at = DateTime() + + +class Sent(BaseEvent): + email = String() + sent_at = DateTime() + + +class UserEventHandler(BaseEventHandler): + @handle(Registered) + def send_activation_email(self, event: Registered) -> None: + dummy(event) + + @handle(Activated) + def provision_user(self, event: Activated) -> None: + dummy(event) + + @handle(Activated) + def send_welcome_email(self, event: Activated) -> None: + dummy(event) + + +class EmailEventHandler(BaseEventHandler): + @handle(Sent) + def record_sent_email(self, event: Sent) -> None: + pass + + +@pytest.fixture(autouse=True) +def register_elements(test_domain): + test_domain.register(User) + test_domain.register(Email) + test_domain.register(Registered) + test_domain.register(Activated) + test_domain.register(Sent) + test_domain.register(UserEventHandler, aggregate_cls=User) + test_domain.register(EmailEventHandler, stream_name="email") + + +@pytest.mark.asyncio +async def test_initial_read_position(test_domain): + engine = Engine(test_domain, test_mode=True) + email_event_handler_subscription = engine._subscriptions[fqn(EmailEventHandler)] + + assert email_event_handler_subscription.current_position == -1 + + last_written_position = await email_event_handler_subscription.fetch_last_position() + assert last_written_position == -1 + + +@pytest.mark.asyncio +async def test_write_position_after_interval(test_domain): + engine = Engine(test_domain, test_mode=True) + email_event_handler_subscription = engine._subscriptions[fqn(EmailEventHandler)] + + await email_event_handler_subscription.load_position_on_start() + await email_event_handler_subscription.update_current_position_to_store() + + email_address = "john.doe@gmail.com" + sent_at = datetime.now(UTC) + email = Email(id=str(uuid4()), email=email_address, sent_at=sent_at) + event = Sent(email=email_address, sent_at=sent_at) + + # ASSERT Initial state + last_written_position = await email_event_handler_subscription.fetch_last_position() + assert last_written_position == -1 # Default value + + test_domain.event_store.store.append_aggregate_event(email, event) + + await email_event_handler_subscription.tick() + + # ASSERT Positions after reading 1 message + last_written_position = await email_event_handler_subscription.fetch_last_position() + assert email_event_handler_subscription.current_position == 1 + assert last_written_position == -1 # Remains -1 because interval is not reached + + # Populate 15 messages (5 more than default interval) + for _ in range(15): + test_domain.event_store.store.append_aggregate_event(email, event) + + await email_event_handler_subscription.tick() + last_written_position = await email_event_handler_subscription.fetch_last_position() + + # ASSERT Positions after reading 10 messages + # Current position should be 12 because even though read 10 messages + # there is a position update message in the middle + assert email_event_handler_subscription.current_position == 12 + assert last_written_position == 11 # We just completed reading 10 messages + + # ASSERT Positions after reading to end of messages + await email_event_handler_subscription.tick() + last_written_position = await email_event_handler_subscription.fetch_last_position() + assert ( + email_event_handler_subscription.current_position == 16 + ) # Continued reading until end + assert last_written_position == 11 # Remains 11 because interval is not reached + + +@pytest.mark.asyncio +async def test_that_positions_are_not_written_when_already_in_sync(test_domain): + engine = Engine(test_domain, test_mode=True) + email_event_handler_subscription = engine._subscriptions[fqn(EmailEventHandler)] + + await email_event_handler_subscription.load_position_on_start() + + email_address = "john.doe@gmail.com" + sent_at = datetime.now(UTC) + email = Email(id=str(uuid4()), email=email_address, sent_at=sent_at) + event = Sent(email=email_address, sent_at=sent_at) + + # Populate 15 messages (5 more than default interval) + for _ in range(15): + test_domain.event_store.store.append_aggregate_event(email, event) + + # Consume messages (By default, 10 messages per tick) + await email_event_handler_subscription.tick() + + # Fetch the current event store state + # total_no_of_messages should be 16, including the position update message + total_no_of_messages = len(test_domain.event_store.store.read("$all")) + assert total_no_of_messages == 16 + + # Simulating server shutdown + # Try to manually update the position to the store + await email_event_handler_subscription.update_current_position_to_store() + + # Ensure that the event store state did not change + # This means that we did not add duplicate position update messages + assert len(test_domain.event_store.store.read("$all")) == total_no_of_messages + # Ensure last read message remains at 10 + assert await email_event_handler_subscription.fetch_last_position() == 10 diff --git a/tests/value_object/test_immutability.py b/tests/value_object/test_immutability.py new file mode 100644 index 00000000..20702dce --- /dev/null +++ b/tests/value_object/test_immutability.py @@ -0,0 +1,57 @@ +import pytest + +from protean import Domain +from protean.exceptions import IncorrectUsageError +from protean.fields import Float, String, ValueObject + +domain = Domain(__name__) + + +@domain.value_object +class Balance: + currency = String(max_length=3, required=True) + amount = Float(required=True) + + +@domain.aggregate +class Account: + balance = ValueObject(Balance) + name = String(max_length=30) + + +def test_value_objects_are_immutable(): + balance = Balance(currency="USD", amount=100.0) + + with pytest.raises(IncorrectUsageError) as exception: + balance.currency = "INR" + + assert str(exception.value) == str( + { + "_value_object": [ + "Value Objects are immutable and cannot be modified once created" + ] + } + ) + + +def test_value_objects_can_be_switched(): + balance = Balance(currency="USD", amount=100.0) + account = Account(balance=balance, name="John Doe") + + assert account.balance.currency == "USD" + assert account.balance_currency == "USD" + + account.balance = Balance(currency="INR", amount=100.0) + assert account.balance.currency == "INR" + assert account.balance_currency == "INR" + + +def test_that_updating_attributes_linked_to_value_objects_has_no_impact(): + balance = Balance(currency="USD", amount=100.0) + account = Account(balance=balance, name="John Doe") + + # This is a dummy attribute that is linked to the value object + # Updating this attribute should not impact the value object + account.balance_currency = "INR" + + assert account.balance.currency == "USD" diff --git a/tests/value_object/test_vo_custom_validators.py b/tests/value_object/test_vo_custom_validators.py new file mode 100644 index 00000000..00f759ba --- /dev/null +++ b/tests/value_object/test_vo_custom_validators.py @@ -0,0 +1,89 @@ +# FIXME Use the file at docs_src/guides/domain-definition/009.py +import pytest + +from protean import BaseAggregate, BaseValueObject +from protean.exceptions import ValidationError +from protean.fields import String, ValueObject + + +class EmailValidator: + def __init__(self): + self.error = f"Invalid email address" + + def __call__(self, value): + """Business rules of Email address""" + if ( + # should contain one "@" symbol + value.count("@") != 1 + # should not start with "@" or "." + or value.startswith("@") + or value.startswith(".") + # should not end with "@" or "." + or value.endswith("@") + or value.endswith(".") + # should not contain consecutive dots + or value in ["..", ".@", "@."] + # local part should not be more than 64 characters + or len(value.split("@")[0]) > 64 + # Each label can be up to 63 characters long. + or any(len(label) > 63 for label in value.split("@")[1].split(".")) + # Labels must start and end with a letter (a-z, A-Z) or a digit (0-9), and can contain hyphens (-), + # but cannot start or end with a hyphen. + or not all( + label[0].isalnum() + and label[-1].isalnum() + and all(c.isalnum() or c == "-" for c in label) + for label in value.split("@")[1].split(".") + ) + # No spaces or unprintable characters are allowed. + or not all(c.isprintable() and not c.isspace() for c in value) + ): + raise ValidationError(self.error) + + +class Email(BaseValueObject): + """An email address value object, with two identified parts: + * local_part + * domain_part + """ + + # This is the external facing data attribute + address = String(max_length=254, required=True, validators=[EmailValidator()]) + + +class User(BaseAggregate): + email = ValueObject(Email) + name = String(max_length=30) + timezone = String(max_length=30) + + +def test_vo_with_correct_email_address(): + email = Email(address="john.doe@gmail.com") + assert email.address == "john.doe@gmail.com" + + +def test_vo_with_incorrect_email_address(): + with pytest.raises(ValidationError) as exc: + Email(address="john.doegmail.com") + + assert str(exc.value) == str({"address": ["Invalid email address"]}) + + +def test_embedded_vo_with_correct_email_address(): + user = User( + email_address="john.doe@gmail.com", + name="John Doe", + timezone="America/Los_Angeles", + ) + assert user.email.address == "john.doe@gmail.com" + + +def test_embedded_vo_with_incorrect_email_address(): + with pytest.raises(ValidationError) as exc: + User( + email_address="john.doegmail.com", + name="John Doe", + timezone="America/Los_Angeles", + ) + + assert str(exc.value) == str({"email_address": ["Invalid email address"]}) diff --git a/tests/value_object/test_vo_field_properties.py b/tests/value_object/test_vo_field_properties.py new file mode 100644 index 00000000..91e133c9 --- /dev/null +++ b/tests/value_object/test_vo_field_properties.py @@ -0,0 +1,37 @@ +import pytest + +from protean import BaseValueObject +from protean.exceptions import IncorrectUsageError +from protean.fields import Float, String + + +def test_vo_cannot_contain_fields_marked_unique(): + with pytest.raises(IncorrectUsageError) as exception: + + class Balance(BaseValueObject): + currency = String(max_length=3, required=True, unique=True) + amount = Float(required=True) + + assert str(exception.value) == str( + { + "_value_object": [ + "Value Objects cannot contain fields marked 'unique' (field 'currency')" + ] + } + ) + + +def test_vo_cannot_contain_fields_marked_as_identifiers(): + with pytest.raises(IncorrectUsageError) as exception: + + class Balance(BaseValueObject): + currency = String(max_length=3, required=True, identifier=True) + amount = Float(required=True) + + assert str(exception.value) == str( + { + "_value_object": [ + "Value Objects cannot contain fields marked 'identifier' (field 'currency')" + ] + } + ) diff --git a/tests/value_object/tests.py b/tests/value_object/tests.py index a64ddb27..dfc88ee1 100644 --- a/tests/value_object/tests.py +++ b/tests/value_object/tests.py @@ -1,6 +1,6 @@ import pytest -from protean.exceptions import InvalidOperationError, ValidationError +from protean.exceptions import IncorrectUsageError, ValidationError from protean.reflection import attributes, declared_fields from .elements import ( @@ -66,10 +66,9 @@ def test_str_output_of_value_object(self): email = Email.from_address("john.doe@gmail.com") assert str(email) == "Email object ({'address': 'john.doe@gmail.com'})" - @pytest.mark.xfail def test_that_value_objects_are_immutable(self): email = Email.from_address(address="john.doe@gmail.com") - with pytest.raises(InvalidOperationError): + with pytest.raises(IncorrectUsageError): email.address = "jane.doe@gmail.com" @@ -192,14 +191,12 @@ def test_that_mandatory_fields_are_validated(self): with pytest.raises(ValidationError) as multi_exceptions: User() - assert "email_address" in multi_exceptions.value.messages - assert multi_exceptions.value.messages["email_address"] == ["is required"] + assert multi_exceptions.value.messages["email"] == ["is required"] with pytest.raises(ValidationError) as email_exception: User(name="John Doe") - assert "email_address" in email_exception.value.messages - assert email_exception.value.messages["email_address"] == ["is required"] + assert email_exception.value.messages["email"] == ["is required"] class TestBalanceVOEmbedding: