-
Notifications
You must be signed in to change notification settings - Fork 121
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Invocation state proxy for Function Executors and new API for functions
We don't want to have Indexify Server API credentials in Function Executors because it's a security risk if Function Executors run not trusted code. To achieve that Executor provides a gRPC API to Function Executor to get/set graph invocation state. Once a request is recieved Executor finds invocation ID of the requesting Function Executor and does Indexify Server API call to get/set and invocation state key. There are a few customer function facing API changes here: Instead of calling `get_ctx().set_state_key()` and `get_ctx().get_state_key()` they call `get_ctx().invocation_state.get()` and `get_ctx().invocation_state.set()`. This intoduces a separate API for graph invocation state which is more clear than adding unbounded number of methods into object returned by `get_ctx()`. Another change is that `set(key, value)` now accepts any value that is serializable using CloudPickle. This is provides a consisten "no surprises" UX because we're using CloudPickle for function inputs and outputs. Supporting this required small changes on Indexify Server side as previous JSON was explicitly used at HTTP protocol and storage layers. Now it supports arbitrary binary and textual formats for invocation state values. Finally `get(key)` now returns Optional[Any] so None is returned if the key wasn't set yet. This allows the customer code to decide what to do in this case. Testing: make fmt make test
- Loading branch information
Showing
30 changed files
with
1,336 additions
and
236 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
212 changes: 212 additions & 0 deletions
212
python-sdk/indexify/executor/function_executor/invocation_state_client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
import asyncio | ||
from typing import Any, AsyncGenerator, Optional, Union | ||
|
||
import httpx | ||
|
||
from indexify.executor.downloader import serialized_object_from_http_response | ||
from indexify.function_executor.proto.function_executor_pb2 import ( | ||
GetInvocationStateResponse, | ||
InvocationStateRequest, | ||
InvocationStateResponse, | ||
SerializedObject, | ||
SetInvocationStateResponse, | ||
) | ||
from indexify.function_executor.proto.function_executor_pb2_grpc import ( | ||
FunctionExecutorStub, | ||
) | ||
from indexify.function_executor.proto.message_validator import MessageValidator | ||
|
||
|
||
class InvocationStateClient: | ||
"""InvocationStateClient is a client for the invocation state server of a Function Executor. | ||
The client initializes the Function Executor's invocation state server and executes requests | ||
it sends to the client. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
stub: FunctionExecutorStub, | ||
base_url: str, | ||
http_client: httpx.AsyncClient, | ||
graph: str, | ||
namespace: str, | ||
logger: Any, | ||
): | ||
self._stub: FunctionExecutorStub = stub | ||
self._base_url: str = base_url | ||
self._http_client: httpx.AsyncClient = http_client | ||
self._graph: str = graph | ||
self._namespace: str = namespace | ||
self._logger: Any = logger.bind( | ||
module=__name__, graph=graph, namespace=namespace | ||
) | ||
self._client_response_queue: asyncio.Queue[ | ||
Union[InvocationStateResponse, str] | ||
] = asyncio.Queue() | ||
self._task_id_to_invocation_id: dict[str, str] = {} | ||
self._request_loop_task: Optional[asyncio.Task] = None | ||
|
||
async def start(self) -> None: | ||
"""Starts the invocation state client. | ||
This method initializes the Function Executor's invocation state server first. | ||
This is why this method needs to be awaited before executing any tasks on the Function Executor | ||
that might use invocation state feature.""" | ||
server_requests = self._stub.initialize_invocation_state_server( | ||
self._response_generator() | ||
) | ||
self._request_loop_task = asyncio.create_task( | ||
self._request_loop(server_requests) | ||
) | ||
|
||
def add_task_to_invocation_id_entry(self, task_id: str, invocation_id: str) -> None: | ||
"""Adds a task ID to invocation ID entry to the client's internal state. | ||
This allows to authorize requests to the invocation state server. | ||
If a request is not comming from the task ID that was added here then it will | ||
be rejected. It's caller's responsibility to only add task IDs that are being | ||
executed by the Function Executor so the Function Executor can't get access to | ||
invocation state of tasks it doesn't run.""" | ||
self._task_id_to_invocation_id[task_id] = invocation_id | ||
|
||
def remove_task_to_invocation_id_entry(self, task_id: str) -> None: | ||
del self._task_id_to_invocation_id[task_id] | ||
|
||
async def destroy(self) -> None: | ||
if self._request_loop_task is not None: | ||
self._request_loop_task.cancel() | ||
await self._client_response_queue.put("shutdown") | ||
|
||
async def _request_loop( | ||
self, server_requests: AsyncGenerator[InvocationStateRequest] | ||
) -> None: | ||
async for request in server_requests: | ||
request: InvocationStateRequest | ||
try: | ||
self._validate_request(request) | ||
response: InvocationStateResponse = await self._process_request(request) | ||
await self._client_response_queue.put(response) | ||
except Exception as e: | ||
await self._client_response_queue.put( | ||
InvocationStateResponse( | ||
request_id=request.request_id, | ||
success=False, | ||
) | ||
) | ||
self._logger.error( | ||
"failed to process request", | ||
exc_info=e, | ||
request_id=request.request_id, | ||
) | ||
|
||
def _validate_request(self, request: InvocationStateRequest) -> None: | ||
( | ||
MessageValidator(request) | ||
.required_field("request_id") | ||
.required_field("task_id") | ||
) | ||
if request.HasField("get"): | ||
(MessageValidator(request.get).required_field("key")) | ||
elif request.HasField("set"): | ||
( | ||
MessageValidator(request.set) | ||
.required_field("key") | ||
.required_serialized_object("value") | ||
) | ||
else: | ||
raise ValueError("unknown request type") | ||
|
||
async def _process_request( | ||
self, request: InvocationStateRequest | ||
) -> InvocationStateResponse: | ||
# This is a very important check. We don't trust invocation ID and task ID | ||
# supplied by Function Executor. If a task ID entry doesn't exist then it's | ||
# a privelege escalation attempt. | ||
invocation_id: str = self._task_id_to_invocation_id[request.task_id] | ||
if request.HasField("get"): | ||
value: Optional[SerializedObject] = await self._get_server_state( | ||
invocation_id, request.get.key | ||
) | ||
return InvocationStateResponse( | ||
request_id=request.request_id, | ||
success=True, | ||
get=GetInvocationStateResponse( | ||
key=request.get.key, | ||
value=value, | ||
), | ||
) | ||
elif request.HasField("set"): | ||
await self._set_server_state( | ||
invocation_id, request.set.key, request.set.value | ||
) | ||
return InvocationStateResponse( | ||
request_id=request.request_id, | ||
success=True, | ||
set=SetInvocationStateResponse(), | ||
) | ||
|
||
async def _response_generator(self) -> AsyncGenerator[InvocationStateResponse]: | ||
while True: | ||
response = await self._client_response_queue.get() | ||
# Hacky cancellation of the generator. | ||
if response == "shutdown": | ||
break | ||
yield response | ||
|
||
async def _set_server_state( | ||
self, invocation_id: str, key: str, value: SerializedObject | ||
) -> None: | ||
url: str = ( | ||
f"{self._base_url}/internal/namespaces/{self._namespace}/compute_graphs/{self._graph}/invocations/{invocation_id}/ctx/{key}" | ||
) | ||
payload = value.bytes if value.HasField("bytes") else value.string | ||
|
||
response = await self._http_client.post( | ||
url=url, | ||
files=[ | ||
( | ||
"value", | ||
("value", payload, value.content_type), | ||
), | ||
], | ||
) | ||
|
||
try: | ||
response.raise_for_status() | ||
except Exception as e: | ||
self._logger.error( | ||
"failed to set graph invocation state", | ||
invocation_id=invocation_id, | ||
key=key, | ||
status_code=response.status_code, | ||
error=response.text, | ||
exc_info=e, | ||
) | ||
raise | ||
|
||
async def _get_server_state( | ||
self, invocation_id: str, key: str | ||
) -> Optional[SerializedObject]: | ||
url: str = ( | ||
f"{self._base_url}/internal/namespaces/{self._namespace}/compute_graphs/{self._graph}/invocations/{invocation_id}/ctx/{key}" | ||
) | ||
|
||
response: httpx.Response = await self._http_client.get(url) | ||
if response.status_code == 404: | ||
return None | ||
|
||
try: | ||
response.raise_for_status() | ||
except httpx.HTTPStatusError as e: | ||
self._logger.error( | ||
f"failed to download graph invocation state value", | ||
invocation_id=invocation_id, | ||
key=key, | ||
status_code=response.status_code, | ||
error=response.text, | ||
exc_info=e, | ||
) | ||
raise | ||
|
||
return serialized_object_from_http_response(response) |
Oops, something went wrong.