diff --git a/src/main/java/de/cronn/testutils/ExecutorServiceExtension.java b/src/main/java/de/cronn/testutils/ExecutorServiceExtension.java index b37ca0a..754921a 100644 --- a/src/main/java/de/cronn/testutils/ExecutorServiceExtension.java +++ b/src/main/java/de/cronn/testutils/ExecutorServiceExtension.java @@ -1,5 +1,6 @@ package de.cronn.testutils; +import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; @@ -15,17 +16,21 @@ public class ExecutorServiceExtension implements BeforeEachCallback, AfterEachCallback { - private final long testTimeoutMillis; + private final Duration testTimeout; private ExecutorService executorService; private List> futures; public ExecutorServiceExtension(long testTimeoutMillis) { - this.testTimeoutMillis = testTimeoutMillis; + this(Duration.ofMillis(testTimeoutMillis)); + } + + public ExecutorServiceExtension(Duration testTimeout) { + this.testTimeout = testTimeout; } @Override public void afterEach(ExtensionContext context) { - ExecutorServiceUtils.shutdownOrThrow(executorService, getTestName(context), testTimeoutMillis); + ExecutorServiceUtils.shutdownOrThrow(executorService, getTestName(context), testTimeout); } @Override @@ -62,7 +67,7 @@ public void awaitAllFutures() throws Exception { } } - class TestNameUtils { + static class TestNameUtils { private TestNameUtils() { } diff --git a/src/main/java/de/cronn/testutils/ExecutorServiceUtils.java b/src/main/java/de/cronn/testutils/ExecutorServiceUtils.java index 96a5d70..4ef7e88 100644 --- a/src/main/java/de/cronn/testutils/ExecutorServiceUtils.java +++ b/src/main/java/de/cronn/testutils/ExecutorServiceUtils.java @@ -1,5 +1,6 @@ package de.cronn.testutils; +import java.time.Duration; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadPoolExecutor; @@ -16,11 +17,15 @@ public final class ExecutorServiceUtils { private ExecutorServiceUtils() { } - public static void shutdownOrThrow(ExecutorService executor, String executorServiceName, long timeoutMs) { + public static void shutdownOrThrow(ExecutorService executor, String executorServiceName, long timeoutMillis) { + shutdownOrThrow(executor, executorServiceName, Duration.ofMillis(timeoutMillis)); + } + + public static void shutdownOrThrow(ExecutorService executor, String executorServiceName, Duration timeout) { if (executor != null) { try { - if (!shutdownGracefully(executor, executorServiceName, timeoutMs)) { - boolean success = shutdownNow(executor, executorServiceName, timeoutMs); + if (!shutdownGracefully(executor, executorServiceName, timeout)) { + boolean success = shutdownNow(executor, executorServiceName, timeout); Assertions.assertTrue(success, String.format("Failed to shutdown %s", executorServiceName)); } } catch (InterruptedException e) { @@ -31,11 +36,23 @@ public static void shutdownOrThrow(ExecutorService executor, String executorServ } public static boolean shutdownNow(ExecutorService executorService, String executorServiceName, long timeoutMillis) throws InterruptedException { - return shutdown(executorService, executorServiceName, timeoutMillis, true); + return shutdownNow(executorService, executorServiceName, Duration.ofMillis(timeoutMillis)); + } + + public static boolean shutdownNow(ExecutorService executorService, String executorServiceName, Duration timeout) throws InterruptedException { + return shutdown(executorService, executorServiceName, timeout, true); } public static boolean shutdownGracefully(ExecutorService executorService, String executorServiceName, long timeoutMillis) throws InterruptedException { - return shutdown(executorService, executorServiceName, timeoutMillis, false); + return shutdownGracefully(executorService, executorServiceName, Duration.ofMillis(timeoutMillis)); + } + + public static boolean shutdownGracefully(ExecutorService executorService, String executorServiceName, Duration timeout) throws InterruptedException { + return shutdown(executorService, executorServiceName, timeout, false); + } + + private static boolean shutdown(ExecutorService executorService, String executorServiceName, Duration timeout, boolean shutdownWithInterrupt) throws InterruptedException { + return shutdown(executorService, executorServiceName, timeout.toMillis(), shutdownWithInterrupt); } private static boolean shutdown(ExecutorService executorService, String executorServiceName, long timeoutMillis, boolean shutdownWithInterrupt) throws InterruptedException { @@ -53,8 +70,7 @@ private static boolean shutdown(ExecutorService executorService, String executor if (success) { log.info("Finished shutdown of '{}'", executorServiceName); } else { - if (executorService instanceof ThreadPoolExecutor) { - ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) executorService; + if (executorService instanceof ThreadPoolExecutor threadPoolExecutor) { log.warn("Shutdown of '{}' timed out after {} ms. Active tasks: {}", executorServiceName, timeoutMillis, threadPoolExecutor.getActiveCount()); } else { log.warn("Shutdown of '{}' timed out after {} ms.", executorServiceName, timeoutMillis); @@ -64,8 +80,7 @@ private static boolean shutdown(ExecutorService executorService, String executor } private static void clearQueue(ExecutorService executorService, String executorServiceName) { - if (executorService instanceof ThreadPoolExecutor) { - ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) executorService; + if (executorService instanceof ThreadPoolExecutor threadPoolExecutor) { BlockingQueue queue = threadPoolExecutor.getQueue(); if (!queue.isEmpty()) { int queueSize = queue.size(); diff --git a/src/main/java/de/cronn/testutils/spring/ResetClockExtension.java b/src/main/java/de/cronn/testutils/spring/ResetClockExtension.java index 44d8779..e124d35 100644 --- a/src/main/java/de/cronn/testutils/spring/ResetClockExtension.java +++ b/src/main/java/de/cronn/testutils/spring/ResetClockExtension.java @@ -45,8 +45,8 @@ public static boolean hasDeclaredMethodOrder(ExtensionContext context) { protected void resetClock(ExtensionContext context) { ApplicationContext applicationContext = SpringExtension.getApplicationContext(context); Clock clock = applicationContext.getBean(Clock.class); - if (clock instanceof TestClock) { - ((TestClock) clock).reset(); + if (clock instanceof TestClock testClock) { + testClock.reset(); } } diff --git a/src/test/java/de/cronn/testutils/ExecutorServiceExtensionTest.java b/src/test/java/de/cronn/testutils/ExecutorServiceExtensionTest.java new file mode 100644 index 0000000..89f1e4b --- /dev/null +++ b/src/test/java/de/cronn/testutils/ExecutorServiceExtensionTest.java @@ -0,0 +1,35 @@ +package de.cronn.testutils; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.Future; + +import static org.assertj.core.api.Assertions.assertThat; + +@ExtendWith(ThreadLeakCheck.class) +class ExecutorServiceExtensionTest { + + @RegisterExtension + ExecutorServiceExtension executorServiceExtension = new ExecutorServiceExtension(Duration.ofSeconds(10)); + + @Test + void testHappyCase() throws Exception { + executorServiceExtension.submit(() -> "one"); + executorServiceExtension.submit(() -> "two"); + executorServiceExtension.submit(() -> "three"); + + executorServiceExtension.awaitAllFutures(); + + List> futures = executorServiceExtension.getFutures(); + assertThat(futures).hasSize(3); + assertThat(futures) + .map(Future::get) + .map(Object::toString) + .containsExactlyInAnyOrder("one", "two", "three"); + } + +}