Skip to content

Commit

Permalink
Allow to return dict instead of Records in Python agents
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Sep 20, 2023
1 parent 190abec commit 5c1e2c1
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#

from langstream import SimpleRecord, SingleRecordProcessor
from langstream import SingleRecordProcessor


class Exclamation(SingleRecordProcessor):
Expand All @@ -23,4 +23,4 @@ def init(self, config):
self.secret_value = config["secret_value"]

def process_record(self, record):
return [SimpleRecord(record.value() + "!!" + self.secret_value)]
return [(record.value() + "!!" + self.secret_value,)]
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ def __init__(
self.record_id = record_id


def wrap_in_record(record):
if isinstance(record, tuple) or isinstance(record, list):
return SimpleRecord(*record)
if isinstance(record, dict):
return SimpleRecord(**record)
return record


def handle_requests(
agent: Source,
requests: Iterable[SourceRequest],
Expand Down Expand Up @@ -103,6 +111,7 @@ def read(self, requests: Iterable[SourceRequest], context):
raise op_result[0]
records = self.agent.read()
if len(records) > 0:
records = [wrap_in_record(record) for record in records]
grpc_records = []
for record in records:
schemas, grpc_record = self.to_grpc_record(record)
Expand Down Expand Up @@ -131,7 +140,9 @@ def process(self, requests: Iterable[ProcessorRequest], context):
grpc_result.error = str(result)
else:
for record in result:
schemas, grpc_record = self.to_grpc_record(record)
schemas, grpc_record = self.to_grpc_record(
wrap_in_record(record)
)
for schema in schemas:
yield ProcessorResponse(schema=schema)
grpc_result.records.append(grpc_record)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,33 @@ def test_info(stub):


class MyProcessor(SingleRecordProcessor):
def __init__(self):
self.i = 0

def agent_info(self) -> Dict[str, Any]:
return {"test-info-key": "test-info-value"}

def process_record(self, record: Record) -> List[RecordType]:
if record.origin() == "failing-record":
raise Exception("failure")
return [record]
if isinstance(record.value(), str):
return [record]
if isinstance(record.value(), float):
return [
{
"value": record.value(),
"key": record.key(),
"headers": record.headers(),
"origin": record.origin(),
"timestamp": record.timestamp(),
}
]
return [
(
record.value(),
record.key(),
record.headers(),
record.origin(),
record.timestamp(),
)
]
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import json
import queue
import time
from io import BytesIO
from typing import List

Expand All @@ -30,7 +31,7 @@
)
from langstream_grpc.proto.agent_pb2_grpc import AgentServiceStub
from langstream_runtime.api import Record, RecordType, Source
from langstream_runtime.util import SimpleRecord, AvroValue
from langstream_runtime.util import AvroValue, SimpleRecord


@pytest.fixture(autouse=True)
Expand All @@ -51,8 +52,19 @@ def server_and_stub():
def test_read(server_and_stub):
server, stub = server_and_stub

responses: list[SourceResponse]
responses = list(stub.read(iter([])))
stop = False

def requests():
while not stop:
time.sleep(0.1)
yield from ()

responses: list[SourceResponse] = []
i = 0
for response in stub.read(iter(requests())):
responses.append(response)
i += 1
stop = i == 4

response_schema = responses[0]
assert len(response_schema.records) == 0
Expand All @@ -76,14 +88,26 @@ def test_read(server_and_stub):
finally:
fp.close()

response_record = responses[2]
assert len(response_schema.records) == 0
record = response_record.records[0]
assert record.record_id == 2
assert record.value.long_value == 42

response_record = responses[3]
assert len(response_schema.records) == 0
record = response_record.records[0]
assert record.record_id == 3
assert record.value.long_value == 43


def test_commit(server_and_stub):
server, stub = server_and_stub
to_commit = queue.Queue()

def send_commit():
committed = 0
while committed < 2:
while committed < 3:
try:
commit_id = to_commit.get(True)
yield SourceRequest(committed_records=[commit_id])
Expand All @@ -97,8 +121,9 @@ def send_commit():
for record in response.records:
to_commit.put(record.record_id)

assert len(server.agent.committed) == 1
assert len(server.agent.committed) == 2
assert server.agent.committed[0] == server.agent.sent[0]
assert server.agent.committed[1].value() == server.agent.sent[1]["value"]


class MySource(Source):
Expand All @@ -115,7 +140,8 @@ def __init__(self):
value={"field": "test"},
)
),
SimpleRecord(value=42),
{"value": 42},
(43,),
]
self.sent = []
self.committed = []
Expand All @@ -129,6 +155,6 @@ def read(self) -> List[RecordType]:

def commit(self, records: List[Record]):
for record in records:
if record.value() == 42:
if record.value() == 43:
raise Exception("test error")
self.committed.extend(records)
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def headers(self) -> List[Tuple[str, Any]]:
pass


RecordType = Union[Record, list, tuple]
RecordType = Union[Record, dict, list, tuple]


class TopicConsumer(ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ def wrap_in_record(records):
for i, record in enumerate(records):
if isinstance(record, tuple) or isinstance(record, list):
records[i] = SimpleRecord(*record)
if isinstance(record, dict):
records[i] = SimpleRecord(**record)
return records


Expand Down

0 comments on commit 5c1e2c1

Please sign in to comment.