diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java index 85f8f0652..2dbb7185c 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java @@ -33,11 +33,11 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; import lombok.Getter; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; @@ -72,8 +72,8 @@ abstract class AbstractGrpcAgent extends AbstractAgentCode { @Getter protected volatile boolean startFailedButDevelopmentMode = false; protected AgentServiceGrpc.AgentServiceStub asyncStub; - AtomicReference> topicProducerWriteResults = - new AtomicReference<>(); + protected CompletableFuture> + topicProducerWriteResults = CompletableFuture.completedFuture(null); private final Map topicProducers = new ConcurrentHashMap<>(); @@ -103,7 +103,8 @@ public void start() throws Exception { AgentServiceGrpc.newBlockingStub(channel).withDeadlineAfter(30, TimeUnit.SECONDS); asyncStub = AgentServiceGrpc.newStub(channel).withWaitForReady(); - topicProducerWriteResults.set( + topicProducerWriteResults = new CompletableFuture<>(); + topicProducerWriteResults.complete( asyncStub.getTopicProducerRecords( new StreamObserver<>() { @Override @@ -111,14 +112,18 @@ public void onNext(TopicProducerRecord topicProducerRecord) { TopicProducer topicProducer = topicProducers.computeIfAbsent( topicProducerRecord.getTopic(), - topic -> - agentContext - .getTopicConnectionProvider() - .createProducer( - agentContext - .getGlobalAgentId(), - topic, - Map.of())); + topic -> { + TopicProducer tp = + agentContext + .getTopicConnectionProvider() + .createProducer( + agentContext + .getGlobalAgentId(), + topic, + Map.of()); + tp.start(); + return tp; + }); try { topicProducer .write(fromGrpc(topicProducerRecord.getRecord())) @@ -149,13 +154,29 @@ public void onNext(TopicProducerRecord topicProducerRecord) { @Override public void onError(Throwable throwable) { - agentContext.criticalFailure(throwable); + if (!restarting.get()) { + agentContext.criticalFailure( + new RuntimeException( + "getTopicProducerRecords: gRPC server sent error: %s" + .formatted(throwable.getMessage()), + throwable)); + } else { + log.info( + "getTopicProducerRecords: ignoring error during restart {}", + throwable + ""); + } } @Override public void onCompleted() { - agentContext.criticalFailure( - new RuntimeException("Unexpected completion")); + if (!restarting.get()) { + agentContext.criticalFailure( + new RuntimeException( + "getTopicProducerRecords: gRPC server completed the stream unexpectedly")); + } else { + log.info( + "getTopicProducerRecords: ignoring error server stop during restart"); + } } })); } @@ -181,6 +202,16 @@ protected Map buildAdditionalInfo() { } protected synchronized void stopBeforeRestart() throws Exception { + restarting.set(true); + StreamObserver topicProducerWriteResultStreamObserver = + topicProducerWriteResults.get(); + if (topicProducerWriteResultStreamObserver != null) { + try { + topicProducerWriteResultStreamObserver.onCompleted(); + } catch (IllegalStateException e) { + log.info("Ignoring error while stopping {}", e + ""); + } + } stopChannel(false); } @@ -199,7 +230,6 @@ public void stopChannel(boolean wait) throws Exception { @Override public synchronized void close() throws Exception { - topicProducerWriteResults.get().onCompleted(); stopBeforeRestart(); stopChannel(true); for (TopicProducer topicProducer : topicProducers.values()) { diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/AbstractGrpcAgentTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/AbstractGrpcAgentTest.java index 62665945a..c05b23755 100644 --- a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/AbstractGrpcAgentTest.java +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/AbstractGrpcAgentTest.java @@ -35,10 +35,12 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +@Slf4j public class AbstractGrpcAgentTest { private Server server; @@ -171,7 +173,8 @@ void testTopicProducerError() throws Exception { TestAgentContext context = new TestAgentContextFailure(); startProcessor(context); assertEquals( - "INTERNAL: test-error", context.failure.get(15, TimeUnit.SECONDS).getMessage()); + "getTopicProducerRecords: gRPC server sent error: INTERNAL: test-error", + context.failure.get(15, TimeUnit.SECONDS).getMessage()); } @Test @@ -179,7 +182,8 @@ void testTopicProducerComplete() throws Exception { TestAgentContextCompleting context = new TestAgentContextCompleting(); startProcessor(context); assertEquals( - "Unexpected completion", context.failure.get(5, TimeUnit.SECONDS).getMessage()); + "getTopicProducerRecords: gRPC server completed the stream unexpectedly", + context.failure.get(5, TimeUnit.SECONDS).getMessage()); } private void startProcessor(AgentContext context) throws Exception { @@ -240,6 +244,7 @@ public long getTotalIn() { @Override public void criticalFailure(Throwable error) { + log.info("TestAgentContext critical failure", error); failure.complete(error); }