diff --git a/infra/common/src/main/java/cn/hippo4j/common/executor/support/SyncPutQueuePolicy.java b/infra/common/src/main/java/cn/hippo4j/common/executor/support/SyncPutQueuePolicy.java index 961b80f615..d1faee86c9 100644 --- a/infra/common/src/main/java/cn/hippo4j/common/executor/support/SyncPutQueuePolicy.java +++ b/infra/common/src/main/java/cn/hippo4j/common/executor/support/SyncPutQueuePolicy.java @@ -19,8 +19,10 @@ import lombok.extern.slf4j.Slf4j; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.RejectedExecutionHandler; import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; /** * Synchronous put queue policy. @@ -28,13 +30,39 @@ @Slf4j public class SyncPutQueuePolicy implements RejectedExecutionHandler { + // The timeout value for the offer method (ms). + private int timeout; + + private final boolean enableTimeout; + + public SyncPutQueuePolicy(int timeout){ + if (timeout < 0){ + throw new IllegalArgumentException("timeout must be greater than 0"); + } + this.timeout = timeout; + this.enableTimeout = true; + } + + public SyncPutQueuePolicy (){ + this.enableTimeout = false; + } + @Override public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) { if (executor.isShutdown()) { return; } try { - executor.getQueue().put(r); + if (enableTimeout) { + if (!executor.getQueue().offer(r, timeout, TimeUnit.MILLISECONDS)) { + throw new RejectedExecutionException("Task " + r.toString() + + " rejected from " + + executor.toString() + " with timeout " + timeout + "ms."); + } + } + else { + executor.getQueue().put(r); + } } catch (InterruptedException e) { log.error("Adding Queue task to thread pool failed.", e); } diff --git a/infra/common/src/test/java/cn/hippo4j/common/executor/support/SyncPutQueuePolicyTest.java b/infra/common/src/test/java/cn/hippo4j/common/executor/support/SyncPutQueuePolicyTest.java index c2f0a35987..fef14769ae 100644 --- a/infra/common/src/test/java/cn/hippo4j/common/executor/support/SyncPutQueuePolicyTest.java +++ b/infra/common/src/test/java/cn/hippo4j/common/executor/support/SyncPutQueuePolicyTest.java @@ -22,20 +22,23 @@ import org.junit.Test; import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.stream.IntStream; +import static org.junit.Assert.fail; + /** * Synchronous placement queue policy implementation test */ public class SyncPutQueuePolicyTest { /** - * test thread pool rejected execution + * test thread pool rejected execution without timeout */ @Test - public void testRejectedExecution() { + public void testRejectedExecutionWithoutTimeout() { SyncPutQueuePolicy syncPutQueuePolicy = new SyncPutQueuePolicy(); ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(1, 2, 60, TimeUnit.SECONDS, new ArrayBlockingQueue<>(1), syncPutQueuePolicy); @@ -50,4 +53,34 @@ public void testRejectedExecution() { } Assert.assertEquals(4, threadPoolExecutor.getCompletedTaskCount()); } + + /** + * test thread pool rejected execution with timeout + */ + @Test + public void testRejectedExecutionWithTimeout() { + SyncPutQueuePolicy syncPutQueuePolicy = new SyncPutQueuePolicy(300); + ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(1, 1, + 60, TimeUnit.SECONDS, new ArrayBlockingQueue<>(1), syncPutQueuePolicy); + threadPoolExecutor.prestartAllCoreThreads(); + + Assert.assertSame(syncPutQueuePolicy, threadPoolExecutor.getRejectedExecutionHandler()); + IntStream.range(0, 4).forEach(s -> { + threadPoolExecutor.execute(() -> ThreadUtil.sleep(200L)); + }); + IntStream.range(0, 2).forEach(s -> { + threadPoolExecutor.execute(() -> ThreadUtil.sleep(500L)); + }); + try { + threadPoolExecutor.execute(() -> ThreadUtil.sleep(100L)); + ThreadUtil.sleep(1000L); + fail("should throw RejectedExecutionException"); + } catch (Exception e) { + Assert.assertTrue(e instanceof RejectedExecutionException); + } + threadPoolExecutor.shutdown(); + while (!threadPoolExecutor.isTerminated()) { + } + Assert.assertEquals(6, threadPoolExecutor.getCompletedTaskCount()); + } }