Skip to content

Commit

Permalink
Clear send objects and futures when calling Client.clear().
Browse files Browse the repository at this point in the history
Signed-off-by: rafa-be <[email protected]>
  • Loading branch information
rafa-be committed Nov 20, 2024
1 parent 35c0a9a commit b8e147b
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 26 deletions.
7 changes: 6 additions & 1 deletion scaler/client/agent/client_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from scaler.client.serializer.mixins import Serializer
from scaler.io.async_connector import AsyncConnector
from scaler.protocol.python.message import (
ClientClearRequest,
ClientDisconnect,
ClientHeartbeatEcho,
ClientShutdownResponse,
Expand Down Expand Up @@ -140,6 +141,10 @@ async def __on_receive_from_client(self, message: Message):
await self._task_manager.on_cancel_graph_task(message)
return

if isinstance(message, ClientClearRequest):
await self._object_manager.on_client_clear_request(message)
return

raise TypeError(f"Unknown {message=}")

async def __on_receive_from_scheduler(self, message: Message):
Expand Down Expand Up @@ -176,7 +181,7 @@ async def __get_loops(self):
finally:
self._stop_event.set() # always set the stop event before setting futures' exceptions

await self._object_manager.clean_all_objects()
await self._object_manager.clear_all_objects(clear_serializer=True)

self._connector_external.destroy()
self._connector_internal.destroy()
Expand Down
10 changes: 8 additions & 2 deletions scaler/client/agent/future_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scaler.client.serializer.mixins import Serializer
from scaler.io.utility import concat_list_of_bytes
from scaler.protocol.python.common import TaskStatus
from scaler.protocol.python.message import ObjectResponse, TaskResult
from scaler.protocol.python.message import ObjectResponse, TaskCancel, TaskResult
from scaler.utility.exceptions import DisconnectedError, NoWorkerError, TaskNotFoundError, WorkerDiedError
from scaler.utility.metadata.profile_result import retrieve_profiling_result_from_task_result
from scaler.utility.object_utility import deserialize_failure
Expand All @@ -34,6 +34,8 @@ def cancel_all_futures(self):
for task_id, future in self._task_id_to_future.items():
future.cancel()

self._task_id_to_future.clear()

def set_all_futures_with_exception(self, exception: Exception):
with self._lock:
for future in self._task_id_to_future.values():
Expand All @@ -42,7 +44,7 @@ def set_all_futures_with_exception(self, exception: Exception):
except InvalidStateError:
continue # Future got canceled

self._task_id_to_future = dict()
self._task_id_to_future.clear()

def on_task_result(self, result: TaskResult):
with self._lock:
Expand Down Expand Up @@ -94,6 +96,10 @@ def on_task_result(self, result: TaskResult):
except InvalidStateError:
return # Future got canceled

def on_cancel_task(self, task_cancel: TaskCancel):
with self._lock:
self._task_id_to_future.pop(task_cancel.task_id, None)

def on_object_response(self, response: ObjectResponse):
for object_id, object_name, object_bytes in zip(
response.object_content.object_ids,
Expand Down
9 changes: 7 additions & 2 deletions scaler/client/agent/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ObjectRequest,
ObjectResponse,
Task,
TaskCancel,
TaskResult,
)

Expand Down Expand Up @@ -40,11 +41,11 @@ async def on_object_request(self, request: ObjectRequest):
raise NotImplementedError()

@abc.abstractmethod
def record_task_result(self, task_id: bytes, object_id: bytes):
def on_task_result(self, result: TaskResult):
raise NotImplementedError()

@abc.abstractmethod
async def clean_all_objects(self):
async def clear_all_objects(self, clear_serializer: bool):
raise NotImplementedError()


Expand Down Expand Up @@ -79,6 +80,10 @@ def set_all_futures_with_exception(self, exception: Exception):
def on_task_result(self, result: TaskResult):
raise NotImplementedError()

@abc.abstractmethod
def on_cancel_task(self, task_cancel: TaskCancel):
raise NotImplementedError()

@abc.abstractmethod
def on_object_response(self, response: ObjectResponse):
raise NotImplementedError()
Expand Down
47 changes: 41 additions & 6 deletions scaler/client/agent/object_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
from scaler.client.agent.mixins import ObjectManager
from scaler.io.async_connector import AsyncConnector
from scaler.protocol.python.common import ObjectContent
from scaler.protocol.python.message import ObjectInstruction, ObjectRequest
from scaler.protocol.python.message import (
ClientClearRequest,
ObjectInstruction,
ObjectRequest,
TaskResult,
)


class ClientObjectManager(ObjectManager):
def __init__(self, identity: bytes):
self._sent_object_ids: Set[bytes] = set()
self._sent_serializer_id: Optional[bytes] = None

self._identity = identity

self._connector_internal: Optional[AsyncConnector] = None
Expand All @@ -28,18 +35,34 @@ async def on_object_request(self, object_request: ObjectRequest):
assert object_request.request_type == ObjectRequest.ObjectRequestType.Get
await self._connector_external.send(object_request)

def record_task_result(self, task_id: bytes, object_id: bytes):
self._sent_object_ids.add(object_id)
def on_task_result(self, task_result: TaskResult):
# TODO: received result objects should be deleted from the scheduler when no longer needed.
# This requires to not delete objects that are required by not-yet-computed dependent graph tasks.
# For now, we just remove the objects when the client makes a clear request, or on client shutdown.
# https://github.com/Citi/scaler/issues/43

self._sent_object_ids.update(task_result.results)

async def on_client_clear_request(self, client_clear_request: ClientClearRequest):
await self.clear_all_objects(clear_serializer=False)

async def clear_all_objects(self, clear_serializer):
cleared_object_ids = self._sent_object_ids.copy()

if clear_serializer:
self._sent_serializer_id = None
elif self._sent_serializer_id is not None:
cleared_object_ids.remove(self._sent_serializer_id)

self._sent_object_ids.difference_update(cleared_object_ids)

async def clean_all_objects(self):
await self._connector_external.send(
ObjectInstruction.new_msg(
ObjectInstruction.ObjectInstructionType.Delete,
self._identity,
ObjectContent.new_msg(tuple(self._sent_object_ids)),
ObjectContent.new_msg(tuple(cleared_object_ids)),
)
)
self._sent_object_ids = set()

async def __send_object_creation(self, instruction: ObjectInstruction):
assert instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Create
Expand All @@ -48,6 +71,13 @@ async def __send_object_creation(self, instruction: ObjectInstruction):
if not new_object_ids:
return

if b"serializer" in instruction.object_content.object_names:
if self._sent_serializer_id is not None:
raise ValueError("trying to send multiple serializers.")

serializer_index = instruction.object_content.object_names.index(b"serializer")
self._sent_serializer_id = instruction.object_content.object_ids[serializer_index]

new_object_content = ObjectContent.new_msg(
*zip(
*filter(
Expand All @@ -71,5 +101,10 @@ async def __send_object_creation(self, instruction: ObjectInstruction):

async def __delete_objects(self, instruction: ObjectInstruction):
assert instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Delete

if self._sent_serializer_id in instruction.object_content.object_ids:
raise ValueError("trying to delete serializer.")

self._sent_object_ids.difference_update(instruction.object_content.object_ids)

await self._connector_external.send(instruction)
12 changes: 7 additions & 5 deletions scaler/client/agent/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from scaler.client.agent.future_manager import ClientFutureManager
from scaler.client.agent.mixins import ObjectManager, TaskManager
from scaler.io.async_connector import AsyncConnector
from scaler.protocol.python.common import TaskStatus
from scaler.protocol.python.message import GraphTask, GraphTaskCancel, Task, TaskCancel, TaskResult


Expand Down Expand Up @@ -39,6 +38,8 @@ async def on_cancel_task(self, task_cancel: TaskCancel):
return

self._task_ids.remove(task_cancel.task_id)
self._future_manager.on_cancel_task(task_cancel)

await self._connector_external.send(task_cancel)

async def on_new_graph_task(self, task: GraphTask):
Expand All @@ -54,13 +55,14 @@ async def on_cancel_graph_task(self, task_cancel: GraphTaskCancel):
await self._connector_external.send(task_cancel)

async def on_task_result(self, result: TaskResult):
# All task result objects must be propagated to the object manager, even if we do not track the task anymore
# (e.g. if it got cancelled). If we don't, we might lose track of these result objects and not properly clear
# them.
self._object_manager.on_task_result(result)

if result.task_id not in self._task_ids:
return

self._task_ids.remove(result.task_id)

if result.status != TaskStatus.Canceled:
for result_object_id in result.results:
self._object_manager.record_task_result(result.task_id, result_object_id)

self._future_manager.on_task_result(result)
6 changes: 4 additions & 2 deletions scaler/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,12 @@ def send_object(self, obj: Any, name: Optional[str] = None) -> ObjectReference:

def clear(self):
"""
clear the resources used by the client, this will cancel all running futures and invalidate all existing object
clear all resources used by the client, this will cancel all running futures and invalidate all existing object
references
"""
...

self._future_manager.cancel_all_futures()
self._object_buffer.clear()

def disconnect(self):
"""
Expand Down
14 changes: 12 additions & 2 deletions scaler/client/object_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scaler.io.sync_connector import SyncConnector
from scaler.io.utility import chunk_to_list_of_bytes
from scaler.protocol.python.common import ObjectContent
from scaler.protocol.python.message import ObjectInstruction
from scaler.protocol.python.message import ClientClearRequest, ObjectInstruction
from scaler.utility.object_utility import generate_object_id, generate_serializer_object_id


Expand Down Expand Up @@ -65,7 +65,7 @@ def commit_send_objects(self):
)
)

self._pending_objects = list()
self._pending_objects.clear()

def commit_delete_objects(self):
if not self._pending_delete_objects:
Expand All @@ -81,6 +81,16 @@ def commit_delete_objects(self):

self._pending_delete_objects.clear()

def clear(self):
"""
remove all commited and pending objects.
"""

self._pending_delete_objects.clear()
self._pending_objects.clear()

self._connector.send(ClientClearRequest.new_msg())

def __construct_serializer(self) -> ObjectCache:
serializer_bytes = cloudpickle.dumps(self._serializer, protocol=pickle.HIGHEST_PROTOCOL)
object_id = generate_serializer_object_id(self._identity)
Expand Down
11 changes: 8 additions & 3 deletions scaler/protocol/capnp/message.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ struct DisconnectResponse {
worker @0 :Data;
}

struct ClientClearRequest {
}

struct ClientDisconnect {
disconnectType @0 :DisconnectType;

Expand Down Expand Up @@ -200,9 +203,11 @@ struct Message {
stateTask @19 :StateTask;
stateGraphTask @20 :StateGraphTask;

clientDisconnect @21 :ClientDisconnect;
clientShutdownResponse @22 :ClientShutdownResponse;
clientClearRequest @21 :ClientClearRequest;

clientDisconnect @22 :ClientDisconnect;
clientShutdownResponse @23 :ClientShutdownResponse;

processorInitialized @23 :ProcessorInitialized;
processorInitialized @24 :ProcessorInitialized;
}
}
10 changes: 10 additions & 0 deletions scaler/protocol/python/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,15 @@ def new_msg(worker: bytes) -> "DisconnectResponse":
return DisconnectResponse(_message.DisconnectResponse(worker=worker))


class ClientClearRequest(Message):
def __init__(self, msg):
super().__init__(msg)

@staticmethod
def new_msg() -> "ClientClearRequest":
return ClientClearRequest(_message.ClientClearRequest())


class ClientDisconnect(Message):
class DisconnectType(enum.Enum):
Disconnect = _message.ClientDisconnect.DisconnectType.disconnect
Expand Down Expand Up @@ -637,6 +646,7 @@ def new_msg() -> "ProcessorInitialized":
"stateWorker": StateWorker,
"stateTask": StateTask,
"stateGraphTask": StateGraphTask,
"clientClearRequest": ClientClearRequest,
"clientDisconnect": ClientDisconnect,
"clientShutdownResponse": ClientShutdownResponse,
"processorInitialized": ProcessorInitialized,
Expand Down
6 changes: 6 additions & 0 deletions scaler/scheduler/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def register(
async def routine(self):
task_id = await self._unassigned.get()

# FIXME: As the assign_task_to_worker() call can be blocking (especially if there is no worker connected to the
# scheduler), we might end up with the task object being in neither _running nor _unassigned.
# In this case, the scheduler will answer any task cancellation request with a "task not found" error, which is
# a bug.
# https://github.com/Citi/scaler/issues/45

if not await self._worker_manager.assign_task_to_worker(self._task_id_to_task[task_id]):
await self._unassigned.put(task_id)
return
Expand Down
6 changes: 3 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from concurrent.futures import CancelledError

from scaler import Client, SchedulerClusterCombo
from scaler.utility.exceptions import ProcessorDiedError
from scaler.utility.exceptions import MissingObjects, ProcessorDiedError
from scaler.utility.logging.scoped_logger import ScopedLogger
from scaler.utility.logging.utility import setup_logger
from tests.utility import get_available_tcp_port, logging_test_name
Expand Down Expand Up @@ -302,8 +302,8 @@ def test_clear(self):
self.assertTrue(future.cancelled())

# using an old reference should fail
with self.assertRaises(KeyError):
client.submit(noop_sleep, arg_reference)
with self.assertRaises(MissingObjects):
client.submit(noop_sleep, arg_reference).result()

# but new tasks should work fine
self.assertEqual(client.submit(round, 3.14).result(), 3.0)

0 comments on commit b8e147b

Please sign in to comment.