Skip to content

Commit

Permalink
Refactor image generation to support different models and fix prompt …
Browse files Browse the repository at this point in the history
…handling
  • Loading branch information
Zingzy committed Sep 4, 2024
1 parent 4e41317 commit 9fdca73
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 39 deletions.
10 changes: 10 additions & 0 deletions cogs/imagine_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,11 @@ async def imagine_command(
if width < 16 or height < 16:
raise DimensionTooSmallError("Width and Height must be greater than 16")

try:
model = model.value
except:
pass

start = datetime.datetime.now()

dic, image = await generate_image(
Expand Down Expand Up @@ -408,6 +413,11 @@ async def imagine_command(
inline=True,
)

embed.add_field(name="", value="", inline=False)

embed.add_field(name="Model", value=f"```{model}```", inline=True)
embed.add_field(name="Dimensions", value=f"```{width}x{height}```", inline=True)

embed.set_image(url=f"attachment://image.png")

embed.set_footer(text=f"Generated by {interaction.user}")
Expand Down
30 changes: 9 additions & 21 deletions cogs/multi_pollinate_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ async def cog_load(self):

async def get_info(interaction: discord.Interaction, index: int):

indexes = ["1st", "2nd", "3rd", "4th"]

data = get_multi_imagined_prompt_data(interaction.message.id)
seed = data["urls"][index].split("?")[-1].split("&")[0]
Expand All @@ -29,11 +28,7 @@ async def get_info(interaction: discord.Interaction, index: int):
url += f"&negative={data['negative']}" if 'negative' in data and data['negative'] else ''
url += f"&nologo=true"
url += f"&enhance={data['enhance']}" if 'enhance' in data and data['enhance'] else ''

if index % 2 == 0:
url += f"&model=turbo"
else:
url += f"&model=flux"
url += f"&model={MODELS[index]}"

async with aiohttp.ClientSession() as session:
async with session.get(data['urls'][index]) as response:
Expand All @@ -49,9 +44,6 @@ async def get_info(interaction: discord.Interaction, index: int):
image_file = discord.File(image_file, "image.png")

try:
user_comment = user_comment[user_comment.find("{") :]
user_comment = json.loads(user_comment)

if user_comment["has_nsfw_concept"]:
image_file.filename = f"SPOILER_{image_file.filename}"
data["nsfw"] = True
Expand All @@ -65,8 +57,8 @@ async def get_info(interaction: discord.Interaction, index: int):
pass

embed = discord.Embed(
title=f"{indexes[index]} Image",
description="",
title=f"{ordinal(index+1)} Image",
description=f"Model: {MODELS[index]}",
)

embed.set_image(url=f"attachment://image.png")
Expand Down Expand Up @@ -214,15 +206,11 @@ async def multiimagine_command(

embeds = []
files = []
positions = ["1st", "2nd", "3rd", "4th"]

for i in range(4):
for i in range(len(MODELS)):
try:
time = datetime.datetime.now()
if i % 2 == 0:
model = "turbo"
else:
model = "flux"
model = MODELS[i]

dic, image = await generate_image(
prompt, width, height, model, negative, cached, nologo, enhance, private
Expand All @@ -234,11 +222,11 @@ async def multiimagine_command(

if private:
await interaction.followup.send(
f"Generated **{positions[i]} Image** in `{round(time_taken.total_seconds(), 2)}` seconds ✅",
f"Generated **{ordinal(i+1)} Image** in `{round(time_taken.total_seconds(), 2)}` seconds ✅",
ephemeral=True,
)
else:
embed = discord.Embed(title="Generating Image", description=f"Generated **{positions[i]} Image** in `{round(time_taken.total_seconds(), 2)}` seconds ✅", color=discord.Color.blurple())
embed = discord.Embed(title="Generating Image", description=f"Generated **{ordinal(i+1)} Image** in `{round(time_taken.total_seconds(), 2)}` seconds ✅", color=discord.Color.blurple())
await response.edit(embeds=[embed])

image_file = discord.File(image, f"image_{i}.png")
Expand All @@ -253,7 +241,7 @@ async def multiimagine_command(
if private:
await interaction.followup.send(
embed=discord.Embed(
title=f"Error generating image of `{i}` model",
title=f"Error generating image of `{MODELS[i]}` model",
description=f"{e}",
color=discord.Color.red(),
),
Expand All @@ -262,7 +250,7 @@ async def multiimagine_command(
else:
await response.edit(
embeds=[discord.Embed(
title=f"Error generating image of `{i}` model",
title=f"Error generating image of `{MODELS[i]}` model",
description=f"{e}",
color=discord.Color.red(),
)])
Expand Down
7 changes: 6 additions & 1 deletion cogs/random_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,15 @@ async def random_image_command(
if width < 16 or height < 16:
raise DimensionTooSmallError("Width and Height must be greater than 16")

try:
model = model.value
except:
pass

start = datetime.datetime.now()

dic, image = await generate_image(
"Random Prompt", width, height, model, negative, False, nologo, None, private
"Random Prompt", width, height, model, negative, False, nologo, True, private
)

image_file = discord.File(image, filename="image.png")
Expand Down
9 changes: 5 additions & 4 deletions constants.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from dotenv import load_dotenv
import requests
import json

load_dotenv(override=True)

TOKEN = os.environ["TOKEN"]
MONGODB_URI = os.environ["MONGODB_URI"]
MODELS = [
"flux",
"turbo",
]

r = requests.get("https://image.pollinations.ai/models")
MODELS = json.loads(r.text)
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ async def on_ready(self):

print(f"Logged in as {self.user.name} (ID: {self.user.id})")
print(f"Connected to {len(self.guilds)} guilds")
print(f"Available MODELS: {MODELS}")


bot = pollinationsBot()
Expand All @@ -103,7 +104,7 @@ async def on_message(message):
if bot.user in message.mentions:
if message.type is not discord.MessageType.reply:
embed = discord.Embed(
description="Hello, I am the Pollinations.ai Bot. I am here to help you with your AI needs. **To Generate Images click </pollinate:1223762317359976519> or </multi-imagine:1187375074722975837> or type `/help` for more commands**.",
description="Hello, I am the Pollinations.ai Bot. I am here to help you with your AI needs. **To Generate Images click </pollinate:1223762317359976519> or </multi-pollinate:1264942861800050891> or type `/help` for more commands**.",
color=discord.Color.og_blurple(),
)

Expand Down
32 changes: 20 additions & 12 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ def generate_global_leaderboard():
print(e)
return None

def ordinal(n):
suffix = ["th", "st", "nd", "rd"] + ["th"] * 6
if 10 <= n % 100 <= 20:
suffix_choice = "th"
else:
suffix_choice = suffix[n % 10]
return f"{n}{suffix_choice}"


async def generate_error_message(
interaction: discord.Interaction,
Expand Down Expand Up @@ -206,14 +214,16 @@ async def generate_error_message(

def extract_user_comment(image_bytes):
image = Image.open(io.BytesIO(image_bytes))
exif_data = piexif.load(image.info["exif"])
user_comment = exif_data.get("Exif", {}).get(piexif.ExifIFD.UserComment, None)

try:
exif = image.info['exif'].decode('latin-1', errors='ignore')
user_comment = json.loads(exif[exif.find("{"):exif.rfind("}") + 1])
except Exception as e:
print(e)
return "No user comment found."

if user_comment:
try:
return user_comment.decode("utf-8")
except UnicodeDecodeError:
return "No user comment found."
return user_comment
else:
return "No user comment found."

Expand All @@ -222,7 +232,7 @@ async def generate_image(
prompt: str,
width: int = 800,
height: int = 800,
model: str = "flux",
model: str = f"{MODELS[0]}",
negative: str | None = None,
cached: bool = False,
nologo: bool = False,
Expand All @@ -231,7 +241,7 @@ async def generate_image(
**kwargs,
):
print(
f"Generating image with prompt: {prompt}, width: {width}, height: {height}, negative: {negative}, cached: {cached}, nologo: {nologo}, enhance: {enhance}",
f"Generating image with prompt: {prompt}, width: {width}, height: {height}, negative: {negative}, cached: {cached}, nologo: {nologo}, enhance: {enhance}, model: {model}",
file=sys.stderr,
)

Expand Down Expand Up @@ -262,7 +272,7 @@ async def generate_image(
dic["seed"] = seed if not cached else None

async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
async with session.get(url, allow_redirects=True) as response:
response.raise_for_status() # Raise an exception for non-2xx status codes
image_data = await response.read()

Expand All @@ -273,12 +283,10 @@ async def generate_image(
image_file.seek(0)

try:
user_comment = user_comment[user_comment.find("{") :]
user_comment = json.loads(user_comment)
dic["nsfw"] = user_comment["has_nsfw_concept"]
if enhance or len(prompt) < 80:
enhance_prompt = user_comment["prompt"]
enhance_prompt = enhance_prompt[: enhance_prompt.rfind("\n")]
enhance_prompt = enhance_prompt[: enhance_prompt.rfind("\n")].strip()
dic["enhanced_prompt"] = enhance_prompt
except Exception as e:
print(e)
Expand Down

0 comments on commit 9fdca73

Please sign in to comment.