Skip to content

Commit

Permalink
Add topic producer API Python side (#733)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored Nov 22, 2023
1 parent 598caf7 commit 00e4415
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,28 @@ def headers(self) -> List[Tuple[str, Any]]:
RecordType = Union[Record, dict, list, tuple]


class TopicProducer(ABC):
"""The topic producer interface"""

async def awrite(self, topic: str, record: Record):
"""Write a record to a topic (for async methods)."""
pass

def write(self, topic: str, record: Record) -> Future:
"""Write a record to a topic (for non-async methods)."""
pass


class AgentContext(ABC):
"""The Agent context interface"""

@abstractmethod
def get_persistent_state_directory(self) -> Optional[str]:
"""Return the path of the agent disk. Return None if not configured."""
return None

@abstractmethod
def get_topic_producer(self) -> TopicProducer:
"""Return the topic producer"""
pass


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import threading
from concurrent.futures import Future
from io import BytesIO
from typing import Union, List, Tuple, Any, Optional, AsyncIterable
from typing import Union, List, Tuple, Any, Optional, AsyncIterable, Dict

import fastavro
import grpc
Expand All @@ -42,9 +42,11 @@
SourceResponse,
SinkRequest,
SinkResponse,
TopicProducerRecord,
TopicProducerWriteResult,
)
from langstream_grpc.proto.agent_pb2_grpc import AgentServiceServicer
from .api import Source, Sink, Processor, Record, Agent, AgentContext
from .api import Source, Sink, Processor, Record, Agent, AgentContext, TopicProducer
from .util import SimpleRecord, AvroValue


Expand All @@ -71,20 +73,52 @@ def wrap_in_record(record):


class AgentService(AgentServiceServicer):
def __init__(self, agent: Union[Agent, Source, Sink, Processor]):
def __init__(
self, agent: Union[Agent, Source, Sink, Processor], topic_producer_records
):
self.agent = agent
self.schema_id = 0
self.schemas = {}
self.client_schemas = {}
self.topic_producer_record_id = 0
self.topic_producer_records = topic_producer_records
self.topic_producer_records_pending: Dict[int, asyncio.Future] = {}

async def agent_info(self, _, __):
info = await acall_method_if_exists(self.agent, "agent_info") or {}
return InfoResponse(json_info=json.dumps(info))

async def get_topic_producer_records(self, request_iterator, context):
# TODO: to be implemented
async for _ in request_iterator:
yield
async def poll_topic_producer_records(self, context):
while True:
topic, record, future = await self.topic_producer_records.get()
# TODO: handle schemas
_, grpc_record = self.to_grpc_record(record)
self.topic_producer_record_id += 1
self.topic_producer_records_pending[self.topic_producer_record_id] = future
grpc_record.record_id = self.topic_producer_record_id
await context.write(TopicProducerRecord(topic=topic, record=grpc_record))

async def handle_write_results(self, context):
write_result = await context.read()
while write_result != grpc.aio.EOF:
future = self.topic_producer_records_pending[write_result.record_id]
if write_result.error:
future.set_exception(RuntimeError(write_result.error))
else:
future.set_result(None)
write_result = await context.read()

async def get_topic_producer_records(
self, requests: AsyncIterable[TopicProducerWriteResult], context
):
poll_task = asyncio.create_task(self.poll_topic_producer_records(context))
write_result_task = asyncio.create_task(self.handle_write_results(context))
done, pending = await asyncio.wait(
[poll_task, write_result_task], return_when=asyncio.FIRST_COMPLETED
)
pending.pop().cancel()
# propagate exception if needed
done.pop().result()

async def do_read(self, context, read_records):
last_record_id = 0
Expand Down Expand Up @@ -335,25 +369,47 @@ def crash_process():
os.exit(1)


async def init_agent(configuration, context) -> Agent:
async def init_agent(configuration, context, topic_producer_records) -> Agent:
full_class_name = configuration["className"]
class_name = full_class_name.split(".")[-1]
module_name = full_class_name[: -len(class_name) - 1]
module = importlib.import_module(module_name)
agent = getattr(module, class_name)()
context_impl = DefaultAgentContext(configuration, context)
context_impl = DefaultAgentContext(configuration, context, topic_producer_records)
await acall_method_if_exists(agent, "init", configuration, context_impl)
return agent


class DefaultTopicProducer(TopicProducer):
def __init__(
self, topic_producer_records: asyncio.Queue[Tuple[str, Record, asyncio.Future]]
):
self.topic_producer_records = topic_producer_records
self.event_loop = asyncio.get_running_loop()

async def awrite(self, topic: str, record: Record):
write_future = self.event_loop.create_future()
await self.topic_producer_records.put((topic, record, write_future))
return await write_future

def write(self, topic: str, record: Record) -> Future:
return asyncio.run_coroutine_threadsafe(
self.awrite(topic, record), self.event_loop
)


class DefaultAgentContext(AgentContext):
def __init__(self, configuration: dict, context: dict):
def __init__(self, configuration: dict, context: dict, topic_producer_records):
self.configuration = configuration
self.context = context
self.topic_producer = DefaultTopicProducer(topic_producer_records)

def get_persistent_state_directory(self) -> Optional[str]:
return self.context.get("persistentStateDirectory")

def get_topic_producer(self) -> TopicProducer:
return self.topic_producer


class AgentServer(object):
def __init__(self, target: str):
Expand All @@ -372,16 +428,20 @@ async def init(self, config, context):
value = env["value"]
logging.debug(f"Setting environment variable {key}={value}")
os.environ[key] = value
self.agent = await init_agent(configuration, json.loads(context))

async def start(self):
await acall_method_if_exists(self.agent, "start")
call_method_new_thread_if_exists(self.agent, "main", crash_process)
topic_producer_records = asyncio.Queue(1000)
self.agent = await init_agent(
configuration, json.loads(context), topic_producer_records
)

agent_pb2_grpc.add_AgentServiceServicer_to_server(
AgentService(self.agent), self.grpc_server
AgentService(self.agent, topic_producer_records), self.grpc_server
)

async def start(self):
await acall_method_if_exists(self.agent, "start")
call_method_new_thread_if_exists(self.agent, "main", crash_process)

await self.grpc_server.start()
logging.info("GRPC Server started, listening on " + self.target)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import List, Dict, Any, Optional

import grpc
import pytest

from langstream_grpc.api import Record, Processor, AgentContext
from langstream_grpc.proto.agent_pb2 import (
Record as GrpcRecord,
ProcessorRequest,
Value,
TopicProducerWriteResult,
)
from langstream_grpc.tests.server_and_stub import ServerAndStub


@pytest.mark.parametrize("klass", ["MyProcessor", "MyAsyncProcessor"])
async def test_topic_producer_success(klass):
async with ServerAndStub(
f"langstream_grpc.tests.test_grpc_topic_producer.{klass}"
) as server_and_stub:
process_call = server_and_stub.stub.process()
await process_call.write(
ProcessorRequest(records=[GrpcRecord(value=Value(string_value="test"))])
)

topic_producer_call = server_and_stub.stub.get_topic_producer_records()
topic_producer_record = await topic_producer_call.read()

assert topic_producer_record.topic == "topic-producer-topic"
assert topic_producer_record.record.value.string_value == "test"
await topic_producer_call.write(
TopicProducerWriteResult(record_id=topic_producer_record.record.record_id)
)

await topic_producer_call.done_writing()

processed = await process_call.read()
assert processed.results[0].records[0].value.string_value == "test"

await process_call.done_writing()


@pytest.mark.parametrize("klass", ["MyProcessor", "MyAsyncProcessor"])
async def test_topic_producer_write_error(klass):
async with ServerAndStub(
f"langstream_grpc.tests.test_grpc_topic_producer.{klass}"
) as server_and_stub:
process_call = server_and_stub.stub.process()
await process_call.write(
ProcessorRequest(records=[GrpcRecord(value=Value(string_value="test"))])
)

topic_producer_call = server_and_stub.stub.get_topic_producer_records()
topic_producer_record = await topic_producer_call.read()

assert topic_producer_record.topic == "topic-producer-topic"
assert topic_producer_record.record.value.string_value == "test"
await topic_producer_call.write(
TopicProducerWriteResult(
record_id=topic_producer_record.record.record_id, error="test-error"
)
)

await topic_producer_call.done_writing()

response = await process_call.read()
assert "test-error" in response.results[0].error

await process_call.done_writing()


async def test_topic_producer_invalid():
async with ServerAndStub(
"langstream_grpc.tests.test_grpc_topic_producer.MyFailingProcessor"
) as server_and_stub:
process_call = server_and_stub.stub.process()
await process_call.write(
ProcessorRequest(records=[GrpcRecord(value=Value(string_value="test"))])
)

topic_producer_call = server_and_stub.stub.get_topic_producer_records()
with pytest.raises(grpc.RpcError):
await topic_producer_call.read()


class MyProcessor(Processor):
def __init__(self):
self.context: Optional[AgentContext] = None

def init(self, config: Dict[str, Any], context: AgentContext):
self.context = context

def process(self, record: Record) -> List[Record]:
self.context.get_topic_producer().write("topic-producer-topic", record).result()
return [record]


class MyAsyncProcessor(MyProcessor):
async def process(self, record: Record) -> List[Record]:
await self.context.get_topic_producer().awrite("topic-producer-topic", record)
return [record]


class MyFailingProcessor(MyProcessor):
async def process(self, record: Record) -> List[Record]:
await self.context.get_topic_producer().awrite(
"topic-producer-topic", "invalid"
)
return [record]

0 comments on commit 00e4415

Please sign in to comment.