Skip to content

Commit

Permalink
Fix #29 and #30
Browse files Browse the repository at this point in the history
Signed-off-by: shedaniel <[email protected]>
  • Loading branch information
shedaniel committed Jun 6, 2022
1 parent b0787e6 commit 260b970
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ import discord4j.core.`object`.entity.channel.MessageChannel
import discord4j.core.event.domain.message.MessageCreateEvent
import discord4j.core.spec.EmbedCreateSpec
import discord4j.rest.util.Color
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import me.shedaniel.linkie.discord.scommands.splitArgs
import me.shedaniel.linkie.discord.utils.CommandContext
import me.shedaniel.linkie.discord.utils.MessageBasedCommandContext
Expand All @@ -36,14 +34,19 @@ import me.shedaniel.linkie.discord.utils.event
import me.shedaniel.linkie.discord.utils.replyComplex
import me.shedaniel.linkie.discord.utils.sendEmbedMessage
import java.time.Duration
import java.util.concurrent.ExecutionException
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException

class CommandHandler(
private val client: GatewayDiscordClient,
private val commandAcceptor: CommandAcceptor,
private val throwableHandler: ThrowableHandler,
val rateLimiter: RateLimiter = RateLimiter(Int.MAX_VALUE),
) {
val scope = CoroutineScope(Dispatchers.Default)

val executor: ExecutorService = Executors.newCachedThreadPool()
fun register() {
client.event(this::onMessageCreate)
}
Expand All @@ -53,17 +56,23 @@ class CommandHandler(
val user = event.message.author.orElse(null)?.takeUnless { it.isBot } ?: return
val message: String = event.message.content
val prefix = commandAcceptor.getPrefix(event)
scope.launch {
try {
try {
executor.submit {
if (message.lowercase().startsWith(prefix)) {
val content = message.substring(prefix.length)
if (!rateLimiter.allow(user.id.asLong())) {
throwableHandler.generateErrorMessage(event.message, RateLimitException(rateLimiter.maxRequestPer10Sec), channel, user)
return@submit
}
val split = content.splitArgs()
if (split.isNotEmpty()) {
val cmd = split[0].lowercase()
val ctx = MessageBasedCommandContext(event, prefix, cmd, channel)
val args = split.drop(1).toMutableList()
try {
commandAcceptor.execute(event, ctx, args)
runBlocking {
commandAcceptor.execute(event, ctx, args)
}
} catch (throwable: Throwable) {
if (throwableHandler.shouldError(throwable)) {
try {
Expand All @@ -79,9 +88,14 @@ class CommandHandler(
}
}
}
} catch (throwable: Throwable) {
throwableHandler.generateErrorMessage(event.message, throwable, channel, user)
}
}.get(10, TimeUnit.SECONDS)
} catch (throwable: TimeoutException) {
val newThrowable = TimeoutException("The command took too long to execute, the maximum execution time is 10 seconds.")
throwableHandler.generateErrorMessage(event.message, newThrowable, channel, user)
} catch (throwable: ExecutionException) {
throwableHandler.generateErrorMessage(event.message, throwable.cause ?: throwable, channel, user)
} catch (throwable: Throwable) {
throwableHandler.generateErrorMessage(event.message, throwable, channel, user)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright (c) 2019, 2020, 2021 shedaniel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package me.shedaniel.linkie.discord.handler

class RateLimitException(limit: Int) : RuntimeException("You have reached the rate limit of $limit per 10 seconds.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) 2019, 2020, 2021 shedaniel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package me.shedaniel.linkie.discord.handler

import java.util.*

class RateLimiter(val maxRequestPer10Sec: Int) {
data class Entry(
val time: Long,
val userId: Long,
)

private val log: Queue<Entry> = LinkedList()

fun allow(userId: Long): Boolean {
val curTime = System.currentTimeMillis()
val boundary = curTime - 10000
synchronized(log) {
while (!log.isEmpty() && log.element().time <= boundary) {
log.poll()
}
log.add(Entry(curTime, userId))
return log.count { it.userId == userId } <= maxRequestPer10Sec
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,22 @@ import discord4j.discordjson.json.ApplicationCommandData
import discord4j.discordjson.json.ApplicationCommandOptionChoiceData
import discord4j.discordjson.json.ApplicationCommandRequest
import discord4j.discordjson.possible.Possible
import me.shedaniel.linkie.discord.handler.RateLimitException
import me.shedaniel.linkie.discord.handler.RateLimiter
import me.shedaniel.linkie.discord.handler.ThrowableHandler
import me.shedaniel.linkie.discord.utils.CommandContext
import me.shedaniel.linkie.discord.utils.SlashCommandBasedContext
import me.shedaniel.linkie.discord.utils.dismissButton
import me.shedaniel.linkie.discord.utils.event
import me.shedaniel.linkie.discord.utils.extensions.getOrNull
import me.shedaniel.linkie.discord.utils.replyComplex
import me.shedaniel.linkie.discord.utils.user
import reactor.core.publisher.Mono
import java.util.concurrent.ExecutionException
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
import java.util.stream.Collectors

data class SlashCommandHandler(
Expand All @@ -50,6 +58,7 @@ class SlashCommands(
private val errorHandler: (String) -> Unit = { println("Error: $it") },
private val defaultEphemeral: Boolean = false,
private val debug: Boolean = false,
val rateLimiter: RateLimiter = RateLimiter(Int.MAX_VALUE),
) {
val applicationId: Long by lazy { client.restClient.applicationId.block() }
val handlers = mutableMapOf<String, SlashCommandHandler>()
Expand All @@ -65,6 +74,7 @@ class SlashCommands(
.collect(Collectors.toList())
}
val guildCommands = mutableMapOf<Snowflake, MutableList<ApplicationCommandData>>()
val executor: ExecutorService = Executors.newCachedThreadPool()

data class GuildCommandKey(val guildId: Snowflake, val commandName: String)

Expand Down Expand Up @@ -240,13 +250,38 @@ class SlashCommands(

private fun buildHandler(command: SlashCommand, cmd: String): SlashCommandHandler = SlashCommandHandler(responder = { event ->
var sentAny = false
val ctx = SlashCommandBasedContext(command, cmd, event, defaultEphemeral) {
it.subscribe()
sentAny = true
val ctx = SlashCommandBasedContext(command, cmd, event, defaultEphemeral) { acknowledge, mono ->
mono.subscribe()
if (!acknowledge) {
sentAny = true
}
}
if (!rateLimiter.allow(event.user.id.asLong())) {
val exception = RateLimitException(rateLimiter.maxRequestPer10Sec)
if (throwableHandler.shouldError(exception)) {
try {
ctx.message.replyComplex {
layout { dismissButton() }
embed { throwableHandler.generateThrowable(this, exception, ctx.user) }
}
} catch (throwable2: Exception) {
throwable2.addSuppressed(exception)
throwable2.printStackTrace()
}
}
return@SlashCommandHandler
}
val optionsGetter = OptionsGetter.of(command, ctx, event)
runCatching {
if (!executeOptions(command, ctx, optionsGetter, command.options, event.options) && !command.execute(command, ctx, optionsGetter)) {
try {
executor.submit {
if (!executeOptions(command, ctx, optionsGetter, command.options, event.options) && !command.execute(command, ctx, optionsGetter)) {
}
}.get(3, TimeUnit.SECONDS)
} catch (throwable: TimeoutException) {
throw TimeoutException("The command took too long to execute, the maximum execution time is 3 seconds.")
} catch (throwable: ExecutionException) {
throw throwable.cause ?: throwable
}
if (!sentAny) {
throw IllegalStateException("Command was not resolved!")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class SlashCommandBasedContext(
override val cmd: String,
val event: ChatInputInteractionEvent,
val defaultEphemeral: Boolean = false,
val send: (Mono<*>) -> Unit,
val send: (acknowledge: Boolean, Mono<*>) -> Unit,
) : CommandContext {
override val message: MessageCreator by lazy { SlashCommandMessageCreator(event, this, send).let {
if (defaultEphemeral) it.ephemeral(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,20 @@ private fun <T> Mono<T>.toFuturePossible(): FuturePossible<T> = object : FutureP
class SlashCommandMessageCreator(
val event: ChatInputInteractionEvent,
val ctx: CommandContext,
val send: (Mono<*>) -> Unit,
val send: (acknowledge: Boolean, Mono<*>) -> Unit,
) : MessageCreator {
var sent: Boolean = false
override fun _acknowledge(ephemeral: Boolean?, content: MessageContent?) {
if (!sent) {
sent = true
send(event.deferReply().withEphemeral(ephemeral ?: false))
send(true, event.deferReply().withEphemeral(ephemeral ?: false))
}
}

override fun _reply(blockIfPossible: Boolean, ephemeral: Boolean?, spec: MessageCreatorComplex): FuturePossible<Message> {
if (!sent) {
sent = true
send(event.replyMessage {
send(false, event.replyMessage {
ephemeral(ephemeral.possible())
embeds(listOf())
content(Possible.absent())
Expand All @@ -169,7 +169,7 @@ class SlashCommandMessageCreator(
spec.compile(ctx)?.also(this::components)
})
} else {
send(event.sendOriginalEdit {
send(false, event.sendOriginalEdit {
embedsOrNull(listOf())
content(Possible.absent())
spec.text?.content?.also(this::contentOrNull)
Expand All @@ -183,7 +183,7 @@ class SlashCommandMessageCreator(
}

override fun presentModal(spec: PresentableModalSpec) {
send(event.presentModal(spec.compile(ctx.client, ctx.user)))
send(false, event.presentModal(spec.compile(ctx.client, ctx.user)))
}
}

Expand Down
6 changes: 4 additions & 2 deletions src/main/kotlin/me/shedaniel/linkie/discord/LinkieBot.kt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ import me.shedaniel.linkie.discord.commands.legacy.RemapAWATCommand
import me.shedaniel.linkie.discord.config.ConfigManager
import me.shedaniel.linkie.discord.handler.CommandHandler
import me.shedaniel.linkie.discord.handler.CommandManager
import me.shedaniel.linkie.discord.handler.RateLimiter
import me.shedaniel.linkie.discord.scommands.SlashCommands
import me.shedaniel.linkie.discord.scommands.sub
import me.shedaniel.linkie.discord.tricks.TricksManager
Expand Down Expand Up @@ -131,7 +132,8 @@ fun main() {
)
)
) {
val slashCommands = SlashCommands(this, LinkieThrowableHandler, ::warn, debug = isDebug)
val rateLimiter = RateLimiter(3)
val slashCommands = SlashCommands(this, LinkieThrowableHandler, ::warn, debug = isDebug, rateLimiter = rateLimiter)
TricksManager.listen(slashCommands)
val commandManager = object : CommandManager(if (isDebug) "@" else "!") {
override fun getPrefix(event: MessageCreateEvent): String {
Expand All @@ -149,7 +151,7 @@ fun main() {
// register the commands
registerCommands(commandManager)
registerSlashCommands(slashCommands)
CommandHandler(this, commandManager, LinkieThrowableHandler).register()
CommandHandler(this, commandManager, LinkieThrowableHandler, rateLimiter = rateLimiter).register()
commandManager.registerToSlashCommands(slashCommands)
gateway.event<ThreadMembersUpdateEvent> { event ->
val dispatch: ThreadMembersUpdate = ThreadMembersUpdateEvent::class.java.getDeclaredField("dispatch").also {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import me.shedaniel.linkie.discord.scripting.EvalContext
import me.shedaniel.linkie.discord.scripting.LinkieScripting
import me.shedaniel.linkie.discord.scripting.push
import me.shedaniel.linkie.discord.utils.CommandContext
import me.shedaniel.linkie.discord.utils.acknowledge
import me.shedaniel.linkie.discord.utils.use

object EvaluateCommand : SimpleCommand<String> {
Expand All @@ -42,6 +43,7 @@ object EvaluateCommand : SimpleCommand<String> {
var string = script
if (string.startsWith("```")) string = string.substring(3)
if (string.endsWith("```")) string = string.substring(0, string.length - 3)
ctx.message.acknowledge()
LinkieScripting.eval(LinkieScripting.simpleContext.push {
ContextExtensions.commandContexts(EvalContext(
ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ open class QueryMappingsCommand(
ns.validateGuild(ctx)
val nsVersion = options.opt(version, VersionNamespaceConfig(ns))
val searchTermStr = options.opt(searchTerm).replace('.', '/').replace('#', '/')
require(searchTermStr.length < 50) { "Search term must be under 50 characters" }
execute(ctx, ns, nsVersion.version!!, searchTermStr, types)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class QueryTranslateMappingsCommand(
}

val searchTermStr = options.opt(searchTerm).replace('.', '/').replace('#', '/')
require(searchTermStr.length < 50) { "Search term must be under 50 characters" }
execute(ctx, src, dst, srcVersion.version!!, searchTermStr, types)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@
package me.shedaniel.linkie.discord.scripting

import discord4j.core.event.domain.message.MessageCreateEvent
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeout
import me.shedaniel.linkie.discord.config.ConfigManager
import me.shedaniel.linkie.discord.tricks.ContentType
import me.shedaniel.linkie.discord.tricks.TrickBase
Expand Down Expand Up @@ -112,29 +108,15 @@ object LinkieScripting {
.option("log.level", "OFF")
.build()
try {
var t: Throwable? = null
withContext(Dispatchers.IO) {
withTimeout(3000) {
launch {
try {
engine.getBindings("js").also {
it.removeMember("load")
it.removeMember("loadWithNewGlobal")
it.removeMember("eval")
it.removeMember("exit")
it.removeMember("quit")
context.applyTo(it)
}
engine.eval("js", script)
} catch (throwable: Throwable) {
t = throwable
}
}.join()
}
engine.getBindings("js").also {
it.removeMember("load")
it.removeMember("loadWithNewGlobal")
it.removeMember("eval")
it.removeMember("exit")
it.removeMember("quit")
context.applyTo(it)
}
t?.let { throw it }
} catch (throwable: Throwable) {
throw throwable
engine.eval("js", script)
} finally {
engine.close(true)
}
Expand Down

0 comments on commit 260b970

Please sign in to comment.