Skip to content

Commit

Permalink
Refactor image generation to support different models such as flux an…
Browse files Browse the repository at this point in the history
…d turbo
  • Loading branch information
Zingzy committed Aug 10, 2024
1 parent 76085b0 commit ccbb233
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 8 deletions.
10 changes: 8 additions & 2 deletions cogs/imagine_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,15 @@ async def regenerate(
prompt = message_data["prompt"]
width = message_data["width"]
height = message_data["height"]
model = message_data.get("model", "flux")
negative = message_data["negative"]
cached = message_data["cached"]
nologo = message_data["nologo"]
enhance = message_data["enhance"]

try:
dic, image = await generate_image(
prompt, width, height, negative, cached, nologo, enhance
prompt, width, height, model, negative, cached, nologo, enhance
)
except Exception as e:
print(e)
Expand Down Expand Up @@ -333,12 +334,16 @@ async def cog_load(self):
self.bot.add_view(ImagineButtonView())

@app_commands.command(name="pollinate", description="Generate AI Images")
@app_commands.choices(
model=[app_commands.Choice(name=choice, value=choice) for choice in MODELS],
)
@app_commands.guild_only()
@app_commands.checks.cooldown(1, 15)
@app_commands.describe(
prompt="Prompt of the Image you want want to generate",
height="Height of the Image",
width="Width of the Image",
model="Model to use for generating the Image",
enhance="Enables AI Prompt Enhancement",
negative="The things not to include in the Image",
cached="Uses the Default seed",
Expand All @@ -351,6 +356,7 @@ async def imagine_command(
prompt: str,
width: int = 1000,
height: int = 1000,
model: app_commands.Choice[str] = MODELS[0],
enhance: bool | None = None,
negative: str | None = None,
cached: bool = False,
Expand All @@ -368,7 +374,7 @@ async def imagine_command(
start = datetime.datetime.now()

dic, image = await generate_image(
prompt, width, height, negative, cached, nologo, enhance, private
prompt, width, height, model, negative, cached, nologo, enhance, private
)

image_file = discord.File(image, filename="image.png")
Expand Down
12 changes: 11 additions & 1 deletion cogs/multi_pollinate_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ async def get_info(interaction: discord.Interaction, index: int):
url += f"&nologo=true"
url += f"&enhance={data['enhance']}" if 'enhance' in data and data['enhance'] else ''

if index % 2 == 0:
url += f"&model=turbo"
else:
url += f"&model=flux"

async with aiohttp.ClientSession() as session:
async with session.get(data['urls'][index]) as response:
response.raise_for_status()
Expand Down Expand Up @@ -214,8 +219,13 @@ async def multiimagine_command(
for i in range(4):
try:
time = datetime.datetime.now()
if i % 2 == 0:
model = "turbo"
else:
model = "flux"

dic, image = await generate_image(
prompt, width, height, negative, cached, nologo, enhance, private
prompt, width, height, model, negative, cached, nologo, enhance, private
)

image_urls.append(dic["bookmark_url"])
Expand Down
7 changes: 6 additions & 1 deletion cogs/random_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@ async def cog_load(self):
await self.bot.wait_until_ready()

@app_commands.command(name="random", description="Generate Random AI Images")
@app_commands.choices(
model=[app_commands.Choice(name=choice, value=choice) for choice in MODELS],
)
@app_commands.guild_only()
@app_commands.checks.cooldown(1, 15)
@app_commands.describe(
height="Height of the image",
width="Width of the image",
model="The model to use for generating the image",
negative="The things not to include in the image",
nologo="Remove the logo",
private="Only you can see the generated Image if set to True",
Expand All @@ -28,6 +32,7 @@ async def random_image_command(
interaction: discord.Interaction,
width: int = 1000,
height: int = 1000,
model: app_commands.Choice[str] = MODELS[0],
negative: str | None = None,
nologo: bool = True,
private: bool = False,
Expand All @@ -40,7 +45,7 @@ async def random_image_command(
start = datetime.datetime.now()

dic, image = await generate_image(
"Random Prompt", width, height, negative, False, nologo, None, private
"Random Prompt", width, height, model, negative, False, nologo, None, private
)

image_file = discord.File(image, filename="image.png")
Expand Down
6 changes: 2 additions & 4 deletions constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
TOKEN = os.environ["TOKEN"]
MONGODB_URI = os.environ["MONGODB_URI"]
MODELS = [
"dreamshaper",
"swizz8",
"deliberate",
"juggernaut",
"flux",
"turbo",
]
2 changes: 2 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ async def generate_image(
prompt: str,
width: int = 800,
height: int = 800,
model: str = "flux",
negative: str | None = None,
cached: bool = False,
nologo: bool = False,
Expand All @@ -240,6 +241,7 @@ async def generate_image(
url += f"?seed={seed}" if not cached else ""
url += f"&width={width}"
url += f"&height={height}"
url += f"&model={model}"
url += f"&negative={negative}" if negative else ""
url += f"&nologo={nologo}" if nologo else ""
url += f"&enhance={enhance}" if enhance else ""
Expand Down

0 comments on commit ccbb233

Please sign in to comment.