diff --git a/scaler/client/agent/client_agent.py b/scaler/client/agent/client_agent.py index 7fb3294..f391353 100644 --- a/scaler/client/agent/client_agent.py +++ b/scaler/client/agent/client_agent.py @@ -13,6 +13,7 @@ from scaler.client.serializer.mixins import Serializer from scaler.io.async_connector import AsyncConnector from scaler.protocol.python.message import ( + ClientClearRequest, ClientDisconnect, ClientHeartbeatEcho, ClientShutdownResponse, @@ -140,6 +141,10 @@ async def __on_receive_from_client(self, message: Message): await self._task_manager.on_cancel_graph_task(message) return + if isinstance(message, ClientClearRequest): + await self._object_manager.on_client_clear_request(message) + return + raise TypeError(f"Unknown {message=}") async def __on_receive_from_scheduler(self, message: Message): @@ -176,7 +181,7 @@ async def __get_loops(self): finally: self._stop_event.set() # always set the stop event before setting futures' exceptions - await self._object_manager.clean_all_objects() + await self._object_manager.clear_all_objects(clear_serializer=True) self._connector_external.destroy() self._connector_internal.destroy() diff --git a/scaler/client/agent/future_manager.py b/scaler/client/agent/future_manager.py index d9733ec..911085b 100644 --- a/scaler/client/agent/future_manager.py +++ b/scaler/client/agent/future_manager.py @@ -8,7 +8,7 @@ from scaler.client.serializer.mixins import Serializer from scaler.io.utility import concat_list_of_bytes from scaler.protocol.python.common import TaskStatus -from scaler.protocol.python.message import ObjectResponse, TaskResult +from scaler.protocol.python.message import ObjectResponse, TaskCancel, TaskResult from scaler.utility.exceptions import DisconnectedError, NoWorkerError, TaskNotFoundError, WorkerDiedError from scaler.utility.metadata.profile_result import retrieve_profiling_result_from_task_result from scaler.utility.object_utility import deserialize_failure @@ -34,6 +34,8 @@ def cancel_all_futures(self): for task_id, future in self._task_id_to_future.items(): future.cancel() + self._task_id_to_future.clear() + def set_all_futures_with_exception(self, exception: Exception): with self._lock: for future in self._task_id_to_future.values(): @@ -42,7 +44,7 @@ def set_all_futures_with_exception(self, exception: Exception): except InvalidStateError: continue # Future got canceled - self._task_id_to_future = dict() + self._task_id_to_future.clear() def on_task_result(self, result: TaskResult): with self._lock: @@ -94,6 +96,10 @@ def on_task_result(self, result: TaskResult): except InvalidStateError: return # Future got canceled + def on_cancel_task(self, task_cancel: TaskCancel): + with self._lock: + self._task_id_to_future.pop(task_cancel.task_id, None) + def on_object_response(self, response: ObjectResponse): for object_id, object_name, object_bytes in zip( response.object_content.object_ids, diff --git a/scaler/client/agent/mixins.py b/scaler/client/agent/mixins.py index a66b1bf..59c988b 100644 --- a/scaler/client/agent/mixins.py +++ b/scaler/client/agent/mixins.py @@ -10,6 +10,7 @@ ObjectRequest, ObjectResponse, Task, + TaskCancel, TaskResult, ) @@ -40,11 +41,11 @@ async def on_object_request(self, request: ObjectRequest): raise NotImplementedError() @abc.abstractmethod - def record_task_result(self, task_id: bytes, object_id: bytes): + def on_task_result(self, result: TaskResult): raise NotImplementedError() @abc.abstractmethod - async def clean_all_objects(self): + async def clear_all_objects(self, clear_serializer: bool): raise NotImplementedError() @@ -79,6 +80,10 @@ def set_all_futures_with_exception(self, exception: Exception): def on_task_result(self, result: TaskResult): raise NotImplementedError() + @abc.abstractmethod + def on_cancel_task(self, task_cancel: TaskCancel): + raise NotImplementedError() + @abc.abstractmethod def on_object_response(self, response: ObjectResponse): raise NotImplementedError() diff --git a/scaler/client/agent/object_manager.py b/scaler/client/agent/object_manager.py index 92ff051..88f5e0e 100644 --- a/scaler/client/agent/object_manager.py +++ b/scaler/client/agent/object_manager.py @@ -3,12 +3,19 @@ from scaler.client.agent.mixins import ObjectManager from scaler.io.async_connector import AsyncConnector from scaler.protocol.python.common import ObjectContent -from scaler.protocol.python.message import ObjectInstruction, ObjectRequest +from scaler.protocol.python.message import ( + ClientClearRequest, + ObjectInstruction, + ObjectRequest, + TaskResult, +) class ClientObjectManager(ObjectManager): def __init__(self, identity: bytes): self._sent_object_ids: Set[bytes] = set() + self._sent_serializer_id: Optional[bytes] = None + self._identity = identity self._connector_internal: Optional[AsyncConnector] = None @@ -28,18 +35,34 @@ async def on_object_request(self, object_request: ObjectRequest): assert object_request.request_type == ObjectRequest.ObjectRequestType.Get await self._connector_external.send(object_request) - def record_task_result(self, task_id: bytes, object_id: bytes): - self._sent_object_ids.add(object_id) + def on_task_result(self, task_result: TaskResult): + # TODO: received result objects should be deleted from the scheduler when no longer needed. + # This requires to not delete objects that are required by not-yet-computed dependent graph tasks. + # For now, we just remove the objects when the client makes a clear request, or on client shutdown. + # https://github.com/Citi/scaler/issues/43 + + self._sent_object_ids.update(task_result.results) + + async def on_client_clear_request(self, client_clear_request: ClientClearRequest): + await self.clear_all_objects(clear_serializer=False) + + async def clear_all_objects(self, clear_serializer): + cleared_object_ids = self._sent_object_ids.copy() + + if clear_serializer: + self._sent_serializer_id = None + elif self._sent_serializer_id is not None: + cleared_object_ids.remove(self._sent_serializer_id) + + self._sent_object_ids.difference_update(cleared_object_ids) - async def clean_all_objects(self): await self._connector_external.send( ObjectInstruction.new_msg( ObjectInstruction.ObjectInstructionType.Delete, self._identity, - ObjectContent.new_msg(tuple(self._sent_object_ids)), + ObjectContent.new_msg(tuple(cleared_object_ids)), ) ) - self._sent_object_ids = set() async def __send_object_creation(self, instruction: ObjectInstruction): assert instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Create @@ -48,6 +71,13 @@ async def __send_object_creation(self, instruction: ObjectInstruction): if not new_object_ids: return + if b"serializer" in instruction.object_content.object_names: + if self._sent_serializer_id is not None: + raise ValueError("trying to send multiple serializers.") + + serializer_index = instruction.object_content.object_names.index(b"serializer") + self._sent_serializer_id = instruction.object_content.object_ids[serializer_index] + new_object_content = ObjectContent.new_msg( *zip( *filter( @@ -71,5 +101,10 @@ async def __send_object_creation(self, instruction: ObjectInstruction): async def __delete_objects(self, instruction: ObjectInstruction): assert instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Delete + + if self._sent_serializer_id in instruction.object_content.object_ids: + raise ValueError("trying to delete serializer.") + self._sent_object_ids.difference_update(instruction.object_content.object_ids) + await self._connector_external.send(instruction) diff --git a/scaler/client/agent/task_manager.py b/scaler/client/agent/task_manager.py index 8b54c77..537d803 100644 --- a/scaler/client/agent/task_manager.py +++ b/scaler/client/agent/task_manager.py @@ -3,7 +3,6 @@ from scaler.client.agent.future_manager import ClientFutureManager from scaler.client.agent.mixins import ObjectManager, TaskManager from scaler.io.async_connector import AsyncConnector -from scaler.protocol.python.common import TaskStatus from scaler.protocol.python.message import GraphTask, GraphTaskCancel, Task, TaskCancel, TaskResult @@ -39,6 +38,8 @@ async def on_cancel_task(self, task_cancel: TaskCancel): return self._task_ids.remove(task_cancel.task_id) + self._future_manager.on_cancel_task(task_cancel) + await self._connector_external.send(task_cancel) async def on_new_graph_task(self, task: GraphTask): @@ -54,13 +55,14 @@ async def on_cancel_graph_task(self, task_cancel: GraphTaskCancel): await self._connector_external.send(task_cancel) async def on_task_result(self, result: TaskResult): + # All task result objects must be propagated to the object manager, even if we do not track the task anymore + # (e.g. if it got cancelled). If we don't, we might lose track of these result objects and not properly clear + # them. + self._object_manager.on_task_result(result) + if result.task_id not in self._task_ids: return self._task_ids.remove(result.task_id) - if result.status != TaskStatus.Canceled: - for result_object_id in result.results: - self._object_manager.record_task_result(result.task_id, result_object_id) - self._future_manager.on_task_result(result) diff --git a/scaler/client/client.py b/scaler/client/client.py index d66161f..fd850fe 100644 --- a/scaler/client/client.py +++ b/scaler/client/client.py @@ -310,10 +310,12 @@ def send_object(self, obj: Any, name: Optional[str] = None) -> ObjectReference: def clear(self): """ - clear the resources used by the client, this will cancel all running futures and invalidate all existing object + clear all resources used by the client, this will cancel all running futures and invalidate all existing object references """ - ... + + self._future_manager.cancel_all_futures() + self._object_buffer.clear() def disconnect(self): """ diff --git a/scaler/client/object_buffer.py b/scaler/client/object_buffer.py index 885c8cc..5144d1d 100644 --- a/scaler/client/object_buffer.py +++ b/scaler/client/object_buffer.py @@ -8,7 +8,7 @@ from scaler.io.sync_connector import SyncConnector from scaler.io.utility import chunk_to_list_of_bytes from scaler.protocol.python.common import ObjectContent -from scaler.protocol.python.message import ObjectInstruction +from scaler.protocol.python.message import ClientClearRequest, ObjectInstruction from scaler.utility.object_utility import generate_object_id, generate_serializer_object_id @@ -65,7 +65,7 @@ def commit_send_objects(self): ) ) - self._pending_objects = list() + self._pending_objects.clear() def commit_delete_objects(self): if not self._pending_delete_objects: @@ -81,6 +81,16 @@ def commit_delete_objects(self): self._pending_delete_objects.clear() + def clear(self): + """ + remove all commited and pending objects. + """ + + self._pending_delete_objects.clear() + self._pending_objects.clear() + + self._connector.send(ClientClearRequest.new_msg()) + def __construct_serializer(self) -> ObjectCache: serializer_bytes = cloudpickle.dumps(self._serializer, protocol=pickle.HIGHEST_PROTOCOL) object_id = generate_serializer_object_id(self._identity) diff --git a/scaler/protocol/capnp/message.capnp b/scaler/protocol/capnp/message.capnp index c554aa2..8e4abb3 100644 --- a/scaler/protocol/capnp/message.capnp +++ b/scaler/protocol/capnp/message.capnp @@ -107,6 +107,9 @@ struct DisconnectResponse { worker @0 :Data; } +struct ClientClearRequest { +} + struct ClientDisconnect { disconnectType @0 :DisconnectType; @@ -200,9 +203,11 @@ struct Message { stateTask @19 :StateTask; stateGraphTask @20 :StateGraphTask; - clientDisconnect @21 :ClientDisconnect; - clientShutdownResponse @22 :ClientShutdownResponse; + clientClearRequest @21 :ClientClearRequest; + + clientDisconnect @22 :ClientDisconnect; + clientShutdownResponse @23 :ClientShutdownResponse; - processorInitialized @23 :ProcessorInitialized; + processorInitialized @24 :ProcessorInitialized; } } diff --git a/scaler/protocol/python/message.py b/scaler/protocol/python/message.py index e38fb80..75f9bb3 100644 --- a/scaler/protocol/python/message.py +++ b/scaler/protocol/python/message.py @@ -395,6 +395,15 @@ def new_msg(worker: bytes) -> "DisconnectResponse": return DisconnectResponse(_message.DisconnectResponse(worker=worker)) +class ClientClearRequest(Message): + def __init__(self, msg): + super().__init__(msg) + + @staticmethod + def new_msg() -> "ClientClearRequest": + return ClientClearRequest(_message.ClientClearRequest()) + + class ClientDisconnect(Message): class DisconnectType(enum.Enum): Disconnect = _message.ClientDisconnect.DisconnectType.disconnect @@ -637,6 +646,7 @@ def new_msg() -> "ProcessorInitialized": "stateWorker": StateWorker, "stateTask": StateTask, "stateGraphTask": StateGraphTask, + "clientClearRequest": ClientClearRequest, "clientDisconnect": ClientDisconnect, "clientShutdownResponse": ClientShutdownResponse, "processorInitialized": ProcessorInitialized, diff --git a/scaler/scheduler/task_manager.py b/scaler/scheduler/task_manager.py index 30f1f1b..40319ba 100644 --- a/scaler/scheduler/task_manager.py +++ b/scaler/scheduler/task_manager.py @@ -55,6 +55,12 @@ def register( async def routine(self): task_id = await self._unassigned.get() + # FIXME: As the assign_task_to_worker() call can be blocking (especially if there is no worker connected to the + # scheduler), we might end up with the task object being in neither _running nor _unassigned. + # In this case, the scheduler will answer any task cancellation request with a "task not found" error, which is + # a bug. + # https://github.com/Citi/scaler/issues/45 + if not await self._worker_manager.assign_task_to_worker(self._task_id_to_task[task_id]): await self._unassigned.put(task_id) return diff --git a/tests/test_client.py b/tests/test_client.py index d80ebed..3e46359 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,7 +6,7 @@ from concurrent.futures import CancelledError from scaler import Client, SchedulerClusterCombo -from scaler.utility.exceptions import ProcessorDiedError +from scaler.utility.exceptions import MissingObjects, ProcessorDiedError from scaler.utility.logging.scoped_logger import ScopedLogger from scaler.utility.logging.utility import setup_logger from tests.utility import get_available_tcp_port, logging_test_name @@ -302,8 +302,8 @@ def test_clear(self): self.assertTrue(future.cancelled()) # using an old reference should fail - with self.assertRaises(KeyError): - client.submit(noop_sleep, arg_reference) + with self.assertRaises(MissingObjects): + client.submit(noop_sleep, arg_reference).result() # but new tasks should work fine self.assertEqual(client.submit(round, 3.14).result(), 3.0)