Skip to content

Commit

Permalink
0.2.6
Browse files Browse the repository at this point in the history
  • Loading branch information
loRes228 committed Mar 27, 2024
1 parent 2df6f63 commit affbf02
Show file tree
Hide file tree
Showing 15 changed files with 181 additions and 136 deletions.
30 changes: 21 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ from aiogram_broadcaster import Broadcaster, DefaultMailerProperties
from aiogram_broadcaster.contents import MessageSendContent
from aiogram_broadcaster.event import EventRouter
from aiogram_broadcaster.mailer import Mailer
from aiogram_broadcaster.storage.redis import RedisBCRStorage
from aiogram_broadcaster.storage.file import FileBCRStorage

TOKEN = "1234:Abc" # noqa: S105
USER_IDS = {78238238, 78378343, 98765431, 12345678}
OWNER_ID = 61043901

router = Router(name=__name__)
event = EventRouter()
Expand All @@ -33,21 +32,34 @@ event = EventRouter()
@router.message()
async def on_any_message(message: Message, broadcaster: Broadcaster) -> Any:
content = MessageSendContent(message=message)
mailer = await broadcaster.create_mailer(content=content, chats=USER_IDS)
mailer = await broadcaster.create_mailer(
content=content,
chats=USER_IDS,
data={"publisher_id": message.chat.id, "message_id": message.message_id},
)
mailer.start()
await message.answer(text="Run broadcasting...")


@event.completed()
async def notify_complete(mailer: Mailer, bot: Bot) -> None:
async def notify_complete(
mailer: Mailer,
bot: Bot,
publisher_id: int,
message_id: int,
) -> None:
text = (
f"Broadcasting has been completed!\n"
f"Mailer ID: {mailer.id} | Bot ID: {bot.id}\n"
f"Total chats: {mailer.statistic.total_count}\n"
f"Failed chats: {mailer.statistic.failed_count}\n"
f"Success chats: {mailer.statistic.success_count}\n"
f"Total chats: {mailer.statistic.total_chats.total}\n"
f"Failed chats: {mailer.statistic.failed_chats.total}\n"
f"Success chats: {mailer.statistic.success_chats.total}\n"
)
await bot.send_message(
chat_id=publisher_id,
text=text,
reply_to_message_id=message_id,
)
await bot.send_message(chat_id=OWNER_ID, text=text)


def main() -> None:
Expand All @@ -56,7 +68,7 @@ def main() -> None:
dispatcher = Dispatcher()
dispatcher.include_router(router)

bcr_storage = RedisBCRStorage.from_url("redis://localhost:6379")
bcr_storage = FileBCRStorage()
default = DefaultMailerProperties(destroy_on_complete=True)
broadcaster = Broadcaster(bot, storage=bcr_storage, default=default)
broadcaster.event.include(event)
Expand Down
2 changes: 1 addition & 1 deletion aiogram_broadcaster/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.5"
__version__ = "0.2.6"
51 changes: 11 additions & 40 deletions aiogram_broadcaster/broadcaster.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, Iterator, List, Literal, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, Literal, Optional, Set, Tuple, Union
from uuid import uuid4

from aiogram import Bot, Dispatcher
Expand All @@ -11,14 +11,15 @@
from .logger import logger
from .mailer import Mailer, MailerStatus
from .mailer.chat_engine import ChatEngine, ChatState
from .mailer.multiple import MultipleMailers
from .mailer.container import MailerContainer
from .mailer.group import MailerGroup
from .mailer.settings import MailerSettings
from .placeholder import PlaceholderWizard
from .storage.base import BaseBCRStorage
from .storage.record import StorageRecord


class Broadcaster:
class Broadcaster(MailerContainer):
_bots: Dict[int, Bot]
storage: Optional[BaseBCRStorage]
language_getter: BaseLanguageGetter
Expand All @@ -27,7 +28,6 @@ class Broadcaster:
kwargs: Dict[str, Any]
event: EventManager
placeholder: PlaceholderWizard
_mailers: Dict[int, Mailer]

def __init__(
self,
Expand All @@ -38,6 +38,8 @@ def __init__(
context_key: str = "broadcaster",
**kwargs: Any,
) -> None:
super().__init__()

self._bots = {bot.id: bot for bot in bots}
self.storage = storage
self.language_getter = language_getter or DefaultLanguageGetter()
Expand All @@ -48,45 +50,13 @@ def __init__(

self.event = EventManager(name="root")
self.placeholder = PlaceholderWizard(name="root")
self._mailers = {}

def __repr__(self) -> str:
return f"Broadcaster(total_mailers={len(self._mailers)})"

def __str__(self) -> str:
mailers = ", ".join(map(repr, self))
return f"Broadcaster[{mailers}]"

def __contains__(self, item: int) -> bool:
return item in self._mailers

def __getitem__(self, item: int) -> Mailer:
if mailer := self._mailers.get(item):
return mailer
raise LookupError(f"Mailer with id={item} not exists.")

def __iter__(self) -> Iterator[Mailer]:
return iter(self._mailers.values())

def __len__(self) -> int:
return len(self._mailers)

@property
def bots(self) -> Tuple[Bot, ...]:
return tuple(self._bots.values())

@property
def mailers(self) -> Dict[int, Mailer]:
return self._mailers

def get_mailers(self) -> List[Mailer]:
return list(self._mailers.values())

def get_mailer(self, mailer_id: int) -> Optional[Mailer]:
return self._mailers.get(mailer_id)

def as_multiple(self) -> MultipleMailers:
return MultipleMailers(mailers=self._mailers.values())
def as_group(self) -> MailerGroup:
return MailerGroup(*self._mailers.values())

async def create_mailers(
self,
Expand All @@ -103,7 +73,7 @@ async def create_mailers(
exclude_placeholders: Optional[Union[Literal[True], Set[str]]] = None,
data: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> MultipleMailers:
) -> MailerGroup:
if not bots and not self._bots:
raise ValueError("At least one bot must be specified.")
if not bots:
Expand All @@ -126,7 +96,7 @@ async def create_mailers(
)
for bot in bots
]
return MultipleMailers(mailers=mailers)
return MailerGroup(*mailers)

async def create_mailer(
self,
Expand Down Expand Up @@ -260,4 +230,5 @@ def setup(self, dispatcher: Dispatcher, *, include_data: bool = True) -> None:
self.kwargs.update(dispatcher.workflow_data)
if self.storage:
dispatcher.startup.register(self.restore_mailers)
dispatcher.shutdown.register(self.storage.close)
dispatcher.startup.register(self.run_mailers)
2 changes: 1 addition & 1 deletion aiogram_broadcaster/contents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pydantic.functional_validators import ModelWrapValidatorHandler


VALIDATOR_KEY = "__validator"
VALIDATOR_KEY = "__V"


class BaseContent(BaseModel, ABC):
Expand Down
26 changes: 17 additions & 9 deletions aiogram_broadcaster/mailer/chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ChatEngine(BaseModel):
mailer_id: Optional[int] = Field(default=None, exclude=True)
storage: Optional[BaseBCRStorage] = Field(default=None, exclude=True)

def model_post_init(self, __context: Optional[Dict[str, Any]]) -> None:
def model_post_init(self, __context: Dict[str, Any]) -> None:
if not __context:
return
self.mailer_id = __context.get("mailer_id")
Expand Down Expand Up @@ -58,21 +58,29 @@ async def add_chats(self, chats: Iterable[int], state: ChatState) -> Set[int]:
if not difference:
return difference
self.chats[state].update(difference)
if self.storage and self.mailer_id:
async with self.storage.update_record(mailer_id=self.mailer_id) as record:
record.chats = self
await self._preserve()
return difference

async def set_chats_state(self, state: ChatState) -> None:
chats = self.get_chats()
self.chats.clear()
self.chats[state] = chats
await self._preserve()

async def set_chat_state(self, chat: int, state: ChatState) -> None:
from_state = self.resolve_chat_state(chat=chat)
from_state = self._resolve_chat_state(chat=chat)
self.chats[from_state].discard(chat)
self.chats[state].add(chat)
if self.storage and self.mailer_id:
async with self.storage.update_record(mailer_id=self.mailer_id) as record:
record.chats = self
await self._preserve()

def resolve_chat_state(self, chat: int) -> ChatState:
def _resolve_chat_state(self, chat: int) -> ChatState:
for state, chats in self.chats.items():
if chat in chats:
return state
raise LookupError(f"Chat={chats} state is undefined.")

async def _preserve(self) -> None:
if not self.storage or not self.mailer_id:
return
async with self.storage.update_record(mailer_id=self.mailer_id) as record:
record.chats = self
39 changes: 39 additions & 0 deletions aiogram_broadcaster/mailer/container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Dict, Iterator, List, Optional

from .mailer import Mailer


class MailerContainer:
_mailers: Dict[int, Mailer]

def __init__(self, *mailers: Mailer) -> None:
self._mailers = {mailer.id: mailer for mailer in mailers}

def __repr__(self) -> str:
return f"{type(self).__name__}(total_mailers={len(self._mailers)})"

def __str__(self) -> str:
mailers = ", ".join(map(repr, self))
return f"{type(self).__name__}[{mailers}]"

def __contains__(self, item: int) -> bool:
return item in self._mailers

def __getitem__(self, item: int) -> Mailer:
return self._mailers[item]

def __iter__(self) -> Iterator[Mailer]:
return iter(self._mailers.copy().values())

def __len__(self) -> int:
return len(self._mailers)

@property
def mailers(self) -> Dict[int, Mailer]:
return self._mailers.copy()

def get_mailer(self, mailer_id: int) -> Optional[Mailer]:
return self._mailers.get(mailer_id)

def get_mailers(self) -> List[Mailer]:
return list(self._mailers.values())
Original file line number Diff line number Diff line change
@@ -1,63 +1,50 @@
from asyncio import gather, wait
from typing import Any, Coroutine, Dict, Iterable, Iterator, List, Tuple
from typing import Any, Coroutine, Dict, Iterable, List

from aiogram_broadcaster.logger import logger

from .container import MailerContainer
from .mailer import Mailer


class MultipleMailers:
mailers: Tuple[Mailer, ...]

def __init__(self, mailers: Iterable[Mailer]) -> None:
self.mailers = tuple(mailers)

def __iter__(self) -> Iterator[Mailer]:
return iter(self.mailers)

def __len__(self) -> int:
return len(self.mailers)

def __repr__(self) -> str:
return f"MultipleMailers(total_mailers={len(self.mailers)})"

def __str__(self) -> str:
mailers = ", ".join(map(repr, self.mailers))
return f"MultipleMailers[{mailers}]"

class MailerGroup(MailerContainer):
def start(self, **kwargs: Any) -> None:
for mailer in self.mailers:
for mailer in self._mailers.values():
try:
mailer.start(**kwargs)
except RuntimeError: # noqa: PERF203
logger.exception("A start error occurred")

async def wait(self) -> None:
futures = [mailer.wait() for mailer in self.mailers]
futures = [mailer.wait() for mailer in self._mailers.values()]
await wait(futures)

async def run(self, **kwargs: Any) -> Dict[Mailer, Any]:
futures = [mailer.run(**kwargs) for mailer in self.mailers]
futures = [mailer.run(**kwargs) for mailer in self._mailers.values()]
return await self._gather_futures(futures=futures)

async def stop(self) -> Dict[Mailer, Any]:
futures = [mailer.stop() for mailer in self.mailers]
futures = [mailer.stop() for mailer in self._mailers.values()]
return await self._gather_futures(futures=futures)

async def destroy(self) -> Dict[Mailer, Any]:
futures = [mailer.destroy() for mailer in self.mailers]
futures = [mailer.destroy() for mailer in self._mailers.values()]
return await self._gather_futures(futures=futures)

async def add_chats(self, chats: Iterable[int]) -> Dict[Mailer, bool]:
futures = [mailer.add_chats(chats=chats) for mailer in self.mailers]
futures = [mailer.add_chats(chats=chats) for mailer in self._mailers.values()]
return await self._gather_futures(futures=futures)

async def send_content(self, chat_id: int) -> Dict[Mailer, Any]:
futures = [mailer.send_content(chat_id=chat_id) for mailer in self.mailers]
async def reset_chats(self) -> None:
futures = [mailer.reset_chats() for mailer in self._mailers.values()]
await self._gather_futures(futures=futures)

async def send(self, chat_id: int) -> Dict[Mailer, Any]:
futures = [mailer.send(chat_id=chat_id) for mailer in self._mailers.values()]
return await self._gather_futures(futures=futures)

async def _gather_futures(self, futures: List[Coroutine[Any, Any, Any]]) -> Dict[Mailer, Any]:
if not futures:
return {}
results = await gather(*futures, return_exceptions=True)
return dict(zip(self.mailers, results))
return dict(zip(self._mailers.values(), results))
Loading

0 comments on commit affbf02

Please sign in to comment.