Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
New gateway 'service' - part 1 (LangStream#663)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloboschi authored Oct 30, 2023
1 parent b623381 commit 1febe44
Show file tree
Hide file tree
Showing 15 changed files with 384 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.apigateway.websocket.api;
package ai.langstream.apigateway.api;

import java.util.Map;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.apigateway.websocket.api;
package ai.langstream.apigateway.api;

import java.util.Map;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.apigateway.websocket.api;
package ai.langstream.apigateway.api;

public record ProduceResponse(Status status, String reason) {
public static ProduceResponse OK = new ProduceResponse(Status.OK, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
import ai.langstream.api.runner.topics.TopicOffsetPosition;
import ai.langstream.api.runner.topics.TopicReadResult;
import ai.langstream.api.runner.topics.TopicReader;
import ai.langstream.apigateway.api.ConsumePushMessage;
import ai.langstream.apigateway.api.ProduceResponse;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import ai.langstream.apigateway.websocket.api.ConsumePushMessage;
import ai.langstream.apigateway.websocket.api.ProduceResponse;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.Closeable;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.HashMap;
Expand All @@ -46,7 +46,7 @@
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class ConsumeGateway implements Closeable {
public class ConsumeGateway implements AutoCloseable {

protected static final ObjectMapper mapper = new ObjectMapper();

Expand Down Expand Up @@ -237,4 +237,42 @@ public void close() {
closeReader();
}
}

public static List<Function<Record, Boolean>> createMessageFilters(
List<Gateway.KeyValueComparison> headersFilters,
Map<String, String> passedParameters,
Map<String, String> principalValues) {
List<Function<Record, Boolean>> filters = new ArrayList<>();
if (headersFilters == null) {
return filters;
}
for (Gateway.KeyValueComparison comparison : headersFilters) {
if (comparison.key() == null) {
throw new IllegalArgumentException("Key cannot be null");
}
filters.add(
record -> {
final Header header = record.getHeader(comparison.key());
if (header == null) {
return false;
}
final String expectedValue = header.valueAsString();
if (expectedValue == null) {
return false;
}
String value = comparison.value();
if (value == null && comparison.valueFromParameters() != null) {
value = passedParameters.get(comparison.valueFromParameters());
}
if (value == null && comparison.valueFromAuthentication() != null) {
value = principalValues.get(comparison.valueFromAuthentication());
}
if (value == null) {
return false;
}
return expectedValue.equals(value);
});
}
return filters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
import ai.langstream.api.runner.topics.TopicConnectionsRuntime;
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runner.topics.TopicProducer;
import ai.langstream.apigateway.api.ProduceRequest;
import ai.langstream.apigateway.api.ProduceResponse;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import ai.langstream.apigateway.websocket.api.ProduceRequest;
import ai.langstream.apigateway.websocket.api.ProduceResponse;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.Closeable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
Expand All @@ -38,7 +37,7 @@
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class ProduceGateway implements Closeable {
public class ProduceGateway implements AutoCloseable {

protected static final ObjectMapper mapper = new ObjectMapper();

Expand Down Expand Up @@ -190,11 +189,29 @@ public void close() {

public static List<Header> getProducerCommonHeaders(
Gateway.ProduceOptions produceOptions, AuthenticatedGatewayRequestContext context) {
if (produceOptions != null) {
return getProducerCommonHeaders(
produceOptions.headers(), context.userParameters(), context.principalValues());
if (produceOptions == null) {
return null;
}
return null;
return getProducerCommonHeaders(
produceOptions.headers(), context.userParameters(), context.principalValues());
}

public static List<Header> getProducerCommonHeaders(
Gateway.ChatOptions chatOptions, AuthenticatedGatewayRequestContext context) {
if (chatOptions == null) {
return null;
}
return getProducerCommonHeaders(
chatOptions.getHeaders(), context.userParameters(), context.principalValues());
}

public static List<Header> getProducerCommonHeaders(
Gateway.ServiceOptions serviceOptions, AuthenticatedGatewayRequestContext context) {
if (serviceOptions == null) {
return null;
}
return getProducerCommonHeaders(
serviceOptions.getHeaders(), context.userParameters(), context.principalValues());
}

public static List<Header> getProducerCommonHeaders(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,24 @@
import ai.langstream.api.gateway.GatewayRequestContext;
import ai.langstream.api.model.Gateway;
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.apigateway.api.ProduceRequest;
import ai.langstream.apigateway.api.ProduceResponse;
import ai.langstream.apigateway.gateways.ConsumeGateway;
import ai.langstream.apigateway.gateways.GatewayRequestHandler;
import ai.langstream.apigateway.gateways.ProduceGateway;
import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import ai.langstream.apigateway.websocket.api.ProduceRequest;
import ai.langstream.apigateway.websocket.api.ProduceResponse;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.validation.constraints.NotBlank;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -49,6 +57,7 @@ public class GatewayResource {

private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider;
private final GatewayRequestHandler gatewayRequestHandler;
private final ExecutorService consumeThreadPool = Executors.newCachedThreadPool();

@PostMapping(
value = "/produce/{tenant}/{application}/{gateway}",
Expand All @@ -61,12 +70,8 @@ ProduceResponse produce(
@RequestBody ProduceRequest produceRequest)
throws ProduceGateway.ProduceException {

final Map<String, String> queryString =
request.getParameterMap().keySet().stream()
.collect(Collectors.toMap(k -> k, k -> request.getParameter(k)));
final Map<String, String> headers = new HashMap<>();
request.getHeaderNames()
.forEachRemaining(name -> headers.put(name, request.getHeader(name)));
final Map<String, String> queryString = computeQueryString(request);
final Map<String, String> headers = computeHeaders(request);
final GatewayRequestContext context =
gatewayRequestHandler.validateRequest(
tenant,
Expand All @@ -83,19 +88,102 @@ ProduceResponse produce(
throw new ResponseStatusException(HttpStatus.UNAUTHORIZED, e.getMessage());
}

final ProduceGateway produceGateway =
try (final ProduceGateway produceGateway =
new ProduceGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry());
try {
.getTopicConnectionsRuntimeRegistry()); ) {
final List<Header> commonHeaders =
ProduceGateway.getProducerCommonHeaders(
context.gateway().getProduceOptions(), authContext);
produceGateway.start(context.gateway().getTopic(), commonHeaders, authContext);
produceGateway.produceMessage(produceRequest);
return ProduceResponse.OK;
} finally {
produceGateway.close();
}
}

private Map<String, String> computeHeaders(WebRequest request) {
final Map<String, String> headers = new HashMap<>();
request.getHeaderNames()
.forEachRemaining(name -> headers.put(name, request.getHeader(name)));
return headers;
}

@PostMapping(
value = "/service/{tenant}/{application}/{gateway}",
consumes = MediaType.APPLICATION_JSON_VALUE)
void service(
WebRequest request,
HttpServletResponse response,
@NotBlank @PathVariable("tenant") String tenant,
@NotBlank @PathVariable("application") String application,
@NotBlank @PathVariable("gateway") String gateway,
@RequestBody ProduceRequest produceRequest)
throws ProduceGateway.ProduceException {

final Map<String, String> queryString = computeQueryString(request);
final Map<String, String> headers = computeHeaders(request);
final GatewayRequestContext context =
gatewayRequestHandler.validateRequest(
tenant,
application,
gateway,
Gateway.GatewayType.service,
queryString,
headers,
new ProduceGateway.ProduceGatewayRequestValidator());
final AuthenticatedGatewayRequestContext authContext;
try {
authContext = gatewayRequestHandler.authenticate(context);
} catch (GatewayRequestHandler.AuthFailedException e) {
throw new ResponseStatusException(HttpStatus.UNAUTHORIZED, e.getMessage());
}

try (final ConsumeGateway consumeGateway =
new ConsumeGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry());
final ProduceGateway produceGateway =
new ProduceGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry()); ) {

final Gateway.ServiceOptions serviceOptions = authContext.gateway().getServiceOptions();
try {
final List<Function<Record, Boolean>> messageFilters =
ConsumeGateway.createMessageFilters(
serviceOptions.getHeaders(),
authContext.userParameters(),
authContext.principalValues());
consumeGateway.setup(serviceOptions.getInputTopic(), messageFilters, authContext);
final AtomicBoolean stop = new AtomicBoolean(false);
consumeGateway.startReadingAsync(
consumeThreadPool,
() -> stop.get(),
record -> {
stop.set(true);
try {
response.getWriter().print(record);
response.getWriter().flush();
response.getWriter().close();
} catch (IOException ioException) {
throw new RuntimeException(ioException);
}
});
} catch (Exception ex) {
log.error("Error while setting up consume gateway", ex);
throw new RuntimeException(ex);
}
final List<Header> commonHeaders =
ProduceGateway.getProducerCommonHeaders(serviceOptions, authContext);
produceGateway.start(serviceOptions.getOutputTopic(), commonHeaders, authContext);
produceGateway.produceMessage(produceRequest);
}
}

private Map<String, String> computeQueryString(WebRequest request) {
final Map<String, String> queryString =
request.getParameterMap().keySet().stream()
.collect(Collectors.toMap(k -> k, k -> request.getParameter(k)));
return queryString;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runner.topics.TopicProducer;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.api.ProduceResponse;
import ai.langstream.apigateway.gateways.ConsumeGateway;
import ai.langstream.apigateway.gateways.GatewayRequestHandler;
import ai.langstream.apigateway.gateways.ProduceGateway;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import ai.langstream.apigateway.websocket.api.ProduceResponse;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
Expand Down Expand Up @@ -238,44 +237,6 @@ protected void startReadingMessages(WebSocketSession webSocketSession, Executor
});
}

protected static List<Function<Record, Boolean>> createMessageFilters(
List<Gateway.KeyValueComparison> headersFilters,
Map<String, String> passedParameters,
Map<String, String> principalValues) {
List<Function<Record, Boolean>> filters = new ArrayList<>();
if (headersFilters == null) {
return filters;
}
for (Gateway.KeyValueComparison comparison : headersFilters) {
if (comparison.key() == null) {
throw new IllegalArgumentException("Key cannot be null");
}
filters.add(
record -> {
final Header header = record.getHeader(comparison.key());
if (header == null) {
return false;
}
final String expectedValue = header.valueAsString();
if (expectedValue == null) {
return false;
}
String value = comparison.value();
if (value == null && comparison.valueFromParameters() != null) {
value = passedParameters.get(comparison.valueFromParameters());
}
if (value == null && comparison.valueFromAuthentication() != null) {
value = principalValues.get(comparison.valueFromAuthentication());
}
if (value == null) {
return false;
}
return expectedValue.equals(value);
});
}
return filters;
}

protected void setupReader(
String topic,
List<Function<Record, Boolean>> filters,
Expand Down
Loading

0 comments on commit 1febe44

Please sign in to comment.