Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CDAP-20832] Enable periodic restart when task workers are running concurrent requests #15575

Merged
merged 5 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.cdap.common.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import java.io.IOException;
import java.net.ConnectException;
import java.net.HttpURLConnection;
import java.net.InetSocketAddress;
import java.net.URI;
Expand All @@ -44,6 +45,7 @@
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
Expand All @@ -60,6 +62,7 @@
* Unit test for {@link TaskWorkerService}.
*/
public class TaskWorkerServiceTest {

@ClassRule
public static final TemporaryFolder TEMP_FOLDER = new TemporaryFolder();

Expand All @@ -69,7 +72,7 @@
private TaskWorkerService taskWorkerService;
private CompletableFuture<Service.State> serviceCompletionFuture;

private CConfiguration createCConf() {

Check warning on line 75 in cdap-app-fabric/src/test/java/io/cdap/cdap/internal/app/worker/TaskWorkerServiceTest.java

View workflow job for this annotation

GitHub Actions / Checkstyle

com.puppycrawl.tools.checkstyle.checks.naming.AbbreviationAsWordInNameCheck

Abbreviation in name 'createCConf' must contain no more than '1' consecutive capital letters.
CConfiguration cConf = CConfiguration.create();
cConf.set(Constants.TaskWorker.ADDRESS, "localhost");
cConf.setInt(Constants.TaskWorker.PORT, 0);
Expand All @@ -80,7 +83,7 @@
return cConf;
}

private SConfiguration createSConf() {

Check warning on line 86 in cdap-app-fabric/src/test/java/io/cdap/cdap/internal/app/worker/TaskWorkerServiceTest.java

View workflow job for this annotation

GitHub Actions / Checkstyle

com.puppycrawl.tools.checkstyle.checks.naming.AbbreviationAsWordInNameCheck

Abbreviation in name 'createSConf' must contain no more than '1' consecutive capital letters.
return SConfiguration.create();
}

Expand All @@ -91,9 +94,11 @@

InMemoryDiscoveryService discoveryService = new InMemoryDiscoveryService();
TaskWorkerService taskWorkerService = new TaskWorkerService(
cConf, sConf, discoveryService, discoveryService, metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(taskWorkerService);
cConf, sConf, discoveryService, discoveryService,
metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(
taskWorkerService);
// start the service
taskWorkerService.startAndWait();
this.taskWorkerService = taskWorkerService;
Expand All @@ -116,9 +121,11 @@

InMemoryDiscoveryService discoveryService = new InMemoryDiscoveryService();
TaskWorkerService taskWorkerService = new TaskWorkerService(
cConf, sConf, discoveryService, discoveryService, metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(taskWorkerService);
cConf, sConf, discoveryService, discoveryService,
metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(
taskWorkerService);
// start the service
taskWorkerService.startAndWait();

Expand All @@ -135,24 +142,28 @@

InMemoryDiscoveryService discoveryService = new InMemoryDiscoveryService();
TaskWorkerService taskWorkerService = new TaskWorkerService(
cConf, sConf, discoveryService, discoveryService, metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(taskWorkerService);
cConf, sConf, discoveryService, discoveryService,
metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(
taskWorkerService);
// start the service
taskWorkerService.startAndWait();

InetSocketAddress addr = taskWorkerService.getBindAddress();
URI uri = URI.create(String.format("http://%s:%s", addr.getHostName(), addr.getPort()));
URI uri = URI.create(
String.format("http://%s:%s", addr.getHostName(), addr.getPort()));

// Post valid request
String want = "5000";
RunnableTaskRequest req = RunnableTaskRequest.getBuilder(TestRunnableClass.class.getName())
RunnableTaskRequest req = RunnableTaskRequest.getBuilder(
TestRunnableClass.class.getName())
.withParam(want).withNamespace("testNamespace").build();
String reqBody = GSON.toJson(req);
HttpResponse response = HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));

Assert.assertEquals(HttpURLConnection.HTTP_OK, response.getResponseCode());
Assert.assertEquals(want, response.getResponseBodyAsString());
Expand All @@ -161,11 +172,12 @@
}

@Test
public void testPeriodicRestartWithNeverEndingInflightRequest() throws IOException {
public void testPeriodicRestartWithNeverEndingInflightRequest() {
CConfiguration cConf = createCConf();
SConfiguration sConf = createSConf();
cConf.setInt(Constants.TaskWorker.CONTAINER_KILL_AFTER_REQUEST_COUNT, 10);
cConf.setInt(Constants.TaskWorker.CONTAINER_KILL_AFTER_DURATION_SECOND, 2);
cConf.setInt(TaskWorker.TASK_EXECUTION_DEADLINE_SECOND, -1);

InMemoryDiscoveryService discoveryService = new InMemoryDiscoveryService();
TaskWorkerService taskWorkerService =
Expand All @@ -176,18 +188,21 @@
discoveryService,
metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(taskWorkerService);
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(
taskWorkerService);
// start the service
taskWorkerService.startAndWait();

new Thread(
() -> {
InetSocketAddress addr = taskWorkerService.getBindAddress();
URI uri = URI.create(String.format("http://%s:%s", addr.getHostName(), addr.getPort()));
URI uri = URI.create(String.format("http://%s:%s", addr.getHostName(),
addr.getPort()));
// Post valid request
RunnableTaskRequest req =
RunnableTaskRequest.getBuilder(TestRunnableClass.class.getName())
.withParam("200000")
.withNamespace("testNamespace")
.build();
String reqBody = GSON.toJson(req);
try {
Expand All @@ -205,6 +220,7 @@
TaskWorkerTestUtil.waitForServiceCompletion(serviceCompletionFuture);
Assert.assertEquals(Service.State.TERMINATED, taskWorkerService.state());
}

@Test
public void testRestartAfterMultipleExecutions() throws IOException {
CConfiguration cConf = createCConf();
Expand All @@ -214,29 +230,33 @@

InMemoryDiscoveryService discoveryService = new InMemoryDiscoveryService();
TaskWorkerService taskWorkerService = new TaskWorkerService(
cConf, sConf, discoveryService, discoveryService, metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(taskWorkerService);
cConf, sConf, discoveryService, discoveryService,
metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(
taskWorkerService);
// start the service
taskWorkerService.startAndWait();

InetSocketAddress addr = taskWorkerService.getBindAddress();
URI uri = URI.create(String.format("http://%s:%s", addr.getHostName(), addr.getPort()));
URI uri = URI.create(
String.format("http://%s:%s", addr.getHostName(), addr.getPort()));

// Post valid request
String want = "100";
RunnableTaskRequest req = RunnableTaskRequest.getBuilder(TestRunnableClass.class.getName())
RunnableTaskRequest req = RunnableTaskRequest.getBuilder(
TestRunnableClass.class.getName())
.withParam(want).withNamespace("testNamespace").build();
String reqBody = GSON.toJson(req);
HttpResponse response = HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));
HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));

response = HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));
HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));

TaskWorkerTestUtil.waitForServiceCompletion(serviceCompletionFuture);
Assert.assertEquals(Service.State.TERMINATED, taskWorkerService.state());
Expand All @@ -245,17 +265,19 @@
@Test
public void testStartAndStopWithValidRequest() throws IOException {
InetSocketAddress addr = taskWorkerService.getBindAddress();
URI uri = URI.create(String.format("http://%s:%s", addr.getHostName(), addr.getPort()));
URI uri = URI.create(
String.format("http://%s:%s", addr.getHostName(), addr.getPort()));

// Post valid request
String want = "100";
RunnableTaskRequest req = RunnableTaskRequest.getBuilder(TestRunnableClass.class.getName())
RunnableTaskRequest req = RunnableTaskRequest.getBuilder(
TestRunnableClass.class.getName())
.withParam(want).withNamespace("testNamespace").build();
String reqBody = GSON.toJson(req);
HttpResponse response = HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));
TaskWorkerTestUtil.waitForServiceCompletion(serviceCompletionFuture);
Assert.assertEquals(HttpURLConnection.HTTP_OK, response.getResponseCode());
Assert.assertEquals(want, response.getResponseBodyAsString());
Expand All @@ -265,20 +287,24 @@
@Test
public void testStartAndStopWithInvalidRequest() throws Exception {
InetSocketAddress addr = taskWorkerService.getBindAddress();
URI uri = URI.create(String.format("http://%s:%s", addr.getHostName(), addr.getPort()));
URI uri = URI.create(
String.format("http://%s:%s", addr.getHostName(), addr.getPort()));

// Post invalid request
RunnableTaskRequest noClassReq = RunnableTaskRequest.getBuilder("NoClass")
.withNamespace("testNamespace").withParam("100").build();
String reqBody = GSON.toJson(noClassReq);
HttpResponse response = HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));
Assert.assertEquals(HttpURLConnection.HTTP_BAD_REQUEST, response.getResponseCode());
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));
Assert.assertEquals(HttpURLConnection.HTTP_BAD_REQUEST,
response.getResponseCode());
BasicThrowable basicThrowable;
basicThrowable = GSON.fromJson(response.getResponseBodyAsString(), BasicThrowable.class);
Assert.assertTrue(basicThrowable.getClassName().contains("java.lang.ClassNotFoundException"));
basicThrowable = GSON.fromJson(response.getResponseBodyAsString(),
BasicThrowable.class);
Assert.assertTrue(basicThrowable.getClassName()
.contains("java.lang.ClassNotFoundException"));
Assert.assertNotNull(basicThrowable.getMessage());
Assert.assertTrue(basicThrowable.getMessage().contains("NoClass"));
Assert.assertNotEquals(basicThrowable.getStackTraces().length, 0);
Expand All @@ -300,18 +326,20 @@

for (int i = 0; i < concurrentRequests; i++) {
calls.add(
() -> HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false))
() -> HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false))
);
}

List<Future<HttpResponse>> responses = Executors.newFixedThreadPool(concurrentRequests).invokeAll(calls);
List<Future<HttpResponse>> responses = Executors.newFixedThreadPool(
concurrentRequests).invokeAll(calls);
int okResponse = 0;
int conflictResponse = 0;
for (int i = 0; i < concurrentRequests; i++) {
if (responses.get(i).get().getResponseCode() == HttpResponseStatus.OK.code()) {
if (responses.get(i).get().getResponseCode()
== HttpResponseStatus.OK.code()) {
okResponse++;
} else if (responses.get(i).get().getResponseCode()
== HttpResponseStatus.TOO_MANY_REQUESTS.code()) {
Expand Down Expand Up @@ -371,7 +399,7 @@
}
// Verify that the task worker service doesn't stop automatically.
try {
Tasks.waitFor(false, () -> taskWorkerService.isRunning(), 1,
Tasks.waitFor(false, taskWorkerService::isRunning, 1,
TimeUnit.SECONDS);
Assert.fail();
} catch (TimeoutException e) {
Expand All @@ -383,6 +411,81 @@
Assert.assertEquals(Service.State.TERMINATED, taskWorkerService.state());
}

@Test
public void testRestartWithConcurrentRequests() throws Exception {
CConfiguration cConf = createCConf();
cConf.setInt(TaskWorker.REQUEST_LIMIT, 3);
cConf.setBoolean(TaskWorker.USER_CODE_ISOLATION_ENABLED, false);
cConf.setInt(Constants.TaskWorker.CONTAINER_KILL_AFTER_DURATION_SECOND, 2);
cConf.setInt(Constants.TaskWorker.TASK_EXECUTION_DEADLINE_SECOND, 1);
InMemoryDiscoveryService discoveryService = new InMemoryDiscoveryService();
TaskWorkerService taskWorkerService = new TaskWorkerService(cConf,
createSConf(), discoveryService, discoveryService,
metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(
taskWorkerService);
taskWorkerService.startAndWait();
InetSocketAddress addr = taskWorkerService.getBindAddress();
URI uri = URI.create(
String.format("http://%s:%s", addr.getHostName(), addr.getPort()));

List<Callable<HttpResponse>> calls = new ArrayList<>();
int concurrentRequests = 2;

for (int i = 0; i < concurrentRequests; i++) {
RunnableTaskRequest request = RunnableTaskRequest.getBuilder(
TestRunnableClass.class.getName())
.withParam("100").withNamespace("testNamespace").build();
calls.add(
() -> HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(GSON.toJson(request)).build(),
new DefaultHttpRequestConfig(false))
);
}

// Send a request that never ends.
RunnableTaskRequest slowRequest = RunnableTaskRequest.getBuilder(
TestRunnableClass.class.getName())
.withParam("1000000").withNamespace("testNamespace").build();

calls.add(
() -> HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(GSON.toJson(slowRequest)).build(),
new DefaultHttpRequestConfig(false))
);

List<Future<HttpResponse>> responses = Executors.newFixedThreadPool(
concurrentRequests).invokeAll(calls);

int okResponse = 0;
int connectionRefusedCount = 0;
for (Future<HttpResponse> response : responses) {
try {
final int responseCode = response.get().getResponseCode();
if (responseCode == HttpResponseStatus.OK.code()) {
okResponse++;
}
} catch (ExecutionException ex) {
if (ex.getCause() instanceof ConnectException) {
connectionRefusedCount++;
} else {
throw ex;
}
}
}

// Verify that the task worker service has stopped automatically.
Assert.assertEquals(2, okResponse);
// The slow request will receive a "connection refused" response once the task
// worker service stops.
Assert.assertEquals(connectionRefusedCount, 1);
TaskWorkerTestUtil.waitForServiceCompletion(serviceCompletionFuture);
Assert.assertEquals(Service.State.TERMINATED, taskWorkerService.state());
}

public static class TestRunnableClass implements RunnableTask {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,8 @@ public static final class TaskWorker {
"task.worker.container.kill.after.duration.second";
public static final String REQUEST_LIMIT = "task.worker.request.limit";
public static final String USER_CODE_ISOLATION_ENABLED = "task.worker.request.userCodeIsolation.enabled";
public static final String TASK_EXECUTION_DEADLINE_SECOND =
"task.worker.taskExecutionDeadline.second";
public static final String CONTAINER_RUN_AS_USER = "task.worker.container.run.as.user";
public static final String CONTAINER_RUN_AS_GROUP = "task.worker.container.run.as.group";
public static final String CONTAINER_DISK_READONLY = "task.worker.container.disk.readonly";
Expand Down
Loading
Loading