Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New gateway 'service' - part 1 #663

Merged
merged 3 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.apigateway.websocket.api;
package ai.langstream.apigateway.api;

import java.util.Map;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.apigateway.websocket.api;
package ai.langstream.apigateway.api;

import java.util.Map;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.apigateway.websocket.api;
package ai.langstream.apigateway.api;

public record ProduceResponse(Status status, String reason) {
public static ProduceResponse OK = new ProduceResponse(Status.OK, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
import ai.langstream.api.runner.topics.TopicOffsetPosition;
import ai.langstream.api.runner.topics.TopicReadResult;
import ai.langstream.api.runner.topics.TopicReader;
import ai.langstream.apigateway.api.ConsumePushMessage;
import ai.langstream.apigateway.api.ProduceResponse;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import ai.langstream.apigateway.websocket.api.ConsumePushMessage;
import ai.langstream.apigateway.websocket.api.ProduceResponse;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.Closeable;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.HashMap;
Expand All @@ -46,7 +46,7 @@
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class ConsumeGateway implements Closeable {
public class ConsumeGateway implements AutoCloseable {

protected static final ObjectMapper mapper = new ObjectMapper();

Expand Down Expand Up @@ -237,4 +237,42 @@ public void close() {
closeReader();
}
}

public static List<Function<Record, Boolean>> createMessageFilters(
List<Gateway.KeyValueComparison> headersFilters,
Map<String, String> passedParameters,
Map<String, String> principalValues) {
List<Function<Record, Boolean>> filters = new ArrayList<>();
if (headersFilters == null) {
return filters;
}
for (Gateway.KeyValueComparison comparison : headersFilters) {
if (comparison.key() == null) {
throw new IllegalArgumentException("Key cannot be null");
}
filters.add(
record -> {
final Header header = record.getHeader(comparison.key());
if (header == null) {
return false;
}
final String expectedValue = header.valueAsString();
if (expectedValue == null) {
return false;
}
String value = comparison.value();
if (value == null && comparison.valueFromParameters() != null) {
value = passedParameters.get(comparison.valueFromParameters());
}
if (value == null && comparison.valueFromAuthentication() != null) {
value = principalValues.get(comparison.valueFromAuthentication());
}
if (value == null) {
return false;
}
return expectedValue.equals(value);
});
}
return filters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
import ai.langstream.api.runner.topics.TopicConnectionsRuntime;
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runner.topics.TopicProducer;
import ai.langstream.apigateway.api.ProduceRequest;
import ai.langstream.apigateway.api.ProduceResponse;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import ai.langstream.apigateway.websocket.api.ProduceRequest;
import ai.langstream.apigateway.websocket.api.ProduceResponse;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.Closeable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
Expand All @@ -38,7 +37,7 @@
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class ProduceGateway implements Closeable {
public class ProduceGateway implements AutoCloseable {

protected static final ObjectMapper mapper = new ObjectMapper();

Expand Down Expand Up @@ -190,11 +189,29 @@ public void close() {

public static List<Header> getProducerCommonHeaders(
Gateway.ProduceOptions produceOptions, AuthenticatedGatewayRequestContext context) {
if (produceOptions != null) {
return getProducerCommonHeaders(
produceOptions.headers(), context.userParameters(), context.principalValues());
if (produceOptions == null) {
return null;
}
return null;
return getProducerCommonHeaders(
produceOptions.headers(), context.userParameters(), context.principalValues());
}

public static List<Header> getProducerCommonHeaders(
Gateway.ChatOptions chatOptions, AuthenticatedGatewayRequestContext context) {
if (chatOptions == null) {
return null;
}
return getProducerCommonHeaders(
chatOptions.getHeaders(), context.userParameters(), context.principalValues());
}

public static List<Header> getProducerCommonHeaders(
Gateway.ServiceOptions serviceOptions, AuthenticatedGatewayRequestContext context) {
if (serviceOptions == null) {
return null;
}
return getProducerCommonHeaders(
serviceOptions.getHeaders(), context.userParameters(), context.principalValues());
}

public static List<Header> getProducerCommonHeaders(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,24 @@
import ai.langstream.api.gateway.GatewayRequestContext;
import ai.langstream.api.model.Gateway;
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.apigateway.api.ProduceRequest;
import ai.langstream.apigateway.api.ProduceResponse;
import ai.langstream.apigateway.gateways.ConsumeGateway;
import ai.langstream.apigateway.gateways.GatewayRequestHandler;
import ai.langstream.apigateway.gateways.ProduceGateway;
import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import ai.langstream.apigateway.websocket.api.ProduceRequest;
import ai.langstream.apigateway.websocket.api.ProduceResponse;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.validation.constraints.NotBlank;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -49,6 +57,7 @@ public class GatewayResource {

private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider;
private final GatewayRequestHandler gatewayRequestHandler;
private final ExecutorService consumeThreadPool = Executors.newCachedThreadPool();

@PostMapping(
value = "/produce/{tenant}/{application}/{gateway}",
Expand All @@ -61,12 +70,8 @@ ProduceResponse produce(
@RequestBody ProduceRequest produceRequest)
throws ProduceGateway.ProduceException {

final Map<String, String> queryString =
request.getParameterMap().keySet().stream()
.collect(Collectors.toMap(k -> k, k -> request.getParameter(k)));
final Map<String, String> headers = new HashMap<>();
request.getHeaderNames()
.forEachRemaining(name -> headers.put(name, request.getHeader(name)));
final Map<String, String> queryString = computeQueryString(request);
final Map<String, String> headers = computeHeaders(request);
final GatewayRequestContext context =
gatewayRequestHandler.validateRequest(
tenant,
Expand All @@ -83,19 +88,102 @@ ProduceResponse produce(
throw new ResponseStatusException(HttpStatus.UNAUTHORIZED, e.getMessage());
}

final ProduceGateway produceGateway =
try (final ProduceGateway produceGateway =
new ProduceGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry());
try {
.getTopicConnectionsRuntimeRegistry()); ) {
final List<Header> commonHeaders =
ProduceGateway.getProducerCommonHeaders(
context.gateway().getProduceOptions(), authContext);
produceGateway.start(context.gateway().getTopic(), commonHeaders, authContext);
produceGateway.produceMessage(produceRequest);
return ProduceResponse.OK;
} finally {
produceGateway.close();
}
}

private Map<String, String> computeHeaders(WebRequest request) {
final Map<String, String> headers = new HashMap<>();
request.getHeaderNames()
.forEachRemaining(name -> headers.put(name, request.getHeader(name)));
return headers;
}

@PostMapping(
value = "/service/{tenant}/{application}/{gateway}",
consumes = MediaType.APPLICATION_JSON_VALUE)
void service(
WebRequest request,
HttpServletResponse response,
@NotBlank @PathVariable("tenant") String tenant,
@NotBlank @PathVariable("application") String application,
@NotBlank @PathVariable("gateway") String gateway,
@RequestBody ProduceRequest produceRequest)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do you set the timeout ?
I guess that we should give the client the ability to set the timeout for waiting for the message to arrive

throws ProduceGateway.ProduceException {

final Map<String, String> queryString = computeQueryString(request);
final Map<String, String> headers = computeHeaders(request);
final GatewayRequestContext context =
gatewayRequestHandler.validateRequest(
tenant,
application,
gateway,
Gateway.GatewayType.service,
queryString,
headers,
new ProduceGateway.ProduceGatewayRequestValidator());
final AuthenticatedGatewayRequestContext authContext;
try {
authContext = gatewayRequestHandler.authenticate(context);
} catch (GatewayRequestHandler.AuthFailedException e) {
throw new ResponseStatusException(HttpStatus.UNAUTHORIZED, e.getMessage());
}

try (final ConsumeGateway consumeGateway =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have to follow up with a cache for this

new ConsumeGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry());
final ProduceGateway produceGateway =
new ProduceGateway(
topicConnectionsRuntimeRegistryProvider
.getTopicConnectionsRuntimeRegistry()); ) {

final Gateway.ServiceOptions serviceOptions = authContext.gateway().getServiceOptions();
try {
final List<Function<Record, Boolean>> messageFilters =
ConsumeGateway.createMessageFilters(
serviceOptions.getHeaders(),
authContext.userParameters(),
authContext.principalValues());
consumeGateway.setup(serviceOptions.getInputTopic(), messageFilters, authContext);
final AtomicBoolean stop = new AtomicBoolean(false);
consumeGateway.startReadingAsync(
consumeThreadPool,
() -> stop.get(),
record -> {
stop.set(true);
try {
response.getWriter().print(record);
response.getWriter().flush();
response.getWriter().close();
} catch (IOException ioException) {
throw new RuntimeException(ioException);
}
});
} catch (Exception ex) {
log.error("Error while setting up consume gateway", ex);
throw new RuntimeException(ex);
}
final List<Header> commonHeaders =
ProduceGateway.getProducerCommonHeaders(serviceOptions, authContext);
produceGateway.start(serviceOptions.getOutputTopic(), commonHeaders, authContext);
produceGateway.produceMessage(produceRequest);
}
}

private Map<String, String> computeQueryString(WebRequest request) {
final Map<String, String> queryString =
request.getParameterMap().keySet().stream()
.collect(Collectors.toMap(k -> k, k -> request.getParameter(k)));
return queryString;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runner.topics.TopicProducer;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.api.ProduceResponse;
import ai.langstream.apigateway.gateways.ConsumeGateway;
import ai.langstream.apigateway.gateways.GatewayRequestHandler;
import ai.langstream.apigateway.gateways.ProduceGateway;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import ai.langstream.apigateway.websocket.api.ProduceResponse;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
Expand Down Expand Up @@ -238,44 +237,6 @@ protected void startReadingMessages(WebSocketSession webSocketSession, Executor
});
}

protected static List<Function<Record, Boolean>> createMessageFilters(
List<Gateway.KeyValueComparison> headersFilters,
Map<String, String> passedParameters,
Map<String, String> principalValues) {
List<Function<Record, Boolean>> filters = new ArrayList<>();
if (headersFilters == null) {
return filters;
}
for (Gateway.KeyValueComparison comparison : headersFilters) {
if (comparison.key() == null) {
throw new IllegalArgumentException("Key cannot be null");
}
filters.add(
record -> {
final Header header = record.getHeader(comparison.key());
if (header == null) {
return false;
}
final String expectedValue = header.valueAsString();
if (expectedValue == null) {
return false;
}
String value = comparison.value();
if (value == null && comparison.valueFromParameters() != null) {
value = passedParameters.get(comparison.valueFromParameters());
}
if (value == null && comparison.valueFromAuthentication() != null) {
value = principalValues.get(comparison.valueFromAuthentication());
}
if (value == null) {
return false;
}
return expectedValue.equals(value);
});
}
return filters;
}

protected void setupReader(
String topic,
List<Function<Record, Boolean>> filters,
Expand Down
Loading
Loading