From 1febe44b2b3099636c5da2ee15ea4f27709b13de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Boschi?= Date: Mon, 30 Oct 2023 12:18:41 +0100 Subject: [PATCH] New gateway 'service' - part 1 (#663) --- .../api/ConsumePushMessage.java | 2 +- .../{websocket => }/api/ProduceRequest.java | 2 +- .../{websocket => }/api/ProduceResponse.java | 2 +- .../apigateway/gateways/ConsumeGateway.java | 46 ++++++- .../apigateway/gateways/ProduceGateway.java | 33 +++-- .../apigateway/http/GatewayResource.java | 114 ++++++++++++++++-- .../websocket/handlers/AbstractHandler.java | 41 +------ .../websocket/handlers/ChatHandler.java | 29 ++--- .../websocket/handlers/ConsumeHandler.java | 3 +- .../apigateway/http/GatewayResourceTest.java | 58 +++++++++ .../handlers/ProduceConsumeHandlerTest.java | 6 +- .../java/ai/langstream/api/model/Gateway.java | 25 +++- .../ApplicationPlaceholderResolver.java | 1 + .../langstream/impl/parser/ModelBuilder.java | 41 +++++++ .../impl/parser/ModelBuilderTest.java | 87 +++++++++++-- 15 files changed, 384 insertions(+), 106 deletions(-) rename langstream-api-gateway/src/main/java/ai/langstream/apigateway/{websocket => }/api/ConsumePushMessage.java (94%) rename langstream-api-gateway/src/main/java/ai/langstream/apigateway/{websocket => }/api/ProduceRequest.java (93%) rename langstream-api-gateway/src/main/java/ai/langstream/apigateway/{websocket => }/api/ProduceResponse.java (94%) diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/api/ConsumePushMessage.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/api/ConsumePushMessage.java similarity index 94% rename from langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/api/ConsumePushMessage.java rename to langstream-api-gateway/src/main/java/ai/langstream/apigateway/api/ConsumePushMessage.java index a34d54b89..1d207e4da 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/api/ConsumePushMessage.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/api/ConsumePushMessage.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package ai.langstream.apigateway.websocket.api; +package ai.langstream.apigateway.api; import java.util.Map; diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/api/ProduceRequest.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/api/ProduceRequest.java similarity index 93% rename from langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/api/ProduceRequest.java rename to langstream-api-gateway/src/main/java/ai/langstream/apigateway/api/ProduceRequest.java index fb52b6ef1..7238f5047 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/api/ProduceRequest.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/api/ProduceRequest.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package ai.langstream.apigateway.websocket.api; +package ai.langstream.apigateway.api; import java.util.Map; diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/api/ProduceResponse.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/api/ProduceResponse.java similarity index 94% rename from langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/api/ProduceResponse.java rename to langstream-api-gateway/src/main/java/ai/langstream/apigateway/api/ProduceResponse.java index 67e6b65fe..ea8af13e1 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/api/ProduceResponse.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/api/ProduceResponse.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package ai.langstream.apigateway.websocket.api; +package ai.langstream.apigateway.api; public record ProduceResponse(Status status, String reason) { public static ProduceResponse OK = new ProduceResponse(Status.OK, null); diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java index 827bdc489..bbc17b20d 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java @@ -24,11 +24,11 @@ import ai.langstream.api.runner.topics.TopicOffsetPosition; import ai.langstream.api.runner.topics.TopicReadResult; import ai.langstream.api.runner.topics.TopicReader; +import ai.langstream.apigateway.api.ConsumePushMessage; +import ai.langstream.apigateway.api.ProduceResponse; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; -import ai.langstream.apigateway.websocket.api.ConsumePushMessage; -import ai.langstream.apigateway.websocket.api.ProduceResponse; import com.fasterxml.jackson.databind.ObjectMapper; -import java.io.Closeable; +import java.util.ArrayList; import java.util.Base64; import java.util.Collection; import java.util.HashMap; @@ -46,7 +46,7 @@ import lombok.extern.slf4j.Slf4j; @Slf4j -public class ConsumeGateway implements Closeable { +public class ConsumeGateway implements AutoCloseable { protected static final ObjectMapper mapper = new ObjectMapper(); @@ -237,4 +237,42 @@ public void close() { closeReader(); } } + + public static List> createMessageFilters( + List headersFilters, + Map passedParameters, + Map principalValues) { + List> filters = new ArrayList<>(); + if (headersFilters == null) { + return filters; + } + for (Gateway.KeyValueComparison comparison : headersFilters) { + if (comparison.key() == null) { + throw new IllegalArgumentException("Key cannot be null"); + } + filters.add( + record -> { + final Header header = record.getHeader(comparison.key()); + if (header == null) { + return false; + } + final String expectedValue = header.valueAsString(); + if (expectedValue == null) { + return false; + } + String value = comparison.value(); + if (value == null && comparison.valueFromParameters() != null) { + value = passedParameters.get(comparison.valueFromParameters()); + } + if (value == null && comparison.valueFromAuthentication() != null) { + value = principalValues.get(comparison.valueFromAuthentication()); + } + if (value == null) { + return false; + } + return expectedValue.equals(value); + }); + } + return filters; + } } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java index 0f6871425..5f741b8d4 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ProduceGateway.java @@ -22,12 +22,11 @@ import ai.langstream.api.runner.topics.TopicConnectionsRuntime; import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; import ai.langstream.api.runner.topics.TopicProducer; +import ai.langstream.apigateway.api.ProduceRequest; +import ai.langstream.apigateway.api.ProduceResponse; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; -import ai.langstream.apigateway.websocket.api.ProduceRequest; -import ai.langstream.apigateway.websocket.api.ProduceResponse; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; -import java.io.Closeable; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -38,7 +37,7 @@ import lombok.extern.slf4j.Slf4j; @Slf4j -public class ProduceGateway implements Closeable { +public class ProduceGateway implements AutoCloseable { protected static final ObjectMapper mapper = new ObjectMapper(); @@ -190,11 +189,29 @@ public void close() { public static List
getProducerCommonHeaders( Gateway.ProduceOptions produceOptions, AuthenticatedGatewayRequestContext context) { - if (produceOptions != null) { - return getProducerCommonHeaders( - produceOptions.headers(), context.userParameters(), context.principalValues()); + if (produceOptions == null) { + return null; } - return null; + return getProducerCommonHeaders( + produceOptions.headers(), context.userParameters(), context.principalValues()); + } + + public static List
getProducerCommonHeaders( + Gateway.ChatOptions chatOptions, AuthenticatedGatewayRequestContext context) { + if (chatOptions == null) { + return null; + } + return getProducerCommonHeaders( + chatOptions.getHeaders(), context.userParameters(), context.principalValues()); + } + + public static List
getProducerCommonHeaders( + Gateway.ServiceOptions serviceOptions, AuthenticatedGatewayRequestContext context) { + if (serviceOptions == null) { + return null; + } + return getProducerCommonHeaders( + serviceOptions.getHeaders(), context.userParameters(), context.principalValues()); } public static List
getProducerCommonHeaders( diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java index d8cca19df..874c32c31 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java @@ -18,16 +18,24 @@ import ai.langstream.api.gateway.GatewayRequestContext; import ai.langstream.api.model.Gateway; import ai.langstream.api.runner.code.Header; +import ai.langstream.api.runner.code.Record; +import ai.langstream.apigateway.api.ProduceRequest; +import ai.langstream.apigateway.api.ProduceResponse; +import ai.langstream.apigateway.gateways.ConsumeGateway; import ai.langstream.apigateway.gateways.GatewayRequestHandler; import ai.langstream.apigateway.gateways.ProduceGateway; import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; -import ai.langstream.apigateway.websocket.api.ProduceRequest; -import ai.langstream.apigateway.websocket.api.ProduceResponse; +import jakarta.servlet.http.HttpServletResponse; import jakarta.validation.constraints.NotBlank; +import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; import java.util.stream.Collectors; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -49,6 +57,7 @@ public class GatewayResource { private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider; private final GatewayRequestHandler gatewayRequestHandler; + private final ExecutorService consumeThreadPool = Executors.newCachedThreadPool(); @PostMapping( value = "/produce/{tenant}/{application}/{gateway}", @@ -61,12 +70,8 @@ ProduceResponse produce( @RequestBody ProduceRequest produceRequest) throws ProduceGateway.ProduceException { - final Map queryString = - request.getParameterMap().keySet().stream() - .collect(Collectors.toMap(k -> k, k -> request.getParameter(k))); - final Map headers = new HashMap<>(); - request.getHeaderNames() - .forEachRemaining(name -> headers.put(name, request.getHeader(name))); + final Map queryString = computeQueryString(request); + final Map headers = computeHeaders(request); final GatewayRequestContext context = gatewayRequestHandler.validateRequest( tenant, @@ -83,19 +88,102 @@ ProduceResponse produce( throw new ResponseStatusException(HttpStatus.UNAUTHORIZED, e.getMessage()); } - final ProduceGateway produceGateway = + try (final ProduceGateway produceGateway = new ProduceGateway( topicConnectionsRuntimeRegistryProvider - .getTopicConnectionsRuntimeRegistry()); - try { + .getTopicConnectionsRuntimeRegistry()); ) { final List
commonHeaders = ProduceGateway.getProducerCommonHeaders( context.gateway().getProduceOptions(), authContext); produceGateway.start(context.gateway().getTopic(), commonHeaders, authContext); produceGateway.produceMessage(produceRequest); return ProduceResponse.OK; - } finally { - produceGateway.close(); } } + + private Map computeHeaders(WebRequest request) { + final Map headers = new HashMap<>(); + request.getHeaderNames() + .forEachRemaining(name -> headers.put(name, request.getHeader(name))); + return headers; + } + + @PostMapping( + value = "/service/{tenant}/{application}/{gateway}", + consumes = MediaType.APPLICATION_JSON_VALUE) + void service( + WebRequest request, + HttpServletResponse response, + @NotBlank @PathVariable("tenant") String tenant, + @NotBlank @PathVariable("application") String application, + @NotBlank @PathVariable("gateway") String gateway, + @RequestBody ProduceRequest produceRequest) + throws ProduceGateway.ProduceException { + + final Map queryString = computeQueryString(request); + final Map headers = computeHeaders(request); + final GatewayRequestContext context = + gatewayRequestHandler.validateRequest( + tenant, + application, + gateway, + Gateway.GatewayType.service, + queryString, + headers, + new ProduceGateway.ProduceGatewayRequestValidator()); + final AuthenticatedGatewayRequestContext authContext; + try { + authContext = gatewayRequestHandler.authenticate(context); + } catch (GatewayRequestHandler.AuthFailedException e) { + throw new ResponseStatusException(HttpStatus.UNAUTHORIZED, e.getMessage()); + } + + try (final ConsumeGateway consumeGateway = + new ConsumeGateway( + topicConnectionsRuntimeRegistryProvider + .getTopicConnectionsRuntimeRegistry()); + final ProduceGateway produceGateway = + new ProduceGateway( + topicConnectionsRuntimeRegistryProvider + .getTopicConnectionsRuntimeRegistry()); ) { + + final Gateway.ServiceOptions serviceOptions = authContext.gateway().getServiceOptions(); + try { + final List> messageFilters = + ConsumeGateway.createMessageFilters( + serviceOptions.getHeaders(), + authContext.userParameters(), + authContext.principalValues()); + consumeGateway.setup(serviceOptions.getInputTopic(), messageFilters, authContext); + final AtomicBoolean stop = new AtomicBoolean(false); + consumeGateway.startReadingAsync( + consumeThreadPool, + () -> stop.get(), + record -> { + stop.set(true); + try { + response.getWriter().print(record); + response.getWriter().flush(); + response.getWriter().close(); + } catch (IOException ioException) { + throw new RuntimeException(ioException); + } + }); + } catch (Exception ex) { + log.error("Error while setting up consume gateway", ex); + throw new RuntimeException(ex); + } + final List
commonHeaders = + ProduceGateway.getProducerCommonHeaders(serviceOptions, authContext); + produceGateway.start(serviceOptions.getOutputTopic(), commonHeaders, authContext); + produceGateway.produceMessage(produceRequest); + } + } + + private Map computeQueryString(WebRequest request) { + final Map queryString = + request.getParameterMap().keySet().stream() + .collect(Collectors.toMap(k -> k, k -> request.getParameter(k))); + return queryString; + } } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java index 7f5615a14..773318334 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java @@ -27,14 +27,13 @@ import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; import ai.langstream.api.runner.topics.TopicProducer; import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.api.ProduceResponse; import ai.langstream.apigateway.gateways.ConsumeGateway; import ai.langstream.apigateway.gateways.GatewayRequestHandler; import ai.langstream.apigateway.gateways.ProduceGateway; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; -import ai.langstream.apigateway.websocket.api.ProduceResponse; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.Executor; @@ -238,44 +237,6 @@ protected void startReadingMessages(WebSocketSession webSocketSession, Executor }); } - protected static List> createMessageFilters( - List headersFilters, - Map passedParameters, - Map principalValues) { - List> filters = new ArrayList<>(); - if (headersFilters == null) { - return filters; - } - for (Gateway.KeyValueComparison comparison : headersFilters) { - if (comparison.key() == null) { - throw new IllegalArgumentException("Key cannot be null"); - } - filters.add( - record -> { - final Header header = record.getHeader(comparison.key()); - if (header == null) { - return false; - } - final String expectedValue = header.valueAsString(); - if (expectedValue == null) { - return false; - } - String value = comparison.value(); - if (value == null && comparison.valueFromParameters() != null) { - value = passedParameters.get(comparison.valueFromParameters()); - } - if (value == null && comparison.valueFromAuthentication() != null) { - value = principalValues.get(comparison.valueFromAuthentication()); - } - if (value == null) { - return false; - } - return expectedValue.equals(value); - }); - } - return filters; - } - protected void setupReader( String topic, List> filters, diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java index c224fc240..9d8285459 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ChatHandler.java @@ -22,6 +22,7 @@ import ai.langstream.api.runner.code.Record; import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.gateways.ConsumeGateway; import ai.langstream.apigateway.gateways.GatewayRequestHandler; import ai.langstream.apigateway.gateways.ProduceGateway; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; @@ -136,35 +137,21 @@ public void onBeforeHandshakeCompleted( } private void setupProducer(AuthenticatedGatewayRequestContext context) throws Exception { - final Gateway.ChatOptions chatOptions = context.gateway().getChatOptions(); - - List headerConfig = new ArrayList<>(); - final List gwHeaders = chatOptions.getHeaders(); - if (gwHeaders != null) { - for (Gateway.KeyValueComparison gwHeader : gwHeaders) { - headerConfig.add(gwHeader); - } - } final List
commonHeaders = ProduceGateway.getProducerCommonHeaders( - headerConfig, context.userParameters(), context.principalValues()); + context.gateway().getChatOptions(), context); - setupProducer(chatOptions.getQuestionsTopic(), commonHeaders, context); + setupProducer( + context.gateway().getChatOptions().getQuestionsTopic(), commonHeaders, context); } private void setupReader(AuthenticatedGatewayRequestContext context) throws Exception { final Gateway.ChatOptions chatOptions = context.gateway().getChatOptions(); - - List headerFilters = new ArrayList<>(); - final List gwHeaders = chatOptions.getHeaders(); - if (gwHeaders != null) { - for (Gateway.KeyValueComparison gwHeader : gwHeaders) { - headerFilters.add(gwHeader); - } - } final List> messageFilters = - createMessageFilters( - headerFilters, context.userParameters(), context.principalValues()); + ConsumeGateway.createMessageFilters( + chatOptions.getHeaders(), + context.userParameters(), + context.principalValues()); setupReader(chatOptions.getAnswersTopic(), messageFilters, context); } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java index efcdc77e9..eeeba648d 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ConsumeHandler.java @@ -21,6 +21,7 @@ import ai.langstream.api.runner.code.Record; import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.gateways.ConsumeGateway; import ai.langstream.apigateway.gateways.GatewayRequestHandler; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; import java.util.List; @@ -106,7 +107,7 @@ public void onBeforeHandshakeCompleted( final List> messageFilters; if (consumeOptions != null && consumeOptions.filters() != null) { messageFilters = - createMessageFilters( + ConsumeGateway.createMessageFilters( consumeOptions.filters().headers(), context.userParameters(), context.principalValues()); diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java index db2f5d2db..c9502d7c9 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java @@ -30,6 +30,7 @@ import ai.langstream.api.runtime.ClusterRuntimeRegistry; import ai.langstream.api.runtime.PluginsRegistry; import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.api.ConsumePushMessage; import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean; import ai.langstream.impl.deploy.ApplicationDeployer; @@ -198,6 +199,19 @@ void produceAndExpectOk(String url, String content) { {"status":"OK","reason":null}""", response.body()); } + @SneakyThrows + String produceAndGetBody(String url, String content) { + final HttpRequest request = + HttpRequest.newBuilder(URI.create(url)) + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(content)) + .build(); + final HttpResponse response = + CLIENT.send(request, HttpResponse.BodyHandlers.ofString()); + assertEquals(200, response.statusCode()); + return response.body(); + } + @SneakyThrows void produceAndExpectBadRequest(String url, String content, String errorMessage) { final HttpRequest request = @@ -388,6 +402,50 @@ void testTestCredentials() throws Exception { "{\"value\": \"my-value\"}"); } + @Test + void testService() throws Exception { + final String topic = genTopic(); + prepareTopicsForTest(topic); + testGateways = + new Gateways( + List.of( + Gateway.builder() + .id("svc") + .type(Gateway.GatewayType.service) + .serviceOptions( + new Gateway.ServiceOptions(topic, topic, List.of())) + .build())); + + final String url = + "http://localhost:%d/api/gateways/service/tenant1/application1/svc".formatted(port); + + assertMessageContent( + new MsgRecord("my-key", "my-value", Map.of()), + produceAndGetBody(url, "{\"key\": \"my-key\", \"value\": \"my-value\"}")); + assertMessageContent( + new MsgRecord("my-key2", "my-value", Map.of()), + produceAndGetBody(url, "{\"key\": \"my-key2\", \"value\": \"my-value\"}")); + assertMessageContent( + new MsgRecord("my-key2", "my-value", Map.of("header1", "value1")), + produceAndGetBody( + url, + "{\"key\": \"my-key2\", \"value\": \"my-value\", \"headers\": {\"header1\":\"value1\"}}")); + } + + private record MsgRecord(Object key, Object value, Map headers) {} + + @SneakyThrows + private void assertMessageContent(MsgRecord expected, String actual) { + ConsumePushMessage consume = MAPPER.readValue(actual, ConsumePushMessage.class); + final MsgRecord actualMsgRecord = + new MsgRecord( + consume.record().key(), + consume.record().value(), + consume.record().headers()); + + assertEquals(expected, actualMsgRecord); + } + protected abstract StreamingCluster getStreamingCluster(); private void prepareTopicsForTest(String... topic) throws Exception { diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java index 4c4b60784..bc37aad8c 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java @@ -37,11 +37,11 @@ import ai.langstream.api.runtime.ClusterRuntimeRegistry; import ai.langstream.api.runtime.PluginsRegistry; import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.api.ConsumePushMessage; +import ai.langstream.apigateway.api.ProduceRequest; +import ai.langstream.apigateway.api.ProduceResponse; import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean; -import ai.langstream.apigateway.websocket.api.ConsumePushMessage; -import ai.langstream.apigateway.websocket.api.ProduceRequest; -import ai.langstream.apigateway.websocket.api.ProduceResponse; import ai.langstream.impl.deploy.ApplicationDeployer; import ai.langstream.impl.parser.ModelBuilder; import com.fasterxml.jackson.core.JsonProcessingException; diff --git a/langstream-api/src/main/java/ai/langstream/api/model/Gateway.java b/langstream-api/src/main/java/ai/langstream/api/model/Gateway.java index 8ce97462d..d3cd09771 100644 --- a/langstream-api/src/main/java/ai/langstream/api/model/Gateway.java +++ b/langstream-api/src/main/java/ai/langstream/api/model/Gateway.java @@ -45,13 +45,17 @@ public final class Gateway { @JsonProperty("chat-options") private ChatOptions chatOptions; + @JsonProperty("service-options") + private ServiceOptions serviceOptions; + @JsonProperty("events-topic") private String eventsTopic; public enum GatewayType { produce, consume, - chat + chat, + service } @Data @@ -117,11 +121,6 @@ public static KeyValueComparison valueFromAuthentication( String key, String valueFromAuthentication) { return new KeyValueComparison(key, null, null, valueFromAuthentication); } - - private String getKeyWithDefaultValue() { - - throw new IllegalStateException(); - } } public record ProduceOptions(List headers) {} @@ -143,4 +142,18 @@ public static class ChatOptions { List headers; } + + @Data + @NoArgsConstructor + @AllArgsConstructor + public static class ServiceOptions { + + @JsonProperty("input-topic") + private String inputTopic; + + @JsonProperty("output-topic") + private String outputTopic; + + List headers; + } } diff --git a/langstream-core/src/main/java/ai/langstream/impl/common/ApplicationPlaceholderResolver.java b/langstream-core/src/main/java/ai/langstream/impl/common/ApplicationPlaceholderResolver.java index b7d705a60..81c0f100f 100644 --- a/langstream-core/src/main/java/ai/langstream/impl/common/ApplicationPlaceholderResolver.java +++ b/langstream-core/src/main/java/ai/langstream/impl/common/ApplicationPlaceholderResolver.java @@ -204,6 +204,7 @@ private static Gateways resolveGateways(Application instance, Map gateways = applicationInstance.getGateways().gateways(); - Assertions.assertEquals(1, gateways.size()); - final Gateway gateway = gateways.get(0); - assertEquals("gw", gateway.getId()); - assertEquals("t1", gateway.getTopic()); - assertEquals("google", gateway.getAuthentication().getProvider()); - assertTrue(gateway.getAuthentication().isAllowTestMode()); + Assertions.assertEquals(4, gateways.size()); + final Gateway gateway1 = gateways.get(0); + assertEquals("g1", gateway1.getId()); + assertNull(gateway1.getParameters()); + assertEquals("t1", gateway1.getTopic()); + assertEquals("google", gateway1.getAuthentication().getProvider()); + assertTrue(gateway1.getAuthentication().isAllowTestMode()); + assertEquals( + List.of(new Gateway.KeyValueComparison(null, null, "v1", null)), + gateway1.getProduceOptions().headers()); + + final Gateway gateway2 = gateways.get(1); + assertEquals("g2", gateway2.getId()); + assertEquals(List.of("p1"), gateway2.getParameters()); + assertEquals("t1", gateway2.getTopic()); + assertEquals("github", gateway2.getAuthentication().getProvider()); + assertTrue(gateway2.getAuthentication().isAllowTestMode()); + assertEquals( + List.of(new Gateway.KeyValueComparison(null, null, "v1", null)), + gateway2.getConsumeOptions().filters().headers()); + + final Gateway gateway3 = gateways.get(2); + assertEquals("g3", gateway3.getId()); + assertNull(gateway3.getParameters()); + assertNull(gateway3.getTopic()); + assertEquals("github", gateway3.getAuthentication().getProvider()); + assertTrue(gateway3.getAuthentication().isAllowTestMode()); + assertEquals("q", gateway3.getChatOptions().getQuestionsTopic()); + assertEquals("a", gateway3.getChatOptions().getAnswersTopic()); + assertEquals( + List.of(new Gateway.KeyValueComparison(null, null, "v1", null)), + gateway3.getChatOptions().getHeaders()); + + final Gateway gateway4 = gateways.get(3); + assertEquals("g4", gateway4.getId()); + assertNull(gateway4.getParameters()); + assertNull(gateway4.getTopic()); + assertEquals("github", gateway4.getAuthentication().getProvider()); + assertTrue(gateway4.getAuthentication().isAllowTestMode()); + assertEquals("q", gateway4.getServiceOptions().getInputTopic()); + assertEquals("a", gateway4.getServiceOptions().getOutputTopic()); + assertEquals( + List.of(new Gateway.KeyValueComparison(null, null, "v1", null)), + gateway4.getServiceOptions().getHeaders()); } @Test