diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py index cf1ef064e..2c7e4cf56 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py @@ -22,7 +22,7 @@ import threading from concurrent.futures import Future from io import BytesIO -from typing import Union, List, Tuple, Any, Optional, Dict, AsyncIterable +from typing import Union, List, Tuple, Any, Optional, AsyncIterable import fastavro import grpc @@ -86,40 +86,33 @@ async def get_topic_producer_records(self, request_iterator, context): async for _ in request_iterator: yield - async def read(self, requests: AsyncIterable[SourceRequest], _): - read_records = {} - op_result = [] + async def do_read(self, context, read_records): last_record_id = 0 - read_requests_task = asyncio.create_task( - self.handle_read_requests(requests, read_records, op_result) - ) while True: - if len(op_result) > 0: - if op_result[0] is True: - break - raise op_result[0] - records = await asyncio.to_thread(self.agent.read) + if inspect.iscoroutinefunction(self.agent.read): + records = await self.agent.read() + else: + records = await asyncio.to_thread(self.agent.read) if len(records) > 0: records = [wrap_in_record(record) for record in records] grpc_records = [] for record in records: schemas, grpc_record = self.to_grpc_record(record) for schema in schemas: - yield SourceResponse(schema=schema) + await context.write(SourceResponse(schema=schema)) grpc_records.append(grpc_record) for i, record in enumerate(records): last_record_id += 1 grpc_records[i].record_id = last_record_id read_records[last_record_id] = record - yield SourceResponse(records=grpc_records) - read_requests_task.cancel() + await context.write(SourceResponse(records=grpc_records)) + else: + await asyncio.sleep(0) + + async def read(self, requests: AsyncIterable[SourceRequest], context): + read_records = {} + read_requests_task = asyncio.create_task(self.do_read(context, read_records)) - async def handle_read_requests( - self, - requests: AsyncIterable[SourceRequest], - read_records: Dict[int, Record], - read_result, - ): try: async for request in requests: if len(request.committed_records) > 0: @@ -136,9 +129,8 @@ async def handle_read_requests( record, RuntimeError(failure.error_message), ) - read_result.append(True) - except Exception as e: - read_result.append(e) + finally: + read_requests_task.cancel() async def process(self, requests: AsyncIterable[ProcessorRequest], _): async for request in requests: @@ -149,9 +141,13 @@ async def process(self, requests: AsyncIterable[ProcessorRequest], _): for source_record in request.records: grpc_result = ProcessorResult(record_id=source_record.record_id) try: - processed_records = await asyncio.to_thread( - self.agent.process, self.from_grpc_record(source_record) - ) + r = self.from_grpc_record(source_record) + if inspect.iscoroutinefunction(self.agent.process): + processed_records = await self.agent.process(r) + else: + processed_records = await asyncio.to_thread( + self.agent.process, r + ) if isinstance(processed_records, Future): processed_records = await asyncio.wrap_future( processed_records @@ -175,9 +171,11 @@ async def write(self, requests: AsyncIterable[SinkRequest], context): self.client_schemas[request.schema.schema_id] = schema if request.HasField("record"): try: - result = await asyncio.to_thread( - self.agent.write, self.from_grpc_record(request.record) - ) + r = self.from_grpc_record(request.record) + if inspect.iscoroutinefunction(self.agent.write): + result = await self.agent.write(r) + else: + result = await asyncio.to_thread(self.agent.write, r) if isinstance(result, Future): await asyncio.wrap_future(result) yield SinkResponse(record_id=request.record.record_id) @@ -280,9 +278,16 @@ def call_method_if_exists(klass, method, *args, **kwargs): return None -async def acall_method_if_exists(klass, method, *args, **kwargs): +async def acall_method_if_exists(klass, method_name, *args, **kwargs): + method = getattr(klass, method_name, None) + if inspect.iscoroutinefunction(method): + defined_positional_parameters_count = len(inspect.signature(method).parameters) + if defined_positional_parameters_count >= len(args): + return await method(*args, **kwargs) + else: + return await method(*args[:defined_positional_parameters_count], **kwargs) return await asyncio.to_thread( - call_method_if_exists, klass, method, *args, **kwargs + call_method_if_exists, klass, method_name, *args, **kwargs ) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py index 638f05dcc..3692bc432 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_processor.py @@ -186,9 +186,10 @@ async def test_failing_record(): assert response.results[0].error == "failure" -async def test_future_record(): +@pytest.mark.parametrize("klass", ["MyFutureProcessor", "MyAsyncProcessor"]) +async def test_future_record(klass): async with ServerAndStub( - "langstream_grpc.tests.test_grpc_processor.MyFutureProcessor" + f"langstream_grpc.tests.test_grpc_processor.{klass}" ) as server_and_stub: response: ProcessorResponse async for response in server_and_stub.stub.process( @@ -270,13 +271,17 @@ def process(self, record: Record) -> List[RecordType]: class MyFutureProcessor(Processor): def __init__(self): - self.written_records = [] self.executor = ThreadPoolExecutor(max_workers=10) def process(self, record: Record) -> Future[List[RecordType]]: return self.executor.submit(lambda r: [r], record) +class MyAsyncProcessor(Processor): + async def process(self, record: Record) -> List[RecordType]: + return [record] + + class ProcessorInitOneParameter: def __init__(self): self.myparam = None diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py index 0d882a2e4..32d726cd2 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py @@ -18,6 +18,7 @@ from io import BytesIO import fastavro +import pytest from langstream_grpc.api import Record, Sink from langstream_grpc.proto.agent_pb2 import ( @@ -30,9 +31,10 @@ from langstream_grpc.tests.server_and_stub import ServerAndStub -async def test_write(): +@pytest.mark.parametrize("klass", ["MySink", "MyFutureSink", "MyAsyncSink"]) +async def test_write(klass): async with ServerAndStub( - "langstream_grpc.tests.test_grpc_sink.MySink" + f"langstream_grpc.tests.test_grpc_sink.{klass}" ) as server_and_stub: async def requests(): @@ -66,6 +68,7 @@ async def requests(): assert len(responses) == 1 assert responses[0].record_id == 43 + assert responses[0].error == "" assert len(server_and_stub.server.agent.written_records) == 1 assert ( server_and_stub.server.agent.written_records[0].value().value["field"] @@ -94,30 +97,6 @@ async def test_write_error(): assert responses[0].error == "test-error" -async def test_write_future(): - async with ServerAndStub( - "langstream_grpc.tests.test_grpc_sink.MyFutureSink" - ) as server_and_stub: - responses: list[SinkResponse] - responses = [ - response - async for response in server_and_stub.stub.write( - [ - SinkRequest( - record=GrpcRecord( - record_id=42, - value=Value(string_value="test"), - ) - ) - ] - ) - ] - assert len(responses) == 1 - assert responses[0].record_id == 42 - assert len(server_and_stub.server.agent.written_records) == 1 - assert server_and_stub.server.agent.written_records[0].value() == "test" - - class MySink(Sink): def __init__(self): self.written_records = [] @@ -126,11 +105,6 @@ def write(self, record: Record): self.written_records.append(record) -class MyErrorSink(Sink): - def write(self, record: Record): - raise RuntimeError("test-error") - - class MyFutureSink(Sink): def __init__(self): self.written_records = [] @@ -138,3 +112,13 @@ def __init__(self): def write(self, record: Record) -> Future[None]: return self.executor.submit(lambda r: self.written_records.append(r), record) + + +class MyAsyncSink(MySink): + async def write(self, record: Record): + super().write(record) + + +class MyErrorSink(Sink): + def write(self, record: Record): + raise RuntimeError("test-error") diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py index 3f019ea31..8ce0ad683 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py @@ -33,107 +33,113 @@ from langstream_grpc.util import AvroValue, SimpleRecord -@pytest.fixture -async def server_and_stub(): +@pytest.mark.parametrize("klass", ["MySource", "MyAsyncSource"]) +async def test_read(klass): async with ServerAndStub( - "langstream_grpc.tests.test_grpc_source.MySource" + f"langstream_grpc.tests.test_grpc_source.{klass}" ) as server_and_stub: - yield server_and_stub - - -async def test_read(server_and_stub): - stop = False - - async def requests(): - while not stop: - await asyncio.sleep(0.1) - yield - - responses: list[SourceResponse] = [] - i = 0 - async for response in server_and_stub.stub.read(requests()): - responses.append(response) - i += 1 - stop = i == 4 - - response_schema = responses[0] - assert len(response_schema.records) == 0 - assert response_schema.HasField("schema") - assert response_schema.schema.schema_id == 1 - schema = response_schema.schema.value.decode("utf-8") - assert ( - schema - == '{"name":"test.Test","type":"record","fields":[{"name":"field","type":"string"}]}' # noqa: E501 - ) - - response_record = responses[1] - assert len(response_schema.records) == 0 - record = response_record.records[0] - assert record.record_id == 1 - assert record.value.schema_id == 1 - fp = BytesIO(record.value.avro_value) - try: - decoded = fastavro.schemaless_reader(fp, json.loads(schema)) - assert decoded == {"field": "test"} - finally: - fp.close() - - response_record = responses[2] - assert len(response_schema.records) == 0 - record = response_record.records[0] - assert record.record_id == 2 - assert record.value.long_value == 42 - - response_record = responses[3] - assert len(response_schema.records) == 0 - record = response_record.records[0] - assert record.record_id == 3 - assert record.value.long_value == 43 - - -async def test_commit(server_and_stub): - to_commit = asyncio.Queue() - - async def send_commit(): - committed = 0 - while committed < 3: - commit_id = await to_commit.get() - yield SourceRequest(committed_records=[commit_id]) - committed += 1 - - with pytest.raises(grpc.RpcError): - response: SourceResponse - async for response in server_and_stub.stub.read(send_commit()): - for record in response.records: - await to_commit.put(record.record_id) + stop = False - sent = server_and_stub.server.agent.sent - committed = server_and_stub.server.agent.committed - assert len(committed) == 2 - assert committed[0] == sent[0] - assert committed[1].value() == sent[1]["value"] + async def requests(): + while not stop: + await asyncio.sleep(0.1) + yield + responses: list[SourceResponse] = [] + i = 0 + async for response in server_and_stub.stub.read(requests()): + responses.append(response) + for record in response.records: + print(record.record_id) + i += 1 + stop = i == 4 + + response_schema = responses[0] + assert len(response_schema.records) == 0 + assert response_schema.HasField("schema") + assert response_schema.schema.schema_id == 1 + schema = response_schema.schema.value.decode("utf-8") + assert ( + schema + == '{"name":"test.Test","type":"record","fields":[{"name":"field","type":"string"}]}' # noqa: E501 + ) -async def test_permanent_failure(server_and_stub): - to_fail = asyncio.Queue() + response_record = responses[1] + assert len(response_schema.records) == 0 + record = response_record.records[0] + assert record.record_id == 1 + assert record.value.schema_id == 1 + fp = BytesIO(record.value.avro_value) + try: + decoded = fastavro.schemaless_reader(fp, json.loads(schema)) + assert decoded == {"field": "test"} + finally: + fp.close() + + response_record = responses[2] + assert len(response_schema.records) == 0 + record = response_record.records[0] + assert record.record_id == 2 + assert record.value.long_value == 42 + + response_record = responses[3] + assert len(response_schema.records) == 0 + record = response_record.records[0] + assert record.record_id == 3 + assert record.value.long_value == 43 + + +@pytest.mark.parametrize("klass", ["MySource", "MyAsyncSource"]) +async def test_commit(klass): + async with ServerAndStub( + f"langstream_grpc.tests.test_grpc_source.{klass}" + ) as server_and_stub: + to_commit = asyncio.Queue() + + async def send_commit(): + committed = 0 + while committed < 3: + commit_id = await to_commit.get() + yield SourceRequest(committed_records=[commit_id]) + committed += 1 + + with pytest.raises(grpc.RpcError): + response: SourceResponse + async for response in server_and_stub.stub.read(send_commit()): + for record in response.records: + await to_commit.put(record.record_id) + + sent = server_and_stub.server.agent.sent + committed = server_and_stub.server.agent.committed + assert len(committed) == 2 + assert committed[0] == sent[0] + assert committed[1].value() == sent[1]["value"] + + +@pytest.mark.parametrize("klass", ["MySource", "MyAsyncSource"]) +async def test_permanent_failure(klass): + async with ServerAndStub( + f"langstream_grpc.tests.test_grpc_source.{klass}" + ) as server_and_stub: + to_fail = asyncio.Queue() - async def send_failure(): - record_id = await to_fail.get() - yield SourceRequest( - permanent_failure=PermanentFailure( - record_id=record_id, error_message="failure" + async def send_failure(): + record_id = await to_fail.get() + yield SourceRequest( + permanent_failure=PermanentFailure( + record_id=record_id, error_message="failure" + ) ) - ) - response: SourceResponse - async for response in server_and_stub.stub.read(send_failure()): - for record in response.records: - await to_fail.put(record.record_id) + response: SourceResponse + async for response in server_and_stub.stub.read(send_failure()): + for record in response.records: + await to_fail.put(record.record_id) - failures = server_and_stub.server.agent.failures - assert len(failures) == 1 - assert failures[0][0] == server_and_stub.server.agent.sent[0] - assert str(failures[0][1]) == "failure" + failures = server_and_stub.server.agent.failures + assert len(failures) == 1 + assert failures[0][0] == server_and_stub.server.agent.sent[0] + assert str(failures[0][1]) == "failure" class MySource(Source): @@ -171,3 +177,14 @@ def commit(self, record: Record): def permanent_failure(self, record: Record, error: Exception): self.failures.append((record, error)) + + +class MyAsyncSource(MySource): + async def read(self) -> List[RecordType]: + return super().read() + + async def commit(self, record: Record): + super().commit(record) + + async def permanent_failure(self, record: Record, error: Exception): + return super().permanent_failure(record, error)