Skip to content

Commit

Permalink
Improved task exceptions (#24)
Browse files Browse the repository at this point in the history
* improved task exceptions

---------

Co-authored-by: IB <[email protected]>
  • Loading branch information
baitcode and IB authored Nov 28, 2023
1 parent 4f1557a commit 2407a97
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 27 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,12 @@ Currently, supported brokers are:
- RabbitMQ (GCP, AWS)
- Azure Service Bus

### Migrations

The correct way to migrate the workload is:
1) Create new task with new workload
2) Maintain 2 tasks until messages for the old one are depleted
3) Delete old task

Important do not change task input as Json deserializer will break causing the queue to block.
To mitigate the mentioned problem we introduced default behavior to automatically drop tasks on SerialisationError
16 changes: 7 additions & 9 deletions src/main/kotlin/RabbitMQBroker.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import com.rabbitmq.client.impl.MicrometerMetricsCollector
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micrometer.core.instrument.logging.LoggingMeterRegistry
import kotlinx.coroutines.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import loggingScope
import withLogCtx
import kotlin.time.Duration.Companion.minutes
Expand All @@ -26,7 +24,7 @@ class RabbitMQBroker(
) : IMessageBroker {
private val createdQueues = mutableSetOf<QueueName>()
private val createdDelayQueues = mutableSetOf<Long>()
private var queues: RabbitMQQueues
private var queues: QueueDefinitions
var connection: Connection
var channel: Channel

Expand Down Expand Up @@ -54,7 +52,7 @@ class RabbitMQBroker(
factory.isAutomaticRecoveryEnabled = true
connection = factory.newConnection()
channel = connection.createChannel()
queues = RabbitMQQueues(channel)
queues = QueueDefinitions(channel)

// Create DELAYED exchange
channel.exchangeDeclare(delayExchangeName, "headers", true)
Expand Down Expand Up @@ -197,18 +195,18 @@ class RabbitMQConsumer(val consumer: DefaultConsumer) : IConsumer {
}


class RabbitMQQueues(
private class QueueDefinitions(
var channel: Channel,
) {
data class RabbitMQQueueDeclaration(
data class QueueDeclaration(
val queueName: String,
val durable: Boolean,
val exclusive: Boolean,
val autoDelete: Boolean,
val arguments: Map<String, Any>?,
)

val declarations = mutableSetOf<RabbitMQQueueDeclaration>()
val declarations = mutableSetOf<QueueDeclaration>()

fun declare(
queueName: String,
Expand All @@ -218,12 +216,12 @@ class RabbitMQQueues(
arguments: Map<String, Any>?,
) {
val declaration =
RabbitMQQueueDeclaration(queueName, durable, exclusive, autoDelete, arguments)
QueueDeclaration(queueName, durable, exclusive, autoDelete, arguments)
declarations.add(declaration)
this.declare(declaration)
}

fun declare(d: RabbitMQQueueDeclaration): AMQP.Queue.DeclareOk? {
fun declare(d: QueueDeclaration): AMQP.Queue.DeclareOk? {
return channel.queueDeclare(d.queueName, d.durable, d.exclusive, d.autoDelete, d.arguments)
}
}
47 changes: 29 additions & 18 deletions src/main/kotlin/Task.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.zamna.kotask

import io.github.oshai.kotlinlogging.KLogger
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
Expand Down Expand Up @@ -36,14 +37,14 @@ class Task<T : Any> @PublishedApi internal constructor(
private var logger = KotlinLogging.logger { }

companion object {

inline fun <reified T : Any> create(
name: String, retry: IRetryPolicy? = null, noinline handler: TaskHandler<T>,
) = Task(serializer(), name, retry, handler)

inline fun <reified T : Any> create(
name: String, retry: IRetryPolicy? = null, noinline handler: OnlyInputTaskHandler<T>,
) = create(name, retry, handler.toTaskHandler())

fun create(
name: String, retry: IRetryPolicy? = null, handler: NoArgTaskHandler,
) = create(name, retry, handler.toTaskHandler())
Expand Down Expand Up @@ -77,6 +78,7 @@ class Task<T : Any> @PublishedApi internal constructor(
return TaskCallFactory(this as Task<NoInput>, NoInput, manager)
}


fun createTaskCall(
input: T,
params: CallParams = CallParams(),
Expand All @@ -86,6 +88,8 @@ class Task<T : Any> @PublishedApi internal constructor(
return manager.createTaskCall(this, inputStr, params)
}

// }

suspend fun execute(inputStr: String, params: CallParams, manager: TaskManager) {
val logCtx = mapOf(
"task" to name,
Expand All @@ -106,29 +110,32 @@ class Task<T : Any> @PublishedApi internal constructor(
withLogCtx("action" to TaskEvents.MESSAGE_COMPLETE) {
logger.info { "Complete task $name with callId=${params.callId} with $inputStr" }
}
} catch (e: RepeatTask) {
withLogCtx("action" to TaskEvents.MESSAGE_SUBMIT_RETRY) {
logger.info { "Received RepeatTask from task $name with callId=${params.callId} with $inputStr" }
manager.enqueueTaskCall(this, inputStr, e.getRetryCallParams(params))
}
} catch (e: ForceRetry) {
withLogCtx("action" to TaskEvents.MESSAGE_SUBMIT_RETRY) {
logger.info { "Received ForceRetry from task $name with callId=${params.callId} with $inputStr" }
manager.enqueueTaskCall(this, inputStr, e.getRetryCallParams(params))
}
} catch (e: FailNoRetry) {
withLogCtx("action" to TaskEvents.MESSAGE_FAIL_NO_RETRY) {
logger.info { "Received FailNoRetry from task $name with callId=${params.callId} with $inputStr" }
}
} catch (e: SerializationException) {
withLogCtx("action" to TaskEvents.MESSAGE_FAIL_NO_RETRY) {
logger.error { "Task got bad json" }
} catch (e: RetryControlException) {
when (e) {
is RepeatTask -> withLogCtx("action" to TaskEvents.MESSAGE_SUBMIT_RETRY) {
logger.info { "Received RepeatTask from task $name with callId=${params.callId} with $inputStr" }
manager.enqueueTaskCall(this, inputStr, e.getRetryCallParams(params))
}
is ForceRetry -> withLogCtx("action" to TaskEvents.MESSAGE_SUBMIT_RETRY) {
logger.info { "Received ForceRetry from task $name with callId=${params.callId} with $inputStr" }
manager.enqueueTaskCall(this, inputStr, e.getRetryCallParams(params))
}
is FailNoRetry -> withLogCtx("action" to TaskEvents.MESSAGE_FAIL_NO_RETRY) {
logger.info { "Received FailNoRetry from task $name with callId=${params.callId} with $inputStr" }
}
}
} catch (e: Throwable) {
withLogCtx("action" to TaskEvents.MESSAGE_FAIL) {
logger.error(e) { "Task $name failed with callId=${params.callId} with $inputStr" }
}

for ((exception, handler) in manager.taskErrorHandlers) {
if (e::class.java == exception) {
handler.invoke(logger, inputStr)
return
}
}

if (getRetryPolicy(manager).shouldRetry(params)) {
withLogCtx("action" to TaskEvents.MESSAGE_FAIL_RETRY) {
logger.info(e) { "Retry task $name with callId=${params.callId} with $inputStr" }
Expand Down Expand Up @@ -197,3 +204,7 @@ data class CallParams(
fun nextAttempt() = copy(attemptNum = attemptNum + 1)
}


fun interface TaskErrorHandler {
fun invoke(logger: KLogger, inputMessage: String)
}
14 changes: 14 additions & 0 deletions src/main/kotlin/TaskManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import kotlinx.datetime.Clock
import kotlinx.datetime.Instant
import kotlinx.serialization.Serializable
import cleanScheduleWorker
import io.github.oshai.kotlinlogging.KLogger
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.serialization.SerializationException
import loggingScope
import withLogCtx
import kotlin.time.Duration
Expand All @@ -16,14 +18,26 @@ import kotlin.time.Duration.Companion.seconds
import kotlin.time.DurationUnit
import kotlin.time.toDuration


public val discardTaskOnSerialisationProblem: Pair<Class<*>, TaskErrorHandler> = SerializationException::class.java to
TaskErrorHandler { logger, inputStr ->
withLogCtx("action" to TaskEvents.MESSAGE_FAIL_NO_RETRY) {
logger.error { "Can't deserialize json as task input. Json: $inputStr" }
}
}

// TODO(baitcode): TaskManager is getting huge
class TaskManager(
private val broker: IMessageBroker,
val scheduler: IScheduleTracker = InMemoryScheduleTracker(),
private val queueNamePrefix: String = "kotask-",
val defaultRetryPolicy: IRetryPolicy = RetryPolicy(4.seconds, 20, expBackoff = true, maxDelay = 1.hours),
schedulersScope: CoroutineScope? = null,
val taskErrorHandlers: List<Pair<Class<*>, TaskErrorHandler>> = listOf(
discardTaskOnSerialisationProblem, // deafault
)
): AutoCloseable {

private val knownTasks: MutableMap<String, Task<*>> = mutableMapOf()
// TODO(baitcode): Why use list? When we always have single consumer. Is it for concurrency.
internal val tasksConsumers: MutableMap<String, MutableList<IConsumer>> = mutableMapOf()
Expand Down
37 changes: 37 additions & 0 deletions src/test/kotlin/TaskManagerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import io.kotest.framework.concurrency.eventually
import io.kotest.framework.concurrency.until
import io.kotest.matchers.shouldBe
import io.kotest.matchers.shouldNotBe
import io.mockk.every
import io.mockk.mockk
import io.mockk.verify
import kotlinx.coroutines.launch
import kotlinx.datetime.Clock
import org.testcontainers.containers.wait.strategy.Wait
Expand Down Expand Up @@ -212,3 +215,37 @@ fun taskManagerTest(taskManager: TaskManager) = funSpec {
}

}




class TaskManagerErrorHandling: FunSpec({
class UnhadledError : Exception()

val errorHandler = mockk<TaskErrorHandler>()
every { errorHandler.invoke(any(), any()) } returns Unit

val tm = TaskManager(
LocalBroker(),
taskErrorHandlers = listOf(
UnhadledError::class.java to errorHandler
)
)

val testTask1 =
Task.create("failing-task-${randomSuffix()}",) { ctx, input: TaskTrackExecutionWithContextCountInput ->
throw UnhadledError()
}

test("test error") {
tm.startWorkers(testTask1)

TaskTrackExecutionWithContextCountInput.new().let {
testTask1.callLater(it)
eventually(1000) {
verify(exactly = 1) { errorHandler.invoke(any(), any()) }
}

}
}
})

0 comments on commit 2407a97

Please sign in to comment.