Skip to content

Commit

Permalink
New gateway 'service'
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloboschi committed Oct 27, 2023
1 parent 5f38cd8 commit e4bfcc5
Show file tree
Hide file tree
Showing 14 changed files with 268 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.apigateway.websocket.api;
package ai.langstream.apigateway.api;

import java.util.Map;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.apigateway.websocket.api;
package ai.langstream.apigateway.api;

import java.util.Map;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.apigateway.websocket.api;
package ai.langstream.apigateway.api;

public record ProduceResponse(Status status, String reason) {
public static ProduceResponse OK = new ProduceResponse(Status.OK, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
import ai.langstream.api.runner.topics.TopicReadResult;
import ai.langstream.api.runner.topics.TopicReader;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import ai.langstream.apigateway.websocket.api.ConsumePushMessage;
import ai.langstream.apigateway.websocket.api.ProduceResponse;
import ai.langstream.apigateway.api.ConsumePushMessage;
import ai.langstream.apigateway.api.ProduceResponse;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.Closeable;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.HashMap;
Expand All @@ -46,7 +47,7 @@
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class ConsumeGateway implements Closeable {
public class ConsumeGateway implements AutoCloseable {

protected static final ObjectMapper mapper = new ObjectMapper();

Expand Down Expand Up @@ -237,4 +238,44 @@ public void close() {
closeReader();
}
}



public static List<Function<Record, Boolean>> createMessageFilters(
List<Gateway.KeyValueComparison> headersFilters,
Map<String, String> passedParameters,
Map<String, String> principalValues) {
List<Function<Record, Boolean>> filters = new ArrayList<>();
if (headersFilters == null) {
return filters;
}
for (Gateway.KeyValueComparison comparison : headersFilters) {
if (comparison.key() == null) {
throw new IllegalArgumentException("Key cannot be null");
}
filters.add(
record -> {
final Header header = record.getHeader(comparison.key());
if (header == null) {
return false;
}
final String expectedValue = header.valueAsString();
if (expectedValue == null) {
return false;
}
String value = comparison.value();
if (value == null && comparison.valueFromParameters() != null) {
value = passedParameters.get(comparison.valueFromParameters());
}
if (value == null && comparison.valueFromAuthentication() != null) {
value = principalValues.get(comparison.valueFromAuthentication());
}
if (value == null) {
return false;
}
return expectedValue.equals(value);
});
}
return filters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runner.topics.TopicProducer;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import ai.langstream.apigateway.websocket.api.ProduceRequest;
import ai.langstream.apigateway.websocket.api.ProduceResponse;
import ai.langstream.apigateway.api.ProduceRequest;
import ai.langstream.apigateway.api.ProduceResponse;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.Closeable;
Expand All @@ -36,10 +36,9 @@
import java.util.stream.Collectors;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.map.LRUMap;

@Slf4j
public class ProduceGateway implements Closeable {
public class ProduceGateway implements AutoCloseable {

protected static final ObjectMapper mapper = new ObjectMapper();

Expand Down Expand Up @@ -83,7 +82,8 @@ public void validateOptions(Map<String, String> options) {
private List<Header> commonHeaders;
private String logRef;

public ProduceGateway(TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry, TopicProducerCache topicProducerCache) {
public ProduceGateway(TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry,
TopicProducerCache topicProducerCache) {
this.topicConnectionsRuntimeRegistry = topicConnectionsRuntimeRegistry;
this.topicProducerCache = topicProducerCache;
}
Expand All @@ -102,8 +102,9 @@ public void start(
final TopicProducerCache.Key key =
new TopicProducerCache.Key(requestContext.tenant(), requestContext.applicationId(),
requestContext.gateway().getId());
producer = topicProducerCache.getOrCreate(key, () -> setupProducer(topic, requestContext.application().getInstance()
.streamingCluster()));
producer = topicProducerCache.getOrCreate(key,
() -> setupProducer(topic, requestContext.application().getInstance()
.streamingCluster()));
}

protected TopicProducer setupProducer(
Expand Down Expand Up @@ -153,8 +154,8 @@ public void produceMessage(ProduceRequest produceRequest) throws ProduceExceptio
if (configuredHeaders.contains(messageHeader.getKey())) {
throw new ProduceException(
"Header "
+ messageHeader.getKey()
+ " is configured as parameter-level header.",
+ messageHeader.getKey()
+ " is configured as parameter-level header.",
ProduceResponse.Status.BAD_REQUEST);
}
headers.add(
Expand Down Expand Up @@ -191,11 +192,27 @@ public void close() {

public static List<Header> getProducerCommonHeaders(
Gateway.ProduceOptions produceOptions, AuthenticatedGatewayRequestContext context) {
if (produceOptions != null) {
return getProducerCommonHeaders(
produceOptions.headers(), context.userParameters(), context.principalValues());
if (produceOptions == null) {
return null;
}
return null;
return getProducerCommonHeaders(
produceOptions.headers(), context.userParameters(), context.principalValues());
}

public static List<Header> getProducerCommonHeaders(
Gateway.ChatOptions chatOptions, AuthenticatedGatewayRequestContext context) {
if (chatOptions == null) {
return null;
}
return getProducerCommonHeaders(chatOptions.getHeaders(), context.userParameters(), context.principalValues());
}

public static List<Header> getProducerCommonHeaders(
Gateway.ServiceOptions serviceOptions, AuthenticatedGatewayRequestContext context) {
if (serviceOptions == null) {
return null;
}
return getProducerCommonHeaders(serviceOptions.getHeaders(), context.userParameters(), context.principalValues());
}

public static List<Header> getProducerCommonHeaders(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,26 @@
import ai.langstream.api.gateway.GatewayRequestContext;
import ai.langstream.api.model.Gateway;
import ai.langstream.api.runner.code.Header;
import ai.langstream.apigateway.config.TopicProperties;
import ai.langstream.api.runner.code.Record;
import ai.langstream.apigateway.api.ConsumePushMessage;
import ai.langstream.apigateway.gateways.ConsumeGateway;
import ai.langstream.apigateway.gateways.GatewayRequestHandler;
import ai.langstream.apigateway.gateways.ProduceGateway;
import ai.langstream.apigateway.gateways.TopicProducerCache;
import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import ai.langstream.apigateway.websocket.api.ProduceRequest;
import ai.langstream.apigateway.websocket.api.ProduceResponse;
import ai.langstream.apigateway.api.ProduceRequest;
import ai.langstream.apigateway.api.ProduceResponse;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.validation.constraints.NotBlank;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -52,6 +60,7 @@ public class GatewayResource {
private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider;
private final TopicProducerCache topicProducerCache;
private final GatewayRequestHandler gatewayRequestHandler;
private final ExecutorService consumeThreadPool = Executors.newCachedThreadPool();

@PostMapping(
value = "/produce/{tenant}/{application}/{gateway}",
Expand Down Expand Up @@ -86,20 +95,93 @@ ProduceResponse produce(
throw new ResponseStatusException(HttpStatus.UNAUTHORIZED, e.getMessage());
}

final ProduceGateway produceGateway =
new ProduceGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry(),
topicProducerCache);
try {

try (final ProduceGateway produceGateway =
new ProduceGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry(),
topicProducerCache);) {
final List<Header> commonHeaders =
ProduceGateway.getProducerCommonHeaders(
context.gateway().getProduceOptions(), authContext);
produceGateway.start(context.gateway().getTopic(), commonHeaders, authContext);
produceGateway.produceMessage(produceRequest);
return ProduceResponse.OK;
} finally {
produceGateway.close();
}
}


@PostMapping(
value = "/service/{tenant}/{application}/{gateway}",
consumes = MediaType.APPLICATION_JSON_VALUE)
void service(
WebRequest request,
HttpServletResponse response,
@NotBlank @PathVariable("tenant") String tenant,
@NotBlank @PathVariable("application") String application,
@NotBlank @PathVariable("gateway") String gateway,
@RequestBody ProduceRequest produceRequest)
throws ProduceGateway.ProduceException {

final Map<String, String> queryString =
request.getParameterMap().keySet().stream()
.collect(Collectors.toMap(k -> k, k -> request.getParameter(k)));
final Map<String, String> headers = new HashMap<>();
request.getHeaderNames()
.forEachRemaining(name -> headers.put(name, request.getHeader(name)));
final GatewayRequestContext context =
gatewayRequestHandler.validateRequest(
tenant,
application,
gateway,
Gateway.GatewayType.service,
queryString,
headers,
new ProduceGateway.ProduceGatewayRequestValidator());
final AuthenticatedGatewayRequestContext authContext;
try {
authContext = gatewayRequestHandler.authenticate(context);
} catch (GatewayRequestHandler.AuthFailedException e) {
throw new ResponseStatusException(HttpStatus.UNAUTHORIZED, e.getMessage());
}


try (final ConsumeGateway consumeGateway = new ConsumeGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry());
final ProduceGateway produceGateway =
new ProduceGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry(),
topicProducerCache);) {

final Gateway.ServiceOptions serviceOptions = authContext.gateway().getServiceOptions();
try {
final List<Function<Record, Boolean>> messageFilters =
ConsumeGateway.createMessageFilters(
serviceOptions.getHeaders(), authContext.userParameters(),
authContext.principalValues());
consumeGateway.setup(serviceOptions.getInputTopic(), messageFilters, authContext);
final AtomicBoolean stop = new AtomicBoolean(false);
consumeGateway.startReadingAsync(consumeThreadPool, () -> stop.get(), record -> {
stop.set(true);
try {
response.getWriter().print(record);
response.getWriter().flush();
response.getWriter().close();
} catch (IOException ioException) {
throw new RuntimeException(ioException);
}
});
} catch (Exception ex) {
log.error("Error while setting up consume gateway", ex);
throw new RuntimeException(ex);
}
final List<Header> commonHeaders =
ProduceGateway.getProducerCommonHeaders(serviceOptions, authContext);
produceGateway.start(serviceOptions.getOutputTopic(), commonHeaders, authContext);
produceGateway.produceMessage(produceRequest);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import ai.langstream.apigateway.gateways.ProduceGateway;
import ai.langstream.apigateway.gateways.TopicProducerCache;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import ai.langstream.apigateway.websocket.api.ProduceResponse;
import ai.langstream.apigateway.api.ProduceResponse;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -242,43 +242,6 @@ protected void startReadingMessages(WebSocketSession webSocketSession, Executor
});
}

protected static List<Function<Record, Boolean>> createMessageFilters(
List<Gateway.KeyValueComparison> headersFilters,
Map<String, String> passedParameters,
Map<String, String> principalValues) {
List<Function<Record, Boolean>> filters = new ArrayList<>();
if (headersFilters == null) {
return filters;
}
for (Gateway.KeyValueComparison comparison : headersFilters) {
if (comparison.key() == null) {
throw new IllegalArgumentException("Key cannot be null");
}
filters.add(
record -> {
final Header header = record.getHeader(comparison.key());
if (header == null) {
return false;
}
final String expectedValue = header.valueAsString();
if (expectedValue == null) {
return false;
}
String value = comparison.value();
if (value == null && comparison.valueFromParameters() != null) {
value = passedParameters.get(comparison.valueFromParameters());
}
if (value == null && comparison.valueFromAuthentication() != null) {
value = principalValues.get(comparison.valueFromAuthentication());
}
if (value == null) {
return false;
}
return expectedValue.equals(value);
});
}
return filters;
}

protected void setupReader(
String topic,
Expand Down
Loading

0 comments on commit e4bfcc5

Please sign in to comment.