diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java index e2d595d75..82ca0fcea 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java @@ -81,6 +81,8 @@ public void validateOptions(Map options) { private final TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry; + private volatile TopicConnectionsRuntime topicConnectionsRuntime; + private volatile TopicReader reader; private volatile boolean interrupted; private volatile String logRef; @@ -108,7 +110,7 @@ public void setup( final StreamingCluster streamingCluster = requestContext.application().getInstance().streamingCluster(); - final TopicConnectionsRuntime topicConnectionsRuntime = + topicConnectionsRuntime = topicConnectionsRuntimeRegistry .getTopicConnectionsRuntime(streamingCluster) .asTopicConnectionsRuntime(); @@ -220,6 +222,13 @@ private void closeReader() { log.warn("error closing reader", e); } } + if (topicConnectionsRuntime != null) { + try { + topicConnectionsRuntime.close(); + } catch (Exception e) { + log.warn("error closing runtime", e); + } + } } @Override diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java index 765e3f7ae..812fcfa59 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java @@ -18,6 +18,7 @@ import ai.langstream.api.model.Gateway; import ai.langstream.api.model.StreamingCluster; import ai.langstream.api.runner.code.Header; +import ai.langstream.api.runner.code.Record; import ai.langstream.api.runner.code.SimpleRecord; import ai.langstream.api.runner.topics.TopicConnectionsRuntime; import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; @@ -33,7 +34,10 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; + +import lombok.AllArgsConstructor; import lombok.Getter; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.tuple.Pair; @@ -126,6 +130,43 @@ public void start( topicProducerCache.getOrCreate(key, () -> setupProducer(topic, streamingCluster)); } + @AllArgsConstructor + static class TopicProducerAndRuntime implements TopicProducer { + private TopicProducer producer; + private TopicConnectionsRuntime runtime; + + @Override + public void start() { + producer.start(); + } + + @Override + public void close() { + producer.close(); + runtime.close(); + } + + @Override + public CompletableFuture write(Record record) { + return producer.write(record); + } + + @Override + public Object getNativeProducer() { + return producer.getNativeProducer(); + } + + @Override + public Object getInfo() { + return producer.getInfo(); + } + + @Override + public long getTotalIn() { + return producer.getTotalIn(); + } + } + protected TopicProducer setupProducer(String topic, StreamingCluster streamingCluster) { final TopicConnectionsRuntime topicConnectionsRuntime = @@ -140,7 +181,7 @@ protected TopicProducer setupProducer(String topic, StreamingCluster streamingCl null, streamingCluster, Map.of("topic", topic)); topicProducer.start(); log.debug("[{}] Started producer on topic {}", logRef, topic); - return topicProducer; + return new TopicProducerAndRuntime(topicProducer, topicConnectionsRuntime); } public void produceMessage(String payload) throws ProduceException { diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java index 0408ff726..cce226690 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java @@ -309,7 +309,6 @@ void testSimpleProduce() throws Exception { final String url = "http://localhost:%d/api/gateways/produce/tenant1/application1/produce" .formatted(port); - produceJsonAndExpectOk(url, "{\"key\": \"my-key\", \"value\": \"my-value\"}"); produceJsonAndExpectOk(url, "{\"key\": \"my-key\"}"); produceJsonAndExpectOk(url, "{\"key\": \"my-key\", \"headers\": {\"h1\": \"v1\"}}"); @@ -574,6 +573,19 @@ void testService() throws Exception { produceJsonAndGetBody( url, "{\"key\": \"my-key2\", \"value\": \"my-value\", \"headers\": {\"header1\":\"value1\"}}")); + + List> futures1 = new ArrayList<>(); + for (int i = 0; i < 30; i++) { + CompletableFuture future = CompletableFuture.runAsync(() -> { + for (int j = 0; j < 10; j++) { + assertMessageContent( + new MsgRecord("my-key", "my-value", Map.of()), + produceJsonAndGetBody(url, "{\"key\": \"my-key\", \"value\": \"my-value\"}")); + } + }); + futures1.add(future); + } + CompletableFuture.allOf(futures1.toArray(new CompletableFuture[]{})).get(2, TimeUnit.MINUTES); } private void startTopicExchange(String fromTopic, String toTopic) throws Exception {