From 3c3842e634180197c48bc458bf4fa769c7c59bda Mon Sep 17 00:00:00 2001 From: taimast Date: Thu, 14 Dec 2023 09:13:47 +0300 Subject: [PATCH] add diffussial --- init_project.py | 3 +- modules/diffusion/__init__.py | 0 modules/diffusion/diffusion.py | 32 ++ .../diffusion/image_scale/tesseract_rec.py | 13 + modules/diffusion/manager.py | 72 ++++ modules/diffusion/math_api/mathpix.py | 79 ++++ modules/diffusion/midjourney/__init__.py | 0 modules/diffusion/midjourney/cache.py | 108 ++++++ modules/diffusion/midjourney/midjourney.py | 344 ++++++++++++++++++ modules/diffusion/midjourney/model.py | 8 + modules/diffusion/midjourney/queue_manager.py | 237 ++++++++++++ modules/diffusion/midjourney/request.py | 37 ++ modules/diffusion/midjourney/response.py | 92 +++++ modules/diffusion/openai_gen/__init__.py | 0 .../diffusion/openai_gen/chatgpt/__init__.py | 0 .../openai_gen/chatgpt/dialog/dialog.py | 83 +++++ .../openai_gen/chatgpt/dialog/message.py | 42 +++ .../openai_gen/chatgpt/group_manager.py | 52 +++ .../diffusion/openai_gen/chatgpt/manager.py | 58 +++ modules/diffusion/openai_gen/chatgpt/model.py | 73 ++++ .../diffusion/openai_gen/chatgpt/sender.py | 54 +++ modules/diffusion/openai_gen/models.py | 74 ++++ modules/diffusion/replicate_gen/__init__.py | 0 modules/diffusion/replicate_gen/base.py | 171 +++++++++ modules/diffusion/replicate_gen/models.py | 46 +++ todo.txt | 3 +- 26 files changed, 1679 insertions(+), 2 deletions(-) create mode 100644 modules/diffusion/__init__.py create mode 100644 modules/diffusion/diffusion.py create mode 100644 modules/diffusion/image_scale/tesseract_rec.py create mode 100644 modules/diffusion/manager.py create mode 100644 modules/diffusion/math_api/mathpix.py create mode 100644 modules/diffusion/midjourney/__init__.py create mode 100644 modules/diffusion/midjourney/cache.py create mode 100644 modules/diffusion/midjourney/midjourney.py create mode 100644 modules/diffusion/midjourney/model.py create mode 100644 modules/diffusion/midjourney/queue_manager.py create mode 100644 modules/diffusion/midjourney/request.py create mode 100644 modules/diffusion/midjourney/response.py create mode 100644 modules/diffusion/openai_gen/__init__.py create mode 100644 modules/diffusion/openai_gen/chatgpt/__init__.py create mode 100644 modules/diffusion/openai_gen/chatgpt/dialog/dialog.py create mode 100644 modules/diffusion/openai_gen/chatgpt/dialog/message.py create mode 100644 modules/diffusion/openai_gen/chatgpt/group_manager.py create mode 100644 modules/diffusion/openai_gen/chatgpt/manager.py create mode 100644 modules/diffusion/openai_gen/chatgpt/model.py create mode 100644 modules/diffusion/openai_gen/chatgpt/sender.py create mode 100644 modules/diffusion/openai_gen/models.py create mode 100644 modules/diffusion/replicate_gen/__init__.py create mode 100644 modules/diffusion/replicate_gen/base.py create mode 100644 modules/diffusion/replicate_gen/models.py diff --git a/init_project.py b/init_project.py index 87ee7ef..a2ab3d6 100644 --- a/init_project.py +++ b/init_project.py @@ -109,7 +109,8 @@ def install_dependencies(project_path: Path): "pydantic_settings", "pydantic>=2.1.1", "bs4", - "lxml" + "lxml", + "alembic" ] utils = [ "watchdog", diff --git a/modules/diffusion/__init__.py b/modules/diffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/diffusion/diffusion.py b/modules/diffusion/diffusion.py new file mode 100644 index 0000000..753d4b3 --- /dev/null +++ b/modules/diffusion/diffusion.py @@ -0,0 +1,32 @@ +import typing +from dataclasses import dataclass, field +from enum import StrEnum +from openai import AsyncClient + +class DiffusionType(StrEnum): + TEXT_TO_IMAGE = "text2image" + TEXT_TO_TEXT = "text2text" + TEXT_TO_VIDEO = "text2video" + TEXT_TO_AUDIO = "text2audio" + AUDIO_TO_TEXT = "audio2text" + IMAGE_TO_IMAGE = "image2image" + IMAGE_TO_TEXT = "image2text" + + def get_text(self, l10n) -> str: + return l10n.get(f"diffusion-type-button-{self}") + + @classmethod + def to_text_types(cls) -> list['DiffusionType']: + return [cls.TEXT_TO_TEXT, cls.AUDIO_TO_TEXT, cls.IMAGE_TO_TEXT] + + @classmethod + def to_image_types(cls) -> list['DiffusionType']: + return [cls.TEXT_TO_IMAGE, cls.IMAGE_TO_IMAGE] + + +@dataclass +class DiffusionModel: + model: str + name: str | None = None + type: DiffusionType = DiffusionType.TEXT_TO_IMAGE + default_inputs: dict[str, typing.Any] = field(default_factory=dict) diff --git a/modules/diffusion/image_scale/tesseract_rec.py b/modules/diffusion/image_scale/tesseract_rec.py new file mode 100644 index 0000000..c157145 --- /dev/null +++ b/modules/diffusion/image_scale/tesseract_rec.py @@ -0,0 +1,13 @@ +import pytesseract +from PIL import Image + +# If you don't have tesseract executable in your PATH, include the following: +pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' +# Example tesseract_cmd = r'C:\Program Files (x86)\Tesseract-OCR\tesseract' +file = open("screen.png", "rb") +# Simple image to string +image = Image.open(file) +print(pytesseract.image_to_string( + image, + # lang="rus" +)) diff --git a/modules/diffusion/manager.py b/modules/diffusion/manager.py new file mode 100644 index 0000000..bf5201e --- /dev/null +++ b/modules/diffusion/manager.py @@ -0,0 +1,72 @@ +from functools import cache +from typing import Any + +from diffusion_bot.apps.diffusion.diffusion import DiffusionModel, DiffusionType +from diffusion_bot.apps.diffusion.math_api.mathpix import MathPix +from diffusion_bot.apps.diffusion.midjourney.midjourney import MidjourneyWorker +from diffusion_bot.apps.diffusion.midjourney.model import MidjourneyDiffusionModel +from diffusion_bot.apps.diffusion.openai_gen.chatgpt.group_manager import GPTGroupManager +from diffusion_bot.apps.diffusion.openai_gen.chatgpt.model import GPTModel +from diffusion_bot.apps.diffusion.openai_gen.models import Whisper, DallE, OpenAIDiffusionModel +from diffusion_bot.apps.diffusion.replicate_gen.base import Replicate +from diffusion_bot.apps.diffusion.replicate_gen.models import ReplicateDiffusionModel + + +class DiffusionManager: + def __init__( + self, + gpt_group_manager: GPTGroupManager, + replicate: Replicate, + midjourney_worker: MidjourneyWorker, + mathpix: MathPix, + + openai_models: list[GPTModel, DallE, Whisper], + replicate_models: list[ReplicateDiffusionModel], + photo_scaler_models: list[ReplicateDiffusionModel], + midjourney_models: list[MidjourneyDiffusionModel], + ): + self.replicate = replicate + self.midjourney_worker = midjourney_worker + self.gpt_group_manager = gpt_group_manager + self.mathpix = mathpix + + self.openai_models = openai_models + self.replicate_models = replicate_models + self.photo_scaler_models = photo_scaler_models + self.midjourney_models = midjourney_models + + self.models = ( + self.midjourney_models + + self.openai_models + + self.replicate_models + + self.photo_scaler_models + ) + + @cache + def get_model(self, model_name: str) -> DiffusionModel: + for model in self.models: + if model.name == model_name: + return model + + @cache + def get_models(self, type: DiffusionType = None) -> list[DiffusionModel]: + # todo L1 TODO 11.06.2023 19:31 taima: Возможно затратная операция, добавить кеширование + if type is None: + return self.models + + models = [] + for model in self.models: + if model.type == type: + models.append(model) + return models + + async def predict(self, model_name: str, prompt: Any = None, **kwargs) -> Any: + model = self.get_model(model_name) + if isinstance(model, ReplicateDiffusionModel): + return await self.replicate.predict(model, prompt) + + elif isinstance(model, GPTModel): + return await self.gpt_group_manager.stream_completion(model.model, **kwargs) + else: + model: OpenAIDiffusionModel + return await model.predict(prompt) diff --git a/modules/diffusion/math_api/mathpix.py b/modules/diffusion/math_api/mathpix.py new file mode 100644 index 0000000..4c0f447 --- /dev/null +++ b/modules/diffusion/math_api/mathpix.py @@ -0,0 +1,79 @@ +import asyncio +import base64 +from dataclasses import dataclass +from pprint import pprint +from typing import ClassVar, Any + +import aiohttp +from loguru import logger + + +@dataclass +class MathPix: + url: ClassVar[str] = "https://api.mathpix.com/v3/text" + app_id: str + app_key: str + session: Any = None + + def init_session(self): + self.session = aiohttp.ClientSession(headers={ + "app_id": self.app_id, + "app_key": self.app_key + }) + + async def __aenter__(self): + self.init_session() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.session.close() + + async def predict(self, photo: bytes) -> str: + src = base64.b64encode(photo).decode('utf-8') + src = f"data:image/png;base64,{src}" + json = { + "src": src, + "formats": ["text"], + "data_options": { + "include_asciimath": True, + # "include_latex": True + } + } + async with self.session.post(self.url, json=json) as res: + result = await res.json() + return result.get("text") + + +async def main(): + # print(url) + with open("img.png", "rb") as f: + photo = f.read() + src = base64.b64encode(photo).decode('utf-8') + src = f"data:image/png;base64,{src}" + # src = await save_photo(f) + # print(src) + pass + # base64_photo = base64.b64encode(photo) + json = { + "src": src, + # "src": "https://mathpix-ocr-examples.s3.amazonaws.com/cases_hw.jpg", + "formats": ["text", "data", "html"], + "data_options": { + "include_asciimath": True, + # "include_latex": True + } + } + headers = { + "app_id": "pacmanbotai_gmail_com_e98796_e16211", + "app_key": "f707dc9f517e0e7add256ed0256db9cf20868143d2fd64118b5bdfeca147d4f9" + } + url = "https://api.mathpix.com/v3/text" + logger.info("Start request") + async with aiohttp.ClientSession(headers=headers) as session: + async with session.post(url, json=json) as res: + result = await res.json() + pprint(result) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/modules/diffusion/midjourney/__init__.py b/modules/diffusion/midjourney/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/diffusion/midjourney/cache.py b/modules/diffusion/midjourney/cache.py new file mode 100644 index 0000000..27f90da --- /dev/null +++ b/modules/diffusion/midjourney/cache.py @@ -0,0 +1,108 @@ +from dataclasses import dataclass + +import CacheToolsUtils as ctu +import aiohttp +import orjson +from aiogram import Bot +from aiogram.types import BufferedInputFile +from cachetools import TTLCache +from redis import Redis + +from diffusion_bot.apps.diffusion.midjourney.response import TriggerID, ProcessingTrigger, Trigger, MsgID +from diffusion_bot.apps.share.cache_manager import CACHE_MANAGER + +MJ_TRIGGERS_PROCESSING_CACHE: TTLCache[ + TriggerID, ProcessingTrigger +] = CACHE_MANAGER.midjourney_triggers_processing_cache +MJ_TRIGGERS_ID_CACHE: TTLCache[MsgID, Trigger] = CACHE_MANAGER.midjourney_triggers_id_cache + + +class PrefixedRedisCache(ctu.PrefixedRedisCache): + def _serialize(self, s): + # if not serializable set NOne + return orjson.dumps(s, default=lambda o: None) + + def _deserialize(self, s): + return orjson.loads(s) + + def _key(self, key): + return orjson.dumps(key) + + +Cache = PrefixedRedisCache + + +@dataclass +class CacheAnimation: + url: str + file_id: str | None = None + + async def get_file_id(self, bot: Bot) -> str: + if self.file_id: + return self.file_id + file = await download_file(self.url) + file_id = await bot.send_animation( + chat_id=269019356, + animation=BufferedInputFile(file, "animation.gif") + ) + self.file_id = file_id.animation.file_id + return self.file_id + + +class MJCache: + + def __init__(self, redis: Redis, ttl: int = 60 * 60 * 24 * 3): + self.redis = redis + self.processing_triggers: Cache[TriggerID, ProcessingTrigger] = Cache( + redis, + prefix="processing.triggers.", + ttl=ttl + ) + self.done_triggers: Cache[MsgID, Trigger] = Cache( + redis, + prefix="done.triggers.", + ttl=ttl + ) + self.imagine_animation = CacheAnimation( + "https://cdn.dribbble.com/users/1514097/screenshots/3457456/media/b1cfddae9e7b9645b9cde7ad9ee4f6bf.gif" + ) + self.variation_animation = CacheAnimation( + "https://usagif.com/wp-content/uploads/loading-9.gif" + ) + + def get_processing_trigger(self, trigger_id: TriggerID) -> ProcessingTrigger | None: + try: + pt = MJ_TRIGGERS_PROCESSING_CACHE.get(trigger_id) + if pt: + return pt + + data = self.processing_triggers[trigger_id] + return ProcessingTrigger(**data) + except KeyError: + return None + + def get_done_trigger(self, msg_id: MsgID) -> Trigger | None: + try: + trigger = MJ_TRIGGERS_ID_CACHE.get(msg_id) + if trigger: + return trigger + + data = self.done_triggers[msg_id] + return Trigger(**data) + except KeyError: + return None + + def set_processing_trigger(self, processing_trigger: ProcessingTrigger): + MJ_TRIGGERS_PROCESSING_CACHE[processing_trigger.trigger_id] = processing_trigger + self.processing_triggers[processing_trigger.trigger_id] = processing_trigger.__dict__ + + def set_done_trigger(self, trigger: Trigger): + MJ_TRIGGERS_ID_CACHE[trigger.msg_id] = trigger + self.done_triggers[trigger.msg_id] = trigger.__dict__ + + +async def download_file(url: str): + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + if resp.status == 200: + return await resp.read() diff --git a/modules/diffusion/midjourney/midjourney.py b/modules/diffusion/midjourney/midjourney.py new file mode 100644 index 0000000..8d2f6b7 --- /dev/null +++ b/modules/diffusion/midjourney/midjourney.py @@ -0,0 +1,344 @@ +from __future__ import annotations + +import asyncio +import re +from dataclasses import dataclass, field +from pprint import pformat +from typing import Awaitable +from typing import Union, BinaryIO, TYPE_CHECKING, Callable + +import aiohttp +from aiogram import Bot +from aiogram.types import Message, InputMediaPhoto, BufferedInputFile, InputMediaAnimation +from aiohttp import web, FormData +from fluentogram import TranslatorHub +from loguru import logger +from sqlalchemy.ext.asyncio import async_sessionmaker + +from diffusion_bot.apps.bot.callback_data.midjourney import MidjourneyAction +from diffusion_bot.apps.bot.keyboards.common import common_kbs +from diffusion_bot.apps.bot.keyboards.common.common_kbs import md +from diffusion_bot.db.models import User +from .cache import MJCache +from .queue_manager import MJQueueManager +from .request import MJRequest +from .response import ( + CallbackData, + Attachment, + Trigger, + ProcessingTrigger, + Embed, + TriggerID, + BannedWordsResult +) + +if TYPE_CHECKING: + from diffusion_bot.locales.stubs.ru.stub import TranslatorRunner + +MJClass = Union['MidjourneySeverMixin', 'MidjourneyWorker'] + +PERCENT_RE = re.compile(r"\((\d+%)") +PROMPT_RE = re.compile(r"#>(.*)\*\*") + +# async method +MJActionMethodType = Callable[[MJRequest], Awaitable[TriggerID]] + + +class MidjourneyServerMixin: + host: str = "0.0.0.0" + port: int = 8065 + cb_path: str = "/callback" + ban_path = "/banned" + + async def start(self: MJClass): + app = web.Application() + app.add_routes([web.post(self.cb_path, self.handle_callback)]) + app.add_routes([web.post(self.ban_path, self.handle_banned)]) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, host=self.host, port=self.port) + await site.start() + + async def handle_callback(self, request: web.Request): + data: CallbackData = await request.json() + logger.info(pformat(data)) + trigger_id = data["trigger_id"] + content = data["content"] + attachments = data.get("attachments") + if attachments: + asyncio.create_task(self.update_message_status( + data, + trigger_id, + content, + attachments[0], + data.get("type") == "end" + )) + elif data.get("embeds"): + asyncio.create_task(self.send_embed(data["embeds"][0], trigger_id)) + return web.Response(text="ok") + + @classmethod + async def handle_banned(cls, request: web.Request): + data: BannedWordsResult = await request.json() + logger.info(f"Banned Trigger {pformat(data)}") + await MJQueueManager.release_by_banned_words(data["words"]) + return web.Response(text="ok") + + async def download_file(self: MJClass, url: str): + async with self.session.get(url) as resp: + if resp.status == 200: + return await resp.read() + + async def send_embed(self: MJClass, embed: Embed, trigger_id: TriggerID): + processing_trigger = self.cache.get_processing_trigger(trigger_id) + attachment_url = embed["image"]["proxy_url"] + file_bytes = await self.download_file(attachment_url) + file = BufferedInputFile(file_bytes, filename="image.jpg") + media = InputMediaPhoto( + media=file, + caption=f"{embed['description']}", + parse_mode="Markdown", + ) + await self.bot.edit_message_media( + media=media, + message_id=processing_trigger.message_id, + chat_id=processing_trigger.user_id, + ) + + processing_trigger = self.cache.get_processing_trigger(trigger_id) + + if processing_trigger: + processing_trigger.release() + await self.resize_tokens(processing_trigger) + + async def resize_tokens(self: MJClass, pt: ProcessingTrigger): + cost = pt.action.cost() + async with self.db_sessionmaker() as session: + user = await session.get(User, pt.user_id) + user.tokens -= cost + await session.commit() + l10n: TranslatorRunner = self.translator_hub.get_translator_by_locale(user.language_code) + await self.bot.send_message( + user.id, + l10n.dialog.used_tokens(spent_tokens=cost) + ) + + @classmethod + def parse_status(cls, content: str): + percentage = PERCENT_RE.search(content) + prompt = PROMPT_RE.search(content) + data = {} + if percentage: + data["percentage"] = percentage.group(1) + if prompt: + data["prompt"] = prompt.group(1) + return data + + def get_l10n(self: MJClass, user: User) -> TranslatorRunner: + return self.translator_hub.get_translator_by_locale(user.language_code) + + async def get_l10n_async(self: MJClass, user_id: int) -> TranslatorRunner: + async with self.db_sessionmaker() as session: + user = await session.get(User, user_id) + return self.translator_hub.get_translator_by_locale(user.language_code) + + async def update_message_status( + self: MJClass, + data: CallbackData, + trigger_id: TriggerID, + content: str, + attachment: Attachment, + done: bool = False + ): + if not (p_trigger := self.cache.get_processing_trigger(trigger_id)): + logger.warning(f"Trigger {trigger_id} not found") + return + attachment_url = attachment["proxy_url"] + file_bytes = await self.download_file(attachment_url) + file = BufferedInputFile(file_bytes, attachment["filename"]) + + trigger_stats = self.parse_status(content) + prompt = trigger_stats.get("prompt") + + l10n = await self.get_l10n_async(p_trigger.user_id) + + if trigger_stats: + percentage = md.hcode(trigger_stats.get('percentage', '')) + content = l10n.diffusion.predicting() + content = f"{md.hitalic(prompt)}\n\n{content} {percentage}" + parse_mode = "HTML" + else: + parse_mode = None + + reply_markup = None + + if done: + content = l10n.diffusion.predicted() + content = f"{content}\n" \ + f"{md.hcode(prompt)}" + trigger = Trigger.from_callback_data(data) + if p_trigger.action in ( + MidjourneyAction.VARIATION, MidjourneyAction.REFRESH, MidjourneyAction.IMAGINE): + reply_markup = common_kbs.midjourney_attachment_scale(trigger.msg_id) + trigger.prompt = trigger_stats.get("prompt", "") + + self.cache.set_done_trigger(trigger) + p_trigger.release() + + media = InputMediaPhoto( + media=file, + caption=content, + parse_mode=parse_mode, + ) + if p_trigger.action in (MidjourneyAction.VARIATION, MidjourneyAction.REFRESH) and not done: + file = await self.cache.imagine_animation.get_file_id(self.bot) + media = InputMediaAnimation( + media=file, + caption=content, + parse_mode=parse_mode, + ) + + await self.bot.edit_message_media( + media=media, + message_id=p_trigger.message_id, + chat_id=p_trigger.user_id, + reply_markup=reply_markup, + ) + + if done: + await self.bot.send_document( + chat_id=p_trigger.user_id, + document=file, + reply_to_message_id=p_trigger.message_id, + ) + await self.resize_tokens(p_trigger) + + +@dataclass +class MidjourneyWorker(MidjourneyServerMixin): + base_ulr: str + bot: Bot + db_sessionmaker: async_sessionmaker + translator_hub: TranslatorHub + cache: MJCache + tasks: dict[TriggerID, Message] = field(default_factory=dict) + session: aiohttp.ClientSession | None = None + + async def init_session(self): + self.session = aiohttp.ClientSession() + + async def send_loading(self, request: MJRequest): + + animation = await self.cache.imagine_animation.get_file_id(self.bot) + + l10n = self.get_l10n(request.user) + message = await request.message.answer_animation( + animation=animation, + caption=l10n.diffusion.predicting(), + ) + + return message + + async def _request(self, path: str, method: str = "POST", **kwargs): + async with self.session.request(method, f"{self.base_ulr}/{path}", **kwargs) as resp: + return await resp.json() + + async def _save_trigger(self, request: MJRequest) -> TriggerID: + trigger_id = request.trigger_id + processing_trigger = ProcessingTrigger.from_request(request) + message = await self.send_loading(request) + processing_trigger.message_id = message.message_id + self.cache.set_processing_trigger(processing_trigger) + return trigger_id + + async def imagine(self, request: MJRequest) -> TriggerID: + trigger_response = await self._request( + "imagine", + json={"prompt": request.prompt, "picurl": request.picurl} + ) + logger.info(pformat(trigger_response)) + if not trigger_response['message'] == "success": + raise Exception(trigger_response['message']) + + request.trigger_id = trigger_response["trigger_id"] + return await self._save_trigger(request) + + async def upscale(self, request: MJRequest) -> TriggerID: + trigger = self.cache.get_done_trigger(request.msg_id) + trigger_response = await self._request( + "upscale", + json={"index": request.index} | trigger.__dict__ + ) + + if not trigger_response['message'] == "success": + raise Exception(trigger_response['message']) + + request.prompt = trigger.prompt + request.trigger_id = trigger.trigger_id + return await self._save_trigger(request) + + async def variation(self, request: MJRequest) -> TriggerID: + trigger = self.cache.get_done_trigger(request.msg_id) + trigger_response = await self._request( + "variation", + json={"index": request.index} | trigger.__dict__ + ) + if not trigger_response['message'] == "success": + raise Exception(trigger_response['message']) + + request.prompt = trigger.prompt + request.trigger_id = trigger.trigger_id + return await self._save_trigger(request) + + async def refresh(self, request: MJRequest) -> TriggerID: + trigger = self.cache.get_done_trigger(request.msg_id) + trigger_response = await self._request( + "reset", + json=trigger.__dict__ + ) + if not trigger_response['message'] == "success": + raise Exception(trigger_response['message']) + + request.prompt = trigger.prompt + request.trigger_id = trigger.trigger_id + return await self._save_trigger(request) + + async def describe(self, request: MJRequest) -> TriggerID: + response = await self._request( + "describe", + json={ + "upload_filename": request.filename, + "trigger_id": request.trigger_id + } + ) + if not response['message'] == "success": + raise Exception(response['message']) + trigger_id = response["trigger_id"] + request.trigger_id = trigger_id + return await self._save_trigger(request) + + async def upload(self, filename: str, file: BinaryIO) -> tuple[str, str]: + data = FormData() + data.add_field( + 'file', + file, + filename=filename, + content_type='image/jpeg' + ) + response = await self._request( + "upload", + data=data + ) + if not response['message'] == "success": + raise Exception(response['message']) + + return response['trigger_id'], response['upload_filename'] + + async def release(self, trigger_id: TriggerID): + response = await self._request( + "queue/release", + json={ + "trigger_id": trigger_id + } + ) + return response diff --git a/modules/diffusion/midjourney/model.py b/modules/diffusion/midjourney/model.py new file mode 100644 index 0000000..c6d2718 --- /dev/null +++ b/modules/diffusion/midjourney/model.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass + +from diffusion_bot.apps.diffusion.diffusion import DiffusionModel + + +@dataclass +class MidjourneyDiffusionModel(DiffusionModel): + pass diff --git a/modules/diffusion/midjourney/queue_manager.py b/modules/diffusion/midjourney/queue_manager.py new file mode 100644 index 0000000..ce161ab --- /dev/null +++ b/modules/diffusion/midjourney/queue_manager.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +from aiogram import types, Bot +from aiogram.utils import markdown as md +from loguru import logger + +from diffusion_bot.apps.bot.keyboards.common import common_kbs + +if TYPE_CHECKING: + from diffusion_bot.locales.stubs.ru.stub import TranslatorRunner + from .midjourney import MidjourneyWorker + +UserID = int + + +class MJQueueManager: + semaphore: asyncio.Semaphore = asyncio.Semaphore(100) + managers: dict[UserID, MJQueueManager] = {} + queue_tasks: dict[UserID, asyncio.Task] = {} + + def __init__( + self, + user_id: UserID, + edit_message: types.Message, + bot:Bot, + l10n: TranslatorRunner, + mj_worker: MidjourneyWorker, + release_delay: int = 60 * 8 + ): + self.user_id = user_id + self.edit_message = edit_message + self.bot = bot + self.place_in_queue = len(self.semaphore._waiters) if self.semaphore._waiters else 1 + self.l10n = l10n + self.mj_worker = mj_worker + + self.trigger_id: str | None = None + self.prompt: str | None = None + + self._progress_task: asyncio.Task | None = None + self._release_task: asyncio.Task | None = None + self.release_delay = release_delay + + self.can_cancel = True + self.released = False + self.queue_reduced = False + + @classmethod + def get_manager(cls, user_id: int) -> MJQueueManager | None: + return cls.managers.get(user_id) + + @classmethod + def get_manager_by_trigger_id(cls, trigger_id: str) -> MJQueueManager | None: + for manager in cls.managers.values(): + if manager.trigger_id == trigger_id: + return manager + return None + + def add_queue_task(self, task: asyncio.Task): + self.queue_tasks[self.user_id] = task + + @classmethod + def get_queue_task(cls, user_id: int) -> asyncio.Task | None: + return cls.queue_tasks.get(user_id) + + @classmethod + async def try_create( + cls, + user_id: int, + edit_message: types.Message, + l10n: TranslatorRunner, + mj_worker: MidjourneyWorker, + ) -> MJQueueManager | None: + bot = Bot.get_current() + if mj_queue_manager := cls.managers.get(user_id): + if mj_queue_manager.can_cancel: + await bot.edit_message_text( + l10n.diffusion.midjourney.queue.want_cancel(), + user_id, + edit_message.message_id, + reply_markup=common_kbs.mj_want_cancel(l10n) + ) + else: + await bot.edit_message_text( + l10n.diffusion.midjourney.queue.cannot_cancel(), + user_id, + edit_message.message_id, + ) + return None + mj_queue_manager = MJQueueManager(user_id, edit_message,bot, l10n, mj_worker) + cls.managers[user_id] = mj_queue_manager + return mj_queue_manager + + async def realtime_progress(self): + wait_time_text = self.l10n.diffusion.midjourney.queue.wait_time() + position_text = self.l10n.diffusion.midjourney.queue.position() + counter = 0 + sleep_time = 3 + reply_markup = common_kbs.mj_want_cancel(self.l10n) + while True: + counter += sleep_time + try: + w_text = f"{wait_time_text} {md.hcode(counter)}s" + rm = None + if self.can_cancel: + w_text = f"{position_text} {md.hcode(self.place_in_queue)}\n{w_text}" + rm = reply_markup + await self.bot.edit_message_text(w_text, self.user_id, self.edit_message.message_id, reply_markup=rm) + except Exception as e: + logger.warning(e) + await asyncio.sleep(sleep_time) + await asyncio.sleep(sleep_time) + if counter > 600: + break + + def _reduce_queues(self): + if self.queue_reduced: + return False + for manager in list(self.managers.values()): + if manager.user_id != self.user_id: + manager.place_in_queue -= 1 + self.queue_reduced = True + return True + + def release(self) -> bool: + if self.released: + return False + self.released = True + if self._progress_task: + self._progress_task.cancel() + if self._release_task: + self._release_task.cancel() + + self._reduce_queues() + + del self.managers[self.user_id] + del self.queue_tasks[self.user_id] + self.semaphore.release() + + logger.info(f"User {self.user_id} released lock") + return True + + def cancel(self) -> bool: + if not self.can_cancel: + return False + self.can_cancel = False + self.queue_tasks[self.user_id].cancel() + del self.queue_tasks[self.user_id] + + if self._progress_task: + self._progress_task.cancel() + self._reduce_queues() + del self.managers[self.user_id] + return True + + def set_trigger_id(self, trigger_id: str): + self.trigger_id = trigger_id + + def set_prompt(self, prompt: str): + self.prompt = prompt + + async def release_trigger(self): + if self.trigger_id: + try: + result = await self.mj_worker.release(self.trigger_id) + pt = self.mj_worker.cache.get_processing_trigger(self.trigger_id) + await self.bot.delete_message( + chat_id=pt.user_id, + message_id=pt.message_id, + ) + logger.info(f"Release callback trigger {self.trigger_id} for user:{self.user_id}: {result}") + except Exception as e: + logger.warning(f"Release callback trigger {self.trigger_id} for user:{self.user_id}: {e}") + + async def delayed_release(self): + await asyncio.sleep(self.release_delay) + release = self.release() + if release: + logger.info(f"User {self.user_id} released after {self.release_delay} seconds") + await self.bot.send_message(self.user_id, self.l10n.diffusion.predicting.timeout()) + # await self.release_trigger() + + @classmethod + async def release_by_banned_words(cls, words: list[str]): + released = 0 + managers_for_release = [] + for manager in cls.managers.values(): + if manager.prompt in words: + managers_for_release.append(manager) + + for manager in managers_for_release: + await manager.banned_release(manager.prompt) + released += 1 + + if released: + logger.info(f"Released {released} users by banned words: {words}") + else: + logger.warning(f"No users released by banned words: {words}") + + async def banned_release(self, text: str): + self.release() + logger.info(f"User {self.user_id} released when banned: {text}") + # diffusion-midjourney-prompt-banned + await self.bot.send_message(self.user_id, self.l10n.diffusion.midjourney.prompt.banned()) + await self.release_trigger() + + def _after_semaphore(self): + self.can_cancel = False + self._reduce_queues() + self._release_task = asyncio.create_task(self.delayed_release()) + + # подождать своей очереди + async def wait(self): + self._progress_task = asyncio.create_task(self.realtime_progress()) + logger.info(f"User {self.user_id} wait for semaphore") + await self.semaphore.acquire() + logger.info(f"User {self.user_id} acquired semaphore") + self._after_semaphore() + + async def __aenter__(self): + try: + await self.wait() + except Exception as e: + logger.exception(e) + self.release() + except (asyncio.CancelledError, asyncio.TimeoutError): + self.release() + raise + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # release only if exception + logger.info(f"User {self.user_id} exit from context manager") + if exc_type: + self.release() diff --git a/modules/diffusion/midjourney/request.py b/modules/diffusion/midjourney/request.py new file mode 100644 index 0000000..f58f436 --- /dev/null +++ b/modules/diffusion/midjourney/request.py @@ -0,0 +1,37 @@ +# message: Message, +# prompt: str, +# user: User, +# picurl: str | None = None, +# lock: asyncio.Lock = None +from dataclasses import dataclass +from typing import Callable + +from aiogram.types import Message + +from diffusion_bot.apps.bot.callback_data.midjourney import MidjourneyAction +from diffusion_bot.db.models import User + + +@dataclass +class MJRequest: + # all + message: Message + user: User + action: MidjourneyAction + + # imagine + prompt: str | None = None + picurl: str | None = None + + # upscale, variation, refresh + msg_id: int | None = None + + # upscale, variation + index: int | None = None + + # describe + filename: str | None = None + trigger_id: str | None = None + + # all + unlocker: Callable = None diff --git a/modules/diffusion/midjourney/response.py b/modules/diffusion/midjourney/response.py new file mode 100644 index 0000000..4acb4a8 --- /dev/null +++ b/modules/diffusion/midjourney/response.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TypedDict, List, Callable + +from charset_normalizer.md import Optional + +from diffusion_bot.apps.bot.callback_data.midjourney import MidjourneyAction +from diffusion_bot.apps.diffusion.midjourney.request import MJRequest + +TriggerID = str +MsgID = UserID = MessageID = int + + +@dataclass +class ProcessingTrigger: + trigger_id: TriggerID + user_id: UserID + message_id: MessageID + unlocker: Optional[Callable] = None + action: MidjourneyAction = MidjourneyAction.VARIATION + + def __post_init__(self): + if isinstance(self.action, str): + self.action = MidjourneyAction(self.action) + + @classmethod + def from_request(cls, request: MJRequest): + return cls( + request.trigger_id, + request.user.id, + request.message.message_id, + request.unlocker, + request.action + ) + + def release(self): + if self.unlocker: + self.unlocker() + + +@dataclass +class Trigger: + msg_id: MsgID + msg_hash: str + trigger_id: TriggerID + prompt: str = "" + + @classmethod + def from_callback_data(cls, callback_data: CallbackData): + filename = callback_data["attachments"][0]["filename"] + msg_hash = filename.split("_")[-1].split(".")[0] + msg_id = callback_data["id"] + trigger_id = callback_data["trigger_id"] + return cls(msg_id, msg_hash, trigger_id) + + +class BannedWordsResult(TypedDict): + words: list[str] + + +class Attachment(TypedDict): + id: int + url: str + proxy_url: str + filename: str + content_type: str + width: int + height: int + size: int + ephemeral: bool + + +class EmbedsImage(TypedDict): + url: str + proxy_url: str + + +class Embed(TypedDict): + type: str + description: str + image: EmbedsImage + + +class CallbackData(TypedDict): + type: str + id: int + content: str + attachments: List[Attachment] + embeds: List[Embed] + + trigger_id: str diff --git a/modules/diffusion/openai_gen/__init__.py b/modules/diffusion/openai_gen/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/diffusion/openai_gen/chatgpt/__init__.py b/modules/diffusion/openai_gen/chatgpt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/diffusion/openai_gen/chatgpt/dialog/dialog.py b/modules/diffusion/openai_gen/chatgpt/dialog/dialog.py new file mode 100644 index 0000000..5999330 --- /dev/null +++ b/modules/diffusion/openai_gen/chatgpt/dialog/dialog.py @@ -0,0 +1,83 @@ +from dataclasses import dataclass, field +from itertools import islice +from typing import TypedDict + +import tiktoken +from loguru import logger + +from .message import Message, Role + + +def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"): + """Returns the number of tokens used by a list of messages.""" + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + if model == "gpt-3.5-turbo": + # logger.warning("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301.") + return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") + elif model == "gpt-4": + # logger.warning("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.") + return num_tokens_from_messages(messages, model="gpt-4-0314") + elif model == "gpt-3.5-turbo-0301": + tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_name = -1 # if there's a name, the role is omitted + elif model == "gpt-4-0314": + tokens_per_message = 3 + tokens_per_name = 1 + else: + raise NotImplementedError( + f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""") + num_tokens = 0 + for message in messages: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> + return num_tokens + + +class MessageTypedDict(TypedDict): + role: str + content: str + name: str + + +@dataclass +class Dialog: + messages: list[MessageTypedDict] = field(default_factory=list) + + def num_tokens(self, model="gpt-3.5-turbo"): + return num_tokens_from_messages(self.messages, model=model) + + def iter_messages(self, reverse=False, limit=None): + """Iterate over messages in the dialog.""" + messages = self.messages if not reverse else reversed(self.messages) + if limit is not None: + messages = islice(messages, limit) + for message in messages: + yield Message.from_dict(message) + + def add_message(self, message: Message): + """Add a message to the dialog.""" + self.messages.append(message.to_dict()) + + def add_system_message(self, text: str): + """Add a system message to the dialog.""" + self.add_message(Message(role=Role.SYSTEM, content=text)) + + def add_user_message(self, text: str): + """Add a user message to the dialog.""" + self.add_message(Message(role=Role.USER, content=text)) + + def add_assistant_message(self, text: str): + """Add an assistant message to the dialog.""" + self.add_message(Message(role=Role.ASSISTANT, content=text)) + + def get_last_content(self): + """Get the content of the last message in the dialog.""" + return self.messages[-1]["content"] diff --git a/modules/diffusion/openai_gen/chatgpt/dialog/message.py b/modules/diffusion/openai_gen/chatgpt/dialog/message.py new file mode 100644 index 0000000..45393b6 --- /dev/null +++ b/modules/diffusion/openai_gen/chatgpt/dialog/message.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import datetime +from dataclasses import dataclass, field +from enum import StrEnum + +from aiogram import md +from aiogram.utils import markdown as md + + +class Role(StrEnum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + + +@dataclass +class Message: + role: Role + content: str + datetime: str = field(default_factory=lambda: datetime.datetime.now().strftime("%d.%m.%Y %H:%M:%S")) + + def pretty(self): + role = f"[{self.role}]" + return f"{md.hcode(role)}\n{self.content}" + + @classmethod + def from_dict(cls, data: dict): + return cls( + role=data['role'], + content=data['content'], + datetime=data.get('datetime') + ) + + def to_dict(self, exclude_datetime=True): + data = { + 'role': self.role, + 'content': self.content, + } + if not exclude_datetime: + data['datetime'] = self.datetime + return data diff --git a/modules/diffusion/openai_gen/chatgpt/group_manager.py b/modules/diffusion/openai_gen/chatgpt/group_manager.py new file mode 100644 index 0000000..0075474 --- /dev/null +++ b/modules/diffusion/openai_gen/chatgpt/group_manager.py @@ -0,0 +1,52 @@ +import collections +from dataclasses import dataclass, field +from typing import Callable, Awaitable + +from .manager import GPTModelManager +from .model import GPTModel, GPTModelName + + +@dataclass +class GPTGroupManager: + model_managers: collections.defaultdict[GPTModelName, GPTModelManager] = field( + default_factory=lambda: collections.defaultdict(GPTModelManager) + ) + + def add_model(self, model: GPTModel): + self.model_managers[model.model].add_model(model) + + def get_model_manager(self, model_name: GPTModelName) -> GPTModelManager: + return self.model_managers[model_name] + + def get_model(self, model_name: GPTModelName) -> GPTModel: + return self.get_model_manager(model_name).models[0] + + def get_model_data(self, model_name: GPTModelName) -> dict: + return self.get_model(model_name).get_model_data() + + def get_model_names(self) -> list[GPTModelName]: + return list(self.model_managers.keys()) + + async def generate_completion( + self, + model_name: GPTModelName, + messages: list[dict], + **kwargs, + ): + return ( + await self.get_model_manager(model_name) + .generate_completion(messages, **kwargs) + ) + + async def stream_completion( + self, + model_name: GPTModelName, + messages: list[dict], + callback: Callable[[str], Awaitable[None]], + new_chars_threshold: int = 50, + **kwargs, + ): + return ( + await self.get_model_manager(model_name) + .stream_completion(messages, callback, new_chars_threshold, **kwargs) + ) diff --git a/modules/diffusion/openai_gen/chatgpt/manager.py b/modules/diffusion/openai_gen/chatgpt/manager.py new file mode 100644 index 0000000..5147039 --- /dev/null +++ b/modules/diffusion/openai_gen/chatgpt/manager.py @@ -0,0 +1,58 @@ +import asyncio +import dataclasses +from contextlib import asynccontextmanager +from typing import List, Callable, Awaitable, AsyncGenerator + +from loguru import logger + +from .model import GPTModel + + +@dataclasses.dataclass +class GPTModelManager: + """ Manages multiple GPTModel instances and rate limits them """ + models: List[GPTModel] = dataclasses.field(default_factory=list) + requests_per_minute: int = 200 # todo L1 TODO 08.09.2023 20:52 taima: Увеличить до 60 + + _available_models: asyncio.Queue[GPTModel] = dataclasses.field(init=False) + _rate_limit_semaphores: dict[GPTModel, asyncio.BoundedSemaphore] = dataclasses.field(init=False) + + def __post_init__(self): + self._available_models = asyncio.Queue() + self._rate_limit_semaphores = {} + for model in self.models: + self.add_model(model) + + def add_model(self, model: GPTModel): + self.models.append(model) + self._rate_limit_semaphores[model] = asyncio.BoundedSemaphore(self.requests_per_minute) + self._available_models.put_nowait(model) + + @asynccontextmanager + async def acquire_model(self) -> AsyncGenerator[GPTModel, None]: + # model = await self._available_models.get() + # fixme L1 08.08.2023 21:05 taima: Доработать логику для работы с несколькими аккаунтами GPT + model = self.models[0] + semaphore = self._rate_limit_semaphores[model] + await semaphore.acquire() + + try: + yield model + finally: + # self._available_models.put_nowait(model) + asyncio.get_event_loop().call_later(60, semaphore.release) + + async def generate_completion(self, messages: List[dict], **kwargs): + async with self.acquire_model() as model: + return await model.generate_completion(messages, **kwargs) + + async def stream_completion( + self, + messages: List[dict], + callback: Callable[[str], Awaitable[None]], + new_chars_threshold: int = 50, + **kwargs, + ): + async with self.acquire_model() as model: + logger.info(f"Using model {model.model}") + return await model.stream_completion(messages, callback, new_chars_threshold, **kwargs) diff --git a/modules/diffusion/openai_gen/chatgpt/model.py b/modules/diffusion/openai_gen/chatgpt/model.py new file mode 100644 index 0000000..e4ccbbe --- /dev/null +++ b/modules/diffusion/openai_gen/chatgpt/model.py @@ -0,0 +1,73 @@ +import typing +from dataclasses import dataclass, field +from enum import StrEnum +from typing import Callable, Awaitable + +import openai +from ...diffusion import DiffusionType + + +class GPTModelName(StrEnum): + # gpt-3.5-turbo + GPT_3_5_TURBO = "gpt-3.5-turbo" + # gpt-4 + GPT_4 = "gpt-4" + + +@dataclass +class GPTModel: + model: GPTModelName + name: str + max_tokens: int + api_key: str | None = None + type: DiffusionType = DiffusionType.TEXT_TO_TEXT + default_inputs: dict[str, typing.Any] = field(default_factory=dict) + + def __hash__(self): + return hash(id(self)) + + def get_model_data(self): + return { + "model": self.model, + "api_key": self.api_key, + } + + async def _request(self, messages: list[dict], stream: bool = False, **kwargs): + return await openai.ChatCompletion.acreate( + **self.get_model_data(), + messages=messages, + stream=stream, + **kwargs, + ) + + async def generate_completion( + self, + messages: list[dict], + **kwargs, + ): + response = await self._request(messages, **kwargs) + return response.choices[0].message.content + + async def stream_completion( + self, + messages: list[dict], + callback: Callable[[str], Awaitable[None]], + new_chars_threshold: int = 50, + **kwargs, + ): + text = "" + new_chars = 0 + response = await self._request(messages, stream=True, **kwargs) + async for message in response: + delta = message.choices[0].delta + if "content" in delta: + text += delta["content"] + new_chars += len(delta["content"]) + + if new_chars >= new_chars_threshold: + await callback(text + "\n...") + # todo L1 TODO 30.04.2023 2:54 taima: + new_chars = 0 + + await callback(text) + return text diff --git a/modules/diffusion/openai_gen/chatgpt/sender.py b/modules/diffusion/openai_gen/chatgpt/sender.py new file mode 100644 index 0000000..d13d887 --- /dev/null +++ b/modules/diffusion/openai_gen/chatgpt/sender.py @@ -0,0 +1,54 @@ +import asyncio +from typing import Callable, Awaitable + +from loguru import logger + + +class QueueSender: + def __init__( + self, + cb: Callable[[str], Awaitable[...]], + max_messages: int = 1, + sleep: float = 0.4 + ): + self.cb = cb + self.max_messages = max_messages + self.sleep = sleep + self._queue = asyncio.Queue() + self._task = asyncio.create_task(self.sender()) + + async def sender(self): + while True: + message = await self._queue.get() + try: + await self.cb(message) + except Exception as e: + logger.warning(e) + await asyncio.sleep(self.sleep) + await self.cb(message, parse_mode=None) + await asyncio.sleep(self.sleep) + self._queue.task_done() + + async def send(self, message: str): + if self._queue.qsize() > self.max_messages: + self._queue.get_nowait() + self._queue.task_done() + logger.debug("Queue overflow detected. Clearing queue") + await self._queue.put(message) + + async def close(self): + """ + Wait for queue to finish and cancel task + :return: + """ + await self._queue.join() + if not self._task.done(): + self._task.cancel() + + async def force_close(self): + """ + Cancel task without waiting for queue + :return: + """ + if not self._task.done(): + self._task.cancel() diff --git a/modules/diffusion/openai_gen/models.py b/modules/diffusion/openai_gen/models.py new file mode 100644 index 0000000..e03bf33 --- /dev/null +++ b/modules/diffusion/openai_gen/models.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum +from io import BytesIO +from typing import BinaryIO, Literal + +import openai +import pydub +from openai import AsyncClient + +from ..diffusion import DiffusionModel, DiffusionType + + +@dataclass +class OpenAIDiffusionModel(DiffusionModel): + client: AsyncClient |None = None + + async def predict(self, prompt: str) -> str: + raise NotImplementedError + + +class DallEVersion(StrEnum): + V2 = "dall-e-2" + V3 = "dall-e-3" + + +@dataclass +class DallE(OpenAIDiffusionModel): + model: DallEVersion = DallEVersion.V3 + type = DiffusionType.TEXT_TO_IMAGE + + async def predict(self, prompt: str) -> str: + response = await self.client.images.generate( + model=self.model, + prompt=prompt, + ) + return response.data[0].url + + +@dataclass +class Whisper(OpenAIDiffusionModel): + model: Literal["whisper-1"] = "whisper-1" + type = DiffusionType.AUDIO_TO_TEXT + + async def predict(self, voice_prompt: BinaryIO) -> str: + speech = pydub.AudioSegment.from_file(voice_prompt) + out_io = BytesIO() + out_io.name = "audio.wav" + speech.export(out_io, format="wav") + response = await self.client.audio.transcriptions.create( + model=self.model, + file=out_io, + ) + return response.text + + +@dataclass +class DAVinci(OpenAIDiffusionModel): + model: Literal["text-davinci-003"] = "text-davinci-003" + type = DiffusionType.TEXT_TO_TEXT + + async def predict(self, prompt: str, extra_inputs: dict | None = None) -> str: + inputs = self.default_inputs + if extra_inputs: + for key, value in extra_inputs.items(): + if key in inputs: + inputs[key] = value + response = await openai.Completion.create( + model=self.model, + prompt=prompt, + **inputs, + ) + return response["choices"][0]["text"] diff --git a/modules/diffusion/replicate_gen/__init__.py b/modules/diffusion/replicate_gen/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/diffusion/replicate_gen/base.py b/modules/diffusion/replicate_gen/base.py new file mode 100644 index 0000000..538302e --- /dev/null +++ b/modules/diffusion/replicate_gen/base.py @@ -0,0 +1,171 @@ +import asyncio +import re +import typing +from dataclasses import dataclass +from pprint import pformat + +import aiohttp +from loguru import logger +from replicate.files import upload_file +from replicate.json import encode_json + +from diffusion_bot.apps.diffusion.diffusion import DiffusionType +from diffusion_bot.apps.diffusion.replicate_gen.models import ReplicateDiffusionModel + +Url: typing.TypeAlias = str + + +# data can be: {'input': {'image': '... +# need dont show this data in logs +def log_data(data: dict) -> str: + data = data.copy() + if "input" in data: + if "image" in data["input"] or "img" in data["input"]: + data["input"]["image"] = "data:image/jpeg;base64, ..." + return pformat(data) + + +@dataclass +class Prediction: + predict_time: float + inputs: dict[str, str] + output: Url + cost: float + + +class Replicate: + api_url: str = "https://api.replicate.com/v1/predictions" + + def __init__( + self, + token: str, + default_inputs: dict[str, str] = None, + predict_status_timeout: int = 3, + proxy: str | None = None, + timeout: int = 5 * 60, + ): + self.inputs = default_inputs or {} + self.proxy = proxy + self.headers = { + "Authorization": f"Token {token}" + } + self.timeout = timeout + self.session: aiohttp.ClientSession | None = None + self.predict_status_timeout = predict_status_timeout + + async def init_session(self): + self.session = aiohttp.ClientSession( + headers=self.headers, + timeout=aiohttp.ClientTimeout(total=self.timeout) + ) + + async def __aenter__(self): + await self.init_session() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.session.close() + + async def predict( + self, + diffusion_model: ReplicateDiffusionModel, + prompt: str | None = None, + extra_inputs: dict[str, str | int] = None, + cb: typing.Callable[[str], typing.Awaitable[typing.Any]] = None, + ) -> Prediction: + inputs = self.inputs | diffusion_model.default_inputs | (extra_inputs or {}) + prompt_field = diffusion_model.prompt_field + if isinstance(prompt, str): + _prompt = "" + if input_field := self.inputs.get(prompt_field): + _prompt += input_field + ", " + if def_field := diffusion_model.default_inputs.get(prompt_field): + _prompt += def_field + ", " + if extra_inputs and (extra_field := extra_inputs.get(prompt_field)): + _prompt += extra_field + ", " + if prompt: + inputs[prompt_field] = _prompt + prompt + else: + inputs[prompt_field] = prompt + inputs = encode_json(inputs, upload_file) + data = { + "version": diffusion_model.version, + "input": inputs + } + async with self.session.post(self.api_url, json=data, proxy=self.proxy) as res: + logger.info(f"Predicting: {diffusion_model.model=}\n{prompt=}\n{log_data(data)}") + result = await res.json() + if "id" not in result: + logger.error(f"\nS1. Predicted: {diffusion_model.model=}\n{prompt=}\n{pformat(result)}") + if detail := result.get("detail"): + raise Exception(detail) + raise Exception(result["detail"]) + prediction_url = self.api_url + "/" + result["id"] + while True: + await asyncio.sleep(self.predict_status_timeout) + async with self.session.get(prediction_url, proxy=self.proxy) as res: + result = await res.json() + logger.debug(f"{pformat(result)}") + if result["status"] in ("canceled", "failed"): + logger.error(f"\nS2. Predicted: {diffusion_model.model=}\n{prompt=}\n{pformat(result)}") + if error := result.get("error"): + logger.error(error) + raise Exception( + f"❌ К сожалению, произошла ошибка при генерации изображения, повторите позже...") + try: + percent = result["logs"].split("\n")[-2] + # find 1 percent + percent = re.search(r"\d+%", percent) + if percent: + percent = percent.group(0) + if cb: + try: + await cb(percent) + except Exception as e: + logger.warning(e) + except IndexError: + pass + if result["status"] == "succeeded": + logger.success(f"\nS3. Predicted: {diffusion_model.model=}\n{prompt=}\n{pformat(result)}") + predict_time = result["metrics"]["predict_time"] + return Prediction( + inputs=inputs, + predict_time=predict_time, + output=diffusion_model.parse_output(result), + cost=predict_time * diffusion_model.cost, + ) + if result["status"] == "failed": + logger.error(f"\nPredicted: {diffusion_model.model=}\n{prompt=}\n{pformat(result)}") + if error := result.get("error"): + logger.error(error) + raise Exception(f"❌ К сожалению, произошла ошибка при генерации изображения, повторите позже...") + + +async def cb(percent): + print(f"\r{percent}", end="") + + +async def main(): + token = "r8_O3PnwjbeND9QqKT1H0OHWhT5XLrEOZN2H7hqx" + token = "r8_3qeY7XfFqwdBboPN1K90kpKJWdpmN1A3IMLyz" + # with as file: + # file = io.BytesIO(file.read()) + # file.name = "screenshot.jpg" + # file.seek(0) + file = open("screenshot.png", "rb") + # print(file.tell()) + async with Replicate(token=token, predict_status_timeout=1) as replicate: + diffusion_model = ReplicateDiffusionModel( + type=DiffusionType.IMAGE_TO_IMAGE, + model="j-min/clip-caption-reward:de37751f75135f7ebbe62548e27d6740d5155dfefdf6447db35c9865253d7e06", + prompt_field="image", + # default_inputs={ + # "reward": "clips_grammar" + # } + ) + prediction = await replicate.predict(diffusion_model, prompt=file, cb=cb) + # print(prediction) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/modules/diffusion/replicate_gen/models.py b/modules/diffusion/replicate_gen/models.py new file mode 100644 index 0000000..9e344bc --- /dev/null +++ b/modules/diffusion/replicate_gen/models.py @@ -0,0 +1,46 @@ +import re +import typing +from dataclasses import dataclass, field + +from diffusion_bot.apps.diffusion.diffusion import DiffusionType, DiffusionModel + +MODEL_RE = re.compile(r"^(?P[^/]+/[^:]+):(?P.+)$") + +Url: typing.TypeAlias = str + + +@dataclass +class ReplicateDiffusionModel(DiffusionModel): + version: str = field(init=False) + cost: float = 0.00055 + prompt_field: str = "prompt" + + def __post_init__(self): + self.version = MODEL_RE.match(self.model).group("version") + + def parse_output(self, prediction: dict) -> Url: + if self.type == DiffusionType.TEXT_TO_IMAGE: + return prediction["output"][0] + elif self.type == DiffusionType.TEXT_TO_VIDEO: + return prediction["output"]["mp4"] + elif self.type == DiffusionType.TEXT_TO_AUDIO: + output = prediction["output"] + if isinstance(output, str): + return output + return output.get("audio_out") or output.get("audio") + elif self.type == DiffusionType.AUDIO_TO_TEXT: + text = "" + for item in prediction["output"]: + text += item["text"] + return text + else: + # todo L1 TODO 16.06.2023 14:42 taima: У некоторых моделей несколько вариантов результатов, возможно стоит добавить обработку всех + output = prediction["output"] + if isinstance(output, list): + output = output[-1] + if isinstance(output, dict): + return iter(output.values()).__next__() + return output + elif isinstance(output, dict): + return iter(output.values()).__next__() + return prediction["output"] diff --git a/todo.txt b/todo.txt index a04bd4b..85dfeaa 100644 --- a/todo.txt +++ b/todo.txt @@ -11,4 +11,5 @@ # todo L1 TODO 22.05.2023 16:16 taima: Переделать Logging # todo L1 TODO 12.08.2023 19:37 taima: Add keyboards class # todo L1 TODO 06.10.2023 14:30 taima: Написать Mixinы и генератов методов для CallbackData -# todo L1 TODO 25.11.2023 14:06 taima: Перевести в классы markdowns \ No newline at end of file +# todo L1 TODO 25.11.2023 14:06 taima: Перевести в классы markdowns +# todo L1 TODO 12.12.2023 13:58 rasul: add init alembic \ No newline at end of file