Skip to content

Commit

Permalink
Update database tables and queries
Browse files Browse the repository at this point in the history
  • Loading branch information
novanai committed Jan 16, 2025
1 parent 8bf423c commit aea1ee7
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 106 deletions.
2 changes: 1 addition & 1 deletion src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ async def error_handler(ctx: arc.GatewayContext, exc: Exception) -> None:
raise exc


@client.set_startup_hook
@client.add_startup_hook
async def startup_hook(_: arc.GatewayClient) -> None:
await init_db()
34 changes: 20 additions & 14 deletions src/database.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from sqlalchemy import BigInteger, Column, Integer, SmallInteger
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import BigInteger, Integer, SmallInteger
from sqlalchemy.ext.asyncio import AsyncAttrs, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

from src.config import DB_HOST, DB_NAME, DB_PASSWORD, DB_USER

engine = create_async_engine(
f"postgresql+asyncpg://{DB_USER}:{DB_PASSWORD}@{DB_HOST}/{DB_NAME}", echo=False
)

Base = declarative_base()

class Base(AsyncAttrs, DeclarativeBase):
pass


Session = async_sessionmaker(bind=engine)


Expand All @@ -23,18 +27,20 @@ async def init_db() -> None:
class StarboardSettings(Base):
__tablename__ = "starboard_settings"

guild = Column(BigInteger, nullable=False, primary_key=True)
channel = Column(BigInteger, nullable=True)
threshold = Column(SmallInteger, nullable=False, default=3)
guild_id: Mapped[int] = mapped_column(BigInteger, nullable=False, primary_key=True)
channel_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
threshold: Mapped[int] = mapped_column(SmallInteger, nullable=False, default=3)
error: Mapped[int | None] = mapped_column(SmallInteger, nullable=True, default=None)


class Starboard(Base):
__tablename__ = "starboard"

id = Column(Integer, nullable=False, primary_key=True, autoincrement=True)
channel = Column(BigInteger, nullable=False)
message = Column(BigInteger, nullable=False)
stars = Column(SmallInteger, nullable=False)
starboard_channel = Column(BigInteger, nullable=False)
starboard_message = Column(BigInteger, nullable=False)
starboard_stars = Column(SmallInteger, nullable=False)
id: Mapped[int] = mapped_column(
Integer, nullable=False, primary_key=True, autoincrement=True
)
channel_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
message_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
stars: Mapped[int] = mapped_column(SmallInteger, nullable=False)
starboard_channel_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
starboard_message_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
252 changes: 161 additions & 91 deletions src/extensions/starboard.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,44 @@
from __future__ import annotations

import enum
import logging

import arc
import hikari
from sqlalchemy import delete, insert, select, update
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy import insert, select, update

from src.database import Starboard, StarboardSettings
from src.database import Session, Starboard, StarboardSettings

logger = logging.getLogger(__name__)

plugin = arc.GatewayPlugin("Starboard")


class StarboardSettingsError(enum.IntEnum):
CHANNEL_FORBIDDEN = 0
CHANNEL_NOT_FOUND = 1


# TODO: handle star remove
@plugin.listen()
@plugin.inject_dependencies
async def on_reaction(
event: hikari.GuildReactionAddEvent,
session: AsyncEngine = arc.inject(),
) -> None:
logger.info("Received guild reaction add event")

if event.emoji_name != "⭐":
return

message = await plugin.client.rest.fetch_message(event.channel_id, event.message_id)
star_count = sum(r.emoji == "⭐" for r in message.reactions)

stmt = select(StarboardSettings).where(StarboardSettings.guild == event.guild_id)
async with session.connect() as conn:
result = await conn.execute(stmt)

settings = result.first()
# get starboard settings
async with Session() as session:
stmt = select(StarboardSettings).where(
StarboardSettings.guild_id == event.guild_id
)
result = await session.execute(stmt)
settings = result.scalars().first()

# TODO: remove temporary logging and merge into one if statement
if not settings:
Expand All @@ -40,125 +47,188 @@ async def on_reaction(
if star_count < settings.threshold:
logger.info("Not enough stars to post to starboard")
return
if not settings.channel:
if not settings.channel_id:
logger.info("No starboard channel set")
return

async with session.connect() as conn:
stmt = select(Starboard).where(Starboard.message == event.message_id)
result = await conn.execute(stmt)
starboard = result.first()

logger.info(starboard)

if not starboard:
stmt = select(Starboard).where(Starboard.starboard_message == event.message_id)
result = await conn.execute(stmt)
starboard = result.first()

logger.info(starboard)

embed = hikari.Embed(description=f"⭐ {star_count}\n[link]({message.make_link(event.guild_id)})")

# TODO: handle starring the starboard message
# i.e. don't create a starboard message for the starboard message
if settings.error is not None:
logger.info("Error with starboard channel")
return

if not starboard:
try:
# TODO: consider ignoring stars reacted to a starboard message

# get starred message
async with Session() as session:
stmt = select(Starboard).where(Starboard.message_id == event.message_id)
result = await session.execute(stmt)
starboard = result.scalars().first()

embed = hikari.Embed(
title=f"⭐ {star_count} - *jump to message*",
url=message.make_link(event.guild_id),
description=message.content,
timestamp=message.timestamp,
).set_author(
name=message.author.username,
icon=message.author.display_avatar_url,
)

images = [
att
for att in message.attachments
if att.media_type and att.media_type.startswith("image")
]
if images:
embed.set_image(images[0])

embeds = [embed, *message.embeds]

try:
if not starboard:
logger.info("Creating message")
message = await plugin.client.rest.create_message(
settings.channel,
embed,
)
stmt = insert(Starboard).values(
channel=event.channel_id,
message=event.message_id,
stars=star_count,
starboard_channel=settings.channel,
starboard_message=message.id,
starboard_stars=0,
settings.channel_id,
embeds=embeds,
)

async with session.begin() as conn:
await conn.execute(stmt)
await conn.commit()
except hikari.ForbiddenError:
logger.info("Can't access starboard channel")
stmt = update(StarboardSettings).where(StarboardSettings.guild == event.guild_id).values(
channel=None)

async with session.begin() as conn:
await conn.execute(stmt)
await conn.commit()

else:
try:
logger.info("Editing message")
await plugin.client.rest.edit_message(
starboard.starboard_channel,
starboard.starboard_message,
embed
async with Session() as session:
session.add(
Starboard(
channel_id=event.channel_id,
message_id=event.message_id,
stars=star_count,
starboard_channel_id=settings.channel_id,
starboard_message_id=message.id,
)
)
await session.commit()
else:
try:
logger.info("Editing message")
await plugin.client.rest.edit_message(
starboard.starboard_channel_id,
starboard.starboard_message_id,
embeds=embeds,
)
except hikari.NotFoundError:
logger.info("Starboard message does not exist, creating new")
message = await plugin.client.rest.create_message(
settings.channel_id,
embeds=embeds,
)
async with Session() as session:
stmt = (
update(Starboard)
.where(
Starboard.starboard_message_id
== starboard.starboard_message_id
)
.values(
starboard_message_id=message.id,
)
)
await session.execute(stmt)
await session.commit()

except hikari.ForbiddenError:
logger.info("Can't access starboard channel")

async with Session() as session:
stmt = (
update(StarboardSettings)
.where(StarboardSettings.guild_id == event.guild_id)
.values(error=StarboardSettingsError.CHANNEL_FORBIDDEN)
)
await session.execute(stmt)
await session.commit()
except hikari.NotFoundError:
logger.info("Can't find starboard channel")

async with Session() as session:
stmt = (
update(StarboardSettings)
.where(StarboardSettings.guild_id == event.guild_id)
.values(error=StarboardSettingsError.CHANNEL_NOT_FOUND)
)
except hikari.ForbiddenError:
logger.info("Can't edit starboard message")
stmt = delete(StarboardSettings).where(StarboardSettings.guild == event.guild_id)
await session.execute(stmt)
await session.commit()

async with session.begin() as conn:
await conn.execute(stmt)
await conn.commit()

# TODO: add permission hook
@plugin.include
@arc.slash_command("starboard", "Edit or view starboard settings.", default_permissions=hikari.Permissions.MANAGE_GUILD)
@arc.slash_command(
"starboard",
"Edit or view starboard settings.",
default_permissions=hikari.Permissions.MANAGE_GUILD,
)
async def starboard_settings(
ctx: arc.GatewayContext,
channel: arc.Option[hikari.TextableGuildChannel | None, arc.ChannelParams("The channel to post starboard messages to.")] = None,
threshold: arc.Option[int | None, arc.IntParams("The minimum number of stars before this message is posted to the starboard", min=1)] = None,
session: AsyncEngine = arc.inject(),
channel: arc.Option[
hikari.TextableGuildChannel | None,
arc.ChannelParams("The channel to post starboard messages to."),
] = None,
threshold: arc.Option[
int | None,
arc.IntParams(
"The minimum number of stars before this message is posted to the starboard",
min=1,
),
] = None,
) -> None:
assert ctx.guild_id

stmt = select(StarboardSettings).where(StarboardSettings.guild == ctx.guild_id)
async with session.connect() as conn:
result = await conn.execute(stmt)

settings = result.first()
async with Session() as session:
stmt = select(StarboardSettings).where(
StarboardSettings.guild_id == ctx.guild_id
)
result = await session.execute(stmt)
settings = result.scalars().first()

if not channel and not threshold:
if not settings:
await ctx.respond("This server has no starboard settings.", flags=hikari.MessageFlag.EPHEMERAL)
await ctx.respond(
"This server has no starboard settings.",
flags=hikari.MessageFlag.EPHEMERAL,
)
else:
# TODO: `channel` and `threshold` can be None
embed = hikari.Embed(
title="Starboard Settings",
description=(
f"**Channel:** <#{settings.channel}>\n"
f"**Threshold:** {settings.threshold}"
f"**Channel:** <#{settings.channel_id}>\n"
f"**Threshold:** {settings.threshold}\n"
+ (f"**Error:** {settings.error}" if settings.error else "")
),
)
await ctx.respond(embed)

return


# TODO: use returning statement to get back new row
# then send embed

if not settings:
stmt = insert(StarboardSettings).values(guild=ctx.guild_id)
# TODO: use add not insert
stmt = insert(StarboardSettings).values(guild_id=ctx.guild_id)
else:
stmt = update(StarboardSettings).where(StarboardSettings.guild == ctx.guild_id)
stmt = update(StarboardSettings).where(
StarboardSettings.guild_id == ctx.guild_id
)

# TODO: simplify logic
if channel and threshold:
stmt = stmt.values(channel=channel.id, threshold=threshold)
stmt = stmt.values(channel_id=channel.id, threshold=threshold)
elif channel:
stmt = stmt.values(channel=channel.id)
stmt = stmt.values(channel_id=channel.id)
elif threshold:
stmt = stmt.values(threshold=threshold)

async with session.begin() as conn:
await conn.execute(stmt)
await conn.commit()

# TODO: respond with embed of new settings?
async with Session() as session:
await session.execute(stmt)
await session.commit()

await ctx.respond("Settings updated.", flags=hikari.MessageFlag.EPHEMERAL)


@arc.loader
def loader(client: arc.GatewayClient) -> None:
client.add_plugin(plugin)

0 comments on commit aea1ee7

Please sign in to comment.