Skip to content

Commit

Permalink
Add Java gRPC Sink (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored Sep 21, 2023
1 parent 493bced commit 407ac86
Show file tree
Hide file tree
Showing 7 changed files with 401 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.agents.grpc;

import ai.langstream.api.runner.code.AgentSink;
import ai.langstream.api.runner.code.Record;
import io.grpc.ManagedChannel;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class GrpcAgentSink extends AbstractGrpcAgent implements AgentSink {
private StreamObserver<SinkRequest> request;
private final StreamObserver<SinkResponse> responseObserver;

// For each record sent, we increment the recordId
protected final AtomicLong recordId = new AtomicLong(0);
private final Map<Long, CompletableFuture<?>> writeHandles = new ConcurrentHashMap<>();

public GrpcAgentSink() {
super();
this.responseObserver = getResponseObserver();
}

public GrpcAgentSink(ManagedChannel channel) {
super(channel);
this.responseObserver = getResponseObserver();
}

@Override
public void onNewSchemaToSend(Schema schema) {
request.onNext(SinkRequest.newBuilder().setSchema(schema).build());
}

@Override
public void start() throws Exception {
super.start();
request = AgentServiceGrpc.newStub(channel).withWaitForReady().write(responseObserver);
}

@Override
public CompletableFuture<?> write(Record record) {
CompletableFuture<?> handle = new CompletableFuture<>();
long rId = recordId.incrementAndGet();
SinkRequest.Builder requestBuilder = SinkRequest.newBuilder();
try {
requestBuilder.setRecord(toGrpc(record).setRecordId(rId));
} catch (IOException e) {
agentContext.criticalFailure(new RuntimeException("Error while processing records", e));
}
writeHandles.put(rId, handle);
request.onNext(requestBuilder.build());
return handle;
}

private StreamObserver<SinkResponse> getResponseObserver() {
return new StreamObserver<>() {
@Override
public void onNext(SinkResponse response) {
CompletableFuture<?> handle = writeHandles.get(response.getRecordId());
if (response.hasError()) {
handle.completeExceptionally(new RuntimeException(response.getError()));
} else {
handle.complete(null);
}
}

@Override
public void onError(Throwable throwable) {
agentContext.criticalFailure(
new RuntimeException(
"gRPC server sent error: %s".formatted(throwable.getMessage()),
throwable));
}

@Override
public void onCompleted() {
agentContext.criticalFailure(
new RuntimeException("gRPC server completed the stream unexpectedly"));
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
public class GrpcAgentsCodeProvider implements AgentCodeProvider {

private static final Set<String> SUPPORTED_AGENT_TYPES =
Set.of("experimental-python-source", "experimental-python-processor");
Set.of(
"experimental-python-source",
"experimental-python-processor",
"experimental-python-sink");

@Override
public boolean supports(String agentType) {
Expand All @@ -34,6 +37,7 @@ public AgentCode createInstance(String agentType) {
return switch (agentType) {
case "experimental-python-source" -> new PythonGrpcAgentSource();
case "experimental-python-processor" -> new PythonGrpcAgentProcessor();
case "experimental-python-sink" -> new PythonGrpcAgentSink();
default -> throw new IllegalStateException("Unexpected agent type: " + agentType);
};
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.agents.grpc;

import java.util.Map;

public class PythonGrpcAgentSink extends GrpcAgentSink {

private PythonGrpcServer server;
private Map<String, Object> configuration;

@Override
public void init(Map<String, Object> configuration) throws Exception {
super.init(configuration);
this.configuration = configuration;
}

@Override
public void start() throws Exception {
server = new PythonGrpcServer(agentContext.getCodeDirectory(), configuration);
channel = server.start();
super.start();
}

@Override
public synchronized void close() throws Exception {
if (server != null) server.close();
super.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ service AgentService {
rpc agent_info(google.protobuf.Empty) returns (InfoResponse) {}
rpc read(stream SourceRequest) returns (stream SourceResponse) {}
rpc process(stream ProcessorRequest) returns (stream ProcessorResponse) {}
rpc write(stream SinkRequest) returns (stream SinkResponse) {}
}

message InfoResponse {
Expand Down Expand Up @@ -89,12 +90,22 @@ message ProcessorRequest {
}

message ProcessorResponse {
Schema schema = 1;
repeated ProcessorResult results = 2;
Schema schema = 1;
repeated ProcessorResult results = 2;
}

message ProcessorResult {
int64 record_id = 1;
optional string error = 2;
repeated Record records = 3;
}
}

message SinkRequest {
Schema schema = 1;
Record record = 2;
}

message SinkResponse {
int64 record_id = 1;
optional string error = 2;
}
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
experimental-python-source
experimental-python-processor
experimental-python-processor
experimental-python-sink
Loading

0 comments on commit 407ac86

Please sign in to comment.