From 260b9708683d6d839a9623247725d411abffe6d5 Mon Sep 17 00:00:00 2001 From: shedaniel Date: Mon, 6 Jun 2022 20:07:16 +0800 Subject: [PATCH] Fix #29 and #30 Signed-off-by: shedaniel --- .../linkie/discord/handler/CommandHandler.kt | 36 +++++++++++----- .../discord/handler/RateLimitException.kt | 19 ++++++++ .../linkie/discord/handler/RateLimiter.kt | 40 +++++++++++++++++ .../linkie/discord/scommands/SlashCommands.kt | 43 +++++++++++++++++-- .../linkie/discord/utils/CommandContext.kt | 2 +- .../linkie/discord/utils/MessageCreator.kt | 10 ++--- .../me/shedaniel/linkie/discord/LinkieBot.kt | 6 ++- .../discord/commands/EvaluateCommand.kt | 2 + .../discord/commands/QueryMappingsCommand.kt | 1 + .../commands/QueryTranslateMappingsCommand.kt | 1 + .../discord/scripting/LinkieScripting.kt | 34 ++++----------- 11 files changed, 145 insertions(+), 49 deletions(-) create mode 100644 src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/RateLimitException.kt create mode 100644 src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/RateLimiter.kt diff --git a/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/CommandHandler.kt b/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/CommandHandler.kt index bd2c3b7..b231c4c 100644 --- a/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/CommandHandler.kt +++ b/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/CommandHandler.kt @@ -23,9 +23,7 @@ import discord4j.core.`object`.entity.channel.MessageChannel import discord4j.core.event.domain.message.MessageCreateEvent import discord4j.core.spec.EmbedCreateSpec import discord4j.rest.util.Color -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking import me.shedaniel.linkie.discord.scommands.splitArgs import me.shedaniel.linkie.discord.utils.CommandContext import me.shedaniel.linkie.discord.utils.MessageBasedCommandContext @@ -36,14 +34,19 @@ import me.shedaniel.linkie.discord.utils.event import me.shedaniel.linkie.discord.utils.replyComplex import me.shedaniel.linkie.discord.utils.sendEmbedMessage import java.time.Duration +import java.util.concurrent.ExecutionException +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException class CommandHandler( private val client: GatewayDiscordClient, private val commandAcceptor: CommandAcceptor, private val throwableHandler: ThrowableHandler, + val rateLimiter: RateLimiter = RateLimiter(Int.MAX_VALUE), ) { - val scope = CoroutineScope(Dispatchers.Default) - + val executor: ExecutorService = Executors.newCachedThreadPool() fun register() { client.event(this::onMessageCreate) } @@ -53,17 +56,23 @@ class CommandHandler( val user = event.message.author.orElse(null)?.takeUnless { it.isBot } ?: return val message: String = event.message.content val prefix = commandAcceptor.getPrefix(event) - scope.launch { - try { + try { + executor.submit { if (message.lowercase().startsWith(prefix)) { val content = message.substring(prefix.length) + if (!rateLimiter.allow(user.id.asLong())) { + throwableHandler.generateErrorMessage(event.message, RateLimitException(rateLimiter.maxRequestPer10Sec), channel, user) + return@submit + } val split = content.splitArgs() if (split.isNotEmpty()) { val cmd = split[0].lowercase() val ctx = MessageBasedCommandContext(event, prefix, cmd, channel) val args = split.drop(1).toMutableList() try { - commandAcceptor.execute(event, ctx, args) + runBlocking { + commandAcceptor.execute(event, ctx, args) + } } catch (throwable: Throwable) { if (throwableHandler.shouldError(throwable)) { try { @@ -79,9 +88,14 @@ class CommandHandler( } } } - } catch (throwable: Throwable) { - throwableHandler.generateErrorMessage(event.message, throwable, channel, user) - } + }.get(10, TimeUnit.SECONDS) + } catch (throwable: TimeoutException) { + val newThrowable = TimeoutException("The command took too long to execute, the maximum execution time is 10 seconds.") + throwableHandler.generateErrorMessage(event.message, newThrowable, channel, user) + } catch (throwable: ExecutionException) { + throwableHandler.generateErrorMessage(event.message, throwable.cause ?: throwable, channel, user) + } catch (throwable: Throwable) { + throwableHandler.generateErrorMessage(event.message, throwable, channel, user) } } } diff --git a/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/RateLimitException.kt b/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/RateLimitException.kt new file mode 100644 index 0000000..a5dda8d --- /dev/null +++ b/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/RateLimitException.kt @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019, 2020, 2021 shedaniel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package me.shedaniel.linkie.discord.handler + +class RateLimitException(limit: Int) : RuntimeException("You have reached the rate limit of $limit per 10 seconds.") diff --git a/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/RateLimiter.kt b/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/RateLimiter.kt new file mode 100644 index 0000000..319e78f --- /dev/null +++ b/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/RateLimiter.kt @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2019, 2020, 2021 shedaniel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package me.shedaniel.linkie.discord.handler + +import java.util.* + +class RateLimiter(val maxRequestPer10Sec: Int) { + data class Entry( + val time: Long, + val userId: Long, + ) + + private val log: Queue = LinkedList() + + fun allow(userId: Long): Boolean { + val curTime = System.currentTimeMillis() + val boundary = curTime - 10000 + synchronized(log) { + while (!log.isEmpty() && log.element().time <= boundary) { + log.poll() + } + log.add(Entry(curTime, userId)) + return log.count { it.userId == userId } <= maxRequestPer10Sec + } + } +} \ No newline at end of file diff --git a/src/discord_api/kotlin/me/shedaniel/linkie/discord/scommands/SlashCommands.kt b/src/discord_api/kotlin/me/shedaniel/linkie/discord/scommands/SlashCommands.kt index 537e4f8..dbf127d 100644 --- a/src/discord_api/kotlin/me/shedaniel/linkie/discord/scommands/SlashCommands.kt +++ b/src/discord_api/kotlin/me/shedaniel/linkie/discord/scommands/SlashCommands.kt @@ -29,6 +29,8 @@ import discord4j.discordjson.json.ApplicationCommandData import discord4j.discordjson.json.ApplicationCommandOptionChoiceData import discord4j.discordjson.json.ApplicationCommandRequest import discord4j.discordjson.possible.Possible +import me.shedaniel.linkie.discord.handler.RateLimitException +import me.shedaniel.linkie.discord.handler.RateLimiter import me.shedaniel.linkie.discord.handler.ThrowableHandler import me.shedaniel.linkie.discord.utils.CommandContext import me.shedaniel.linkie.discord.utils.SlashCommandBasedContext @@ -36,7 +38,13 @@ import me.shedaniel.linkie.discord.utils.dismissButton import me.shedaniel.linkie.discord.utils.event import me.shedaniel.linkie.discord.utils.extensions.getOrNull import me.shedaniel.linkie.discord.utils.replyComplex +import me.shedaniel.linkie.discord.utils.user import reactor.core.publisher.Mono +import java.util.concurrent.ExecutionException +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException import java.util.stream.Collectors data class SlashCommandHandler( @@ -50,6 +58,7 @@ class SlashCommands( private val errorHandler: (String) -> Unit = { println("Error: $it") }, private val defaultEphemeral: Boolean = false, private val debug: Boolean = false, + val rateLimiter: RateLimiter = RateLimiter(Int.MAX_VALUE), ) { val applicationId: Long by lazy { client.restClient.applicationId.block() } val handlers = mutableMapOf() @@ -65,6 +74,7 @@ class SlashCommands( .collect(Collectors.toList()) } val guildCommands = mutableMapOf>() + val executor: ExecutorService = Executors.newCachedThreadPool() data class GuildCommandKey(val guildId: Snowflake, val commandName: String) @@ -240,13 +250,38 @@ class SlashCommands( private fun buildHandler(command: SlashCommand, cmd: String): SlashCommandHandler = SlashCommandHandler(responder = { event -> var sentAny = false - val ctx = SlashCommandBasedContext(command, cmd, event, defaultEphemeral) { - it.subscribe() - sentAny = true + val ctx = SlashCommandBasedContext(command, cmd, event, defaultEphemeral) { acknowledge, mono -> + mono.subscribe() + if (!acknowledge) { + sentAny = true + } + } + if (!rateLimiter.allow(event.user.id.asLong())) { + val exception = RateLimitException(rateLimiter.maxRequestPer10Sec) + if (throwableHandler.shouldError(exception)) { + try { + ctx.message.replyComplex { + layout { dismissButton() } + embed { throwableHandler.generateThrowable(this, exception, ctx.user) } + } + } catch (throwable2: Exception) { + throwable2.addSuppressed(exception) + throwable2.printStackTrace() + } + } + return@SlashCommandHandler } val optionsGetter = OptionsGetter.of(command, ctx, event) runCatching { - if (!executeOptions(command, ctx, optionsGetter, command.options, event.options) && !command.execute(command, ctx, optionsGetter)) { + try { + executor.submit { + if (!executeOptions(command, ctx, optionsGetter, command.options, event.options) && !command.execute(command, ctx, optionsGetter)) { + } + }.get(3, TimeUnit.SECONDS) + } catch (throwable: TimeoutException) { + throw TimeoutException("The command took too long to execute, the maximum execution time is 3 seconds.") + } catch (throwable: ExecutionException) { + throw throwable.cause ?: throwable } if (!sentAny) { throw IllegalStateException("Command was not resolved!") diff --git a/src/discord_api/kotlin/me/shedaniel/linkie/discord/utils/CommandContext.kt b/src/discord_api/kotlin/me/shedaniel/linkie/discord/utils/CommandContext.kt index 0688e5c..3bf095d 100644 --- a/src/discord_api/kotlin/me/shedaniel/linkie/discord/utils/CommandContext.kt +++ b/src/discord_api/kotlin/me/shedaniel/linkie/discord/utils/CommandContext.kt @@ -88,7 +88,7 @@ class SlashCommandBasedContext( override val cmd: String, val event: ChatInputInteractionEvent, val defaultEphemeral: Boolean = false, - val send: (Mono<*>) -> Unit, + val send: (acknowledge: Boolean, Mono<*>) -> Unit, ) : CommandContext { override val message: MessageCreator by lazy { SlashCommandMessageCreator(event, this, send).let { if (defaultEphemeral) it.ephemeral(true) diff --git a/src/discord_api/kotlin/me/shedaniel/linkie/discord/utils/MessageCreator.kt b/src/discord_api/kotlin/me/shedaniel/linkie/discord/utils/MessageCreator.kt index 46d7bfc..5ef94b7 100644 --- a/src/discord_api/kotlin/me/shedaniel/linkie/discord/utils/MessageCreator.kt +++ b/src/discord_api/kotlin/me/shedaniel/linkie/discord/utils/MessageCreator.kt @@ -142,20 +142,20 @@ private fun Mono.toFuturePossible(): FuturePossible = object : FutureP class SlashCommandMessageCreator( val event: ChatInputInteractionEvent, val ctx: CommandContext, - val send: (Mono<*>) -> Unit, + val send: (acknowledge: Boolean, Mono<*>) -> Unit, ) : MessageCreator { var sent: Boolean = false override fun _acknowledge(ephemeral: Boolean?, content: MessageContent?) { if (!sent) { sent = true - send(event.deferReply().withEphemeral(ephemeral ?: false)) + send(true, event.deferReply().withEphemeral(ephemeral ?: false)) } } override fun _reply(blockIfPossible: Boolean, ephemeral: Boolean?, spec: MessageCreatorComplex): FuturePossible { if (!sent) { sent = true - send(event.replyMessage { + send(false, event.replyMessage { ephemeral(ephemeral.possible()) embeds(listOf()) content(Possible.absent()) @@ -169,7 +169,7 @@ class SlashCommandMessageCreator( spec.compile(ctx)?.also(this::components) }) } else { - send(event.sendOriginalEdit { + send(false, event.sendOriginalEdit { embedsOrNull(listOf()) content(Possible.absent()) spec.text?.content?.also(this::contentOrNull) @@ -183,7 +183,7 @@ class SlashCommandMessageCreator( } override fun presentModal(spec: PresentableModalSpec) { - send(event.presentModal(spec.compile(ctx.client, ctx.user))) + send(false, event.presentModal(spec.compile(ctx.client, ctx.user))) } } diff --git a/src/main/kotlin/me/shedaniel/linkie/discord/LinkieBot.kt b/src/main/kotlin/me/shedaniel/linkie/discord/LinkieBot.kt index 3f60ed8..c0b39a9 100644 --- a/src/main/kotlin/me/shedaniel/linkie/discord/LinkieBot.kt +++ b/src/main/kotlin/me/shedaniel/linkie/discord/LinkieBot.kt @@ -65,6 +65,7 @@ import me.shedaniel.linkie.discord.commands.legacy.RemapAWATCommand import me.shedaniel.linkie.discord.config.ConfigManager import me.shedaniel.linkie.discord.handler.CommandHandler import me.shedaniel.linkie.discord.handler.CommandManager +import me.shedaniel.linkie.discord.handler.RateLimiter import me.shedaniel.linkie.discord.scommands.SlashCommands import me.shedaniel.linkie.discord.scommands.sub import me.shedaniel.linkie.discord.tricks.TricksManager @@ -131,7 +132,8 @@ fun main() { ) ) ) { - val slashCommands = SlashCommands(this, LinkieThrowableHandler, ::warn, debug = isDebug) + val rateLimiter = RateLimiter(3) + val slashCommands = SlashCommands(this, LinkieThrowableHandler, ::warn, debug = isDebug, rateLimiter = rateLimiter) TricksManager.listen(slashCommands) val commandManager = object : CommandManager(if (isDebug) "@" else "!") { override fun getPrefix(event: MessageCreateEvent): String { @@ -149,7 +151,7 @@ fun main() { // register the commands registerCommands(commandManager) registerSlashCommands(slashCommands) - CommandHandler(this, commandManager, LinkieThrowableHandler).register() + CommandHandler(this, commandManager, LinkieThrowableHandler, rateLimiter = rateLimiter).register() commandManager.registerToSlashCommands(slashCommands) gateway.event { event -> val dispatch: ThreadMembersUpdate = ThreadMembersUpdateEvent::class.java.getDeclaredField("dispatch").also { diff --git a/src/main/kotlin/me/shedaniel/linkie/discord/commands/EvaluateCommand.kt b/src/main/kotlin/me/shedaniel/linkie/discord/commands/EvaluateCommand.kt index cc985fb..0e4bc58 100644 --- a/src/main/kotlin/me/shedaniel/linkie/discord/commands/EvaluateCommand.kt +++ b/src/main/kotlin/me/shedaniel/linkie/discord/commands/EvaluateCommand.kt @@ -26,6 +26,7 @@ import me.shedaniel.linkie.discord.scripting.EvalContext import me.shedaniel.linkie.discord.scripting.LinkieScripting import me.shedaniel.linkie.discord.scripting.push import me.shedaniel.linkie.discord.utils.CommandContext +import me.shedaniel.linkie.discord.utils.acknowledge import me.shedaniel.linkie.discord.utils.use object EvaluateCommand : SimpleCommand { @@ -42,6 +43,7 @@ object EvaluateCommand : SimpleCommand { var string = script if (string.startsWith("```")) string = string.substring(3) if (string.endsWith("```")) string = string.substring(0, string.length - 3) + ctx.message.acknowledge() LinkieScripting.eval(LinkieScripting.simpleContext.push { ContextExtensions.commandContexts(EvalContext( ctx, diff --git a/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryMappingsCommand.kt b/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryMappingsCommand.kt index c2abdde..f3d8b4a 100644 --- a/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryMappingsCommand.kt +++ b/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryMappingsCommand.kt @@ -151,6 +151,7 @@ open class QueryMappingsCommand( ns.validateGuild(ctx) val nsVersion = options.opt(version, VersionNamespaceConfig(ns)) val searchTermStr = options.opt(searchTerm).replace('.', '/').replace('#', '/') + require(searchTermStr.length < 50) { "Search term must be under 50 characters" } execute(ctx, ns, nsVersion.version!!, searchTermStr, types) } } diff --git a/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryTranslateMappingsCommand.kt b/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryTranslateMappingsCommand.kt index a4014f7..0d9308c 100644 --- a/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryTranslateMappingsCommand.kt +++ b/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryTranslateMappingsCommand.kt @@ -193,6 +193,7 @@ class QueryTranslateMappingsCommand( } val searchTermStr = options.opt(searchTerm).replace('.', '/').replace('#', '/') + require(searchTermStr.length < 50) { "Search term must be under 50 characters" } execute(ctx, src, dst, srcVersion.version!!, searchTermStr, types) } } diff --git a/src/main/kotlin/me/shedaniel/linkie/discord/scripting/LinkieScripting.kt b/src/main/kotlin/me/shedaniel/linkie/discord/scripting/LinkieScripting.kt index d73ee74..a9a8e3c 100644 --- a/src/main/kotlin/me/shedaniel/linkie/discord/scripting/LinkieScripting.kt +++ b/src/main/kotlin/me/shedaniel/linkie/discord/scripting/LinkieScripting.kt @@ -17,10 +17,6 @@ package me.shedaniel.linkie.discord.scripting import discord4j.core.event.domain.message.MessageCreateEvent -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.launch -import kotlinx.coroutines.withContext -import kotlinx.coroutines.withTimeout import me.shedaniel.linkie.discord.config.ConfigManager import me.shedaniel.linkie.discord.tricks.ContentType import me.shedaniel.linkie.discord.tricks.TrickBase @@ -112,29 +108,15 @@ object LinkieScripting { .option("log.level", "OFF") .build() try { - var t: Throwable? = null - withContext(Dispatchers.IO) { - withTimeout(3000) { - launch { - try { - engine.getBindings("js").also { - it.removeMember("load") - it.removeMember("loadWithNewGlobal") - it.removeMember("eval") - it.removeMember("exit") - it.removeMember("quit") - context.applyTo(it) - } - engine.eval("js", script) - } catch (throwable: Throwable) { - t = throwable - } - }.join() - } + engine.getBindings("js").also { + it.removeMember("load") + it.removeMember("loadWithNewGlobal") + it.removeMember("eval") + it.removeMember("exit") + it.removeMember("quit") + context.applyTo(it) } - t?.let { throw it } - } catch (throwable: Throwable) { - throw throwable + engine.eval("js", script) } finally { engine.close(true) }