From 1cdc4b3437635c41318fe1e99099e0d2ebc27cc1 Mon Sep 17 00:00:00 2001 From: Clari Date: Sun, 12 Nov 2023 00:57:23 -0600 Subject: [PATCH] Add SlimChannel and SlimThread converters --- jishaku/features/invocation.py | 43 ++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/jishaku/features/invocation.py b/jishaku/features/invocation.py index b2640ca3..9abf8606 100644 --- a/jishaku/features/invocation.py +++ b/jishaku/features/invocation.py @@ -29,6 +29,8 @@ from jishaku.types import ContextA, ContextT UserIDConverter = commands.IDConverter[typing.Union[discord.Member, discord.User]] +ChannelIDConverter = commands.IDConverter[discord.TextChannel] +ThreadIDConverter = commands.IDConverter[discord.Thread] class SlimUserConverter(UserIDConverter): # pylint: disable=too-few-public-methods @@ -54,15 +56,46 @@ async def convert(self, ctx: ContextA, argument: str) -> typing.Union[discord.Me raise commands.UserNotFound(argument) +class SlimChannelConverter(ChannelIDConverter): # pylint: disable=too-few-public-methods + """ + Identical to the stock TextChannelConverter, but does not perform plaintext name checks. + """ + + async def convert(self, ctx: ContextA, argument: str) -> discord.TextChannel: + """Converter method""" + match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) + + if match is not None: + channel_id = int(match.group(1)) + result = ctx.bot.get_channel(channel_id) or discord.utils.get(ctx.message.channel_mentions, id=channel_id) + if result is not None: + return result + raise commands.ChannelNotFound(argument) + + +class SlimThreadConverter(ThreadIDConverter): # pylint: disable=too-few-public-methods + """ + Identical to the stock ThreadConverter, but does not perform plaintext name checks. + """ + + async def convert(self, ctx: ContextA, argument: str) -> discord.Thread: + """Converter method""" + match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) + + if match is not None: + thread_id = int(match.group(1)) + result = ctx.guild.get_thread(thread_id) + if result is not None: + return result + raise commands.ThreadNotFound(argument) + + class InvocationFeature(Feature): """ Feature containing the command invocation related commands """ - if typing.TYPE_CHECKING or hasattr(discord, 'Thread'): - OVERRIDE_SIGNATURE = typing.Union[SlimUserConverter, discord.TextChannel, discord.Thread] # pylint: disable=no-member - else: - OVERRIDE_SIGNATURE = typing.Union[SlimUserConverter, discord.TextChannel] + OVERRIDE_SIGNATURE = typing.Union[SlimUserConverter, SlimChannelConverter, SlimThreadConverter] @Feature.Command(parent="jsk", name="override", aliases=["execute", "exec", "override!", "execute!", "exec!"]) async def jsk_override( @@ -160,7 +193,7 @@ async def jsk_debug(self, ctx: ContextT, *, command_string: str): start = time.perf_counter() - async with ReplResponseReactor(ctx.message): + async with ReplResponseReactor(ctx): with self.submit(ctx): await alt_ctx.command.invoke(alt_ctx)