diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java index ddcbb9518..c62aef757 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java @@ -17,6 +17,7 @@ import ai.langstream.api.runner.code.AbstractAgentCode; import ai.langstream.api.runner.code.SimpleRecord; +import ai.langstream.api.runner.topics.TopicProducer; import ai.langstream.api.util.ConfigurationUtils; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; @@ -24,13 +25,16 @@ import com.google.protobuf.ByteString; import com.google.protobuf.Empty; import io.grpc.ManagedChannel; +import io.grpc.stub.StreamObserver; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -65,6 +69,12 @@ abstract class AbstractGrpcAgent extends AbstractAgentCode { protected final AtomicBoolean restarting = new AtomicBoolean(false); @Getter protected volatile boolean startFailedButDevelopmentMode = false; + protected AgentServiceGrpc.AgentServiceStub asyncStub; + + protected CompletableFuture> + topicProducerWriteResults = CompletableFuture.completedFuture(null); + + private final Map topicProducers = new ConcurrentHashMap<>(); protected record GrpcAgentRecord( Long id, @@ -90,6 +100,93 @@ public void start() throws Exception { } blockingStub = AgentServiceGrpc.newBlockingStub(channel).withDeadlineAfter(30, TimeUnit.SECONDS); + asyncStub = AgentServiceGrpc.newStub(channel).withWaitForReady(); + + topicProducerWriteResults = new CompletableFuture<>(); + topicProducerWriteResults.complete( + asyncStub.getTopicProducerRecords( + new StreamObserver<>() { + @Override + public void onNext(TopicProducerRecord topicProducerRecord) { + TopicProducer topicProducer = + topicProducers.computeIfAbsent( + topicProducerRecord.getTopic(), + topic -> { + TopicProducer tp = + agentContext + .getTopicConnectionProvider() + .createProducer( + agentContext + .getGlobalAgentId(), + topic, + Map.of()); + tp.start(); + return tp; + }); + try { + topicProducer + .write(fromGrpc(topicProducerRecord.getRecord())) + .whenComplete( + (r, e) -> { + if (e != null) { + log.error("Error writing record", e); + sendTopicProducerWriteResult( + TopicProducerWriteResult + .newBuilder() + .setError( + e + .getMessage())); + } else { + sendTopicProducerWriteResult( + TopicProducerWriteResult + .newBuilder() + .setRecordId( + topicProducerRecord + .getRecord() + .getRecordId())); + } + }); + } catch (IOException e) { + agentContext.criticalFailure(e); + } + } + + @Override + public void onError(Throwable throwable) { + if (!restarting.get()) { + agentContext.criticalFailure( + new RuntimeException( + "getTopicProducerRecords: gRPC server sent error: %s" + .formatted(throwable.getMessage()), + throwable)); + } else { + log.info( + "getTopicProducerRecords: ignoring error during restart {}", + throwable + ""); + } + } + + @Override + public void onCompleted() { + if (!restarting.get()) { + agentContext.criticalFailure( + new RuntimeException( + "getTopicProducerRecords: gRPC server completed the stream unexpectedly")); + } else { + log.info( + "getTopicProducerRecords: ignoring error server stop during restart"); + } + } + })); + } + + private synchronized void sendTopicProducerWriteResult( + TopicProducerWriteResult.Builder result) { + try { + topicProducerWriteResults.get().onNext(result.build()); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } } @Override @@ -103,6 +200,16 @@ protected Map buildAdditionalInfo() { } protected synchronized void stopBeforeRestart() throws Exception { + restarting.set(true); + StreamObserver topicProducerWriteResultStreamObserver = + topicProducerWriteResults.get(); + if (topicProducerWriteResultStreamObserver != null) { + try { + topicProducerWriteResultStreamObserver.onCompleted(); + } catch (IllegalStateException e) { + log.info("Ignoring error while stopping {}", e + ""); + } + } stopChannel(false); } @@ -123,6 +230,10 @@ public void stopChannel(boolean wait) throws Exception { public synchronized void close() throws Exception { stopBeforeRestart(); stopChannel(true); + for (TopicProducer topicProducer : topicProducers.values()) { + topicProducer.close(); + } + topicProducers.clear(); } protected Object fromGrpc(Value value) throws IOException { diff --git a/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto b/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto index da376e3c4..5edc60a25 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto +++ b/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto @@ -26,6 +26,7 @@ service AgentService { rpc read(stream SourceRequest) returns (stream SourceResponse) {} rpc process(stream ProcessorRequest) returns (stream ProcessorResponse) {} rpc write(stream SinkRequest) returns (stream SinkResponse) {} + rpc get_topic_producer_records(stream TopicProducerWriteResult) returns (stream TopicProducerRecord) {} } message InfoResponse { @@ -68,6 +69,16 @@ message Record { optional int64 timestamp = 6; } +message TopicProducerWriteResult { + int64 record_id = 1; + optional string error = 2; +} + +message TopicProducerRecord { + string topic = 1; + Record record = 3; +} + message PermanentFailure { int64 record_id = 1; string error_message = 2; diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/AbstractGrpcAgentTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/AbstractGrpcAgentTest.java new file mode 100644 index 000000000..c05b23755 --- /dev/null +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/AbstractGrpcAgentTest.java @@ -0,0 +1,278 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.agents.grpc; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import ai.langstream.api.runner.code.AgentContext; +import ai.langstream.api.runner.code.Record; +import ai.langstream.api.runner.topics.TopicAdmin; +import ai.langstream.api.runner.topics.TopicConnectionProvider; +import ai.langstream.api.runner.topics.TopicConsumer; +import ai.langstream.api.runner.topics.TopicProducer; +import com.google.protobuf.Empty; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.Status; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.StreamObserver; +import java.nio.file.Path; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +@Slf4j +public class AbstractGrpcAgentTest { + + private Server server; + private ManagedChannel channel; + private GrpcAgentProcessor processor; + + private final AgentServiceGrpc.AgentServiceImplBase testProcessorService = + new AgentServiceGrpc.AgentServiceImplBase() { + + @Override + public void agentInfo( + Empty request, StreamObserver responseObserver) { + responseObserver.onNext( + InfoResponse.newBuilder() + .setJsonInfo("{\"test-info-key\": \"test-info-value\"}") + .build()); + responseObserver.onCompleted(); + } + + @Override + public StreamObserver getTopicProducerRecords( + StreamObserver responseObserver) { + responseObserver.onNext( + TopicProducerRecord.newBuilder() + .setTopic("test-topic") + .setRecord( + ai.langstream.agents.grpc.Record.newBuilder() + .setRecordId(42) + .setValue( + Value.newBuilder() + .setStringValue("test-value1"))) + .build()); + return new StreamObserver<>() { + + @Override + public void onNext(TopicProducerWriteResult topicProducerWriteResult) { + if (topicProducerWriteResult.getError().equals("test-error")) { + responseObserver.onError( + Status.INTERNAL + .withDescription("test-error") + .asRuntimeException()); + } else if (topicProducerWriteResult + .getError() + .equals("test-complete")) { + responseObserver.onCompleted(); + } else if (topicProducerWriteResult.getRecordId() == 42) { + responseObserver.onNext( + TopicProducerRecord.newBuilder() + .setTopic("test-topic") + .setRecord( + ai.langstream.agents.grpc.Record + .newBuilder() + .setRecordId(43) + .setValue( + Value.newBuilder() + .setStringValue( + "test-value2"))) + .build()); + } + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + + @Override + public StreamObserver process( + StreamObserver responseObserver) { + return new StreamObserver<>() { + @Override + public void onNext(ProcessorRequest processorRequest) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + }; + + @BeforeEach + public void setUp() throws Exception { + String serverName = InProcessServerBuilder.generateName(); + server = + InProcessServerBuilder.forName(serverName) + .addService(testProcessorService) + .build() + .start(); + + channel = InProcessChannelBuilder.forName(serverName).build(); + processor = new GrpcAgentProcessor(channel); + } + + @AfterEach + public void tearDown() throws Exception { + processor.close(); + channel.shutdownNow(); + server.shutdownNow(); + channel.awaitTermination(30, TimeUnit.SECONDS); + server.awaitTermination(30, TimeUnit.SECONDS); + } + + @Test + void testInfo() throws Exception { + startProcessor(new TestAgentContextSuccessful()); + Map info = processor.buildAdditionalInfo(); + assertEquals("test-info-value", info.get("test-info-key")); + } + + @Test + void testTopicProducerSuccess() throws Exception { + TestAgentContext context = new TestAgentContextSuccessful(); + startProcessor(context); + LinkedBlockingQueue recordsToWrite = context.recordsToWrite; + assertEquals("test-value1", recordsToWrite.poll(5, TimeUnit.SECONDS).value().toString()); + assertEquals("test-value2", recordsToWrite.poll(5, TimeUnit.SECONDS).value().toString()); + } + + @Test + void testTopicProducerError() throws Exception { + TestAgentContext context = new TestAgentContextFailure(); + startProcessor(context); + assertEquals( + "getTopicProducerRecords: gRPC server sent error: INTERNAL: test-error", + context.failure.get(15, TimeUnit.SECONDS).getMessage()); + } + + @Test + void testTopicProducerComplete() throws Exception { + TestAgentContextCompleting context = new TestAgentContextCompleting(); + startProcessor(context); + assertEquals( + "getTopicProducerRecords: gRPC server completed the stream unexpectedly", + context.failure.get(5, TimeUnit.SECONDS).getMessage()); + } + + private void startProcessor(AgentContext context) throws Exception { + processor.setContext(context); + processor.start(); + } + + abstract static class TestAgentContext implements AgentContext { + protected final LinkedBlockingQueue recordsToWrite = new LinkedBlockingQueue<>(); + protected final CompletableFuture failure = new CompletableFuture<>(); + + @Override + public TopicConsumer getTopicConsumer() { + return null; + } + + @Override + public TopicProducer getTopicProducer() { + return null; + } + + @Override + public String getGlobalAgentId() { + return null; + } + + @Override + public TopicAdmin getTopicAdmin() { + return null; + } + + protected abstract CompletableFuture writeRecord(Record record); + + @Override + public TopicConnectionProvider getTopicConnectionProvider() { + return new TopicConnectionProvider() { + @Override + public TopicProducer createProducer( + String agentId, String topic, Map config) { + return new TopicProducer() { + @Override + public CompletableFuture write( + ai.langstream.api.runner.code.Record record) { + if (topic.equals("test-topic")) { + return writeRecord(record); + } + return CompletableFuture.completedFuture(null); + } + + @Override + public long getTotalIn() { + return 0; + } + }; + } + }; + } + + @Override + public void criticalFailure(Throwable error) { + log.info("TestAgentContext critical failure", error); + failure.complete(error); + } + + @Override + public Path getCodeDirectory() { + return null; + } + } + + static class TestAgentContextSuccessful extends TestAgentContext { + @Override + protected CompletableFuture writeRecord(Record record) { + recordsToWrite.add(record); + return CompletableFuture.completedFuture(null); + } + } + + static class TestAgentContextFailure extends TestAgentContext { + @Override + protected CompletableFuture writeRecord(Record record) { + return CompletableFuture.failedFuture(new RuntimeException("test-error")); + } + } + + static class TestAgentContextCompleting extends TestAgentContext { + @Override + protected CompletableFuture writeRecord(Record record) { + return CompletableFuture.failedFuture(new RuntimeException("test-complete")); + } + } +} diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java index 5be658d9d..cf8994324 100644 --- a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java @@ -16,7 +16,6 @@ package ai.langstream.agents.grpc; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertSame; @@ -31,7 +30,6 @@ import ai.langstream.api.runner.topics.TopicConsumer; import ai.langstream.api.runner.topics.TopicProducer; import com.google.protobuf.ByteString; -import com.google.protobuf.Empty; import io.grpc.ManagedChannel; import io.grpc.Server; import io.grpc.Status; @@ -41,9 +39,7 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.util.List; -import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -57,8 +53,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.MethodSource; -import org.junit.jupiter.params.provider.ValueSource; public class GrpcAgentProcessorTest { private Server server; @@ -70,16 +66,6 @@ public class GrpcAgentProcessorTest { private final AgentServiceGrpc.AgentServiceImplBase testProcessorService = new AgentServiceGrpc.AgentServiceImplBase() { - @Override - public void agentInfo( - Empty request, StreamObserver responseObserver) { - responseObserver.onNext( - InfoResponse.newBuilder() - .setJsonInfo("{\"test-info-key\": \"test-info-value\"}") - .build()); - responseObserver.onCompleted(); - } - @Override public StreamObserver process( StreamObserver response) { @@ -135,6 +121,23 @@ public void onCompleted() { } }; } + + @Override + public StreamObserver getTopicProducerRecords( + StreamObserver responseObserver) { + return new StreamObserver<>() { + @Override + public void onNext(TopicProducerWriteResult topicProducerWriteResult) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } }; @BeforeEach @@ -198,7 +201,6 @@ void testProcess(Object value, Object key, Object header) throws Exception { @Test void testEmpty() throws Exception { assertProcessSuccessful(processor, SimpleRecord.builder().build()); - assertFalse(context.failureCalled.await(1, TimeUnit.SECONDS)); } @Test @@ -222,14 +224,18 @@ void testFailingRecord() throws Exception { } @ParameterizedTest - @ValueSource( - strings = {"failing-server", "completing-server", "wrong-record-id", "wrong-schema-id"}) - void testServerError(String origin) throws Exception { + @CsvSource({ + "failing-server,gRPC server sent error: INTERNAL: server error", + "completing-server,gRPC server completed the stream unexpectedly", + "wrong-record-id,Received unknown record id 2", + "wrong-schema-id,Error while processing record 1: Unknown schema id 1" + }) + void testServerError(String origin, String error) throws Exception { Record inputRecord = SimpleRecord.builder().origin(origin).build(); processor.process(List.of(inputRecord), result -> {}); - assertTrue(context.failureCalled.await(1, TimeUnit.SECONDS)); + assertEquals(error, context.failure.get(1, TimeUnit.SECONDS).getMessage()); } @Test @@ -256,12 +262,6 @@ void testAvroAndSchema() throws Exception { assertEquals(1, schemaCounter.get()); } - @Test - void testInfo() throws Exception { - Map info = processor.buildAdditionalInfo(); - assertEquals("test-info-value", info.get("test-info-key")); - } - private static void assertProcessSuccessful(GrpcAgentProcessor processor, Record inputRecord) throws ExecutionException, InterruptedException, TimeoutException { CompletableFuture op = new CompletableFuture<>(); @@ -304,7 +304,7 @@ private static void assertValueEquals(Object expected, Object actual) { static class TestAgentContext implements AgentContext { - private final CountDownLatch failureCalled = new CountDownLatch(1); + private final CompletableFuture failure = new CompletableFuture<>(); @Override public TopicConsumer getTopicConsumer() { @@ -333,7 +333,7 @@ public TopicConnectionProvider getTopicConnectionProvider() { @Override public void criticalFailure(Throwable error) { - failureCalled.countDown(); + failure.complete(error); } @Override diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSinkTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSinkTest.java index d5dec8b11..7269ff858 100644 --- a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSinkTest.java +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSinkTest.java @@ -16,7 +16,6 @@ package ai.langstream.agents.grpc; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; import ai.langstream.api.runner.code.AgentContext; import ai.langstream.api.runner.code.SimpleRecord; @@ -32,8 +31,8 @@ import java.io.IOException; import java.nio.file.Path; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -95,13 +94,17 @@ void testWriteError() throws Exception { @Test void testSinkGrpcError() throws Exception { sink.write(SimpleRecord.builder().origin("failing-server").build()); - assertTrue(context.failureCalled.await(1, TimeUnit.SECONDS)); + assertEquals( + "gRPC server sent error: UNKNOWN", + context.failure.get(1, TimeUnit.SECONDS).getMessage()); } @Test void testSinkGrpcCompletedUnexpectedly() throws Exception { sink.write(SimpleRecord.builder().origin("completing-server").build()); - assertTrue(context.failureCalled.await(1, TimeUnit.SECONDS)); + assertEquals( + "gRPC server completed the stream unexpectedly", + context.failure.get(1, TimeUnit.SECONDS).getMessage()); } @Test @@ -181,6 +184,23 @@ public void onCompleted() { } }; } + + @Override + public StreamObserver getTopicProducerRecords( + StreamObserver responseObserver) { + return new StreamObserver<>() { + @Override + public void onNext(TopicProducerWriteResult topicProducerWriteResult) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } } private static GenericRecord deserializeGenericRecord( @@ -193,7 +213,7 @@ private static GenericRecord deserializeGenericRecord( static class TestAgentContext implements AgentContext { - private final CountDownLatch failureCalled = new CountDownLatch(1); + private final CompletableFuture failure = new CompletableFuture<>(); @Override public TopicConsumer getTopicConsumer() { @@ -222,7 +242,7 @@ public TopicConnectionProvider getTopicConnectionProvider() { @Override public void criticalFailure(Throwable error) { - failureCalled.countDown(); + failure.complete(error); } @Override diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java index 8b1a6f2aa..74159cb0c 100644 --- a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java @@ -17,8 +17,6 @@ import static org.awaitility.Awaitility.await; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; import ai.langstream.api.runner.code.AgentContext; import ai.langstream.api.runner.code.Record; @@ -37,8 +35,8 @@ import java.nio.file.Path; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import org.apache.avro.Conversions; import org.apache.avro.generic.GenericData; @@ -76,6 +74,7 @@ public void setUp() throws Exception { @AfterEach public void tearDown() throws Exception { + source.close(); channel.shutdownNow(); server.shutdownNow(); channel.awaitTermination(30, TimeUnit.SECONDS); @@ -86,26 +85,26 @@ public void tearDown() throws Exception { void testCommit() throws Exception { List read = readRecords(source, 3); source.commit(List.of(read.get(0))); - assertFalse(context.failureCalled.await(1, TimeUnit.SECONDS)); assertEquals(1, testSourceService.committedRecords.size()); assertEquals(42, testSourceService.committedRecords.get(0)); - source.close(); } @Test void testSourceGrpcError() throws Exception { List read = readRecords(source, 3); source.commit(List.of(read.get(1))); - assertTrue(context.failureCalled.await(1, TimeUnit.SECONDS)); - source.close(); + assertEquals( + "gRPC server sent error: UNKNOWN", + context.failure.get(1, TimeUnit.SECONDS).getMessage()); } @Test void testSourceGrpcCompletedUnexpectedly() throws Exception { List read = readRecords(source, 3); source.commit(List.of(read.get(2))); - assertTrue(context.failureCalled.await(1, TimeUnit.SECONDS)); - source.close(); + assertEquals( + "gRPC server completed the stream unexpectedly", + context.failure.get(1, TimeUnit.SECONDS).getMessage()); } @Test @@ -113,7 +112,6 @@ void testAvroAndSchema() throws Exception { List read = readRecords(source, 1); GenericRecord record = (GenericRecord) read.get(0).value(); assertEquals("test-string", record.get("testField").toString()); - source.close(); } @Test @@ -122,7 +120,6 @@ void testPermanentFailure() throws Exception { source.permanentFailure(read.get(0), new RuntimeException("permanent-failure")); assertEquals(testSourceService.permanentFailure.getRecordId(), 42); assertEquals(testSourceService.permanentFailure.getErrorMessage(), "permanent-failure"); - source.close(); } static List readRecords(GrpcAgentSource source, int numberOfRecords) { @@ -219,11 +216,28 @@ public void onCompleted() { } }; } + + @Override + public StreamObserver getTopicProducerRecords( + StreamObserver responseObserver) { + return new StreamObserver<>() { + @Override + public void onNext(TopicProducerWriteResult topicProducerWriteResult) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } } static class TestAgentContext implements AgentContext { - private final CountDownLatch failureCalled = new CountDownLatch(1); + private final CompletableFuture failure = new CompletableFuture<>(); @Override public TopicConsumer getTopicConsumer() { @@ -252,7 +266,7 @@ public TopicConnectionProvider getTopicConnectionProvider() { @Override public void criticalFailure(Throwable error) { - failureCalled.countDown(); + failure.complete(error); } @Override diff --git a/langstream-e2e-tests/src/test/java/ai/langstream/tests/PythonAgentsIT.java b/langstream-e2e-tests/src/test/java/ai/langstream/tests/PythonAgentsIT.java index 8a5706a98..c328ea104 100644 --- a/langstream-e2e-tests/src/test/java/ai/langstream/tests/PythonAgentsIT.java +++ b/langstream-e2e-tests/src/test/java/ai/langstream/tests/PythonAgentsIT.java @@ -31,7 +31,7 @@ public class PythonAgentsIT extends BaseEndToEndTest { @Test - public void testProcessor() throws Exception { + public void testProcessor() { installLangStreamCluster(true); final String tenant = "ten-" + System.currentTimeMillis(); setupTenant(tenant); diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py index edbf6da4e..a12d5772d 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py @@ -82,6 +82,11 @@ def agent_info(self, _, __): info = call_method_if_exists(self.agent, "agent_info") or {} return InfoResponse(json_info=json.dumps(info)) + def get_topic_producer_records(self, request_iterator, context): + # TODO: to be implementedbla + for _ in request_iterator: + yield None + def read(self, requests: Iterable[SourceRequest], _): read_records = {} op_result = [] diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py index 6889bb9e3..ee36bdb57 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py @@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n!langstream_grpc/proto/agent.proto\x1a\x1bgoogle/protobuf/empty.proto"!\n\x0cInfoResponse\x12\x11\n\tjson_info\x18\x01 \x01(\t"\xa3\x02\n\x05Value\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\x15\n\x0b\x62ytes_value\x18\x02 \x01(\x0cH\x00\x12\x17\n\rboolean_value\x18\x03 \x01(\x08H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x14\n\nbyte_value\x18\x05 \x01(\x05H\x00\x12\x15\n\x0bshort_value\x18\x06 \x01(\x05H\x00\x12\x13\n\tint_value\x18\x07 \x01(\x05H\x00\x12\x14\n\nlong_value\x18\x08 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\t \x01(\x02H\x00\x12\x16\n\x0c\x64ouble_value\x18\n \x01(\x01H\x00\x12\x14\n\njson_value\x18\x0b \x01(\tH\x00\x12\x14\n\navro_value\x18\x0c \x01(\x0cH\x00\x42\x0c\n\ntype_oneof"-\n\x06Header\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x05value\x18\x02 \x01(\x0b\x32\x06.Value"*\n\x06Schema\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x0c"\xb3\x01\n\x06Record\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x18\n\x03key\x18\x02 \x01(\x0b\x32\x06.ValueH\x00\x88\x01\x01\x12\x1a\n\x05value\x18\x03 \x01(\x0b\x32\x06.ValueH\x01\x88\x01\x01\x12\x18\n\x07headers\x18\x04 \x03(\x0b\x32\x07.Header\x12\x0e\n\x06origin\x18\x05 \x01(\t\x12\x16\n\ttimestamp\x18\x06 \x01(\x03H\x02\x88\x01\x01\x42\x06\n\x04_keyB\x08\n\x06_valueB\x0c\n\n_timestamp"<\n\x10PermanentFailure\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x15\n\rerror_message\x18\x02 \x01(\t"X\n\rSourceRequest\x12\x19\n\x11\x63ommitted_records\x18\x01 \x03(\x03\x12,\n\x11permanent_failure\x18\x02 \x01(\x0b\x32\x11.PermanentFailure"C\n\x0eSourceResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"E\n\x10ProcessorRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"O\n\x11ProcessorResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12!\n\x07results\x18\x02 \x03(\x0b\x32\x10.ProcessorResult"\\\n\x0fProcessorResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x07records\x18\x03 \x03(\x0b\x32\x07.RecordB\x08\n\x06_error"?\n\x0bSinkRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x17\n\x06record\x18\x02 \x01(\x0b\x32\x07.Record"?\n\x0cSinkResponse\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error2\xd8\x01\n\x0c\x41gentService\x12\x35\n\nagent_info\x12\x16.google.protobuf.Empty\x1a\r.InfoResponse"\x00\x12-\n\x04read\x12\x0e.SourceRequest\x1a\x0f.SourceResponse"\x00(\x01\x30\x01\x12\x36\n\x07process\x12\x11.ProcessorRequest\x1a\x12.ProcessorResponse"\x00(\x01\x30\x01\x12*\n\x05write\x12\x0c.SinkRequest\x1a\r.SinkResponse"\x00(\x01\x30\x01\x42\x1d\n\x19\x61i.langstream.agents.grpcP\x01\x62\x06proto3' + b'\n!langstream_grpc/proto/agent.proto\x1a\x1bgoogle/protobuf/empty.proto"!\n\x0cInfoResponse\x12\x11\n\tjson_info\x18\x01 \x01(\t"\xa3\x02\n\x05Value\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\x15\n\x0b\x62ytes_value\x18\x02 \x01(\x0cH\x00\x12\x17\n\rboolean_value\x18\x03 \x01(\x08H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x14\n\nbyte_value\x18\x05 \x01(\x05H\x00\x12\x15\n\x0bshort_value\x18\x06 \x01(\x05H\x00\x12\x13\n\tint_value\x18\x07 \x01(\x05H\x00\x12\x14\n\nlong_value\x18\x08 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\t \x01(\x02H\x00\x12\x16\n\x0c\x64ouble_value\x18\n \x01(\x01H\x00\x12\x14\n\njson_value\x18\x0b \x01(\tH\x00\x12\x14\n\navro_value\x18\x0c \x01(\x0cH\x00\x42\x0c\n\ntype_oneof"-\n\x06Header\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x05value\x18\x02 \x01(\x0b\x32\x06.Value"*\n\x06Schema\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x0c"\xb3\x01\n\x06Record\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x18\n\x03key\x18\x02 \x01(\x0b\x32\x06.ValueH\x00\x88\x01\x01\x12\x1a\n\x05value\x18\x03 \x01(\x0b\x32\x06.ValueH\x01\x88\x01\x01\x12\x18\n\x07headers\x18\x04 \x03(\x0b\x32\x07.Header\x12\x0e\n\x06origin\x18\x05 \x01(\t\x12\x16\n\ttimestamp\x18\x06 \x01(\x03H\x02\x88\x01\x01\x42\x06\n\x04_keyB\x08\n\x06_valueB\x0c\n\n_timestamp"K\n\x18TopicProducerWriteResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error"=\n\x13TopicProducerRecord\x12\r\n\x05topic\x18\x01 \x01(\t\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record"<\n\x10PermanentFailure\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x15\n\rerror_message\x18\x02 \x01(\t"X\n\rSourceRequest\x12\x19\n\x11\x63ommitted_records\x18\x01 \x03(\x03\x12,\n\x11permanent_failure\x18\x02 \x01(\x0b\x32\x11.PermanentFailure"C\n\x0eSourceResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"E\n\x10ProcessorRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"O\n\x11ProcessorResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12!\n\x07results\x18\x02 \x03(\x0b\x32\x10.ProcessorResult"\\\n\x0fProcessorResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x07records\x18\x03 \x03(\x0b\x32\x07.RecordB\x08\n\x06_error"?\n\x0bSinkRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x17\n\x06record\x18\x02 \x01(\x0b\x32\x07.Record"?\n\x0cSinkResponse\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error2\xad\x02\n\x0c\x41gentService\x12\x35\n\nagent_info\x12\x16.google.protobuf.Empty\x1a\r.InfoResponse"\x00\x12-\n\x04read\x12\x0e.SourceRequest\x1a\x0f.SourceResponse"\x00(\x01\x30\x01\x12\x36\n\x07process\x12\x11.ProcessorRequest\x1a\x12.ProcessorResponse"\x00(\x01\x30\x01\x12*\n\x05write\x12\x0c.SinkRequest\x1a\r.SinkResponse"\x00(\x01\x30\x01\x12S\n\x1aget_topic_producer_records\x12\x19.TopicProducerWriteResult\x1a\x14.TopicProducerRecord"\x00(\x01\x30\x01\x42\x1d\n\x19\x61i.langstream.agents.grpcP\x01\x62\x06proto3' ) _globals = globals() @@ -53,22 +53,26 @@ _globals["_SCHEMA"]._serialized_end = 484 _globals["_RECORD"]._serialized_start = 487 _globals["_RECORD"]._serialized_end = 666 - _globals["_PERMANENTFAILURE"]._serialized_start = 668 - _globals["_PERMANENTFAILURE"]._serialized_end = 728 - _globals["_SOURCEREQUEST"]._serialized_start = 730 - _globals["_SOURCEREQUEST"]._serialized_end = 818 - _globals["_SOURCERESPONSE"]._serialized_start = 820 - _globals["_SOURCERESPONSE"]._serialized_end = 887 - _globals["_PROCESSORREQUEST"]._serialized_start = 889 - _globals["_PROCESSORREQUEST"]._serialized_end = 958 - _globals["_PROCESSORRESPONSE"]._serialized_start = 960 - _globals["_PROCESSORRESPONSE"]._serialized_end = 1039 - _globals["_PROCESSORRESULT"]._serialized_start = 1041 - _globals["_PROCESSORRESULT"]._serialized_end = 1133 - _globals["_SINKREQUEST"]._serialized_start = 1135 - _globals["_SINKREQUEST"]._serialized_end = 1198 - _globals["_SINKRESPONSE"]._serialized_start = 1200 - _globals["_SINKRESPONSE"]._serialized_end = 1263 - _globals["_AGENTSERVICE"]._serialized_start = 1266 - _globals["_AGENTSERVICE"]._serialized_end = 1482 + _globals["_TOPICPRODUCERWRITERESULT"]._serialized_start = 668 + _globals["_TOPICPRODUCERWRITERESULT"]._serialized_end = 743 + _globals["_TOPICPRODUCERRECORD"]._serialized_start = 745 + _globals["_TOPICPRODUCERRECORD"]._serialized_end = 806 + _globals["_PERMANENTFAILURE"]._serialized_start = 808 + _globals["_PERMANENTFAILURE"]._serialized_end = 868 + _globals["_SOURCEREQUEST"]._serialized_start = 870 + _globals["_SOURCEREQUEST"]._serialized_end = 958 + _globals["_SOURCERESPONSE"]._serialized_start = 960 + _globals["_SOURCERESPONSE"]._serialized_end = 1027 + _globals["_PROCESSORREQUEST"]._serialized_start = 1029 + _globals["_PROCESSORREQUEST"]._serialized_end = 1098 + _globals["_PROCESSORRESPONSE"]._serialized_start = 1100 + _globals["_PROCESSORRESPONSE"]._serialized_end = 1179 + _globals["_PROCESSORRESULT"]._serialized_start = 1181 + _globals["_PROCESSORRESULT"]._serialized_end = 1273 + _globals["_SINKREQUEST"]._serialized_start = 1275 + _globals["_SINKREQUEST"]._serialized_end = 1338 + _globals["_SINKRESPONSE"]._serialized_start = 1340 + _globals["_SINKRESPONSE"]._serialized_end = 1403 + _globals["_AGENTSERVICE"]._serialized_start = 1406 + _globals["_AGENTSERVICE"]._serialized_end = 1707 # @@protoc_insertion_point(module_scope) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi index 1fc887ddb..8ae844d8c 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi @@ -119,6 +119,28 @@ class Record(_message.Message): timestamp: _Optional[int] = ..., ) -> None: ... +class TopicProducerWriteResult(_message.Message): + __slots__ = ["record_id", "error"] + RECORD_ID_FIELD_NUMBER: _ClassVar[int] + ERROR_FIELD_NUMBER: _ClassVar[int] + record_id: int + error: str + def __init__( + self, record_id: _Optional[int] = ..., error: _Optional[str] = ... + ) -> None: ... + +class TopicProducerRecord(_message.Message): + __slots__ = ["topic", "record"] + TOPIC_FIELD_NUMBER: _ClassVar[int] + RECORD_FIELD_NUMBER: _ClassVar[int] + topic: str + record: Record + def __init__( + self, + topic: _Optional[str] = ..., + record: _Optional[_Union[Record, _Mapping]] = ..., + ) -> None: ... + class PermanentFailure(_message.Message): __slots__ = ["record_id", "error_message"] RECORD_ID_FIELD_NUMBER: _ClassVar[int] diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2_grpc.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2_grpc.py index fd1caf989..acffce84c 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2_grpc.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2_grpc.py @@ -51,6 +51,11 @@ def __init__(self, channel): request_serializer=langstream__grpc_dot_proto_dot_agent__pb2.SinkRequest.SerializeToString, response_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.SinkResponse.FromString, ) + self.get_topic_producer_records = channel.stream_stream( + "/AgentService/get_topic_producer_records", + request_serializer=langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerWriteResult.SerializeToString, + response_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerRecord.FromString, + ) class AgentServiceServicer(object): @@ -80,6 +85,12 @@ def write(self, request_iterator, context): context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def get_topic_producer_records(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def add_AgentServiceServicer_to_server(servicer, server): rpc_method_handlers = { @@ -103,6 +114,11 @@ def add_AgentServiceServicer_to_server(servicer, server): request_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.SinkRequest.FromString, response_serializer=langstream__grpc_dot_proto_dot_agent__pb2.SinkResponse.SerializeToString, ), + "get_topic_producer_records": grpc.stream_stream_rpc_method_handler( + servicer.get_topic_producer_records, + request_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerWriteResult.FromString, + response_serializer=langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerRecord.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( "AgentService", rpc_method_handlers @@ -229,3 +245,32 @@ def write( timeout, metadata, ) + + @staticmethod + def get_topic_producer_records( + request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.stream_stream( + request_iterator, + target, + "/AgentService/get_topic_producer_records", + langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerWriteResult.SerializeToString, + langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerRecord.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + )