Skip to content

Commit

Permalink
add diffussial
Browse files Browse the repository at this point in the history
  • Loading branch information
taimast committed Dec 14, 2023
1 parent 25d24a5 commit 3c3842e
Show file tree
Hide file tree
Showing 26 changed files with 1,679 additions and 2 deletions.
3 changes: 2 additions & 1 deletion init_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def install_dependencies(project_path: Path):
"pydantic_settings",
"pydantic>=2.1.1",
"bs4",
"lxml"
"lxml",
"alembic"
]
utils = [
"watchdog",
Expand Down
Empty file added modules/diffusion/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions modules/diffusion/diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import typing
from dataclasses import dataclass, field
from enum import StrEnum
from openai import AsyncClient

class DiffusionType(StrEnum):
TEXT_TO_IMAGE = "text2image"
TEXT_TO_TEXT = "text2text"
TEXT_TO_VIDEO = "text2video"
TEXT_TO_AUDIO = "text2audio"
AUDIO_TO_TEXT = "audio2text"
IMAGE_TO_IMAGE = "image2image"
IMAGE_TO_TEXT = "image2text"

def get_text(self, l10n) -> str:
return l10n.get(f"diffusion-type-button-{self}")

@classmethod
def to_text_types(cls) -> list['DiffusionType']:
return [cls.TEXT_TO_TEXT, cls.AUDIO_TO_TEXT, cls.IMAGE_TO_TEXT]

@classmethod
def to_image_types(cls) -> list['DiffusionType']:
return [cls.TEXT_TO_IMAGE, cls.IMAGE_TO_IMAGE]


@dataclass
class DiffusionModel:
model: str
name: str | None = None
type: DiffusionType = DiffusionType.TEXT_TO_IMAGE
default_inputs: dict[str, typing.Any] = field(default_factory=dict)
13 changes: 13 additions & 0 deletions modules/diffusion/image_scale/tesseract_rec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytesseract
from PIL import Image

# If you don't have tesseract executable in your PATH, include the following:
pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
# Example tesseract_cmd = r'C:\Program Files (x86)\Tesseract-OCR\tesseract'
file = open("screen.png", "rb")
# Simple image to string
image = Image.open(file)
print(pytesseract.image_to_string(
image,
# lang="rus"
))
72 changes: 72 additions & 0 deletions modules/diffusion/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from functools import cache
from typing import Any

from diffusion_bot.apps.diffusion.diffusion import DiffusionModel, DiffusionType
from diffusion_bot.apps.diffusion.math_api.mathpix import MathPix
from diffusion_bot.apps.diffusion.midjourney.midjourney import MidjourneyWorker
from diffusion_bot.apps.diffusion.midjourney.model import MidjourneyDiffusionModel
from diffusion_bot.apps.diffusion.openai_gen.chatgpt.group_manager import GPTGroupManager
from diffusion_bot.apps.diffusion.openai_gen.chatgpt.model import GPTModel
from diffusion_bot.apps.diffusion.openai_gen.models import Whisper, DallE, OpenAIDiffusionModel
from diffusion_bot.apps.diffusion.replicate_gen.base import Replicate
from diffusion_bot.apps.diffusion.replicate_gen.models import ReplicateDiffusionModel


class DiffusionManager:
def __init__(
self,
gpt_group_manager: GPTGroupManager,
replicate: Replicate,
midjourney_worker: MidjourneyWorker,
mathpix: MathPix,

openai_models: list[GPTModel, DallE, Whisper],
replicate_models: list[ReplicateDiffusionModel],
photo_scaler_models: list[ReplicateDiffusionModel],
midjourney_models: list[MidjourneyDiffusionModel],
):
self.replicate = replicate
self.midjourney_worker = midjourney_worker
self.gpt_group_manager = gpt_group_manager
self.mathpix = mathpix

self.openai_models = openai_models
self.replicate_models = replicate_models
self.photo_scaler_models = photo_scaler_models
self.midjourney_models = midjourney_models

self.models = (
self.midjourney_models
+ self.openai_models
+ self.replicate_models
+ self.photo_scaler_models
)

@cache
def get_model(self, model_name: str) -> DiffusionModel:
for model in self.models:
if model.name == model_name:
return model

@cache
def get_models(self, type: DiffusionType = None) -> list[DiffusionModel]:
# todo L1 TODO 11.06.2023 19:31 taima: Возможно затратная операция, добавить кеширование
if type is None:
return self.models

models = []
for model in self.models:
if model.type == type:
models.append(model)
return models

async def predict(self, model_name: str, prompt: Any = None, **kwargs) -> Any:
model = self.get_model(model_name)
if isinstance(model, ReplicateDiffusionModel):
return await self.replicate.predict(model, prompt)

elif isinstance(model, GPTModel):
return await self.gpt_group_manager.stream_completion(model.model, **kwargs)
else:
model: OpenAIDiffusionModel
return await model.predict(prompt)
79 changes: 79 additions & 0 deletions modules/diffusion/math_api/mathpix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import asyncio
import base64
from dataclasses import dataclass
from pprint import pprint
from typing import ClassVar, Any

import aiohttp
from loguru import logger


@dataclass
class MathPix:
url: ClassVar[str] = "https://api.mathpix.com/v3/text"
app_id: str
app_key: str
session: Any = None

def init_session(self):
self.session = aiohttp.ClientSession(headers={
"app_id": self.app_id,
"app_key": self.app_key
})

async def __aenter__(self):
self.init_session()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.session.close()

async def predict(self, photo: bytes) -> str:
src = base64.b64encode(photo).decode('utf-8')
src = f"data:image/png;base64,{src}"
json = {
"src": src,
"formats": ["text"],
"data_options": {
"include_asciimath": True,
# "include_latex": True
}
}
async with self.session.post(self.url, json=json) as res:
result = await res.json()
return result.get("text")


async def main():
# print(url)
with open("img.png", "rb") as f:
photo = f.read()
src = base64.b64encode(photo).decode('utf-8')
src = f"data:image/png;base64,{src}"
# src = await save_photo(f)
# print(src)
pass
# base64_photo = base64.b64encode(photo)
json = {
"src": src,
# "src": "https://mathpix-ocr-examples.s3.amazonaws.com/cases_hw.jpg",
"formats": ["text", "data", "html"],
"data_options": {
"include_asciimath": True,
# "include_latex": True
}
}
headers = {
"app_id": "pacmanbotai_gmail_com_e98796_e16211",
"app_key": "f707dc9f517e0e7add256ed0256db9cf20868143d2fd64118b5bdfeca147d4f9"
}
url = "https://api.mathpix.com/v3/text"
logger.info("Start request")
async with aiohttp.ClientSession(headers=headers) as session:
async with session.post(url, json=json) as res:
result = await res.json()
pprint(result)


if __name__ == '__main__':
asyncio.run(main())
Empty file.
108 changes: 108 additions & 0 deletions modules/diffusion/midjourney/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from dataclasses import dataclass

import CacheToolsUtils as ctu
import aiohttp
import orjson
from aiogram import Bot
from aiogram.types import BufferedInputFile
from cachetools import TTLCache
from redis import Redis

from diffusion_bot.apps.diffusion.midjourney.response import TriggerID, ProcessingTrigger, Trigger, MsgID
from diffusion_bot.apps.share.cache_manager import CACHE_MANAGER

MJ_TRIGGERS_PROCESSING_CACHE: TTLCache[
TriggerID, ProcessingTrigger
] = CACHE_MANAGER.midjourney_triggers_processing_cache
MJ_TRIGGERS_ID_CACHE: TTLCache[MsgID, Trigger] = CACHE_MANAGER.midjourney_triggers_id_cache


class PrefixedRedisCache(ctu.PrefixedRedisCache):
def _serialize(self, s):
# if not serializable set NOne
return orjson.dumps(s, default=lambda o: None)

def _deserialize(self, s):
return orjson.loads(s)

def _key(self, key):
return orjson.dumps(key)


Cache = PrefixedRedisCache


@dataclass
class CacheAnimation:
url: str
file_id: str | None = None

async def get_file_id(self, bot: Bot) -> str:
if self.file_id:
return self.file_id
file = await download_file(self.url)
file_id = await bot.send_animation(
chat_id=269019356,
animation=BufferedInputFile(file, "animation.gif")
)
self.file_id = file_id.animation.file_id
return self.file_id


class MJCache:

def __init__(self, redis: Redis, ttl: int = 60 * 60 * 24 * 3):
self.redis = redis
self.processing_triggers: Cache[TriggerID, ProcessingTrigger] = Cache(
redis,
prefix="processing.triggers.",
ttl=ttl
)
self.done_triggers: Cache[MsgID, Trigger] = Cache(
redis,
prefix="done.triggers.",
ttl=ttl
)
self.imagine_animation = CacheAnimation(
"https://cdn.dribbble.com/users/1514097/screenshots/3457456/media/b1cfddae9e7b9645b9cde7ad9ee4f6bf.gif"
)
self.variation_animation = CacheAnimation(
"https://usagif.com/wp-content/uploads/loading-9.gif"
)

def get_processing_trigger(self, trigger_id: TriggerID) -> ProcessingTrigger | None:
try:
pt = MJ_TRIGGERS_PROCESSING_CACHE.get(trigger_id)
if pt:
return pt

data = self.processing_triggers[trigger_id]
return ProcessingTrigger(**data)
except KeyError:
return None

def get_done_trigger(self, msg_id: MsgID) -> Trigger | None:
try:
trigger = MJ_TRIGGERS_ID_CACHE.get(msg_id)
if trigger:
return trigger

data = self.done_triggers[msg_id]
return Trigger(**data)
except KeyError:
return None

def set_processing_trigger(self, processing_trigger: ProcessingTrigger):
MJ_TRIGGERS_PROCESSING_CACHE[processing_trigger.trigger_id] = processing_trigger
self.processing_triggers[processing_trigger.trigger_id] = processing_trigger.__dict__

def set_done_trigger(self, trigger: Trigger):
MJ_TRIGGERS_ID_CACHE[trigger.msg_id] = trigger
self.done_triggers[trigger.msg_id] = trigger.__dict__


async def download_file(url: str):
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
if resp.status == 200:
return await resp.read()
Loading

0 comments on commit 3c3842e

Please sign in to comment.