From a41984d87fd65f61468c729a1f61f1f66653f145 Mon Sep 17 00:00:00 2001 From: Sharpner6 <1sc2l4qi@duck.com> Date: Wed, 23 Oct 2024 21:33:58 -0400 Subject: [PATCH] Fix serialized big data issue - now single object can be more than 500MB because of cap'n protocol limitation, according to author, it should be List(List(Data)) then concatenate them should support the bytes more than 500MB, but pycapnp is not supporting memoryview now, so when splitting and concatenation, it will create multiple copy for sure, but it should be done right after serialization and deserialization stage and release the memory, - later when we moved from pyzmq to pynng, the memory copy issue should get easied because pynng will expose more raw operations, and we can save threads Signed-off-by: Sharpner6 <1sc2l4qi@duck.com> --- scaler/about.py | 2 +- scaler/client/agent/future_manager.py | 7 ++++--- scaler/client/client.py | 7 ++----- scaler/client/object_buffer.py | 11 +++++++---- scaler/cluster/combo.py | 2 +- scaler/cluster/scheduler.py | 2 +- scaler/entry_points/cluster.py | 2 +- scaler/entry_points/top.py | 2 +- scaler/io/async_binder.py | 2 +- scaler/io/config.py | 5 ++++- scaler/io/utility.py | 18 +++++++++++++++--- scaler/protocol/capnp/_python.py | 1 + scaler/protocol/capnp/common.capnp | 2 +- scaler/protocol/python/common.py | 8 ++++---- scaler/protocol/python/message.py | 6 +++--- scaler/scheduler/mixins.py | 8 ++++---- scaler/scheduler/object_manager.py | 17 +++++++++-------- .../graph/topological_sorter_graphblas.py | 2 +- scaler/utility/queues/async_priority_queue.py | 1 - scaler/worker/agent/heartbeat_manager.py | 2 +- scaler/worker/agent/processor/object_cache.py | 9 +++++---- scaler/worker/agent/processor/processor.py | 7 +++++-- scaler/worker/agent/processor_holder.py | 2 +- scaler/worker/agent/processor_manager.py | 7 ++++--- tests/test_async_sorted_priority_queue.py | 3 ++- tests/test_client.py | 12 ++++++++++++ tests/test_death_timeout.py | 12 ++++-------- tests/test_object_usage.py | 3 ++- tests/test_worker_object_tracker.py | 5 +++-- 29 files changed, 100 insertions(+), 67 deletions(-) diff --git a/scaler/about.py b/scaler/about.py index d3d41ce..d48d2bf 100644 --- a/scaler/about.py +++ b/scaler/about.py @@ -1 +1 @@ -__version__ = "1.8.10" +__version__ = "1.8.11" diff --git a/scaler/client/agent/future_manager.py b/scaler/client/agent/future_manager.py index 154fcd5..d9733ec 100644 --- a/scaler/client/agent/future_manager.py +++ b/scaler/client/agent/future_manager.py @@ -1,11 +1,12 @@ import logging import threading -from concurrent.futures import InvalidStateError, Future +from concurrent.futures import Future, InvalidStateError from typing import Dict, Tuple from scaler.client.agent.mixins import FutureManager from scaler.client.future import ScalerFuture from scaler.client.serializer.mixins import Serializer +from scaler.io.utility import concat_list_of_bytes from scaler.protocol.python.common import TaskStatus from scaler.protocol.python.message import ObjectResponse, TaskResult from scaler.utility.exceptions import DisconnectedError, NoWorkerError, TaskNotFoundError, WorkerDiedError @@ -106,9 +107,9 @@ def on_object_response(self, response: ObjectResponse): try: if status == TaskStatus.Success: - future.set_result(self._serializer.deserialize(object_bytes)) + future.set_result(self._serializer.deserialize(concat_list_of_bytes(object_bytes))) elif status == TaskStatus.Failed: - future.set_exception(deserialize_failure(object_bytes)) + future.set_exception(deserialize_failure(concat_list_of_bytes(object_bytes))) except InvalidStateError: continue # future got canceled diff --git a/scaler/client/client.py b/scaler/client/client.py index 8fc6071..e5c9899 100644 --- a/scaler/client/client.py +++ b/scaler/client/client.py @@ -91,10 +91,7 @@ def __initialize__( self._stop_event = threading.Event() self._context = zmq.Context() self._connector = SyncConnector( - context=self._context, - socket_type=zmq.PAIR, - address=self._client_agent_address, - identity=self._identity, + context=self._context, socket_type=zmq.PAIR, address=self._client_agent_address, identity=self._identity ) self._future_manager = ClientFutureManager(self._serializer) @@ -309,7 +306,7 @@ def send_object(self, obj: Any, name: Optional[str] = None) -> ObjectReference: self.__assert_client_not_stopped() cache = self._object_buffer.buffer_send_object(obj, name) - return ObjectReference(cache.object_name, cache.object_id, len(cache.object_bytes)) + return ObjectReference(cache.object_name, cache.object_id, sum(map(len, cache.object_bytes))) def disconnect(self): """ diff --git a/scaler/client/object_buffer.py b/scaler/client/object_buffer.py index bbb1050..885c8cc 100644 --- a/scaler/client/object_buffer.py +++ b/scaler/client/object_buffer.py @@ -6,6 +6,7 @@ from scaler.client.serializer.mixins import Serializer from scaler.io.sync_connector import SyncConnector +from scaler.io.utility import chunk_to_list_of_bytes from scaler.protocol.python.common import ObjectContent from scaler.protocol.python.message import ObjectInstruction from scaler.utility.object_utility import generate_object_id, generate_serializer_object_id @@ -15,7 +16,7 @@ class ObjectCache: object_id: bytes object_name: bytes - object_bytes: bytes + object_bytes: List[bytes] class ObjectBuffer: @@ -83,13 +84,15 @@ def commit_delete_objects(self): def __construct_serializer(self) -> ObjectCache: serializer_bytes = cloudpickle.dumps(self._serializer, protocol=pickle.HIGHEST_PROTOCOL) object_id = generate_serializer_object_id(self._identity) - return ObjectCache(object_id, b"serializer", serializer_bytes) + return ObjectCache(object_id, b"serializer", chunk_to_list_of_bytes(serializer_bytes)) def __construct_function(self, fn: Callable) -> ObjectCache: function_bytes = self._serializer.serialize(fn) object_id = generate_object_id(self._identity, function_bytes) function_cache = ObjectCache( - object_id, getattr(fn, "__name__", f"").encode(), function_bytes + object_id, + getattr(fn, "__name__", f"").encode(), + chunk_to_list_of_bytes(function_bytes), ) return function_cache @@ -97,4 +100,4 @@ def __construct_object(self, obj: Any, name: Optional[str] = None) -> ObjectCach object_payload = self._serializer.serialize(obj) object_id = generate_object_id(self._identity, object_payload) name_bytes = name.encode() if name else f"".encode() - return ObjectCache(object_id, name_bytes, object_payload) + return ObjectCache(object_id, name_bytes, chunk_to_list_of_bytes(object_payload)) diff --git a/scaler/cluster/combo.py b/scaler/cluster/combo.py index d7af8ec..858722b 100644 --- a/scaler/cluster/combo.py +++ b/scaler/cluster/combo.py @@ -7,6 +7,7 @@ from scaler.io.config import ( DEFAULT_CLIENT_TIMEOUT_SECONDS, DEFAULT_GARBAGE_COLLECT_INTERVAL_SECONDS, + DEFAULT_HARD_PROCESSOR_SUSPEND, DEFAULT_HEARTBEAT_INTERVAL_SECONDS, DEFAULT_IO_THREADS, DEFAULT_LOAD_BALANCE_SECONDS, @@ -18,7 +19,6 @@ DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES, DEFAULT_WORKER_DEATH_TIMEOUT, DEFAULT_WORKER_TIMEOUT_SECONDS, - DEFAULT_HARD_PROCESSOR_SUSPEND, ) from scaler.utility.zmq_config import ZMQConfig diff --git a/scaler/cluster/scheduler.py b/scaler/cluster/scheduler.py index 347721a..b770160 100644 --- a/scaler/cluster/scheduler.py +++ b/scaler/cluster/scheduler.py @@ -1,7 +1,7 @@ import asyncio import multiprocessing from asyncio import AbstractEventLoop, Task -from typing import Optional, Tuple, Any +from typing import Any, Optional, Tuple from scaler.scheduler.config import SchedulerConfig from scaler.scheduler.scheduler import Scheduler, scheduler_main diff --git a/scaler/entry_points/cluster.py b/scaler/entry_points/cluster.py index 581465a..8814003 100644 --- a/scaler/entry_points/cluster.py +++ b/scaler/entry_points/cluster.py @@ -4,13 +4,13 @@ from scaler.cluster.cluster import Cluster from scaler.io.config import ( DEFAULT_GARBAGE_COLLECT_INTERVAL_SECONDS, + DEFAULT_HARD_PROCESSOR_SUSPEND, DEFAULT_HEARTBEAT_INTERVAL_SECONDS, DEFAULT_IO_THREADS, DEFAULT_NUMBER_OF_WORKER, DEFAULT_TASK_TIMEOUT_SECONDS, DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES, DEFAULT_WORKER_DEATH_TIMEOUT, - DEFAULT_HARD_PROCESSOR_SUSPEND, ) from scaler.utility.event_loop import EventLoopType, register_event_loop from scaler.utility.zmq_config import ZMQConfig diff --git a/scaler/entry_points/top.py b/scaler/entry_points/top.py index c8c598e..f47970a 100644 --- a/scaler/entry_points/top.py +++ b/scaler/entry_points/top.py @@ -1,7 +1,7 @@ import argparse import curses import functools -from typing import List, Literal, Dict, Union +from typing import Dict, List, Literal, Union from scaler.io.sync_subscriber import SyncSubscriber from scaler.protocol.python.message import StateScheduler diff --git a/scaler/io/async_binder.py b/scaler/io/async_binder.py index 80e1b93..bb2e192 100644 --- a/scaler/io/async_binder.py +++ b/scaler/io/async_binder.py @@ -2,7 +2,7 @@ import os import uuid from collections import defaultdict -from typing import Awaitable, Callable, List, Optional, Dict +from typing import Awaitable, Callable, Dict, List, Optional import zmq.asyncio from zmq import Frame diff --git a/scaler/io/config.py b/scaler/io/config.py index 3934444..7e22715 100644 --- a/scaler/io/config.py +++ b/scaler/io/config.py @@ -12,8 +12,11 @@ # number of seconds for profiling PROFILING_INTERVAL_SECONDS = 1 +# cap'n proto only allow Data/Text/Blob size to be as big as 500MB +CAPNP_DATA_SIZE_LIMIT = 2**29 - 1 + # message size limitation, max can be 2**64 -MESSAGE_SIZE_LIMIT = 2**64 - 1 +CAPNP_MESSAGE_SIZE_LIMIT = 2**64 - 1 # ========================== # SCHEDULER SPECIFIC OPTIONS diff --git a/scaler/io/utility.py b/scaler/io/utility.py index 6465671..861841f 100644 --- a/scaler/io/utility.py +++ b/scaler/io/utility.py @@ -1,14 +1,14 @@ import logging -from typing import Optional +from typing import List, Optional -from scaler.io.config import MESSAGE_SIZE_LIMIT +from scaler.io.config import CAPNP_DATA_SIZE_LIMIT, CAPNP_MESSAGE_SIZE_LIMIT from scaler.protocol.capnp._python import _message # noqa from scaler.protocol.python.message import PROTOCOL from scaler.protocol.python.mixins import Message def deserialize(data: bytes) -> Optional[Message]: - with _message.Message.from_bytes(data, traversal_limit_in_words=MESSAGE_SIZE_LIMIT) as payload: + with _message.Message.from_bytes(data, traversal_limit_in_words=CAPNP_MESSAGE_SIZE_LIMIT) as payload: if not hasattr(payload, payload.which()): logging.error(f"unknown message type: {payload.which()}") return None @@ -20,3 +20,15 @@ def deserialize(data: bytes) -> Optional[Message]: def serialize(message: Message) -> bytes: payload = _message.Message(**{PROTOCOL.inverse[type(message)]: message.get_message()}) return payload.to_bytes() + + +def chunk_to_list_of_bytes(data: bytes) -> List[bytes]: + # TODO: change to list of memoryview when capnp can support memoryview + return [data[i : i + CAPNP_DATA_SIZE_LIMIT] for i in range(0, len(data), CAPNP_DATA_SIZE_LIMIT)] + + +def concat_list_of_bytes(data: List[bytes]) -> bytes: + one_object_bytes = bytearray() + for chunk in data: + one_object_bytes.extend(chunk) + return one_object_bytes diff --git a/scaler/protocol/capnp/_python.py b/scaler/protocol/capnp/_python.py index 93d725f..6d77cae 100644 --- a/scaler/protocol/capnp/_python.py +++ b/scaler/protocol/capnp/_python.py @@ -1,4 +1,5 @@ import capnp # noqa + import scaler.protocol.capnp.common_capnp as _common # noqa import scaler.protocol.capnp.message_capnp as _message # noqa import scaler.protocol.capnp.status_capnp as _status # noqa diff --git a/scaler/protocol/capnp/common.capnp b/scaler/protocol/capnp/common.capnp index 3f08c37..628f626 100644 --- a/scaler/protocol/capnp/common.capnp +++ b/scaler/protocol/capnp/common.capnp @@ -18,5 +18,5 @@ enum TaskStatus { struct ObjectContent { objectIds @0 :List(Data); objectNames @1 :List(Data); - objectBytes @2 :List(Data); + objectBytes @2 :List(List(Data)); } diff --git a/scaler/protocol/python/common.py b/scaler/protocol/python/common.py index 1e94e9c..e254ddf 100644 --- a/scaler/protocol/python/common.py +++ b/scaler/protocol/python/common.py @@ -1,6 +1,6 @@ import dataclasses import enum -from typing import Tuple +from typing import List, Tuple from scaler.protocol.capnp._python import _common # noqa from scaler.protocol.python.mixins import Message @@ -26,7 +26,7 @@ class TaskStatus(enum.Enum): @dataclasses.dataclass class ObjectContent(Message): def __init__(self, msg): - self._msg = msg + super().__init__(msg) @property def object_ids(self) -> Tuple[bytes, ...]: @@ -37,14 +37,14 @@ def object_names(self) -> Tuple[bytes, ...]: return tuple(self._msg.objectNames) @property - def object_bytes(self) -> Tuple[bytes, ...]: + def object_bytes(self) -> Tuple[List[bytes], ...]: return tuple(self._msg.objectBytes) @staticmethod def new_msg( object_ids: Tuple[bytes, ...], object_names: Tuple[bytes, ...] = tuple(), - object_bytes: Tuple[bytes, ...] = tuple(), + object_bytes: Tuple[List[bytes], ...] = tuple(), ) -> "ObjectContent": return ObjectContent( _common.ObjectContent( diff --git a/scaler/protocol/python/message.py b/scaler/protocol/python/message.py index 43d55b9..e38fb80 100644 --- a/scaler/protocol/python/message.py +++ b/scaler/protocol/python/message.py @@ -1,19 +1,19 @@ import dataclasses import enum import os -from typing import List, Set, Tuple, Optional, Type +from typing import List, Optional, Set, Tuple, Type import bidict from scaler.protocol.capnp._python import _message # noqa -from scaler.protocol.python.common import TaskStatus, ObjectContent +from scaler.protocol.python.common import ObjectContent, TaskStatus from scaler.protocol.python.mixins import Message from scaler.protocol.python.status import ( BinderStatus, ClientManagerStatus, ObjectManagerStatus, - Resource, ProcessorStatus, + Resource, TaskManagerStatus, WorkerManagerStatus, ) diff --git a/scaler/scheduler/mixins.py b/scaler/scheduler/mixins.py index bb6af88..5169e63 100644 --- a/scaler/scheduler/mixins.py +++ b/scaler/scheduler/mixins.py @@ -1,5 +1,5 @@ import abc -from typing import Optional, Set +from typing import List, Optional, Set from scaler.protocol.python.message import ( ClientDisconnect, @@ -7,12 +7,12 @@ DisconnectRequest, GraphTask, GraphTaskCancel, + ObjectInstruction, ObjectRequest, Task, TaskCancel, TaskResult, WorkerHeartbeat, - ObjectInstruction, ) from scaler.utility.mixins import Reporter @@ -27,7 +27,7 @@ async def on_object_request(self, source: bytes, request: ObjectRequest): raise NotImplementedError() @abc.abstractmethod - def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: bytes): + def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: List[bytes]): raise NotImplementedError() @abc.abstractmethod @@ -47,7 +47,7 @@ def get_object_name(self, object_id: bytes) -> bytes: raise NotImplementedError() @abc.abstractmethod - def get_object_content(self, object_id: bytes) -> bytes: + def get_object_content(self, object_id: bytes) -> List[bytes]: raise NotImplementedError() diff --git a/scaler/scheduler/object_manager.py b/scaler/scheduler/object_manager.py index 4953e22..0907e28 100644 --- a/scaler/scheduler/object_manager.py +++ b/scaler/scheduler/object_manager.py @@ -1,7 +1,7 @@ import dataclasses import logging from asyncio import Queue -from typing import Optional, Set +from typing import List, Optional, Set from scaler.io.async_binder import AsyncBinder from scaler.io.async_connector import AsyncConnector @@ -19,7 +19,7 @@ class _ObjectCreation(ObjectUsage): object_id: bytes object_creator: bytes object_name: bytes - object_bytes: bytes + object_bytes: List[bytes] def get_object_key(self) -> bytes: return self.object_id @@ -71,7 +71,7 @@ async def on_object_request(self, source: bytes, request: ObjectRequest): logging.error(f"received unknown object request type {request=} from {source=!r}") - def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: bytes): + def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: List[bytes]): creation = _ObjectCreation(object_id, object_user, object_name, object_bytes) logging.debug( f"add object cache " @@ -102,15 +102,16 @@ def get_object_name(self, object_id: bytes) -> bytes: return self._object_storage.get_object(object_id).object_name - def get_object_content(self, object_id: bytes) -> bytes: + def get_object_content(self, object_id: bytes) -> List[bytes]: if not self.has_object(object_id): - return b"" + return list() return self._object_storage.get_object(object_id).object_bytes def get_status(self) -> ObjectManagerStatus: return ObjectManagerStatus.new_msg( - self._object_storage.object_count(), sum(len(v.object_bytes) for _, v in self._object_storage.items()) + self._object_storage.object_count(), + sum(sum(map(len, v.object_bytes)) for _, v in self._object_storage.items()), ) async def __process_get_request(self, source: bytes, request: ObjectRequest): @@ -139,12 +140,12 @@ def __on_object_create(self, source: bytes, instruction: ObjectInstruction): logging.error(f"received object creation from {source!r} for unknown client {instruction.object_user!r}") return - for object_id, object_name, object_content in zip( + for object_id, object_name, object_bytes in zip( instruction.object_content.object_ids, instruction.object_content.object_names, instruction.object_content.object_bytes, ): - self.on_add_object(instruction.object_user, object_id, object_name, object_content) + self.on_add_object(instruction.object_user, object_id, object_name, object_bytes) def __finished_object_storage(self, creation: _ObjectCreation): logging.debug( diff --git a/scaler/utility/graph/topological_sorter_graphblas.py b/scaler/utility/graph/topological_sorter_graphblas.py index 717608f..60e39a8 100644 --- a/scaler/utility/graph/topological_sorter_graphblas.py +++ b/scaler/utility/graph/topological_sorter_graphblas.py @@ -1,7 +1,7 @@ import collections import graphlib import itertools -from typing import Hashable, Iterable, List, Optional, Tuple, TypeVar, Generic, Mapping +from typing import Generic, Hashable, Iterable, List, Mapping, Optional, Tuple, TypeVar from bidict import bidict diff --git a/scaler/utility/queues/async_priority_queue.py b/scaler/utility/queues/async_priority_queue.py index f02bf57..6f58ed4 100644 --- a/scaler/utility/queues/async_priority_queue.py +++ b/scaler/utility/queues/async_priority_queue.py @@ -2,7 +2,6 @@ from asyncio import Queue from typing import Dict, List, Tuple, Union - PriorityType = Union[int, Tuple["PriorityType", ...]] diff --git a/scaler/worker/agent/heartbeat_manager.py b/scaler/worker/agent/heartbeat_manager.py index 04823f3..98b51be 100644 --- a/scaler/worker/agent/heartbeat_manager.py +++ b/scaler/worker/agent/heartbeat_manager.py @@ -4,7 +4,7 @@ import psutil from scaler.io.async_connector import AsyncConnector -from scaler.protocol.python.message import WorkerHeartbeat, WorkerHeartbeatEcho, Resource +from scaler.protocol.python.message import Resource, WorkerHeartbeat, WorkerHeartbeatEcho from scaler.protocol.python.status import ProcessorStatus from scaler.utility.mixins import Looper from scaler.worker.agent.mixins import HeartbeatManager, ProcessorManager, TaskManager, TimeoutManager diff --git a/scaler/worker/agent/processor/object_cache.py b/scaler/worker/agent/processor/object_cache.py index ee8e42e..4cbbb8a 100644 --- a/scaler/worker/agent/processor/object_cache.py +++ b/scaler/worker/agent/processor/object_cache.py @@ -5,13 +5,14 @@ import platform import threading import time -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import cloudpickle import psutil from scaler.client.serializer.mixins import Serializer from scaler.io.config import CLEANUP_INTERVAL_SECONDS +from scaler.io.utility import concat_list_of_bytes from scaler.protocol.python.common import ObjectContent from scaler.protocol.python.message import Task from scaler.utility.exceptions import DeserializeObjectError @@ -51,8 +52,8 @@ def add_serializer(self, client: bytes, serializer: Serializer): def serialize(self, client: bytes, obj: Any) -> bytes: return self.get_serializer(client).serialize(obj) - def deserialize(self, client: bytes, payload: bytes) -> Any: - return self.get_serializer(client).deserialize(payload) + def deserialize(self, client: bytes, payload: List[bytes]) -> Any: + return self.get_serializer(client).deserialize(concat_list_of_bytes(payload)) def add_objects(self, object_content: ObjectContent, task: Task): zipped = list(zip(object_content.object_ids, object_content.object_names, object_content.object_bytes)) @@ -60,7 +61,7 @@ def add_objects(self, object_content: ObjectContent, task: Task): others = filter(lambda o: not is_object_id_serializer(o[0]), zipped) for object_id, object_name, object_bytes in serializers: - self.add_serializer(object_id, cloudpickle.loads(object_bytes)) + self.add_serializer(object_id, cloudpickle.loads(concat_list_of_bytes(object_bytes))) for object_id, object_name, object_bytes in others: try: diff --git a/scaler/worker/agent/processor/processor.py b/scaler/worker/agent/processor/processor.py index 95aad7a..b331bee 100644 --- a/scaler/worker/agent/processor/processor.py +++ b/scaler/worker/agent/processor/processor.py @@ -13,7 +13,8 @@ from scaler.io.config import DUMMY_CLIENT from scaler.io.sync_connector import SyncConnector -from scaler.protocol.python.common import TaskStatus, ObjectContent +from scaler.io.utility import chunk_to_list_of_bytes +from scaler.protocol.python.common import ObjectContent, TaskStatus from scaler.protocol.python.message import ( ObjectInstruction, ObjectRequest, @@ -263,7 +264,9 @@ def __send_result(self, source: bytes, task_id: bytes, status: TaskStatus, resul ObjectInstruction.ObjectInstructionType.Create, source, ObjectContent.new_msg( - (result_object_id,), (f"".encode(),), (result_bytes,) + (result_object_id,), + (f"".encode(),), + (chunk_to_list_of_bytes(result_bytes),), ), ) ) diff --git a/scaler/worker/agent/processor_holder.py b/scaler/worker/agent/processor_holder.py index 97a829d..48ea1c2 100644 --- a/scaler/worker/agent/processor_holder.py +++ b/scaler/worker/agent/processor_holder.py @@ -10,7 +10,7 @@ from scaler.io.config import DEFAULT_PROCESSOR_KILL_DELAY_SECONDS from scaler.protocol.python.message import Task from scaler.utility.zmq_config import ZMQConfig -from scaler.worker.agent.processor.processor import Processor, SUSPEND_SIGNAL +from scaler.worker.agent.processor.processor import SUSPEND_SIGNAL, Processor class ProcessorHolder: diff --git a/scaler/worker/agent/processor_manager.py b/scaler/worker/agent/processor_manager.py index 3d467aa..10e9fc8 100644 --- a/scaler/worker/agent/processor_manager.py +++ b/scaler/worker/agent/processor_manager.py @@ -11,7 +11,8 @@ # from scaler.utility.logging.utility import setup_logger from scaler.io.async_binder import AsyncBinder from scaler.io.async_connector import AsyncConnector -from scaler.protocol.python.common import TaskStatus, ObjectContent +from scaler.io.utility import chunk_to_list_of_bytes +from scaler.protocol.python.common import ObjectContent, TaskStatus from scaler.protocol.python.message import ( ObjectInstruction, ObjectRequest, @@ -26,7 +27,7 @@ from scaler.utility.mixins import Looper from scaler.utility.object_utility import generate_object_id, serialize_failure from scaler.utility.zmq_config import ZMQConfig, ZMQType -from scaler.worker.agent.mixins import HeartbeatManager, ProcessorManager, ProfilingManager, TaskManager, ObjectTracker +from scaler.worker.agent.mixins import HeartbeatManager, ObjectTracker, ProcessorManager, ProfilingManager, TaskManager from scaler.worker.agent.processor_holder import ProcessorHolder @@ -148,7 +149,7 @@ async def on_failing_task(self, process_status: str): profile_result = self.__end_task(self._current_holder) - result_object_bytes = serialize_failure(ProcessorDiedError(f"{process_status=}")) + result_object_bytes = chunk_to_list_of_bytes(serialize_failure(ProcessorDiedError(f"{process_status=}"))) result_object_id = generate_object_id(source, uuid.uuid4().bytes) await self._connector_external.send( diff --git a/tests/test_async_sorted_priority_queue.py b/tests/test_async_sorted_priority_queue.py index e1cb4d0..d124f21 100644 --- a/tests/test_async_sorted_priority_queue.py +++ b/tests/test_async_sorted_priority_queue.py @@ -2,7 +2,8 @@ import unittest from scaler.utility.logging.utility import setup_logger -from scaler.utility.queues.async_sorted_priority_queue import AsyncSortedPriorityQueue +from scaler.utility.queues.async_sorted_priority_queue import \ + AsyncSortedPriorityQueue from tests.utility import logging_test_name diff --git a/tests/test_client.py b/tests/test_client.py index fadd817..02e75ac 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -115,6 +115,18 @@ def test_heavy_function(self): expected = [task * size for task in tasks] self.assertEqual(results, expected) + def test_very_large_payload(self): + def func(data: bytes): + return data + + with Client(self.address) as client: + payload = os.urandom(2**29 + 300) # 512MB + 300B + future = client.submit(func, payload) + + result = future.result() + + self.assertTrue(payload == result) + def test_sleep(self): with Client(self.address) as client: time.sleep(5) diff --git a/tests/test_death_timeout.py b/tests/test_death_timeout.py index 58273d9..698cb4d 100644 --- a/tests/test_death_timeout.py +++ b/tests/test_death_timeout.py @@ -3,18 +3,14 @@ import unittest from scaler import Client, Cluster, SchedulerClusterCombo -from scaler.io.config import ( - DEFAULT_GARBAGE_COLLECT_INTERVAL_SECONDS, - DEFAULT_HEARTBEAT_INTERVAL_SECONDS, - DEFAULT_IO_THREADS, - DEFAULT_TASK_TIMEOUT_SECONDS, - DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES, -) +from scaler.io.config import (DEFAULT_GARBAGE_COLLECT_INTERVAL_SECONDS, + DEFAULT_HEARTBEAT_INTERVAL_SECONDS, + DEFAULT_IO_THREADS, DEFAULT_TASK_TIMEOUT_SECONDS, + DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES) from scaler.utility.logging.utility import setup_logger from scaler.utility.zmq_config import ZMQConfig from tests.utility import get_available_tcp_port, logging_test_name - # This is a manual test because it can loop infinitely if it fails diff --git a/tests/test_object_usage.py b/tests/test_object_usage.py index a062e9d..94a84d9 100644 --- a/tests/test_object_usage.py +++ b/tests/test_object_usage.py @@ -1,7 +1,8 @@ import dataclasses import unittest -from scaler.scheduler.object_usage.object_tracker import ObjectTracker, ObjectUsage +from scaler.scheduler.object_usage.object_tracker import (ObjectTracker, + ObjectUsage) from scaler.utility.logging.utility import setup_logger from tests.utility import logging_test_name diff --git a/tests/test_worker_object_tracker.py b/tests/test_worker_object_tracker.py index ef18eed..34bb2cc 100644 --- a/tests/test_worker_object_tracker.py +++ b/tests/test_worker_object_tracker.py @@ -1,7 +1,8 @@ import unittest from scaler.protocol.python.common import ObjectContent -from scaler.protocol.python.message import ObjectInstruction, ObjectRequest, ObjectResponse +from scaler.protocol.python.message import (ObjectInstruction, ObjectRequest, + ObjectResponse) from scaler.utility.logging.utility import setup_logger from scaler.worker.agent.object_tracker import VanillaObjectTracker from tests.utility import logging_test_name @@ -56,7 +57,7 @@ def test_object_tracker(self) -> None: ObjectContent.new_msg( (b"object_1", b"object_2", b"object_3"), (b"name_1", b"name_2", b"name_3"), - (b"content_1", b"content_2", b"content_3"), + ([b"content_1"], [b"content_2"], [b"content_3"]), ), ) )