-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
26 changed files
with
1,679 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import typing | ||
from dataclasses import dataclass, field | ||
from enum import StrEnum | ||
from openai import AsyncClient | ||
|
||
class DiffusionType(StrEnum): | ||
TEXT_TO_IMAGE = "text2image" | ||
TEXT_TO_TEXT = "text2text" | ||
TEXT_TO_VIDEO = "text2video" | ||
TEXT_TO_AUDIO = "text2audio" | ||
AUDIO_TO_TEXT = "audio2text" | ||
IMAGE_TO_IMAGE = "image2image" | ||
IMAGE_TO_TEXT = "image2text" | ||
|
||
def get_text(self, l10n) -> str: | ||
return l10n.get(f"diffusion-type-button-{self}") | ||
|
||
@classmethod | ||
def to_text_types(cls) -> list['DiffusionType']: | ||
return [cls.TEXT_TO_TEXT, cls.AUDIO_TO_TEXT, cls.IMAGE_TO_TEXT] | ||
|
||
@classmethod | ||
def to_image_types(cls) -> list['DiffusionType']: | ||
return [cls.TEXT_TO_IMAGE, cls.IMAGE_TO_IMAGE] | ||
|
||
|
||
@dataclass | ||
class DiffusionModel: | ||
model: str | ||
name: str | None = None | ||
type: DiffusionType = DiffusionType.TEXT_TO_IMAGE | ||
default_inputs: dict[str, typing.Any] = field(default_factory=dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import pytesseract | ||
from PIL import Image | ||
|
||
# If you don't have tesseract executable in your PATH, include the following: | ||
pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' | ||
# Example tesseract_cmd = r'C:\Program Files (x86)\Tesseract-OCR\tesseract' | ||
file = open("screen.png", "rb") | ||
# Simple image to string | ||
image = Image.open(file) | ||
print(pytesseract.image_to_string( | ||
image, | ||
# lang="rus" | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from functools import cache | ||
from typing import Any | ||
|
||
from diffusion_bot.apps.diffusion.diffusion import DiffusionModel, DiffusionType | ||
from diffusion_bot.apps.diffusion.math_api.mathpix import MathPix | ||
from diffusion_bot.apps.diffusion.midjourney.midjourney import MidjourneyWorker | ||
from diffusion_bot.apps.diffusion.midjourney.model import MidjourneyDiffusionModel | ||
from diffusion_bot.apps.diffusion.openai_gen.chatgpt.group_manager import GPTGroupManager | ||
from diffusion_bot.apps.diffusion.openai_gen.chatgpt.model import GPTModel | ||
from diffusion_bot.apps.diffusion.openai_gen.models import Whisper, DallE, OpenAIDiffusionModel | ||
from diffusion_bot.apps.diffusion.replicate_gen.base import Replicate | ||
from diffusion_bot.apps.diffusion.replicate_gen.models import ReplicateDiffusionModel | ||
|
||
|
||
class DiffusionManager: | ||
def __init__( | ||
self, | ||
gpt_group_manager: GPTGroupManager, | ||
replicate: Replicate, | ||
midjourney_worker: MidjourneyWorker, | ||
mathpix: MathPix, | ||
|
||
openai_models: list[GPTModel, DallE, Whisper], | ||
replicate_models: list[ReplicateDiffusionModel], | ||
photo_scaler_models: list[ReplicateDiffusionModel], | ||
midjourney_models: list[MidjourneyDiffusionModel], | ||
): | ||
self.replicate = replicate | ||
self.midjourney_worker = midjourney_worker | ||
self.gpt_group_manager = gpt_group_manager | ||
self.mathpix = mathpix | ||
|
||
self.openai_models = openai_models | ||
self.replicate_models = replicate_models | ||
self.photo_scaler_models = photo_scaler_models | ||
self.midjourney_models = midjourney_models | ||
|
||
self.models = ( | ||
self.midjourney_models | ||
+ self.openai_models | ||
+ self.replicate_models | ||
+ self.photo_scaler_models | ||
) | ||
|
||
@cache | ||
def get_model(self, model_name: str) -> DiffusionModel: | ||
for model in self.models: | ||
if model.name == model_name: | ||
return model | ||
|
||
@cache | ||
def get_models(self, type: DiffusionType = None) -> list[DiffusionModel]: | ||
# todo L1 TODO 11.06.2023 19:31 taima: Возможно затратная операция, добавить кеширование | ||
if type is None: | ||
return self.models | ||
|
||
models = [] | ||
for model in self.models: | ||
if model.type == type: | ||
models.append(model) | ||
return models | ||
|
||
async def predict(self, model_name: str, prompt: Any = None, **kwargs) -> Any: | ||
model = self.get_model(model_name) | ||
if isinstance(model, ReplicateDiffusionModel): | ||
return await self.replicate.predict(model, prompt) | ||
|
||
elif isinstance(model, GPTModel): | ||
return await self.gpt_group_manager.stream_completion(model.model, **kwargs) | ||
else: | ||
model: OpenAIDiffusionModel | ||
return await model.predict(prompt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import asyncio | ||
import base64 | ||
from dataclasses import dataclass | ||
from pprint import pprint | ||
from typing import ClassVar, Any | ||
|
||
import aiohttp | ||
from loguru import logger | ||
|
||
|
||
@dataclass | ||
class MathPix: | ||
url: ClassVar[str] = "https://api.mathpix.com/v3/text" | ||
app_id: str | ||
app_key: str | ||
session: Any = None | ||
|
||
def init_session(self): | ||
self.session = aiohttp.ClientSession(headers={ | ||
"app_id": self.app_id, | ||
"app_key": self.app_key | ||
}) | ||
|
||
async def __aenter__(self): | ||
self.init_session() | ||
return self | ||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb): | ||
await self.session.close() | ||
|
||
async def predict(self, photo: bytes) -> str: | ||
src = base64.b64encode(photo).decode('utf-8') | ||
src = f"data:image/png;base64,{src}" | ||
json = { | ||
"src": src, | ||
"formats": ["text"], | ||
"data_options": { | ||
"include_asciimath": True, | ||
# "include_latex": True | ||
} | ||
} | ||
async with self.session.post(self.url, json=json) as res: | ||
result = await res.json() | ||
return result.get("text") | ||
|
||
|
||
async def main(): | ||
# print(url) | ||
with open("img.png", "rb") as f: | ||
photo = f.read() | ||
src = base64.b64encode(photo).decode('utf-8') | ||
src = f"data:image/png;base64,{src}" | ||
# src = await save_photo(f) | ||
# print(src) | ||
pass | ||
# base64_photo = base64.b64encode(photo) | ||
json = { | ||
"src": src, | ||
# "src": "https://mathpix-ocr-examples.s3.amazonaws.com/cases_hw.jpg", | ||
"formats": ["text", "data", "html"], | ||
"data_options": { | ||
"include_asciimath": True, | ||
# "include_latex": True | ||
} | ||
} | ||
headers = { | ||
"app_id": "pacmanbotai_gmail_com_e98796_e16211", | ||
"app_key": "f707dc9f517e0e7add256ed0256db9cf20868143d2fd64118b5bdfeca147d4f9" | ||
} | ||
url = "https://api.mathpix.com/v3/text" | ||
logger.info("Start request") | ||
async with aiohttp.ClientSession(headers=headers) as session: | ||
async with session.post(url, json=json) as res: | ||
result = await res.json() | ||
pprint(result) | ||
|
||
|
||
if __name__ == '__main__': | ||
asyncio.run(main()) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from dataclasses import dataclass | ||
|
||
import CacheToolsUtils as ctu | ||
import aiohttp | ||
import orjson | ||
from aiogram import Bot | ||
from aiogram.types import BufferedInputFile | ||
from cachetools import TTLCache | ||
from redis import Redis | ||
|
||
from diffusion_bot.apps.diffusion.midjourney.response import TriggerID, ProcessingTrigger, Trigger, MsgID | ||
from diffusion_bot.apps.share.cache_manager import CACHE_MANAGER | ||
|
||
MJ_TRIGGERS_PROCESSING_CACHE: TTLCache[ | ||
TriggerID, ProcessingTrigger | ||
] = CACHE_MANAGER.midjourney_triggers_processing_cache | ||
MJ_TRIGGERS_ID_CACHE: TTLCache[MsgID, Trigger] = CACHE_MANAGER.midjourney_triggers_id_cache | ||
|
||
|
||
class PrefixedRedisCache(ctu.PrefixedRedisCache): | ||
def _serialize(self, s): | ||
# if not serializable set NOne | ||
return orjson.dumps(s, default=lambda o: None) | ||
|
||
def _deserialize(self, s): | ||
return orjson.loads(s) | ||
|
||
def _key(self, key): | ||
return orjson.dumps(key) | ||
|
||
|
||
Cache = PrefixedRedisCache | ||
|
||
|
||
@dataclass | ||
class CacheAnimation: | ||
url: str | ||
file_id: str | None = None | ||
|
||
async def get_file_id(self, bot: Bot) -> str: | ||
if self.file_id: | ||
return self.file_id | ||
file = await download_file(self.url) | ||
file_id = await bot.send_animation( | ||
chat_id=269019356, | ||
animation=BufferedInputFile(file, "animation.gif") | ||
) | ||
self.file_id = file_id.animation.file_id | ||
return self.file_id | ||
|
||
|
||
class MJCache: | ||
|
||
def __init__(self, redis: Redis, ttl: int = 60 * 60 * 24 * 3): | ||
self.redis = redis | ||
self.processing_triggers: Cache[TriggerID, ProcessingTrigger] = Cache( | ||
redis, | ||
prefix="processing.triggers.", | ||
ttl=ttl | ||
) | ||
self.done_triggers: Cache[MsgID, Trigger] = Cache( | ||
redis, | ||
prefix="done.triggers.", | ||
ttl=ttl | ||
) | ||
self.imagine_animation = CacheAnimation( | ||
"https://cdn.dribbble.com/users/1514097/screenshots/3457456/media/b1cfddae9e7b9645b9cde7ad9ee4f6bf.gif" | ||
) | ||
self.variation_animation = CacheAnimation( | ||
"https://usagif.com/wp-content/uploads/loading-9.gif" | ||
) | ||
|
||
def get_processing_trigger(self, trigger_id: TriggerID) -> ProcessingTrigger | None: | ||
try: | ||
pt = MJ_TRIGGERS_PROCESSING_CACHE.get(trigger_id) | ||
if pt: | ||
return pt | ||
|
||
data = self.processing_triggers[trigger_id] | ||
return ProcessingTrigger(**data) | ||
except KeyError: | ||
return None | ||
|
||
def get_done_trigger(self, msg_id: MsgID) -> Trigger | None: | ||
try: | ||
trigger = MJ_TRIGGERS_ID_CACHE.get(msg_id) | ||
if trigger: | ||
return trigger | ||
|
||
data = self.done_triggers[msg_id] | ||
return Trigger(**data) | ||
except KeyError: | ||
return None | ||
|
||
def set_processing_trigger(self, processing_trigger: ProcessingTrigger): | ||
MJ_TRIGGERS_PROCESSING_CACHE[processing_trigger.trigger_id] = processing_trigger | ||
self.processing_triggers[processing_trigger.trigger_id] = processing_trigger.__dict__ | ||
|
||
def set_done_trigger(self, trigger: Trigger): | ||
MJ_TRIGGERS_ID_CACHE[trigger.msg_id] = trigger | ||
self.done_triggers[trigger.msg_id] = trigger.__dict__ | ||
|
||
|
||
async def download_file(url: str): | ||
async with aiohttp.ClientSession() as session: | ||
async with session.get(url) as resp: | ||
if resp.status == 200: | ||
return await resp.read() |
Oops, something went wrong.