Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
Add gRPC Source Python server side (LangStream#455)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored Sep 20, 2023
1 parent accb94d commit dc9095d
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import importlib
import json
import logging
import threading
from io import BytesIO
from typing import Iterable, Union, List, Tuple, Any, Optional
from typing import Iterable, Union, List, Tuple, Any, Optional, Dict

import fastavro
import grpc
Expand All @@ -34,6 +35,8 @@
ProcessorResult,
Schema,
InfoResponse,
SourceRequest,
SourceResponse,
)
from langstream_grpc.proto.agent_pb2_grpc import AgentServiceServicer
from langstream_runtime.api import Source, Sink, Processor, Record, Agent
Expand All @@ -54,6 +57,26 @@ def __init__(
self.record_id = record_id


def handle_requests(
agent: Source,
requests: Iterable[SourceRequest],
read_records: Dict[int, Record],
read_result,
):
try:
for request in requests:
if len(request.committed_records) > 0:
records = []
for record_id in request.committed_records:
record = read_records.pop(record_id, None)
if record is not None:
records.append(record)
call_method_if_exists(agent, "commit", records)
read_result.append(True)
except Exception as e:
read_result.append(e)


class AgentService(AgentServiceServicer):
def __init__(self, agent: Union[Agent, Source, Sink, Processor]):
self.agent = agent
Expand All @@ -65,6 +88,34 @@ def agent_info(self, _, context):
info = call_method_if_exists(self.agent, "agent_info") or {}
return InfoResponse(json_info=json.dumps(info))

def read(self, requests: Iterable[SourceRequest], context):
read_records = {}
op_result = []
read_thread = threading.Thread(
target=handle_requests, args=(self.agent, requests, read_records, op_result)
)
last_record_id = 0
read_thread.start()
while True:
if len(op_result) > 0:
if op_result[0] is True:
break
raise op_result[0]
records = self.agent.read()
if len(records) > 0:
grpc_records = []
for record in records:
schemas, grpc_record = self.to_grpc_record(record)
for schema in schemas:
yield SourceResponse(schema=schema)
grpc_records.append(grpc_record)
for i, record in enumerate(records):
last_record_id += 1
grpc_records[i].record_id = last_record_id
read_records[last_record_id] = record
yield SourceResponse(records=grpc_records)
read_thread.join()

def process(self, requests: Iterable[ProcessorRequest], context):
for request in requests:
if request.HasField("schema"):
Expand All @@ -79,13 +130,12 @@ def process(self, requests: Iterable[ProcessorRequest], context):
if isinstance(result, Exception):
grpc_result.error = str(result)
else:
for r in result:
schemas, grpc_record = self.to_grpc_record(r)
for record in result:
schemas, grpc_record = self.to_grpc_record(record)
for schema in schemas:
yield ProcessorResponse(schema=schema)
grpc_result.records.append(grpc_record)
grpc_results.append(grpc_result)

yield ProcessorResponse(results=grpc_results)

def from_grpc_record(self, record: GrpcRecord) -> SimpleRecord:
Expand All @@ -110,6 +160,8 @@ def from_grpc_value(self, value: Value):
)
finally:
avro_value.close()
if value.HasField("json_value"):
return json.loads(value.json_value)
return getattr(value, value.WhichOneof("type_oneof"))

def to_grpc_record(self, record: Record) -> Tuple[List[Schema], GrpcRecord]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n!langstream_grpc/proto/agent.proto\x1a\x1bgoogle/protobuf/empty.proto"!\n\x0cInfoResponse\x12\x11\n\tjson_info\x18\x01 \x01(\t"\xa3\x02\n\x05Value\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\x15\n\x0b\x62ytes_value\x18\x02 \x01(\x0cH\x00\x12\x17\n\rboolean_value\x18\x03 \x01(\x08H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x14\n\nbyte_value\x18\x05 \x01(\x05H\x00\x12\x15\n\x0bshort_value\x18\x06 \x01(\x05H\x00\x12\x13\n\tint_value\x18\x07 \x01(\x05H\x00\x12\x14\n\nlong_value\x18\x08 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\t \x01(\x02H\x00\x12\x16\n\x0c\x64ouble_value\x18\n \x01(\x01H\x00\x12\x14\n\njson_value\x18\x0b \x01(\tH\x00\x12\x14\n\navro_value\x18\x0c \x01(\x0cH\x00\x42\x0c\n\ntype_oneof"-\n\x06Header\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x05value\x18\x02 \x01(\x0b\x32\x06.Value"*\n\x06Schema\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x0c"\xb3\x01\n\x06Record\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x18\n\x03key\x18\x02 \x01(\x0b\x32\x06.ValueH\x00\x88\x01\x01\x12\x1a\n\x05value\x18\x03 \x01(\x0b\x32\x06.ValueH\x01\x88\x01\x01\x12\x18\n\x07headers\x18\x04 \x03(\x0b\x32\x07.Header\x12\x0e\n\x06origin\x18\x05 \x01(\t\x12\x16\n\ttimestamp\x18\x06 \x01(\x03H\x02\x88\x01\x01\x42\x06\n\x04_keyB\x08\n\x06_valueB\x0c\n\n_timestamp"E\n\x10ProcessorRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"O\n\x11ProcessorResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12!\n\x07results\x18\x02 \x03(\x0b\x32\x10.ProcessorResult"\\\n\x0fProcessorResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x07records\x18\x03 \x03(\x0b\x32\x07.RecordB\x08\n\x06_error2}\n\x0c\x41gentService\x12\x35\n\nagent_info\x12\x16.google.protobuf.Empty\x1a\r.InfoResponse"\x00\x12\x36\n\x07process\x12\x11.ProcessorRequest\x1a\x12.ProcessorResponse"\x00(\x01\x30\x01\x42\x1d\n\x19\x61i.langstream.agents.grpcP\x01\x62\x06proto3'
b'\n!langstream_grpc/proto/agent.proto\x1a\x1bgoogle/protobuf/empty.proto"!\n\x0cInfoResponse\x12\x11\n\tjson_info\x18\x01 \x01(\t"\xa3\x02\n\x05Value\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\x15\n\x0b\x62ytes_value\x18\x02 \x01(\x0cH\x00\x12\x17\n\rboolean_value\x18\x03 \x01(\x08H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x14\n\nbyte_value\x18\x05 \x01(\x05H\x00\x12\x15\n\x0bshort_value\x18\x06 \x01(\x05H\x00\x12\x13\n\tint_value\x18\x07 \x01(\x05H\x00\x12\x14\n\nlong_value\x18\x08 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\t \x01(\x02H\x00\x12\x16\n\x0c\x64ouble_value\x18\n \x01(\x01H\x00\x12\x14\n\njson_value\x18\x0b \x01(\tH\x00\x12\x14\n\navro_value\x18\x0c \x01(\x0cH\x00\x42\x0c\n\ntype_oneof"-\n\x06Header\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x05value\x18\x02 \x01(\x0b\x32\x06.Value"*\n\x06Schema\x12\x11\n\tschema_id\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x0c"\xb3\x01\n\x06Record\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x18\n\x03key\x18\x02 \x01(\x0b\x32\x06.ValueH\x00\x88\x01\x01\x12\x1a\n\x05value\x18\x03 \x01(\x0b\x32\x06.ValueH\x01\x88\x01\x01\x12\x18\n\x07headers\x18\x04 \x03(\x0b\x32\x07.Header\x12\x0e\n\x06origin\x18\x05 \x01(\t\x12\x16\n\ttimestamp\x18\x06 \x01(\x03H\x02\x88\x01\x01\x42\x06\n\x04_keyB\x08\n\x06_valueB\x0c\n\n_timestamp"*\n\rSourceRequest\x12\x19\n\x11\x63ommitted_records\x18\x01 \x03(\x03"C\n\x0eSourceResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"E\n\x10ProcessorRequest\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12\x18\n\x07records\x18\x02 \x03(\x0b\x32\x07.Record"O\n\x11ProcessorResponse\x12\x17\n\x06schema\x18\x01 \x01(\x0b\x32\x07.Schema\x12!\n\x07results\x18\x02 \x03(\x0b\x32\x10.ProcessorResult"\\\n\x0fProcessorResult\x12\x11\n\trecord_id\x18\x01 \x01(\x03\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x07records\x18\x03 \x03(\x0b\x32\x07.RecordB\x08\n\x06_error2\xac\x01\n\x0c\x41gentService\x12\x35\n\nagent_info\x12\x16.google.protobuf.Empty\x1a\r.InfoResponse"\x00\x12-\n\x04read\x12\x0e.SourceRequest\x1a\x0f.SourceResponse"\x00(\x01\x30\x01\x12\x36\n\x07process\x12\x11.ProcessorRequest\x1a\x12.ProcessorResponse"\x00(\x01\x30\x01\x42\x1d\n\x19\x61i.langstream.agents.grpcP\x01\x62\x06proto3'
)

_globals = globals()
Expand All @@ -53,12 +53,16 @@
_globals["_SCHEMA"]._serialized_end = 484
_globals["_RECORD"]._serialized_start = 487
_globals["_RECORD"]._serialized_end = 666
_globals["_PROCESSORREQUEST"]._serialized_start = 668
_globals["_PROCESSORREQUEST"]._serialized_end = 737
_globals["_PROCESSORRESPONSE"]._serialized_start = 739
_globals["_PROCESSORRESPONSE"]._serialized_end = 818
_globals["_PROCESSORRESULT"]._serialized_start = 820
_globals["_PROCESSORRESULT"]._serialized_end = 912
_globals["_AGENTSERVICE"]._serialized_start = 914
_globals["_AGENTSERVICE"]._serialized_end = 1039
_globals["_SOURCEREQUEST"]._serialized_start = 668
_globals["_SOURCEREQUEST"]._serialized_end = 710
_globals["_SOURCERESPONSE"]._serialized_start = 712
_globals["_SOURCERESPONSE"]._serialized_end = 779
_globals["_PROCESSORREQUEST"]._serialized_start = 781
_globals["_PROCESSORREQUEST"]._serialized_end = 850
_globals["_PROCESSORRESPONSE"]._serialized_start = 852
_globals["_PROCESSORRESPONSE"]._serialized_end = 931
_globals["_PROCESSORRESULT"]._serialized_start = 933
_globals["_PROCESSORRESULT"]._serialized_end = 1025
_globals["_AGENTSERVICE"]._serialized_start = 1028
_globals["_AGENTSERVICE"]._serialized_end = 1200
# @@protoc_insertion_point(module_scope)
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,24 @@ class Record(_message.Message):
timestamp: _Optional[int] = ...,
) -> None: ...

class SourceRequest(_message.Message):
__slots__ = ["committed_records"]
COMMITTED_RECORDS_FIELD_NUMBER: _ClassVar[int]
committed_records: _containers.RepeatedScalarFieldContainer[int]
def __init__(self, committed_records: _Optional[_Iterable[int]] = ...) -> None: ...

class SourceResponse(_message.Message):
__slots__ = ["schema", "records"]
SCHEMA_FIELD_NUMBER: _ClassVar[int]
RECORDS_FIELD_NUMBER: _ClassVar[int]
schema: Schema
records: _containers.RepeatedCompositeFieldContainer[Record]
def __init__(
self,
schema: _Optional[_Union[Schema, _Mapping]] = ...,
records: _Optional[_Iterable[_Union[Record, _Mapping]]] = ...,
) -> None: ...

class ProcessorRequest(_message.Message):
__slots__ = ["schema", "records"]
SCHEMA_FIELD_NUMBER: _ClassVar[int]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def __init__(self, channel):
request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
response_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.InfoResponse.FromString,
)
self.read = channel.stream_stream(
"/AgentService/read",
request_serializer=langstream__grpc_dot_proto_dot_agent__pb2.SourceRequest.SerializeToString,
response_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.SourceResponse.FromString,
)
self.process = channel.stream_stream(
"/AgentService/process",
request_serializer=langstream__grpc_dot_proto_dot_agent__pb2.ProcessorRequest.SerializeToString,
Expand All @@ -52,6 +57,12 @@ def agent_info(self, request, context):
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")

def read(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")

def process(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
Expand All @@ -66,6 +77,11 @@ def add_AgentServiceServicer_to_server(servicer, server):
request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
response_serializer=langstream__grpc_dot_proto_dot_agent__pb2.InfoResponse.SerializeToString,
),
"read": grpc.stream_stream_rpc_method_handler(
servicer.read,
request_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.SourceRequest.FromString,
response_serializer=langstream__grpc_dot_proto_dot_agent__pb2.SourceResponse.SerializeToString,
),
"process": grpc.stream_stream_rpc_method_handler(
servicer.process,
request_deserializer=langstream__grpc_dot_proto_dot_agent__pb2.ProcessorRequest.FromString,
Expand Down Expand Up @@ -111,6 +127,35 @@ def agent_info(
metadata,
)

@staticmethod
def read(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
target,
"/AgentService/read",
langstream__grpc_dot_proto_dot_agent__pb2.SourceRequest.SerializeToString,
langstream__grpc_dot_proto_dot_agent__pb2.SourceResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)

@staticmethod
def process(
request_iterator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def stub():
pytest.param(
"bytes_value", "bytes_value", b"test-value", b"test-key", b"test-header"
),
pytest.param(
"json_value",
"json_value",
'{"test": "value"}',
'{"test": "key"}',
'{"test": "header"}',
),
pytest.param("boolean_value", "boolean_value", True, False, True),
pytest.param("boolean_value", "boolean_value", False, True, False),
pytest.param("byte_value", "long_value", 42, 43, 44),
pytest.param("short_value", "long_value", 42, 43, 44),
pytest.param("int_value", "long_value", 42, 43, 44),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import json
import queue
from io import BytesIO
from typing import List

import fastavro
import grpc
import pytest

from langstream_grpc.grpc_service import AgentServer
from langstream_grpc.proto.agent_pb2 import (
SourceResponse,
SourceRequest,
)
from langstream_grpc.proto.agent_pb2_grpc import AgentServiceStub
from langstream_runtime.api import Record, RecordType, Source
from langstream_runtime.util import SimpleRecord, AvroValue


@pytest.fixture(autouse=True)
def server_and_stub():
config = """{
"className": "langstream_grpc.tests.test_grpc_source.MySource"
}"""
server = AgentServer("[::]:0", config)
server.start()
channel = grpc.insecure_channel("localhost:%d" % server.port)

yield server, AgentServiceStub(channel=channel)

channel.close()
server.stop()


def test_read(server_and_stub):
server, stub = server_and_stub

responses: list[SourceResponse]
responses = list(stub.read(iter([])))

response_schema = responses[0]
assert len(response_schema.records) == 0
assert response_schema.HasField("schema")
assert response_schema.schema.schema_id == 1
schema = response_schema.schema.value.decode("utf-8")
assert (
schema
== '{"name":"test.Test","type":"record","fields":[{"name":"field","type":"string"}]}' # noqa: E501
)

response_record = responses[1]
assert len(response_schema.records) == 0
record = response_record.records[0]
assert record.record_id == 1
assert record.value.schema_id == 1
fp = BytesIO(record.value.avro_value)
try:
decoded = fastavro.schemaless_reader(fp, json.loads(schema))
assert decoded == {"field": "test"}
finally:
fp.close()


def test_commit(server_and_stub):
server, stub = server_and_stub
to_commit = queue.Queue()

def send_commit():
committed = 0
while committed < 2:
try:
commit_id = to_commit.get(True)
yield SourceRequest(committed_records=[commit_id])
committed += 1
except queue.Empty:
pass

with pytest.raises(grpc.RpcError):
response: SourceResponse
for response in stub.read(iter(send_commit())):
for record in response.records:
to_commit.put(record.record_id)

assert len(server.agent.committed) == 1
assert server.agent.committed[0] == server.agent.sent[0]


class MySource(Source):
def __init__(self):
self.records = [
SimpleRecord(
value=AvroValue(
schema={
"type": "record",
"name": "Test",
"namespace": "test",
"fields": [{"name": "field", "type": {"type": "string"}}],
},
value={"field": "test"},
)
),
SimpleRecord(value=42),
]
self.sent = []
self.committed = []

def read(self) -> List[RecordType]:
if len(self.records) > 0:
record = self.records.pop(0)
self.sent.append(record)
return [record]
return []

def commit(self, records: List[Record]):
for record in records:
if record.value() == 42:
raise Exception("test error")
self.committed.extend(records)

0 comments on commit dc9095d

Please sign in to comment.