Skip to content

Commit

Permalink
Enable periodic restart in non user code isolation mode
Browse files Browse the repository at this point in the history
  • Loading branch information
arjan-bal committed Apr 10, 2024
1 parent 038c1c7 commit a7fc249
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.cdap.common.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import java.io.IOException;
import java.net.ConnectException;
import java.net.HttpURLConnection;
import java.net.InetSocketAddress;
import java.net.URI;
Expand All @@ -44,6 +45,7 @@
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
Expand All @@ -60,6 +62,7 @@
* Unit test for {@link TaskWorkerService}.
*/
public class TaskWorkerServiceTest {

@ClassRule
public static final TemporaryFolder TEMP_FOLDER = new TemporaryFolder();

Expand Down Expand Up @@ -91,9 +94,11 @@ public void beforeTest() {

InMemoryDiscoveryService discoveryService = new InMemoryDiscoveryService();
TaskWorkerService taskWorkerService = new TaskWorkerService(
cConf, sConf, discoveryService, discoveryService, metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(taskWorkerService);
cConf, sConf, discoveryService, discoveryService,
metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(
taskWorkerService);
// start the service
taskWorkerService.startAndWait();
this.taskWorkerService = taskWorkerService;
Expand All @@ -116,9 +121,11 @@ public void testPeriodicRestart() {

InMemoryDiscoveryService discoveryService = new InMemoryDiscoveryService();
TaskWorkerService taskWorkerService = new TaskWorkerService(
cConf, sConf, discoveryService, discoveryService, metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(taskWorkerService);
cConf, sConf, discoveryService, discoveryService,
metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(
taskWorkerService);
// start the service
taskWorkerService.startAndWait();

Expand All @@ -135,24 +142,28 @@ public void testPeriodicRestartWithInflightRequest() throws IOException {

InMemoryDiscoveryService discoveryService = new InMemoryDiscoveryService();
TaskWorkerService taskWorkerService = new TaskWorkerService(
cConf, sConf, discoveryService, discoveryService, metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(taskWorkerService);
cConf, sConf, discoveryService, discoveryService,
metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(
taskWorkerService);
// start the service
taskWorkerService.startAndWait();

InetSocketAddress addr = taskWorkerService.getBindAddress();
URI uri = URI.create(String.format("http://%s:%s", addr.getHostName(), addr.getPort()));
URI uri = URI.create(
String.format("http://%s:%s", addr.getHostName(), addr.getPort()));

// Post valid request
String want = "5000";
RunnableTaskRequest req = RunnableTaskRequest.getBuilder(TestRunnableClass.class.getName())
RunnableTaskRequest req = RunnableTaskRequest.getBuilder(
TestRunnableClass.class.getName())
.withParam(want).withNamespace("testNamespace").build();
String reqBody = GSON.toJson(req);
HttpResponse response = HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));

Assert.assertEquals(HttpURLConnection.HTTP_OK, response.getResponseCode());
Assert.assertEquals(want, response.getResponseBodyAsString());
Expand All @@ -161,11 +172,12 @@ public void testPeriodicRestartWithInflightRequest() throws IOException {
}

@Test
public void testPeriodicRestartWithNeverEndingInflightRequest() throws IOException {
public void testPeriodicRestartWithNeverEndingInflightRequest() {
CConfiguration cConf = createCConf();
SConfiguration sConf = createSConf();
cConf.setInt(Constants.TaskWorker.CONTAINER_KILL_AFTER_REQUEST_COUNT, 10);
cConf.setInt(Constants.TaskWorker.CONTAINER_KILL_AFTER_DURATION_SECOND, 2);
cConf.setInt(TaskWorker.TASK_EXECUTION_DEADLINE_SECOND, -1);

InMemoryDiscoveryService discoveryService = new InMemoryDiscoveryService();
TaskWorkerService taskWorkerService =
Expand All @@ -176,18 +188,21 @@ public void testPeriodicRestartWithNeverEndingInflightRequest() throws IOExcepti
discoveryService,
metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(taskWorkerService);
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(
taskWorkerService);
// start the service
taskWorkerService.startAndWait();

new Thread(
() -> {
InetSocketAddress addr = taskWorkerService.getBindAddress();
URI uri = URI.create(String.format("http://%s:%s", addr.getHostName(), addr.getPort()));
URI uri = URI.create(String.format("http://%s:%s", addr.getHostName(),
addr.getPort()));
// Post valid request
RunnableTaskRequest req =
RunnableTaskRequest.getBuilder(TestRunnableClass.class.getName())
.withParam("200000")
.withNamespace("testNamespace")
.build();
String reqBody = GSON.toJson(req);
try {
Expand All @@ -205,6 +220,7 @@ public void testPeriodicRestartWithNeverEndingInflightRequest() throws IOExcepti
TaskWorkerTestUtil.waitForServiceCompletion(serviceCompletionFuture);
Assert.assertEquals(Service.State.TERMINATED, taskWorkerService.state());
}

@Test
public void testRestartAfterMultipleExecutions() throws IOException {
CConfiguration cConf = createCConf();
Expand All @@ -214,29 +230,33 @@ public void testRestartAfterMultipleExecutions() throws IOException {

InMemoryDiscoveryService discoveryService = new InMemoryDiscoveryService();
TaskWorkerService taskWorkerService = new TaskWorkerService(
cConf, sConf, discoveryService, discoveryService, metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(taskWorkerService);
cConf, sConf, discoveryService, discoveryService,
metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(
taskWorkerService);
// start the service
taskWorkerService.startAndWait();

InetSocketAddress addr = taskWorkerService.getBindAddress();
URI uri = URI.create(String.format("http://%s:%s", addr.getHostName(), addr.getPort()));
URI uri = URI.create(
String.format("http://%s:%s", addr.getHostName(), addr.getPort()));

// Post valid request
String want = "100";
RunnableTaskRequest req = RunnableTaskRequest.getBuilder(TestRunnableClass.class.getName())
RunnableTaskRequest req = RunnableTaskRequest.getBuilder(
TestRunnableClass.class.getName())
.withParam(want).withNamespace("testNamespace").build();
String reqBody = GSON.toJson(req);
HttpResponse response = HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));
HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));

response = HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));
HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));

TaskWorkerTestUtil.waitForServiceCompletion(serviceCompletionFuture);
Assert.assertEquals(Service.State.TERMINATED, taskWorkerService.state());
Expand All @@ -245,17 +265,19 @@ public void testRestartAfterMultipleExecutions() throws IOException {
@Test
public void testStartAndStopWithValidRequest() throws IOException {
InetSocketAddress addr = taskWorkerService.getBindAddress();
URI uri = URI.create(String.format("http://%s:%s", addr.getHostName(), addr.getPort()));
URI uri = URI.create(
String.format("http://%s:%s", addr.getHostName(), addr.getPort()));

// Post valid request
String want = "100";
RunnableTaskRequest req = RunnableTaskRequest.getBuilder(TestRunnableClass.class.getName())
RunnableTaskRequest req = RunnableTaskRequest.getBuilder(
TestRunnableClass.class.getName())
.withParam(want).withNamespace("testNamespace").build();
String reqBody = GSON.toJson(req);
HttpResponse response = HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));
TaskWorkerTestUtil.waitForServiceCompletion(serviceCompletionFuture);
Assert.assertEquals(HttpURLConnection.HTTP_OK, response.getResponseCode());
Assert.assertEquals(want, response.getResponseBodyAsString());
Expand All @@ -265,20 +287,24 @@ public void testStartAndStopWithValidRequest() throws IOException {
@Test
public void testStartAndStopWithInvalidRequest() throws Exception {
InetSocketAddress addr = taskWorkerService.getBindAddress();
URI uri = URI.create(String.format("http://%s:%s", addr.getHostName(), addr.getPort()));
URI uri = URI.create(
String.format("http://%s:%s", addr.getHostName(), addr.getPort()));

// Post invalid request
RunnableTaskRequest noClassReq = RunnableTaskRequest.getBuilder("NoClass")
.withNamespace("testNamespace").withParam("100").build();
String reqBody = GSON.toJson(noClassReq);
HttpResponse response = HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));
Assert.assertEquals(HttpURLConnection.HTTP_BAD_REQUEST, response.getResponseCode());
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false));
Assert.assertEquals(HttpURLConnection.HTTP_BAD_REQUEST,
response.getResponseCode());
BasicThrowable basicThrowable;
basicThrowable = GSON.fromJson(response.getResponseBodyAsString(), BasicThrowable.class);
Assert.assertTrue(basicThrowable.getClassName().contains("java.lang.ClassNotFoundException"));
basicThrowable = GSON.fromJson(response.getResponseBodyAsString(),
BasicThrowable.class);
Assert.assertTrue(basicThrowable.getClassName()
.contains("java.lang.ClassNotFoundException"));
Assert.assertNotNull(basicThrowable.getMessage());
Assert.assertTrue(basicThrowable.getMessage().contains("NoClass"));
Assert.assertNotEquals(basicThrowable.getStackTraces().length, 0);
Expand All @@ -300,18 +326,20 @@ public void testConcurrentRequestsWithIsolationEnabled() throws Exception {

for (int i = 0; i < concurrentRequests; i++) {
calls.add(
() -> HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false))
() -> HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(reqBody).build(),
new DefaultHttpRequestConfig(false))
);
}

List<Future<HttpResponse>> responses = Executors.newFixedThreadPool(concurrentRequests).invokeAll(calls);
List<Future<HttpResponse>> responses = Executors.newFixedThreadPool(
concurrentRequests).invokeAll(calls);
int okResponse = 0;
int conflictResponse = 0;
for (int i = 0; i < concurrentRequests; i++) {
if (responses.get(i).get().getResponseCode() == HttpResponseStatus.OK.code()) {
if (responses.get(i).get().getResponseCode()
== HttpResponseStatus.OK.code()) {
okResponse++;
} else if (responses.get(i).get().getResponseCode()
== HttpResponseStatus.TOO_MANY_REQUESTS.code()) {
Expand Down Expand Up @@ -371,7 +399,7 @@ public void testConcurrentRequestsWithIsolationDisabled() throws Exception {
}
// Verify that the task worker service doesn't stop automatically.
try {
Tasks.waitFor(false, () -> taskWorkerService.isRunning(), 1,
Tasks.waitFor(false, taskWorkerService::isRunning, 1,
TimeUnit.SECONDS);
Assert.fail();
} catch (TimeoutException e) {
Expand All @@ -383,6 +411,81 @@ public void testConcurrentRequestsWithIsolationDisabled() throws Exception {
Assert.assertEquals(Service.State.TERMINATED, taskWorkerService.state());
}

@Test
public void testRestartWithConcurrentRequests() throws Exception {
CConfiguration cConf = createCConf();
cConf.setInt(TaskWorker.REQUEST_LIMIT, 3);
cConf.setBoolean(TaskWorker.USER_CODE_ISOLATION_ENABLED, false);
cConf.setInt(Constants.TaskWorker.CONTAINER_KILL_AFTER_DURATION_SECOND, 2);
cConf.setInt(Constants.TaskWorker.TASK_EXECUTION_DEADLINE_SECOND, 1);
InMemoryDiscoveryService discoveryService = new InMemoryDiscoveryService();
TaskWorkerService taskWorkerService = new TaskWorkerService(cConf,
createSConf(), discoveryService, discoveryService,
metricsCollectionService,
new CommonNettyHttpServiceFactory(cConf, metricsCollectionService));
serviceCompletionFuture = TaskWorkerTestUtil.getServiceCompletionFuture(
taskWorkerService);
taskWorkerService.startAndWait();
InetSocketAddress addr = taskWorkerService.getBindAddress();
URI uri = URI.create(
String.format("http://%s:%s", addr.getHostName(), addr.getPort()));

List<Callable<HttpResponse>> calls = new ArrayList<>();
int concurrentRequests = 2;

for (int i = 0; i < concurrentRequests; i++) {
RunnableTaskRequest request = RunnableTaskRequest.getBuilder(
TestRunnableClass.class.getName())
.withParam("100").withNamespace("testNamespace").build();
calls.add(
() -> HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(GSON.toJson(request)).build(),
new DefaultHttpRequestConfig(false))
);
}

// Send a request that never ends.
RunnableTaskRequest slowRequest = RunnableTaskRequest.getBuilder(
TestRunnableClass.class.getName())
.withParam("1000000").withNamespace("testNamespace").build();

calls.add(
() -> HttpRequests.execute(
HttpRequest.post(uri.resolve("/v3Internal/worker/run").toURL())
.withBody(GSON.toJson(slowRequest)).build(),
new DefaultHttpRequestConfig(false))
);

List<Future<HttpResponse>> responses = Executors.newFixedThreadPool(
concurrentRequests).invokeAll(calls);

int okResponse = 0;
int connectionRefusedCount = 0;
for (Future<HttpResponse> response : responses) {
try {
final int responseCode = response.get().getResponseCode();
if (responseCode == HttpResponseStatus.OK.code()) {
okResponse++;
}
} catch (ExecutionException ex) {
if (ex.getCause() instanceof ConnectException) {
connectionRefusedCount++;
} else {
throw ex;
}
}
}

// Verify that the task worker service has stopped automatically.
Assert.assertEquals(2, okResponse);
// The slow request will receive a "connection refused" response once the task
// worker service stops.
Assert.assertEquals(connectionRefusedCount, 1);
TaskWorkerTestUtil.waitForServiceCompletion(serviceCompletionFuture);
Assert.assertEquals(Service.State.TERMINATED, taskWorkerService.state());
}

public static class TestRunnableClass implements RunnableTask {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,8 @@ public static final class TaskWorker {
"task.worker.container.kill.after.duration.second";
public static final String REQUEST_LIMIT = "task.worker.request.limit";
public static final String USER_CODE_ISOLATION_ENABLED = "task.worker.request.userCodeIsolation.enabled";
public static final String TASK_EXECUTION_DEADLINE_SECOND =
"task.worker.taskExecutionDeadline.second";
public static final String CONTAINER_RUN_AS_USER = "task.worker.container.run.as.user";
public static final String CONTAINER_RUN_AS_GROUP = "task.worker.container.run.as.group";
public static final String CONTAINER_DISK_READONLY = "task.worker.container.disk.readonly";
Expand Down
Loading

0 comments on commit a7fc249

Please sign in to comment.