diff --git a/python-sdk/indexify/cli.py b/python-sdk/indexify/cli.py index 37484e27e..36db65be2 100644 --- a/python-sdk/indexify/cli.py +++ b/python-sdk/indexify/cli.py @@ -45,9 +45,6 @@ console = Console(theme=custom_theme) app = typer.Typer(pretty_exceptions_enable=False, no_args_is_help=True) -config_path_option: Optional[str] = typer.Option( - None, help="Path to the TLS configuration file" -) @app.command( @@ -209,7 +206,9 @@ def executor( dev: Annotated[ bool, typer.Option("--dev", "-d", help="Run the executor in development mode") ] = False, - config_path: Optional[str] = config_path_option, + config_path: Optional[str] = typer.Option( + None, help="Path to the TLS configuration file" + ), executor_cache: Optional[str] = typer.Option( "~/.indexify/executor_cache", help="Path to the executor cache directory" ), @@ -264,11 +263,9 @@ def function_executor( function_executor_server_address: str = typer.Option( help="Function Executor server address" ), - indexify_server_address: str = typer.Option(help="Indexify server address"), dev: Annotated[ bool, typer.Option("--dev", "-d", help="Run the executor in development mode") ] = False, - config_path: Optional[str] = config_path_option, ): if not dev: configure_production_logging() @@ -276,15 +273,11 @@ def function_executor( logger.info( "starting function executor server", function_executor_server_address=function_executor_server_address, - indexify_server_address=indexify_server_address, - config_path=config_path, ) FunctionExecutorServer( server_address=function_executor_server_address, - service=FunctionExecutorService( - indexify_server_address=indexify_server_address, config_path=config_path - ), + service=FunctionExecutorService(), ).run() diff --git a/python-sdk/indexify/executor/downloader.py b/python-sdk/indexify/executor/downloader.py index 9c4cee09d..7feb8a0ff 100644 --- a/python-sdk/indexify/executor/downloader.py +++ b/python-sdk/indexify/executor/downloader.py @@ -142,7 +142,7 @@ async def _fetch_url( self, url: str, resource_description: str, logger: Any ) -> SerializedObject: logger.info(f"fetching {resource_description}", url=url) - response = await self._client.get(url) + response: httpx.Response = await self._client.get(url) try: response.raise_for_status() except httpx.HTTPStatusError as e: @@ -153,13 +153,17 @@ async def _fetch_url( ) raise - # We're hardcoding the content type currently used by Python SDK. It might change in the future. - # There's no other way for now to determine if the response is a bytes or string. - if response.headers["content-type"] == "application/octet-stream": - return SerializedObject( - bytes=response.content, content_type=response.headers["content-type"] - ) - else: - return SerializedObject( - string=response.text, content_type=response.headers["content-type"] - ) + return serialized_object_from_http_response(response) + + +def serialized_object_from_http_response(response: httpx.Response) -> SerializedObject: + # We're hardcoding the content type currently used by Python SDK. It might change in the future. + # There's no other way for now to determine if the response is a bytes or string. + if response.headers["content-type"] == "application/octet-stream": + return SerializedObject( + bytes=response.content, content_type=response.headers["content-type"] + ) + else: + return SerializedObject( + string=response.text, content_type=response.headers["content-type"] + ) diff --git a/python-sdk/indexify/executor/executor.py b/python-sdk/indexify/executor/executor.py index 91b46a0b3..7e4a935a5 100644 --- a/python-sdk/indexify/executor/executor.py +++ b/python-sdk/indexify/executor/executor.py @@ -42,16 +42,16 @@ def __init__( self._logger.info("running the extractor with TLS enabled") protocol = "https" + self._server_addr = server_addr + self._base_url = f"{protocol}://{self._server_addr}" + self._code_path = code_path self._function_worker = FunctionWorker( function_executor_factory=ProcessFunctionExecutorFactory( - indexify_server_address=server_addr, development_mode=development_mode, - config_path=config_path, - ) + ), + base_url=self._base_url, + config_path=config_path, ) - self._server_addr = server_addr - self._base_url = f"{protocol}://{self._server_addr}" - self._code_path = code_path self._downloader = Downloader( code_path=code_path, base_url=self._base_url, config_path=config_path ) diff --git a/python-sdk/indexify/executor/function_executor/function_executor_map.py b/python-sdk/indexify/executor/function_executor/function_executor_map.py index ffeda06c8..5d7eed77a 100644 --- a/python-sdk/indexify/executor/function_executor/function_executor_map.py +++ b/python-sdk/indexify/executor/function_executor/function_executor_map.py @@ -3,6 +3,7 @@ import grpc +from indexify.common_util import get_httpx_client from indexify.function_executor.proto.function_executor_pb2 import ( InitializeRequest, InitializeResponse, @@ -13,6 +14,7 @@ from .function_executor import FunctionExecutor from .function_executor_factory import FunctionExecutorFactory +from .invocation_state_client import InvocationStateClient class FunctionExecutorMap: @@ -21,10 +23,17 @@ class FunctionExecutorMap: The map is safe to use by multiple couroutines running in event loop on the same thread but it's not thread safe (can't be used from different threads concurrently).""" - def __init__(self, factory: FunctionExecutorFactory): + def __init__( + self, + factory: FunctionExecutorFactory, + base_url: str, + config_path: Optional[str], + ): self._factory = factory + self._base_url = base_url + self._config_path = config_path # Map of initialized Function executors ready to run tasks. - # function ID -> FunctionExecutor + # Function ID -> FunctionExecutor. self._executors: Dict[str, FunctionExecutor] = {} # We have to do all operations under this lock because we need to ensure # that we don't create more Function Executors than required. This is important @@ -52,6 +61,7 @@ async def get_or_create( return self._executors[id] executor: Optional[FunctionExecutor] = None + invocation_state_client: Optional[InvocationStateClient] = None try: executor = await self._factory.create(logger, state=initial_state) channel: grpc.aio.Channel = await executor.channel() @@ -61,7 +71,22 @@ async def get_or_create( ) if not initialize_response.success: raise Exception("initialize RPC failed at function executor") + invocation_state_client = InvocationStateClient( + stub=stub, + base_url=self._base_url, + http_client=get_httpx_client( + config_path=self._config_path, make_async=True + ), + graph=initialize_request.graph_name, + namespace=initialize_request.namespace, + logger=logger, + ) + await invocation_state_client.start() + # This is dirty but requires refactoring to implement properly. + initial_state.invocation_state_client = invocation_state_client except Exception: + if invocation_state_client is not None: + await invocation_state_client.destroy() if executor is not None: await self._factory.destroy(executor=executor, logger=logger) # Function Executor creation or initialization failed. @@ -82,10 +107,14 @@ async def delete( # Function Executor was already deleted or replaced and the caller is not aware of this. return del self._executors[id] + if function_executor.state().invocation_state_client is not None: + await function_executor.state().invocation_state_client.destroy() await self._factory.destroy(executor=function_executor, logger=logger) async def clear(self, logger): async with self._executors_lock: while self._executors: id, function_executor = self._executors.popitem() + if function_executor.state().invocation_state_client is not None: + await function_executor.state().invocation_state_client.destroy() await self._factory.destroy(function_executor, logger) diff --git a/python-sdk/indexify/executor/function_executor/invocation_state_client.py b/python-sdk/indexify/executor/function_executor/invocation_state_client.py new file mode 100644 index 000000000..283e0a5e3 --- /dev/null +++ b/python-sdk/indexify/executor/function_executor/invocation_state_client.py @@ -0,0 +1,232 @@ +import asyncio +from typing import Any, AsyncGenerator, Optional, Union + +import grpc +import httpx + +from indexify.executor.downloader import serialized_object_from_http_response +from indexify.function_executor.proto.function_executor_pb2 import ( + GetInvocationStateResponse, + InvocationStateRequest, + InvocationStateResponse, + SerializedObject, + SetInvocationStateResponse, +) +from indexify.function_executor.proto.function_executor_pb2_grpc import ( + FunctionExecutorStub, +) +from indexify.function_executor.proto.message_validator import MessageValidator + + +class InvocationStateClient: + """InvocationStateClient is a client for the invocation state server of a Function Executor. + + The client initializes the Function Executor's invocation state server and executes requests + it sends to the client. + """ + + def __init__( + self, + stub: FunctionExecutorStub, + base_url: str, + http_client: httpx.AsyncClient, + graph: str, + namespace: str, + logger: Any, + ): + self._stub: FunctionExecutorStub = stub + self._base_url: str = base_url + self._http_client: httpx.AsyncClient = http_client + self._graph: str = graph + self._namespace: str = namespace + self._logger: Any = logger.bind( + module=__name__, graph=graph, namespace=namespace + ) + self._client_response_queue: asyncio.Queue[ + Union[InvocationStateResponse, str] + ] = asyncio.Queue() + self._task_id_to_invocation_id: dict[str, str] = {} + self._request_loop_task: Optional[asyncio.Task] = None + + async def start(self) -> None: + """Starts the invocation state client. + + This method initializes the Function Executor's invocation state server first. + This is why this method needs to be awaited before executing any tasks on the Function Executor + that might use invocation state feature.""" + server_requests = self._stub.initialize_invocation_state_server( + self._response_generator() + ) + self._request_loop_task = asyncio.create_task( + self._request_loop(server_requests) + ) + + def add_task_to_invocation_id_entry(self, task_id: str, invocation_id: str) -> None: + """Adds a task ID to invocation ID entry to the client's internal state. + + This allows to authorize requests to the invocation state server. + If a request is not comming from the task ID that was added here then it will + be rejected. It's caller's responsibility to only add task IDs that are being + executed by the Function Executor so the Function Executor can't get access to + invocation state of tasks it doesn't run.""" + self._task_id_to_invocation_id[task_id] = invocation_id + + def remove_task_to_invocation_id_entry(self, task_id: str) -> None: + del self._task_id_to_invocation_id[task_id] + + async def destroy(self) -> None: + if self._request_loop_task is not None: + self._request_loop_task.cancel() + await self._client_response_queue.put("shutdown") + + async def _request_loop( + self, server_requests: AsyncGenerator[InvocationStateRequest, None] + ) -> None: + try: + async for request in server_requests: + await self._process_request_no_raise(request) + except grpc.aio.AioRpcError: + # Reading from the stream failed. + # This is a normal situation when the server is shutting down. + pass + except asyncio.CancelledError: + # This async task was cancelled by destroy(). Normal situation too. + pass + + async def _process_request_no_raise(self, request: InvocationStateRequest) -> None: + try: + await self._process_request(request) + except Exception as e: + try: + await self._client_response_queue.put( + InvocationStateResponse( + request_id=request.request_id, + success=False, + ) + ) + except Exception as ee: + self._logger.error("failed to send error response", exc_info=ee) + + self._logger.error( + "failed to process request", + exc_info=e, + request_id=request.request_id, + ) + + async def _process_request( + self, request: InvocationStateRequest + ) -> InvocationStateResponse: + self._validate_request(request) + # This is a very important check. We don't trust invocation ID and task ID + # supplied by Function Executor. If a task ID entry doesn't exist then it's + # a privelege escalation attempt. + invocation_id: str = self._task_id_to_invocation_id[request.task_id] + if request.HasField("get"): + value: Optional[SerializedObject] = await self._get_server_state( + invocation_id, request.get.key + ) + await self._client_response_queue.put( + InvocationStateResponse( + request_id=request.request_id, + success=True, + get=GetInvocationStateResponse( + key=request.get.key, + value=value, + ), + ) + ) + elif request.HasField("set"): + await self._set_server_state( + invocation_id, request.set.key, request.set.value + ) + await self._client_response_queue.put( + InvocationStateResponse( + request_id=request.request_id, + success=True, + set=SetInvocationStateResponse(), + ) + ) + + async def _response_generator( + self, + ) -> AsyncGenerator[InvocationStateResponse, None]: + while True: + response = await self._client_response_queue.get() + # Hacky cancellation of the generator. + if response == "shutdown": + break + yield response + + async def _set_server_state( + self, invocation_id: str, key: str, value: SerializedObject + ) -> None: + url: str = ( + f"{self._base_url}/internal/namespaces/{self._namespace}/compute_graphs/{self._graph}/invocations/{invocation_id}/ctx/{key}" + ) + payload = value.bytes if value.HasField("bytes") else value.string + + response = await self._http_client.post( + url=url, + files=[ + ( + "value", + ("value", payload, value.content_type), + ), + ], + ) + + try: + response.raise_for_status() + except Exception as e: + self._logger.error( + "failed to set graph invocation state", + invocation_id=invocation_id, + key=key, + status_code=response.status_code, + error=response.text, + exc_info=e, + ) + raise + + async def _get_server_state( + self, invocation_id: str, key: str + ) -> Optional[SerializedObject]: + url: str = ( + f"{self._base_url}/internal/namespaces/{self._namespace}/compute_graphs/{self._graph}/invocations/{invocation_id}/ctx/{key}" + ) + + response: httpx.Response = await self._http_client.get(url) + if response.status_code == 404: + return None + + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + self._logger.error( + f"failed to download graph invocation state value", + invocation_id=invocation_id, + key=key, + status_code=response.status_code, + error=response.text, + exc_info=e, + ) + raise + + return serialized_object_from_http_response(response) + + def _validate_request(self, request: InvocationStateRequest) -> None: + ( + MessageValidator(request) + .required_field("request_id") + .required_field("task_id") + ) + if request.HasField("get"): + (MessageValidator(request.get).required_field("key")) + elif request.HasField("set"): + ( + MessageValidator(request.set) + .required_field("key") + .required_serialized_object("value") + ) + else: + raise ValueError("unknown request type") diff --git a/python-sdk/indexify/executor/function_executor/process_function_executor_factory.py b/python-sdk/indexify/executor/function_executor/process_function_executor_factory.py index 6181e819f..c753ea4e1 100644 --- a/python-sdk/indexify/executor/function_executor/process_function_executor_factory.py +++ b/python-sdk/indexify/executor/function_executor/process_function_executor_factory.py @@ -8,13 +8,9 @@ class ProcessFunctionExecutorFactory(FunctionExecutorFactory): def __init__( self, - indexify_server_address: str, development_mode: bool, - config_path: Optional[str], ): - self._indexify_server_address: str = indexify_server_address self._development_mode: bool = development_mode - self._config_path: Optional[str] = config_path # Registred ports range end at 49151. We start from 50000 to hopefully avoid conflicts. self._free_ports = set(range(50000, 51000)) @@ -30,13 +26,9 @@ async def create( "function-executor", "--function-executor-server-address", _server_address(port), - "--indexify-server-address", - self._indexify_server_address, ] if self._development_mode: args.append("--dev") - if self._config_path is not None: - args.extend(["--config-path", self._config_path]) # Run the process with our stdout, stderr. We want to see process logs and exceptions in our process output. # This is useful for dubugging. Customer function stdout and stderr is captured and returned in the response # so we won't see it in our process outputs. This is the right behavior as customer function stdout and stderr diff --git a/python-sdk/indexify/executor/function_worker.py b/python-sdk/indexify/executor/function_worker.py index 297bc750c..8cb79f9ff 100644 --- a/python-sdk/indexify/executor/function_worker.py +++ b/python-sdk/indexify/executor/function_worker.py @@ -23,6 +23,7 @@ FunctionExecutorFactory, ) from .function_executor.function_executor_map import FunctionExecutorMap +from .function_executor.invocation_state_client import InvocationStateClient class FunctionWorkerInput: @@ -63,15 +64,28 @@ def __init__( function_id_with_version: str, function_id_without_version: str, ongoing_tasks_count: int, + invocation_state_client: Optional[InvocationStateClient] = None, ): self.function_id_with_version: str = function_id_with_version self.function_id_without_version: str = function_id_without_version self.ongoing_tasks_count: int = ongoing_tasks_count + self.invocation_state_client: Optional[InvocationStateClient] = ( + invocation_state_client + ) class FunctionWorker: - def __init__(self, function_executor_factory: FunctionExecutorFactory): - self._function_executors = FunctionExecutorMap(function_executor_factory) + def __init__( + self, + function_executor_factory: FunctionExecutorFactory, + base_url: str, + config_path: Optional[str], + ): + self._function_executors = FunctionExecutorMap( + factory=function_executor_factory, + base_url=base_url, + config_path=config_path, + ) async def run(self, input: FunctionWorkerInput) -> FunctionWorkerOutput: logger = _logger(input.task) @@ -176,6 +190,9 @@ async def _run_in_executor( input.function_input.init_value ) channel: grpc.aio.Channel = await function_executor.channel() + function_executor.state().invocation_state_client.add_task_to_invocation_id_entry( + task_id=input.task.id, invocation_id=input.task.invocation_id + ) run_task_response: RunTaskResponse = await FunctionExecutorStub( channel ).run_task(run_task_request) @@ -184,6 +201,9 @@ async def _run_in_executor( # If this Function Executor was destroyed then it's not # visible in the map but we still have a reference to it. function_executor.state().ongoing_tasks_count -= 1 + function_executor.state().invocation_state_client.remove_task_to_invocation_id_entry( + input.task.id + ) async def shutdown(self) -> None: await self._function_executors.clear( diff --git a/python-sdk/indexify/function_executor/function_executor_service.py b/python-sdk/indexify/function_executor/function_executor_service.py index 3ae6e23f2..3d7e19bdc 100644 --- a/python-sdk/indexify/function_executor/function_executor_service.py +++ b/python-sdk/indexify/function_executor/function_executor_service.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Iterator, Optional, Union import grpc import structlog @@ -14,9 +14,14 @@ RequestValidator as RunTaskRequestValidator, ) from .initialize_request_validator import InitializeRequestValidator +from .invocation_state.invocation_state_proxy_server import ( + InvocationStateProxyServer, +) +from .invocation_state.proxied_invocation_state import ProxiedInvocationState from .proto.function_executor_pb2 import ( InitializeRequest, InitializeResponse, + InvocationStateResponse, RunTaskRequest, RunTaskResponse, ) @@ -24,15 +29,14 @@ class FunctionExecutorService(FunctionExecutorServicer): - def __init__(self, indexify_server_address: str, config_path: Optional[str]): - self._indexify_server_address = indexify_server_address - self._config_path = config_path + def __init__(self): self._logger = structlog.get_logger(module=__name__) self._namespace: Optional[str] = None self._graph_name: Optional[str] = None self._graph_version: Optional[int] = None self._function_name: Optional[str] = None self._function: Optional[Union[IndexifyFunction, IndexifyRouter]] = None + self._invocation_state_proxy_server: Optional[InvocationStateProxyServer] = None def initialize( self, request: InitializeRequest, context: grpc.ServicerContext @@ -65,6 +69,17 @@ def initialize( return InitializeResponse(success=True) + def initialize_invocation_state_server( + self, + client_responses: Iterator[InvocationStateResponse], + context: grpc.ServicerContext, + ): + self._invocation_state_proxy_server = InvocationStateProxyServer( + client_responses, self._logger + ) + self._logger.info("initialized invocation proxy server") + yield from self._invocation_state_proxy_server.run() + def run_task( self, request: RunTaskRequest, context: grpc.ServicerContext ) -> RunTaskResponse: @@ -75,12 +90,12 @@ def run_task( RunTaskRequestValidator(request=request).check() return RunTaskHandler( request=request, - namespace=self._namespace, graph_name=self._graph_name, graph_version=self._graph_version, function_name=self._function_name, function=self._function, + invocation_state=ProxiedInvocationState( + request.task_id, self._invocation_state_proxy_server + ), logger=self._logger, - indexify_server_addr=self._indexify_server_address, - config_path=self._config_path, ).run() diff --git a/python-sdk/indexify/function_executor/handlers/run_function/handler.py b/python-sdk/indexify/function_executor/handlers/run_function/handler.py index 7b33235aa..07f8f946f 100644 --- a/python-sdk/indexify/function_executor/handlers/run_function/handler.py +++ b/python-sdk/indexify/function_executor/handlers/run_function/handler.py @@ -16,6 +16,9 @@ IndexifyRouter, RouterCallResult, ) +from indexify.functions_sdk.invocation_state.invocation_state import ( + InvocationState, +) from indexify.http_client import IndexifyClient from .function_inputs_loader import FunctionInputs, FunctionInputsLoader @@ -26,14 +29,12 @@ class Handler: def __init__( self, request: RunTaskRequest, - namespace: str, graph_name: str, graph_version: int, function_name: str, function: Union[IndexifyFunction, IndexifyRouter], + invocation_state: InvocationState, logger: Any, - indexify_server_addr: str, - config_path: Optional[str], ): self._function_name: str = function_name self._logger = logger.bind( @@ -53,12 +54,7 @@ def __init__( invocation_id=request.graph_invocation_id, graph_name=graph_name, graph_version=str(graph_version), - indexify_client=_indexify_client( - logger=self._logger, - namespace=namespace, - indexify_server_addr=indexify_server_addr, - config_path=config_path, - ), + invocation_state=invocation_state, ), ) diff --git a/python-sdk/indexify/function_executor/invocation_state/invocation_state_proxy_server.py b/python-sdk/indexify/function_executor/invocation_state/invocation_state_proxy_server.py new file mode 100644 index 000000000..b04d5461f --- /dev/null +++ b/python-sdk/indexify/function_executor/invocation_state/invocation_state_proxy_server.py @@ -0,0 +1,170 @@ +import queue +import threading +from typing import Any, Iterator, Optional + +from indexify.functions_sdk.object_serializer import ( + CloudPickleSerializer, + get_serializer, +) + +from ..proto.function_executor_pb2 import ( + GetInvocationStateRequest, + InvocationStateRequest, + InvocationStateResponse, + SerializedObject, + SetInvocationStateRequest, +) +from .response_validator import ResponseValidator + + +class InvocationStateProxyServer: + """A gRPC server that proxies InvocationState calls to the gRPC client. + + The gRPC client is responsible for the actual implementation of the InvocationState. + We do the proxying to remove authorization logic and credentials from Function Executor. + This improves security posture of Function Executor because it may run untrusted code. + """ + + def __init__( + self, client_responses: Iterator[InvocationStateResponse], logger: Any + ): + self._client_responses: Iterator[InvocationStateResponse] = client_responses + self._logger: Any = logger.bind(module=__name__) + self._reciever_thread: threading.Thread = threading.Thread( + target=self._reciever + ) + self._request_queue: queue.SimpleQueue = queue.SimpleQueue() + # This lock protects everything below. + self._lock: threading.Lock = threading.Lock() + # Python supports big integers natively so we don't need + # to be worried about interger overflows. + self._request_seq_num: int = 0 + # Request ID -> Client Response. + self._response_map: dict[str, InvocationStateResponse] = {} + self._new_response: threading.Condition = threading.Condition(self._lock) + + def run(self) -> Iterator[InvocationStateRequest]: + # There's no need to implement shutdown of the server and its threads because + # the server lives while the Function Executor process lives. + self._reciever_thread.start() + yield from self._sender() + + def _reciever(self) -> None: + self._logger.info("reciever thread started") + try: + for response in self._client_responses: + validator = ResponseValidator(response) + try: + validator.check() + except ValueError as e: + self._logger.error("invalid response from the client", exc_info=e) + continue + + with self._lock: + self._response_map[response.request_id] = response + self._new_response.notify_all() + except Exception as e: + self._logger.error("error in reciever thread, exiting", exc_info=e) + + def _sender(self) -> Iterator[InvocationStateRequest]: + while True: + yield self._request_queue.get() + with self._lock: + # Wait until we get a response for the request. + # This allows to ensure a serialized order of reads and writes so + # we can avoid a read returning not previously written value. + self._new_response.wait() + + def set(self, task_id: str, key: str, value: Any) -> None: + with self._lock: + request_id: str = str(self._request_seq_num) + self._request_seq_num += 1 + + # We currently use CloudPickleSerializer for function inputs, + # outputs and invocation state values. This provides consistent UX. + request = InvocationStateRequest( + request_id=request_id, + task_id=task_id, + set=SetInvocationStateRequest( + key=key, + value=SerializedObject( + content_type=CloudPickleSerializer.content_type, + bytes=CloudPickleSerializer.serialize(value), + ), + ), + ) + self._request_queue.put(request) + while request_id not in self._response_map: + self._new_response.wait() + + response: InvocationStateResponse = self._response_map.pop(request_id) + if response.request_id != request_id: + self._logger.error( + "response request_id doesn't match actual request_id", + request_id=request_id, + response=response, + ) + raise RuntimeError( + "response request_id doesn't match actual request_id" + ) + if not response.HasField("set"): + self._logger.error( + "set response is missing in the client response", + request_id=request_id, + response=response, + ) + raise RuntimeError("set response is missing in the client response") + if not response.success: + self._logger.error( + "failed to set the invocation state for key", + key=key, + ) + raise RuntimeError("failed to set the invocation state for key") + + def get(self, task_id: str, key: str) -> Optional[Any]: + with self._lock: + request_id: str = str(self._request_seq_num) + self._request_seq_num += 1 + + request = InvocationStateRequest( + request_id=request_id, + task_id=task_id, + get=GetInvocationStateRequest( + key=key, + ), + ) + self._request_queue.put(request) + while request_id not in self._response_map: + self._new_response.wait() + + response: InvocationStateResponse = self._response_map.pop(request_id) + if response.request_id != request_id: + self._logger.error( + "response request_id doesn't match actual request_id", + request_id=request_id, + response=response, + ) + raise RuntimeError( + "response request_id doesn't match actual request_id" + ) + if not response.HasField("get"): + self._logger.error( + "get response is missing in the client response", + request_id=request_id, + response=response, + ) + raise RuntimeError("get response is missing in the client response") + if not response.success: + self._logger.error( + "failed to get the invocation state for key", + key=key, + ) + raise RuntimeError("failed to get the invocation state for key") + if not response.get.HasField("value"): + return None + + return get_serializer(response.get.value.content_type).deserialize( + response.get.value.bytes + if response.get.value.HasField("bytes") + else response.get.value.string + ) diff --git a/python-sdk/indexify/function_executor/invocation_state/proxied_invocation_state.py b/python-sdk/indexify/function_executor/invocation_state/proxied_invocation_state.py new file mode 100644 index 000000000..92c6194ff --- /dev/null +++ b/python-sdk/indexify/function_executor/invocation_state/proxied_invocation_state.py @@ -0,0 +1,24 @@ +from typing import Any, Optional + +from indexify.functions_sdk.invocation_state.invocation_state import ( + InvocationState, +) + +from .invocation_state_proxy_server import InvocationStateProxyServer + + +class ProxiedInvocationState(InvocationState): + """InvocationState that proxies the calls via InvocationStateProxyServer.""" + + def __init__(self, task_id: str, proxy_server: InvocationStateProxyServer): + self._task_id: str = task_id + self._proxy_server: InvocationStateProxyServer = proxy_server + + def set(self, key: str, value: Any) -> None: + """Set a key-value pair.""" + self._proxy_server.set(self._task_id, key, value) + + def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: + """Get a value by key. If the key does not exist, return the default value.""" + value: Optional[Any] = self._proxy_server.get(self._task_id, key) + return default if value is None else value diff --git a/python-sdk/indexify/function_executor/invocation_state/response_validator.py b/python-sdk/indexify/function_executor/invocation_state/response_validator.py new file mode 100644 index 000000000..0d9b06d15 --- /dev/null +++ b/python-sdk/indexify/function_executor/invocation_state/response_validator.py @@ -0,0 +1,29 @@ +from ..proto.function_executor_pb2 import InvocationStateResponse +from ..proto.message_validator import MessageValidator + + +class ResponseValidator(MessageValidator): + def __init__(self, response: InvocationStateResponse): + self._response = response + + def check(self): + """Validates the request. + + Raises: ValueError: If the response is invalid. + """ + ( + MessageValidator(self._response) + .required_field("request_id") + .required_field("success") + ) + + if self._response.HasField("set"): + pass + elif self._response.HasField("get"): + ( + MessageValidator(self._response.get) + .required_field("key") + .optional_serialized_object("value") + ) + else: + raise ValueError(f"Unknown response type: {self._response}") diff --git a/python-sdk/indexify/function_executor/proto/function_executor.proto b/python-sdk/indexify/function_executor/proto/function_executor.proto index e417ff1b8..84c68084f 100644 --- a/python-sdk/indexify/function_executor/proto/function_executor.proto +++ b/python-sdk/indexify/function_executor/proto/function_executor.proto @@ -33,6 +33,49 @@ message InitializeResponse { optional bool success = 1; } +message SetInvocationStateRequest { + optional string key = 1; + optional SerializedObject value = 2; +} + +message SetInvocationStateResponse {} + +message GetInvocationStateRequest { + optional string key = 1; +} + +message GetInvocationStateResponse { + optional string key = 1; + optional SerializedObject value = 2; +} + +// InvocationStateRequest is sent by RPC Server to the client +// to perform actions on a task's graph invocation state. +message InvocationStateRequest { + // The ID of the request sent by the client. + // Must be unique per Function Executor. + optional string request_id = 1; + // The ID of the task initiated the request. + optional string task_id = 2; + oneof request { + SetInvocationStateRequest set = 3; + GetInvocationStateRequest get = 4; + } +} + +// InvocationStateResponse is sent by RPC client to the Server. +// A response contains the result of the action performed on the +// task's graph invocation state. +message InvocationStateResponse { + // The id of the request this response is for. + optional string request_id = 1; + optional bool success = 2; + oneof response { + SetInvocationStateResponse set = 3; + GetInvocationStateResponse get = 4; + } +} + message FunctionOutput { repeated SerializedObject outputs = 1; } @@ -64,6 +107,10 @@ service FunctionExecutor { // once per Function Executor as it can only run a single function. // It should be called before calling RunTask for the function. rpc initialize(InitializeRequest) returns (InitializeResponse); + // Initializes a server that sends requests to the client to perform actions on + // a task's graph invocation state. This method is called only once per Function Executor + // It should be called before calling RunTask for the function. + rpc initialize_invocation_state_server(stream InvocationStateResponse) returns (stream InvocationStateRequest); // Executes the task defined in the request. // Multiple tasks can be running in parallel. rpc run_task(RunTaskRequest) returns (RunTaskResponse); diff --git a/python-sdk/indexify/function_executor/proto/function_executor_pb2.py b/python-sdk/indexify/function_executor/proto/function_executor_pb2.py index d36651beb..354f3ebea 100644 --- a/python-sdk/indexify/function_executor/proto/function_executor_pb2.py +++ b/python-sdk/indexify/function_executor/proto/function_executor_pb2.py @@ -24,7 +24,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n8indexify/function_executor/proto/function_executor.proto\x12\x19\x66unction_executor_service"i\n\x10SerializedObject\x12\x0f\n\x05\x62ytes\x18\x01 \x01(\x0cH\x00\x12\x10\n\x06string\x18\x02 \x01(\tH\x00\x12\x19\n\x0c\x63ontent_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x42\x06\n\x04\x64\x61taB\x0f\n\r_content_type"\x88\x02\n\x11InitializeRequest\x12\x16\n\tnamespace\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ngraph_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x1a\n\rgraph_version\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x1a\n\rfunction_name\x18\x05 \x01(\tH\x03\x88\x01\x01\x12?\n\x05graph\x18\x07 \x01(\x0b\x32+.function_executor_service.SerializedObjectH\x04\x88\x01\x01\x42\x0c\n\n_namespaceB\r\n\x0b_graph_nameB\x10\n\x0e_graph_versionB\x10\n\x0e_function_nameB\x08\n\x06_graph"6\n\x12InitializeResponse\x12\x14\n\x07success\x18\x01 \x01(\x08H\x00\x88\x01\x01\x42\n\n\x08_success"N\n\x0e\x46unctionOutput\x12<\n\x07outputs\x18\x01 \x03(\x0b\x32+.function_executor_service.SerializedObject"\x1d\n\x0cRouterOutput\x12\r\n\x05\x65\x64ges\x18\x01 \x03(\t"\xb0\x02\n\x0eRunTaskRequest\x12 \n\x13graph_invocation_id\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x14\n\x07task_id\x18\x06 \x01(\tH\x01\x88\x01\x01\x12H\n\x0e\x66unction_input\x18\t \x01(\x0b\x32+.function_executor_service.SerializedObjectH\x02\x88\x01\x01\x12M\n\x13\x66unction_init_value\x18\n \x01(\x0b\x32+.function_executor_service.SerializedObjectH\x03\x88\x01\x01\x42\x16\n\x14_graph_invocation_idB\n\n\x08_task_idB\x11\n\x0f_function_inputB\x16\n\x14_function_init_value"\xf1\x02\n\x0fRunTaskResponse\x12\x14\n\x07task_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12G\n\x0f\x66unction_output\x18\x02 \x01(\x0b\x32).function_executor_service.FunctionOutputH\x01\x88\x01\x01\x12\x43\n\rrouter_output\x18\x03 \x01(\x0b\x32\'.function_executor_service.RouterOutputH\x02\x88\x01\x01\x12\x13\n\x06stdout\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x13\n\x06stderr\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x17\n\nis_reducer\x18\x06 \x01(\x08H\x05\x88\x01\x01\x12\x14\n\x07success\x18\x07 \x01(\x08H\x06\x88\x01\x01\x42\n\n\x08_task_idB\x12\n\x10_function_outputB\x10\n\x0e_router_outputB\t\n\x07_stdoutB\t\n\x07_stderrB\r\n\x0b_is_reducerB\n\n\x08_success2\xe0\x01\n\x10\x46unctionExecutor\x12i\n\ninitialize\x12,.function_executor_service.InitializeRequest\x1a-.function_executor_service.InitializeResponse\x12\x61\n\x08run_task\x12).function_executor_service.RunTaskRequest\x1a*.function_executor_service.RunTaskResponseb\x06proto3' + b'\n8indexify/function_executor/proto/function_executor.proto\x12\x19\x66unction_executor_service"i\n\x10SerializedObject\x12\x0f\n\x05\x62ytes\x18\x01 \x01(\x0cH\x00\x12\x10\n\x06string\x18\x02 \x01(\tH\x00\x12\x19\n\x0c\x63ontent_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x42\x06\n\x04\x64\x61taB\x0f\n\r_content_type"\x88\x02\n\x11InitializeRequest\x12\x16\n\tnamespace\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x17\n\ngraph_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x1a\n\rgraph_version\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x1a\n\rfunction_name\x18\x05 \x01(\tH\x03\x88\x01\x01\x12?\n\x05graph\x18\x07 \x01(\x0b\x32+.function_executor_service.SerializedObjectH\x04\x88\x01\x01\x42\x0c\n\n_namespaceB\r\n\x0b_graph_nameB\x10\n\x0e_graph_versionB\x10\n\x0e_function_nameB\x08\n\x06_graph"6\n\x12InitializeResponse\x12\x14\n\x07success\x18\x01 \x01(\x08H\x00\x88\x01\x01\x42\n\n\x08_success"\x80\x01\n\x19SetInvocationStateRequest\x12\x10\n\x03key\x18\x01 \x01(\tH\x00\x88\x01\x01\x12?\n\x05value\x18\x02 \x01(\x0b\x32+.function_executor_service.SerializedObjectH\x01\x88\x01\x01\x42\x06\n\x04_keyB\x08\n\x06_value"\x1c\n\x1aSetInvocationStateResponse"5\n\x19GetInvocationStateRequest\x12\x10\n\x03key\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x06\n\x04_key"\x81\x01\n\x1aGetInvocationStateResponse\x12\x10\n\x03key\x18\x01 \x01(\tH\x00\x88\x01\x01\x12?\n\x05value\x18\x02 \x01(\x0b\x32+.function_executor_service.SerializedObjectH\x01\x88\x01\x01\x42\x06\n\x04_keyB\x08\n\x06_value"\xf7\x01\n\x16InvocationStateRequest\x12\x17\n\nrequest_id\x18\x01 \x01(\tH\x01\x88\x01\x01\x12\x14\n\x07task_id\x18\x02 \x01(\tH\x02\x88\x01\x01\x12\x43\n\x03set\x18\x03 \x01(\x0b\x32\x34.function_executor_service.SetInvocationStateRequestH\x00\x12\x43\n\x03get\x18\x04 \x01(\x0b\x32\x34.function_executor_service.GetInvocationStateRequestH\x00\x42\t\n\x07requestB\r\n\x0b_request_idB\n\n\x08_task_id"\xfb\x01\n\x17InvocationStateResponse\x12\x17\n\nrequest_id\x18\x01 \x01(\tH\x01\x88\x01\x01\x12\x14\n\x07success\x18\x02 \x01(\x08H\x02\x88\x01\x01\x12\x44\n\x03set\x18\x03 \x01(\x0b\x32\x35.function_executor_service.SetInvocationStateResponseH\x00\x12\x44\n\x03get\x18\x04 \x01(\x0b\x32\x35.function_executor_service.GetInvocationStateResponseH\x00\x42\n\n\x08responseB\r\n\x0b_request_idB\n\n\x08_success"N\n\x0e\x46unctionOutput\x12<\n\x07outputs\x18\x01 \x03(\x0b\x32+.function_executor_service.SerializedObject"\x1d\n\x0cRouterOutput\x12\r\n\x05\x65\x64ges\x18\x01 \x03(\t"\xb0\x02\n\x0eRunTaskRequest\x12 \n\x13graph_invocation_id\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x14\n\x07task_id\x18\x06 \x01(\tH\x01\x88\x01\x01\x12H\n\x0e\x66unction_input\x18\t \x01(\x0b\x32+.function_executor_service.SerializedObjectH\x02\x88\x01\x01\x12M\n\x13\x66unction_init_value\x18\n \x01(\x0b\x32+.function_executor_service.SerializedObjectH\x03\x88\x01\x01\x42\x16\n\x14_graph_invocation_idB\n\n\x08_task_idB\x11\n\x0f_function_inputB\x16\n\x14_function_init_value"\xf1\x02\n\x0fRunTaskResponse\x12\x14\n\x07task_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12G\n\x0f\x66unction_output\x18\x02 \x01(\x0b\x32).function_executor_service.FunctionOutputH\x01\x88\x01\x01\x12\x43\n\rrouter_output\x18\x03 \x01(\x0b\x32\'.function_executor_service.RouterOutputH\x02\x88\x01\x01\x12\x13\n\x06stdout\x18\x04 \x01(\tH\x03\x88\x01\x01\x12\x13\n\x06stderr\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x17\n\nis_reducer\x18\x06 \x01(\x08H\x05\x88\x01\x01\x12\x14\n\x07success\x18\x07 \x01(\x08H\x06\x88\x01\x01\x42\n\n\x08_task_idB\x12\n\x10_function_outputB\x10\n\x0e_router_outputB\t\n\x07_stdoutB\t\n\x07_stderrB\r\n\x0b_is_reducerB\n\n\x08_success2\xf2\x02\n\x10\x46unctionExecutor\x12i\n\ninitialize\x12,.function_executor_service.InitializeRequest\x1a-.function_executor_service.InitializeResponse\x12\x8f\x01\n"initialize_invocation_state_server\x12\x32.function_executor_service.InvocationStateResponse\x1a\x31.function_executor_service.InvocationStateRequest(\x01\x30\x01\x12\x61\n\x08run_task\x12).function_executor_service.RunTaskRequest\x1a*.function_executor_service.RunTaskResponseb\x06proto3' ) _globals = globals() @@ -40,14 +40,26 @@ _globals["_INITIALIZEREQUEST"]._serialized_end = 459 _globals["_INITIALIZERESPONSE"]._serialized_start = 461 _globals["_INITIALIZERESPONSE"]._serialized_end = 515 - _globals["_FUNCTIONOUTPUT"]._serialized_start = 517 - _globals["_FUNCTIONOUTPUT"]._serialized_end = 595 - _globals["_ROUTEROUTPUT"]._serialized_start = 597 - _globals["_ROUTEROUTPUT"]._serialized_end = 626 - _globals["_RUNTASKREQUEST"]._serialized_start = 629 - _globals["_RUNTASKREQUEST"]._serialized_end = 933 - _globals["_RUNTASKRESPONSE"]._serialized_start = 936 - _globals["_RUNTASKRESPONSE"]._serialized_end = 1305 - _globals["_FUNCTIONEXECUTOR"]._serialized_start = 1308 - _globals["_FUNCTIONEXECUTOR"]._serialized_end = 1532 + _globals["_SETINVOCATIONSTATEREQUEST"]._serialized_start = 518 + _globals["_SETINVOCATIONSTATEREQUEST"]._serialized_end = 646 + _globals["_SETINVOCATIONSTATERESPONSE"]._serialized_start = 648 + _globals["_SETINVOCATIONSTATERESPONSE"]._serialized_end = 676 + _globals["_GETINVOCATIONSTATEREQUEST"]._serialized_start = 678 + _globals["_GETINVOCATIONSTATEREQUEST"]._serialized_end = 731 + _globals["_GETINVOCATIONSTATERESPONSE"]._serialized_start = 734 + _globals["_GETINVOCATIONSTATERESPONSE"]._serialized_end = 863 + _globals["_INVOCATIONSTATEREQUEST"]._serialized_start = 866 + _globals["_INVOCATIONSTATEREQUEST"]._serialized_end = 1113 + _globals["_INVOCATIONSTATERESPONSE"]._serialized_start = 1116 + _globals["_INVOCATIONSTATERESPONSE"]._serialized_end = 1367 + _globals["_FUNCTIONOUTPUT"]._serialized_start = 1369 + _globals["_FUNCTIONOUTPUT"]._serialized_end = 1447 + _globals["_ROUTEROUTPUT"]._serialized_start = 1449 + _globals["_ROUTEROUTPUT"]._serialized_end = 1478 + _globals["_RUNTASKREQUEST"]._serialized_start = 1481 + _globals["_RUNTASKREQUEST"]._serialized_end = 1785 + _globals["_RUNTASKRESPONSE"]._serialized_start = 1788 + _globals["_RUNTASKRESPONSE"]._serialized_end = 2157 + _globals["_FUNCTIONEXECUTOR"]._serialized_start = 2160 + _globals["_FUNCTIONEXECUTOR"]._serialized_end = 2530 # @@protoc_insertion_point(module_scope) diff --git a/python-sdk/indexify/function_executor/proto/function_executor_pb2.pyi b/python-sdk/indexify/function_executor/proto/function_executor_pb2.pyi index 01a8a71c3..320632b4c 100644 --- a/python-sdk/indexify/function_executor/proto/function_executor_pb2.pyi +++ b/python-sdk/indexify/function_executor/proto/function_executor_pb2.pyi @@ -52,6 +52,76 @@ class InitializeResponse(_message.Message): success: bool def __init__(self, success: bool = ...) -> None: ... +class SetInvocationStateRequest(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: SerializedObject + def __init__( + self, + key: _Optional[str] = ..., + value: _Optional[_Union[SerializedObject, _Mapping]] = ..., + ) -> None: ... + +class SetInvocationStateResponse(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class GetInvocationStateRequest(_message.Message): + __slots__ = ("key",) + KEY_FIELD_NUMBER: _ClassVar[int] + key: str + def __init__(self, key: _Optional[str] = ...) -> None: ... + +class GetInvocationStateResponse(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: SerializedObject + def __init__( + self, + key: _Optional[str] = ..., + value: _Optional[_Union[SerializedObject, _Mapping]] = ..., + ) -> None: ... + +class InvocationStateRequest(_message.Message): + __slots__ = ("request_id", "task_id", "set", "get") + REQUEST_ID_FIELD_NUMBER: _ClassVar[int] + TASK_ID_FIELD_NUMBER: _ClassVar[int] + SET_FIELD_NUMBER: _ClassVar[int] + GET_FIELD_NUMBER: _ClassVar[int] + request_id: str + task_id: str + set: SetInvocationStateRequest + get: GetInvocationStateRequest + def __init__( + self, + request_id: _Optional[str] = ..., + task_id: _Optional[str] = ..., + set: _Optional[_Union[SetInvocationStateRequest, _Mapping]] = ..., + get: _Optional[_Union[GetInvocationStateRequest, _Mapping]] = ..., + ) -> None: ... + +class InvocationStateResponse(_message.Message): + __slots__ = ("request_id", "success", "set", "get") + REQUEST_ID_FIELD_NUMBER: _ClassVar[int] + SUCCESS_FIELD_NUMBER: _ClassVar[int] + SET_FIELD_NUMBER: _ClassVar[int] + GET_FIELD_NUMBER: _ClassVar[int] + request_id: str + success: bool + set: SetInvocationStateResponse + get: GetInvocationStateResponse + def __init__( + self, + request_id: _Optional[str] = ..., + success: bool = ..., + set: _Optional[_Union[SetInvocationStateResponse, _Mapping]] = ..., + get: _Optional[_Union[GetInvocationStateResponse, _Mapping]] = ..., + ) -> None: ... + class FunctionOutput(_message.Message): __slots__ = ("outputs",) OUTPUTS_FIELD_NUMBER: _ClassVar[int] diff --git a/python-sdk/indexify/function_executor/proto/function_executor_pb2_grpc.py b/python-sdk/indexify/function_executor/proto/function_executor_pb2_grpc.py index 8c2a446c3..e2a2d1fd0 100644 --- a/python-sdk/indexify/function_executor/proto/function_executor_pb2_grpc.py +++ b/python-sdk/indexify/function_executor/proto/function_executor_pb2_grpc.py @@ -46,6 +46,12 @@ def __init__(self, channel): response_deserializer=indexify_dot_function__executor_dot_proto_dot_function__executor__pb2.InitializeResponse.FromString, _registered_method=True, ) + self.initialize_invocation_state_server = channel.stream_stream( + "/function_executor_service.FunctionExecutor/initialize_invocation_state_server", + request_serializer=indexify_dot_function__executor_dot_proto_dot_function__executor__pb2.InvocationStateResponse.SerializeToString, + response_deserializer=indexify_dot_function__executor_dot_proto_dot_function__executor__pb2.InvocationStateRequest.FromString, + _registered_method=True, + ) self.run_task = channel.unary_unary( "/function_executor_service.FunctionExecutor/run_task", request_serializer=indexify_dot_function__executor_dot_proto_dot_function__executor__pb2.RunTaskRequest.SerializeToString, @@ -67,6 +73,15 @@ def initialize(self, request, context): context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def initialize_invocation_state_server(self, request_iterator, context): + """Initializes a server that sends requests to the client to perform actions on + a task's graph invocation state. This method is called only once per Function Executor + It should be called before calling RunTask for the function. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def run_task(self, request, context): """Executes the task defined in the request. Multiple tasks can be running in parallel. @@ -83,6 +98,11 @@ def add_FunctionExecutorServicer_to_server(servicer, server): request_deserializer=indexify_dot_function__executor_dot_proto_dot_function__executor__pb2.InitializeRequest.FromString, response_serializer=indexify_dot_function__executor_dot_proto_dot_function__executor__pb2.InitializeResponse.SerializeToString, ), + "initialize_invocation_state_server": grpc.stream_stream_rpc_method_handler( + servicer.initialize_invocation_state_server, + request_deserializer=indexify_dot_function__executor_dot_proto_dot_function__executor__pb2.InvocationStateResponse.FromString, + response_serializer=indexify_dot_function__executor_dot_proto_dot_function__executor__pb2.InvocationStateRequest.SerializeToString, + ), "run_task": grpc.unary_unary_rpc_method_handler( servicer.run_task, request_deserializer=indexify_dot_function__executor_dot_proto_dot_function__executor__pb2.RunTaskRequest.FromString, @@ -132,6 +152,36 @@ def initialize( _registered_method=True, ) + @staticmethod + def initialize_invocation_state_server( + request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.stream_stream( + request_iterator, + target, + "/function_executor_service.FunctionExecutor/initialize_invocation_state_server", + indexify_dot_function__executor_dot_proto_dot_function__executor__pb2.InvocationStateResponse.SerializeToString, + indexify_dot_function__executor_dot_proto_dot_function__executor__pb2.InvocationStateRequest.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True, + ) + @staticmethod def run_task( request, diff --git a/python-sdk/indexify/functions_sdk/graph.py b/python-sdk/indexify/functions_sdk/graph.py index 8ce2d981d..382b6622b 100644 --- a/python-sdk/indexify/functions_sdk/graph.py +++ b/python-sdk/indexify/functions_sdk/graph.py @@ -37,6 +37,7 @@ IndexifyRouter, RouterCallResult, ) +from .invocation_state.local_invocation_state import LocalInvocationState from .object_serializer import get_serializer RouterFn = Annotated[ @@ -236,13 +237,12 @@ def run(self, block_until_done: bool = False, **kwargs) -> str: payload=serializer.serialize(v), encoder=node.input_encoder ) self._results[input.id] = outputs - ctx = GraphInvocationContext( + self._local_graph_ctx = GraphInvocationContext( invocation_id=input.id, graph_name=self.name, graph_version="1", - indexify_client=None, + invocation_state=LocalInvocationState(), ) - self._local_graph_ctx = ctx self._run(input, outputs) return input.id diff --git a/python-sdk/indexify/functions_sdk/indexify_functions.py b/python-sdk/indexify/functions_sdk/indexify_functions.py index 8c5630f67..27245e633 100644 --- a/python-sdk/indexify/functions_sdk/indexify_functions.py +++ b/python-sdk/indexify/functions_sdk/indexify_functions.py @@ -13,35 +13,27 @@ get_origin, ) -from pydantic import BaseModel, Field, PrivateAttr +from pydantic import BaseModel from typing_extensions import get_type_hints from .data_objects import IndexifyData from .image import DEFAULT_IMAGE, Image +from .invocation_state.invocation_state import InvocationState from .object_serializer import get_serializer -class GraphInvocationContext(BaseModel): - invocation_id: str - graph_name: str - graph_version: str - indexify_client: Optional[Any] = Field(default=None) # avoids circular import - _local_state: Dict[str, Any] = PrivateAttr(default_factory=dict) - - def set_state_key(self, key: str, value: Any) -> None: - if self.indexify_client is None: - self._local_state[key] = value - return - self.indexify_client.set_state_key( - self.graph_name, self.invocation_id, key, value - ) - - def get_state_key(self, key: str) -> Any: - if self.indexify_client is None: - return self._local_state.get(key) - return self.indexify_client.get_state_key( - self.graph_name, self.invocation_id, key - ) +class GraphInvocationContext: + def __init__( + self, + invocation_id: str, + graph_name: str, + graph_version: str, + invocation_state: InvocationState, + ): + self.invocation_id = invocation_id + self.graph_name = graph_name + self.graph_version = graph_version + self.invocation_state = invocation_state def is_pydantic_model_from_annotation(type_annotation): diff --git a/python-sdk/indexify/functions_sdk/invocation_state/invocation_state.py b/python-sdk/indexify/functions_sdk/invocation_state/invocation_state.py new file mode 100644 index 000000000..beced521c --- /dev/null +++ b/python-sdk/indexify/functions_sdk/invocation_state/invocation_state.py @@ -0,0 +1,22 @@ +from typing import Any, Optional + + +class InvocationState: + """Abstract interface for Graph invocation state key-value API. + + The API allows to set and get key-value pairs from Indexify functions. + The key-value pairs are scoped per Graph invocation. + Each new invocation starts with an empty state (empty set of key-value pairs). + A value can be any CloudPickleSerializer serializable object.""" + + def set(self, key: str, value: Any) -> None: + """Set a key-value pair. + + Raises Exception if an error occured.""" + raise NotImplementedError() + + def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: + """Get a value by key. If the key does not exist, return the default value. + + Raises Exception if an error occured.""" + raise NotImplementedError() diff --git a/python-sdk/indexify/functions_sdk/invocation_state/local_invocation_state.py b/python-sdk/indexify/functions_sdk/invocation_state/local_invocation_state.py new file mode 100644 index 000000000..bc3d489bb --- /dev/null +++ b/python-sdk/indexify/functions_sdk/invocation_state/local_invocation_state.py @@ -0,0 +1,30 @@ +from typing import Any, Dict, Optional + +from ..object_serializer import CloudPickleSerializer +from .invocation_state import InvocationState + + +class LocalInvocationState(InvocationState): + """InvocationState that stores the key-value pairs in memory. + + This is intended to be used with local graphs.""" + + def __init__(self): + """Creates a new instance. + + Caller needs to ensure that the returned instance is only used for a single invocation state. + """ + self._state: Dict[str, bytes] = {} + + def set(self, key: str, value: Any) -> None: + # It's important to serialize the value even in the local implementation + # so there are no unexpected errors when running in remote graph mode. + self._state[key] = CloudPickleSerializer.serialize(value) + + def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: + serialized_value: Optional[bytes] = self._state.get(key, None) + return ( + default + if serialized_value is None + else CloudPickleSerializer.deserialize(serialized_value) + ) diff --git a/python-sdk/indexify/http_client.py b/python-sdk/indexify/http_client.py index 38b8b7690..23cb67f8b 100644 --- a/python-sdk/indexify/http_client.py +++ b/python-sdk/indexify/http_client.py @@ -208,23 +208,6 @@ def namespaces(self) -> List[str]: namespaces.append(item["name"]) return namespaces - def set_state_key( - self, compute_graph: str, invocation_id: str, key: str, value: Json - ) -> None: - response = self._post( - f"internal/namespaces/{self.namespace}/compute_graphs/{compute_graph}/invocations/{invocation_id}/ctx", - json={"key": key, "value": value}, - ) - response.raise_for_status() - - def get_state_key(self, compute_graph: str, invocation_id: str, key: str) -> Json: - response = self._get( - f"internal/namespaces/{self.namespace}/compute_graphs/{compute_graph}/invocations/{invocation_id}/ctx", - json={"key": key}, - ) - response.raise_for_status() - return response.json().get("value") - @classmethod def new_namespace( cls, namespace: str, server_addr: Optional[str] = "http://localhost:8900" diff --git a/python-sdk/tests/function_executor/test_invocation_state.py b/python-sdk/tests/function_executor/test_invocation_state.py new file mode 100644 index 000000000..0b6d5b7fa --- /dev/null +++ b/python-sdk/tests/function_executor/test_invocation_state.py @@ -0,0 +1,389 @@ +import sys +import threading +import unittest +from typing import Any, Dict, List + +from pydantic import BaseModel + +from indexify import Graph +from indexify.function_executor.proto.function_executor_pb2 import ( + GetInvocationStateRequest, + GetInvocationStateResponse, + InitializeRequest, + InitializeResponse, + InvocationStateRequest, + InvocationStateResponse, + RunTaskResponse, + SerializedObject, + SetInvocationStateRequest, + SetInvocationStateResponse, +) +from indexify.function_executor.proto.function_executor_pb2_grpc import ( + FunctionExecutorStub, +) +from indexify.functions_sdk.indexify_functions import ( + get_ctx, + indexify_function, +) +from indexify.functions_sdk.object_serializer import CloudPickleSerializer +from tests.function_executor.utils import ( + FunctionExecutorServerTestCase, + deserialized_function_output, + run_task, +) +from tests.testing import test_graph_name + + +class StructuredField(BaseModel): + list: List[int] + dictionary: Dict[str, Any] + + +class StructuredState(BaseModel): + string: str + integer: int + structured: StructuredField + + +def invocation_state_client_stub( + test_case: unittest.TestCase, + stub: FunctionExecutorStub, + expected_requests: List[InvocationStateRequest], + responses: List[InvocationStateResponse], +) -> threading.Thread: + server_request_iterator = stub.initialize_invocation_state_server(iter(responses)) + + def loop(): + for expected_request in expected_requests: + request = next(server_request_iterator) + request: InvocationStateRequest + test_case.assertEqual(request.request_id, expected_request.request_id) + test_case.assertEqual(request.task_id, expected_request.task_id) + if request.HasField("set"): + test_case.assertEqual(request.set.key, expected_request.set.key) + # Two different serialized objects are not equal so we need to deserialize them. + test_case.assertEqual( + CloudPickleSerializer.deserialize(request.set.value.bytes), + CloudPickleSerializer.deserialize(expected_request.set.value.bytes), + ) + else: + test_case.assertEqual(request.get.key, expected_request.get.key) + + invocation_state_client_thread = threading.Thread(target=loop) + invocation_state_client_thread.start() + return invocation_state_client_thread + + +class TestSetInvocationState(FunctionExecutorServerTestCase): + def _create_graph(self): + @indexify_function() + def set_invocation_state(x: int) -> str: + get_ctx().invocation_state.set( + "test_state_key", + StructuredState( + string="hello", + integer=x, + structured=StructuredField( + list=[1, 2, 3], dictionary={"a": 1, "b": 2} + ), + ), + ) + return "success" + + return Graph( + name=test_graph_name(self), + description="test", + start_node=set_invocation_state, + ) + + def _initialize_function_executor(self, stub: FunctionExecutorStub): + graph = self._create_graph() + initialize_response: InitializeResponse = stub.initialize( + InitializeRequest( + namespace="test", + graph_name="test", + graph_version=1, + function_name="set_invocation_state", + graph=SerializedObject( + bytes=CloudPickleSerializer.serialize( + graph.serialize(additional_modules=[sys.modules[__name__]]) + ), + content_type=CloudPickleSerializer.content_type, + ), + ) + ) + self.assertTrue(initialize_response.success) + + def test_success(self): + with self._rpc_channel() as channel: + stub: FunctionExecutorStub = FunctionExecutorStub(channel) + self._initialize_function_executor(stub) + expected_requests = [ + InvocationStateRequest( + request_id="0", + task_id="test-task", + set=SetInvocationStateRequest( + key="test_state_key", + value=SerializedObject( + content_type=CloudPickleSerializer.content_type, + bytes=CloudPickleSerializer.serialize( + StructuredState( + string="hello", + integer=42, + structured=StructuredField( + list=[1, 2, 3], dictionary={"a": 1, "b": 2} + ), + ) + ), + ), + ), + ), + ] + responses = [ + InvocationStateResponse( + request_id="0", success=True, set=SetInvocationStateResponse() + ), + ] + client_thread = invocation_state_client_stub( + self, stub, expected_requests, responses + ) + run_task_response: RunTaskResponse = run_task(stub, 42) + self.assertTrue(run_task_response.success) + fn_outputs = deserialized_function_output( + self, run_task_response.function_output + ) + self.assertEqual(len(fn_outputs), 1) + self.assertEqual("success", fn_outputs[0]) + + print( + "Joining invocation state client thread, it should exit immediately..." + ) + client_thread.join() + + def test_client_failure(self): + with self._rpc_channel() as channel: + stub: FunctionExecutorStub = FunctionExecutorStub(channel) + self._initialize_function_executor(stub) + expected_requests = [ + InvocationStateRequest( + request_id="0", + task_id="test-task", + set=SetInvocationStateRequest( + key="test_state_key", + value=SerializedObject( + content_type=CloudPickleSerializer.content_type, + bytes=CloudPickleSerializer.serialize( + StructuredState( + string="hello", + integer=42, + structured=StructuredField( + list=[1, 2, 3], dictionary={"a": 1, "b": 2} + ), + ) + ), + ), + ), + ), + ] + responses = [ + InvocationStateResponse( + request_id="0", success=False, set=SetInvocationStateResponse() + ), + ] + client_thread = invocation_state_client_stub( + self, stub, expected_requests, responses + ) + run_task_response: RunTaskResponse = run_task(stub, 42) + self.assertFalse(run_task_response.success) + self.assertTrue( + 'RuntimeError("failed to set the invocation state for key")' + in run_task_response.stderr + ) + + print( + "Joining invocation state client thread, it should exit immediately..." + ) + client_thread.join() + + +class TestGetInvocationState(FunctionExecutorServerTestCase): + def _create_graph_with_result_validation(self): + @indexify_function() + def get_invocation_state(x: int) -> str: + got_state: StructuredState = get_ctx().invocation_state.get( + "test_state_key" + ) + expected_state: StructuredState = StructuredState( + string="hello", + integer=x, + structured=StructuredField(list=[1, 2, 3], dictionary={"a": 1, "b": 2}), + ) + return "success" if got_state == expected_state else "failure" + + return Graph( + name=test_graph_name(self), + description="test", + start_node=get_invocation_state, + ) + + def _initialize_function_executor(self, graph: Graph, stub: FunctionExecutorStub): + initialize_response: InitializeResponse = stub.initialize( + InitializeRequest( + namespace="test", + graph_name="test", + graph_version=1, + function_name="get_invocation_state", + graph=SerializedObject( + bytes=CloudPickleSerializer.serialize( + graph.serialize(additional_modules=[sys.modules[__name__]]) + ), + content_type=CloudPickleSerializer.content_type, + ), + ) + ) + self.assertTrue(initialize_response.success) + + def test_success(self): + with self._rpc_channel() as channel: + stub: FunctionExecutorStub = FunctionExecutorStub(channel) + self._initialize_function_executor( + self._create_graph_with_result_validation(), stub + ) + expected_requests = [ + InvocationStateRequest( + request_id="0", + task_id="test-task", + get=GetInvocationStateRequest( + key="test_state_key", + ), + ), + ] + responses = [ + InvocationStateResponse( + request_id="0", + success=True, + get=GetInvocationStateResponse( + key="test_state_key", + value=SerializedObject( + content_type=CloudPickleSerializer.content_type, + bytes=CloudPickleSerializer.serialize( + StructuredState( + string="hello", + integer=33, + structured=StructuredField( + list=[1, 2, 3], dictionary={"a": 1, "b": 2} + ), + ) + ), + ), + ), + ), + ] + client_thread = invocation_state_client_stub( + self, stub, expected_requests, responses + ) + run_task_response: RunTaskResponse = run_task(stub, 33) + self.assertTrue(run_task_response.success) + fn_outputs = deserialized_function_output( + self, run_task_response.function_output + ) + self.assertEqual(len(fn_outputs), 1) + self.assertEqual("success", fn_outputs[0]) + + print( + "Joining invocation state client thread, it should exit immediately..." + ) + client_thread.join() + + def test_success_none_value(self): + @indexify_function() + def get_invocation_state(x: int) -> str: + got_state: StructuredState = get_ctx().invocation_state.get( + "test_state_key" + ) + return "success" if got_state is None else "failure" + + graph = Graph( + name=test_graph_name(self), + description="test", + start_node=get_invocation_state, + ) + + with self._rpc_channel() as channel: + stub: FunctionExecutorStub = FunctionExecutorStub(channel) + self._initialize_function_executor(graph, stub) + expected_requests = [ + InvocationStateRequest( + request_id="0", + task_id="test-task", + get=GetInvocationStateRequest( + key="test_state_key", + ), + ), + ] + responses = [ + InvocationStateResponse( + request_id="0", + success=True, + get=GetInvocationStateResponse( + key="test_state_key", + value=None, + ), + ), + ] + client_thread = invocation_state_client_stub( + self, stub, expected_requests, responses + ) + run_task_response: RunTaskResponse = run_task(stub, 33) + self.assertTrue(run_task_response.success) + fn_outputs = deserialized_function_output( + self, run_task_response.function_output + ) + self.assertEqual(len(fn_outputs), 1) + self.assertEqual("success", fn_outputs[0]) + + print( + "Joining invocation state client thread, it should exit immediately..." + ) + client_thread.join() + + def test_client_failure(self): + with self._rpc_channel() as channel: + stub: FunctionExecutorStub = FunctionExecutorStub(channel) + self._initialize_function_executor( + self._create_graph_with_result_validation(), stub + ) + expected_requests = [ + InvocationStateRequest( + request_id="0", + task_id="test-task", + get=GetInvocationStateRequest( + key="test_state_key", + ), + ), + ] + responses = [ + InvocationStateResponse( + request_id="0", + success=False, + get=GetInvocationStateResponse(key="test_state_key"), + ), + ] + client_thread = invocation_state_client_stub( + self, stub, expected_requests, responses + ) + run_task_response: RunTaskResponse = run_task(stub, 14) + self.assertFalse(run_task_response.success) + self.assertTrue( + 'RuntimeError("failed to get the invocation state for key")' + in run_task_response.stderr + ) + + print( + "Joining invocation state client thread, it should exit immediately..." + ) + client_thread.join() + + +if __name__ == "__main__": + unittest.main() diff --git a/python-sdk/tests/function_executor/test_max_payload.py b/python-sdk/tests/function_executor/test_max_payload.py index 00c42d4ed..9feff4e9b 100644 --- a/python-sdk/tests/function_executor/test_max_payload.py +++ b/python-sdk/tests/function_executor/test_max_payload.py @@ -9,7 +9,6 @@ from indexify.function_executor.proto.function_executor_pb2 import ( InitializeRequest, InitializeResponse, - RunTaskRequest, RunTaskResponse, SerializedObject, ) @@ -18,7 +17,11 @@ ) from indexify.functions_sdk.data_objects import File from indexify.functions_sdk.object_serializer import CloudPickleSerializer -from tests.function_executor.utils import FunctionExecutorServerTestCase +from tests.function_executor.utils import ( + FunctionExecutorServerTestCase, + deserialized_function_output, + run_task, +) # Current max input and output sizes that we support. MAX_FUNCTION_PAYLOAD_SIZE_BYTES = math.floor(1.9 * 1024 * 1024 * 1024) # 1.9 GB @@ -72,7 +75,6 @@ def test_max_function_input_size(self): ) max_input_data = random_bytes(MAX_FUNCTION_PAYLOAD_SIZE_BYTES) max_input = File(data=max_input_data, sha_256=hash(max_input_data)) - serialized_max_input = CloudPickleSerializer.serialize(max_input) with self._rpc_channel() as channel: stub: FunctionExecutorStub = FunctionExecutorStub(channel) @@ -92,26 +94,13 @@ def test_max_function_input_size(self): ) self.assertTrue(initialize_response.success) - run_task_response: RunTaskResponse = stub.run_task( - RunTaskRequest( - graph_invocation_id="123", - task_id="test-task", - function_input=SerializedObject( - bytes=serialized_max_input, - content_type=CloudPickleSerializer.content_type, - ), - ) - ) - + run_task_response: RunTaskResponse = run_task(stub, max_input) self.assertTrue(run_task_response.success) self.assertFalse(run_task_response.is_reducer) - fn_outputs = [] - for output in run_task_response.function_output.outputs: - self.assertEqual( - output.content_type, CloudPickleSerializer.content_type - ) - fn_outputs.append(CloudPickleSerializer.deserialize(output.bytes)) + fn_outputs = deserialized_function_output( + self, run_task_response.function_output + ) self.assertEqual(len(fn_outputs), 1) self.assertEqual("success", fn_outputs[0]) @@ -140,26 +129,14 @@ def test_max_function_output_size(self): ) self.assertTrue(initialize_response.success) - run_task_response: RunTaskResponse = stub.run_task( - RunTaskRequest( - graph_invocation_id="123", - task_id="test-task", - function_input=SerializedObject( - bytes=CloudPickleSerializer.serialize(1), - content_type=CloudPickleSerializer.content_type, - ), - ) - ) + run_task_response: RunTaskResponse = run_task(stub, 1) self.assertTrue(run_task_response.success) self.assertFalse(run_task_response.is_reducer) - fn_outputs = [] - for output in run_task_response.function_output.outputs: - self.assertEqual( - output.content_type, CloudPickleSerializer.content_type - ) - fn_outputs.append(CloudPickleSerializer.deserialize(output.bytes)) + fn_outputs = deserialized_function_output( + self, run_task_response.function_output + ) self.assertEqual(len(fn_outputs), 1) output_file: File = fn_outputs[0] self.assertEqual(MAX_FUNCTION_PAYLOAD_SIZE_BYTES, len(output_file.data)) diff --git a/python-sdk/tests/function_executor/test_rpc_server.py b/python-sdk/tests/function_executor/test_run_task.py similarity index 74% rename from python-sdk/tests/function_executor/test_rpc_server.py rename to python-sdk/tests/function_executor/test_run_task.py index 4b3b2d52d..607b2fed8 100644 --- a/python-sdk/tests/function_executor/test_rpc_server.py +++ b/python-sdk/tests/function_executor/test_run_task.py @@ -8,7 +8,6 @@ from indexify.function_executor.proto.function_executor_pb2 import ( InitializeRequest, InitializeResponse, - RunTaskRequest, RunTaskResponse, SerializedObject, ) @@ -18,7 +17,11 @@ from indexify.functions_sdk.data_objects import File from indexify.functions_sdk.indexify_functions import indexify_function from indexify.functions_sdk.object_serializer import CloudPickleSerializer -from tests.function_executor.utils import FunctionExecutorServerTestCase +from tests.function_executor.utils import ( + FunctionExecutorServerTestCase, + deserialized_function_output, + run_task, +) @indexify_function() @@ -71,8 +74,8 @@ def create_graph_exception(): return graph -class TestRPCServer(FunctionExecutorServerTestCase): - def test_run_task_success(self): +class TestRunTask(FunctionExecutorServerTestCase): + def test_function_success(self): with self._rpc_channel() as channel: stub: FunctionExecutorStub = FunctionExecutorStub(channel) initialize_response: InitializeResponse = stub.initialize( @@ -93,34 +96,22 @@ def test_run_task_success(self): ) self.assertTrue(initialize_response.success) - run_task_response: RunTaskResponse = stub.run_task( - RunTaskRequest( - graph_invocation_id="123", - task_id="test-task", - function_input=SerializedObject( - bytes=CloudPickleSerializer.serialize( - File(data=bytes(b"hello"), mime_type="text/plain") - ), - content_type=CloudPickleSerializer.content_type, - ), - ) + run_task_response: RunTaskResponse = run_task( + stub, File(data=bytes(b"hello"), mime_type="text/plain") ) self.assertTrue(run_task_response.success) self.assertFalse(run_task_response.is_reducer) - fn_outputs = [] - for output in run_task_response.function_output.outputs: - self.assertEqual( - output.content_type, CloudPickleSerializer.content_type - ) - fn_outputs.append(CloudPickleSerializer.deserialize(output.bytes)) + fn_outputs = deserialized_function_output( + self, run_task_response.function_output + ) self.assertEqual(len(fn_outputs), 2) expected = FileChunk(data=b"hello", start=5, end=5) self.assertEqual(expected, fn_outputs[1]) - def test_run_task_extractor_raises_error(self): + def test_function_raises_error(self): with self._rpc_channel() as channel: stub: FunctionExecutorStub = FunctionExecutorStub(channel) initialize_response: InitializeResponse = stub.initialize( @@ -141,16 +132,7 @@ def test_run_task_extractor_raises_error(self): ) self.assertTrue(initialize_response.success) - run_task_response: RunTaskResponse = stub.run_task( - RunTaskRequest( - graph_invocation_id="123", - task_id="test-task", - function_input=SerializedObject( - bytes=CloudPickleSerializer.serialize(10), - content_type=CloudPickleSerializer.content_type, - ), - ) - ) + run_task_response: RunTaskResponse = run_task(stub, 10) self.assertFalse(run_task_response.success) self.assertFalse(run_task_response.is_reducer) diff --git a/python-sdk/tests/function_executor/utils.py b/python-sdk/tests/function_executor/utils.py index a5160d0ba..78f5234b0 100644 --- a/python-sdk/tests/function_executor/utils.py +++ b/python-sdk/tests/function_executor/utils.py @@ -1,10 +1,20 @@ -import os import subprocess import unittest +from typing import Any, List import grpc from indexify.function_executor.proto.configuration import GRPC_CHANNEL_OPTIONS +from indexify.function_executor.proto.function_executor_pb2 import ( + FunctionOutput, + RunTaskRequest, + RunTaskResponse, + SerializedObject, +) +from indexify.function_executor.proto.function_executor_pb2_grpc import ( + FunctionExecutorStub, +) +from indexify.functions_sdk.object_serializer import CloudPickleSerializer class FunctionExecutorServerTestCase(unittest.TestCase): @@ -23,8 +33,6 @@ def setUpClass(cls): "--dev", "--function-executor-server-address", cls.FUNCTION_EXECUTOR_SERVER_ADDRESS, - "--indexify-server-address", - os.environ.get("INDEXIFY_URL"), ] ) @@ -51,3 +59,26 @@ def _rpc_channel(self) -> grpc.Channel: self.fail( f"Failed to connect to the gRPC server within {SERVER_STARTUP_TIMEOUT_SEC} seconds: {e}" ) + + +def run_task(stub: FunctionExecutorStub, input: Any) -> RunTaskResponse: + return stub.run_task( + RunTaskRequest( + graph_invocation_id="123", + task_id="test-task", + function_input=SerializedObject( + bytes=CloudPickleSerializer.serialize(input), + content_type=CloudPickleSerializer.content_type, + ), + ) + ) + + +def deserialized_function_output( + test_case: unittest.TestCase, function_output: FunctionOutput +) -> List[Any]: + outputs: List[Any] = [] + for output in function_output.outputs: + test_case.assertEqual(output.content_type, CloudPickleSerializer.content_type) + outputs.append(CloudPickleSerializer.deserialize(output.bytes)) + return outputs diff --git a/python-sdk/tests/test_functions.py b/python-sdk/tests/test_functions.py index 425018ecd..5d350e561 100644 --- a/python-sdk/tests/test_functions.py +++ b/python-sdk/tests/test_functions.py @@ -10,9 +10,15 @@ indexify_function, indexify_router, ) +from indexify.functions_sdk.invocation_state.local_invocation_state import ( + LocalInvocationState, +) TEST_GRAPH_CTX = GraphInvocationContext( - invocation_id="123", graph_name="test", graph_version="1" + invocation_id="123", + graph_name="test", + graph_version="1", + invocation_state=LocalInvocationState(), ) @@ -90,8 +96,8 @@ def test_get_ctx(self): @indexify_function() def extractor_c(url: str) -> str: ctx = get_ctx() # type: ignore - ctx.set_state_key("foo", "bar") - foo_val = ctx.get_state_key("foo") + ctx.invocation_state.set("foo", "bar") + foo_val = ctx.invocation_state.get("foo") return ctx.invocation_id extractor_wrapper = IndexifyFunctionWrapper(extractor_c, TEST_GRAPH_CTX) diff --git a/python-sdk/tests/test_graph_behaviours.py b/python-sdk/tests/test_graph_behaviours.py index 2a57ff31e..1b9226032 100644 --- a/python-sdk/tests/test_graph_behaviours.py +++ b/python-sdk/tests/test_graph_behaviours.py @@ -66,7 +66,7 @@ class ComplexObject(BaseModel): @indexify_function() def simple_function_ctx(x: MyObject) -> ComplexObject: ctx = get_ctx() - ctx.set_state_key("my_key", 10) + ctx.invocation_state.set("my_key", 10) return ComplexObject( invocation_id=ctx.invocation_id, graph_name=ctx.graph_name, @@ -77,7 +77,7 @@ def simple_function_ctx(x: MyObject) -> ComplexObject: @indexify_function() def simple_function_ctx_b(x: ComplexObject) -> int: ctx = get_ctx() - val = ctx.get_state_key("my_key") + val = ctx.invocation_state.get("my_key") return val + 1 @@ -90,9 +90,9 @@ def __init__(self): def run(self, x: ComplexObject) -> int: ctx = get_ctx() print(f"ctx: {ctx}") - val = ctx.get_state_key("my_key") + val = ctx.invocation_state.get("my_key") assert val == 10 - not_present = ctx.get_state_key("not_present") + not_present = ctx.invocation_state.get("not_present") assert not_present is None return val + 1 diff --git a/server/src/http_objects.rs b/server/src/http_objects.rs index d311b9936..a8b67e4d6 100644 --- a/server/src/http_objects.rs +++ b/server/src/http_objects.rs @@ -554,28 +554,6 @@ pub struct InvocationQueryParams { pub block_until_finish: Option, } -#[derive(Debug, Serialize, Deserialize, ToSchema)] -pub struct CtxStatePutRequest { - pub key: String, - - // Could have encoded this as string but - // making sure we get valid json from user - // code - pub value: serde_json::Value, -} - -#[derive(Debug, Serialize, Deserialize, ToSchema)] -pub struct CtxStateGetRequest { - /// The key to retrieve, and if none, return all keys - pub key: String, -} - -#[derive(Debug, Serialize, Deserialize, ToSchema)] -pub struct CtxStateGetResponse { - // Values indexed by key names - pub value: Option, -} - #[cfg(test)] mod tests { use crate::http_objects::{ComputeFn, DynamicRouter}; diff --git a/server/src/routes.rs b/server/src/routes.rs index a4a592eec..53d408cb7 100644 --- a/server/src/routes.rs +++ b/server/src/routes.rs @@ -26,7 +26,7 @@ use metrics::api_io_stats; use nanoid::nanoid; use prometheus::Encoder; use state_store::{ - kv::{WriteContextData, KVS}, + kv::{ReadContextData, WriteContextData, KVS}, requests::{ CreateOrUpdateComputeGraphRequest, DeleteComputeGraphRequest, @@ -42,10 +42,7 @@ use tracing::{error, info}; use utoipa::{OpenApi, ToSchema}; use utoipa_swagger_ui::SwaggerUi; -use crate::{ - executors::{self, EXECUTOR_TIMEOUT}, - http_objects::{CtxStateGetRequest, CtxStateGetResponse, CtxStatePutRequest}, -}; +use crate::executors::{self, EXECUTOR_TIMEOUT}; mod download; mod internal_ingest; @@ -195,11 +192,11 @@ pub fn create_routes(route_state: RouteState) -> Router { get(download_fn_output_by_key).with_state(route_state.clone()), ) .route( - "/internal/namespaces/:namespace/compute_graphs/:compute_graph/invocations/:invocation_id/ctx", + "/internal/namespaces/:namespace/compute_graphs/:compute_graph/invocations/:invocation_id/ctx/:name", post(set_ctx_state_key).with_state(route_state.clone()), ) .route( - "/internal/namespaces/:namespace/compute_graphs/:compute_graph/invocations/:invocation_id/ctx", + "/internal/namespaces/:namespace/compute_graphs/:compute_graph/invocations/:invocation_id/ctx/:name", get(get_ctx_state_key).with_state(route_state.clone()), ) .layer(OtelInResponseLayer::default()) @@ -888,18 +885,47 @@ async fn get_versioned_code( } async fn set_ctx_state_key( - Path((namespace, compute_graph, invocation_id)): Path<(String, String, String)>, + Path((namespace, compute_graph, invocation_id, key)): Path<(String, String, String, String)>, State(state): State, - Json(payload): Json, + mut values: Multipart, ) -> Result<(), IndexifyAPIError> { - let request = WriteContextData { + let mut request: WriteContextData = WriteContextData { namespace, compute_graph, invocation_id, - key: payload.key, - value: serde_json::to_vec(&payload.value) - .map_err(|e| IndexifyAPIError::bad_request(&e.to_string()))?, + key, + value: vec![], }; + + while let Some(field) = values.next_field().await.unwrap() { + if let Some(name) = field.name() { + if name == "value" { + let content_type: &str = field.content_type().ok_or_else(|| { + IndexifyAPIError::bad_request("content-type of the value is required") + })?; + if content_type != "application/octet-stream" { + // Server doesn't support flexible client controlled content-type yet because + // we don't yet store content-type in the kv store. + return Err(IndexifyAPIError::bad_request( + "only 'application/octet-stream' content-type is currently supported", + )); + } + request.value = field + .bytes() + .await + .map_err(|e| { + IndexifyAPIError::internal_error(anyhow!("failed reading the value: {}", e)) + })? + .to_vec(); + } else { + return Err(IndexifyAPIError::bad_request(&format!( + "unexpected field: {}", + name + ))); + } + } + } + state .kvs .put_ctx_state(request) @@ -909,17 +935,33 @@ async fn set_ctx_state_key( } async fn get_ctx_state_key( - Path((namespace, compute_graph, invocation_id)): Path<(String, String, String)>, + Path((namespace, compute_graph, invocation_id, key)): Path<(String, String, String, String)>, State(state): State, - Json(request): Json, -) -> Result, IndexifyAPIError> { +) -> Result, IndexifyAPIError> { let value = state .kvs - .get_ctx_state_key(&namespace, &compute_graph, &invocation_id, &request.key) + .get_ctx_state_key(ReadContextData { + namespace, + compute_graph, + invocation_id, + key, + }) .await .map_err(IndexifyAPIError::internal_error)?; - let value = value.map(|v| serde_json::from_slice(&v).unwrap()); - Ok(Json(CtxStateGetResponse { value })) + match value { + Some(value) => Response::builder() + .header("Content-Type", "application/octet-stream") + .header("Content-Length", value.len().to_string()) + .body(Body::from(value)) + .map_err(|e| { + tracing::error!("failed streaming get ctx response: {:?}", e); + IndexifyAPIError::internal_error_str("failed streaming the response") + }), + None => Ok(Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::empty()) + .unwrap()), + } } #[axum::debug_handler] diff --git a/server/state_store/src/kv.rs b/server/state_store/src/kv.rs index dbbc52cd7..13025eca4 100644 --- a/server/state_store/src/kv.rs +++ b/server/state_store/src/kv.rs @@ -14,6 +14,14 @@ pub struct WriteContextData { pub key: String, pub value: Vec, } + +pub struct ReadContextData { + pub namespace: String, + pub compute_graph: String, + pub invocation_id: String, + pub key: String, +} + pub struct KVS { kv_store: Arc, metrics: Metrics, @@ -47,17 +55,14 @@ impl KVS { Ok(()) } - pub async fn get_ctx_state_key( - &self, - namespace: &str, - compute_graph: &str, - invocation_id: &str, - key: &str, - ) -> Result> { + pub async fn get_ctx_state_key(&self, req: ReadContextData) -> Result> { let timer_kvs = &[KeyValue::new("op", "get_ctx_state_key")]; let _timer = Timer::start_with_labels(&self.metrics.reads, timer_kvs); - let key = format!("{}|{}|{}|{}", namespace, compute_graph, invocation_id, key); + let key = format!( + "{}|{}|{}|{}", + req.namespace, req.compute_graph, req.invocation_id, req.key + ); let value = self.kv_store.get(key.as_bytes()).await?; Ok(value) }