From a732e3e3ebd9c3b8149735b23b022989451f75f9 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 22 Nov 2023 14:23:56 +0100 Subject: [PATCH] Add support for schemas in python topic producer (#734) --- .../agents/grpc/AbstractGrpcAgent.java | 98 +++++++++++-------- .../proto/langstream_grpc/proto/agent.proto | 5 +- .../agents/grpc/AbstractGrpcAgentTest.java | 73 ++++++++++++-- .../agents/grpc/GrpcAgentProcessorTest.java | 2 +- .../agents/grpc/GrpcAgentSinkTest.java | 2 +- .../agents/grpc/GrpcAgentSourceTest.java | 2 +- .../python/langstream_grpc/grpc_service.py | 9 +- .../python/langstream_grpc/proto/agent_pb2.py | 42 ++++---- .../langstream_grpc/proto/agent_pb2.pyi | 7 +- .../langstream_grpc/proto/agent_pb2_grpc.py | 6 +- .../tests/test_grpc_topic_producer.py | 59 +++++++++-- 11 files changed, 213 insertions(+), 92 deletions(-) diff --git a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java index c62aef757..acdf2fe26 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java +++ b/langstream-agents/langstream-agent-grpc/src/main/java/ai/langstream/agents/grpc/AbstractGrpcAgent.java @@ -107,47 +107,65 @@ public void start() throws Exception { asyncStub.getTopicProducerRecords( new StreamObserver<>() { @Override - public void onNext(TopicProducerRecord topicProducerRecord) { - TopicProducer topicProducer = - topicProducers.computeIfAbsent( - topicProducerRecord.getTopic(), - topic -> { - TopicProducer tp = - agentContext - .getTopicConnectionProvider() - .createProducer( - agentContext - .getGlobalAgentId(), - topic, - Map.of()); - tp.start(); - return tp; - }); + public void onNext(TopicProducerResponse topicProducerResponse) { try { - topicProducer - .write(fromGrpc(topicProducerRecord.getRecord())) - .whenComplete( - (r, e) -> { - if (e != null) { - log.error("Error writing record", e); - sendTopicProducerWriteResult( - TopicProducerWriteResult - .newBuilder() - .setError( - e - .getMessage())); - } else { - sendTopicProducerWriteResult( - TopicProducerWriteResult - .newBuilder() - .setRecordId( - topicProducerRecord - .getRecord() - .getRecordId())); - } - }); - } catch (IOException e) { - agentContext.criticalFailure(e); + if (topicProducerResponse.hasSchema()) { + serverSchemas.put( + topicProducerResponse.getSchema().getSchemaId(), + new org.apache.avro.Schema.Parser() + .parse( + topicProducerResponse + .getSchema() + .getValue() + .toStringUtf8())); + } + if (topicProducerResponse.hasRecord() + && !"".equals(topicProducerResponse.getTopic())) { + TopicProducer topicProducer = + topicProducers.computeIfAbsent( + topicProducerResponse.getTopic(), + topic -> { + TopicProducer tp = + agentContext + .getTopicConnectionProvider() + .createProducer( + agentContext + .getGlobalAgentId(), + topic, + Map.of()); + tp.start(); + return tp; + }); + topicProducer + .write(fromGrpc(topicProducerResponse.getRecord())) + .whenComplete( + (r, e) -> { + if (e != null) { + log.error( + "Error writing record", e); + sendTopicProducerWriteResult( + TopicProducerWriteResult + .newBuilder() + .setError( + e + .getMessage())); + } else { + sendTopicProducerWriteResult( + TopicProducerWriteResult + .newBuilder() + .setRecordId( + topicProducerResponse + .getRecord() + .getRecordId())); + } + }); + } + } catch (Exception e) { + agentContext.criticalFailure( + new RuntimeException( + "getTopicProducerRecords: Error while processing TopicProducerResponse: %s" + .formatted(e.getMessage()), + e)); } } diff --git a/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto b/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto index 5edc60a25..8c1afec9e 100644 --- a/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto +++ b/langstream-agents/langstream-agent-grpc/src/main/proto/langstream_grpc/proto/agent.proto @@ -26,7 +26,7 @@ service AgentService { rpc read(stream SourceRequest) returns (stream SourceResponse) {} rpc process(stream ProcessorRequest) returns (stream ProcessorResponse) {} rpc write(stream SinkRequest) returns (stream SinkResponse) {} - rpc get_topic_producer_records(stream TopicProducerWriteResult) returns (stream TopicProducerRecord) {} + rpc get_topic_producer_records(stream TopicProducerWriteResult) returns (stream TopicProducerResponse) {} } message InfoResponse { @@ -74,8 +74,9 @@ message TopicProducerWriteResult { optional string error = 2; } -message TopicProducerRecord { +message TopicProducerResponse { string topic = 1; + Schema schema = 2; Record record = 3; } diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/AbstractGrpcAgentTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/AbstractGrpcAgentTest.java index c05b23755..c0506b9ec 100644 --- a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/AbstractGrpcAgentTest.java +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/AbstractGrpcAgentTest.java @@ -23,6 +23,7 @@ import ai.langstream.api.runner.topics.TopicConnectionProvider; import ai.langstream.api.runner.topics.TopicConsumer; import ai.langstream.api.runner.topics.TopicProducer; +import com.google.protobuf.ByteString; import com.google.protobuf.Empty; import io.grpc.ManagedChannel; import io.grpc.Server; @@ -30,12 +31,21 @@ import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.stub.StreamObserver; +import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.nio.file.Path; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import lombok.extern.slf4j.Slf4j; +import org.apache.avro.Conversions; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.BinaryEncoder; +import org.apache.avro.io.EncoderFactory; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -62,17 +72,47 @@ public void agentInfo( @Override public StreamObserver getTopicProducerRecords( - StreamObserver responseObserver) { + StreamObserver responseObserver) { + org.apache.avro.Schema schema = + SchemaBuilder.record("testRecord") + .fields() + .name("testField") + .type() + .stringType() + .noDefault() + .endRecord(); + GenericData.Record avroRecord = new GenericData.Record(schema); + avroRecord.put("testField", "test-string"); + responseObserver.onNext( - TopicProducerRecord.newBuilder() - .setTopic("test-topic") - .setRecord( - ai.langstream.agents.grpc.Record.newBuilder() - .setRecordId(42) + TopicProducerResponse.newBuilder() + .setSchema( + Schema.newBuilder() + .setSchemaId(123) .setValue( - Value.newBuilder() - .setStringValue("test-value1"))) + ByteString.copyFromUtf8( + schema.toString())) + .build()) .build()); + + try { + responseObserver.onNext( + TopicProducerResponse.newBuilder() + .setTopic("test-topic") + .setRecord( + ai.langstream.agents.grpc.Record.newBuilder() + .setRecordId(42) + .setValue( + Value.newBuilder() + .setSchemaId(123) + .setAvroValue( + ByteString.copyFrom( + serializeGenericRecord( + avroRecord))))) + .build()); + } catch (IOException e) { + throw new RuntimeException(e); + } return new StreamObserver<>() { @Override @@ -88,7 +128,7 @@ public void onNext(TopicProducerWriteResult topicProducerWriteResult) { responseObserver.onCompleted(); } else if (topicProducerWriteResult.getRecordId() == 42) { responseObserver.onNext( - TopicProducerRecord.newBuilder() + TopicProducerResponse.newBuilder() .setTopic("test-topic") .setRecord( ai.langstream.agents.grpc.Record @@ -164,7 +204,9 @@ void testTopicProducerSuccess() throws Exception { TestAgentContext context = new TestAgentContextSuccessful(); startProcessor(context); LinkedBlockingQueue recordsToWrite = context.recordsToWrite; - assertEquals("test-value1", recordsToWrite.poll(5, TimeUnit.SECONDS).value().toString()); + assertEquals( + "{\"testField\": \"test-string\"}", + recordsToWrite.poll(5, TimeUnit.SECONDS).value().toString()); assertEquals("test-value2", recordsToWrite.poll(5, TimeUnit.SECONDS).value().toString()); } @@ -275,4 +317,15 @@ protected CompletableFuture writeRecord(Record record) { return CompletableFuture.failedFuture(new RuntimeException("test-complete")); } } + + private static byte[] serializeGenericRecord(GenericRecord record) throws IOException { + GenericDatumWriter writer = new GenericDatumWriter<>(record.getSchema()); + // enable Decimal conversion, otherwise attempting to serialize java.math.BigDecimal will + // throw ClassCastException. + writer.getData().addLogicalTypeConversion(new Conversions.DecimalConversion()); + ByteArrayOutputStream oo = new ByteArrayOutputStream(); + BinaryEncoder encoder = EncoderFactory.get().directBinaryEncoder(oo, null); + writer.write(record, encoder); + return oo.toByteArray(); + } } diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java index cf8994324..6ff615da2 100644 --- a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentProcessorTest.java @@ -124,7 +124,7 @@ public void onCompleted() { @Override public StreamObserver getTopicProducerRecords( - StreamObserver responseObserver) { + StreamObserver responseObserver) { return new StreamObserver<>() { @Override public void onNext(TopicProducerWriteResult topicProducerWriteResult) {} diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSinkTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSinkTest.java index 7269ff858..1b76ecc4d 100644 --- a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSinkTest.java +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSinkTest.java @@ -187,7 +187,7 @@ public void onCompleted() { @Override public StreamObserver getTopicProducerRecords( - StreamObserver responseObserver) { + StreamObserver responseObserver) { return new StreamObserver<>() { @Override public void onNext(TopicProducerWriteResult topicProducerWriteResult) {} diff --git a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java index 74159cb0c..4a7fc104f 100644 --- a/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java +++ b/langstream-agents/langstream-agent-grpc/src/test/java/ai/langstream/agents/grpc/GrpcAgentSourceTest.java @@ -219,7 +219,7 @@ public void onCompleted() { @Override public StreamObserver getTopicProducerRecords( - StreamObserver responseObserver) { + StreamObserver responseObserver) { return new StreamObserver<>() { @Override public void onNext(TopicProducerWriteResult topicProducerWriteResult) {} diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py index 089471550..8518f6dbc 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py @@ -42,8 +42,8 @@ SourceResponse, SinkRequest, SinkResponse, - TopicProducerRecord, TopicProducerWriteResult, + TopicProducerResponse, ) from langstream_grpc.proto.agent_pb2_grpc import AgentServiceServicer from .api import Source, Sink, Processor, Record, Agent, AgentContext, TopicProducer @@ -91,12 +91,13 @@ async def agent_info(self, _, __): async def poll_topic_producer_records(self, context): while True: topic, record, future = await self.topic_producer_records.get() - # TODO: handle schemas - _, grpc_record = self.to_grpc_record(record) + schemas, grpc_record = self.to_grpc_record(record) + for schema in schemas: + await context.write(TopicProducerResponse(schema=schema)) self.topic_producer_record_id += 1 self.topic_producer_records_pending[self.topic_producer_record_id] = future grpc_record.record_id = self.topic_producer_record_id - await context.write(TopicProducerRecord(topic=topic, record=grpc_record)) + await context.write(TopicProducerResponse(topic=topic, record=grpc_record)) async def handle_write_results(self, context): write_result = await context.read() diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py index ee36bdb57..525be91a7 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.py @@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n!langstream_grpc/proto/agent.proto\x1a\x1bgoogle/protobuf/empty.proto"!\n\x0cInfoResponse\x12\x11\n\tjson_info\x18\x01 \x01(\t"\xa3\x02\n\x05Value\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\x15\n\x0b\x62ytes_value\x18\x02 \x01(\x0cH\x00\x12\x17\n\rboolean_value\x18\x03 \x01(\x08H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x14\n\nbyte_value\x18\x05 \x01(\x05H\x00\x12\x15\n\x0bshort_value\x18\x06 \x01(\x05H\x00\x12\x13\n\tint_value\x18\x07 \x01(\x05H\x00\x12\x14\n\nlong_value\x18\x08 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\t \x01(\x02H\x00\x12\x16\n\x0c\x64ouble_value\x18\n \x01(\x01H\x00\x12\x14\n\njson_value\x18\x0b \x01(\tH\x00\x12\x14\n\navro_value\x18\x0c \x01(\x0cH\x00\x42\x0c\n\ntype_oneof"-\n\x06Header\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x05value\x18\x02 \x01(\x0b\x32\x06.Value"*\n\x06Schema\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x0c"\xb3\x01\n\x06Record\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x18\n\x03key\x18\x02 \x01(\x0b\x32\x06.ValueH\x00\x88\x01\x01\x12\x1a\n\x05value\x18\x03 \x01(\x0b\x32\x06.ValueH\x01\x88\x01\x01\x12\x18\n\x07headers\x18\x04 \x03(\x0b\x32\x07.Header\x12\x0e\n\x06origin\x18\x05 \x01(\t\x12\x16\n\ttimestamp\x18\x06 \x01(\x03H\x02\x88\x01\x01\x42\x06\n\x04_keyB\x08\n\x06_valueB\x0c\n\n_timestamp"K\n\x18TopicProducerWriteResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error"=\n\x13TopicProducerRecord\x12\r\n\x05topic\x18\x01 \x01(\t\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record"<\n\x10PermanentFailure\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x15\n\rerror_message\x18\x02 \x01(\t"X\n\rSourceRequest\x12\x19\n\x11\x63ommitted_records\x18\x01 \x03(\x03\x12,\n\x11permanent_failure\x18\x02 \x01(\x0b\x32\x11.PermanentFailure"C\n\x0eSourceResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"E\n\x10ProcessorRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"O\n\x11ProcessorResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12!\n\x07results\x18\x02 \x03(\x0b\x32\x10.ProcessorResult"\\\n\x0fProcessorResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x07records\x18\x03 \x03(\x0b\x32\x07.RecordB\x08\n\x06_error"?\n\x0bSinkRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x17\n\x06record\x18\x02 \x01(\x0b\x32\x07.Record"?\n\x0cSinkResponse\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error2\xad\x02\n\x0c\x41gentService\x12\x35\n\nagent_info\x12\x16.google.protobuf.Empty\x1a\r.InfoResponse"\x00\x12-\n\x04read\x12\x0e.SourceRequest\x1a\x0f.SourceResponse"\x00(\x01\x30\x01\x12\x36\n\x07process\x12\x11.ProcessorRequest\x1a\x12.ProcessorResponse"\x00(\x01\x30\x01\x12*\n\x05write\x12\x0c.SinkRequest\x1a\r.SinkResponse"\x00(\x01\x30\x01\x12S\n\x1aget_topic_producer_records\x12\x19.TopicProducerWriteResult\x1a\x14.TopicProducerRecord"\x00(\x01\x30\x01\x42\x1d\n\x19\x61i.langstream.agents.grpcP\x01\x62\x06proto3' + b'\n!langstream_grpc/proto/agent.proto\x1a\x1bgoogle/protobuf/empty.proto"!\n\x0cInfoResponse\x12\x11\n\tjson_info\x18\x01 \x01(\t"\xa3\x02\n\x05Value\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\x15\n\x0b\x62ytes_value\x18\x02 \x01(\x0cH\x00\x12\x17\n\rboolean_value\x18\x03 \x01(\x08H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x14\n\nbyte_value\x18\x05 \x01(\x05H\x00\x12\x15\n\x0bshort_value\x18\x06 \x01(\x05H\x00\x12\x13\n\tint_value\x18\x07 \x01(\x05H\x00\x12\x14\n\nlong_value\x18\x08 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\t \x01(\x02H\x00\x12\x16\n\x0c\x64ouble_value\x18\n \x01(\x01H\x00\x12\x14\n\njson_value\x18\x0b \x01(\tH\x00\x12\x14\n\navro_value\x18\x0c \x01(\x0cH\x00\x42\x0c\n\ntype_oneof"-\n\x06Header\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x05value\x18\x02 \x01(\x0b\x32\x06.Value"*\n\x06Schema\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x0c"\xb3\x01\n\x06Record\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x18\n\x03key\x18\x02 \x01(\x0b\x32\x06.ValueH\x00\x88\x01\x01\x12\x1a\n\x05value\x18\x03 \x01(\x0b\x32\x06.ValueH\x01\x88\x01\x01\x12\x18\n\x07headers\x18\x04 \x03(\x0b\x32\x07.Header\x12\x0e\n\x06origin\x18\x05 \x01(\t\x12\x16\n\ttimestamp\x18\x06 \x01(\x03H\x02\x88\x01\x01\x42\x06\n\x04_keyB\x08\n\x06_valueB\x0c\n\n_timestamp"K\n\x18TopicProducerWriteResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error"X\n\x15TopicProducerResponse\x12\r\n\x05topic\x18\x01 \x01(\t\x12\x17\n\x06schema\x18\x02 \x01(\x0b\x32\x07.Schema\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record"<\n\x10PermanentFailure\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x15\n\rerror_message\x18\x02 \x01(\t"X\n\rSourceRequest\x12\x19\n\x11\x63ommitted_records\x18\x01 \x03(\x03\x12,\n\x11permanent_failure\x18\x02 \x01(\x0b\x32\x11.PermanentFailure"C\n\x0eSourceResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"E\n\x10ProcessorRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"O\n\x11ProcessorResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12!\n\x07results\x18\x02 \x03(\x0b\x32\x10.ProcessorResult"\\\n\x0fProcessorResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x07records\x18\x03 \x03(\x0b\x32\x07.RecordB\x08\n\x06_error"?\n\x0bSinkRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x17\n\x06record\x18\x02 \x01(\x0b\x32\x07.Record"?\n\x0cSinkResponse\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error2\xaf\x02\n\x0c\x41gentService\x12\x35\n\nagent_info\x12\x16.google.protobuf.Empty\x1a\r.InfoResponse"\x00\x12-\n\x04read\x12\x0e.SourceRequest\x1a\x0f.SourceResponse"\x00(\x01\x30\x01\x12\x36\n\x07process\x12\x11.ProcessorRequest\x1a\x12.ProcessorResponse"\x00(\x01\x30\x01\x12*\n\x05write\x12\x0c.SinkRequest\x1a\r.SinkResponse"\x00(\x01\x30\x01\x12U\n\x1aget_topic_producer_records\x12\x19.TopicProducerWriteResult\x1a\x16.TopicProducerResponse"\x00(\x01\x30\x01\x42\x1d\n\x19\x61i.langstream.agents.grpcP\x01\x62\x06proto3' ) _globals = globals() @@ -55,24 +55,24 @@ _globals["_RECORD"]._serialized_end = 666 _globals["_TOPICPRODUCERWRITERESULT"]._serialized_start = 668 _globals["_TOPICPRODUCERWRITERESULT"]._serialized_end = 743 - _globals["_TOPICPRODUCERRECORD"]._serialized_start = 745 - _globals["_TOPICPRODUCERRECORD"]._serialized_end = 806 - _globals["_PERMANENTFAILURE"]._serialized_start = 808 - _globals["_PERMANENTFAILURE"]._serialized_end = 868 - _globals["_SOURCEREQUEST"]._serialized_start = 870 - _globals["_SOURCEREQUEST"]._serialized_end = 958 - _globals["_SOURCERESPONSE"]._serialized_start = 960 - _globals["_SOURCERESPONSE"]._serialized_end = 1027 - _globals["_PROCESSORREQUEST"]._serialized_start = 1029 - _globals["_PROCESSORREQUEST"]._serialized_end = 1098 - _globals["_PROCESSORRESPONSE"]._serialized_start = 1100 - _globals["_PROCESSORRESPONSE"]._serialized_end = 1179 - _globals["_PROCESSORRESULT"]._serialized_start = 1181 - _globals["_PROCESSORRESULT"]._serialized_end = 1273 - _globals["_SINKREQUEST"]._serialized_start = 1275 - _globals["_SINKREQUEST"]._serialized_end = 1338 - _globals["_SINKRESPONSE"]._serialized_start = 1340 - _globals["_SINKRESPONSE"]._serialized_end = 1403 - _globals["_AGENTSERVICE"]._serialized_start = 1406 - _globals["_AGENTSERVICE"]._serialized_end = 1707 + _globals["_TOPICPRODUCERRESPONSE"]._serialized_start = 745 + _globals["_TOPICPRODUCERRESPONSE"]._serialized_end = 833 + _globals["_PERMANENTFAILURE"]._serialized_start = 835 + _globals["_PERMANENTFAILURE"]._serialized_end = 895 + _globals["_SOURCEREQUEST"]._serialized_start = 897 + _globals["_SOURCEREQUEST"]._serialized_end = 985 + _globals["_SOURCERESPONSE"]._serialized_start = 987 + _globals["_SOURCERESPONSE"]._serialized_end = 1054 + _globals["_PROCESSORREQUEST"]._serialized_start = 1056 + _globals["_PROCESSORREQUEST"]._serialized_end = 1125 + _globals["_PROCESSORRESPONSE"]._serialized_start = 1127 + _globals["_PROCESSORRESPONSE"]._serialized_end = 1206 + _globals["_PROCESSORRESULT"]._serialized_start = 1208 + _globals["_PROCESSORRESULT"]._serialized_end = 1300 + _globals["_SINKREQUEST"]._serialized_start = 1302 + _globals["_SINKREQUEST"]._serialized_end = 1365 + _globals["_SINKRESPONSE"]._serialized_start = 1367 + _globals["_SINKRESPONSE"]._serialized_end = 1430 + _globals["_AGENTSERVICE"]._serialized_start = 1433 + _globals["_AGENTSERVICE"]._serialized_end = 1736 # @@protoc_insertion_point(module_scope) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi index 8ae844d8c..d84bfd4f6 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2.pyi @@ -129,15 +129,18 @@ class TopicProducerWriteResult(_message.Message): self, record_id: _Optional[int] = ..., error: _Optional[str] = ... ) -> None: ... -class TopicProducerRecord(_message.Message): - __slots__ = ["topic", "record"] +class TopicProducerResponse(_message.Message): + __slots__ = ["topic", "schema", "record"] TOPIC_FIELD_NUMBER: _ClassVar[int] + SCHEMA_FIELD_NUMBER: _ClassVar[int] RECORD_FIELD_NUMBER: _ClassVar[int] topic: str + schema: Schema record: Record def __init__( self, topic: _Optional[str] = ..., + schema: _Optional[_Union[Schema, _Mapping]] = ..., record: _Optional[_Union[Record, _Mapping]] = ..., ) -> None: ... diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2_grpc.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2_grpc.py index acffce84c..1930bb556 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2_grpc.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/proto/agent_pb2_grpc.py @@ -54,7 +54,7 @@ def __init__(self, channel): self.get_topic_producer_records = channel.stream_stream( "/AgentService/get_topic_producer_records", request_serializer=langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerWriteResult.SerializeToString, - response_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerRecord.FromString, + response_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerResponse.FromString, ) @@ -117,7 +117,7 @@ def add_AgentServiceServicer_to_server(servicer, server): "get_topic_producer_records": grpc.stream_stream_rpc_method_handler( servicer.get_topic_producer_records, request_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerWriteResult.FromString, - response_serializer=langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerRecord.SerializeToString, + response_serializer=langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -264,7 +264,7 @@ def get_topic_producer_records( target, "/AgentService/get_topic_producer_records", langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerWriteResult.SerializeToString, - langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerRecord.FromString, + langstream__grpc_dot_proto_dot_agent__pb2.TopicProducerResponse.FromString, options, channel_credentials, insecure, diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_topic_producer.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_topic_producer.py index 550dfdd22..8908870e2 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_topic_producer.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_topic_producer.py @@ -14,8 +14,11 @@ # limitations under the License. # +import json +from io import BytesIO from typing import List, Dict, Any, Optional +import fastavro import grpc import pytest @@ -25,6 +28,7 @@ ProcessorRequest, Value, TopicProducerWriteResult, + Schema, ) from langstream_grpc.tests.server_and_stub import ServerAndStub @@ -35,23 +39,64 @@ async def test_topic_producer_success(klass): f"langstream_grpc.tests.test_grpc_topic_producer.{klass}" ) as server_and_stub: process_call = server_and_stub.stub.process() + + schema = { + "type": "record", + "name": "Test", + "namespace": "test", + "fields": [{"name": "field", "type": {"type": "string"}}], + } + canonical_schema = fastavro.schema.to_parsing_canonical_form(schema) await process_call.write( - ProcessorRequest(records=[GrpcRecord(value=Value(string_value="test"))]) + ProcessorRequest( + schema=Schema(schema_id=42, value=canonical_schema.encode("utf-8")) + ) ) + fp = BytesIO() + try: + fastavro.schemaless_writer(fp, schema, {"field": "test"}) + await process_call.write( + ProcessorRequest( + records=[ + GrpcRecord( + record_id=43, + value=Value(schema_id=42, avro_value=fp.getvalue()), + ) + ] + ) + ) + finally: + fp.close() + topic_producer_call = server_and_stub.stub.get_topic_producer_records() - topic_producer_record = await topic_producer_call.read() + response = await topic_producer_call.read() + + assert response.HasField("schema") + assert response.schema.schema_id == 1 + assert response.schema.value.decode("utf-8") == canonical_schema + + response = await topic_producer_call.read() + assert response.topic == "topic-producer-topic" + record = response.record + assert record.record_id == 1 + assert record.value.schema_id == 1 + fp = BytesIO(record.value.avro_value) + try: + decoded = fastavro.schemaless_reader(fp, json.loads(canonical_schema)) + assert decoded == {"field": "test"} + finally: + fp.close() - assert topic_producer_record.topic == "topic-producer-topic" - assert topic_producer_record.record.value.string_value == "test" await topic_producer_call.write( - TopicProducerWriteResult(record_id=topic_producer_record.record.record_id) + TopicProducerWriteResult(record_id=record.record_id) ) await topic_producer_call.done_writing() - processed = await process_call.read() - assert processed.results[0].records[0].value.string_value == "test" + response = await process_call.read() + assert response.results[0].records[0].value.schema_id == 1 + assert response.results[0].record_id == 43 await process_call.done_writing()