Skip to content

Commit

Permalink
Merge pull request #211 from ZaphodBeebblebrox/discord-bot
Browse files Browse the repository at this point in the history
Discord Bot MVP
  • Loading branch information
ZaphodBeebblebrox authored Dec 5, 2024
2 parents c5d3437 + c5a3627 commit f319ad0
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 10 deletions.
4 changes: 4 additions & 0 deletions config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ password =
oauth_key =
oauth_secret =

[discord]
token =
guild =

[service.mal]
username =
password =
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ requests>=2.22.0
Unidecode>=1.1.1
PyYAML>=5.1.2
python-dateutil>=2.8.0
sqlite-spellfix>=1.1.0
discord.py>=2.4.0
8 changes: 8 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def __init__(self):
self.r_oauth_key = None
self.r_oauth_secret = None

self.d_token = None
self.d_guild = None

self.services = dict()

self.new_show_types = list()
Expand Down Expand Up @@ -68,6 +71,11 @@ def from_file(file_path):
config.r_password = sec.get("password", None)
config.r_oauth_key = sec.get("oauth_key", None)
config.r_oauth_secret = sec.get("oauth_secret", None)

if "discord" in parsed:
sec = parsed["discord"]
config.d_token = sec.get("token", None)
config.d_guild = sec.get("guild", None)

if "options" in parsed:
sec = parsed["options"]
Expand Down
41 changes: 38 additions & 3 deletions src/data/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from logging import debug, error, exception
import sqlite3, re
import sqlite3, sqlite_spellfix, re
from functools import wraps, lru_cache
from unidecode import unidecode
from typing import Set, List, Optional
Expand All @@ -15,6 +15,8 @@ def living_in(the_database):
"""
try:
db = sqlite3.connect(the_database)
db.enable_load_extension(True)
db.load_extension(sqlite_spellfix.extension_path())
db.execute("PRAGMA foreign_keys=ON")
except sqlite3.OperationalError:
error("Failed to open database, {}".format(the_database))
Expand Down Expand Up @@ -181,6 +183,14 @@ def setup_tables(self):
UNIQUE(show, episode) ON CONFLICT REPLACE
)""")

# The two inserts take minimal time < 0.005 when we have 1300 shows, so running them on setup is fine.
self.q.executescript("""CREATE VIRTUAL TABLE IF NOT EXISTS FuzzySearch;
INSERT INTO FuzzySearch (word) SELECT name FROM Shows
WHERE name NOT IN (SELECT word FROM FuzzySearch);
INSERT INTO FuzzySearch (word) SELECT name_en FROM Shows
WHERE name_en NOT NULL AND name_en NOT IN (SELECT word FROM FuzzySearch);
""")

self.commit()

def register_services(self, services):
Expand Down Expand Up @@ -510,11 +520,32 @@ def get_show_by_name(self, name) -> Optional[Show]:
WHERE name = ?", (name,))
show = self.q.fetchone()
if show is None:
return None
self.q.execute(
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
WHERE name_en = ?", (name,))
show = self.q.fetchone()
if show is None:
return None

show = Show(*show)
show.aliases = self.get_aliases(show)
return show

@db_error_default(None)
def get_show_by_name_fuzzy(self, text, count=1) -> Optional[Show]:
# The trailing * makes it a prefix search,
# which is more likely to match watch people are looking for.
text = text + "*"
self.q.execute("SELECT word FROM FuzzySearch WHERE word MATCH ? AND top=?", (text, count,))
names = self.q.fetchall()
if names is None:
return None

shows = []
for name in names:
shows.append(self.get_show_by_name(name[0]))
return shows[0] if count == 1 else shows

@db_error_default(list())
def get_aliases(self, show: Show) -> [str]:
self.q.execute("SELECT alias FROM Aliases where show = ?", (show.id,))
Expand All @@ -531,6 +562,8 @@ def add_show(self, raw_show: UnprocessedShow, commit=True) -> int:
has_source = raw_show.has_source
is_nsfw = raw_show.is_nsfw
self.q.execute("INSERT INTO Shows (name, name_en, length, type, has_source, is_nsfw) VALUES (?, ?, ?, ?, ?, ?)", (name, name_en, length, show_type, has_source, is_nsfw))
self.q.execute("INSERT INTO FuzzySearch (word) VALUES (?)", (name))
self.q.execute("INSERT INTO FuzzySearch (word) VALUES (?)", (name_en))
show_id = self.q.lastrowid
self.add_show_names(raw_show.name, *raw_show.more_names, id=show_id, commit=commit)

Expand All @@ -556,7 +589,9 @@ def update_show(self, show_id: str, raw_show: UnprocessedShow, commit=True):
is_nsfw = raw_show.is_nsfw

if name_en:
self.q.execute("UPDATE Shows SET name_en = ? WHERE id = ?", (name_en, show_id))
self.q.execute("DELETE FROM FuzzySearch WHERE word = (SELECT name_en FROM Shows WHERE id = ?)", (show_id,))
self.q.execute("INSERT INTO FuzzySearch (word) VALUES (?)", (name_en,))
self.q.execute("UPDATE Shows SET name_en = ? WHERE id = ?", (name_en, show_id))
if length != 0:
self.q.execute("UPDATE Shows SET length = ? WHERE id = ?", (length, show_id))
self.q.execute("UPDATE Shows SET type = ?, has_source = ?, is_nsfw = ? WHERE id = ?", (show_type, has_source, is_nsfw, show_id))
Expand Down
6 changes: 5 additions & 1 deletion src/holo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def main(config, args, extra_args):
info("Batch creating threads")
import module_batch_create as m
m.main(config, db, *extra_args)
elif config.module == "discord":
info("starting discord bot")
import module_discord as m
m.main(config, db, *extra_args)
else:
warning("This should never happen or you broke it!")
except:
Expand All @@ -80,7 +84,7 @@ def main(config, args, extra_args):
import argparse
parser = argparse.ArgumentParser(description="{}, {}".format(name, description))
parser.add_argument("--no-input", dest="no_input", action="store_true", help="run without stdin and write to a log file")
parser.add_argument("-m", "--module", dest="module", nargs=1, choices=["setup", "edit", "episode", "update", "find", "create", "batch"], default=["episode"], help="runs the specified module")
parser.add_argument("-m", "--module", dest="module", nargs=1, choices=["setup", "edit", "episode", "update", "find", "create", "batch", "discord"], default=["episode"], help="runs the specified module")
parser.add_argument("-c", "--config", dest="config_file", nargs=1, default=["config.ini"], help="use or create the specified database location")
parser.add_argument("-d", "--database", dest="db_name", nargs=1, default=None, help="use or create the specified database location")
parser.add_argument("-s", "--subreddit", dest="subreddit", nargs=1, default=None, help="set the subreddit on which to make posts")
Expand Down
3 changes: 2 additions & 1 deletion src/module_batch_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def main(config, db, show_name, episode_count):
if not config.debug:
megathread_post = reddit.submit_text_post(config.subreddit, megathread_title, megathread_body)
else:
megathread_post = None
megathread_post = None

if megathread_post is not None:
debug("Post successful")
Expand All @@ -50,6 +50,7 @@ def main(config, db, show_name, episode_count):
for i, url in enumerate(post_urls):
info(f"Episode {i}: {url}")
info(f"Megathread: {megathread_url}")
return megathread_post


def _create_megathread_content(config, db, show, stream, episode_count):
Expand Down
4 changes: 2 additions & 2 deletions src/module_create_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main(config, db, show_name, episode):
db.set_show_delayed(show, False)
for editing_episode in db.get_episodes(show):
_edit_reddit_post(config, db, show, stream, editing_episode, editing_episode.link, submit=not config.debug)
return True
return post_url
else:
error(" Episode not submitted")
return False
return None
120 changes: 120 additions & 0 deletions src/module_discord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from logging import debug, info, warning, error
from functools import partial

import discord
from discord import app_commands


def main(config, db):
intents = discord.Intents.default()
guild = discord.Object(config.d_guild)
client = Lovepon(intents=intents, guild=guild)

@client.tree.command(guild=guild)
@app_commands.describe(
anime='The anime you want to post a thread for.',
episode='The episode number.',
)
async def post(interaction: discord.Interaction, anime: str, episode: int):
if anime.isdigit():
await post_thread(interaction, config, db, db.get_show(anime), episode)
return

anime = anime.lower()
show = db.get_show_by_name_fuzzy(anime)
if show is None:
await interaction.response.send_message(f'Cannot identify {anime}', ephemeral=True)
return

if anime != show.name.lower() and (not show.name_en or anime != show.name_en.lower()):
view = Confirmation_Button(partial(post_thread, config=config, db=db, show=show, episode=episode))
await interaction.response.send_message(f'Post a discussion thread for epsiode {episode} of {show.name}?', view=view, ephemeral=True)
else:
await post_thread(interaction, config, db, show, episode)

@client.tree.command(guild=guild)
@app_commands.describe(
anime='The anime you want to post a thread for.',
count='The number of episodes to post.',
)
async def batch(interaction: discord.Interaction, anime: str, count: int):
if anime.isdigit():
await post_batch(interaction, config, db, db.get_show(anime), count)
return

anime = anime.lower()
show = db.get_show_by_name_fuzzy(anime)
if show is None:
await interaction.response.send_message(f'Cannot identify {anime}', ephemeral=True)
return

if anime != show.name.lower() and (not show.name_en or anime != show.name_en.lower()):
view = Confirmation_Button(partial(post_batch, config=config, db=db, show=show, count=count))
await interaction.response.send_message(f'Create a {count} episode batch for {show.name}?', view=view, ephemeral=True)
else:
await post_batch(interaction, config, db, show, count)

@client.tree.command(guild=guild)
@app_commands.describe(
anime='The anime title you want to search.'
)
async def search(interaction: discord.Interaction, anime: str):
shows = db.get_show_by_name_fuzzy(anime, 10)
if shows is None:
await interaction.response.send_message(f'Cannot identify {anime}')
return

format = "### Matching Shows:\n"
for show in shows:
format += f"**{show.id}**: {show.name}"
format += f" • {show.name_en}\n" if show.name_en else "\n"
await interaction.response.send_message(format, ephemeral=True)


client.run(config.d_token)

class Lovepon(discord.Client):
def __init__(self, *, intents, guild):
super().__init__(intents=intents)

self.tree = app_commands.CommandTree(self)
self.guild = guild

async def setup_hook(self):
await self.tree.sync(guild=self.guild)

async def on_ready(self):
info(f'Logged on to discord as {self.user}.')

class Confirmation_Button(discord.ui.View):
def __init__(self, yesfunc):
super().__init__()
self.yesfunc = yesfunc

@discord.ui.button(label='Yes', style=discord.ButtonStyle.success)
async def yes(self, interaction: discord.Interaction, button: discord.ui.Button):
await self.yesfunc(interaction)
self.stop()

@discord.ui.button(label='No', style=discord.ButtonStyle.danger)
async def no(self, interaction: discord.Interaction, button: discord.ui.Button):
await interaction.response.send_message("Post Canceled", ephemeral=True)
self.stop()

async def post_thread(interaction, config, db, show, episode):
info(f"Creating new thread for {show.name} episode {episode}.")

# When there are a lot of other threads to edit, we exceed discord's 3 second window.
await interaction.response.defer()
import module_create_threads as m
post_url = m.main(config, db, show.name, episode)
await interaction.followup.send(f'Created thread for epsiode {episode} of {show.name}: {post_url}')

async def post_batch(interaction, config, db, show, count):
info(f"Creating {count} episode batch for {show.name}.")

# Creating a large batch exceeds discord's 3 second window.
await interaction.response.defer()
import module_batch_create as m
post_url = m.main(config, db, show.name, count)
await interaction.followup.send(f'Created {count} episode batch for {show.name}: https://redd.it/{post_url}')
7 changes: 4 additions & 3 deletions src/reddit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def _connect_reddit():
return praw.Reddit(client_id=_config.r_oauth_key, client_secret=_config.r_oauth_secret,
username=_config.r_username, password=_config.r_password,
user_agent=_config.useragent,
check_for_updates=False)
check_for_updates=False,
check_for_async=False)

def _ensure_connection():
global _r
Expand All @@ -38,10 +39,10 @@ def submit_text_post(subreddit, title, body):
else:
warning('Flair not selectable, flairing will be disabled')
flair_id, flair_text = None, None

info("Submitting post to {}".format(subreddit))
new_post = _r.subreddit(subreddit).submit(title,
selftext=body,
selftext=body,
flair_id=flair_id,
flair_text=flair_text,
send_replies=False)
Expand Down

0 comments on commit f319ad0

Please sign in to comment.