Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: auto-upgrade flag for applications + force restart #763

Merged
merged 14 commits into from
May 20, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,22 @@ public String get(String archetype) {
private class ApplicationsImpl implements Applications {
@Override
public String deploy(String application, MultiPartBodyPublisher multiPartBodyPublisher) {
return deploy(application, multiPartBodyPublisher, false);
return deploy(application, multiPartBodyPublisher, false, false);
}

@Override
@SneakyThrows
public String deploy(
String application, MultiPartBodyPublisher multiPartBodyPublisher, boolean dryRun) {
final String path = tenantAppPath("/" + application) + "?dry-run=" + dryRun;
String application,
MultiPartBodyPublisher multiPartBodyPublisher,
boolean dryRun,
boolean autoUpgrade) {
final String path =
tenantAppPath("/" + application)
+ "?dry-run="
+ dryRun
+ "&auto-upgrade="
+ autoUpgrade;
final String contentType =
String.format(
"multipart/form-data; boundary=%s",
Expand Down Expand Up @@ -280,8 +288,17 @@ public String deployFromArchetype(

@Override
@SneakyThrows
public void update(String application, MultiPartBodyPublisher multiPartBodyPublisher) {
final String path = tenantAppPath("/" + application);
public void update(
String application,
MultiPartBodyPublisher multiPartBodyPublisher,
boolean autoUpgrade,
boolean forceRestart) {
final String path =
tenantAppPath("/" + application)
+ "?auto-upgrade="
+ autoUpgrade
+ "&force-restart="
+ forceRestart;
final String contentType =
String.format(
"multipart/form-data; boundary=%s",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@ public interface Applications {
String deploy(String application, MultiPartBodyPublisher multiPartBodyPublisher);

String deploy(
String application, MultiPartBodyPublisher multiPartBodyPublisher, boolean dryRun);
String application,
MultiPartBodyPublisher multiPartBodyPublisher,
boolean dryRun,
boolean autoUpgrade);

void update(String application, MultiPartBodyPublisher multiPartBodyPublisher);
void update(
String application,
MultiPartBodyPublisher multiPartBodyPublisher,
boolean autoUpgrade,
boolean forceRestart);

void delete(String application, boolean force);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,9 @@ public TopicProducer getOrCreate(
throw new RuntimeException(ex);
}
}

@Override
public void close() {
cache.invalidateAll();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import ai.langstream.api.runner.topics.TopicProducer;
import java.util.function.Supplier;

public interface TopicProducerCache {
public interface TopicProducerCache extends AutoCloseable {
record Key(
String tenant,
String application,
Expand All @@ -27,4 +27,7 @@ record Key(
String configString) {}

TopicProducer getOrCreate(Key key, Supplier<TopicProducer> topicProducerSupplier);

@Override
void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
*/
package ai.langstream.apigateway.gateways;

import ai.langstream.api.runner.topics.TopicProducer;
import ai.langstream.apigateway.MetricsNames;
import ai.langstream.apigateway.config.TopicProperties;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.binder.cache.GuavaCacheMetrics;
import java.util.function.Supplier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

Expand All @@ -34,7 +36,16 @@ public TopicProducerCache topicProducerCache(TopicProperties topicProperties) {
Metrics.globalRegistry, cache.getCache(), MetricsNames.TOPIC_PRODUCER_CACHE);
return cache;
} else {
return (key, topicProducerSupplier) -> topicProducerSupplier.get();
return new TopicProducerCache() {
@Override
public TopicProducer getOrCreate(
Key key, Supplier<TopicProducer> topicProducerSupplier) {
return topicProducerSupplier.get();
}

@Override
public void close() {}
};
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,7 @@ private CompletableFuture<ResponseEntity> handleServiceWithTopics(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry(),
clusterRuntimeRegistry);
completableFuture.thenRunAsync(
() -> {
if (consumeGateway != null) {
consumeGateway.close();
}
},
consumeThreadPool);
completableFuture.thenRunAsync(consumeGateway::close, consumeThreadPool);

final Gateway.ServiceOptions serviceOptions = authContext.gateway().getServiceOptions();
try {
Expand All @@ -297,7 +291,7 @@ record -> {
final AtomicBoolean stop = new AtomicBoolean(false);
consumeGateway.startReadingAsync(
consumeThreadPool,
() -> stop.get(),
stop::get,
record -> {
stop.set(true);
completableFuture.complete(ResponseEntity.ok(record));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ public ServletServerContainerFactoryBean createWebSocketContainer() {

@PreDestroy
public void onDestroy() {
consumeThreadPool.shutdown();
log.info("Shutting down WebSocket");
consumeThreadPool.shutdownNow();
clusterRuntimeRegistry.close();
topicProducerCache.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import ai.langstream.api.runner.topics.TopicConsumer;
import ai.langstream.api.runner.topics.TopicProducer;
import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import ai.langstream.api.runtime.DeployContext;
import ai.langstream.api.runtime.PluginsRegistry;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.api.ConsumePushMessage;
Expand Down Expand Up @@ -569,8 +570,7 @@ void testService() throws Exception {
url,
"{\"key\": \"my-key2\", \"value\": \"my-value\", \"headers\": {\"header1\":\"value1\"}}"));

// sorry but kafka can't keep up
final int numParallel = getStreamingCluster().type().equals("kafka") ? 5 : 30;
final int numParallel = 10;

List<CompletableFuture<Void>> futures1 = new ArrayList<>();
for (int i = 0; i < numParallel; i++) {
Expand All @@ -588,7 +588,7 @@ void testService() throws Exception {
futures1.add(future);
}
CompletableFuture.allOf(futures1.toArray(new CompletableFuture[] {}))
.get(2, TimeUnit.MINUTES);
.get(3, TimeUnit.MINUTES);
}

private void startTopicExchange(String logicalFromTopic, String logicalToTopic)
Expand Down Expand Up @@ -678,6 +678,7 @@ private void prepareTopicsForTest(String... topic) throws Exception {
.pluginsRegistry(new PluginsRegistry())
.registry(new ClusterRuntimeRegistry())
.topicConnectionsRuntimeRegistry(topicConnectionsRuntimeRegistry)
.deployContext(DeployContext.NO_DEPLOY_CONTEXT)
.build();
final StreamingCluster streamingCluster = getStreamingCluster();
topicConnectionsRuntimeRegistry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import ai.langstream.api.model.StreamingCluster;
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runtime.ClusterRuntimeRegistry;
import ai.langstream.api.runtime.DeployContext;
import ai.langstream.api.runtime.PluginsRegistry;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.api.ConsumePushMessage;
Expand Down Expand Up @@ -267,6 +268,7 @@ private void prepareTopicsForTest(String... topic) throws Exception {
.pluginsRegistry(new PluginsRegistry())
.registry(new ClusterRuntimeRegistry())
.topicConnectionsRuntimeRegistry(topicConnectionsRuntimeRegistry)
.deployContext(DeployContext.NO_DEPLOY_CONTEXT)
.build();
final StreamingCluster streamingCluster = getStreamingCluster();
topicConnectionsRuntimeRegistry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,62 @@
*/
package ai.langstream.api.runtime;

import ai.langstream.api.runner.assets.AssetManagerRegistry;
import ai.langstream.api.webservice.application.ApplicationCodeInfo;

public interface DeployContext extends AutoCloseable {

default ApplicationCodeInfo getApplicationCodeInfo(
String tenant, String applicationId, String codeArchiveId) {
throw new UnsupportedOperationException();
}
DeployContext NO_DEPLOY_CONTEXT = new NoOpDeployContext();

class NoOpDeployContext implements DeployContext {

@Override
public ApplicationCodeInfo getApplicationCodeInfo(
String tenant, String applicationId, String codeArchiveId) {
return null;
}

@Override
public boolean isAutoUpgradeRuntimeImage() {
return false;
}

@Override
public boolean isAutoUpgradeRuntimeImagePullPolicy() {
return false;
}

@Override
public boolean isAutoUpgradeAgentResources() {
return false;
}

@Override
public boolean isAutoUpgradeAgentPodTemplate() {
return false;
}

default AssetManagerRegistry getAssetManagerRegistry() {
throw new UnsupportedOperationException();
@Override
public long getApplicationSeed() {
return -1L;
}

@Override
public void close() {}
}

ApplicationCodeInfo getApplicationCodeInfo(
String tenant, String applicationId, String codeArchiveId);

boolean isAutoUpgradeRuntimeImage();

boolean isAutoUpgradeRuntimeImagePullPolicy();

boolean isAutoUpgradeAgentResources();

boolean isAutoUpgradeAgentPodTemplate();

long getApplicationSeed();

@Override
default void close() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ void put(
String applicationId,
Application applicationInstance,
String codeArchiveReference,
ExecutionPlan executionPlan);
ExecutionPlan executionPlan,
boolean autoUpgrade,
boolean forceRestart);

StoredApplication get(String tenant, String applicationId, boolean queryPods);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ public static class DeployApplicationCmd extends AbstractDeployApplicationCmd {
"Output format for dry-run mode. Formats are: yaml, json. Default value is yaml.")
private Formats format = Formats.yaml;

@CommandLine.Option(
names = {"--auto-upgrade"},
description =
"Whether to make the executors to automatically upgrades the environment (image, resources mapping etc.) when restarted")
private boolean autoUpgrade;

@Override
String applicationId() {
return name;
Expand Down Expand Up @@ -94,6 +100,16 @@ boolean isDryRun() {
return dryRun;
}

@Override
boolean isAutoUpgrade() {
return autoUpgrade;
}

@Override
boolean isForceRestart() {
return false;
}

@Override
Formats format() {
ensureFormatIn(format, Formats.json, Formats.yaml);
Expand Down Expand Up @@ -122,6 +138,17 @@ public static class UpdateApplicationCmd extends AbstractDeployApplicationCmd {
description = "Secrets file path")
private String secretFilePath;

@CommandLine.Option(
names = {"--auto-upgrade"},
description =
"Whether to make the executors to automatically upgrades the environment (image, resources mapping etc.) when restarted")
private boolean autoUpgrade;

@CommandLine.Option(
names = {"--force-restart"},
description = "Whether to make force restart all the executors of the application")
private boolean forceRestart;

@Override
String applicationId() {
return name;
Expand Down Expand Up @@ -152,6 +179,16 @@ boolean isDryRun() {
return false;
}

@Override
boolean isAutoUpgrade() {
return autoUpgrade;
}

@Override
boolean isForceRestart() {
return forceRestart;
}

@Override
Formats format() {
return null;
Expand All @@ -170,6 +207,10 @@ Formats format() {

abstract boolean isDryRun();

abstract boolean isAutoUpgrade();

abstract boolean isForceRestart();

abstract Formats format();

@Override
Expand Down Expand Up @@ -229,7 +270,9 @@ public void run() {

if (isUpdate()) {
log(String.format("updating application: %s (%d KB)", applicationId, size / 1024));
getClient().applications().update(applicationId, bodyPublisher);
getClient()
.applications()
.update(applicationId, bodyPublisher, isAutoUpgrade(), isForceRestart());
log(String.format("application %s updated", applicationId));
} else {
final boolean dryRun = isDryRun();
Expand All @@ -242,7 +285,9 @@ public void run() {
log(String.format("deploying application: %s (%d KB)", applicationId, size / 1024));
}
final String response =
getClient().applications().deploy(applicationId, bodyPublisher, dryRun);
getClient()
.applications()
.deploy(applicationId, bodyPublisher, dryRun, isAutoUpgrade());
if (dryRun) {
final Formats format = format();
print(format == Formats.raw ? Formats.yaml : format, response, null, null);
Expand Down
Loading
Loading