Skip to content

Commit

Permalink
Fix Python agent read when an error is raised (#732)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored Nov 21, 2023
1 parent e365f7f commit 598caf7
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,28 +109,39 @@ async def do_read(self, context, read_records):
else:
await asyncio.sleep(0)

async def read(self, requests: AsyncIterable[SourceRequest], context):
read_records = {}
read_requests_task = asyncio.create_task(self.do_read(context, read_records))

try:
async for request in requests:
if len(request.committed_records) > 0:
for record_id in request.committed_records:
record = read_records.pop(record_id, None)
if record is not None:
await acall_method_if_exists(self.agent, "commit", record)
if request.HasField("permanent_failure"):
failure = request.permanent_failure
record = read_records.pop(failure.record_id, None)
async def handle_read_requests(self, context, read_records):
request = await context.read()
while request != grpc.aio.EOF:
if len(request.committed_records) > 0:
for record_id in request.committed_records:
record = read_records.pop(record_id, None)
if record is not None:
await acall_method_if_exists(self.agent, "commit", record)
if request.HasField("permanent_failure"):
failure = request.permanent_failure
record = read_records.pop(failure.record_id, None)
if record is not None:
await acall_method_if_exists(
self.agent,
"permanent_failure",
record,
RuntimeError(failure.error_message),
)
finally:
read_requests_task.cancel()
request = await context.read()

async def read(self, requests: AsyncIterable[SourceRequest], context):
read_records = {}
read_task = asyncio.create_task(self.do_read(context, read_records))
read_requests_task = asyncio.create_task(
self.handle_read_requests(context, read_records)
)

done, pending = await asyncio.wait(
[read_task, read_requests_task], return_when=asyncio.FIRST_COMPLETED
)
pending.pop().cancel()
# propagate exception if needed
done.pop().result()

async def process(self, requests: AsyncIterable[ProcessorRequest], _):
async for request in requests:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from io import BytesIO

import fastavro
import grpc.aio
import pytest

from langstream_grpc.api import Record, Sink
Expand All @@ -36,45 +37,47 @@ async def test_write(klass):
async with ServerAndStub(
f"langstream_grpc.tests.test_grpc_sink.{klass}"
) as server_and_stub:

async def requests():
schema = {
"type": "record",
"name": "Test",
"namespace": "test",
"fields": [{"name": "field", "type": {"type": "string"}}],
}
canonical_schema = fastavro.schema.to_parsing_canonical_form(schema)
yield SinkRequest(
write_call = server_and_stub.stub.write()

schema = {
"type": "record",
"name": "Test",
"namespace": "test",
"fields": [{"name": "field", "type": {"type": "string"}}],
}
canonical_schema = fastavro.schema.to_parsing_canonical_form(schema)
await write_call.write(
SinkRequest(
schema=Schema(schema_id=42, value=canonical_schema.encode("utf-8"))
)
)

fp = BytesIO()
try:
fastavro.schemaless_writer(fp, schema, {"field": "test"})
yield SinkRequest(
fp = BytesIO()
try:
fastavro.schemaless_writer(fp, schema, {"field": "test"})
await write_call.write(
SinkRequest(
record=GrpcRecord(
record_id=43,
value=Value(schema_id=42, avro_value=fp.getvalue()),
)
)
finally:
fp.close()

responses: list[SinkResponse]
responses = [
response async for response in server_and_stub.stub.write(requests())
]
)
finally:
fp.close()

assert len(responses) == 1
assert responses[0].record_id == 43
assert responses[0].error == ""
response = await write_call.read()
assert response.record_id == 43
assert response.error == ""
assert len(server_and_stub.server.agent.written_records) == 1
assert (
server_and_stub.server.agent.written_records[0].value().value["field"]
== "test"
)

await write_call.done_writing()
assert await write_call.read() == grpc.aio.EOF


async def test_write_error():
async with ServerAndStub(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
#

import asyncio
import json
from io import BytesIO
from typing import List
Expand All @@ -25,48 +24,33 @@

from langstream_grpc.api import Record, RecordType, Source
from langstream_grpc.proto.agent_pb2 import (
SourceResponse,
SourceRequest,
PermanentFailure,
)
from langstream_grpc.tests.server_and_stub import ServerAndStub
from langstream_grpc.util import AvroValue, SimpleRecord


@pytest.mark.parametrize("klass", ["MySource", "MyAsyncSource"])
@pytest.mark.parametrize("klass", ["MyAsyncSource"])
async def test_read(klass):
async with ServerAndStub(
f"langstream_grpc.tests.test_grpc_source.{klass}"
) as server_and_stub:
stop = False

async def requests():
while not stop:
await asyncio.sleep(0.1)
yield

responses: list[SourceResponse] = []
i = 0
async for response in server_and_stub.stub.read(requests()):
responses.append(response)
for record in response.records:
print(record.record_id)
i += 1
stop = i == 4

response_schema = responses[0]
assert len(response_schema.records) == 0
assert response_schema.HasField("schema")
assert response_schema.schema.schema_id == 1
schema = response_schema.schema.value.decode("utf-8")
read_call = server_and_stub.stub.read()
response = await read_call.read()

assert len(response.records) == 0
assert response.HasField("schema")
assert response.schema.schema_id == 1
schema = response.schema.value.decode("utf-8")
assert (
schema
== '{"name":"test.Test","type":"record","fields":[{"name":"field","type":"string"}]}' # noqa: E501
)

response_record = responses[1]
assert len(response_schema.records) == 0
record = response_record.records[0]
response = await read_call.read()
assert len(response.records) == 1
record = response.records[0]
assert record.record_id == 1
assert record.value.schema_id == 1
fp = BytesIO(record.value.avro_value)
Expand All @@ -76,38 +60,42 @@ async def requests():
finally:
fp.close()

response_record = responses[2]
assert len(response_schema.records) == 0
record = response_record.records[0]
response = await read_call.read()
assert len(response.records) == 1
record = response.records[0]
assert record.record_id == 2
assert record.value.long_value == 42

response_record = responses[3]
assert len(response_schema.records) == 0
record = response_record.records[0]
response = await read_call.read()
assert len(response.records) == 1
record = response.records[0]
assert record.record_id == 3
assert record.value.long_value == 43

await read_call.done_writing()

assert await read_call.read() == grpc.aio.EOF


@pytest.mark.parametrize("klass", ["MySource", "MyAsyncSource"])
async def test_commit(klass):
async with ServerAndStub(
f"langstream_grpc.tests.test_grpc_source.{klass}"
) as server_and_stub:
to_commit = asyncio.Queue()
read_call = server_and_stub.stub.read()

async def send_commit():
committed = 0
while committed < 3:
commit_id = await to_commit.get()
yield SourceRequest(committed_records=[commit_id])
committed += 1
# first read is a schema
await read_call.read()

for _ in range(3):
response = await read_call.read()
assert len(response.records) == 1
await read_call.write(
SourceRequest(committed_records=[response.records[0].record_id])
)

with pytest.raises(grpc.RpcError):
response: SourceResponse
async for response in server_and_stub.stub.read(send_commit()):
for record in response.records:
await to_commit.put(record.record_id)
await read_call.read()

sent = server_and_stub.server.agent.sent
committed = server_and_stub.server.agent.committed
Expand All @@ -121,27 +109,48 @@ async def test_permanent_failure(klass):
async with ServerAndStub(
f"langstream_grpc.tests.test_grpc_source.{klass}"
) as server_and_stub:
to_fail = asyncio.Queue()
read_call = server_and_stub.stub.read()

# first read is a schema
await read_call.read()

async def send_failure():
record_id = await to_fail.get()
yield SourceRequest(
response = await read_call.read()
await read_call.write(
SourceRequest(
permanent_failure=PermanentFailure(
record_id=record_id, error_message="failure"
record_id=response.records[0].record_id, error_message="failure"
)
)
)

await read_call.done_writing()

response: SourceResponse
async for response in server_and_stub.stub.read(send_failure()):
for record in response.records:
await to_fail.put(record.record_id)
try:
response = await read_call.read()
while response != grpc.aio.EOF:
response = await read_call.read()
pytest.fail("call should have raised an exception")
except grpc.RpcError as e:
assert "failure" in e.details()

failures = server_and_stub.server.agent.failures
assert len(failures) == 1
assert failures[0][0] == server_and_stub.server.agent.sent[0]
assert str(failures[0][1]) == "failure"


async def test_read_error():
async with ServerAndStub(
"langstream_grpc.tests.test_grpc_source.MyErrorSource"
) as server_and_stub:
read_call = server_and_stub.stub.read()
try:
await read_call.read()
pytest.fail("call should have raised an exception")
except grpc.RpcError as e:
assert "test-error" in e.details()


class MySource(Source):
def __init__(self):
self.records = [
Expand Down Expand Up @@ -177,6 +186,7 @@ def commit(self, record: Record):

def permanent_failure(self, record: Record, error: Exception):
self.failures.append((record, error))
raise error


class MyAsyncSource(MySource):
Expand All @@ -188,3 +198,8 @@ async def commit(self, record: Record):

async def permanent_failure(self, record: Record, error: Exception):
return super().permanent_failure(record, error)


class MyErrorSource(Source):
def read(self) -> List[RecordType]:
raise ValueError("test-error")

0 comments on commit 598caf7

Please sign in to comment.