From 20b55e9fa7ff174ea79206192db5ae7c44e4bced Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Thu, 21 Sep 2023 12:41:14 +0200 Subject: [PATCH] Add Java gRPC Sink --- .../langstream/agents/grpc/GrpcAgentSink.java | 101 ++++++++ .../agents/grpc/GrpcAgentsCodeProvider.java | 6 +- .../agents/grpc/PythonGrpcAgentSink.java | 43 ++++ .../proto/langstream_grpc/proto/agent.proto | 17 +- .../META-INF/ai.langstream.agents.index | 3 +- .../agents/grpc/GrpcAgentSinkTest.java | 231 ++++++++++++++++++ .../impl/k8s/agents/GrpcAgentsProvider.java | 6 +- 7 files changed, 401 insertions(+), 6 deletions(-) create mode 100644 langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSink.java create mode 100644 langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSink.java create mode 100644 langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSinkTest.java diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSink.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSink.java new file mode 100644 index 000000000..678ec9a2f --- /dev/null +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSink.java @@ -0,0 +1,101 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.agents.grpc; + +import ai.langstream.api.runner.code.AgentSink; +import ai.langstream.api.runner.code.Record; +import io.grpc.ManagedChannel; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class GrpcAgentSink extends AbstractGrpcAgent implements AgentSink { + private StreamObserver request; + private final StreamObserver responseObserver; + + // For each record sent, we increment the recordId + protected final AtomicLong recordId = new AtomicLong(0); + private final Map> writeHandles = new ConcurrentHashMap<>(); + + public GrpcAgentSink() { + super(); + this.responseObserver = getResponseObserver(); + } + + public GrpcAgentSink(ManagedChannel channel) { + super(channel); + this.responseObserver = getResponseObserver(); + } + + @Override + public void onNewSchemaToSend(Schema schema) { + request.onNext(SinkRequest.newBuilder().setSchema(schema).build()); + } + + @Override + public void start() throws Exception { + super.start(); + request = AgentServiceGrpc.newStub(channel).withWaitForReady().write(responseObserver); + } + + @Override + public CompletableFuture write(Record record) { + CompletableFuture handle = new CompletableFuture<>(); + long rId = recordId.incrementAndGet(); + SinkRequest.Builder requestBuilder = SinkRequest.newBuilder(); + try { + requestBuilder.setRecord(toGrpc(record).setRecordId(rId)); + } catch (IOException e) { + agentContext.criticalFailure(new RuntimeException("Error while processing records", e)); + } + writeHandles.put(rId, handle); + request.onNext(requestBuilder.build()); + return handle; + } + + private StreamObserver getResponseObserver() { + return new StreamObserver<>() { + @Override + public void onNext(SinkResponse response) { + CompletableFuture handle = writeHandles.get(response.getRecordId()); + if (response.hasError()) { + handle.completeExceptionally(new RuntimeException(response.getError())); + } else { + handle.complete(null); + } + } + + @Override + public void onError(Throwable throwable) { + agentContext.criticalFailure( + new RuntimeException( + "gRPC server sent error: %s".formatted(throwable.getMessage()), + throwable)); + } + + @Override + public void onCompleted() { + agentContext.criticalFailure( + new RuntimeException("gRPC server completed the stream unexpectedly")); + } + }; + } +} diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentsCodeProvider.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentsCodeProvider.java index 9b1eebfdb..cf8d3f088 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentsCodeProvider.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentsCodeProvider.java @@ -22,7 +22,10 @@ public class GrpcAgentsCodeProvider implements AgentCodeProvider { private static final Set SUPPORTED_AGENT_TYPES = - Set.of("experimental-python-source", "experimental-python-processor"); + Set.of( + "experimental-python-source", + "experimental-python-processor", + "experimental-python-sink"); @Override public boolean supports(String agentType) { @@ -34,6 +37,7 @@ public AgentCode createInstance(String agentType) { return switch (agentType) { case "experimental-python-source" -> new PythonGrpcAgentSource(); case "experimental-python-processor" -> new PythonGrpcAgentProcessor(); + case "experimental-python-sink" -> new PythonGrpcAgentSink(); default -> throw new IllegalStateException("Unexpected agent type: " + agentType); }; } diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSink.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSink.java new file mode 100644 index 000000000..a9739c935 --- /dev/null +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSink.java @@ -0,0 +1,43 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.agents.grpc; + +import java.util.Map; + +public class PythonGrpcAgentSink extends GrpcAgentSink { + + private PythonGrpcServer server; + private Map configuration; + + @Override + public void init(Map configuration) throws Exception { + super.init(configuration); + this.configuration = configuration; + } + + @Override + public void start() throws Exception { + server = new PythonGrpcServer(agentContext.getCodeDirectory(), configuration); + channel = server.start(); + super.start(); + } + + @Override + public synchronized void close() throws Exception { + if (server != null) server.close(); + super.close(); + } +} diff --git a/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto b/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto index 215383850..da376e3c4 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto +++ b/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto @@ -25,6 +25,7 @@ service AgentService { rpc agent_info(google.protobuf.Empty) returns (InfoResponse) {} rpc read(stream SourceRequest) returns (stream SourceResponse) {} rpc process(stream ProcessorRequest) returns (stream ProcessorResponse) {} + rpc write(stream SinkRequest) returns (stream SinkResponse) {} } message InfoResponse { @@ -89,12 +90,22 @@ message ProcessorRequest { } message ProcessorResponse { - Schema schema = 1; - repeated ProcessorResult results = 2; + Schema schema = 1; + repeated ProcessorResult results = 2; } message ProcessorResult { int64 record_id = 1; optional string error = 2; repeated Record records = 3; -} \ No newline at end of file +} + +message SinkRequest { + Schema schema = 1; + Record record = 2; +} + +message SinkResponse { + int64 record_id = 1; + optional string error = 2; +} diff --git a/langstream-agents/langstream-agent-grpc/src/main/resources/META-INF/ai.langstream.agents.index b/langstream-agents/langstream-agent-grpc/src/main/resources/META-INF/ai.langstream.agents.index index babae276d..091efb372 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/resources/META-INF/ai.langstream.agents.index +++ b/langstream-agents/langstream-agent-grpc/src/main/resources/META-INF/ai.langstream.agents.index @@ -1,2 +1,3 @@ experimental-python-source -experimental-python-processor \ No newline at end of file +experimental-python-processor +experimental-python-sink \ No newline at end of file diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSinkTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSinkTest.java new file mode 100644 index 000000000..2672b1d55 --- /dev/null +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSinkTest.java @@ -0,0 +1,231 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.agents.grpc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import ai.langstream.api.runner.code.AgentContext; +import ai.langstream.api.runner.code.SimpleRecord; +import ai.langstream.api.runner.topics.TopicAdmin; +import ai.langstream.api.runner.topics.TopicConnectionProvider; +import ai.langstream.api.runner.topics.TopicConsumer; +import ai.langstream.api.runner.topics.TopicProducer; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import org.apache.avro.Conversions; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.Decoder; +import org.apache.avro.io.DecoderFactory; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class GrpcAgentSinkTest { + private Server server; + private ManagedChannel channel; + private final TestSinkService testSinkService = new TestSinkService(); + private GrpcAgentSink sink; + private TestAgentContext context; + + @BeforeEach + public void setUp() throws Exception { + String serverName = InProcessServerBuilder.generateName(); + server = + InProcessServerBuilder.forName(serverName) + .directExecutor() + .addService(testSinkService) + .build() + .start(); + + channel = InProcessChannelBuilder.forName(serverName).directExecutor().build(); + sink = new GrpcAgentSink(channel); + context = new TestAgentContext(); + sink.setContext(context); + sink.start(); + } + + @AfterEach + public void tearDown() throws Exception { + sink.close(); + channel.shutdownNow(); + server.shutdownNow(); + channel.awaitTermination(30, TimeUnit.SECONDS); + server.awaitTermination(30, TimeUnit.SECONDS); + } + + @Test + void testWriteError() throws Exception { + try { + sink.write(SimpleRecord.builder().origin("failing-record").build()) + .get(5, TimeUnit.SECONDS); + } catch (ExecutionException e) { + assertEquals("test-error", e.getCause().getMessage()); + } + } + + @Test + void testSinkGrpcError() throws Exception { + sink.write(SimpleRecord.builder().origin("failing-server").build()); + assertTrue(context.failureCalled.await(1, TimeUnit.SECONDS)); + } + + @Test + void testSinkGrpcCompletedUnexpectedly() throws Exception { + sink.write(SimpleRecord.builder().origin("completing-server").build()); + assertTrue(context.failureCalled.await(1, TimeUnit.SECONDS)); + } + + @Test + void testAvroAndSchema() throws Exception { + Schema schema = + SchemaBuilder.record("testRecord") + .fields() + .name("testField") + .type() + .stringType() + .noDefault() + .endRecord(); + GenericData.Record avroRecord = new GenericData.Record(schema); + avroRecord.put("testField", "test-string"); + + sink.write(SimpleRecord.of(null, avroRecord)).get(5, TimeUnit.SECONDS); + + GenericRecord writtenRecord = testSinkService.avroRecords.poll(5, TimeUnit.SECONDS); + assertEquals("test-string", writtenRecord.get("testField").toString()); + } + + static class TestSinkService extends AgentServiceGrpc.AgentServiceImplBase { + + private final Map schemas = new ConcurrentHashMap<>(); + private final LinkedBlockingQueue avroRecords = new LinkedBlockingQueue<>(); + + @Override + public StreamObserver write(StreamObserver responseObserver) { + return new StreamObserver<>() { + @Override + public void onNext(SinkRequest request) { + if (request.hasSchema()) { + Schema schema = + new Schema.Parser() + .parse(request.getSchema().getValue().toStringUtf8()); + schemas.put(request.getSchema().getSchemaId(), schema); + } + if (request.hasRecord()) { + ai.langstream.agents.grpc.Record record = request.getRecord(); + Value value = record.getValue(); + if (value.hasAvroValue()) { + Schema schema = schemas.get(value.getSchemaId()); + try { + GenericRecord genericRecord = + deserializeGenericRecord( + schema, value.getAvroValue().toByteArray()); + avroRecords.add(genericRecord); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + if (record.getOrigin().equals("failing-record")) { + responseObserver.onNext( + SinkResponse.newBuilder() + .setRecordId(record.getRecordId()) + .setError("test-error") + .build()); + } else if (record.getOrigin().equals("failing-server")) { + responseObserver.onError(new RuntimeException("test-error")); + } else if (record.getOrigin().equals("completing-server")) { + responseObserver.onCompleted(); + } else { + responseObserver.onNext( + SinkResponse.newBuilder() + .setRecordId(record.getRecordId()) + .build()); + } + } + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + }; + } + } + + private static GenericRecord deserializeGenericRecord( + org.apache.avro.Schema schema, byte[] data) throws IOException { + GenericDatumReader reader = new GenericDatumReader<>(schema); + reader.getData().addLogicalTypeConversion(new Conversions.DecimalConversion()); + Decoder decoder = DecoderFactory.get().binaryDecoder(data, null); + return reader.read(null, decoder); + } + + static class TestAgentContext implements AgentContext { + + private final CountDownLatch failureCalled = new CountDownLatch(1); + + @Override + public TopicConsumer getTopicConsumer() { + return null; + } + + @Override + public TopicProducer getTopicProducer() { + return null; + } + + @Override + public String getGlobalAgentId() { + return null; + } + + @Override + public TopicAdmin getTopicAdmin() { + return null; + } + + @Override + public TopicConnectionProvider getTopicConnectionProvider() { + return null; + } + + @Override + public void criticalFailure(Throwable error) { + failureCalled.countDown(); + } + + @Override + public Path getCodeDirectory() { + return null; + } + } +} diff --git a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/GrpcAgentsProvider.java b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/GrpcAgentsProvider.java index 433f83667..2d1ea8024 100644 --- a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/GrpcAgentsProvider.java +++ b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/GrpcAgentsProvider.java @@ -28,7 +28,10 @@ public class GrpcAgentsProvider extends AbstractComposableAgentProvider { private static final Set SUPPORTED_AGENT_TYPES = - Set.of("experimental-python-source", "experimental-python-processor"); + Set.of( + "experimental-python-source", + "experimental-python-processor", + "experimental-python-sink"); public GrpcAgentsProvider() { super(SUPPORTED_AGENT_TYPES, List.of(KubernetesClusterRuntime.CLUSTER_TYPE, "none")); @@ -39,6 +42,7 @@ protected ComponentType getComponentType(AgentConfiguration agentConfiguration) return switch (agentConfiguration.getType()) { case "experimental-python-source" -> ComponentType.SOURCE; case "experimental-python-processor" -> ComponentType.PROCESSOR; + case "experimental-python-sink" -> ComponentType.SINK; default -> throw new IllegalStateException( "Unexpected agent type: " + agentConfiguration.getType()); };