diff --git a/pyproject.toml b/pyproject.toml index ad6e56fe..baf8c099 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,6 +155,7 @@ markers = [ "message_db", "sendgrid", "database", + "broker_common", "eventstore", "no_test_domain", ] diff --git a/src/protean/adapters/broker/inline.py b/src/protean/adapters/broker/inline.py index a3368e3b..b2034c7e 100644 --- a/src/protean/adapters/broker/inline.py +++ b/src/protean/adapters/broker/inline.py @@ -29,6 +29,15 @@ def _get_next(self, channel: str) -> dict | None: return self._messages[channel].pop(0) return None + def read(self, channel: str, no_of_messages: int) -> list[dict]: + """Read messages from the broker""" + messages = [] + while no_of_messages > 0 and self._messages[channel]: + messages.append(self._messages[channel].pop(0)) + no_of_messages -= 1 + + return messages + def _data_reset(self) -> None: """Flush all data in broker instance""" self._messages.clear() diff --git a/src/protean/adapters/broker/redis.py b/src/protean/adapters/broker/redis.py index f4165be8..b3f11f7a 100644 --- a/src/protean/adapters/broker/redis.py +++ b/src/protean/adapters/broker/redis.py @@ -31,5 +31,14 @@ def _get_next(self, channel: str) -> dict | None: return None + def read(self, channel: str, no_of_messages: int) -> list[dict]: + messages = [] + for _ in range(no_of_messages): + bytes_message = self.redis_instance.lpop(channel) + if bytes_message: + messages.append(json.loads(bytes_message)) + + return messages + def _data_reset(self) -> None: self.redis_instance.flushall() diff --git a/src/protean/port/broker.py b/src/protean/port/broker.py index c5b4b5a8..e6c1c408 100644 --- a/src/protean/port/broker.py +++ b/src/protean/port/broker.py @@ -72,6 +72,18 @@ def get_next(self, channel: str) -> dict | None: def _get_next(self, channel: str) -> dict | None: """Overridden method to retrieve the next message to process from broker.""" + @abstractmethod + def read(self, channel: str, no_of_messages: int) -> list[dict]: + """Read messages from the broker. + + Args: + channel (str): The channel from which to read messages + no_of_messages (int): The number of messages to read + + Returns: + list[dict]: The list of messages + """ + @abstractmethod def _data_reset(self) -> None: """Flush all data in broker instance. diff --git a/src/protean/server/broker_subscription.py b/src/protean/server/broker_subscription.py new file mode 100644 index 00000000..6393c0e4 --- /dev/null +++ b/src/protean/server/broker_subscription.py @@ -0,0 +1,154 @@ +import asyncio +import logging +from typing import Type + +from protean.core.subscriber import BaseSubscriber +from protean.port.broker import BaseBroker + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s,%(msecs)d %(levelname)s: %(message)s", + datefmt="%H:%M:%S", +) + +logger = logging.getLogger(__name__) + + +class BrokerSubscription: + """ + Represents a subscription to a broker channel. + + A broker subscription allows a subscriber to receive and process messages from a specific channel. + It provides methods to start and stop the subscription, as well as process messages in batches. + """ + + def __init__( + self, + engine, + broker, + subscriber_id: str, + channel: str, + handler: Type[BaseSubscriber], + messages_per_tick: int = 10, + tick_interval: int = 1, + ) -> None: + """ + Initialize the Subscription object. + + Args: + engine: The Protean engine instance. + subscriber_id (str): The unique identifier for the subscriber. + channel (str): The name of the stream to subscribe to. + handler (Union[BaseEventHandler, BaseCommandHandler]): The event or command handler. + messages_per_tick (int, optional): The number of messages to process per tick. Defaults to 10. + tick_interval (int, optional): The interval between ticks. Defaults to 1. + """ + self.engine = engine + self.broker: BaseBroker = broker + self.loop = engine.loop + + self.subscriber_id = subscriber_id + self.channel = channel + self.handler = handler + self.messages_per_tick = messages_per_tick + self.tick_interval = tick_interval + + self.keep_going = True # Initially set to keep going + + async def start(self) -> None: + """ + Start the subscription. + + This method initializes the subscription by loading the last position from the event store + and starting the polling loop. + + Returns: + None + """ + logger.debug(f"Starting {self.subscriber_id}") + + # Start the polling loop + self.loop.create_task(self.poll()) + + async def poll(self) -> None: + """ + Polling loop for processing messages. + + This method continuously polls for new messages and processes them by calling the `tick` method. + It sleeps for the specified `tick_interval` between each tick. + + Returns: + None + """ + await self.tick() + + if self.keep_going and not self.engine.shutting_down: + # Keep control of the loop if in test mode + # Otherwise `asyncio.sleep` will give away control and + # the loop will be able to be stopped with `shutdown()` + if not self.engine.test_mode: + await asyncio.sleep(self.tick_interval) + + self.loop.create_task(self.poll()) + + async def tick(self): + """ + This method retrieves the next batch of messages to process and calls the `process_batch` method + to handle each message. It also updates the read position after processing each message. + + Returns: + None + """ + messages = await self.get_next_batch_of_messages() + if messages: + await self.process_batch(messages) + + async def shutdown(self): + """ + Shutdown the subscription. + + This method signals the subscription to stop polling and updates the current position to the store. + It also logs a message indicating the shutdown of the subscription. + + Returns: + None + """ + self.keep_going = False # Signal to stop polling + logger.info(f"Shutting down subscription {self.subscriber_id}") + + async def get_next_batch_of_messages(self): + """ + Get the next batch of messages to process. + + This method reads messages from the event store starting from the current position + 1. + It retrieves a specified number of messages per tick and applies filtering based on the origin stream name. + + Returns: + List[Message]: The next batch of messages to process. + """ + messages = self.broker.read( + self.channel, + no_of_messages=self.messages_per_tick, + ) # FIXME Implement filtering + + return messages + + async def process_batch(self, messages: list[dict]): + """ + Process a batch of messages. + + This method takes a batch of messages and processes each message by calling the `handle_message` method + of the engine. It also updates the read position after processing each message. If an exception occurs + during message processing, it logs the error using the `log_error` method. + + Args: + messages (List[Message]): The batch of messages to process. + + Returns: + int: The number of messages processed. + """ + logging.debug(f"Processing {len(messages)} messages...") + for message in messages: + await self.engine.handle_broker_message(self.handler, message) + + return len(messages) diff --git a/src/protean/server/engine.py b/src/protean/server/engine.py index 90244003..85c3f611 100644 --- a/src/protean/server/engine.py +++ b/src/protean/server/engine.py @@ -8,9 +8,11 @@ from protean.core.command_handler import BaseCommandHandler from protean.core.event_handler import BaseEventHandler +from protean.core.subscriber import BaseSubscriber from protean.utils.globals import g from protean.utils.mixins import Message +from .broker_subscription import BrokerSubscription from .subscription import Subscription logging.basicConfig( @@ -50,7 +52,7 @@ def __init__(self, domain, test_mode: bool = False, debug: bool = False) -> None self.loop = asyncio.get_event_loop() - # FIXME Gather all handlers + # Gather all handlers self._subscriptions = {} for handler_name, record in self.domain.registry.event_handlers.items(): # Create a subscription for each event handler @@ -72,6 +74,54 @@ def __init__(self, domain, test_mode: bool = False, debug: bool = False) -> None record.cls, ) + # Gather broker subscriptions + self._broker_subscriptions = {} + + for ( + subscriber_name, + subscriber_record, + ) in self.domain.registry.subscribers.items(): + subscriber_cls = subscriber_record.cls + broker_name = subscriber_cls.meta_.broker + broker = self.domain.brokers[broker_name] + channel = subscriber_cls.meta_.channel + self._broker_subscriptions[subscriber_name] = BrokerSubscription( + self, + broker, + subscriber_name, + channel, + subscriber_cls, + ) + + async def handle_broker_message( + self, subscriber_cls: Type[BaseSubscriber], message: dict + ) -> None: + """ + Handle a message received from the broker. + """ + + if self.shutting_down: + return # Skip handling if shutdown is in progress + + with self.domain.domain_context(): + try: + subscriber = subscriber_cls() + subscriber(message) + + logger.info( + f"{subscriber_cls.__name__} processed message successfully." + ) + except Exception as exc: + logger.error( + f"Error handling message in {subscriber_cls.__name__}: {str(exc)}" + ) + # Print the stack trace + logger.error(traceback.format_exc()) + # subscriber_cls.handle_error(exc, message) + + await self.shutdown(exit_code=1) + return + async def handle_message( self, handler_cls: Type[Union[BaseCommandHandler, BaseEventHandler]], @@ -81,7 +131,7 @@ async def handle_message( Handle a message by invoking the appropriate handler class. Args: - handler_cls (Type[Union[BaseCommandHandler, BaseEventHandler]]): The handler class to invoke. + handler_cls (Type[Union[BaseCommandHandler, BaseEventHandler]]): The handler class message (Message): The message to be handled. Returns: @@ -131,8 +181,12 @@ async def shutdown(self, signal=None, exit_code=0): self.shutting_down = True # Set shutdown flag try: - if signal: - logger.info(f"Received exit signal {signal.name}...") + msg = ( + "Received exit signal {signal.name}. Shutting down..." + if signal + else "Shutting down..." + ) + logger.info(msg) # Store the exit code self.exit_code = exit_code @@ -154,8 +208,7 @@ async def shutdown(self, signal=None, exit_code=0): await asyncio.gather(*subscription_shutdown_tasks, return_exceptions=True) logger.info("All subscriptions have been shut down.") finally: - if self.loop.is_running(): - self.loop.stop() + self.loop.stop() def run(self): """ @@ -184,19 +237,26 @@ def handle_exception(loop, context): self.loop.set_exception_handler(handle_exception) - if len(self._subscriptions) == 0: + if len(self._subscriptions) == 0 and len(self._broker_subscriptions) == 0: logger.info("No subscriptions to start. Exiting...") + return - # Start consumption, one per subscription - try: - tasks = [ - self.loop.create_task(subscription.start()) - for _, subscription in self._subscriptions.items() - ] + subscription_tasks = [ + self.loop.create_task(subscription.start()) + for _, subscription in self._subscriptions.items() + ] + broker_subscription_tasks = [ + self.loop.create_task(subscription.start()) + for _, subscription in self._broker_subscriptions.items() + ] + + try: if self.test_mode: # If in test mode, run until all tasks complete - self.loop.run_until_complete(asyncio.gather(*tasks)) + self.loop.run_until_complete( + asyncio.gather(*subscription_tasks, *broker_subscription_tasks) + ) # Then immediately call and await the shutdown directly self.loop.run_until_complete(self.shutdown()) else: diff --git a/tests/adapters/broker/redis_broker/test_processing_subscriber_messages.py b/tests/adapters/broker/redis_broker/test_processing_subscriber_messages.py new file mode 100644 index 00000000..6bc54869 --- /dev/null +++ b/tests/adapters/broker/redis_broker/test_processing_subscriber_messages.py @@ -0,0 +1,81 @@ +import asyncio + +import pytest + +from protean.core.subscriber import BaseSubscriber +from protean.server import Engine + +terms = [] + + +def append_to_terms(term): + global terms + terms.append(term) + + +class DummySubscriber(BaseSubscriber): + def __call__(self, data: dict): + append_to_terms(data["foo"]) + + +@pytest.fixture(autouse=True) +def clear_terms(): + yield + + global terms + terms = [] + + +@pytest.fixture(autouse=True) +def auto_set_and_close_loop(): + # Create and set a new loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + yield + + # Close the loop after the test + if not loop.is_closed(): + loop.close() + asyncio.set_event_loop(None) # Explicitly unset the loop + + +@pytest.mark.redis +@pytest.mark.asyncio +async def test_handler_invocation(test_domain): + test_domain.register(DummySubscriber, channel="test_channel") + test_domain.init(traverse=False) + + with test_domain.domain_context(): + channel = "test_channel" + message = {"foo": "bar"} + + test_domain.brokers["default"].publish(channel, message) + + engine = Engine(domain=test_domain, test_mode=True) + await engine.handle_broker_message(DummySubscriber, message) + + global terms + assert len(terms) == 1 + assert terms[0] == "bar" + + +@pytest.mark.redis +def test_processing_broker_messages(test_domain): + test_domain.register(DummySubscriber, channel="test_channel") + test_domain.init(traverse=False) + + with test_domain.domain_context(): + channel = "test_channel" + message1 = {"foo": "bar"} + message2 = {"foo": "baz"} + test_domain.brokers["default"].publish(channel, message1) + test_domain.brokers["default"].publish(channel, message2) + + engine = Engine(domain=test_domain, test_mode=True) + engine.run() + + global terms + assert len(terms) == 2 + assert terms[0] == "bar" + assert terms[1] == "baz" diff --git a/tests/cli/test_server.py b/tests/cli/test_server.py index c014142a..8a08a906 100644 --- a/tests/cli/test_server.py +++ b/tests/cli/test_server.py @@ -42,24 +42,13 @@ def test_server_start_successfully(self): # Assertions assert result.exit_code == 0 - def test_server_start_failure(self): - pass - def test_that_server_processes_messages_on_start(self): # Start in non-test mode # Ensure messages are processed # Manually shutdown with `asyncio.create_task(engine.shutdown())` pass + @pytest.mark.skip(reason="Not implemented") def test_debug_mode(self): # Test debug mode is saved and correct logger level is set pass - - def test_that_server_processes_messages_in_test_mode(self): - pass - - def test_that_server_handles_exceptions_elegantly(self): - pass - - def test_that_last_read_positions_are_saved(self): - pass diff --git a/tests/conftest.py b/tests/conftest.py index c056d634..f4b0702f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,6 +64,20 @@ def pytest_addoption(parser): help="Run tests against a Eventstore type", ) + # Options to run Broker tests + parser.addoption( + "--broker_common", + action="store_true", + default=False, + help="Broker test marker", + ) + parser.addoption( + "--broker", + action="store", + default="INLINE", + help="Run tests against a Eventstore type", + ) + def pytest_collection_modifyitems(config, items): """Configure special markers on tests, so as to control execution""" diff --git a/tests/server/test_engine_exits.py b/tests/server/test_engine_exits.py new file mode 100644 index 00000000..d95d3f03 --- /dev/null +++ b/tests/server/test_engine_exits.py @@ -0,0 +1,23 @@ +import logging + +import pytest + +from protean.domain import Domain +from protean.server.engine import Engine + + +@pytest.mark.no_test_domain +def test_engine_exits_if_no_subscriptions(caplog): + # Configure the logger to capture INFO level messages + logger = logging.getLogger("protean.server.engine") + logger.setLevel(logging.INFO) + + domain = Domain("dummy", load_toml=False) + engine = Engine(domain, test_mode=True) + engine.run() + + assert any( + record.levelname == "INFO" + and "No subscriptions to start. Exiting..." in record.message + for record in caplog.records + ) diff --git a/tests/server/test_gathering_broker_subscriptions.py b/tests/server/test_gathering_broker_subscriptions.py new file mode 100644 index 00000000..a6ca8a44 --- /dev/null +++ b/tests/server/test_gathering_broker_subscriptions.py @@ -0,0 +1,27 @@ +import pytest + +from protean.core.subscriber import BaseSubscriber +from protean.server import Engine +from protean.utils import fqn + + +class DummySubscriber(BaseSubscriber): + def __call__(self, data: dict): + pass + + +@pytest.fixture(autouse=True) +def register_elements(test_domain): + test_domain.register(DummySubscriber, channel="test_channel") + test_domain.init(traverse=False) + + +@pytest.fixture +def engine(test_domain): + return Engine(test_domain, test_mode=True) + + +def test_broker_subscriptions(engine): + assert len(engine._broker_subscriptions) == 1 + + assert fqn(DummySubscriber) in engine._broker_subscriptions diff --git a/tests/server/test_message_handling.py b/tests/server/test_message_handling.py new file mode 100644 index 00000000..bc4f17b9 --- /dev/null +++ b/tests/server/test_message_handling.py @@ -0,0 +1,60 @@ +import pytest + +from protean.core.subscriber import BaseSubscriber +from protean.server import Engine + +counter = 0 + + +def count_up(): + global counter + counter += 1 + + +class DummySubscriber(BaseSubscriber): + def __call__(self, data: dict): + count_up() + + +class ExceptionSubscriber(BaseSubscriber): + def __call__(self, data: dict): + raise Exception("This is a dummy exception") + + +@pytest.mark.asyncio +async def test_handler_invocation(test_domain): + test_domain.register(DummySubscriber, channel="test_channel") + test_domain.init(traverse=False) + + channel = "test_channel" + message = {"foo": "bar"} + + test_domain.brokers["default"].publish(channel, message) + + engine = Engine(domain=test_domain, test_mode=True) + await engine.handle_broker_message(DummySubscriber, message) + + global counter + assert counter == 1 + + +@pytest.mark.asyncio +async def test_handling_exception_raised_in_handler(test_domain, caplog): + test_domain.register(ExceptionSubscriber, channel="test_channel") + test_domain.init(traverse=False) + + channel = "test_channel" + message = {"foo": "bar"} + + test_domain.brokers["default"].publish(channel, message) + + engine = Engine(domain=test_domain, test_mode=True) + + await engine.handle_broker_message(ExceptionSubscriber, message) + + assert any( + record.levelname == "ERROR" and "Error handling message in " in record.message + for record in caplog.records + ) + + assert engine.shutting_down is True diff --git a/tests/server/test_processing_broker_messages.py b/tests/server/test_processing_broker_messages.py new file mode 100644 index 00000000..b40c929c --- /dev/null +++ b/tests/server/test_processing_broker_messages.py @@ -0,0 +1,77 @@ +import asyncio + +import pytest + +from protean.core.subscriber import BaseSubscriber +from protean.server import Engine + +terms = [] + + +def append_to_terms(term): + global terms + terms.append(term) + + +class DummySubscriber(BaseSubscriber): + def __call__(self, data: dict): + append_to_terms(data["foo"]) + + +@pytest.fixture(autouse=True) +def auto_set_and_close_loop(): + # Create and set a new loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + yield + + # Close the loop after the test + if not loop.is_closed(): + loop.close() + asyncio.set_event_loop(None) # Explicitly unset the loop + + +@pytest.fixture(autouse=True) +def clear_terms(): + yield + + global terms + terms = [] + + +@pytest.mark.broker_common +def test_processing_broker_messages(test_domain): + test_domain.register(DummySubscriber, channel="test_channel") + test_domain.init(traverse=False) + + channel = "test_channel" + message1 = {"foo": "bar"} + message2 = {"foo": "baz"} + test_domain.brokers["default"].publish(channel, message1) + test_domain.brokers["default"].publish(channel, message2) + + engine = Engine(domain=test_domain, test_mode=True) + engine.run() + + global terms + assert len(terms) == 2 + assert terms[0] == "bar" + assert terms[1] == "baz" + + +@pytest.mark.broker_common +def test_no_processing_when_shutting_down(test_domain): + test_domain.register(DummySubscriber, channel="test_channel") + test_domain.init(traverse=False) + + channel = "test_channel" + message = {"foo": "bar"} + test_domain.brokers["default"].publish(channel, message) + + engine = Engine(domain=test_domain, test_mode=True) + engine.shutting_down = True + engine.run() + + global terms + assert len(terms) == 0