diff --git a/src/main/java/de/cronn/testutils/ExecutorServiceExtension.java b/src/main/java/de/cronn/testutils/ExecutorServiceExtension.java new file mode 100644 index 0000000..b37ca0a --- /dev/null +++ b/src/main/java/de/cronn/testutils/ExecutorServiceExtension.java @@ -0,0 +1,89 @@ +package de.cronn.testutils; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadFactory; + +import org.junit.jupiter.api.extension.AfterEachCallback; +import org.junit.jupiter.api.extension.BeforeEachCallback; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.springframework.scheduling.concurrent.CustomizableThreadFactory; + +public class ExecutorServiceExtension implements BeforeEachCallback, AfterEachCallback { + + private final long testTimeoutMillis; + private ExecutorService executorService; + private List> futures; + + public ExecutorServiceExtension(long testTimeoutMillis) { + this.testTimeoutMillis = testTimeoutMillis; + } + + @Override + public void afterEach(ExtensionContext context) { + ExecutorServiceUtils.shutdownOrThrow(executorService, getTestName(context), testTimeoutMillis); + } + + @Override + public void beforeEach(ExtensionContext context) { + ThreadFactory threadFactory = new CustomizableThreadFactory(getTestName(context)); + executorService = Executors.newCachedThreadPool(threadFactory); + futures = new ArrayList<>(); + } + + private String getTestName(ExtensionContext context) { + return TestNameUtils.getTestName(context.getRequiredTestClass(), context.getRequiredTestMethod().getName()); + } + + public Future submit(Runnable runnable) { + return submit(() -> { + runnable.run(); + return null; + }); + } + + public Future submit(Callable callable) { + Future future = executorService.submit(callable); + futures.add(future); + return future; + } + + public List> getFutures() { + return futures; + } + + public void awaitAllFutures() throws Exception { + for (Future future : getFutures()) { + future.get(); + } + } + + class TestNameUtils { + + private TestNameUtils() { + } + + public static String getTestName(Class aClass, String methodName) { + return join(enclosingClassesUpstream(aClass), methodName); + } + + private static String enclosingClassesUpstream(Class aClass) { + String classHierarchy = aClass.getSimpleName(); + Class enclosingClass = aClass.getEnclosingClass(); + while (enclosingClass != null) { + classHierarchy = join(enclosingClass.getSimpleName(), classHierarchy); + enclosingClass = enclosingClass.getEnclosingClass(); + } + return classHierarchy; + } + + private static String join(String element, String other) { + return other.startsWith("_") ? (element + other) : (element + "_" + other); + } + } + +}