Skip to content

Commit

Permalink
API gateway: cache topic producer for the same gateway
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloboschi committed Oct 27, 2023
1 parent 7897c34 commit 5f38cd8
Show file tree
Hide file tree
Showing 15 changed files with 314 additions and 31 deletions.
4 changes: 4 additions & 0 deletions langstream-api-gateway/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@
<artifactId>langstream-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>

<dependency>
<groupId>ai.langstream</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties;
import ai.langstream.apigateway.config.StorageProperties;
import ai.langstream.apigateway.config.TopicProperties;
import ai.langstream.apigateway.runner.CodeConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -29,7 +30,8 @@
@EnableConfigurationProperties({
StorageProperties.class,
GatewayTestAuthenticationProperties.class,
CodeConfiguration.class
CodeConfiguration.class,
TopicProperties.class
})
public class LangStreamApiGateway {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.apigateway.config;

import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.HashMap;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.springframework.boot.context.properties.ConfigurationProperties;

@ConfigurationProperties(prefix = "application.topics")
@Data
@NoArgsConstructor
@AllArgsConstructor
public class TopicProperties {

@JsonProperty("producers-cache-enabled")
private boolean producersCacheEnabled;
@JsonProperty("producers-cache-size")
private int producersCacheSize;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package ai.langstream.apigateway.gateways;

import ai.langstream.api.runner.code.Record;
import ai.langstream.api.runner.topics.TopicProducer;
import ai.langstream.api.runtime.Topic;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.RemovalNotification;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

public class LRUTopicProducerCache implements TopicProducerCache {


private static class SharedTopicProducer implements TopicProducer {
private final TopicProducer producer;
private volatile int referenceCount;
private volatile boolean cached = true;

public SharedTopicProducer(TopicProducer producer) {
this.producer = producer;
}

public void acquire() {
synchronized (this) {
referenceCount++;
}
}

public void removedFromCache() {
synchronized (this) {
cached = false;
if (referenceCount == 0) {
producer.close();
}
}
}

@Override
public void start() {
producer.start();
}

@Override
public void close() {
synchronized (this) {
referenceCount--;
if (referenceCount == 0 && !cached) {
producer.close();
}
}
}

@Override
public CompletableFuture<?> write(Record record) {
return producer.write(record);
}

@Override
public Object getNativeProducer() {
return producer.getNativeProducer();
}

@Override
public Object getInfo() {
return producer.getInfo();
}

@Override
public long getTotalIn() {
return 0;
}
}


final Cache<Key, SharedTopicProducer> cache;

public LRUTopicProducerCache(int size) {
this.cache = CacheBuilder
.newBuilder()
.maximumSize(size)
.expireAfterWrite(10, TimeUnit.MINUTES)
.expireAfterAccess(10, TimeUnit.MINUTES)
.removalListener((RemovalNotification<Key, SharedTopicProducer> notification) -> {
SharedTopicProducer resource = notification.getValue();
resource.removedFromCache();
})
.build();
}

@Override
public synchronized TopicProducer getOrCreate(TopicProducerCache.Key key, Supplier<TopicProducer> topicProducerSupplier) {
try {
final SharedTopicProducer sharedTopicProducer =
cache.get(key, () -> new SharedTopicProducer(topicProducerSupplier.get()));
sharedTopicProducer.acquire();
return sharedTopicProducer;
} catch (ExecutionException ex) {
throw new RuntimeException(ex);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.stream.Collectors;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.map.LRUMap;

@Slf4j
public class ProduceGateway implements Closeable {
Expand Down Expand Up @@ -77,12 +78,14 @@ public void validateOptions(Map<String, String> options) {
}

private final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry;
private final TopicProducerCache topicProducerCache;
private TopicProducer producer;
private List<Header> commonHeaders;
private String logRef;

public ProduceGateway(TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry) {
public ProduceGateway(TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry, TopicProducerCache topicProducerCache) {
this.topicConnectionsRuntimeRegistry = topicConnectionsRuntimeRegistry;
this.topicProducerCache = topicProducerCache;
}

public void start(
Expand All @@ -96,33 +99,30 @@ public void start(
requestContext.applicationId(),
requestContext.gateway().getId());
this.commonHeaders = commonHeaders == null ? List.of() : commonHeaders;

setupProducer(
topic,
requestContext.application().getInstance().streamingCluster(),
requestContext.tenant(),
requestContext.applicationId(),
requestContext.gateway().getId());
final TopicProducerCache.Key key =
new TopicProducerCache.Key(requestContext.tenant(), requestContext.applicationId(),
requestContext.gateway().getId());
producer = topicProducerCache.getOrCreate(key, () -> setupProducer(topic, requestContext.application().getInstance()
.streamingCluster()));
}

protected void setupProducer(
protected TopicProducer setupProducer(
String topic,
StreamingCluster streamingCluster,
final String tenant,
final String applicationId,
final String gatewayId) {
StreamingCluster streamingCluster) {

final TopicConnectionsRuntime topicConnectionsRuntime =
topicConnectionsRuntimeRegistry
.getTopicConnectionsRuntime(streamingCluster)
.asTopicConnectionsRuntime();

topicConnectionsRuntime.init(streamingCluster);

producer =
topicConnectionsRuntime.createProducer(
null, streamingCluster, Map.of("topic", topic));
producer.start();
final TopicProducer topicProducer = topicConnectionsRuntime.createProducer(
null, streamingCluster, Map.of("topic", topic));
topicProducer.start();
log.debug("[{}] Started producer on topic {}", logRef, topic);
return topicProducer;

}

public void produceMessage(String payload) throws ProduceException {
Expand Down Expand Up @@ -185,6 +185,7 @@ public void close() {
} catch (Exception e) {
log.debug("[{}] Error closing producer: {}", logRef, e.getMessage(), e);
}
producer = null;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package ai.langstream.apigateway.gateways;

import ai.langstream.api.runner.topics.TopicProducer;
import java.util.function.Supplier;

public interface TopicProducerCache {
record Key(String tenant, String application, String gatewayId){}

TopicProducer getOrCreate(Key key, Supplier<TopicProducer> topicProducerSupplier);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package ai.langstream.apigateway.gateways;

import ai.langstream.api.runner.topics.TopicProducer;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties;
import ai.langstream.apigateway.config.TopicProperties;
import java.util.function.Supplier;
import org.apache.commons.collections4.map.LRUMap;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class TopicProducerCacheFactory {

@Bean
public TopicProducerCache topicProducerCache(
TopicProperties topicProperties) {
if (topicProperties.isProducersCacheEnabled()) {
return new LRUTopicProducerCache(topicProperties.getProducersCacheSize());
} else {
return (key, topicProducerSupplier) -> topicProducerSupplier.get();
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import ai.langstream.api.gateway.GatewayRequestContext;
import ai.langstream.api.model.Gateway;
import ai.langstream.api.runner.code.Header;
import ai.langstream.apigateway.config.TopicProperties;
import ai.langstream.apigateway.gateways.GatewayRequestHandler;
import ai.langstream.apigateway.gateways.ProduceGateway;
import ai.langstream.apigateway.gateways.TopicProducerCache;
import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import ai.langstream.apigateway.websocket.api.ProduceRequest;
Expand Down Expand Up @@ -48,6 +50,7 @@
public class GatewayResource {

private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider;
private final TopicProducerCache topicProducerCache;
private final GatewayRequestHandler gatewayRequestHandler;

@PostMapping(
Expand Down Expand Up @@ -86,7 +89,8 @@ ProduceResponse produce(
final ProduceGateway produceGateway =
new ProduceGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry());
.getTopicConnectionsRuntimeRegistry(),
topicProducerCache);
try {
final List<Header> commonHeaders =
ProduceGateway.getProducerCommonHeaders(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.gateways.GatewayRequestHandler;
import ai.langstream.apigateway.gateways.TopicProducerCache;
import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean;
import ai.langstream.apigateway.websocket.handlers.ChatHandler;
import ai.langstream.apigateway.websocket.handlers.ConsumeHandler;
Expand Down Expand Up @@ -48,6 +49,7 @@ public class WebSocketConfig implements WebSocketConfigurer {
private final ApplicationStore applicationStore;
private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider;
private final GatewayRequestHandler gatewayRequestHandler;
private final TopicProducerCache topicProducerCache;
private final ExecutorService consumeThreadPool = Executors.newCachedThreadPool();

@Override
Expand All @@ -58,16 +60,19 @@ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
new ConsumeHandler(
applicationStore,
consumeThreadPool,
topicConnectionsRuntimeRegistry),
topicConnectionsRuntimeRegistry,
topicProducerCache),
CONSUME_PATH)
.addHandler(
new ProduceHandler(applicationStore, topicConnectionsRuntimeRegistry),
new ProduceHandler(applicationStore, topicConnectionsRuntimeRegistry,
topicProducerCache),
PRODUCE_PATH)
.addHandler(
new ChatHandler(
applicationStore,
consumeThreadPool,
topicConnectionsRuntimeRegistry),
topicConnectionsRuntimeRegistry,
topicProducerCache),
CHAT_PATH)
.setAllowedOrigins("*")
.addInterceptors(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import ai.langstream.apigateway.gateways.ConsumeGateway;
import ai.langstream.apigateway.gateways.GatewayRequestHandler;
import ai.langstream.apigateway.gateways.ProduceGateway;
import ai.langstream.apigateway.gateways.TopicProducerCache;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import ai.langstream.apigateway.websocket.api.ProduceResponse;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -53,12 +54,15 @@ public abstract class AbstractHandler extends TextWebSocketHandler {
protected static final String ATTRIBUTE_CONSUME_GATEWAY = "__consume_gateway";
protected final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry;
protected final ApplicationStore applicationStore;
private final TopicProducerCache topicProducerCache;

public AbstractHandler(
ApplicationStore applicationStore,
TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry) {
TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry,
TopicProducerCache topicProducerCache) {
this.topicConnectionsRuntimeRegistry = topicConnectionsRuntimeRegistry;
this.applicationStore = applicationStore;
this.topicProducerCache = topicProducerCache;
}

public abstract String path();
Expand Down Expand Up @@ -295,7 +299,7 @@ protected void setupReader(
protected void setupProducer(
String topic, List<Header> commonHeaders, AuthenticatedGatewayRequestContext context)
throws Exception {
final ProduceGateway produceGateway = new ProduceGateway(topicConnectionsRuntimeRegistry);
final ProduceGateway produceGateway = new ProduceGateway(topicConnectionsRuntimeRegistry, topicProducerCache);

try {
produceGateway.start(topic, commonHeaders, context);
Expand Down
Loading

0 comments on commit 5f38cd8

Please sign in to comment.