diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/api.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/api.py index 855a5cec0..298cc3574 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/api.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/api.py @@ -63,12 +63,28 @@ def headers(self) -> List[Tuple[str, Any]]: RecordType = Union[Record, dict, list, tuple] +class TopicProducer(ABC): + """The topic producer interface""" + + async def awrite(self, topic: str, record: Record): + """Write a record to a topic (for async methods).""" + pass + + def write(self, topic: str, record: Record) -> Future: + """Write a record to a topic (for non-async methods).""" + pass + + class AgentContext(ABC): """The Agent context interface""" - @abstractmethod def get_persistent_state_directory(self) -> Optional[str]: """Return the path of the agent disk. Return None if not configured.""" + return None + + @abstractmethod + def get_topic_producer(self) -> TopicProducer: + """Return the topic producer""" pass diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py index 988fd4532..089471550 100644 --- a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/grpc_service.py @@ -22,7 +22,7 @@ import threading from concurrent.futures import Future from io import BytesIO -from typing import Union, List, Tuple, Any, Optional, AsyncIterable +from typing import Union, List, Tuple, Any, Optional, AsyncIterable, Dict import fastavro import grpc @@ -42,9 +42,11 @@ SourceResponse, SinkRequest, SinkResponse, + TopicProducerRecord, + TopicProducerWriteResult, ) from langstream_grpc.proto.agent_pb2_grpc import AgentServiceServicer -from .api import Source, Sink, Processor, Record, Agent, AgentContext +from .api import Source, Sink, Processor, Record, Agent, AgentContext, TopicProducer from .util import SimpleRecord, AvroValue @@ -71,20 +73,52 @@ def wrap_in_record(record): class AgentService(AgentServiceServicer): - def __init__(self, agent: Union[Agent, Source, Sink, Processor]): + def __init__( + self, agent: Union[Agent, Source, Sink, Processor], topic_producer_records + ): self.agent = agent self.schema_id = 0 self.schemas = {} self.client_schemas = {} + self.topic_producer_record_id = 0 + self.topic_producer_records = topic_producer_records + self.topic_producer_records_pending: Dict[int, asyncio.Future] = {} async def agent_info(self, _, __): info = await acall_method_if_exists(self.agent, "agent_info") or {} return InfoResponse(json_info=json.dumps(info)) - async def get_topic_producer_records(self, request_iterator, context): - # TODO: to be implemented - async for _ in request_iterator: - yield + async def poll_topic_producer_records(self, context): + while True: + topic, record, future = await self.topic_producer_records.get() + # TODO: handle schemas + _, grpc_record = self.to_grpc_record(record) + self.topic_producer_record_id += 1 + self.topic_producer_records_pending[self.topic_producer_record_id] = future + grpc_record.record_id = self.topic_producer_record_id + await context.write(TopicProducerRecord(topic=topic, record=grpc_record)) + + async def handle_write_results(self, context): + write_result = await context.read() + while write_result != grpc.aio.EOF: + future = self.topic_producer_records_pending[write_result.record_id] + if write_result.error: + future.set_exception(RuntimeError(write_result.error)) + else: + future.set_result(None) + write_result = await context.read() + + async def get_topic_producer_records( + self, requests: AsyncIterable[TopicProducerWriteResult], context + ): + poll_task = asyncio.create_task(self.poll_topic_producer_records(context)) + write_result_task = asyncio.create_task(self.handle_write_results(context)) + done, pending = await asyncio.wait( + [poll_task, write_result_task], return_when=asyncio.FIRST_COMPLETED + ) + pending.pop().cancel() + # propagate exception if needed + done.pop().result() async def do_read(self, context, read_records): last_record_id = 0 @@ -335,25 +369,47 @@ def crash_process(): os.exit(1) -async def init_agent(configuration, context) -> Agent: +async def init_agent(configuration, context, topic_producer_records) -> Agent: full_class_name = configuration["className"] class_name = full_class_name.split(".")[-1] module_name = full_class_name[: -len(class_name) - 1] module = importlib.import_module(module_name) agent = getattr(module, class_name)() - context_impl = DefaultAgentContext(configuration, context) + context_impl = DefaultAgentContext(configuration, context, topic_producer_records) await acall_method_if_exists(agent, "init", configuration, context_impl) return agent +class DefaultTopicProducer(TopicProducer): + def __init__( + self, topic_producer_records: asyncio.Queue[Tuple[str, Record, asyncio.Future]] + ): + self.topic_producer_records = topic_producer_records + self.event_loop = asyncio.get_running_loop() + + async def awrite(self, topic: str, record: Record): + write_future = self.event_loop.create_future() + await self.topic_producer_records.put((topic, record, write_future)) + return await write_future + + def write(self, topic: str, record: Record) -> Future: + return asyncio.run_coroutine_threadsafe( + self.awrite(topic, record), self.event_loop + ) + + class DefaultAgentContext(AgentContext): - def __init__(self, configuration: dict, context: dict): + def __init__(self, configuration: dict, context: dict, topic_producer_records): self.configuration = configuration self.context = context + self.topic_producer = DefaultTopicProducer(topic_producer_records) def get_persistent_state_directory(self) -> Optional[str]: return self.context.get("persistentStateDirectory") + def get_topic_producer(self) -> TopicProducer: + return self.topic_producer + class AgentServer(object): def __init__(self, target: str): @@ -372,16 +428,20 @@ async def init(self, config, context): value = env["value"] logging.debug(f"Setting environment variable {key}={value}") os.environ[key] = value - self.agent = await init_agent(configuration, json.loads(context)) - async def start(self): - await acall_method_if_exists(self.agent, "start") - call_method_new_thread_if_exists(self.agent, "main", crash_process) + topic_producer_records = asyncio.Queue(1000) + self.agent = await init_agent( + configuration, json.loads(context), topic_producer_records + ) agent_pb2_grpc.add_AgentServiceServicer_to_server( - AgentService(self.agent), self.grpc_server + AgentService(self.agent, topic_producer_records), self.grpc_server ) + async def start(self): + await acall_method_if_exists(self.agent, "start") + call_method_new_thread_if_exists(self.agent, "main", crash_process) + await self.grpc_server.start() logging.info("GRPC Server started, listening on " + self.target) diff --git a/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_topic_producer.py b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_topic_producer.py new file mode 100644 index 000000000..550dfdd22 --- /dev/null +++ b/langstream-runtime/langstream-runtime-impl/src/main/python/langstream_grpc/tests/test_grpc_topic_producer.py @@ -0,0 +1,125 @@ +# +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List, Dict, Any, Optional + +import grpc +import pytest + +from langstream_grpc.api import Record, Processor, AgentContext +from langstream_grpc.proto.agent_pb2 import ( + Record as GrpcRecord, + ProcessorRequest, + Value, + TopicProducerWriteResult, +) +from langstream_grpc.tests.server_and_stub import ServerAndStub + + +@pytest.mark.parametrize("klass", ["MyProcessor", "MyAsyncProcessor"]) +async def test_topic_producer_success(klass): + async with ServerAndStub( + f"langstream_grpc.tests.test_grpc_topic_producer.{klass}" + ) as server_and_stub: + process_call = server_and_stub.stub.process() + await process_call.write( + ProcessorRequest(records=[GrpcRecord(value=Value(string_value="test"))]) + ) + + topic_producer_call = server_and_stub.stub.get_topic_producer_records() + topic_producer_record = await topic_producer_call.read() + + assert topic_producer_record.topic == "topic-producer-topic" + assert topic_producer_record.record.value.string_value == "test" + await topic_producer_call.write( + TopicProducerWriteResult(record_id=topic_producer_record.record.record_id) + ) + + await topic_producer_call.done_writing() + + processed = await process_call.read() + assert processed.results[0].records[0].value.string_value == "test" + + await process_call.done_writing() + + +@pytest.mark.parametrize("klass", ["MyProcessor", "MyAsyncProcessor"]) +async def test_topic_producer_write_error(klass): + async with ServerAndStub( + f"langstream_grpc.tests.test_grpc_topic_producer.{klass}" + ) as server_and_stub: + process_call = server_and_stub.stub.process() + await process_call.write( + ProcessorRequest(records=[GrpcRecord(value=Value(string_value="test"))]) + ) + + topic_producer_call = server_and_stub.stub.get_topic_producer_records() + topic_producer_record = await topic_producer_call.read() + + assert topic_producer_record.topic == "topic-producer-topic" + assert topic_producer_record.record.value.string_value == "test" + await topic_producer_call.write( + TopicProducerWriteResult( + record_id=topic_producer_record.record.record_id, error="test-error" + ) + ) + + await topic_producer_call.done_writing() + + response = await process_call.read() + assert "test-error" in response.results[0].error + + await process_call.done_writing() + + +async def test_topic_producer_invalid(): + async with ServerAndStub( + "langstream_grpc.tests.test_grpc_topic_producer.MyFailingProcessor" + ) as server_and_stub: + process_call = server_and_stub.stub.process() + await process_call.write( + ProcessorRequest(records=[GrpcRecord(value=Value(string_value="test"))]) + ) + + topic_producer_call = server_and_stub.stub.get_topic_producer_records() + with pytest.raises(grpc.RpcError): + await topic_producer_call.read() + + +class MyProcessor(Processor): + def __init__(self): + self.context: Optional[AgentContext] = None + + def init(self, config: Dict[str, Any], context: AgentContext): + self.context = context + + def process(self, record: Record) -> List[Record]: + self.context.get_topic_producer().write("topic-producer-topic", record).result() + return [record] + + +class MyAsyncProcessor(MyProcessor): + async def process(self, record: Record) -> List[Record]: + await self.context.get_topic_producer().awrite("topic-producer-topic", record) + return [record] + + +class MyFailingProcessor(MyProcessor): + async def process(self, record: Record) -> List[Record]: + await self.context.get_topic_producer().awrite( + "topic-producer-topic", "invalid" + ) + return [record]