From 5c1e2c1bbfea196ba2c45146f90ce2ba1090a90b Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 20 Sep 2023 20:13:52 +0200 Subject: [PATCH] Allow to return dict instead of Records in Python agents --- .../python/example.py | 4 +- .../python/langstream_grpc/grpc_service.py | 13 +++++- .../tests/test_grpc_service.py | 25 +++++++++++- .../langstream_grpc/tests/test_grpc_source.py | 40 +++++++++++++++---- .../src/main/python/langstream_runtime/api.py | 2 +- .../main/python/langstream_runtime/runtime.py | 2 + 6 files changed, 74 insertions(+), 12 deletions(-) diff --git a/langstream-e2e-tests/src/test/resources/apps/experimental-python-processor/python/example.py b/langstream-e2e-tests/src/test/resources/apps/experimental-python-processor/python/example.py index 2fb59b47f..879ea2bea 100644 --- a/langstream-e2e-tests/src/test/resources/apps/experimental-python-processor/python/example.py +++ b/langstream-e2e-tests/src/test/resources/apps/experimental-python-processor/python/example.py @@ -14,7 +14,7 @@ # limitations under the License. # -from langstream import SimpleRecord, SingleRecordProcessor +from langstream import SingleRecordProcessor class Exclamation(SingleRecordProcessor): @@ -23,4 +23,4 @@ def init(self, config): self.secret_value = config["secret_value"] def process_record(self, record): - return [SimpleRecord(record.value() + "!!" + self.secret_value)] + return [(record.value() + "!!" + self.secret_value,)] diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py index 1a1f266d0..337faa236 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py @@ -57,6 +57,14 @@ def __init__( self.record_id = record_id +def wrap_in_record(record): + if isinstance(record, tuple) or isinstance(record, list): + return SimpleRecord(*record) + if isinstance(record, dict): + return SimpleRecord(**record) + return record + + def handle_requests( agent: Source, requests: Iterable[SourceRequest], @@ -103,6 +111,7 @@ def read(self, requests: Iterable[SourceRequest], context): raise op_result[0] records = self.agent.read() if len(records) > 0: + records = [wrap_in_record(record) for record in records] grpc_records = [] for record in records: schemas, grpc_record = self.to_grpc_record(record) @@ -131,7 +140,9 @@ def process(self, requests: Iterable[ProcessorRequest], context): grpc_result.error = str(result) else: for record in result: - schemas, grpc_record = self.to_grpc_record(record) + schemas, grpc_record = self.to_grpc_record( + wrap_in_record(record) + ) for schema in schemas: yield ProcessorResponse(schema=schema) grpc_result.records.append(grpc_record) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_service.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_service.py index c388fb9bd..ec7a2dbe4 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_service.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_service.py @@ -194,10 +194,33 @@ def test_info(stub): class MyProcessor(SingleRecordProcessor): + def __init__(self): + self.i = 0 + def agent_info(self) -> Dict[str, Any]: return {"test-info-key": "test-info-value"} def process_record(self, record: Record) -> List[RecordType]: if record.origin() == "failing-record": raise Exception("failure") - return [record] + if isinstance(record.value(), str): + return [record] + if isinstance(record.value(), float): + return [ + { + "value": record.value(), + "key": record.key(), + "headers": record.headers(), + "origin": record.origin(), + "timestamp": record.timestamp(), + } + ] + return [ + ( + record.value(), + record.key(), + record.headers(), + record.origin(), + record.timestamp(), + ) + ] diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py index 5b33d3373..0f523fb8d 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py @@ -16,6 +16,7 @@ import json import queue +import time from io import BytesIO from typing import List @@ -30,7 +31,7 @@ ) from langstream_grpc.proto.agent_pb2_grpc import AgentServiceStub from langstream_runtime.api import Record, RecordType, Source -from langstream_runtime.util import SimpleRecord, AvroValue +from langstream_runtime.util import AvroValue, SimpleRecord @pytest.fixture(autouse=True) @@ -51,8 +52,19 @@ def server_and_stub(): def test_read(server_and_stub): server, stub = server_and_stub - responses: list[SourceResponse] - responses = list(stub.read(iter([]))) + stop = False + + def requests(): + while not stop: + time.sleep(0.1) + yield from () + + responses: list[SourceResponse] = [] + i = 0 + for response in stub.read(iter(requests())): + responses.append(response) + i += 1 + stop = i == 4 response_schema = responses[0] assert len(response_schema.records) == 0 @@ -76,6 +88,18 @@ def test_read(server_and_stub): finally: fp.close() + response_record = responses[2] + assert len(response_schema.records) == 0 + record = response_record.records[0] + assert record.record_id == 2 + assert record.value.long_value == 42 + + response_record = responses[3] + assert len(response_schema.records) == 0 + record = response_record.records[0] + assert record.record_id == 3 + assert record.value.long_value == 43 + def test_commit(server_and_stub): server, stub = server_and_stub @@ -83,7 +107,7 @@ def test_commit(server_and_stub): def send_commit(): committed = 0 - while committed < 2: + while committed < 3: try: commit_id = to_commit.get(True) yield SourceRequest(committed_records=[commit_id]) @@ -97,8 +121,9 @@ def send_commit(): for record in response.records: to_commit.put(record.record_id) - assert len(server.agent.committed) == 1 + assert len(server.agent.committed) == 2 assert server.agent.committed[0] == server.agent.sent[0] + assert server.agent.committed[1].value() == server.agent.sent[1]["value"] class MySource(Source): @@ -115,7 +140,8 @@ def __init__(self): value={"field": "test"}, ) ), - SimpleRecord(value=42), + {"value": 42}, + (43,), ] self.sent = [] self.committed = [] @@ -129,6 +155,6 @@ def read(self) -> List[RecordType]: def commit(self, records: List[Record]): for record in records: - if record.value() == 42: + if record.value() == 43: raise Exception("test error") self.committed.extend(records) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/api.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/api.py index d47342d29..083a9975c 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/api.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/api.py @@ -60,7 +60,7 @@ def headers(self) -> List[Tuple[str, Any]]: pass -RecordType = Union[Record, list, tuple] +RecordType = Union[Record, dict, list, tuple] class TopicConsumer(ABC): diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/runtime.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/runtime.py index 8dbaf3062..bae98e9ed 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/runtime.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/runtime.py @@ -357,6 +357,8 @@ def wrap_in_record(records): for i, record in enumerate(records): if isinstance(record, tuple) or isinstance(record, list): records[i] = SimpleRecord(*record) + if isinstance(record, dict): + records[i] = SimpleRecord(**record) return records