Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run subscribers within Engine #451

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ markers = [
"message_db",
"sendgrid",
"database",
"broker_common",
"eventstore",
"no_test_domain",
]
Expand Down
9 changes: 9 additions & 0 deletions src/protean/adapters/broker/inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ def _get_next(self, channel: str) -> dict | None:
return self._messages[channel].pop(0)
return None

def read(self, channel: str, no_of_messages: int) -> list[dict]:
"""Read messages from the broker"""
messages = []
while no_of_messages > 0 and self._messages[channel]:
messages.append(self._messages[channel].pop(0))
no_of_messages -= 1

return messages

def _data_reset(self) -> None:
"""Flush all data in broker instance"""
self._messages.clear()
9 changes: 9 additions & 0 deletions src/protean/adapters/broker/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,14 @@ def _get_next(self, channel: str) -> dict | None:

return None

def read(self, channel: str, no_of_messages: int) -> list[dict]:
messages = []
for _ in range(no_of_messages):
bytes_message = self.redis_instance.lpop(channel)
if bytes_message:
messages.append(json.loads(bytes_message))

return messages

def _data_reset(self) -> None:
self.redis_instance.flushall()
12 changes: 12 additions & 0 deletions src/protean/port/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ def get_next(self, channel: str) -> dict | None:
def _get_next(self, channel: str) -> dict | None:
"""Overridden method to retrieve the next message to process from broker."""

@abstractmethod
def read(self, channel: str, no_of_messages: int) -> list[dict]:
"""Read messages from the broker.

Args:
channel (str): The channel from which to read messages
no_of_messages (int): The number of messages to read

Returns:
list[dict]: The list of messages
"""

@abstractmethod
def _data_reset(self) -> None:
"""Flush all data in broker instance.
Expand Down
154 changes: 154 additions & 0 deletions src/protean/server/broker_subscription.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import asyncio
import logging
from typing import Type

from protean.core.subscriber import BaseSubscriber
from protean.port.broker import BaseBroker

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s,%(msecs)d %(levelname)s: %(message)s",
datefmt="%H:%M:%S",
)

logger = logging.getLogger(__name__)


class BrokerSubscription:
"""
Represents a subscription to a broker channel.

A broker subscription allows a subscriber to receive and process messages from a specific channel.
It provides methods to start and stop the subscription, as well as process messages in batches.
"""

def __init__(
self,
engine,
broker,
subscriber_id: str,
channel: str,
handler: Type[BaseSubscriber],
messages_per_tick: int = 10,
tick_interval: int = 1,
) -> None:
"""
Initialize the Subscription object.

Args:
engine: The Protean engine instance.
subscriber_id (str): The unique identifier for the subscriber.
channel (str): The name of the stream to subscribe to.
handler (Union[BaseEventHandler, BaseCommandHandler]): The event or command handler.
messages_per_tick (int, optional): The number of messages to process per tick. Defaults to 10.
tick_interval (int, optional): The interval between ticks. Defaults to 1.
"""
self.engine = engine
self.broker: BaseBroker = broker
self.loop = engine.loop

self.subscriber_id = subscriber_id
self.channel = channel
self.handler = handler
self.messages_per_tick = messages_per_tick
self.tick_interval = tick_interval

self.keep_going = True # Initially set to keep going

async def start(self) -> None:
"""
Start the subscription.

This method initializes the subscription by loading the last position from the event store
and starting the polling loop.

Returns:
None
"""
logger.debug(f"Starting {self.subscriber_id}")

# Start the polling loop
self.loop.create_task(self.poll())

async def poll(self) -> None:
"""
Polling loop for processing messages.

This method continuously polls for new messages and processes them by calling the `tick` method.
It sleeps for the specified `tick_interval` between each tick.

Returns:
None
"""
await self.tick()

if self.keep_going and not self.engine.shutting_down:
# Keep control of the loop if in test mode
# Otherwise `asyncio.sleep` will give away control and
# the loop will be able to be stopped with `shutdown()`
if not self.engine.test_mode:
await asyncio.sleep(self.tick_interval)

Check warning on line 90 in src/protean/server/broker_subscription.py

View check run for this annotation

Codecov / codecov/patch

src/protean/server/broker_subscription.py#L90

Added line #L90 was not covered by tests

self.loop.create_task(self.poll())

async def tick(self):
"""
This method retrieves the next batch of messages to process and calls the `process_batch` method
to handle each message. It also updates the read position after processing each message.

Returns:
None
"""
messages = await self.get_next_batch_of_messages()
if messages:
await self.process_batch(messages)

async def shutdown(self):
"""
Shutdown the subscription.

This method signals the subscription to stop polling and updates the current position to the store.
It also logs a message indicating the shutdown of the subscription.

Returns:
None
"""
self.keep_going = False # Signal to stop polling
logger.info(f"Shutting down subscription {self.subscriber_id}")

Check warning on line 117 in src/protean/server/broker_subscription.py

View check run for this annotation

Codecov / codecov/patch

src/protean/server/broker_subscription.py#L116-L117

Added lines #L116 - L117 were not covered by tests

async def get_next_batch_of_messages(self):
"""
Get the next batch of messages to process.

This method reads messages from the event store starting from the current position + 1.
It retrieves a specified number of messages per tick and applies filtering based on the origin stream name.

Returns:
List[Message]: The next batch of messages to process.
"""
messages = self.broker.read(
self.channel,
no_of_messages=self.messages_per_tick,
) # FIXME Implement filtering

return messages

async def process_batch(self, messages: list[dict]):
"""
Process a batch of messages.

This method takes a batch of messages and processes each message by calling the `handle_message` method
of the engine. It also updates the read position after processing each message. If an exception occurs
during message processing, it logs the error using the `log_error` method.

Args:
messages (List[Message]): The batch of messages to process.

Returns:
int: The number of messages processed.
"""
logging.debug(f"Processing {len(messages)} messages...")
for message in messages:
await self.engine.handle_broker_message(self.handler, message)

return len(messages)
88 changes: 74 additions & 14 deletions src/protean/server/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

from protean.core.command_handler import BaseCommandHandler
from protean.core.event_handler import BaseEventHandler
from protean.core.subscriber import BaseSubscriber
from protean.utils.globals import g
from protean.utils.mixins import Message

from .broker_subscription import BrokerSubscription
from .subscription import Subscription

logging.basicConfig(
Expand Down Expand Up @@ -50,7 +52,7 @@ def __init__(self, domain, test_mode: bool = False, debug: bool = False) -> None

self.loop = asyncio.get_event_loop()

# FIXME Gather all handlers
# Gather all handlers
self._subscriptions = {}
for handler_name, record in self.domain.registry.event_handlers.items():
# Create a subscription for each event handler
Expand All @@ -72,6 +74,54 @@ def __init__(self, domain, test_mode: bool = False, debug: bool = False) -> None
record.cls,
)

# Gather broker subscriptions
self._broker_subscriptions = {}

for (
subscriber_name,
subscriber_record,
) in self.domain.registry.subscribers.items():
subscriber_cls = subscriber_record.cls
broker_name = subscriber_cls.meta_.broker
broker = self.domain.brokers[broker_name]
channel = subscriber_cls.meta_.channel
self._broker_subscriptions[subscriber_name] = BrokerSubscription(
self,
broker,
subscriber_name,
channel,
subscriber_cls,
)

async def handle_broker_message(
self, subscriber_cls: Type[BaseSubscriber], message: dict
) -> None:
"""
Handle a message received from the broker.
"""

if self.shutting_down:
return # Skip handling if shutdown is in progress

with self.domain.domain_context():
try:
subscriber = subscriber_cls()
subscriber(message)

logger.info(
f"{subscriber_cls.__name__} processed message successfully."
)
except Exception as exc:
logger.error(
f"Error handling message in {subscriber_cls.__name__}: {str(exc)}"
)
# Print the stack trace
logger.error(traceback.format_exc())
# subscriber_cls.handle_error(exc, message)

await self.shutdown(exit_code=1)
return

async def handle_message(
self,
handler_cls: Type[Union[BaseCommandHandler, BaseEventHandler]],
Expand All @@ -81,7 +131,7 @@ async def handle_message(
Handle a message by invoking the appropriate handler class.

Args:
handler_cls (Type[Union[BaseCommandHandler, BaseEventHandler]]): The handler class to invoke.
handler_cls (Type[Union[BaseCommandHandler, BaseEventHandler]]): The handler class
message (Message): The message to be handled.

Returns:
Expand Down Expand Up @@ -131,8 +181,12 @@ async def shutdown(self, signal=None, exit_code=0):
self.shutting_down = True # Set shutdown flag

try:
if signal:
logger.info(f"Received exit signal {signal.name}...")
msg = (
"Received exit signal {signal.name}. Shutting down..."
if signal
else "Shutting down..."
)
logger.info(msg)

# Store the exit code
self.exit_code = exit_code
Expand All @@ -154,8 +208,7 @@ async def shutdown(self, signal=None, exit_code=0):
await asyncio.gather(*subscription_shutdown_tasks, return_exceptions=True)
logger.info("All subscriptions have been shut down.")
finally:
if self.loop.is_running():
self.loop.stop()
self.loop.stop()

def run(self):
"""
Expand Down Expand Up @@ -184,19 +237,26 @@ def handle_exception(loop, context):

self.loop.set_exception_handler(handle_exception)

if len(self._subscriptions) == 0:
if len(self._subscriptions) == 0 and len(self._broker_subscriptions) == 0:
logger.info("No subscriptions to start. Exiting...")
return

# Start consumption, one per subscription
try:
tasks = [
self.loop.create_task(subscription.start())
for _, subscription in self._subscriptions.items()
]
subscription_tasks = [
self.loop.create_task(subscription.start())
for _, subscription in self._subscriptions.items()
]

broker_subscription_tasks = [
self.loop.create_task(subscription.start())
for _, subscription in self._broker_subscriptions.items()
]

try:
if self.test_mode:
# If in test mode, run until all tasks complete
self.loop.run_until_complete(asyncio.gather(*tasks))
self.loop.run_until_complete(
asyncio.gather(*subscription_tasks, *broker_subscription_tasks)
)
# Then immediately call and await the shutdown directly
self.loop.run_until_complete(self.shutdown())
else:
Expand Down
Loading