diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java new file mode 100644 index 000000000..7bfabb427 --- /dev/null +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java @@ -0,0 +1,254 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.agents.grpc; + +import ai.langstream.api.runner.code.AbstractAgentCode; +import ai.langstream.api.runner.code.AgentContext; +import ai.langstream.api.runner.code.SimpleRecord; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.protobuf.ByteString; +import com.google.protobuf.Empty; +import io.grpc.ManagedChannel; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import lombok.extern.slf4j.Slf4j; +import org.apache.avro.Conversions; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.BinaryEncoder; +import org.apache.avro.io.Decoder; +import org.apache.avro.io.DecoderFactory; +import org.apache.avro.io.EncoderFactory; + +@Slf4j +abstract class AbstractGrpcAgent extends AbstractAgentCode { + protected static final ObjectMapper MAPPER = new ObjectMapper(); + protected ManagedChannel channel; + + // For each schema sent, we increment the schemaId + private final AtomicInteger schemaId = new AtomicInteger(0); + + // Schemas sent to the server + private final Map schemaIds = new ConcurrentHashMap<>(); + + // Schemas received from the server + protected final Map serverSchemas = new ConcurrentHashMap<>(); + + protected AgentContext agentContext; + protected AgentServiceGrpc.AgentServiceBlockingStub blockingStub; + + protected record GrpcAgentRecord( + Long id, + Object key, + Object value, + String origin, + Long timestamp, + Collection headers) + implements ai.langstream.api.runner.code.Record {} + + public AbstractGrpcAgent() {} + + public AbstractGrpcAgent(ManagedChannel channel) { + this.channel = channel; + } + + public abstract void onNewSchemaToSend(Schema schema); + + @Override + public void start() throws Exception { + if (channel == null) { + throw new IllegalStateException("Channel not initialized"); + } + blockingStub = + AgentServiceGrpc.newBlockingStub(channel).withDeadlineAfter(30, TimeUnit.SECONDS); + } + + @Override + public void setContext(AgentContext context) throws Exception { + this.agentContext = context; + } + + @Override + protected Map buildAdditionalInfo() { + try { + return MAPPER.readValue( + blockingStub.agentInfo(Empty.getDefaultInstance()).getJsonInfo(), Map.class); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Override + public synchronized void close() throws Exception { + if (channel != null) { + channel.shutdown(); + } + } + + protected Object fromGrpc(Value value) throws IOException { + if (value == null) { + return null; + } + return switch (value.getTypeOneofCase()) { + case BYTES_VALUE -> value.getBytesValue().toByteArray(); + case BOOLEAN_VALUE -> value.getBooleanValue(); + case STRING_VALUE -> value.getStringValue(); + case BYTE_VALUE -> (byte) value.getByteValue(); + case SHORT_VALUE -> (short) value.getShortValue(); + case INT_VALUE -> value.getIntValue(); + case LONG_VALUE -> value.getLongValue(); + case FLOAT_VALUE -> value.getFloatValue(); + case DOUBLE_VALUE -> value.getDoubleValue(); + case JSON_VALUE -> value.getJsonValue(); + case AVRO_VALUE -> { + Object serverSchema = serverSchemas.get(value.getSchemaId()); + if (serverSchema instanceof org.apache.avro.Schema schema) { + yield deserializeGenericRecord(schema, value.getAvroValue().toByteArray()); + } else { + log.error("Unknown schema id {}", value.getSchemaId()); + throw new RuntimeException("Unknown schema id " + value.getSchemaId()); + } + } + case TYPEONEOF_NOT_SET -> null; + }; + } + + protected GrpcAgentRecord fromGrpc(Record record) throws IOException { + List headers = new ArrayList<>(); + for (Header header : record.getHeadersList()) { + headers.add(fromGrpc(header)); + } + return new GrpcAgentRecord( + record.getRecordId(), + fromGrpc(record.getKey()), + fromGrpc(record.getValue()), + record.getOrigin().isEmpty() ? null : record.getOrigin(), + record.hasTimestamp() ? record.getTimestamp() : null, + headers); + } + + protected SimpleRecord.SimpleHeader fromGrpc(Header header) throws IOException { + return SimpleRecord.SimpleHeader.of(header.getName(), fromGrpc(header.getValue())); + } + + protected Record.Builder toGrpc(ai.langstream.api.runner.code.Record record) + throws IOException { + Record.Builder recordBuilder = Record.newBuilder(); + if (record.value() != null) { + recordBuilder.setValue(toGrpc(record.value())); + } + + if (record.key() != null) { + recordBuilder.setKey(toGrpc(record.key())); + } + + if (record.origin() != null) { + recordBuilder.setOrigin(record.origin()); + } + + if (record.timestamp() != null) { + recordBuilder.setTimestamp(record.timestamp()); + } + + if (record.headers() != null) { + for (ai.langstream.api.runner.code.Header h : record.headers()) { + Header.Builder headerBuilder = recordBuilder.addHeadersBuilder().setName(h.key()); + if (h.value() != null) { + headerBuilder.setValue(toGrpc(h.value())); + } + } + } + return recordBuilder; + } + + protected Value toGrpc(Object obj) throws IOException { + if (obj == null) { + return null; + } + Value.Builder valueBuilder = Value.newBuilder(); + if (obj instanceof String value) { + valueBuilder.setStringValue(value); + } else if (obj instanceof byte[] value) { + valueBuilder.setBytesValue(ByteString.copyFrom((value))); + } else if (obj instanceof Boolean value) { + valueBuilder.setBooleanValue(value); + } else if (obj instanceof Byte value) { + valueBuilder.setByteValue(value.intValue()); + } else if (obj instanceof Short value) { + valueBuilder.setShortValue(value.intValue()); + } else if (obj instanceof Integer value) { + valueBuilder.setIntValue(value); + } else if (obj instanceof Long value) { + valueBuilder.setLongValue(value); + } else if (obj instanceof Float value) { + valueBuilder.setFloatValue(value); + } else if (obj instanceof Double value) { + valueBuilder.setDoubleValue(value); + } else if (obj instanceof JsonNode value) { + valueBuilder.setJsonValue(value.toString()); + } else if (obj instanceof GenericRecord genericRecord) { + org.apache.avro.Schema schema = genericRecord.getSchema(); + Integer schemaId = + schemaIds.computeIfAbsent( + schema, + s -> { + int sId = this.schemaId.incrementAndGet(); + onNewSchemaToSend( + Schema.newBuilder() + .setValue( + ByteString.copyFromUtf8(schema.toString())) + .setSchemaId(sId) + .build()); + return sId; + }); + + valueBuilder.setSchemaId(schemaId); + valueBuilder.setAvroValue(ByteString.copyFrom(serializeGenericRecord(genericRecord))); + } else { + throw new IllegalArgumentException("Unsupported type " + obj.getClass()); + } + return valueBuilder.build(); + } + + private static byte[] serializeGenericRecord(GenericRecord record) throws IOException { + GenericDatumWriter writer = new GenericDatumWriter<>(record.getSchema()); + // enable Decimal conversion, otherwise attempting to serialize java.math.BigDecimal will + // throw ClassCastException. + writer.getData().addLogicalTypeConversion(new Conversions.DecimalConversion()); + ByteArrayOutputStream oo = new ByteArrayOutputStream(); + BinaryEncoder encoder = EncoderFactory.get().directBinaryEncoder(oo, null); + writer.write(record, encoder); + return oo.toByteArray(); + } + + private static GenericRecord deserializeGenericRecord( + org.apache.avro.Schema schema, byte[] data) throws IOException { + GenericDatumReader reader = new GenericDatumReader<>(schema); + reader.getData().addLogicalTypeConversion(new Conversions.DecimalConversion()); + Decoder decoder = DecoderFactory.get().binaryDecoder(data, null); + return reader.read(null, decoder); + } +} diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentProcessor.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentProcessor.java index 84bce6d13..e85e8506a 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentProcessor.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentProcessor.java @@ -15,95 +15,51 @@ */ package ai.langstream.agents.grpc; -import ai.langstream.api.runner.code.AbstractAgentCode; -import ai.langstream.api.runner.code.AgentContext; import ai.langstream.api.runner.code.AgentProcessor; import ai.langstream.api.runner.code.RecordSink; -import ai.langstream.api.runner.code.SimpleRecord; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.protobuf.ByteString; -import com.google.protobuf.Empty; import io.grpc.ManagedChannel; import io.grpc.stub.StreamObserver; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import lombok.extern.slf4j.Slf4j; -import org.apache.avro.Conversions; -import org.apache.avro.generic.GenericDatumReader; -import org.apache.avro.generic.GenericDatumWriter; -import org.apache.avro.generic.GenericRecord; -import org.apache.avro.io.BinaryEncoder; -import org.apache.avro.io.Decoder; -import org.apache.avro.io.DecoderFactory; -import org.apache.avro.io.EncoderFactory; @Slf4j -public class GrpcAgentProcessor extends AbstractAgentCode implements AgentProcessor { - protected static final ObjectMapper MAPPER = new ObjectMapper(); - protected ManagedChannel channel; +public class GrpcAgentProcessor extends AbstractGrpcAgent implements AgentProcessor { private StreamObserver request; private RecordSink sink; // For each record sent, we increment the recordId - private final AtomicLong recordId = new AtomicLong(0); + protected final AtomicLong recordId = new AtomicLong(0); // For each record sent, we store the record and the sink to which the result should be emitted private final Map sourceRecords = new ConcurrentHashMap<>(); - // For each schema sent, we increment the schemaId - private final AtomicInteger schemaId = new AtomicInteger(0); - - // Schemas sent to the server - private final Map schemaIds = new ConcurrentHashMap<>(); - - // Schemas received from the server - private final Map serverSchemas = new ConcurrentHashMap<>(); - private final StreamObserver responseObserver = getResponseObserver(); - protected AgentContext agentContext; - protected AgentServiceGrpc.AgentServiceBlockingStub blockingStub; private record RecordAndSink( ai.langstream.api.runner.code.Record sourceRecord, RecordSink sink) {} - public GrpcAgentProcessor() {} - - public GrpcAgentProcessor(ManagedChannel channel) { - this.channel = channel; + public GrpcAgentProcessor() { + super(); } - @Override - public void start() throws Exception { - if (channel == null) { - throw new IllegalStateException("Channel not initialized"); - } - blockingStub = - AgentServiceGrpc.newBlockingStub(channel).withDeadlineAfter(30, TimeUnit.SECONDS); - request = AgentServiceGrpc.newStub(channel).withWaitForReady().process(responseObserver); + public GrpcAgentProcessor(ManagedChannel channel) { + super(channel); } @Override - public void setContext(AgentContext context) throws Exception { - this.agentContext = context; + public synchronized void onNewSchemaToSend(Schema schema) { + request.onNext(ProcessorRequest.newBuilder().setSchema(schema).build()); } @Override - protected Map buildAdditionalInfo() { - try { - return MAPPER.readValue( - blockingStub.agentInfo(Empty.getDefaultInstance()).getJsonInfo(), Map.class); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } + public void start() throws Exception { + super.start(); + request = AgentServiceGrpc.newStub(channel).withWaitForReady().process(responseObserver); } @Override @@ -117,33 +73,7 @@ public synchronized void process( for (ai.langstream.api.runner.code.Record record : records) { long rId = recordId.incrementAndGet(); try { - Record.Builder recordBuilder = requestBuilder.addRecordsBuilder().setRecordId(rId); - - if (record.value() != null) { - recordBuilder.setValue(toGrpc(record.value())); - } - - if (record.key() != null) { - recordBuilder.setKey(toGrpc(record.key())); - } - - if (record.origin() != null) { - recordBuilder.setOrigin(record.origin()); - } - - if (record.timestamp() != null) { - recordBuilder.setTimestamp(record.timestamp()); - } - - if (record.headers() != null) { - for (ai.langstream.api.runner.code.Header h : record.headers()) { - Header.Builder headerBuilder = - recordBuilder.addHeadersBuilder().setName(h.key()); - if (h.value() != null) { - headerBuilder.setValue(toGrpc(h.value())); - } - } - } + requestBuilder.addRecords(toGrpc(record).setRecordId(rId)); sourceRecords.put(rId, new RecordAndSink(record, recordSink)); } catch (Exception e) { recordSink.emit(new SourceRecordAndResult(record, null, e)); @@ -159,37 +89,7 @@ public synchronized void close() throws Exception { if (request != null) { request.onCompleted(); } - if (channel != null) { - channel.shutdown(); - } - } - - private Object fromGrpc(Value value) throws IOException { - if (value == null) { - return null; - } - return switch (value.getTypeOneofCase()) { - case BYTES_VALUE -> value.getBytesValue().toByteArray(); - case BOOLEAN_VALUE -> value.getBooleanValue(); - case STRING_VALUE -> value.getStringValue(); - case BYTE_VALUE -> (byte) value.getByteValue(); - case SHORT_VALUE -> (short) value.getShortValue(); - case INT_VALUE -> value.getIntValue(); - case LONG_VALUE -> value.getLongValue(); - case FLOAT_VALUE -> value.getFloatValue(); - case DOUBLE_VALUE -> value.getDoubleValue(); - case JSON_VALUE -> value.getJsonValue(); - case AVRO_VALUE -> { - Object serverSchema = serverSchemas.get(value.getSchemaId()); - if (serverSchema instanceof org.apache.avro.Schema schema) { - yield deserializeGenericRecord(schema, value.getAvroValue().toByteArray()); - } else { - log.error("Unknown schema id {}", value.getSchemaId()); - throw new RuntimeException("Unknown schema id " + value.getSchemaId()); - } - } - case TYPEONEOF_NOT_SET -> null; - }; + super.close(); } private SourceRecordAndResult fromGrpc( @@ -207,103 +107,6 @@ private SourceRecordAndResult fromGrpc( return new SourceRecordAndResult(sourceRecord, resultRecords, null); } - private SimpleRecord fromGrpc(Record record) throws IOException { - List headers = new ArrayList<>(); - for (Header header : record.getHeadersList()) { - headers.add(fromGrpc(header)); - } - SimpleRecord.SimpleRecordBuilder result = - SimpleRecord.builder() - .key(fromGrpc(record.getKey())) - .value(fromGrpc(record.getValue())) - .headers(headers); - - if (!record.getOrigin().isEmpty()) { - result.origin(record.getOrigin()); - } - - if (record.hasTimestamp()) { - result.timestamp(record.getTimestamp()); - } - - return result.build(); - } - - private ai.langstream.api.runner.code.Header fromGrpc(Header header) throws IOException { - return SimpleRecord.SimpleHeader.of(header.getName(), fromGrpc(header.getValue())); - } - - private Value toGrpc(Object obj) throws IOException { - if (obj == null) { - return null; - } - Value.Builder valueBuilder = Value.newBuilder(); - if (obj instanceof String value) { - valueBuilder.setStringValue(value); - } else if (obj instanceof byte[] value) { - valueBuilder.setBytesValue(ByteString.copyFrom((value))); - } else if (obj instanceof Boolean value) { - valueBuilder.setBooleanValue(value); - } else if (obj instanceof Byte value) { - valueBuilder.setByteValue(value.intValue()); - } else if (obj instanceof Short value) { - valueBuilder.setShortValue(value.intValue()); - } else if (obj instanceof Integer value) { - valueBuilder.setIntValue(value); - } else if (obj instanceof Long value) { - valueBuilder.setLongValue(value); - } else if (obj instanceof Float value) { - valueBuilder.setFloatValue(value); - } else if (obj instanceof Double value) { - valueBuilder.setDoubleValue(value); - } else if (obj instanceof JsonNode value) { - valueBuilder.setJsonValue(value.toString()); - } else if (obj instanceof GenericRecord genericRecord) { - org.apache.avro.Schema schema = genericRecord.getSchema(); - Integer schemaId = - schemaIds.computeIfAbsent( - schema, - s -> { - int sId = this.schemaId.incrementAndGet(); - request.onNext( - ProcessorRequest.newBuilder() - .setSchema( - Schema.newBuilder() - .setValue( - ByteString.copyFromUtf8( - schema.toString())) - .setSchemaId(sId)) - .build()); - return sId; - }); - - valueBuilder.setSchemaId(schemaId); - valueBuilder.setAvroValue(ByteString.copyFrom(serializeGenericRecord(genericRecord))); - } else { - throw new IllegalArgumentException("Unsupported type " + obj.getClass()); - } - return valueBuilder.build(); - } - - private static byte[] serializeGenericRecord(GenericRecord record) throws IOException { - GenericDatumWriter writer = new GenericDatumWriter<>(record.getSchema()); - // enable Decimal conversion, otherwise attempting to serialize java.math.BigDecimal will - // throw ClassCastException. - writer.getData().addLogicalTypeConversion(new Conversions.DecimalConversion()); - ByteArrayOutputStream oo = new ByteArrayOutputStream(); - BinaryEncoder encoder = EncoderFactory.get().directBinaryEncoder(oo, null); - writer.write(record, encoder); - return oo.toByteArray(); - } - - private static GenericRecord deserializeGenericRecord( - org.apache.avro.Schema schema, byte[] data) throws IOException { - GenericDatumReader reader = new GenericDatumReader<>(schema); - reader.getData().addLogicalTypeConversion(new Conversions.DecimalConversion()); - Decoder decoder = DecoderFactory.get().binaryDecoder(data, null); - return reader.read(null, decoder); - } - private StreamObserver getResponseObserver() { return new StreamObserver<>() { @Override diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java new file mode 100644 index 000000000..e4f2b0f13 --- /dev/null +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentSource.java @@ -0,0 +1,120 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.agents.grpc; + +import ai.langstream.api.runner.code.AgentSource; +import ai.langstream.api.runner.code.Record; +import io.grpc.ManagedChannel; +import io.grpc.stub.StreamObserver; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentLinkedQueue; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class GrpcAgentSource extends AbstractGrpcAgent implements AgentSource { + + private static final int MAX_RECORDS_PER_READ = 10_000; + private StreamObserver request; + private final StreamObserver responseObserver; + + // TODO: use a bounded queue ? backpressure ? + private final ConcurrentLinkedQueue readRecords = new ConcurrentLinkedQueue<>(); + + public GrpcAgentSource() { + super(); + this.responseObserver = getResponseObserver(); + } + + public GrpcAgentSource(ManagedChannel channel) { + super(channel); + this.responseObserver = getResponseObserver(); + } + + @Override + public void onNewSchemaToSend(Schema schema) { + throw new UnsupportedOperationException("Shouldn't be called on a source"); + } + + @Override + public void start() throws Exception { + super.start(); + request = AgentServiceGrpc.newStub(channel).withWaitForReady().read(responseObserver); + } + + @Override + public List read() throws Exception { + List read = new ArrayList<>(); + for (int i = 0; i < MAX_RECORDS_PER_READ; i++) { + Record record = readRecords.poll(); + if (record == null) { + break; + } + read.add(record); + } + return read; + } + + @Override + public void commit(List records) throws Exception { + SourceRequest.Builder requestBuilder = SourceRequest.newBuilder(); + for (Record record : records) { + if (record instanceof GrpcAgentRecord grpcAgentRecord) { + requestBuilder.addCommittedRecords(grpcAgentRecord.id()); + } else { + throw new IllegalArgumentException( + "Record %s is not a GrpcAgentRecord".formatted(record)); + } + } + request.onNext(requestBuilder.build()); + } + + private StreamObserver getResponseObserver() { + return new StreamObserver<>() { + @Override + public void onNext(SourceResponse response) { + if (response.hasSchema()) { + org.apache.avro.Schema schema = + new org.apache.avro.Schema.Parser() + .parse(response.getSchema().getValue().toStringUtf8()); + serverSchemas.put(response.getSchema().getSchemaId(), schema); + } + try { + for (ai.langstream.agents.grpc.Record record : response.getRecordsList()) { + readRecords.add(fromGrpc(record)); + } + } catch (Exception e) { + agentContext.criticalFailure( + new RuntimeException("Error while processing records", e)); + } + } + + @Override + public void onError(Throwable throwable) { + agentContext.criticalFailure( + new RuntimeException( + "gRPC server sent error: %s".formatted(throwable.getMessage()), + throwable)); + } + + @Override + public void onCompleted() { + agentContext.criticalFailure( + new RuntimeException("gRPC server completed the stream unexpectedly")); + } + }; + } +} diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentsCodeProvider.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentsCodeProvider.java index 4d3745641..9b1eebfdb 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentsCodeProvider.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/GrpcAgentsCodeProvider.java @@ -17,15 +17,24 @@ import ai.langstream.api.runner.code.AgentCode; import ai.langstream.api.runner.code.AgentCodeProvider; +import java.util.Set; public class GrpcAgentsCodeProvider implements AgentCodeProvider { + + private static final Set SUPPORTED_AGENT_TYPES = + Set.of("experimental-python-source", "experimental-python-processor"); + @Override public boolean supports(String agentType) { - return "experimental-python-processor".equals(agentType); + return SUPPORTED_AGENT_TYPES.contains(agentType); } @Override public AgentCode createInstance(String agentType) { - return new PythonGrpcAgentProcessor(); + return switch (agentType) { + case "experimental-python-source" -> new PythonGrpcAgentSource(); + case "experimental-python-processor" -> new PythonGrpcAgentProcessor(); + default -> throw new IllegalStateException("Unexpected agent type: " + agentType); + }; } } diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentProcessor.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentProcessor.java index aff0e12e0..2c57c0c2c 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentProcessor.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentProcessor.java @@ -15,92 +15,29 @@ */ package ai.langstream.agents.grpc; -import com.google.protobuf.Empty; -import io.grpc.ManagedChannelBuilder; -import java.net.ServerSocket; -import java.nio.file.Path; import java.util.Map; -import java.util.concurrent.TimeUnit; -import lombok.extern.slf4j.Slf4j; -@Slf4j public class PythonGrpcAgentProcessor extends GrpcAgentProcessor { + + private PythonGrpcServer server; private Map configuration; - private Process pythonProcess; @Override - public void init(Map configuration) { + public void init(Map configuration) throws Exception { + super.init(configuration); this.configuration = configuration; } @Override public void start() throws Exception { - // Get a free port - int port; - try (ServerSocket socket = new ServerSocket(0)) { - socket.setReuseAddress(true); - port = socket.getLocalPort(); - } - - Path pythonCodeDirectory = agentContext.getCodeDirectory().resolve("python"); - log.info("Python code directory {}", pythonCodeDirectory); - - final String pythonPath = System.getenv("PYTHONPATH"); - final String newPythonPath = - "%s:%s:%s" - .formatted( - pythonPath, - pythonCodeDirectory.toAbsolutePath(), - pythonCodeDirectory.resolve("lib").toAbsolutePath()); - - // copy input/output to standard input/output of the java process - // this allows to use "kubectl logs" easily - ProcessBuilder processBuilder = - new ProcessBuilder( - "python3", - "-m", - "langstream_grpc", - "[::]:%s".formatted(port), - MAPPER.writeValueAsString(configuration)) - .inheritIO() - .redirectOutput(ProcessBuilder.Redirect.INHERIT) - .redirectError(ProcessBuilder.Redirect.INHERIT); - processBuilder.environment().put("PYTHONPATH", newPythonPath); - processBuilder.environment().put("NLTK_DATA", "/app/nltk_data"); - pythonProcess = processBuilder.start(); - this.channel = - ManagedChannelBuilder.forAddress("localhost", port) - .directExecutor() - .usePlaintext() - .build(); - AgentServiceGrpc.AgentServiceBlockingStub stub = - AgentServiceGrpc.newBlockingStub(channel).withDeadlineAfter(30, TimeUnit.SECONDS); - for (int i = 0; ; i++) { - try { - stub.agentInfo(Empty.getDefaultInstance()); - break; - } catch (Exception e) { - if (i > 8) { - throw e; - } - log.info("Waiting for python agent to start"); - Thread.sleep(1000); - } - } + server = new PythonGrpcServer(agentContext.getCodeDirectory(), configuration); + channel = server.start(); super.start(); } @Override - public void close() throws Exception { + public synchronized void close() throws Exception { + if (server != null) server.close(); super.close(); - if (pythonProcess != null) { - pythonProcess.destroy(); - int exitCode = pythonProcess.waitFor(); - log.info("Python process exited with code {}", exitCode); - - if (exitCode != 0) { - throw new RuntimeException("Python code exited with code " + exitCode); - } - } } } diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSource.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSource.java new file mode 100644 index 000000000..210b32561 --- /dev/null +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcAgentSource.java @@ -0,0 +1,43 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.agents.grpc; + +import java.util.Map; + +public class PythonGrpcAgentSource extends GrpcAgentSource { + + private PythonGrpcServer server; + private Map configuration; + + @Override + public void init(Map configuration) throws Exception { + super.init(configuration); + this.configuration = configuration; + } + + @Override + public void start() throws Exception { + server = new PythonGrpcServer(agentContext.getCodeDirectory(), configuration); + channel = server.start(); + super.start(); + } + + @Override + public synchronized void close() throws Exception { + if (server != null) server.close(); + super.close(); + } +} diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcServer.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcServer.java new file mode 100644 index 000000000..d6b91bb27 --- /dev/null +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/PythonGrpcServer.java @@ -0,0 +1,108 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.agents.grpc; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.protobuf.Empty; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import java.net.ServerSocket; +import java.nio.file.Path; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class PythonGrpcServer { + private static final ObjectMapper MAPPER = new ObjectMapper(); + + private final Path codeDirectory; + private final Map configuration; + private Process pythonProcess; + + public PythonGrpcServer(Path codeDirectory, Map configuration) { + this.codeDirectory = codeDirectory; + this.configuration = configuration; + } + + public ManagedChannel start() throws Exception { + // Get a free port + int port; + try (ServerSocket socket = new ServerSocket(0)) { + socket.setReuseAddress(true); + port = socket.getLocalPort(); + } + + Path pythonCodeDirectory = codeDirectory.resolve("python"); + log.info("Python code directory {}", pythonCodeDirectory); + + final String pythonPath = System.getenv("PYTHONPATH"); + final String newPythonPath = + "%s:%s:%s" + .formatted( + pythonPath, + pythonCodeDirectory.toAbsolutePath(), + pythonCodeDirectory.resolve("lib").toAbsolutePath()); + + // copy input/output to standard input/output of the java process + // this allows to use "kubectl logs" easily + ProcessBuilder processBuilder = + new ProcessBuilder( + "python3", + "-m", + "langstream_grpc", + "[::]:%s".formatted(port), + MAPPER.writeValueAsString(configuration)) + .inheritIO() + .redirectOutput(ProcessBuilder.Redirect.INHERIT) + .redirectError(ProcessBuilder.Redirect.INHERIT); + processBuilder.environment().put("PYTHONPATH", newPythonPath); + processBuilder.environment().put("NLTK_DATA", "/app/nltk_data"); + pythonProcess = processBuilder.start(); + ManagedChannel channel = + ManagedChannelBuilder.forAddress("localhost", port) + .directExecutor() + .usePlaintext() + .build(); + AgentServiceGrpc.AgentServiceBlockingStub stub = + AgentServiceGrpc.newBlockingStub(channel).withDeadlineAfter(30, TimeUnit.SECONDS); + for (int i = 0; ; i++) { + try { + stub.agentInfo(Empty.getDefaultInstance()); + break; + } catch (Exception e) { + if (i > 8) { + throw e; + } + log.info("Waiting for python agent to start"); + Thread.sleep(1000); + } + } + return channel; + } + + public void close() throws Exception { + if (pythonProcess != null) { + pythonProcess.destroy(); + int exitCode = pythonProcess.waitFor(); + log.info("Python process exited with code {}", exitCode); + + if (exitCode != 0) { + throw new RuntimeException("Python code exited with code " + exitCode); + } + } + } +} diff --git a/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto b/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto index 3a9bfde07..4bdbebfc5 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto +++ b/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto @@ -23,6 +23,7 @@ import "google/protobuf/empty.proto"; service AgentService { rpc agent_info(google.protobuf.Empty) returns (InfoResponse) {} + rpc read(stream SourceRequest) returns (stream SourceResponse) {} rpc process(stream ProcessorRequest) returns (stream ProcessorResponse) {} } @@ -66,6 +67,16 @@ message Record { optional int64 timestamp = 6; } +message SourceRequest { + repeated int64 committed_records = 1; +} + +message SourceResponse { + Schema schema = 1; + repeated Record records = 2; +} + + message ProcessorRequest { Schema schema = 1; repeated Record records = 2; diff --git a/langstream-agents/langstream-agent-grpc/src/main/resources/META-INF/ai.langstream.agents.index b/langstream-agents/langstream-agent-grpc/src/main/resources/META-INF/ai.langstream.agents.index index 037285d44..babae276d 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/resources/META-INF/ai.langstream.agents.index +++ b/langstream-agents/langstream-agent-grpc/src/main/resources/META-INF/ai.langstream.agents.index @@ -1 +1,2 @@ +experimental-python-source experimental-python-processor \ No newline at end of file diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java index 4d9ed0205..091d7d39d 100644 --- a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java @@ -16,6 +16,7 @@ package ai.langstream.agents.grpc; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertSame; @@ -193,9 +194,11 @@ void testProcess(Object value, Object key, Object header) throws Exception { @Test void testEmpty() throws Exception { GrpcAgentProcessor processor = new GrpcAgentProcessor(channel); - processor.setContext(new TestAgentContext()); + TestAgentContext context = new TestAgentContext(); + processor.setContext(context); processor.start(); assertProcessSuccessful(processor, SimpleRecord.builder().build()); + assertFalse(context.failureCalled.await(1, TimeUnit.SECONDS)); processor.close(); } @@ -235,7 +238,7 @@ void testServerError(String origin) throws Exception { processor.start(); processor.process(List.of(inputRecord), result -> {}); - assertTrue(testAgentContext.failureCalled.await(5, TimeUnit.SECONDS)); + assertTrue(testAgentContext.failureCalled.await(1, TimeUnit.SECONDS)); processor.close(); } diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java new file mode 100644 index 000000000..8598c4c86 --- /dev/null +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java @@ -0,0 +1,258 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.agents.grpc; + +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import ai.langstream.api.runner.code.AgentContext; +import ai.langstream.api.runner.code.Record; +import ai.langstream.api.runner.topics.TopicAdmin; +import ai.langstream.api.runner.topics.TopicConnectionProvider; +import ai.langstream.api.runner.topics.TopicConsumer; +import ai.langstream.api.runner.topics.TopicProducer; +import com.google.protobuf.ByteString; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.StreamObserver; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.apache.avro.Conversions; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.BinaryEncoder; +import org.apache.avro.io.EncoderFactory; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class GrpcAgentSourceTest { + private Server server; + private ManagedChannel channel; + + private final TestSourceService testSourceService = new TestSourceService(); + + @BeforeEach + public void setUp() throws Exception { + String serverName = InProcessServerBuilder.generateName(); + server = + InProcessServerBuilder.forName(serverName) + .directExecutor() + .addService(testSourceService) + .build() + .start(); + + channel = InProcessChannelBuilder.forName(serverName).directExecutor().build(); + } + + @AfterEach + public void tearDown() throws Exception { + channel.shutdownNow(); + server.shutdownNow(); + channel.awaitTermination(30, TimeUnit.SECONDS); + server.awaitTermination(30, TimeUnit.SECONDS); + } + + @Test + void testCommit() throws Exception { + GrpcAgentSource source = new GrpcAgentSource(channel); + TestAgentContext context = new TestAgentContext(); + source.setContext(context); + source.start(); + List read = readRecords(source, 3); + source.commit(List.of(read.get(0))); + assertFalse(context.failureCalled.await(1, TimeUnit.SECONDS)); + assertEquals(1, testSourceService.committedRecords.size()); + assertEquals(42, testSourceService.committedRecords.get(0)); + source.close(); + } + + @Test + void testSourceGrpcError() throws Exception { + GrpcAgentSource source = new GrpcAgentSource(channel); + TestAgentContext context = new TestAgentContext(); + source.setContext(context); + source.start(); + List read = readRecords(source, 3); + source.commit(List.of(read.get(1))); + assertTrue(context.failureCalled.await(1, TimeUnit.SECONDS)); + source.close(); + } + + @Test + void testSourceGrpcCompletedUnexpectedly() throws Exception { + GrpcAgentSource source = new GrpcAgentSource(channel); + TestAgentContext context = new TestAgentContext(); + source.setContext(context); + source.start(); + List read = readRecords(source, 3); + source.commit(List.of(read.get(2))); + assertTrue(context.failureCalled.await(1, TimeUnit.SECONDS)); + source.close(); + } + + @Test + void testAvroAndSchema() throws Exception { + GrpcAgentSource source = new GrpcAgentSource(channel); + source.setContext(new TestAgentContext()); + source.start(); + List read = readRecords(source, 1); + GenericRecord record = (GenericRecord) read.get(0).value(); + assertEquals("test-string", record.get("testField").toString()); + source.close(); + } + + static List readRecords(GrpcAgentSource source, int numberOfRecords) { + List read = new ArrayList<>(); + await().atMost(5, TimeUnit.SECONDS) + .until( + () -> { + read.addAll(source.read()); + return read.size() >= numberOfRecords; + }); + return read; + } + + static byte[] serializeGenericRecord(GenericRecord record) throws IOException { + GenericDatumWriter writer = new GenericDatumWriter<>(record.getSchema()); + // enable Decimal conversion, otherwise attempting to serialize java.math.BigDecimal will + // throw ClassCastException. + writer.getData().addLogicalTypeConversion(new Conversions.DecimalConversion()); + ByteArrayOutputStream oo = new ByteArrayOutputStream(); + BinaryEncoder encoder = EncoderFactory.get().directBinaryEncoder(oo, null); + writer.write(record, encoder); + return oo.toByteArray(); + } + + static class TestSourceService extends AgentServiceGrpc.AgentServiceImplBase { + + final List committedRecords = new CopyOnWriteArrayList<>(); + + @Override + public StreamObserver read(StreamObserver responseObserver) { + + String schemaStr = + "{\"type\":\"record\",\"name\":\"testRecord\",\"fields\":[{\"name\":\"testField\",\"type\":\"string\"}]}"; + org.apache.avro.Schema avroSchema = + new org.apache.avro.Schema.Parser().parse(schemaStr); + GenericData.Record avroRecord = new GenericData.Record(avroSchema); + avroRecord.put("testField", "test-string"); + try { + responseObserver.onNext( + SourceResponse.newBuilder() + .setSchema( + Schema.newBuilder() + .setValue(ByteString.copyFromUtf8(schemaStr)) + .setSchemaId(42) + .build()) + .addRecords( + ai.langstream.agents.grpc.Record.newBuilder() + .setRecordId(42) + .setValue( + Value.newBuilder() + .setSchemaId(42) + .setAvroValue( + ByteString.copyFrom( + serializeGenericRecord( + avroRecord))))) + .build()); + responseObserver.onNext( + SourceResponse.newBuilder() + .addRecords( + ai.langstream.agents.grpc.Record.newBuilder() + .setRecordId(43)) + .build()); + responseObserver.onNext( + SourceResponse.newBuilder() + .addRecords( + ai.langstream.agents.grpc.Record.newBuilder() + .setRecordId(44)) + .build()); + } catch (IOException e) { + responseObserver.onError(e); + } + + return new StreamObserver<>() { + @Override + public void onNext(SourceRequest request) { + committedRecords.addAll(request.getCommittedRecordsList()); + if (request.getCommittedRecordsList().contains(43L)) { + responseObserver.onError(new RuntimeException("test error")); + } else if (request.getCommittedRecordsList().contains(44L)) { + responseObserver.onCompleted(); + } + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + }; + } + } + + static class TestAgentContext implements AgentContext { + + private final CountDownLatch failureCalled = new CountDownLatch(1); + + @Override + public TopicConsumer getTopicConsumer() { + return null; + } + + @Override + public TopicProducer getTopicProducer() { + return null; + } + + @Override + public String getGlobalAgentId() { + return null; + } + + @Override + public TopicAdmin getTopicAdmin() { + return null; + } + + @Override + public TopicConnectionProvider getTopicConnectionProvider() { + return null; + } + + @Override + public void criticalFailure(Throwable error) { + failureCalled.countDown(); + } + + @Override + public Path getCodeDirectory() { + return null; + } + } +} diff --git a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/GrpcAgentsProvider.java b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/GrpcAgentsProvider.java index 6631364ce..433f83667 100644 --- a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/GrpcAgentsProvider.java +++ b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/agents/GrpcAgentsProvider.java @@ -28,7 +28,7 @@ public class GrpcAgentsProvider extends AbstractComposableAgentProvider { private static final Set SUPPORTED_AGENT_TYPES = - Set.of("experimental-python-processor"); + Set.of("experimental-python-source", "experimental-python-processor"); public GrpcAgentsProvider() { super(SUPPORTED_AGENT_TYPES, List.of(KubernetesClusterRuntime.CLUSTER_TYPE, "none")); @@ -36,6 +36,11 @@ public GrpcAgentsProvider() { @Override protected ComponentType getComponentType(AgentConfiguration agentConfiguration) { - return ComponentType.PROCESSOR; + return switch (agentConfiguration.getType()) { + case "experimental-python-source" -> ComponentType.SOURCE; + case "experimental-python-processor" -> ComponentType.PROCESSOR; + default -> throw new IllegalStateException( + "Unexpected agent type: " + agentConfiguration.getType()); + }; } }