Skip to content

Commit

Permalink
Add ExecutorServiceExtension
Browse files Browse the repository at this point in the history
  • Loading branch information
rpost committed Jun 21, 2024
1 parent 7bfdd7f commit fbdcfc3
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions src/main/java/de/cronn/testutils/ExecutorServiceExtension.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package de.cronn.testutils;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;

import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.springframework.scheduling.concurrent.CustomizableThreadFactory;

public class ExecutorServiceExtension implements BeforeEachCallback, AfterEachCallback {

private final long testTimeoutMillis;
private ExecutorService executorService;
private List<Future<?>> futures;

public ExecutorServiceExtension(long testTimeoutMillis) {
this.testTimeoutMillis = testTimeoutMillis;
}

@Override
public void afterEach(ExtensionContext context) {
ExecutorServiceUtils.shutdownOrThrow(executorService, getTestName(context), testTimeoutMillis);
}

@Override
public void beforeEach(ExtensionContext context) {
ThreadFactory threadFactory = new CustomizableThreadFactory(getTestName(context));
executorService = Executors.newCachedThreadPool(threadFactory);
futures = new ArrayList<>();
}

private String getTestName(ExtensionContext context) {
return TestNameUtils.getTestName(context.getRequiredTestClass(), context.getRequiredTestMethod().getName());
}

public Future<Void> submit(Runnable runnable) {
return submit(() -> {
runnable.run();
return null;
});
}

public <T> Future<T> submit(Callable<T> callable) {
Future<T> future = executorService.submit(callable);
futures.add(future);
return future;
}

public List<Future<?>> getFutures() {
return futures;
}

public void awaitAllFutures() throws Exception {
for (Future<?> future : getFutures()) {
future.get();
}
}

class TestNameUtils {

private TestNameUtils() {
}

public static String getTestName(Class<?> aClass, String methodName) {
return join(enclosingClassesUpstream(aClass), methodName);
}

private static String enclosingClassesUpstream(Class<?> aClass) {
String classHierarchy = aClass.getSimpleName();
Class<?> enclosingClass = aClass.getEnclosingClass();
while (enclosingClass != null) {
classHierarchy = join(enclosingClass.getSimpleName(), classHierarchy);
enclosingClass = enclosingClass.getEnclosingClass();
}
return classHierarchy;
}

private static String join(String element, String other) {
return other.startsWith("_") ? (element + other) : (element + "_" + other);
}
}

}

0 comments on commit fbdcfc3

Please sign in to comment.