Skip to content

Commit

Permalink
Invocation state proxy for Function Executors and new API for functions
Browse files Browse the repository at this point in the history
We don't want to have Indexify Server API credentials in Function Executors
because it's a security risk if Function Executors run not trusted code.
To achieve that Executor provides a gRPC API to Function Executor to
get/set graph invocation state. Once a request is recieved Executor
finds invocation ID of the requesting Function Executor and does
Indexify Server API call to get/set and invocation state key.

There are a few customer function facing API changes here:

Instead of calling `get_ctx().set_state_key()` and
`get_ctx().get_state_key()` they call
`get_ctx().invocation_state.get()` and
`get_ctx().invocation_state.set()`.

This intoduces a separate API for graph invocation state which is more
clear than adding unbounded number of methods into object returned
by `get_ctx()`.

Another change is that `set(key, value)` now accepts any value
that is serializable using CloudPickle. This is provides a consisten
"no surprises" UX because we're using CloudPickle for function inputs
and outputs. Supporting this required small changes on Indexify Server
side as previous JSON was explicitly used at HTTP protocol and storage
layers. Now it supports arbitrary binary and textual formats for
invocation state values.

Finally `get(key)` now returns Optional[Any] so None is returned
if the key wasn't set yet. This allows the customer code to decide what
to do in this case.

Testing:

make fmt
make test
  • Loading branch information
eabatalov committed Dec 20, 2024
1 parent 208c2e8 commit f0b1267
Show file tree
Hide file tree
Showing 30 changed files with 1,336 additions and 236 deletions.
15 changes: 4 additions & 11 deletions python-sdk/indexify/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@
console = Console(theme=custom_theme)

app = typer.Typer(pretty_exceptions_enable=False, no_args_is_help=True)
config_path_option: Optional[str] = typer.Option(
None, help="Path to the TLS configuration file"
)


@app.command(
Expand Down Expand Up @@ -186,7 +183,9 @@ def executor(
dev: Annotated[
bool, typer.Option("--dev", "-d", help="Run the executor in development mode")
] = False,
config_path: Optional[str] = config_path_option,
config_path: Optional[str] = typer.Option(
None, help="Path to the TLS configuration file"
),
executor_cache: Optional[str] = typer.Option(
"~/.indexify/executor_cache", help="Path to the executor cache directory"
),
Expand Down Expand Up @@ -242,27 +241,21 @@ def function_executor(
function_executor_server_address: str = typer.Option(
help="Function Executor server address"
),
indexify_server_address: str = typer.Option(help="Indexify server address"),
dev: Annotated[
bool, typer.Option("--dev", "-d", help="Run the executor in development mode")
] = False,
config_path: Optional[str] = config_path_option,
):
if not dev:
configure_production_logging()

logger.info(
"starting function executor server",
function_executor_server_address=function_executor_server_address,
indexify_server_address=indexify_server_address,
config_path=config_path,
)

FunctionExecutorServer(
server_address=function_executor_server_address,
service=FunctionExecutorService(
indexify_server_address=indexify_server_address, config_path=config_path
),
service=FunctionExecutorService(),
).run()


Expand Down
26 changes: 15 additions & 11 deletions python-sdk/indexify/executor/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async def _fetch_url(
self, url: str, resource_description: str, logger: Any
) -> SerializedObject:
logger.info(f"fetching {resource_description}", url=url)
response = await self._client.get(url)
response: httpx.Response = await self._client.get(url)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
Expand All @@ -153,13 +153,17 @@ async def _fetch_url(
)
raise

# We're hardcoding the content type currently used by Python SDK. It might change in the future.
# There's no other way for now to determine if the response is a bytes or string.
if response.headers["content-type"] == "application/octet-stream":
return SerializedObject(
bytes=response.content, content_type=response.headers["content-type"]
)
else:
return SerializedObject(
string=response.text, content_type=response.headers["content-type"]
)
return serialized_object_from_http_response(response)


def serialized_object_from_http_response(response: httpx.Response) -> SerializedObject:
# We're hardcoding the content type currently used by Python SDK. It might change in the future.
# There's no other way for now to determine if the response is a bytes or string.
if response.headers["content-type"] == "application/octet-stream":
return SerializedObject(
bytes=response.content, content_type=response.headers["content-type"]
)
else:
return SerializedObject(
string=response.text, content_type=response.headers["content-type"]
)
12 changes: 6 additions & 6 deletions python-sdk/indexify/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ def __init__(
self._logger.info("running the extractor with TLS enabled")
protocol = "https"

self._server_addr = server_addr
self._base_url = f"{protocol}://{self._server_addr}"
self._code_path = code_path
self._function_worker = FunctionWorker(
function_executor_factory=ProcessFunctionExecutorFactory(
indexify_server_address=server_addr,
development_mode=development_mode,
config_path=config_path,
)
),
base_url=self._base_url,
config_path=config_path,
)
self._server_addr = server_addr
self._base_url = f"{protocol}://{self._server_addr}"
self._code_path = code_path
self._downloader = Downloader(
code_path=code_path, base_url=self._base_url, config_path=config_path
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import grpc

from indexify.common_util import get_httpx_client
from indexify.function_executor.proto.function_executor_pb2 import (
InitializeRequest,
InitializeResponse,
Expand All @@ -13,6 +14,7 @@

from .function_executor import FunctionExecutor
from .function_executor_factory import FunctionExecutorFactory
from .invocation_state_client import InvocationStateClient


class FunctionExecutorMap:
Expand All @@ -21,10 +23,17 @@ class FunctionExecutorMap:
The map is safe to use by multiple couroutines running in event loop on the same thread
but it's not thread safe (can't be used from different threads concurrently)."""

def __init__(self, factory: FunctionExecutorFactory):
def __init__(
self,
factory: FunctionExecutorFactory,
base_url: str,
config_path: Optional[str],
):
self._factory = factory
self._base_url = base_url
self._config_path = config_path
# Map of initialized Function executors ready to run tasks.
# function ID -> FunctionExecutor
# Function ID -> FunctionExecutor.
self._executors: Dict[str, FunctionExecutor] = {}
# We have to do all operations under this lock because we need to ensure
# that we don't create more Function Executors than required. This is important
Expand Down Expand Up @@ -52,6 +61,7 @@ async def get_or_create(
return self._executors[id]

executor: Optional[FunctionExecutor] = None
invocation_state_client: Optional[InvocationStateClient] = None
try:
executor = await self._factory.create(logger, state=initial_state)
channel: grpc.aio.Channel = await executor.channel()
Expand All @@ -61,7 +71,22 @@ async def get_or_create(
)
if not initialize_response.success:
raise Exception("initialize RPC failed at function executor")
invocation_state_client = InvocationStateClient(
stub=stub,
base_url=self._base_url,
http_client=get_httpx_client(
config_path=self._config_path, make_async=True
),
graph=initialize_request.graph_name,
namespace=initialize_request.namespace,
logger=logger,
)
await invocation_state_client.start()
# This is dirty but requires refactoring to implement properly.
initial_state.invocation_state_client = invocation_state_client
except Exception:
if invocation_state_client is not None:
await invocation_state_client.destroy()
if executor is not None:
await self._factory.destroy(executor=executor, logger=logger)
# Function Executor creation or initialization failed.
Expand All @@ -82,6 +107,8 @@ async def delete(
# Function Executor was already deleted or replaced and the caller is not aware of this.
return
del self._executors[id]
if function_executor.state().invocation_state_client is not None:
await function_executor.state().invocation_state_client.destroy()
await self._factory.destroy(executor=function_executor, logger=logger)

async def clear(self, logger):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import asyncio
from typing import Any, AsyncGenerator, Optional, Union

import httpx

from indexify.executor.downloader import serialized_object_from_http_response
from indexify.function_executor.proto.function_executor_pb2 import (
GetInvocationStateResponse,
InvocationStateRequest,
InvocationStateResponse,
SerializedObject,
SetInvocationStateResponse,
)
from indexify.function_executor.proto.function_executor_pb2_grpc import (
FunctionExecutorStub,
)
from indexify.function_executor.proto.message_validator import MessageValidator


class InvocationStateClient:
"""InvocationStateClient is a client for the invocation state server of a Function Executor.
The client initializes the Function Executor's invocation state server and executes requests
it sends to the client.
"""

def __init__(
self,
stub: FunctionExecutorStub,
base_url: str,
http_client: httpx.AsyncClient,
graph: str,
namespace: str,
logger: Any,
):
self._stub: FunctionExecutorStub = stub
self._base_url: str = base_url
self._http_client: httpx.AsyncClient = http_client
self._graph: str = graph
self._namespace: str = namespace
self._logger: Any = logger.bind(
module=__name__, graph=graph, namespace=namespace
)
self._client_response_queue: asyncio.Queue[
Union[InvocationStateResponse, str]
] = asyncio.Queue()
self._task_id_to_invocation_id: dict[str, str] = {}
self._request_loop_task: Optional[asyncio.Task] = None

async def start(self) -> None:
"""Starts the invocation state client.
This method initializes the Function Executor's invocation state server first.
This is why this method needs to be awaited before executing any tasks on the Function Executor
that might use invocation state feature."""
server_requests = self._stub.initialize_invocation_state_server(
self._response_generator()
)
self._request_loop_task = asyncio.create_task(
self._request_loop(server_requests)
)

def add_task_to_invocation_id_entry(self, task_id: str, invocation_id: str) -> None:
"""Adds a task ID to invocation ID entry to the client's internal state.
This allows to authorize requests to the invocation state server.
If a request is not comming from the task ID that was added here then it will
be rejected. It's caller's responsibility to only add task IDs that are being
executed by the Function Executor so the Function Executor can't get access to
invocation state of tasks it doesn't run."""
self._task_id_to_invocation_id[task_id] = invocation_id

def remove_task_to_invocation_id_entry(self, task_id: str) -> None:
del self._task_id_to_invocation_id[task_id]

async def destroy(self) -> None:
if self._request_loop_task is not None:
self._request_loop_task.cancel()
await self._client_response_queue.put("shutdown")

async def _request_loop(
self, server_requests: AsyncGenerator[InvocationStateRequest]
) -> None:
async for request in server_requests:
request: InvocationStateRequest
try:
self._validate_request(request)
response: InvocationStateResponse = await self._process_request(request)
await self._client_response_queue.put(response)
except Exception as e:
await self._client_response_queue.put(
InvocationStateResponse(
request_id=request.request_id,
success=False,
)
)
self._logger.error(
"failed to process request",
exc_info=e,
request_id=request.request_id,
)

def _validate_request(self, request: InvocationStateRequest) -> None:
(
MessageValidator(request)
.required_field("request_id")
.required_field("task_id")
)
if request.HasField("get"):
(MessageValidator(request.get).required_field("key"))
elif request.HasField("set"):
(
MessageValidator(request.set)
.required_field("key")
.required_serialized_object("value")
)
else:
raise ValueError("unknown request type")

async def _process_request(
self, request: InvocationStateRequest
) -> InvocationStateResponse:
# This is a very important check. We don't trust invocation ID and task ID
# supplied by Function Executor. If a task ID entry doesn't exist then it's
# a privelege escalation attempt.
invocation_id: str = self._task_id_to_invocation_id[request.task_id]
if request.HasField("get"):
value: Optional[SerializedObject] = await self._get_server_state(
invocation_id, request.get.key
)
return InvocationStateResponse(
request_id=request.request_id,
success=True,
get=GetInvocationStateResponse(
key=request.get.key,
value=value,
),
)
elif request.HasField("set"):
await self._set_server_state(
invocation_id, request.set.key, request.set.value
)
return InvocationStateResponse(
request_id=request.request_id,
success=True,
set=SetInvocationStateResponse(),
)

async def _response_generator(self) -> AsyncGenerator[InvocationStateResponse]:
while True:
response = await self._client_response_queue.get()
# Hacky cancellation of the generator.
if response == "shutdown":
break
yield response

async def _set_server_state(
self, invocation_id: str, key: str, value: SerializedObject
) -> None:
url: str = (
f"{self._base_url}/internal/namespaces/{self._namespace}/compute_graphs/{self._graph}/invocations/{invocation_id}/ctx/{key}"
)
payload = value.bytes if value.HasField("bytes") else value.string

response = await self._http_client.post(
url=url,
files=[
(
"value",
("value", payload, value.content_type),
),
],
)

try:
response.raise_for_status()
except Exception as e:
self._logger.error(
"failed to set graph invocation state",
invocation_id=invocation_id,
key=key,
status_code=response.status_code,
error=response.text,
exc_info=e,
)
raise

async def _get_server_state(
self, invocation_id: str, key: str
) -> Optional[SerializedObject]:
url: str = (
f"{self._base_url}/internal/namespaces/{self._namespace}/compute_graphs/{self._graph}/invocations/{invocation_id}/ctx/{key}"
)

response: httpx.Response = await self._http_client.get(url)
if response.status_code == 404:
return None

try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
self._logger.error(
f"failed to download graph invocation state value",
invocation_id=invocation_id,
key=key,
status_code=response.status_code,
error=response.text,
exc_info=e,
)
raise

return serialized_object_from_http_response(response)
Loading

0 comments on commit f0b1267

Please sign in to comment.