diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py index 2c7e4cf56..988fd4532 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py @@ -109,28 +109,39 @@ async def do_read(self, context, read_records): else: await asyncio.sleep(0) - async def read(self, requests: AsyncIterable[SourceRequest], context): - read_records = {} - read_requests_task = asyncio.create_task(self.do_read(context, read_records)) - - try: - async for request in requests: - if len(request.committed_records) > 0: - for record_id in request.committed_records: - record = read_records.pop(record_id, None) - if record is not None: - await acall_method_if_exists(self.agent, "commit", record) - if request.HasField("permanent_failure"): - failure = request.permanent_failure - record = read_records.pop(failure.record_id, None) + async def handle_read_requests(self, context, read_records): + request = await context.read() + while request != grpc.aio.EOF: + if len(request.committed_records) > 0: + for record_id in request.committed_records: + record = read_records.pop(record_id, None) + if record is not None: + await acall_method_if_exists(self.agent, "commit", record) + if request.HasField("permanent_failure"): + failure = request.permanent_failure + record = read_records.pop(failure.record_id, None) + if record is not None: await acall_method_if_exists( self.agent, "permanent_failure", record, RuntimeError(failure.error_message), ) - finally: - read_requests_task.cancel() + request = await context.read() + + async def read(self, requests: AsyncIterable[SourceRequest], context): + read_records = {} + read_task = asyncio.create_task(self.do_read(context, read_records)) + read_requests_task = asyncio.create_task( + self.handle_read_requests(context, read_records) + ) + + done, pending = await asyncio.wait( + [read_task, read_requests_task], return_when=asyncio.FIRST_COMPLETED + ) + pending.pop().cancel() + # propagate exception if needed + done.pop().result() async def process(self, requests: AsyncIterable[ProcessorRequest], _): async for request in requests: diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py index 32d726cd2..028e9e7f6 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_sink.py @@ -18,6 +18,7 @@ from io import BytesIO import fastavro +import grpc.aio import pytest from langstream_grpc.api import Record, Sink @@ -36,45 +37,47 @@ async def test_write(klass): async with ServerAndStub( f"langstream_grpc.tests.test_grpc_sink.{klass}" ) as server_and_stub: - - async def requests(): - schema = { - "type": "record", - "name": "Test", - "namespace": "test", - "fields": [{"name": "field", "type": {"type": "string"}}], - } - canonical_schema = fastavro.schema.to_parsing_canonical_form(schema) - yield SinkRequest( + write_call = server_and_stub.stub.write() + + schema = { + "type": "record", + "name": "Test", + "namespace": "test", + "fields": [{"name": "field", "type": {"type": "string"}}], + } + canonical_schema = fastavro.schema.to_parsing_canonical_form(schema) + await write_call.write( + SinkRequest( schema=Schema(schema_id=42, value=canonical_schema.encode("utf-8")) ) + ) - fp = BytesIO() - try: - fastavro.schemaless_writer(fp, schema, {"field": "test"}) - yield SinkRequest( + fp = BytesIO() + try: + fastavro.schemaless_writer(fp, schema, {"field": "test"}) + await write_call.write( + SinkRequest( record=GrpcRecord( record_id=43, value=Value(schema_id=42, avro_value=fp.getvalue()), ) ) - finally: - fp.close() - - responses: list[SinkResponse] - responses = [ - response async for response in server_and_stub.stub.write(requests()) - ] + ) + finally: + fp.close() - assert len(responses) == 1 - assert responses[0].record_id == 43 - assert responses[0].error == "" + response = await write_call.read() + assert response.record_id == 43 + assert response.error == "" assert len(server_and_stub.server.agent.written_records) == 1 assert ( server_and_stub.server.agent.written_records[0].value().value["field"] == "test" ) + await write_call.done_writing() + assert await write_call.read() == grpc.aio.EOF + async def test_write_error(): async with ServerAndStub( diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py index 8ce0ad683..8352828b0 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_source.py @@ -14,7 +14,6 @@ # limitations under the License. # -import asyncio import json from io import BytesIO from typing import List @@ -25,7 +24,6 @@ from langstream_grpc.api import Record, RecordType, Source from langstream_grpc.proto.agent_pb2 import ( - SourceResponse, SourceRequest, PermanentFailure, ) @@ -33,40 +31,26 @@ from langstream_grpc.util import AvroValue, SimpleRecord -@pytest.mark.parametrize("klass", ["MySource", "MyAsyncSource"]) +@pytest.mark.parametrize("klass", ["MyAsyncSource"]) async def test_read(klass): async with ServerAndStub( f"langstream_grpc.tests.test_grpc_source.{klass}" ) as server_and_stub: - stop = False - - async def requests(): - while not stop: - await asyncio.sleep(0.1) - yield - - responses: list[SourceResponse] = [] - i = 0 - async for response in server_and_stub.stub.read(requests()): - responses.append(response) - for record in response.records: - print(record.record_id) - i += 1 - stop = i == 4 - - response_schema = responses[0] - assert len(response_schema.records) == 0 - assert response_schema.HasField("schema") - assert response_schema.schema.schema_id == 1 - schema = response_schema.schema.value.decode("utf-8") + read_call = server_and_stub.stub.read() + response = await read_call.read() + + assert len(response.records) == 0 + assert response.HasField("schema") + assert response.schema.schema_id == 1 + schema = response.schema.value.decode("utf-8") assert ( schema == '{"name":"test.Test","type":"record","fields":[{"name":"field","type":"string"}]}' # noqa: E501 ) - response_record = responses[1] - assert len(response_schema.records) == 0 - record = response_record.records[0] + response = await read_call.read() + assert len(response.records) == 1 + record = response.records[0] assert record.record_id == 1 assert record.value.schema_id == 1 fp = BytesIO(record.value.avro_value) @@ -76,38 +60,42 @@ async def requests(): finally: fp.close() - response_record = responses[2] - assert len(response_schema.records) == 0 - record = response_record.records[0] + response = await read_call.read() + assert len(response.records) == 1 + record = response.records[0] assert record.record_id == 2 assert record.value.long_value == 42 - response_record = responses[3] - assert len(response_schema.records) == 0 - record = response_record.records[0] + response = await read_call.read() + assert len(response.records) == 1 + record = response.records[0] assert record.record_id == 3 assert record.value.long_value == 43 + await read_call.done_writing() + + assert await read_call.read() == grpc.aio.EOF + @pytest.mark.parametrize("klass", ["MySource", "MyAsyncSource"]) async def test_commit(klass): async with ServerAndStub( f"langstream_grpc.tests.test_grpc_source.{klass}" ) as server_and_stub: - to_commit = asyncio.Queue() + read_call = server_and_stub.stub.read() - async def send_commit(): - committed = 0 - while committed < 3: - commit_id = await to_commit.get() - yield SourceRequest(committed_records=[commit_id]) - committed += 1 + # first read is a schema + await read_call.read() + + for _ in range(3): + response = await read_call.read() + assert len(response.records) == 1 + await read_call.write( + SourceRequest(committed_records=[response.records[0].record_id]) + ) with pytest.raises(grpc.RpcError): - response: SourceResponse - async for response in server_and_stub.stub.read(send_commit()): - for record in response.records: - await to_commit.put(record.record_id) + await read_call.read() sent = server_and_stub.server.agent.sent committed = server_and_stub.server.agent.committed @@ -121,20 +109,29 @@ async def test_permanent_failure(klass): async with ServerAndStub( f"langstream_grpc.tests.test_grpc_source.{klass}" ) as server_and_stub: - to_fail = asyncio.Queue() + read_call = server_and_stub.stub.read() + + # first read is a schema + await read_call.read() - async def send_failure(): - record_id = await to_fail.get() - yield SourceRequest( + response = await read_call.read() + await read_call.write( + SourceRequest( permanent_failure=PermanentFailure( - record_id=record_id, error_message="failure" + record_id=response.records[0].record_id, error_message="failure" ) ) + ) + + await read_call.done_writing() - response: SourceResponse - async for response in server_and_stub.stub.read(send_failure()): - for record in response.records: - await to_fail.put(record.record_id) + try: + response = await read_call.read() + while response != grpc.aio.EOF: + response = await read_call.read() + pytest.fail("call should have raised an exception") + except grpc.RpcError as e: + assert "failure" in e.details() failures = server_and_stub.server.agent.failures assert len(failures) == 1 @@ -142,6 +139,18 @@ async def send_failure(): assert str(failures[0][1]) == "failure" +async def test_read_error(): + async with ServerAndStub( + "langstream_grpc.tests.test_grpc_source.MyErrorSource" + ) as server_and_stub: + read_call = server_and_stub.stub.read() + try: + await read_call.read() + pytest.fail("call should have raised an exception") + except grpc.RpcError as e: + assert "test-error" in e.details() + + class MySource(Source): def __init__(self): self.records = [ @@ -177,6 +186,7 @@ def commit(self, record: Record): def permanent_failure(self, record: Record, error: Exception): self.failures.append((record, error)) + raise error class MyAsyncSource(MySource): @@ -188,3 +198,8 @@ async def commit(self, record: Record): async def permanent_failure(self, record: Record, error: Exception): return super().permanent_failure(record, error) + + +class MyErrorSource(Source): + def read(self) -> List[RecordType]: + raise ValueError("test-error")