From 8f151de409f825bac988a4ec7c20d37a9a24de10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Boschi?= Date: Wed, 7 Aug 2024 14:02:58 +0200 Subject: [PATCH] feat: make service gateway response configurable (#121) * Now when the agent fails by default sets a new header "langstream-error-type" which has two possible values: INTERNAL_ERROR or INVALID_RECORD. The default is INTERNAL_ERROR. * When the service gateway gets the message it looks for this header and decide the status code (INVALID_RECORD -> 400, INTERNAL_ERROR -> 500). **This is a breaking change** since before that it was always returning 200. It also looks the following header until one it's not blank: langstream-error-message, langstream-error-cause-message, langstream-error-root-cause-message. * These new headers are sent to the deadletter along with the existing headers: error-msg, cause-msg, root-cause-msg (note that those headers are still present to not break compatibility but they're not checked by the gateway) * From the agent perspective, to set an error as INVALID_RECORD it can be done in this way: 1. In Java agents, you can emit the result with appropriate type 2. In Python you can raise an Exception in the sink or processor ``` def process(self, record): logging.info("Processing record" + str(record)) from langstream import InvalidRecordError raise InvalidRecordError("record was not ok:" + str(record)) ``` Any other exception will treat the error as INTERNAL_ERROR --- .github/workflows/ci.yml | 3 + docker/build.sh | 5 +- .../langstream/agents/camel/CamelSource.java | 8 +- .../agents/grpc/GrpcAgentProcessor.java | 10 +- .../agents/grpc/GrpcAgentSource.java | 10 +- .../proto/langstream_grpc/proto/agent.proto | 2 + .../agents/grpc/GrpcAgentSourceTest.java | 2 +- .../agents/pulsardlq/PulsarDLQSource.java | 84 ++++----- .../langstream/agents/flow/DispatchAgent.java | 1 - .../apigateway/gateways/ConsumeGateway.java | 19 +- .../apigateway/http/GatewayResource.java | 81 ++++++++- .../websocket/handlers/AbstractHandler.java | 14 +- .../apigateway/http/GatewayResourceTest.java | 107 ++++++++--- .../api/runner/code/AgentProcessor.java | 12 +- .../api/runner/code/AgentSource.java | 6 +- .../api/runner/code/ErrorTypes.java | 21 +++ .../api/runner/code/SystemHeaders.java | 47 +++++ .../java/ai/langstream/cli/LangStreamCLI.java | 5 +- .../ai/langstream/cli/api/model/Gateways.java | 1 + .../cli/commands/RootGatewayCmd.java | 8 +- .../cli/commands/gateway/BaseGatewayCmd.java | 17 +- .../commands/gateway/ServiceGatewayCmd.java | 136 ++++++++++++++ .../applications/GatewaysCmdTest.java | 36 ++++ .../java/ai/langstream/tests/PythonDLQIT.java | 69 ++++++++ .../tests/util/BaseEndToEndTest.java | 51 +++--- .../dlq-pipeline.yaml | 34 ++++ .../python-processor-with-dlq/gateways.yaml | 23 +++ .../python-processor-with-dlq/pipeline.yaml | 38 ++++ .../python/example.py | 26 +++ .../src/test/resources/secrets/secret1.yaml | 4 + .../agents/PulsarDLQSourceAgentProvider.java | 10 ++ .../langstream-runtime-impl/pom.xml | 36 +++- .../langstream/runtime/agent/AgentRunner.java | 27 ++- .../runtime/agent/TopicConsumerSource.java | 57 ++++-- .../src/main/python/langstream/__init__.py | 3 +- .../src/main/python/langstream/api.py | 2 +- .../src/main/python/langstream/util.py | 6 +- .../python/langstream_grpc/grpc_service.py | 5 + .../python/langstream_grpc/proto/agent_pb2.py | 77 ++++---- .../langstream_grpc/proto/agent_pb2.pyi | 42 +++-- .../langstream_grpc/proto/agent_pb2_grpc.py | 56 +++--- .../tests/test_grpc_processor.py | 21 +++ .../src/main/python/langstream_grpc/util.py | 6 +- .../main/python/scripts/generate-grpc-code.sh | 26 +++ ...HandlingTest.java => ErrorHandlingIT.java} | 95 +++++++--- ...lAgentsTest.java => StatefulAgentsIT.java} | 6 +- .../langstream/agents/WebCrawlerSourceIT.java | 3 + .../langstream/pulsar/PulsarDLQSourceIT.java | 166 ++++++++++++++++++ .../runtime/agent/AgentRunnerTest.java | 108 ++++-------- .../AbstractApplicationRunner.java | 63 ++++++- requirements.txt | 3 - 51 files changed, 1338 insertions(+), 360 deletions(-) create mode 100644 langstream-api/src/main/java/ai/langstream/api/runner/code/ErrorTypes.java create mode 100644 langstream-api/src/main/java/ai/langstream/api/runner/code/SystemHeaders.java create mode 100644 langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ServiceGatewayCmd.java create mode 100644 langstream-e2e-tests/src/test/java/ai/langstream/tests/PythonDLQIT.java create mode 100644 langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/dlq-pipeline.yaml create mode 100644 langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/gateways.yaml create mode 100644 langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/pipeline.yaml create mode 100644 langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/python/example.py create mode 100755 langstream-runtime/langstream-runtime-impl/src/main/python/scripts/generate-grpc-code.sh rename langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/{ErrorHandlingTest.java => ErrorHandlingIT.java} (73%) rename langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/{StatefulAgentsTest.java => StatefulAgentsIT.java} (93%) create mode 100644 langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/pulsar/PulsarDLQSourceIT.java diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2b801e3c6..5522f9357 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -71,9 +71,12 @@ jobs: - name: Agents test_cmd: ./mvnw verify -f langstream-agents $MAVEN_COMMON_SKIP_FLAGS - name: Other + setup_python: "true" test_cmd: | exclude_modules="$(cd langstream-agents && ls -d langstream-* | sed 's/^/!:/g' | tr '\n' ',' | sed 's/,$//'),!langstream-agents,!langstream-webservice,!:langstream-api-gateway,!:langstream-k8s-deployer-operator,!:langstream-runtime-impl" ./mvnw verify -pl $exclude_modules $MAVEN_COMMON_SKIP_FLAGS + # python + unit tests for runtime-impl + ./mvnw package -pl ":langstream-runtime-impl" -ntp -Dspotless.skip -Dlicense.skip steps: - name: Free Disk Space (Ubuntu) diff --git a/docker/build.sh b/docker/build.sh index cba9d1bf8..990ce2c99 100755 --- a/docker/build.sh +++ b/docker/build.sh @@ -56,10 +56,13 @@ elif [ "$only_image" == "cli" ]; then build_docker_image langstream-cli elif [ "$only_image" == "api-gateway" ]; then build_docker_image langstream-api-gateway -else +elif [ "$only_image" == "" ]; then # Always clean to remove old NARs and cached docker images in the "target" directory ./mvnw clean install -Pdocker -Ddocker.platforms="$(docker_platforms)" $common_flags docker images | head -n 6 +else + echo "Unknown image type: $only_image. Valid values are: control-plane, operator, deployer, runtime-base-docker-image, runtime, runtime-tester, cli, api-gateway" + exit 1 fi diff --git a/langstream-agents/langstream-agent-camel/src/main/java/ai/langstream/agents/camel/CamelSource.java b/langstream-agents/langstream-agent-camel/src/main/java/ai/langstream/agents/camel/CamelSource.java index cb0dbe040..6469b6630 100644 --- a/langstream-agents/langstream-agent-camel/src/main/java/ai/langstream/agents/camel/CamelSource.java +++ b/langstream-agents/langstream-agent-camel/src/main/java/ai/langstream/agents/camel/CamelSource.java @@ -15,11 +15,8 @@ */ package ai.langstream.agents.camel; -import ai.langstream.api.runner.code.AbstractAgentCode; -import ai.langstream.api.runner.code.AgentSource; -import ai.langstream.api.runner.code.Header; +import ai.langstream.api.runner.code.*; import ai.langstream.api.runner.code.Record; -import ai.langstream.api.runner.code.SimpleRecord; import ai.langstream.api.util.ConfigurationUtils; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; @@ -240,7 +237,8 @@ public void commit(List records) throws Exception { } @Override - public void permanentFailure(Record record, Exception error) throws Exception { + public void permanentFailure(Record record, Exception error, ErrorTypes errorType) + throws Exception { CamelRecord camelRecord = (CamelRecord) record; log.info("Record {} failed", camelRecord); camelRecord.exchange.setException(error); diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentProcessor.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentProcessor.java index 9ee9433dd..070387fd6 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentProcessor.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentProcessor.java @@ -16,6 +16,7 @@ package ai.langstream.agents.grpc; import ai.langstream.api.runner.code.AgentProcessor; +import ai.langstream.api.runner.code.ErrorTypes; import ai.langstream.api.runner.code.RecordSink; import io.grpc.ManagedChannel; import io.grpc.stub.StreamObserver; @@ -112,9 +113,14 @@ private SourceRecordAndResult fromGrpc( throws IOException { List resultRecords = new ArrayList<>(); if (result.hasError()) { - // TODO: specialize exception ? + final ErrorTypes errorType; + if (result.hasErrorType()) { + errorType = ErrorTypes.valueOf(result.getErrorType().toUpperCase()); + } else { + errorType = ErrorTypes.INTERNAL_ERROR; + } return new SourceRecordAndResult( - sourceRecord, null, new RuntimeException(result.getError())); + sourceRecord, null, new RuntimeException(result.getError()), errorType); } for (Record record : result.getRecordsList()) { resultRecords.add(fromGrpc(record)); diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java index 65f672ab2..b1a0db9ec 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java @@ -16,6 +16,7 @@ package ai.langstream.agents.grpc; import ai.langstream.api.runner.code.AgentSource; +import ai.langstream.api.runner.code.ErrorTypes; import ai.langstream.api.runner.code.Record; import ai.langstream.api.util.ConfigurationUtils; import io.grpc.ManagedChannel; @@ -72,14 +73,19 @@ public List read() throws Exception { } @Override - public void permanentFailure(Record record, Exception error) { + public void permanentFailure(Record record, Exception error, ErrorTypes errorType) + throws Exception { if (record instanceof GrpcAgentRecord grpcAgentRecord) { request.onNext( SourceRequest.newBuilder() .setPermanentFailure( PermanentFailure.newBuilder() .setRecordId(grpcAgentRecord.id()) - .setErrorMessage(error.getMessage())) + .setErrorMessage(error.getMessage()) + .setErrorType( + errorType == null + ? ErrorTypes.INTERNAL_ERROR.toString() + : errorType.toString())) .build()); } else { throw new IllegalArgumentException( diff --git a/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto b/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto index 8c1afec9e..fc662d717 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto +++ b/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto @@ -83,6 +83,7 @@ message TopicProducerResponse { message PermanentFailure { int64 record_id = 1; string error_message = 2; + string error_type = 3; } message SourceRequest { @@ -110,6 +111,7 @@ message ProcessorResult { int64 record_id = 1; optional string error = 2; repeated Record records = 3; + optional string error_type = 4; } message SinkRequest { diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java index 3d27d4a75..1364380f7 100644 --- a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java @@ -119,7 +119,7 @@ void testAvroAndSchema() throws Exception { @Test void testPermanentFailure() throws Exception { List read = readRecords(source, 1); - source.permanentFailure(read.get(0), new RuntimeException("permanent-failure")); + source.permanentFailure(read.get(0), new RuntimeException("permanent-failure"), null); assertEquals(testSourceService.permanentFailure.getRecordId(), 42); assertEquals(testSourceService.permanentFailure.getErrorMessage(), "permanent-failure"); } diff --git a/langstream-agents/langstream-agent-pulsardlq/src/main/java/ai/langstream/agents/pulsardlq/PulsarDLQSource.java b/langstream-agents/langstream-agent-pulsardlq/src/main/java/ai/langstream/agents/pulsardlq/PulsarDLQSource.java index a36b55268..18587127c 100644 --- a/langstream-agents/langstream-agent-pulsardlq/src/main/java/ai/langstream/agents/pulsardlq/PulsarDLQSource.java +++ b/langstream-agents/langstream-agent-pulsardlq/src/main/java/ai/langstream/agents/pulsardlq/PulsarDLQSource.java @@ -15,15 +15,14 @@ */ package ai.langstream.agents.pulsardlq; -import ai.langstream.api.runner.code.AbstractAgentCode; -import ai.langstream.api.runner.code.AgentSource; -import ai.langstream.api.runner.code.Header; +import ai.langstream.api.runner.code.*; import ai.langstream.api.runner.code.Record; import ai.langstream.api.util.ConfigurationUtils; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.regex.Pattern; import lombok.extern.slf4j.Slf4j; import org.apache.pulsar.client.api.Consumer; @@ -40,41 +39,20 @@ public class PulsarDLQSource extends AbstractAgentCode implements AgentSource { private PulsarClient pulsarClient; private Consumer dlqTopicsConsumer; private boolean includePartitioned; - - private static class SimpleHeader implements Header { - private final String key; - private final Object value; - - public SimpleHeader(String key, Object value) { - this.key = key; - this.value = value; - } - - @Override - public String key() { - return key; - } - - @Override - public Object value() { - return value; - } - - @Override - public String valueAsString() { - if (value != null) { - return value.toString(); - } else { - return null; - } - } - } + private int timeoutMs; private static class PulsarRecord implements Record { private final Message message; + private final Collection
headers = new ArrayList<>(); public PulsarRecord(Message message) { this.message = message; + Map properties = message.getProperties(); + if (properties != null) { + for (Map.Entry entry : properties.entrySet()) { + headers.add(new SimpleRecord.SimpleHeader(entry.getKey(), entry.getValue())); + } + } } @Override @@ -103,14 +81,6 @@ public Long timestamp() { @Override public Collection
headers() { - Collection
headers = new ArrayList<>(); - Map properties = message.getProperties(); - - if (properties != null) { - for (Map.Entry entry : properties.entrySet()) { - headers.add(new SimpleHeader(entry.getKey(), entry.getValue())); - } - } return headers; } @@ -135,6 +105,7 @@ public void init(Map configuration) throws Exception { dlqSuffix = ConfigurationUtils.getString("dlq-suffix", "-DLQ", configuration); includePartitioned = ConfigurationUtils.getBoolean("include-partitioned", false, configuration); + timeoutMs = ConfigurationUtils.getInt("timeout-ms", 0, configuration); log.info("Initializing PulsarDLQSource with pulsarUrl: {}", pulsarUrl); log.info("Namespace: {}", namespace); log.info("Subscription: {}", subscription); @@ -144,10 +115,11 @@ public void init(Map configuration) throws Exception { @Override public void start() throws Exception { + log.info("Starting pulsar client {}", pulsarUrl); pulsarClient = PulsarClient.builder().serviceUrl(pulsarUrl).build(); - // The maximum lenth of the regex is 50 characters - // Uisng the persistent:// prefix generally works better, but - // it can push the partitioned pattern over the 50 character limit, so + // The maximum length of the regex is 50 characters + // Using the persistent:// prefix generally works better, but + // it can push the partitioned pattern over the 50 characters limit, so // we drop it for partitioned topics String patternString = "persistent://" + namespace + "/.*" + dlqSuffix; if (includePartitioned) { @@ -167,15 +139,26 @@ public void start() throws Exception { @Override public void close() throws Exception { super.close(); - dlqTopicsConsumer.close(); - pulsarClient.close(); + if (dlqTopicsConsumer != null) { + dlqTopicsConsumer.close(); + } + if (pulsarClient != null) { + pulsarClient.close(); + } } @Override public List read() throws Exception { - - Message msg = dlqTopicsConsumer.receive(); - + Message msg; + if (timeoutMs > 0) { + msg = dlqTopicsConsumer.receive(timeoutMs, TimeUnit.MILLISECONDS); + } else { + msg = dlqTopicsConsumer.receive(); + } + if (msg == null) { + log.debug("No message received"); + return List.of(); + } log.info("Received message: {}", new String(msg.getData())); Record record = new PulsarRecord(msg); return List.of(record); @@ -185,12 +168,13 @@ public List read() throws Exception { public void commit(List records) throws Exception { for (Record r : records) { PulsarRecord record = (PulsarRecord) r; - dlqTopicsConsumer.acknowledge(record.messageId()); // acknowledge the message + dlqTopicsConsumer.acknowledge(record.messageId()); } } @Override - public void permanentFailure(Record record, Exception error) throws Exception { + public void permanentFailure(Record record, Exception error, ErrorTypes errorType) + throws Exception { PulsarRecord pulsarRecord = (PulsarRecord) record; log.error("Failure on record {}", pulsarRecord, error); dlqTopicsConsumer.negativeAcknowledge(pulsarRecord.messageId()); diff --git a/langstream-agents/langstream-agents-flow-control/src/main/java/ai/langstream/agents/flow/DispatchAgent.java b/langstream-agents/langstream-agents-flow-control/src/main/java/ai/langstream/agents/flow/DispatchAgent.java index 197dafa48..55c911a57 100644 --- a/langstream-agents/langstream-agents-flow-control/src/main/java/ai/langstream/agents/flow/DispatchAgent.java +++ b/langstream-agents/langstream-agents-flow-control/src/main/java/ai/langstream/agents/flow/DispatchAgent.java @@ -98,7 +98,6 @@ public void start() throws Exception { @Override public void process(List records, RecordSink recordSink) { - log.info("got to process!{}", records); for (Record record : records) { processRecord(record, recordSink); } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java index 3d3103375..45a10e9d9 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/gateways/ConsumeGateway.java @@ -128,7 +128,10 @@ public void setup( } public void startReadingAsync( - Executor executor, Supplier stop, Consumer onMessage) { + Executor executor, + Supplier stop, + Consumer onMessage, + Consumer onError) { if (requestContext == null || reader == null) { throw new IllegalStateException("Not initialized"); } @@ -142,16 +145,23 @@ public void startReadingAsync( log.debug("[{}] Started reader", logRef); readMessages(stop, onMessage); } catch (Throwable ex) { - log.error("[{}] Error reading messages", logRef, ex); throw new RuntimeException(ex); } finally { closeReader(); } }, executor); + readerFuture.whenComplete( + (v, ex) -> { + if (ex != null) { + log.error("[{}] Error reading messages", logRef, ex); + onError.accept(ex); + } + }); } - private void readMessages(Supplier stop, Consumer onMessage) throws Exception { + private void readMessages(Supplier stop, Consumer onMessage) + throws Exception { while (true) { if (Thread.interrupted() || interrupted) { return; @@ -182,8 +192,7 @@ private void readMessages(Supplier stop, Consumer onMessage) th new ConsumePushMessage.Record( record.key(), record.value(), messageHeaders), offset); - final String jsonMessage = mapper.writeValueAsString(message); - onMessage.accept(jsonMessage); + onMessage.accept(message); } } } diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java index 7ae88001d..f5338ee51 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/http/GatewayResource.java @@ -17,16 +17,20 @@ import ai.langstream.api.gateway.GatewayRequestContext; import ai.langstream.api.model.Gateway; +import ai.langstream.api.runner.code.ErrorTypes; import ai.langstream.api.runner.code.Header; import ai.langstream.api.runner.code.Record; +import ai.langstream.api.runner.code.SystemHeaders; import ai.langstream.api.runtime.ClusterRuntimeRegistry; import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.api.ConsumePushMessage; import ai.langstream.apigateway.api.ProducePayload; import ai.langstream.apigateway.api.ProduceRequest; import ai.langstream.apigateway.api.ProduceResponse; import ai.langstream.apigateway.gateways.*; import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; +import com.fasterxml.jackson.databind.ObjectMapper; import jakarta.annotation.PreDestroy; import jakarta.servlet.http.HttpServletRequest; import jakarta.validation.constraints.NotBlank; @@ -46,11 +50,10 @@ import java.util.stream.Collectors; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.concurrent.BasicThreadFactory; import org.springframework.core.io.InputStreamResource; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import org.springframework.http.ResponseEntity; +import org.springframework.http.*; import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; @@ -69,9 +72,16 @@ @AllArgsConstructor public class GatewayResource { + private static final List RECORD_ERROR_HEADERS = + List.of( + SystemHeaders.ERROR_HANDLING_ERROR_MESSAGE.getKey(), + SystemHeaders.ERROR_HANDLING_CAUSE_ERROR_MESSAGE.getKey(), + SystemHeaders.ERROR_HANDLING_ROOT_CAUSE_ERROR_MESSAGE.getKey()); protected static final String GATEWAY_SERVICE_PATH = "/service/{tenant}/{application}/{gateway}/**"; - protected static final String SERVICE_REQUEST_ID_HEADER = "langstream-service-request-id"; + protected static final String SERVICE_REQUEST_ID_HEADER = + SystemHeaders.SERVICE_REQUEST_ID_HEADER.getKey(); + protected static final ObjectMapper mapper = new ObjectMapper(); private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider; private final ClusterRuntimeRegistry clusterRuntimeRegistry; private final TopicProducerCache topicProducerCache; @@ -299,7 +309,10 @@ private CompletableFuture handleServiceWithTopics( .getTopicConnectionsRuntimeRegistry(), clusterRuntimeRegistry, topicConnectionsRuntimeCache); - completableFuture.thenRunAsync(consumeGateway::close, consumeThreadPool); + completableFuture.whenComplete( + (r, t) -> { + consumeGateway.close(); + }); final Gateway.ServiceOptions serviceOptions = authContext.gateway().getServiceOptions(); try { @@ -323,8 +336,9 @@ record -> { stop::get, record -> { stop.set(true); - completableFuture.complete(ResponseEntity.ok(record)); - }); + completableFuture.complete(buildResponseFromReceivedRecord(record)); + }, + completableFuture::completeExceptionally); } catch (Exception ex) { log.error("Error while setting up consume gateway", ex); throw new RuntimeException(ex); @@ -348,6 +362,59 @@ record -> { return completableFuture; } + private static ResponseEntity buildResponseFromReceivedRecord( + ConsumePushMessage consumePushMessage) { + Objects.requireNonNull(consumePushMessage); + ConsumePushMessage.Record record = consumePushMessage.record(); + if (record != null && record.headers() != null) { + String errorType = + record.headers().get(SystemHeaders.ERROR_HANDLING_ERROR_TYPE.getKey()); + if (errorType != null) { + int statusCode = convertRecordErrorToStatusCode(errorType); + String errorMessage = convertRecordErrorToHttpResponseMessage(record); + return ResponseEntity.status(statusCode) + .body( + ProblemDetail.forStatusAndDetail( + HttpStatus.valueOf(statusCode), errorMessage)); + } + } + try { + String asString = mapper.writeValueAsString(consumePushMessage); + return ResponseEntity.ok(asString); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static String convertRecordErrorToHttpResponseMessage( + ConsumePushMessage.Record record) { + String errorMessage = null; + for (String header : RECORD_ERROR_HEADERS) { + String value = record.headers().get(header); + if (!StringUtils.isEmpty(value)) { + errorMessage = value; + break; + } + } + if (errorMessage == null) { + if (record.value() != null) { + errorMessage = record.value().toString(); + } else { + // in this case the user will only have the status code as a hint + errorMessage = ""; + } + } + return errorMessage; + } + + private static int convertRecordErrorToStatusCode(String errorTypeString) { + ErrorTypes errorType = ErrorTypes.valueOf(errorTypeString.toUpperCase()); + return switch (errorType) { + case INVALID_RECORD -> 400; + case INTERNAL_ERROR -> 500; + }; + } + private Map computeQueryString(WebRequest request) { final Map queryString = request.getParameterMap().keySet().stream() diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java index 0c1f99cc7..5c60bd3dc 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java @@ -126,21 +126,22 @@ private AuthenticatedGatewayRequestContext getContext(WebSocketSession session) return (AuthenticatedGatewayRequestContext) session.getAttributes().get("context"); } - private void closeSession(WebSocketSession session, Throwable throwable) throws IOException { + private void closeSession(WebSocketSession session, Throwable throwable) { CloseStatus status = CloseStatus.SERVER_ERROR; if (throwable instanceof IllegalArgumentException) { status = CloseStatus.POLICY_VIOLATION; } try { session.close(status.withReason(throwable.getMessage())); + } catch (IOException e) { + log.error("error while closing websocket", e); } finally { callHandlerOnClose(session, status); } } @Override - protected void handleTextMessage(WebSocketSession session, TextMessage message) - throws Exception { + protected void handleTextMessage(WebSocketSession session, TextMessage message) { try { onMessage(session, getContext(session), message); } catch (Throwable throwable) { @@ -280,10 +281,15 @@ protected void startReadingMessages(WebSocketSession webSocketSession, Executor () -> !webSocketSession.isOpen(), message -> { try { - webSocketSession.sendMessage(new TextMessage(message)); + String jsonStringMessage = mapper.writeValueAsString(message); + webSocketSession.sendMessage(new TextMessage(jsonStringMessage)); } catch (IOException ex) { throw new RuntimeException(ex); } + }, + throwable -> { + log.error("error while reading messages", throwable); + closeSession(webSocketSession, throwable); }); } diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java index c964ee461..f2025e352 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/http/GatewayResourceTest.java @@ -24,7 +24,9 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; import ai.langstream.api.model.*; +import ai.langstream.api.runner.code.Header; import ai.langstream.api.runner.code.Record; +import ai.langstream.api.runner.code.SimpleRecord; import ai.langstream.api.runner.topics.TopicConnectionsRuntime; import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry; import ai.langstream.api.runner.topics.TopicConsumer; @@ -50,9 +52,7 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.nio.file.Path; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; @@ -61,11 +61,7 @@ import org.apache.commons.lang3.concurrent.BasicThreadFactory; import org.awaitility.Awaitility; import org.jetbrains.annotations.NotNull; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.mockito.Mockito; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.autoconfigure.actuate.observability.AutoConfigureObservability; @@ -98,9 +94,7 @@ abstract class GatewayResourceTest { protected static final ObjectMapper MAPPER = new ObjectMapper(); static List topics; - static ExecutorService futuresExecutor = - Executors.newCachedThreadPool( - new BasicThreadFactory.Builder().namingPattern("test-exec-%d").build()); + ExecutorService futuresExecutor; static Gateways testGateways; protected static ApplicationStore getMockedStore(String instanceYaml) { @@ -206,10 +200,15 @@ public static void beforeAll(WireMockRuntimeInfo wmRuntimeInfo) { } @BeforeEach - public void beforeEach(WireMockRuntimeInfo wmRuntimeInfo) { + public void beforeEach(WireMockRuntimeInfo wmRuntimeInfo, TestInfo testInfo) { testGateways = null; topics = null; Awaitility.setDefaultTimeout(30, TimeUnit.SECONDS); + futuresExecutor = + Executors.newCachedThreadPool( + new BasicThreadFactory.Builder() + .namingPattern("test-exec-" + testInfo.getDisplayName() + "-%d") + .build()); } @AfterAll @@ -222,9 +221,7 @@ public void afterEach() throws Exception { Metrics.globalRegistry.clear(); futuresExecutor.shutdownNow(); futuresExecutor.awaitTermination(1, TimeUnit.MINUTES); - futuresExecutor = - Executors.newCachedThreadPool( - new BasicThreadFactory.Builder().namingPattern("test-exec-%d").build()); + futuresExecutor = null; } @SneakyThrows @@ -738,7 +735,7 @@ void testService() throws Exception { final String outputTopic = genTopic("testService-output"); prepareTopicsForTest(inputTopic, outputTopic); - startTopicExchange(inputTopic, outputTopic); + startTopicExchange(inputTopic, outputTopic, false); testGateways = new Gateways( @@ -798,8 +795,59 @@ void testService() throws Exception { produceJsonAndGetBody(valueUrl, "{\"key\": \"my-key\", \"value\": \"my-value\"}")); } - private void startTopicExchange(String logicalFromTopic, String logicalToTopic) + @Test + void testServiceWithError() throws Exception { + final String inputTopic = genTopic("testServiceWithError-input"); + final String outputTopic = genTopic("testServiceWithError-output"); + prepareTopicsForTest(inputTopic, outputTopic); + + startTopicExchange(inputTopic, outputTopic, true); + + testGateways = + new Gateways( + List.of( + Gateway.builder() + .id("svc") + .type(Gateway.GatewayType.service) + .serviceOptions( + new Gateway.ServiceOptions( + null, + inputTopic, + outputTopic, + Gateway.ProducePayloadSchema.full, + List.of())) + .build(), + Gateway.builder() + .id("svc-value") + .type(Gateway.GatewayType.service) + .serviceOptions( + new Gateway.ServiceOptions( + null, + inputTopic, + outputTopic, + Gateway.ProducePayloadSchema.value, + List.of())) + .build())); + + final String url = + "http://localhost:%d/api/gateways/service/tenant1/application1/svc".formatted(port); + + HttpRequest request = + HttpRequest.newBuilder(URI.create(url)) + .POST(HttpRequest.BodyPublishers.ofString("my-string")) + .build(); + HttpResponse response = CLIENT.send(request, HttpResponse.BodyHandlers.ofString()); + + assertEquals(500, response.statusCode()); + assertEquals( + "{\"type\":\"about:blank\",\"title\":\"Internal Server Error\",\"status\":500,\"detail\":\"the agent failed!\",\"instance\":\"/api/gateways/service/tenant1/application1/svc\"}", + response.body()); + } + + private void startTopicExchange( + String logicalFromTopic, String logicalToTopic, boolean injectAgentFailure) throws Exception { + CompletableFuture started = new CompletableFuture<>(); final CompletableFuture future = CompletableFuture.runAsync( () -> { @@ -815,7 +863,7 @@ private void startTopicExchange(String logicalFromTopic, String logicalToTopic) final String toTopic = resolveTopicName(logicalToTopic); try (final TopicConsumer consumer = runtime.createConsumer( - null, + "gateway-resource-test" + fromTopic, streamingCluster, Map.of( "topic", @@ -826,11 +874,12 @@ private void startTopicExchange(String logicalFromTopic, String logicalToTopic) try (final TopicProducer producer = runtime.createProducer( - null, + "gateway-resource-test" + toTopic, streamingCluster, Map.of("topic", toTopic)); ) { producer.start(); + started.complete(null); while (true) { if (Thread.currentThread().isInterrupted()) { break; @@ -852,7 +901,23 @@ private void startTopicExchange(String logicalFromTopic, String logicalToTopic) record.value() == null ? "NULL" : record.value().getClass()); - producer.write(record).get(); + Collection
headers = + new ArrayList<>(record.headers()); + if (injectAgentFailure) { + headers.add( + SimpleRecord.SimpleHeader.of( + "langstream-error-message", + "the agent failed!")); + headers.add( + SimpleRecord.SimpleHeader.of( + "langstream-error-type", + "INTERNAL_ERROR")); + } + producer.write( + SimpleRecord.copyFrom(record) + .headers(headers) + .build()) + .get(); } consumer.commit(records); log.info( @@ -877,6 +942,8 @@ private void startTopicExchange(String logicalFromTopic, String logicalToTopic) } }, futuresExecutor); + started.get(); + log.info("Topic exchange started"); } private record MsgRecord(Object key, Object value, Map headers) {} diff --git a/langstream-api/src/main/java/ai/langstream/api/runner/code/AgentProcessor.java b/langstream-api/src/main/java/ai/langstream/api/runner/code/AgentProcessor.java index 8926e77cb..83863e94e 100644 --- a/langstream-api/src/main/java/ai/langstream/api/runner/code/AgentProcessor.java +++ b/langstream-api/src/main/java/ai/langstream/api/runner/code/AgentProcessor.java @@ -38,9 +38,19 @@ default ComponentType componentType() { return ComponentType.PROCESSOR; } - record SourceRecordAndResult(Record sourceRecord, List resultRecords, Throwable error) { + record SourceRecordAndResult( + Record sourceRecord, + List resultRecords, + Throwable error, + ErrorTypes errorType) { + public SourceRecordAndResult { resultRecords = Objects.requireNonNullElseGet(resultRecords, List::of); } + + public SourceRecordAndResult( + Record sourceRecord, List resultRecords, Throwable error) { + this(sourceRecord, resultRecords, error, null); + } } } diff --git a/langstream-api/src/main/java/ai/langstream/api/runner/code/AgentSource.java b/langstream-api/src/main/java/ai/langstream/api/runner/code/AgentSource.java index 9554159a4..54ba99c51 100644 --- a/langstream-api/src/main/java/ai/langstream/api/runner/code/AgentSource.java +++ b/langstream-api/src/main/java/ai/langstream/api/runner/code/AgentSource.java @@ -18,7 +18,6 @@ import ai.langstream.api.runtime.ComponentType; import java.util.List; -/** Body of the agent */ public interface AgentSource extends AgentCode { /** @@ -46,9 +45,12 @@ default ComponentType componentType() { * dead letter queue or throw an error * * @param record the record that failed + * @param error the error that caused the failure + * @param errorType the type of error if defined * @throws Exception if the source fails to process the permanently failed record */ - default void permanentFailure(Record record, Exception error) throws Exception { + default void permanentFailure(Record record, Exception error, ErrorTypes errorType) + throws Exception { throw error; } } diff --git a/langstream-api/src/main/java/ai/langstream/api/runner/code/ErrorTypes.java b/langstream-api/src/main/java/ai/langstream/api/runner/code/ErrorTypes.java new file mode 100644 index 000000000..b46a18270 --- /dev/null +++ b/langstream-api/src/main/java/ai/langstream/api/runner/code/ErrorTypes.java @@ -0,0 +1,21 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.api.runner.code; + +public enum ErrorTypes { + INVALID_RECORD, + INTERNAL_ERROR +} diff --git a/langstream-api/src/main/java/ai/langstream/api/runner/code/SystemHeaders.java b/langstream-api/src/main/java/ai/langstream/api/runner/code/SystemHeaders.java new file mode 100644 index 000000000..1dce93bfd --- /dev/null +++ b/langstream-api/src/main/java/ai/langstream/api/runner/code/SystemHeaders.java @@ -0,0 +1,47 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.api.runner.code; + +import lombok.Getter; + +@Getter +public enum SystemHeaders { + SERVICE_REQUEST_ID_HEADER("langstream-service-request-id"), + ERROR_HANDLING_ERROR_MESSAGE("langstream-error-message"), + ERROR_HANDLING_ERROR_MESSAGE_LEGACY("error-msg"), + ERROR_HANDLING_ERROR_CLASS("langstream-error-class"), + ERROR_HANDLING_ERROR_CLASS_LEGACY("error-class"), + ERROR_HANDLING_ERROR_TYPE("langstream-error-type"), + + ERROR_HANDLING_CAUSE_ERROR_MESSAGE("langstream-error-cause-message"), + ERROR_HANDLING_CAUSE_ERROR_MESSAGE_LEGACY("cause-msg"), + ERROR_HANDLING_CAUSE_ERROR_CLASS("langstream-error-cause-class"), + ERROR_HANDLING_CAUSE_ERROR_CLASS_LEGACY("cause-class"), + + ERROR_HANDLING_ROOT_CAUSE_ERROR_MESSAGE("langstream-error-root-cause-message"), + ERROR_HANDLING_ROOT_CAUSE_ERROR_MESSAGE_LEGACY("root-cause-msg"), + ERROR_HANDLING_ROOT_CAUSE_ERROR_CLASS("langstream-error-root-cause-class"), + ERROR_HANDLING_ROOT_CAUSE_ERROR_CLASS_LEGACY("root-cause-class"), + + ERROR_HANDLING_SOURCE_TOPIC("langstream-error-source-topic"), + ERROR_HANDLING_SOURCE_TOPIC_LEGACY("source-topic"); + + private final String key; + + SystemHeaders(String key) { + this.key = key; + } +} diff --git a/langstream-cli/src/main/java/ai/langstream/cli/LangStreamCLI.java b/langstream-cli/src/main/java/ai/langstream/cli/LangStreamCLI.java index 69bd08969..7743e3edf 100644 --- a/langstream-cli/src/main/java/ai/langstream/cli/LangStreamCLI.java +++ b/langstream-cli/src/main/java/ai/langstream/cli/LangStreamCLI.java @@ -95,13 +95,12 @@ private static String computeErrorMessage(Exception e) { final HttpResponse response = httpRequestFailedException.getResponse(); if (response != null) { Object body = httpRequestFailedException.getResponse().body(); + msg += String.format(" with code %d:", response.statusCode()); if (body != null) { if (body instanceof byte[]) { body = new String((byte[]) body, StandardCharsets.UTF_8); } - msg += String.format(": %s", body); - } else { - msg += String.format(": %s", response.statusCode()); + msg += String.format("\n%s", body); } } return msg; diff --git a/langstream-cli/src/main/java/ai/langstream/cli/api/model/Gateways.java b/langstream-cli/src/main/java/ai/langstream/cli/api/model/Gateways.java index 41e2a68e6..ab9b34eb1 100644 --- a/langstream-cli/src/main/java/ai/langstream/cli/api/model/Gateways.java +++ b/langstream-cli/src/main/java/ai/langstream/cli/api/model/Gateways.java @@ -65,6 +65,7 @@ public static class Gateway { public static final String TYPE_PRODUCE = "produce"; public static final String TYPE_CONSUME = "consume"; public static final String TYPE_CHAT = "chat"; + public static final String TYPE_SERVICE = "service"; String id; String type; diff --git a/langstream-cli/src/main/java/ai/langstream/cli/commands/RootGatewayCmd.java b/langstream-cli/src/main/java/ai/langstream/cli/commands/RootGatewayCmd.java index b6dc3ae2b..7aeec13b1 100644 --- a/langstream-cli/src/main/java/ai/langstream/cli/commands/RootGatewayCmd.java +++ b/langstream-cli/src/main/java/ai/langstream/cli/commands/RootGatewayCmd.java @@ -18,13 +18,19 @@ import ai.langstream.cli.commands.gateway.ChatGatewayCmd; import ai.langstream.cli.commands.gateway.ConsumeGatewayCmd; import ai.langstream.cli.commands.gateway.ProduceGatewayCmd; +import ai.langstream.cli.commands.gateway.ServiceGatewayCmd; import lombok.Getter; import picocli.CommandLine; @CommandLine.Command( name = "gateway", header = "Interact with a application gateway", - subcommands = {ProduceGatewayCmd.class, ConsumeGatewayCmd.class, ChatGatewayCmd.class}) + subcommands = { + ProduceGatewayCmd.class, + ConsumeGatewayCmd.class, + ChatGatewayCmd.class, + ServiceGatewayCmd.class + }) @Getter public class RootGatewayCmd { @CommandLine.ParentCommand private RootCmd rootCmd; diff --git a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java index c4382a1f4..f8dab4c36 100644 --- a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java +++ b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java @@ -134,16 +134,17 @@ protected GatewayRequestInfo validateGatewayAndGetUrl( } } if (protocol == Protocols.http) { - if (!type.equals("produce")) { - throw new IllegalArgumentException("HTTP protocol is only supported for produce"); + if (!type.equals("produce") && !type.equals("service")) { + throw new IllegalArgumentException( + "HTTP protocol is only supported for produce and service gateways."); } + Map gwOptions = + type.equals("produce") + ? gatewayInfo.getProduceOptions() + : gatewayInfo.getServiceOptions(); boolean fullPayloadSchema = - gatewayInfo.getProduceOptions() == null - || !"value" - .equals( - gatewayInfo - .getProduceOptions() - .getOrDefault("payload-schema", "full")); + gwOptions == null + || !"value".equals(gwOptions.getOrDefault("payload-schema", "full")); return new GatewayRequestInfo( String.format( diff --git a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ServiceGatewayCmd.java b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ServiceGatewayCmd.java new file mode 100644 index 000000000..be5a881c6 --- /dev/null +++ b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ServiceGatewayCmd.java @@ -0,0 +1,136 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.cli.commands.gateway; + +import ai.langstream.cli.api.model.Gateways; +import com.fasterxml.jackson.core.JsonProcessingException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.Map; +import lombok.SneakyThrows; +import picocli.CommandLine; + +@CommandLine.Command(name = "service", header = "Interact with a service gateway") +public class ServiceGatewayCmd extends BaseGatewayCmd { + + @CommandLine.Parameters(description = "Application ID") + private String applicationId; + + @CommandLine.Parameters(description = "Gateway ID") + private String gatewayId; + + @CommandLine.Option( + names = {"-p", "--param"}, + description = "Gateway parameters. Format: key=value") + private Map params; + + @CommandLine.Option( + names = {"-c", "--credentials"}, + description = + "Credentials for the gateway. Required if the gateway requires authentication.") + private String credentials; + + @CommandLine.Option( + names = {"-tc", "--test-credentials"}, + description = "Test credentials for the gateway.") + private String testCredentials; + + @CommandLine.Option( + names = {"--connect-timeout"}, + description = "Connect timeout in seconds.") + private long connectTimeoutSeconds = 0; + + @CommandLine.Option( + names = {"-v", "--value"}, + description = "Message value") + private String messageValue; + + @CommandLine.Option( + names = {"-k", "--key"}, + description = "Message key") + private String messageKey; + + @CommandLine.Option( + names = {"--header"}, + description = "Messages headers. Format: key=value") + private Map headers; + + @Override + @SneakyThrows + public void run() { + GatewayRequestInfo gatewayRequestInfo = + validateGatewayAndGetUrl( + applicationId, + gatewayId, + Gateways.Gateway.TYPE_SERVICE, + params, + Map.of(), + credentials, + testCredentials, + Protocols.http); + final Duration connectTimeout = + connectTimeoutSeconds > 0 ? Duration.ofSeconds(connectTimeoutSeconds) : null; + + String json; + if (gatewayRequestInfo.isFullPayloadSchema()) { + final ProduceGatewayCmd.ProduceRequest produceRequest = + new ProduceGatewayCmd.ProduceRequest(messageKey, messageValue, headers); + json = messageMapper.writeValueAsString(produceRequest); + } else { + if (messageKey != null) { + log("Warning: key is ignored when the payload schema is value"); + } + if (headers != null && !headers.isEmpty()) { + log("Warning: headers are ignored when the payload schema is value"); + } + try { + // it's already a json string + messageMapper.readValue(messageValue, Map.class); + json = messageValue; + } catch (JsonProcessingException ex) { + json = messageMapper.writeValueAsString(messageValue); + } + } + + produceHttp( + gatewayRequestInfo.getUrl(), connectTimeout, gatewayRequestInfo.getHeaders(), json); + } + + private void produceHttp( + String producePath, Duration connectTimeout, Map headers, String json) + throws Exception { + final HttpRequest.Builder builder = + HttpRequest.newBuilder(URI.create(producePath)) + .header("Content-Type", "application/json") + .version(HttpClient.Version.HTTP_1_1) + .POST(HttpRequest.BodyPublishers.ofString(json)); + if (connectTimeout != null) { + builder.timeout(connectTimeout); + } + if (headers != null) { + headers.forEach(builder::header); + } + final HttpRequest request = builder.build(); + final HttpResponse response = + getClient() + .getHttpClientFacade() + .http(request, HttpResponse.BodyHandlers.ofString()); + log(response.body()); + } +} diff --git a/langstream-cli/src/test/java/ai/langstream/cli/commands/applications/GatewaysCmdTest.java b/langstream-cli/src/test/java/ai/langstream/cli/commands/applications/GatewaysCmdTest.java index 72a837241..2a9dcbc17 100644 --- a/langstream-cli/src/test/java/ai/langstream/cli/commands/applications/GatewaysCmdTest.java +++ b/langstream-cli/src/test/java/ai/langstream/cli/commands/applications/GatewaysCmdTest.java @@ -121,4 +121,40 @@ public void testProduceHttpHeader() throws Exception { assertEquals(0, result.exitCode()); assertEquals("", result.err()); } + + @Test + public void testServiceGateway() throws Exception { + Map response = + Map.of( + "application", + Map.of( + "" + "gateways", + Map.of( + "gateways", + List.of( + Map.of( + "id", + "g1", + "type", + "service", + "service-options", + Map.of( + "input-topic", "from", + "output-topic", "to", + "payload-schema", "value")))))); + + wireMock.register( + WireMock.get(String.format("/api/applications/%s/my-app?stats=false", TENANT)) + .willReturn(WireMock.ok(new ObjectMapper().writeValueAsString(response)))); + + wireMock.register( + WireMock.post(String.format("/api/gateways/service/%s/my-app/g1", TENANT)) + .withRequestBody(equalToJson("{\"my\": true}")) + .willReturn(WireMock.ok())); + + CommandResult result = + executeCommand("gateway", "service", "my-app", "g1", "-v", "{\"my\": true}"); + assertEquals(0, result.exitCode()); + assertEquals("", result.err()); + } } diff --git a/langstream-e2e-tests/src/test/java/ai/langstream/tests/PythonDLQIT.java b/langstream-e2e-tests/src/test/java/ai/langstream/tests/PythonDLQIT.java new file mode 100644 index 000000000..0ea9a9761 --- /dev/null +++ b/langstream-e2e-tests/src/test/java/ai/langstream/tests/PythonDLQIT.java @@ -0,0 +1,69 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.tests; + +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import ai.langstream.tests.util.BaseEndToEndTest; +import ai.langstream.tests.util.TestSuites; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +@Slf4j +@ExtendWith(BaseEndToEndTest.class) +@Tag(TestSuites.CATEGORY_OTHER) +public class PythonDLQIT extends BaseEndToEndTest { + + @Test + public void test() throws Exception { + assumeTrue(streamingCluster.type().equals("pulsar")); + installLangStreamCluster(true); + final String tenant = "ten-" + System.currentTimeMillis(); + setupTenant(tenant); + final String applicationId = "my-test-app"; + Map appEnv = new HashMap<>(); + Map serviceMap = + (Map) streamingCluster.configuration().get("service"); + String brokerUrl = serviceMap.get("serviceUrl").toString(); + appEnv.put("PULSAR_BROKER_URL", brokerUrl); + + deployLocalApplicationAndAwaitReady( + tenant, applicationId, "python-processor-with-dlq", appEnv, 2); + CompletableFuture commandResult = + executeCommandOnClientAsync( + "bin/langstream gateway service %s svc -v '{\"my-schema\":true}' --connect-timeout 60" + .formatted(applicationId) + .split(" ")); + try { + commandResult.get(); + Assertions.fail("Expected exception"); + } catch (ExecutionException ex) { + CommandExecFailedException failedException = (CommandExecFailedException) ex.getCause(); + log.info("Error: {}", failedException.getStderr()); + String stderr = failedException.getStderr(); + Assertions.assertTrue(stderr.contains("with code 400:")); + Assertions.assertTrue(stderr.contains("exception from python processor")); + } + deleteAppAndAwaitCleanup(tenant, applicationId); + } +} diff --git a/langstream-e2e-tests/src/test/java/ai/langstream/tests/util/BaseEndToEndTest.java b/langstream-e2e-tests/src/test/java/ai/langstream/tests/util/BaseEndToEndTest.java index 7871e65a6..6032f541a 100644 --- a/langstream-e2e-tests/src/test/java/ai/langstream/tests/util/BaseEndToEndTest.java +++ b/langstream-e2e-tests/src/test/java/ai/langstream/tests/util/BaseEndToEndTest.java @@ -78,7 +78,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.stream.Collectors; -import lombok.SneakyThrows; +import lombok.*; import lombok.extern.slf4j.Slf4j; import org.awaitility.Awaitility; import org.junit.jupiter.api.AfterAll; @@ -223,17 +223,14 @@ private static Pod getFirstPodFromDeployment(String deploymentName) { @SneakyThrows public static String executeCommandOnClient(String... args) { - return executeCommandOnClient(2, TimeUnit.MINUTES, args); + return executeCommandOnClientAsync(args).get(2, TimeUnit.MINUTES); } @SneakyThrows - protected static String executeCommandOnClient(long timeout, TimeUnit unit, String... args) { + protected static CompletableFuture executeCommandOnClientAsync(String... args) { final Pod pod = getFirstPodFromDeployment("langstream-client"); return execInPod( - pod.getMetadata().getName(), - pod.getSpec().getContainers().get(0).getName(), - args) - .get(timeout, unit); + pod.getMetadata().getName(), pod.getSpec().getContainers().get(0).getName(), args); } @SneakyThrows @@ -275,6 +272,15 @@ public static CompletableFuture execInPod( return execInPodInNamespace(namespace, podName, containerName, cmds); } + @AllArgsConstructor + @Getter + @ToString + public static class CommandExecFailedException extends RuntimeException { + private final String command; + private final String stdout; + private final String stderr; + } + public static CompletableFuture execInPodInNamespace( String namespace, String podName, String containerName, String... cmds) { @@ -298,13 +304,17 @@ public void onFailure(Throwable t, Response failureResponse) { if (!completed.compareAndSet(false, true)) { return; } + String errString = error.toString(StandardCharsets.UTF_8); + String outString = out.toString(StandardCharsets.UTF_8); log.warn( "Error executing {} encountered; \nstderr: {}\nstdout: {}", cmd, - error.toString(StandardCharsets.UTF_8), - out.toString(), + errString, + outString, t); - response.completeExceptionally(t); + CommandExecFailedException commandExecFailedException = + new CommandExecFailedException(cmd, outString, errString); + response.completeExceptionally(commandExecFailedException); } @Override @@ -313,18 +323,18 @@ public void onExit(int code, Status status) { return; } if (code != 0) { + String errString = error.toString(StandardCharsets.UTF_8); + String outString = out.toString(StandardCharsets.UTF_8); log.warn( "Error executing {} encountered; \ncode: {}\n stderr: {}\nstdout: {}", cmd, code, - error.toString(StandardCharsets.UTF_8), - out.toString(StandardCharsets.UTF_8)); - response.completeExceptionally( - new RuntimeException( - "Command failed with err code: " - + code - + ", stderr: " - + error.toString(StandardCharsets.UTF_8))); + errString, + outString); + + CommandExecFailedException commandExecFailedException = + new CommandExecFailedException(cmd, outString, errString); + response.completeExceptionally(commandExecFailedException); } else { log.info( "Command completed {}; \nstderr: {}\nstdout: {}", @@ -340,12 +350,13 @@ public void onClose(int rc, String reason) { if (!completed.compareAndSet(false, true)) { return; } + String outString = out.toString(StandardCharsets.UTF_8); log.info( "Command completed {}; \nstderr: {}\nstdout: {}", cmd, error.toString(StandardCharsets.UTF_8), - out.toString(StandardCharsets.UTF_8)); - response.complete(out.toString(StandardCharsets.UTF_8)); + outString); + response.complete(outString); } }; diff --git a/langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/dlq-pipeline.yaml b/langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/dlq-pipeline.yaml new file mode 100644 index 000000000..607936508 --- /dev/null +++ b/langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/dlq-pipeline.yaml @@ -0,0 +1,34 @@ +# +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +name: "DLQ Pipeline" +topics: + - name: ls-test-topic0-deadletter + creation-mode: create-if-not-exists + - name: ls-test-topic1 + creation-mode: create-if-not-exists +pipeline: + - name: "Read DLQ" + id: "read-errors" + type: "pulsardlq-source" + output: "ls-test-topic1" + configuration: + pulsar-url: "${secrets.pulsar.broker-url}" + namespace: "public/default" + subscription: "dlq-subscription" + dlq-suffix: "-deadletter" + include-partitioned: false + timeout-ms: 1000 \ No newline at end of file diff --git a/langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/gateways.yaml b/langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/gateways.yaml new file mode 100644 index 000000000..0e2bbf86d --- /dev/null +++ b/langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/gateways.yaml @@ -0,0 +1,23 @@ +# +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +gateways: + - id: svc + type: service + service-options: + input-topic: ls-test-topic0 + output-topic: ls-test-topic1 + payload-schema: value \ No newline at end of file diff --git a/langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/pipeline.yaml b/langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/pipeline.yaml new file mode 100644 index 000000000..f72bcc434 --- /dev/null +++ b/langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/pipeline.yaml @@ -0,0 +1,38 @@ +# +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +module: "module-1" +id: "pipeline-1" +name: "Exclamation processor" +topics: + - name: ls-test-topic0 + creation-mode: create-if-not-exists + schema: + type: string + keySchema: + type: string + - name: ls-test-topic1 + creation-mode: create-if-not-exists +pipeline: + - name: "Process using Python" + id: "test-python-processor" + type: "python-processor" + input: ls-test-topic0 + output: ls-test-topic1 + errors: + on-failure: dead-letter + configuration: + className: example.FailProcessor \ No newline at end of file diff --git a/langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/python/example.py b/langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/python/example.py new file mode 100644 index 000000000..30bf5fc6a --- /dev/null +++ b/langstream-e2e-tests/src/test/resources/apps/python-processor-with-dlq/python/example.py @@ -0,0 +1,26 @@ +# +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from langstream import SimpleRecord, Processor, AgentContext +import logging, os + + +class FailProcessor(Processor): + + def process(self, record): + logging.info("Processing record" + str(record)) + from langstream import InvalidRecordError + raise InvalidRecordError("exception from python processor") \ No newline at end of file diff --git a/langstream-e2e-tests/src/test/resources/secrets/secret1.yaml b/langstream-e2e-tests/src/test/resources/secrets/secret1.yaml index 084964b4b..f722fdb57 100644 --- a/langstream-e2e-tests/src/test/resources/secrets/secret1.yaml +++ b/langstream-e2e-tests/src/test/resources/secrets/secret1.yaml @@ -66,6 +66,10 @@ secrets: data: producer-config-json: "${KAFKA_PRODUCER_CONFIG}" + - id: pulsar + data: + broker-url: "${PULSAR_BROKER_URL}" + - id: s3 data: endpoint: "${S3_ENDPOINT}" diff --git a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/PulsarDLQSourceAgentProvider.java b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/PulsarDLQSourceAgentProvider.java index 0dc3b62ad..542a91028 100644 --- a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/PulsarDLQSourceAgentProvider.java +++ b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/PulsarDLQSourceAgentProvider.java @@ -101,5 +101,15 @@ public static class PulsarDLQSourceConfiguration { defaultValue = "false") @JsonProperty("include-partitioned") private boolean includePartitioned; + + @ConfigProperty( + description = + """ + Timeout in milliseconds to wait for messages from the DLQ topics. + Default is 0, meaning it will wait indefinitely. + """, + defaultValue = "0") + @JsonProperty("timeout-ms") + private int timeoutMs; } } diff --git a/langstream-runtime/langstream-runtime-impl/pom.xml b/langstream-runtime/langstream-runtime-impl/pom.xml index 503f42076..227d95ab9 100644 --- a/langstream-runtime/langstream-runtime-impl/pom.xml +++ b/langstream-runtime/langstream-runtime-impl/pom.xml @@ -430,6 +430,20 @@ exec-maven-plugin 3.1.0 + + python-setup-poetry + generate-sources + + exec + + + ${project.basedir}/src/main/python + poetry + + install + + + python-lint-black generate-sources @@ -438,8 +452,10 @@ ${project.basedir}/src/main/python - black + poetry + run + black . @@ -452,8 +468,10 @@ ${project.basedir}/src/main/python - ruff + poetry + run + ruff check --fix . @@ -467,8 +485,12 @@ exec - ${project.build.directory}/python - tox + ${project.basedir}/src/main/python + poetry + + run + tox + ${skipTests} @@ -479,7 +501,7 @@ exec - ${project.build.directory}/python + ${project.basedir}/src/main/python poetry ${skipPythonPackage} @@ -755,6 +777,10 @@ org.codehaus.mojo exec-maven-plugin + + python-setup-poetry + none + python-lint-black none diff --git a/langstream-runtime/langstream-runtime-impl/src/main/java/ai/langstream/runtime/agent/AgentRunner.java b/langstream-runtime/langstream-runtime-impl/src/main/java/ai/langstream/runtime/agent/AgentRunner.java index 3b9d1a5de..e6e97e96d 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/java/ai/langstream/runtime/agent/AgentRunner.java +++ b/langstream-runtime/langstream-runtime-impl/src/main/java/ai/langstream/runtime/agent/AgentRunner.java @@ -19,19 +19,8 @@ import static ai.langstream.api.model.ErrorsSpec.FAIL; import static ai.langstream.api.model.ErrorsSpec.SKIP; -import ai.langstream.api.runner.code.AgentCode; -import ai.langstream.api.runner.code.AgentCodeAndLoader; -import ai.langstream.api.runner.code.AgentCodeRegistry; -import ai.langstream.api.runner.code.AgentContext; -import ai.langstream.api.runner.code.AgentProcessor; -import ai.langstream.api.runner.code.AgentService; -import ai.langstream.api.runner.code.AgentSink; -import ai.langstream.api.runner.code.AgentSource; -import ai.langstream.api.runner.code.AgentStatusResponse; -import ai.langstream.api.runner.code.BadRecordHandler; -import ai.langstream.api.runner.code.MetricsReporter; +import ai.langstream.api.runner.code.*; import ai.langstream.api.runner.code.Record; -import ai.langstream.api.runner.code.RecordSink; import ai.langstream.api.runner.topics.TopicAdmin; import ai.langstream.api.runner.topics.TopicConnectionProvider; import ai.langstream.api.runner.topics.TopicConnectionsRuntime; @@ -564,8 +553,9 @@ public ComponentType componentType() { } @Override - public void permanentFailure(Record record, Exception error) throws Exception { - wrapped.permanentFailure(record, error); + public void permanentFailure(Record record, Exception error, ErrorTypes errorType) + throws Exception { + wrapped.permanentFailure(record, error, errorType); } @Override @@ -851,7 +841,7 @@ private static void writeRecordToTheSink( new PermanentFailureException(error); try { source.permanentFailure( - sourceRecord, permanentFailureException); + sourceRecord, permanentFailureException, null); } catch (Exception err) { err.addSuppressed(permanentFailureException); log.error("Cannot send permanent failure to the source", err); @@ -926,14 +916,17 @@ private static void runProcessorAgent( new PermanentFailureException(error); permanentFailureException.fillInStackTrace(); source.permanentFailure( - sourceRecord, permanentFailureException); + sourceRecord, + permanentFailureException, + result.errorType()); if (errorsHandler.failProcessingOnPermanentErrors()) { log.error("Failing processing on permanent error"); finalSink.emit( new AgentProcessor.SourceRecordAndResult( sourceRecord, List.of(), - permanentFailureException)); + permanentFailureException, + result.errorType())); } else { // in case the source does not throw an exception we mark // the record as "skipped" diff --git a/langstream-runtime/langstream-runtime-impl/src/main/java/ai/langstream/runtime/agent/TopicConsumerSource.java b/langstream-runtime/langstream-runtime-impl/src/main/java/ai/langstream/runtime/agent/TopicConsumerSource.java index 4db01ee87..e15bacfac 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/java/ai/langstream/runtime/agent/TopicConsumerSource.java +++ b/langstream-runtime/langstream-runtime-impl/src/main/java/ai/langstream/runtime/agent/TopicConsumerSource.java @@ -16,8 +16,7 @@ package ai.langstream.runtime.agent; import ai.langstream.ai.agents.commons.MutableRecord; -import ai.langstream.api.runner.code.AbstractAgentCode; -import ai.langstream.api.runner.code.AgentSource; +import ai.langstream.api.runner.code.*; import ai.langstream.api.runner.code.Record; import ai.langstream.api.runner.topics.TopicConsumer; import ai.langstream.api.runner.topics.TopicProducer; @@ -49,24 +48,58 @@ public void commit(List records) throws Exception { } @Override - public void permanentFailure(Record record, Exception error) { - // DLQ - log.error("Permanent failure on record {}", record, error); + public void permanentFailure(Record record, Exception error, ErrorTypes errorType) { + errorType = errorType == null ? ErrorTypes.INTERNAL_ERROR : errorType; + log.error("Permanent {} failure on record {}", errorType, record, error); MutableRecord recordWithError = MutableRecord.recordToMutableRecord(record, false); String sourceTopic = record.origin(); + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_ERROR_TYPE.getKey(), errorType.toString()); + + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_SOURCE_TOPIC.getKey(), sourceTopic); + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_SOURCE_TOPIC_LEGACY.getKey(), sourceTopic); + + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_ERROR_MESSAGE.getKey(), error.getMessage()); + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_ERROR_MESSAGE_LEGACY.getKey(), error.getMessage()); + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_ERROR_CLASS.getKey(), error.getClass().getName()); + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_ERROR_CLASS_LEGACY.getKey(), + error.getClass().getName()); + Throwable cause = error.getCause(); - recordWithError.setProperty("error-msg", error.getMessage()); - recordWithError.setProperty("error-class", error.getClass().getName()); - recordWithError.setProperty("source-topic", sourceTopic); if (cause != null) { - recordWithError.setProperty("cause-msg", cause.getMessage()); - recordWithError.setProperty("cause-class", cause.getClass().getName()); + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_CAUSE_ERROR_MESSAGE.getKey(), cause.getMessage()); + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_CAUSE_ERROR_MESSAGE_LEGACY.getKey(), + cause.getMessage()); + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_CAUSE_ERROR_CLASS.getKey(), + cause.getClass().getName()); + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_CAUSE_ERROR_CLASS_LEGACY.getKey(), + cause.getClass().getName()); Throwable rootCause = cause; while (rootCause.getCause() != null) { rootCause = rootCause.getCause(); } - recordWithError.setProperty("root-cause-msg", rootCause.getMessage()); - recordWithError.setProperty("root-cause-class", rootCause.getClass().getName()); + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_ROOT_CAUSE_ERROR_MESSAGE.getKey(), + rootCause.getMessage()); + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_ROOT_CAUSE_ERROR_MESSAGE_LEGACY.getKey(), + rootCause.getMessage()); + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_ROOT_CAUSE_ERROR_CLASS.getKey(), + rootCause.getClass().getName()); + recordWithError.setProperty( + SystemHeaders.ERROR_HANDLING_ROOT_CAUSE_ERROR_CLASS_LEGACY.getKey(), + rootCause.getClass().getName()); } Record finalRecord = MutableRecord.mutableRecordToRecord(recordWithError).get(); log.info("Writing to DLQ: {}", finalRecord); diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/__init__.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/__init__.py index 89c46f7e8..83883ccc3 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/__init__.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/__init__.py @@ -25,7 +25,7 @@ AgentContext, Service, ) -from .util import SimpleRecord, AvroValue +from .util import SimpleRecord, AvroValue, InvalidRecordError __all__ = [ "Record", @@ -38,4 +38,5 @@ "SimpleRecord", "AvroValue", "AgentContext", + "InvalidRecordError", ] diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/api.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/api.py index 9fe8fa760..06184858a 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/api.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/api.py @@ -124,7 +124,7 @@ def commit(self, record: Record): processed.""" pass - def permanent_failure(self, record: Record, error: Exception): + def permanent_failure(self, record: Record, error: Exception, error_type: str): """Called by the framework to indicate that the agent has permanently failed to process a record. The Source agent may send the record to a dead letter queue or raise an error. diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/util.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/util.py index 98f06444f..28dbc9a31 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/util.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/util.py @@ -19,7 +19,7 @@ from .api import Record -__all__ = ["SimpleRecord", "AvroValue"] +__all__ = ["SimpleRecord", "AvroValue", "InvalidRecordError"] class SimpleRecord(Record): @@ -69,3 +69,7 @@ def __repr__(self): class AvroValue(object): schema: dict value: Any + + +class InvalidRecordError(Exception): + pass diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py index fdd0194ee..baf6ede33 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py @@ -161,6 +161,7 @@ async def handle_read_requests(self, context, read_records): "permanent_failure", record, RuntimeError(failure.error_message), + failure.error_type, ) request = await context.read() @@ -207,6 +208,10 @@ async def process(self, requests: AsyncIterable[ProcessorRequest], _): grpc_result.records.append(grpc_record) yield ProcessorResponse(results=[grpc_result]) except Exception as e: + if e.__class__.__name__ == "InvalidRecordError": + grpc_result.error_type = "INVALID_RECORD" + else: + grpc_result.error_type = "INTERNAL_ERROR" grpc_result.error = str(e) yield ProcessorResponse(results=[grpc_result]) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py index 525be91a7..2daa20301 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py @@ -16,7 +16,8 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: langstream_grpc/proto/agent.proto +# source: agent.proto +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -32,47 +33,45 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n!langstream_grpc/proto/agent.proto\x1a\x1bgoogle/protobuf/empty.proto"!\n\x0cInfoResponse\x12\x11\n\tjson_info\x18\x01 \x01(\t"\xa3\x02\n\x05Value\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\x15\n\x0b\x62ytes_value\x18\x02 \x01(\x0cH\x00\x12\x17\n\rboolean_value\x18\x03 \x01(\x08H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x14\n\nbyte_value\x18\x05 \x01(\x05H\x00\x12\x15\n\x0bshort_value\x18\x06 \x01(\x05H\x00\x12\x13\n\tint_value\x18\x07 \x01(\x05H\x00\x12\x14\n\nlong_value\x18\x08 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\t \x01(\x02H\x00\x12\x16\n\x0c\x64ouble_value\x18\n \x01(\x01H\x00\x12\x14\n\njson_value\x18\x0b \x01(\tH\x00\x12\x14\n\navro_value\x18\x0c \x01(\x0cH\x00\x42\x0c\n\ntype_oneof"-\n\x06Header\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x05value\x18\x02 \x01(\x0b\x32\x06.Value"*\n\x06Schema\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x0c"\xb3\x01\n\x06Record\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x18\n\x03key\x18\x02 \x01(\x0b\x32\x06.ValueH\x00\x88\x01\x01\x12\x1a\n\x05value\x18\x03 \x01(\x0b\x32\x06.ValueH\x01\x88\x01\x01\x12\x18\n\x07headers\x18\x04 \x03(\x0b\x32\x07.Header\x12\x0e\n\x06origin\x18\x05 \x01(\t\x12\x16\n\ttimestamp\x18\x06 \x01(\x03H\x02\x88\x01\x01\x42\x06\n\x04_keyB\x08\n\x06_valueB\x0c\n\n_timestamp"K\n\x18TopicProducerWriteResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error"X\n\x15TopicProducerResponse\x12\r\n\x05topic\x18\x01 \x01(\t\x12\x17\n\x06schema\x18\x02 \x01(\x0b\x32\x07.Schema\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record"<\n\x10PermanentFailure\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x15\n\rerror_message\x18\x02 \x01(\t"X\n\rSourceRequest\x12\x19\n\x11\x63ommitted_records\x18\x01 \x03(\x03\x12,\n\x11permanent_failure\x18\x02 \x01(\x0b\x32\x11.PermanentFailure"C\n\x0eSourceResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"E\n\x10ProcessorRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"O\n\x11ProcessorResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12!\n\x07results\x18\x02 \x03(\x0b\x32\x10.ProcessorResult"\\\n\x0fProcessorResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x07records\x18\x03 \x03(\x0b\x32\x07.RecordB\x08\n\x06_error"?\n\x0bSinkRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x17\n\x06record\x18\x02 \x01(\x0b\x32\x07.Record"?\n\x0cSinkResponse\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error2\xaf\x02\n\x0c\x41gentService\x12\x35\n\nagent_info\x12\x16.google.protobuf.Empty\x1a\r.InfoResponse"\x00\x12-\n\x04read\x12\x0e.SourceRequest\x1a\x0f.SourceResponse"\x00(\x01\x30\x01\x12\x36\n\x07process\x12\x11.ProcessorRequest\x1a\x12.ProcessorResponse"\x00(\x01\x30\x01\x12*\n\x05write\x12\x0c.SinkRequest\x1a\r.SinkResponse"\x00(\x01\x30\x01\x12U\n\x1aget_topic_producer_records\x12\x19.TopicProducerWriteResult\x1a\x16.TopicProducerResponse"\x00(\x01\x30\x01\x42\x1d\n\x19\x61i.langstream.agents.grpcP\x01\x62\x06proto3' + b'\n\x0b\x61gent.proto\x1a\x1bgoogle/protobuf/empty.proto"!\n\x0cInfoResponse\x12\x11\n\tjson_info\x18\x01 \x01(\t"\xa3\x02\n\x05Value\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\x15\n\x0b\x62ytes_value\x18\x02 \x01(\x0cH\x00\x12\x17\n\rboolean_value\x18\x03 \x01(\x08H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x14\n\nbyte_value\x18\x05 \x01(\x05H\x00\x12\x15\n\x0bshort_value\x18\x06 \x01(\x05H\x00\x12\x13\n\tint_value\x18\x07 \x01(\x05H\x00\x12\x14\n\nlong_value\x18\x08 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\t \x01(\x02H\x00\x12\x16\n\x0c\x64ouble_value\x18\n \x01(\x01H\x00\x12\x14\n\njson_value\x18\x0b \x01(\tH\x00\x12\x14\n\navro_value\x18\x0c \x01(\x0cH\x00\x42\x0c\n\ntype_oneof"-\n\x06Header\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x05value\x18\x02 \x01(\x0b\x32\x06.Value"*\n\x06Schema\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x0c"\xb3\x01\n\x06Record\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x18\n\x03key\x18\x02 \x01(\x0b\x32\x06.ValueH\x00\x88\x01\x01\x12\x1a\n\x05value\x18\x03 \x01(\x0b\x32\x06.ValueH\x01\x88\x01\x01\x12\x18\n\x07headers\x18\x04 \x03(\x0b\x32\x07.Header\x12\x0e\n\x06origin\x18\x05 \x01(\t\x12\x16\n\ttimestamp\x18\x06 \x01(\x03H\x02\x88\x01\x01\x42\x06\n\x04_keyB\x08\n\x06_valueB\x0c\n\n_timestamp"K\n\x18TopicProducerWriteResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error"X\n\x15TopicProducerResponse\x12\r\n\x05topic\x18\x01 \x01(\t\x12\x17\n\x06schema\x18\x02 \x01(\x0b\x32\x07.Schema\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record"P\n\x10PermanentFailure\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x15\n\rerror_message\x18\x02 \x01(\t\x12\x12\n\nerror_type\x18\x03 \x01(\t"X\n\rSourceRequest\x12\x19\n\x11\x63ommitted_records\x18\x01 \x03(\x03\x12,\n\x11permanent_failure\x18\x02 \x01(\x0b\x32\x11.PermanentFailure"C\n\x0eSourceResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"E\n\x10ProcessorRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"O\n\x11ProcessorResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12!\n\x07results\x18\x02 \x03(\x0b\x32\x10.ProcessorResult"\x84\x01\n\x0fProcessorResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x07records\x18\x03 \x03(\x0b\x32\x07.Record\x12\x17\n\nerror_type\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\x08\n\x06_errorB\r\n\x0b_error_type"?\n\x0bSinkRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x17\n\x06record\x18\x02 \x01(\x0b\x32\x07.Record"?\n\x0cSinkResponse\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error2\xaf\x02\n\x0c\x41gentService\x12\x35\n\nagent_info\x12\x16.google.protobuf.Empty\x1a\r.InfoResponse"\x00\x12-\n\x04read\x12\x0e.SourceRequest\x1a\x0f.SourceResponse"\x00(\x01\x30\x01\x12\x36\n\x07process\x12\x11.ProcessorRequest\x1a\x12.ProcessorResponse"\x00(\x01\x30\x01\x12*\n\x05write\x12\x0c.SinkRequest\x1a\r.SinkResponse"\x00(\x01\x30\x01\x12U\n\x1aget_topic_producer_records\x12\x19.TopicProducerWriteResult\x1a\x16.TopicProducerResponse"\x00(\x01\x30\x01\x42\x1d\n\x19\x61i.langstream.agents.grpcP\x01\x62\x06proto3' ) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages( - DESCRIPTOR, "langstream_grpc.proto.agent_pb2", _globals -) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "agent_pb2", _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b"\n\031ai.langstream.agents.grpcP\001" - _globals["_INFORESPONSE"]._serialized_start = 66 - _globals["_INFORESPONSE"]._serialized_end = 99 - _globals["_VALUE"]._serialized_start = 102 - _globals["_VALUE"]._serialized_end = 393 - _globals["_HEADER"]._serialized_start = 395 - _globals["_HEADER"]._serialized_end = 440 - _globals["_SCHEMA"]._serialized_start = 442 - _globals["_SCHEMA"]._serialized_end = 484 - _globals["_RECORD"]._serialized_start = 487 - _globals["_RECORD"]._serialized_end = 666 - _globals["_TOPICPRODUCERWRITERESULT"]._serialized_start = 668 - _globals["_TOPICPRODUCERWRITERESULT"]._serialized_end = 743 - _globals["_TOPICPRODUCERRESPONSE"]._serialized_start = 745 - _globals["_TOPICPRODUCERRESPONSE"]._serialized_end = 833 - _globals["_PERMANENTFAILURE"]._serialized_start = 835 - _globals["_PERMANENTFAILURE"]._serialized_end = 895 - _globals["_SOURCEREQUEST"]._serialized_start = 897 - _globals["_SOURCEREQUEST"]._serialized_end = 985 - _globals["_SOURCERESPONSE"]._serialized_start = 987 - _globals["_SOURCERESPONSE"]._serialized_end = 1054 - _globals["_PROCESSORREQUEST"]._serialized_start = 1056 - _globals["_PROCESSORREQUEST"]._serialized_end = 1125 - _globals["_PROCESSORRESPONSE"]._serialized_start = 1127 - _globals["_PROCESSORRESPONSE"]._serialized_end = 1206 - _globals["_PROCESSORRESULT"]._serialized_start = 1208 - _globals["_PROCESSORRESULT"]._serialized_end = 1300 - _globals["_SINKREQUEST"]._serialized_start = 1302 - _globals["_SINKREQUEST"]._serialized_end = 1365 - _globals["_SINKRESPONSE"]._serialized_start = 1367 - _globals["_SINKRESPONSE"]._serialized_end = 1430 - _globals["_AGENTSERVICE"]._serialized_start = 1433 - _globals["_AGENTSERVICE"]._serialized_end = 1736 + _globals["DESCRIPTOR"]._options = None + _globals["DESCRIPTOR"]._serialized_options = b"\n\031ai.langstream.agents.grpcP\001" + _globals["_INFORESPONSE"]._serialized_start = 44 + _globals["_INFORESPONSE"]._serialized_end = 77 + _globals["_VALUE"]._serialized_start = 80 + _globals["_VALUE"]._serialized_end = 371 + _globals["_HEADER"]._serialized_start = 373 + _globals["_HEADER"]._serialized_end = 418 + _globals["_SCHEMA"]._serialized_start = 420 + _globals["_SCHEMA"]._serialized_end = 462 + _globals["_RECORD"]._serialized_start = 465 + _globals["_RECORD"]._serialized_end = 644 + _globals["_TOPICPRODUCERWRITERESULT"]._serialized_start = 646 + _globals["_TOPICPRODUCERWRITERESULT"]._serialized_end = 721 + _globals["_TOPICPRODUCERRESPONSE"]._serialized_start = 723 + _globals["_TOPICPRODUCERRESPONSE"]._serialized_end = 811 + _globals["_PERMANENTFAILURE"]._serialized_start = 813 + _globals["_PERMANENTFAILURE"]._serialized_end = 893 + _globals["_SOURCEREQUEST"]._serialized_start = 895 + _globals["_SOURCEREQUEST"]._serialized_end = 983 + _globals["_SOURCERESPONSE"]._serialized_start = 985 + _globals["_SOURCERESPONSE"]._serialized_end = 1052 + _globals["_PROCESSORREQUEST"]._serialized_start = 1054 + _globals["_PROCESSORREQUEST"]._serialized_end = 1123 + _globals["_PROCESSORRESPONSE"]._serialized_start = 1125 + _globals["_PROCESSORRESPONSE"]._serialized_end = 1204 + _globals["_PROCESSORRESULT"]._serialized_start = 1207 + _globals["_PROCESSORRESULT"]._serialized_end = 1339 + _globals["_SINKREQUEST"]._serialized_start = 1341 + _globals["_SINKREQUEST"]._serialized_end = 1404 + _globals["_SINKRESPONSE"]._serialized_start = 1406 + _globals["_SINKRESPONSE"]._serialized_end = 1469 + _globals["_AGENTSERVICE"]._serialized_start = 1472 + _globals["_AGENTSERVICE"]._serialized_end = 1775 # @@protoc_insertion_point(module_scope) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi index d84bfd4f6..22d3e499a 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi @@ -13,13 +13,13 @@ from typing import ( DESCRIPTOR: _descriptor.FileDescriptor class InfoResponse(_message.Message): - __slots__ = ["json_info"] + __slots__ = ("json_info",) JSON_INFO_FIELD_NUMBER: _ClassVar[int] json_info: str def __init__(self, json_info: _Optional[str] = ...) -> None: ... class Value(_message.Message): - __slots__ = [ + __slots__ = ( "schema_id", "bytes_value", "boolean_value", @@ -32,7 +32,7 @@ class Value(_message.Message): "double_value", "json_value", "avro_value", - ] + ) SCHEMA_ID_FIELD_NUMBER: _ClassVar[int] BYTES_VALUE_FIELD_NUMBER: _ClassVar[int] BOOLEAN_VALUE_FIELD_NUMBER: _ClassVar[int] @@ -74,7 +74,7 @@ class Value(_message.Message): ) -> None: ... class Header(_message.Message): - __slots__ = ["name", "value"] + __slots__ = ("name", "value") NAME_FIELD_NUMBER: _ClassVar[int] VALUE_FIELD_NUMBER: _ClassVar[int] name: str @@ -86,7 +86,7 @@ class Header(_message.Message): ) -> None: ... class Schema(_message.Message): - __slots__ = ["schema_id", "value"] + __slots__ = ("schema_id", "value") SCHEMA_ID_FIELD_NUMBER: _ClassVar[int] VALUE_FIELD_NUMBER: _ClassVar[int] schema_id: int @@ -96,7 +96,7 @@ class Schema(_message.Message): ) -> None: ... class Record(_message.Message): - __slots__ = ["record_id", "key", "value", "headers", "origin", "timestamp"] + __slots__ = ("record_id", "key", "value", "headers", "origin", "timestamp") RECORD_ID_FIELD_NUMBER: _ClassVar[int] KEY_FIELD_NUMBER: _ClassVar[int] VALUE_FIELD_NUMBER: _ClassVar[int] @@ -120,7 +120,7 @@ class Record(_message.Message): ) -> None: ... class TopicProducerWriteResult(_message.Message): - __slots__ = ["record_id", "error"] + __slots__ = ("record_id", "error") RECORD_ID_FIELD_NUMBER: _ClassVar[int] ERROR_FIELD_NUMBER: _ClassVar[int] record_id: int @@ -130,7 +130,7 @@ class TopicProducerWriteResult(_message.Message): ) -> None: ... class TopicProducerResponse(_message.Message): - __slots__ = ["topic", "schema", "record"] + __slots__ = ("topic", "schema", "record") TOPIC_FIELD_NUMBER: _ClassVar[int] SCHEMA_FIELD_NUMBER: _ClassVar[int] RECORD_FIELD_NUMBER: _ClassVar[int] @@ -145,17 +145,22 @@ class TopicProducerResponse(_message.Message): ) -> None: ... class PermanentFailure(_message.Message): - __slots__ = ["record_id", "error_message"] + __slots__ = ("record_id", "error_message", "error_type") RECORD_ID_FIELD_NUMBER: _ClassVar[int] ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int] + ERROR_TYPE_FIELD_NUMBER: _ClassVar[int] record_id: int error_message: str + error_type: str def __init__( - self, record_id: _Optional[int] = ..., error_message: _Optional[str] = ... + self, + record_id: _Optional[int] = ..., + error_message: _Optional[str] = ..., + error_type: _Optional[str] = ..., ) -> None: ... class SourceRequest(_message.Message): - __slots__ = ["committed_records", "permanent_failure"] + __slots__ = ("committed_records", "permanent_failure") COMMITTED_RECORDS_FIELD_NUMBER: _ClassVar[int] PERMANENT_FAILURE_FIELD_NUMBER: _ClassVar[int] committed_records: _containers.RepeatedScalarFieldContainer[int] @@ -167,7 +172,7 @@ class SourceRequest(_message.Message): ) -> None: ... class SourceResponse(_message.Message): - __slots__ = ["schema", "records"] + __slots__ = ("schema", "records") SCHEMA_FIELD_NUMBER: _ClassVar[int] RECORDS_FIELD_NUMBER: _ClassVar[int] schema: Schema @@ -179,7 +184,7 @@ class SourceResponse(_message.Message): ) -> None: ... class ProcessorRequest(_message.Message): - __slots__ = ["schema", "records"] + __slots__ = ("schema", "records") SCHEMA_FIELD_NUMBER: _ClassVar[int] RECORDS_FIELD_NUMBER: _ClassVar[int] schema: Schema @@ -191,7 +196,7 @@ class ProcessorRequest(_message.Message): ) -> None: ... class ProcessorResponse(_message.Message): - __slots__ = ["schema", "results"] + __slots__ = ("schema", "results") SCHEMA_FIELD_NUMBER: _ClassVar[int] RESULTS_FIELD_NUMBER: _ClassVar[int] schema: Schema @@ -203,22 +208,25 @@ class ProcessorResponse(_message.Message): ) -> None: ... class ProcessorResult(_message.Message): - __slots__ = ["record_id", "error", "records"] + __slots__ = ("record_id", "error", "records", "error_type") RECORD_ID_FIELD_NUMBER: _ClassVar[int] ERROR_FIELD_NUMBER: _ClassVar[int] RECORDS_FIELD_NUMBER: _ClassVar[int] + ERROR_TYPE_FIELD_NUMBER: _ClassVar[int] record_id: int error: str records: _containers.RepeatedCompositeFieldContainer[Record] + error_type: str def __init__( self, record_id: _Optional[int] = ..., error: _Optional[str] = ..., records: _Optional[_Iterable[_Union[Record, _Mapping]]] = ..., + error_type: _Optional[str] = ..., ) -> None: ... class SinkRequest(_message.Message): - __slots__ = ["schema", "record"] + __slots__ = ("schema", "record") SCHEMA_FIELD_NUMBER: _ClassVar[int] RECORD_FIELD_NUMBER: _ClassVar[int] schema: Schema @@ -230,7 +238,7 @@ class SinkRequest(_message.Message): ) -> None: ... class SinkResponse(_message.Message): - __slots__ = ["record_id", "error"] + __slots__ = ("record_id", "error") RECORD_ID_FIELD_NUMBER: _ClassVar[int] ERROR_FIELD_NUMBER: _ClassVar[int] record_id: int diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2_grpc.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2_grpc.py index 1930bb556..e3874da85 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2_grpc.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2_grpc.py @@ -18,8 +18,8 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc +from langstream_grpc.proto import agent_pb2 as agent__pb2 from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -from langstream_grpc.proto import agent_pb2 as langstream__grpc_dot_proto_dot_agent__pb2 class AgentServiceStub(object): @@ -34,27 +34,27 @@ def __init__(self, channel): self.agent_info = channel.unary_unary( "/AgentService/agent_info", request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - response_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.InfoResponse.FromString, + response_deserializer=agent__pb2.InfoResponse.FromString, ) self.read = channel.stream_stream( "/AgentService/read", - request_serializer=langstream__grpc_dot_proto_dot_agent__pb2.SourceRequest.SerializeToString, - response_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.SourceResponse.FromString, + request_serializer=agent__pb2.SourceRequest.SerializeToString, + response_deserializer=agent__pb2.SourceResponse.FromString, ) self.process = channel.stream_stream( "/AgentService/process", - request_serializer=langstream__grpc_dot_proto_dot_agent__pb2.ProcessorRequest.SerializeToString, - response_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.ProcessorResponse.FromString, + request_serializer=agent__pb2.ProcessorRequest.SerializeToString, + response_deserializer=agent__pb2.ProcessorResponse.FromString, ) self.write = channel.stream_stream( "/AgentService/write", - request_serializer=langstream__grpc_dot_proto_dot_agent__pb2.SinkRequest.SerializeToString, - response_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.SinkResponse.FromString, + request_serializer=agent__pb2.SinkRequest.SerializeToString, + response_deserializer=agent__pb2.SinkResponse.FromString, ) self.get_topic_producer_records = channel.stream_stream( "/AgentService/get_topic_producer_records", - request_serializer=langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerWriteResult.SerializeToString, - response_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerResponse.FromString, + request_serializer=agent__pb2.TopicProducerWriteResult.SerializeToString, + response_deserializer=agent__pb2.TopicProducerResponse.FromString, ) @@ -97,27 +97,27 @@ def add_AgentServiceServicer_to_server(servicer, server): "agent_info": grpc.unary_unary_rpc_method_handler( servicer.agent_info, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - response_serializer=langstream__grpc_dot_proto_dot_agent__pb2.InfoResponse.SerializeToString, + response_serializer=agent__pb2.InfoResponse.SerializeToString, ), "read": grpc.stream_stream_rpc_method_handler( servicer.read, - request_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.SourceRequest.FromString, - response_serializer=langstream__grpc_dot_proto_dot_agent__pb2.SourceResponse.SerializeToString, + request_deserializer=agent__pb2.SourceRequest.FromString, + response_serializer=agent__pb2.SourceResponse.SerializeToString, ), "process": grpc.stream_stream_rpc_method_handler( servicer.process, - request_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.ProcessorRequest.FromString, - response_serializer=langstream__grpc_dot_proto_dot_agent__pb2.ProcessorResponse.SerializeToString, + request_deserializer=agent__pb2.ProcessorRequest.FromString, + response_serializer=agent__pb2.ProcessorResponse.SerializeToString, ), "write": grpc.stream_stream_rpc_method_handler( servicer.write, - request_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.SinkRequest.FromString, - response_serializer=langstream__grpc_dot_proto_dot_agent__pb2.SinkResponse.SerializeToString, + request_deserializer=agent__pb2.SinkRequest.FromString, + response_serializer=agent__pb2.SinkResponse.SerializeToString, ), "get_topic_producer_records": grpc.stream_stream_rpc_method_handler( servicer.get_topic_producer_records, - request_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerWriteResult.FromString, - response_serializer=langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerResponse.SerializeToString, + request_deserializer=agent__pb2.TopicProducerWriteResult.FromString, + response_serializer=agent__pb2.TopicProducerResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -148,7 +148,7 @@ def agent_info( target, "/AgentService/agent_info", google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - langstream__grpc_dot_proto_dot_agent__pb2.InfoResponse.FromString, + agent__pb2.InfoResponse.FromString, options, channel_credentials, insecure, @@ -176,8 +176,8 @@ def read( request_iterator, target, "/AgentService/read", - langstream__grpc_dot_proto_dot_agent__pb2.SourceRequest.SerializeToString, - langstream__grpc_dot_proto_dot_agent__pb2.SourceResponse.FromString, + agent__pb2.SourceRequest.SerializeToString, + agent__pb2.SourceResponse.FromString, options, channel_credentials, insecure, @@ -205,8 +205,8 @@ def process( request_iterator, target, "/AgentService/process", - langstream__grpc_dot_proto_dot_agent__pb2.ProcessorRequest.SerializeToString, - langstream__grpc_dot_proto_dot_agent__pb2.ProcessorResponse.FromString, + agent__pb2.ProcessorRequest.SerializeToString, + agent__pb2.ProcessorResponse.FromString, options, channel_credentials, insecure, @@ -234,8 +234,8 @@ def write( request_iterator, target, "/AgentService/write", - langstream__grpc_dot_proto_dot_agent__pb2.SinkRequest.SerializeToString, - langstream__grpc_dot_proto_dot_agent__pb2.SinkResponse.FromString, + agent__pb2.SinkRequest.SerializeToString, + agent__pb2.SinkResponse.FromString, options, channel_credentials, insecure, @@ -263,8 +263,8 @@ def get_topic_producer_records( request_iterator, target, "/AgentService/get_topic_producer_records", - langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerWriteResult.SerializeToString, - langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerResponse.FromString, + agent__pb2.TopicProducerWriteResult.SerializeToString, + agent__pb2.TopicProducerResponse.FromString, options, channel_credentials, insecure, diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py index 82eefae20..43c2aef3a 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py @@ -184,6 +184,20 @@ async def test_failing_record(): assert len(response.results) == 1 assert response.results[0].HasField("error") is True assert response.results[0].error == "failure" + assert response.results[0].error_type == "INTERNAL_ERROR" + + +async def test_failing_record_bad_record(): + async with ServerAndStub( + "langstream_grpc.tests.test_grpc_processor.MyFailingProcessorForBadRecord" + ) as server_and_stub: + async for response in server_and_stub.stub.process( + [ProcessorRequest(records=[GrpcRecord()])] + ): + assert len(response.results) == 1 + assert response.results[0].HasField("error") is True + assert response.results[0].error == "this record is invalid" + assert response.results[0].error_type == "INVALID_RECORD" @pytest.mark.parametrize("klass", ["MyFutureProcessor", "MyAsyncProcessor"]) @@ -288,6 +302,13 @@ def process(self, record: Record) -> List[RecordType]: raise Exception("failure") +class MyFailingProcessorForBadRecord(Processor): + def process(self, record: Record) -> List[RecordType]: + from langstream_grpc.util import InvalidRecordError + + raise InvalidRecordError("this record is invalid") + + class MyFutureProcessor(Processor): def __init__(self): self.executor = ThreadPoolExecutor(max_workers=10) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/util.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/util.py index 98f06444f..28dbc9a31 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/util.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/util.py @@ -19,7 +19,7 @@ from .api import Record -__all__ = ["SimpleRecord", "AvroValue"] +__all__ = ["SimpleRecord", "AvroValue", "InvalidRecordError"] class SimpleRecord(Record): @@ -69,3 +69,7 @@ def __repr__(self): class AvroValue(object): schema: dict value: Any + + +class InvalidRecordError(Exception): + pass diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/scripts/generate-grpc-code.sh b/langstream-runtime/langstream-runtime-impl/src/main/python/scripts/generate-grpc-code.sh new file mode 100755 index 000000000..651a22c9d --- /dev/null +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/scripts/generate-grpc-code.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +poetry install +grpc_proto_dir=../../../../../langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto +out_dir=./langstream_grpc/proto +poetry run python -m grpc_tools.protoc \ + -I${grpc_proto_dir} \ + --python_out=${out_dir} \ + --pyi_out=${out_dir} \ + --grpc_python_out=${out_dir} \ + ${grpc_proto_dir}/agent.proto \ No newline at end of file diff --git a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/ErrorHandlingTest.java b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/ErrorHandlingIT.java similarity index 73% rename from langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/ErrorHandlingTest.java rename to langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/ErrorHandlingIT.java index ba74055ce..b2e3ef8e9 100644 --- a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/ErrorHandlingTest.java +++ b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/ErrorHandlingIT.java @@ -15,9 +15,12 @@ */ package ai.langstream.agents; +import static ai.langstream.testrunners.AbstractApplicationRunner.INTEGRATION_TESTS_GROUP1; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; +import ai.langstream.api.runner.code.Record; +import ai.langstream.api.runner.code.SimpleRecord; import ai.langstream.api.runner.topics.TopicConsumer; import ai.langstream.api.runner.topics.TopicProducer; import ai.langstream.mockagents.MockProcessorAgentsCodeProvider; @@ -30,10 +33,16 @@ import java.util.UUID; import lombok.extern.slf4j.Slf4j; import org.awaitility.Awaitility; +import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; @Slf4j -class ErrorHandlingTest extends AbstractGenericStreamingApplicationRunner { +@Tag(INTEGRATION_TESTS_GROUP1) +class ErrorHandlingIT extends AbstractGenericStreamingApplicationRunner { + + public ErrorHandlingIT() { + super("pulsar"); + } @Test public void testDiscardErrors() throws Exception { @@ -51,9 +60,6 @@ public void testDiscardErrors() throws Exception { topics: - name: "%s" creation-mode: create-if-not-exists - options: - # we want to read more than one record at a time - consumer.max.poll.records: 10 - name: "%s" creation-mode: create-if-not-exists errors: @@ -92,8 +98,8 @@ public void testDiscardErrors() throws Exception { public void testDeadLetter() throws Exception { String tenant = "tenant"; String[] expectedAgents = {"app-step1"}; - String inputTopic = "input-topic-" + UUID.randomUUID(); - String outputTopic = "output-topic-" + UUID.randomUUID(); + String inputTopic = "input"; + String outputTopic = "output"; Map application = Map.of( @@ -126,17 +132,70 @@ public void testDeadLetter() throws Exception { TopicConsumer consumerDeadletter = createConsumer(inputTopic + "-deadletter")) { List expectedMessages = new ArrayList<>(); - List expectedMessagesDeadletter = new ArrayList<>(); + List expectedMessagesDeadletter = new ArrayList<>(); for (int i = 0; i < 10; i++) { sendMessage(producer, "fail-me-" + i); sendMessage(producer, "keep-me-" + i); expectedMessages.add("keep-me-" + i); - expectedMessagesDeadletter.add("fail-me-" + i); + expectedMessagesDeadletter.add( + SimpleRecord.builder() + .value("fail-me-" + i) + .key(null) + .headers( + List.of( + SimpleRecord.SimpleHeader.of( + "langstream-error-type", + "INTERNAL_ERROR"), + SimpleRecord.SimpleHeader.of( + "cause-class", + "ai.langstream.mockagents.MockProcessorAgentsCodeProvider$InjectedFailure"), + SimpleRecord.SimpleHeader.of( + "langstream-error-cause-class", + "ai.langstream.mockagents.MockProcessorAgentsCodeProvider$InjectedFailure"), + SimpleRecord.SimpleHeader.of( + "error-class", + "ai.langstream.runtime.agent.AgentRunner$PermanentFailureException"), + SimpleRecord.SimpleHeader.of( + "langstream-error-class", + "ai.langstream.runtime.agent.AgentRunner$PermanentFailureException"), + SimpleRecord.SimpleHeader.of( + "source-topic", + "persistent://public/default/input"), + SimpleRecord.SimpleHeader.of( + "langstream-error-source-topic", + "persistent://public/default/input"), + SimpleRecord.SimpleHeader.of( + "cause-msg", + "Failing on content: fail-me-" + i), + SimpleRecord.SimpleHeader.of( + "langstream-error-cause-message", + "Failing on content: fail-me-" + i), + SimpleRecord.SimpleHeader.of( + "root-cause-class", + "ai.langstream.mockagents.MockProcessorAgentsCodeProvider$InjectedFailure"), + SimpleRecord.SimpleHeader.of( + "langstream-error-root-cause-class", + "ai.langstream.mockagents.MockProcessorAgentsCodeProvider$InjectedFailure"), + SimpleRecord.SimpleHeader.of( + "root-cause-msg", + "Failing on content: fail-me-" + i), + SimpleRecord.SimpleHeader.of( + "langstream-error-root-cause-message", + "Failing on content: fail-me-" + i), + SimpleRecord.SimpleHeader.of( + "error-msg", + "ai.langstream.mockagents.MockProcessorAgentsCodeProvider$InjectedFailure: Failing on content: fail-me-" + + i), + SimpleRecord.SimpleHeader.of( + "langstream-error-message", + "ai.langstream.mockagents.MockProcessorAgentsCodeProvider$InjectedFailure: Failing on content: fail-me-" + + i))) + .build()); } - executeAgentRunners(applicationRuntime); + executeAgentRunners(applicationRuntime, 15); - waitForMessages(consumerDeadletter, expectedMessagesDeadletter); + waitForRecords(consumerDeadletter, expectedMessagesDeadletter); waitForMessages(consumer, expectedMessages); } } @@ -158,9 +217,6 @@ public void testFailOnErrors() throws Exception { topics: - name: "%s" creation-mode: create-if-not-exists - options: - # we want to read more than one record at a time - consumer.max.poll.records: 10 - name: "%s" creation-mode: create-if-not-exists errors: @@ -227,9 +283,6 @@ public void testDiscardErrorsOnSink() throws Exception { topics: - name: "%s" creation-mode: create-if-not-exists - options: - # we want to read more than one record at a time - consumer.max.poll.records: 10 errors: on-failure: fail retries: 5 @@ -283,9 +336,6 @@ public void testFailOnErrorsOnSink() throws Exception { topics: - name: "%s" creation-mode: create-if-not-exists - options: - # we want to read more than one record at a time - consumer.max.poll.records: 10 errors: on-failure: skip retries: 5 @@ -352,7 +402,6 @@ public void testDeadLetterOnSink() throws Exception { input: "%s" errors: on-failure: dead-letter - retries: 3 configuration: fail-on-content: "fail-me" """ @@ -364,7 +413,7 @@ public void testDeadLetterOnSink() throws Exception { TopicConsumer consumerDeadletter = createConsumer(inputTopic + "-deadletter")) { List expectedMessages = new ArrayList<>(); - List expectedMessagesDeadletter = new ArrayList<>(); + List expectedMessagesDeadletter = new ArrayList<>(); for (int i = 0; i < 10; i++) { sendMessage(producer, "fail-me-" + i); sendMessage(producer, "keep-me-" + i); @@ -372,9 +421,9 @@ public void testDeadLetterOnSink() throws Exception { expectedMessagesDeadletter.add("fail-me-" + i); } - executeAgentRunners(applicationRuntime); + executeAgentRunners(applicationRuntime, 20); - waitForMessages(consumerDeadletter, expectedMessagesDeadletter); + waitForMessagesInAnyOrder(consumerDeadletter, expectedMessagesDeadletter); Awaitility.await() .untilAsserted( diff --git a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/StatefulAgentsTest.java b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/StatefulAgentsIT.java similarity index 93% rename from langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/StatefulAgentsTest.java rename to langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/StatefulAgentsIT.java index 74a8490c3..ea00152bb 100644 --- a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/StatefulAgentsTest.java +++ b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/StatefulAgentsIT.java @@ -15,6 +15,8 @@ */ package ai.langstream.agents; +import static ai.langstream.testrunners.AbstractApplicationRunner.INTEGRATION_TESTS_GROUP1; + import ai.langstream.api.runner.topics.TopicConsumer; import ai.langstream.api.runner.topics.TopicProducer; import ai.langstream.testrunners.AbstractGenericStreamingApplicationRunner; @@ -22,10 +24,12 @@ import java.util.Map; import java.util.UUID; import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; @Slf4j -class StatefulAgentsTest extends AbstractGenericStreamingApplicationRunner { +@Tag(INTEGRATION_TESTS_GROUP1) +class StatefulAgentsIT extends AbstractGenericStreamingApplicationRunner { @Test public void testSingleStatefulAgent() throws Exception { diff --git a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/WebCrawlerSourceIT.java b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/WebCrawlerSourceIT.java index 2d3d42ea2..d8eedc803 100644 --- a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/WebCrawlerSourceIT.java +++ b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/agents/WebCrawlerSourceIT.java @@ -15,6 +15,7 @@ */ package ai.langstream.agents; +import static ai.langstream.testrunners.AbstractApplicationRunner.INTEGRATION_TESTS_GROUP1; import static com.github.tomakehurst.wiremock.client.WireMock.*; import static org.junit.jupiter.api.Assertions.*; @@ -32,10 +33,12 @@ import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; @Slf4j @WireMockTest +@Tag(INTEGRATION_TESTS_GROUP1) class WebCrawlerSourceIT extends AbstractGenericStreamingApplicationRunner { static WireMockRuntimeInfo wireMockRuntimeInfo; diff --git a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/pulsar/PulsarDLQSourceIT.java b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/pulsar/PulsarDLQSourceIT.java new file mode 100644 index 000000000..74e2625c5 --- /dev/null +++ b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/pulsar/PulsarDLQSourceIT.java @@ -0,0 +1,166 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.pulsar; + +import static org.junit.jupiter.api.Assertions.*; + +import ai.langstream.api.runner.code.Record; +import ai.langstream.api.runner.code.SimpleRecord; +import ai.langstream.api.runner.topics.TopicConsumer; +import ai.langstream.api.runner.topics.TopicProducer; +import ai.langstream.testrunners.AbstractGenericStreamingApplicationRunner; +import ai.langstream.testrunners.pulsar.PulsarApplicationRunner; +import java.nio.charset.StandardCharsets; +import java.util.*; +import lombok.extern.slf4j.Slf4j; +import org.apache.pulsar.client.api.*; +import org.junit.jupiter.api.Test; + +@Slf4j +class PulsarDLQSourceIT extends AbstractGenericStreamingApplicationRunner { + + public PulsarDLQSourceIT() { + super("pulsar"); + } + + @Test + public void test() throws Exception { + String tenant = "tenant"; + String[] expectedAgents = {"app-step1", "app-read-errors"}; + + String serviceUrl = ((PulsarApplicationRunner) streamingClusterRunner).getBrokerUrl(); + + Map application = + Map.of( + "module.yaml", + """ + module: "module-1" + id: "pipeline-1" + topics: + - name: "input" + creation-mode: create-if-not-exists + - name: "output" + creation-mode: create-if-not-exists + pipeline: + - name: "some agent" + id: "step1" + type: "mock-failing-processor" + input: "input" + output: "output" + errors: + on-failure: dead-letter + configuration: + fail-on-content: "fail-me" + """, + "pipeline-readerrors.yaml", + """ + id: "read-errors" + topics: + - name: "input-deadletter" + creation-mode: create-if-not-exists + - name: "dlq-out" + creation-mode: create-if-not-exists + pipeline: + - name: "dlq" + id: "read-errors" + type: "pulsardlq-source" + output: "dlq-out" + configuration: + pulsar-url: "%s" + namespace: "public/default" + subscription: "dlq-subscription" + dlq-suffix: "-deadletter" + include-partitioned: false + timeout-ms: 1000 + """ + .formatted(serviceUrl)); + try (ApplicationRuntime applicationRuntime = + deployApplication( + tenant, "app", application, buildInstanceYaml(), expectedAgents)) { + try (TopicProducer producer = createProducer("input"); + TopicConsumer consumer = createConsumer("output"); + TopicConsumer consumerDeadletterOut = createConsumer("dlq-out")) { + + List expectedMessages = new ArrayList<>(); + List expectedMessagesDeadletter = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + sendMessage(producer, "fail-me-" + i); + sendMessage(producer, "keep-me-" + i); + expectedMessages.add("keep-me-" + i); + expectedMessagesDeadletter.add( + SimpleRecord.builder() + .value(("fail-me-" + i).getBytes(StandardCharsets.UTF_8)) + .key(null) + .headers( + List.of( + SimpleRecord.SimpleHeader.of( + "langstream-error-type", + "INTERNAL_ERROR"), + SimpleRecord.SimpleHeader.of( + "cause-class", + "ai.langstream.mockagents.MockProcessorAgentsCodeProvider$InjectedFailure"), + SimpleRecord.SimpleHeader.of( + "langstream-error-cause-class", + "ai.langstream.mockagents.MockProcessorAgentsCodeProvider$InjectedFailure"), + SimpleRecord.SimpleHeader.of( + "error-class", + "ai.langstream.runtime.agent.AgentRunner$PermanentFailureException"), + SimpleRecord.SimpleHeader.of( + "langstream-error-class", + "ai.langstream.runtime.agent.AgentRunner$PermanentFailureException"), + SimpleRecord.SimpleHeader.of( + "source-topic", + "persistent://public/default/input"), + SimpleRecord.SimpleHeader.of( + "langstream-error-source-topic", + "persistent://public/default/input"), + SimpleRecord.SimpleHeader.of( + "cause-msg", + "Failing on content: fail-me-" + i), + SimpleRecord.SimpleHeader.of( + "langstream-error-cause-message", + "Failing on content: fail-me-" + i), + SimpleRecord.SimpleHeader.of( + "root-cause-class", + "ai.langstream.mockagents.MockProcessorAgentsCodeProvider$InjectedFailure"), + SimpleRecord.SimpleHeader.of( + "langstream-error-root-cause-class", + "ai.langstream.mockagents.MockProcessorAgentsCodeProvider$InjectedFailure"), + SimpleRecord.SimpleHeader.of( + "root-cause-msg", + "Failing on content: fail-me-" + i), + SimpleRecord.SimpleHeader.of( + "langstream-error-root-cause-message", + "Failing on content: fail-me-" + i), + SimpleRecord.SimpleHeader.of( + "error-msg", + "ai.langstream.mockagents.MockProcessorAgentsCodeProvider$InjectedFailure: Failing on content: fail-me-" + + i), + SimpleRecord.SimpleHeader.of( + "langstream-error-message", + "ai.langstream.mockagents.MockProcessorAgentsCodeProvider$InjectedFailure: Failing on content: fail-me-" + + i))) + .build()); + } + + executeAgentRunners(applicationRuntime, 15); + + waitForRecords(consumerDeadletterOut, expectedMessagesDeadletter); + waitForMessages(consumer, expectedMessages); + } + } + } +} diff --git a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/runtime/agent/AgentRunnerTest.java b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/runtime/agent/AgentRunnerTest.java index 41cb32349..e8fa0f897 100644 --- a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/runtime/agent/AgentRunnerTest.java +++ b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/runtime/agent/AgentRunnerTest.java @@ -20,14 +20,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import ai.langstream.api.runner.code.AbstractAgentCode; -import ai.langstream.api.runner.code.AgentContext; -import ai.langstream.api.runner.code.AgentSink; -import ai.langstream.api.runner.code.AgentSource; -import ai.langstream.api.runner.code.MetricsReporter; +import ai.langstream.api.runner.code.*; import ai.langstream.api.runner.code.Record; -import ai.langstream.api.runner.code.SimpleRecord; -import ai.langstream.api.runner.code.SingleRecordAgentProcessor; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -36,14 +30,13 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.function.Supplier; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; @Slf4j class AgentRunnerTest { - private final ExecutorService executorService = Executors.newCachedThreadPool(); - @Test void skip() throws Exception { SimpleSource source = new SimpleSource(List.of(SimpleRecord.of("key", "fail-me"))); @@ -53,14 +46,7 @@ void skip() throws Exception { new StandardErrorsHandler(Map.of("retries", 0, "onFailure", "skip")); AgentContext context = createMockAgentContext(); - AgentRunner.runMainLoop( - source, - processor, - sink, - context, - errorHandler, - source::hasMoreRecords, - executorService); + runMainLoopSync(source, processor, sink, context, errorHandler, source::hasMoreRecords); processor.expectExecutions(1); source.expectUncommitted(0); } @@ -83,14 +69,13 @@ void failWithRetries() { assertThrows( AgentRunner.PermanentFailureException.class, () -> - AgentRunner.runMainLoop( + runMainLoopSync( source, processor, sink, context, errorHandler, - source::hasMoreRecords, - executorService)); + source::hasMoreRecords)); processor.expectExecutions(3); source.expectUncommitted(1); } @@ -106,14 +91,13 @@ void failNoRetries() { assertThrows( AgentRunner.PermanentFailureException.class, () -> - AgentRunner.runMainLoop( + runMainLoopSync( source, processor, sink, context, errorHandler, - source::hasMoreRecords, - executorService)); + source::hasMoreRecords)); processor.expectExecutions(1); source.expectUncommitted(1); } @@ -130,14 +114,7 @@ void someFailedSomeGoodWithSkip() throws Exception { StandardErrorsHandler errorHandler = new StandardErrorsHandler(Map.of("retries", 0, "onFailure", "skip")); AgentContext context = createMockAgentContext(); - AgentRunner.runMainLoop( - source, - processor, - sink, - context, - errorHandler, - source::hasMoreRecords, - executorService); + runMainLoopSync(source, processor, sink, context, errorHandler, source::hasMoreRecords); processor.expectExecutions(2); source.expectUncommitted(0); } @@ -154,14 +131,7 @@ void someGoodSomeFailedWithSkip() throws Exception { StandardErrorsHandler errorHandler = new StandardErrorsHandler(Map.of("retries", 0, "onFailure", "skip")); AgentContext context = createMockAgentContext(); - AgentRunner.runMainLoop( - source, - processor, - sink, - context, - errorHandler, - source::hasMoreRecords, - executorService); + runMainLoopSync(source, processor, sink, context, errorHandler, source::hasMoreRecords); processor.expectExecutions(2); source.expectUncommitted(0); } @@ -178,14 +148,7 @@ void someGoodSomeFailedWithRetry() throws Exception { StandardErrorsHandler errorHandler = new StandardErrorsHandler(Map.of("retries", 5, "onFailure", "fail")); AgentContext context = createMockAgentContext(); - AgentRunner.runMainLoop( - source, - processor, - sink, - context, - errorHandler, - source::hasMoreRecords, - executorService); + runMainLoopSync(source, processor, sink, context, errorHandler, source::hasMoreRecords); processor.expectExecutions(5); source.expectUncommitted(0); } @@ -203,14 +166,7 @@ void someGoodSomeFailedWithSkipAndBatching() throws Exception { StandardErrorsHandler errorHandler = new StandardErrorsHandler(Map.of("retries", 0, "onFailure", "skip")); AgentContext context = createMockAgentContext(); - AgentRunner.runMainLoop( - source, - processor, - sink, - context, - errorHandler, - source::hasMoreRecords, - executorService); + runMainLoopSync(source, processor, sink, context, errorHandler, source::hasMoreRecords); processor.expectExecutions(2); source.expectUncommitted(0); } @@ -228,14 +184,7 @@ void someFailedSomeGoodWithSkipAndBatching() throws Exception { StandardErrorsHandler errorHandler = new StandardErrorsHandler(Map.of("retries", 0, "onFailure", "skip")); AgentContext context = createMockAgentContext(); - AgentRunner.runMainLoop( - source, - processor, - sink, - context, - errorHandler, - source::hasMoreRecords, - executorService); + runMainLoopSync(source, processor, sink, context, errorHandler, source::hasMoreRecords); // all the records are processed in one batch processor.expectExecutions(2); source.expectUncommitted(0); @@ -254,14 +203,7 @@ void someFailedSomeGoodWithRetryAndBatching() throws Exception { StandardErrorsHandler errorHandler = new StandardErrorsHandler(Map.of("retries", 3, "onFailure", "fail")); AgentContext context = createMockAgentContext(); - AgentRunner.runMainLoop( - source, - processor, - sink, - context, - errorHandler, - source::hasMoreRecords, - executorService); + runMainLoopSync(source, processor, sink, context, errorHandler, source::hasMoreRecords); // all the records are processed in one batch processor.expectExecutions(4); source.expectUncommitted(0); @@ -313,7 +255,9 @@ public synchronized List read() { @Override public synchronized void commit(List records) { + System.out.println("COMMIT " + records + " UNCOMMITTED " + uncommitted); uncommitted.removeAll(records); + System.out.println("AFTER COMMIT UNCOMMITTED " + uncommitted); } synchronized void expectUncommitted(int count) { @@ -382,4 +326,26 @@ void expectExecutions(int count) { assertEquals(count, executionCount); } } + + private static void runMainLoopSync( + AgentSource source, + AgentProcessor processor, + AgentSink sink, + AgentContext agentContext, + ErrorsHandler errorsHandler, + Supplier continueLoop) + throws Exception { + final ExecutorService executorService = Executors.newCachedThreadPool(); + AgentRunner.runMainLoop( + source, + processor, + sink, + agentContext, + errorsHandler, + continueLoop, + executorService); + + executorService.shutdown(); + executorService.awaitTermination(30, java.util.concurrent.TimeUnit.SECONDS); + } } diff --git a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/testrunners/AbstractApplicationRunner.java b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/testrunners/AbstractApplicationRunner.java index bdab570af..faf45a14c 100644 --- a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/testrunners/AbstractApplicationRunner.java +++ b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/testrunners/AbstractApplicationRunner.java @@ -37,6 +37,7 @@ import ai.langstream.runtime.agent.api.AgentAPIController; import ai.langstream.runtime.api.agent.RuntimePodConfiguration; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; import com.github.dockerjava.api.model.Image; import io.fabric8.kubernetes.api.model.Secret; import java.io.IOException; @@ -58,6 +59,7 @@ import org.apache.commons.io.FileUtils; import org.apache.kafka.clients.producer.KafkaProducer; import org.awaitility.Awaitility; +import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.extension.*; @@ -71,6 +73,8 @@ public abstract class AbstractApplicationRunner { public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); public static final String INTEGRATION_TESTS_GROUP1 = "group-1"; + public static final ObjectMapper JSON_MAPPER = + new ObjectMapper().configure(SerializationFeature.INDENT_OUTPUT, true); private static final int DEFAULT_NUM_LOOPS = 10; public static final Path agentsDirectory; @@ -434,6 +438,23 @@ protected void sendMessage(TopicProducer producer, Object content) { sendFullMessage(producer, null, content, List.of()); } + protected List waitForRecords(TopicConsumer consumer, List expectedRecords) { + return waitForMessages( + consumer, + expectedRecords.size(), + (result, received) -> { + for (int i = 0; i < expectedRecords.size(); i++) { + Record expectedRecord = expectedRecords.get(i); + Record actualRecord = result.get(i); + assertRecordEquals( + actualRecord, + expectedRecord.key(), + expectedRecord.value(), + recordHeadersToMap(expectedRecord)); + } + }); + } + protected List waitForMessages( TopicConsumer consumer, int expectedSize, @@ -545,14 +566,12 @@ protected List waitForMessagesInAnyOrder( return result; } + @SneakyThrows protected static void assertRecordEquals( Record record, Object key, Object value, Map headers) { final Object recordKey = record.key(); final Object recordValue = record.value(); - Map recordHeaders = new HashMap<>(); - for (Header header : record.headers()) { - recordHeaders.put(header.key(), header.valueAsString()); - } + Map recordHeaders = recordHeadersToMap(record); log.info( """ @@ -573,15 +592,45 @@ protected static void assertRecordEquals( value, headers); assertEquals(key, record.key()); - assertEquals(value, record.value()); - assertEquals(headers, recordHeaders); + if (value instanceof byte[]) { + assertArrayEquals((byte[]) value, (byte[]) record.value()); + } else { + assertEquals(value, record.value()); + } + assertEquals( + headers.size(), + recordHeaders.size(), + "Headers size is different, expected: " + + headers.size() + + " but was: " + + recordHeaders.size()); + for (Map.Entry stringStringEntry : headers.entrySet()) { + String key1 = stringStringEntry.getKey(); + String value1 = stringStringEntry.getValue(); + String recordHeaderValue = recordHeaders.get(key1); + assertEquals( + value1, + recordHeaderValue, + "Header " + + key1 + + " is different, expected:\n" + + value1 + + "\nbut was:\n" + + recordHeaderValue); + } } - protected static void assertRecordHeadersEquals(Record record, Map headers) { + @NotNull + private static Map recordHeadersToMap(Record record) { Map recordHeaders = new HashMap<>(); for (Header header : record.headers()) { recordHeaders.put(header.key(), header.valueAsString()); } + return recordHeaders; + } + + protected static void assertRecordHeadersEquals(Record record, Map headers) { + Map recordHeaders = recordHeadersToMap(record); log.info( """ diff --git a/requirements.txt b/requirements.txt index 311c5a0b9..e3f10e480 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1 @@ -tox -black -ruff poetry