From e736ca82665fb20a632351b85b2790570e696df0 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Thu, 21 Sep 2023 11:34:54 +0200 Subject: [PATCH] Allow to return dict instead of Records in Python agents (#460) --- .../python/example.py | 4 +- .../python/example.py | 2 +- .../src/main/python/langstream/api.py | 15 ++++++- .../python/langstream_grpc/grpc_service.py | 13 +++++- .../tests/test_grpc_processor.py | 22 +++++++++- .../langstream_grpc/tests/test_grpc_source.py | 40 +++++++++++++++---- .../src/main/python/langstream_runtime/api.py | 17 ++++++-- .../main/python/langstream_runtime/runtime.py | 2 + 8 files changed, 98 insertions(+), 17 deletions(-) diff --git a/langstream-e2e-tests/src/test/resources/apps/experimental-python-processor/python/example.py b/langstream-e2e-tests/src/test/resources/apps/experimental-python-processor/python/example.py index 2fb59b47f..879ea2bea 100644 --- a/langstream-e2e-tests/src/test/resources/apps/experimental-python-processor/python/example.py +++ b/langstream-e2e-tests/src/test/resources/apps/experimental-python-processor/python/example.py @@ -14,7 +14,7 @@ # limitations under the License. # -from langstream import SimpleRecord, SingleRecordProcessor +from langstream import SingleRecordProcessor class Exclamation(SingleRecordProcessor): @@ -23,4 +23,4 @@ def init(self, config): self.secret_value = config["secret_value"] def process_record(self, record): - return [SimpleRecord(record.value() + "!!" + self.secret_value)] + return [(record.value() + "!!" + self.secret_value,)] diff --git a/langstream-e2e-tests/src/test/resources/apps/experimental-python-source/python/example.py b/langstream-e2e-tests/src/test/resources/apps/experimental-python-source/python/example.py index dcefa5445..9f6b3b2c3 100644 --- a/langstream-e2e-tests/src/test/resources/apps/experimental-python-source/python/example.py +++ b/langstream-e2e-tests/src/test/resources/apps/experimental-python-source/python/example.py @@ -17,8 +17,8 @@ from langstream import SimpleRecord import logging -class TestSource(object): +class TestSource(object): def __init__(self): self.sent = False diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/api.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/api.py index ff0216734..2e998cf0a 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/api.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream/api.py @@ -92,7 +92,12 @@ def read(self) -> List[RecordType]: """The Source agent generates records and returns them as list of records. :returns: the list of records. The records must either respect the Record - API contract (have methods value(), key() and so on) or be tuples/list. + API contract (have methods value(), key() and so on) or be a dict or + tuples/list. + If the records are dict, the keys if present shall be "value", "key", + "headers", "origin" and "timestamp". + Eg: + * if you return [{"value": "foo"}] a record Record(value="foo") will be built. If the records are tuples/list, the framework will automatically construct Record objects from them with the values in the following order : value, key, headers, origin, timestamp. @@ -138,7 +143,13 @@ def process( exception. Eg: [(input_record, RuntimeError("Could not process"))] When the processing is successful, the output records must either respect the - Record API contract (have methods value(), key() and so on) or be tuples/list. + Record API contract (have methods value(), key() and so on) or be a dict or + tuples/list. + If the records are dict, the keys if present shall be "value", "key", + "headers", "origin" and "timestamp". + Eg: + * if you return [(input_record, [{"value": "foo"}])] a record + Record(value="foo") will be built. If the output records are tuples/list, the framework will automatically construct Record objects from them with the values in the following order : value, key, headers, origin, timestamp. diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py index db43aa431..460e59937 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py @@ -57,6 +57,14 @@ def __init__( self.record_id = record_id +def wrap_in_record(record): + if isinstance(record, tuple) or isinstance(record, list): + return SimpleRecord(*record) + if isinstance(record, dict): + return SimpleRecord(**record) + return record + + def handle_requests( agent: Source, requests: Iterable[SourceRequest], @@ -112,6 +120,7 @@ def read(self, requests: Iterable[SourceRequest], context): raise op_result[0] records = self.agent.read() if len(records) > 0: + records = [wrap_in_record(record) for record in records] grpc_records = [] for record in records: schemas, grpc_record = self.to_grpc_record(record) @@ -140,7 +149,9 @@ def process(self, requests: Iterable[ProcessorRequest], context): grpc_result.error = str(result) else: for record in result: - schemas, grpc_record = self.to_grpc_record(record) + schemas, grpc_record = self.to_grpc_record( + wrap_in_record(record) + ) for schema in schemas: yield ProcessorResponse(schema=schema) grpc_result.records.append(grpc_record) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py index f7edf5a49..4ea790421 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py @@ -200,4 +200,24 @@ def agent_info(self) -> Dict[str, Any]: def process_record(self, record: Record) -> List[RecordType]: if record.origin() == "failing-record": raise Exception("failure") - return [record] + if isinstance(record.value(), str): + return [record] + if isinstance(record.value(), float): + return [ + { + "value": record.value(), + "key": record.key(), + "headers": record.headers(), + "origin": record.origin(), + "timestamp": record.timestamp(), + } + ] + return [ + ( + record.value(), + record.key(), + record.headers(), + record.origin(), + record.timestamp(), + ) + ] diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py index 59577ab67..57fcd4b3c 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py @@ -16,6 +16,7 @@ import json import queue +import time from io import BytesIO from typing import List @@ -31,7 +32,7 @@ ) from langstream_grpc.proto.agent_pb2_grpc import AgentServiceStub from langstream_runtime.api import Record, RecordType, Source -from langstream_runtime.util import SimpleRecord, AvroValue +from langstream_runtime.util import AvroValue, SimpleRecord @pytest.fixture(autouse=True) @@ -52,8 +53,19 @@ def server_and_stub(): def test_read(server_and_stub): server, stub = server_and_stub - responses: list[SourceResponse] - responses = list(stub.read(iter([]))) + stop = False + + def requests(): + while not stop: + time.sleep(0.1) + yield from () + + responses: list[SourceResponse] = [] + i = 0 + for response in stub.read(iter(requests())): + responses.append(response) + i += 1 + stop = i == 4 response_schema = responses[0] assert len(response_schema.records) == 0 @@ -77,6 +89,18 @@ def test_read(server_and_stub): finally: fp.close() + response_record = responses[2] + assert len(response_schema.records) == 0 + record = response_record.records[0] + assert record.record_id == 2 + assert record.value.long_value == 42 + + response_record = responses[3] + assert len(response_schema.records) == 0 + record = response_record.records[0] + assert record.record_id == 3 + assert record.value.long_value == 43 + def test_commit(server_and_stub): server, stub = server_and_stub @@ -84,7 +108,7 @@ def test_commit(server_and_stub): def send_commit(): committed = 0 - while committed < 2: + while committed < 3: try: commit_id = to_commit.get(True, 1) yield SourceRequest(committed_records=[commit_id]) @@ -98,8 +122,9 @@ def send_commit(): for record in response.records: to_commit.put(record.record_id) - assert len(server.agent.committed) == 1 + assert len(server.agent.committed) == 2 assert server.agent.committed[0] == server.agent.sent[0] + assert server.agent.committed[1].value() == server.agent.sent[1]["value"] def test_permanent_failure(server_and_stub): @@ -141,7 +166,8 @@ def __init__(self): value={"field": "test"}, ) ), - SimpleRecord(value=42), + {"value": 42}, + (43,), ] self.sent = [] self.committed = [] @@ -156,7 +182,7 @@ def read(self) -> List[RecordType]: def commit(self, records: List[Record]): for record in records: - if record.value() == 42: + if record.value() == 43: raise Exception("test error") self.committed.extend(records) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/api.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/api.py index d47342d29..6e49b0149 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/api.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/api.py @@ -60,7 +60,7 @@ def headers(self) -> List[Tuple[str, Any]]: pass -RecordType = Union[Record, list, tuple] +RecordType = Union[Record, dict, list, tuple] class TopicConsumer(ABC): @@ -146,7 +146,12 @@ def read(self) -> List[RecordType]: """The Source agent generates records and returns them as list of records. :returns: the list of records. The records must either respect the Record - API contract (have methods value(), key() and so on) or be tuples/list. + API contract (have methods value(), key() and so on) or be a dict or + tuples/list. + If the records are dict, the keys if present shall be "value", "key", + "headers", "origin" and "timestamp". + Eg: + * if you return [{"value": "foo"}] a record Record(value="foo") will be built. If the records are tuples/list, the framework will automatically construct Record objects from them with the values in the following order : value, key, headers, origin, timestamp. @@ -192,7 +197,13 @@ def process( exception. Eg: [(input_record, RuntimeError("Could not process"))] When the processing is successful, the output records must either respect the - Record API contract (have methods value(), key() and so on) or be tuples/list. + Record API contract (have methods value(), key() and so on) or be a dict or + tuples/list. + If the records are dict, the keys if present shall be "value", "key", + "headers", "origin" and "timestamp". + Eg: + * if you return [(input_record, [{"value": "foo"}])] a record + Record(value="foo") will be built. If the output records are tuples/list, the framework will automatically construct Record objects from them with the values in the following order : value, key, headers, origin, timestamp. diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/runtime.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/runtime.py index 8dbaf3062..bae98e9ed 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/runtime.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_runtime/runtime.py @@ -357,6 +357,8 @@ def wrap_in_record(records): for i, record in enumerate(records): if isinstance(record, tuple) or isinstance(record, list): records[i] = SimpleRecord(*record) + if isinstance(record, dict): + records[i] = SimpleRecord(**record) return records