From ccbb2333b477457c2a9f90486292c6299735db06 Mon Sep 17 00:00:00 2001 From: Zane <90309290+Zingzy@users.noreply.github.com> Date: Sat, 10 Aug 2024 20:18:49 +0530 Subject: [PATCH] Refactor image generation to support different models such as flux and turbo --- cogs/imagine_cog.py | 10 ++++++++-- cogs/multi_pollinate_cog.py | 12 +++++++++++- cogs/random_cog.py | 7 ++++++- constants.py | 6 ++---- utils.py | 2 ++ 5 files changed, 29 insertions(+), 8 deletions(-) diff --git a/cogs/imagine_cog.py b/cogs/imagine_cog.py index c6157d5..4daa00a 100644 --- a/cogs/imagine_cog.py +++ b/cogs/imagine_cog.py @@ -51,6 +51,7 @@ async def regenerate( prompt = message_data["prompt"] width = message_data["width"] height = message_data["height"] + model = message_data.get("model", "flux") negative = message_data["negative"] cached = message_data["cached"] nologo = message_data["nologo"] @@ -58,7 +59,7 @@ async def regenerate( try: dic, image = await generate_image( - prompt, width, height, negative, cached, nologo, enhance + prompt, width, height, model, negative, cached, nologo, enhance ) except Exception as e: print(e) @@ -333,12 +334,16 @@ async def cog_load(self): self.bot.add_view(ImagineButtonView()) @app_commands.command(name="pollinate", description="Generate AI Images") + @app_commands.choices( + model=[app_commands.Choice(name=choice, value=choice) for choice in MODELS], + ) @app_commands.guild_only() @app_commands.checks.cooldown(1, 15) @app_commands.describe( prompt="Prompt of the Image you want want to generate", height="Height of the Image", width="Width of the Image", + model="Model to use for generating the Image", enhance="Enables AI Prompt Enhancement", negative="The things not to include in the Image", cached="Uses the Default seed", @@ -351,6 +356,7 @@ async def imagine_command( prompt: str, width: int = 1000, height: int = 1000, + model: app_commands.Choice[str] = MODELS[0], enhance: bool | None = None, negative: str | None = None, cached: bool = False, @@ -368,7 +374,7 @@ async def imagine_command( start = datetime.datetime.now() dic, image = await generate_image( - prompt, width, height, negative, cached, nologo, enhance, private + prompt, width, height, model, negative, cached, nologo, enhance, private ) image_file = discord.File(image, filename="image.png") diff --git a/cogs/multi_pollinate_cog.py b/cogs/multi_pollinate_cog.py index ba6bde2..3247012 100644 --- a/cogs/multi_pollinate_cog.py +++ b/cogs/multi_pollinate_cog.py @@ -30,6 +30,11 @@ async def get_info(interaction: discord.Interaction, index: int): url += f"&nologo=true" url += f"&enhance={data['enhance']}" if 'enhance' in data and data['enhance'] else '' + if index % 2 == 0: + url += f"&model=turbo" + else: + url += f"&model=flux" + async with aiohttp.ClientSession() as session: async with session.get(data['urls'][index]) as response: response.raise_for_status() @@ -214,8 +219,13 @@ async def multiimagine_command( for i in range(4): try: time = datetime.datetime.now() + if i % 2 == 0: + model = "turbo" + else: + model = "flux" + dic, image = await generate_image( - prompt, width, height, negative, cached, nologo, enhance, private + prompt, width, height, model, negative, cached, nologo, enhance, private ) image_urls.append(dic["bookmark_url"]) diff --git a/cogs/random_cog.py b/cogs/random_cog.py index 7fcc68c..27ebb69 100644 --- a/cogs/random_cog.py +++ b/cogs/random_cog.py @@ -14,11 +14,15 @@ async def cog_load(self): await self.bot.wait_until_ready() @app_commands.command(name="random", description="Generate Random AI Images") + @app_commands.choices( + model=[app_commands.Choice(name=choice, value=choice) for choice in MODELS], + ) @app_commands.guild_only() @app_commands.checks.cooldown(1, 15) @app_commands.describe( height="Height of the image", width="Width of the image", + model="The model to use for generating the image", negative="The things not to include in the image", nologo="Remove the logo", private="Only you can see the generated Image if set to True", @@ -28,6 +32,7 @@ async def random_image_command( interaction: discord.Interaction, width: int = 1000, height: int = 1000, + model: app_commands.Choice[str] = MODELS[0], negative: str | None = None, nologo: bool = True, private: bool = False, @@ -40,7 +45,7 @@ async def random_image_command( start = datetime.datetime.now() dic, image = await generate_image( - "Random Prompt", width, height, negative, False, nologo, None, private + "Random Prompt", width, height, model, negative, False, nologo, None, private ) image_file = discord.File(image, filename="image.png") diff --git a/constants.py b/constants.py index 597a4f2..c6d8bbc 100644 --- a/constants.py +++ b/constants.py @@ -6,8 +6,6 @@ TOKEN = os.environ["TOKEN"] MONGODB_URI = os.environ["MONGODB_URI"] MODELS = [ - "dreamshaper", - "swizz8", - "deliberate", - "juggernaut", + "flux", + "turbo", ] diff --git a/utils.py b/utils.py index 87e1819..aaa35b2 100644 --- a/utils.py +++ b/utils.py @@ -222,6 +222,7 @@ async def generate_image( prompt: str, width: int = 800, height: int = 800, + model: str = "flux", negative: str | None = None, cached: bool = False, nologo: bool = False, @@ -240,6 +241,7 @@ async def generate_image( url += f"?seed={seed}" if not cached else "" url += f"&width={width}" url += f"&height={height}" + url += f"&model={model}" url += f"&negative={negative}" if negative else "" url += f"&nologo={nologo}" if nologo else "" url += f"&enhance={enhance}" if enhance else ""