Skip to content

Commit

Permalink
Initial implementation of the non-blocking tagging task allocator.
Browse files Browse the repository at this point in the history
Signed-off-by: rafa-be <[email protected]>
  • Loading branch information
rafa-be committed Dec 19, 2024
1 parent f4b4daa commit 1bc05f0
Show file tree
Hide file tree
Showing 9 changed files with 432 additions and 12 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ graphlib-backport; python_version < '3.9'
psutil
pycapnp
pyzmq
sortedcontainers
tblib
2 changes: 1 addition & 1 deletion scaler/about.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.8.13"
__version__ = "1.8.14"
10 changes: 9 additions & 1 deletion scaler/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def get(
task=Task.new_msg(
task_id=graph_task.task_id,
source=self._identity,
tags=set(),
metadata=b"",
func_object_id=b"",
function_args=[],
Expand Down Expand Up @@ -370,6 +371,7 @@ def __submit(self, function_object_id: bytes, args: Tuple[Any, ...], delayed: bo
task = Task.new_msg(
task_id=task_id,
source=self._identity,
tags=set(),
metadata=task_flags_bytes,
func_object_id=function_object_id,
function_args=arguments,
Expand Down Expand Up @@ -486,6 +488,7 @@ def __construct_graph(
task_id_to_tasks[task_id] = Task.new_msg(
task_id=task_id,
source=self._identity,
tags=set(),
metadata=task_flags_bytes,
func_object_id=function_cache.object_id,
function_args=arguments,
Expand All @@ -506,7 +509,12 @@ def __construct_graph(
argument, data = node_name_to_arguments[key]
future: ScalerFuture = self._future_factory(
task=Task.new_msg(
task_id=argument.data, source=self._identity, metadata=b"", func_object_id=b"", function_args=[]
task_id=argument.data,
source=self._identity,
tags=set(),
metadata=b"",
func_object_id=b"",
function_args=[],
),
is_delayed=False,
group_task_id=graph_task_id,
Expand Down
20 changes: 11 additions & 9 deletions scaler/protocol/capnp/message.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ using Status = import "status.capnp";
struct Task {
taskId @0 :Data;
source @1 :Data;
metadata @2 :Data;
funcObjectId @3 :Data;
functionArgs @4 :List(Argument);
tags @2 :List(Text);
metadata @3 :Data;
funcObjectId @4 :Data;
functionArgs @5 :List(Argument);

struct Argument {
type @0 :ArgumentType;
Expand Down Expand Up @@ -58,12 +59,13 @@ struct ClientHeartbeatEcho {
}

struct WorkerHeartbeat {
agent @0 :Status.Resource;
rssFree @1 :UInt64;
queuedTasks @2 :UInt32;
latencyUS @3 :UInt32;
taskLock @4 :Bool;
processors @5 :List(Status.ProcessorStatus);
tags @0 :List(Text);
agent @1 :Status.Resource;
rssFree @2 :UInt64;
queuedTasks @3 :UInt32;
latencyUS @4 :UInt32;
taskLock @5 :Bool;
processors @6 :List(Status.ProcessorStatus);
}

struct WorkerHeartbeatEcho {
Expand Down
18 changes: 17 additions & 1 deletion scaler/protocol/python/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def task_id(self) -> bytes:
def source(self) -> bytes:
return self._msg.source

@property
def tags(self) -> Set[str]:
return set(self._msg.tags)

@property
def metadata(self) -> bytes:
return self._msg.metadata
Expand All @@ -66,12 +70,18 @@ def function_args(self) -> List[Argument]:

@staticmethod
def new_msg(
task_id: bytes, source: bytes, metadata: bytes, func_object_id: bytes, function_args: List[Argument]
task_id: bytes,
source: bytes,
tags: Set[str],
metadata: bytes,
func_object_id: bytes,
function_args: List[Argument]
) -> "Task":
return Task(
_message.Task(
taskId=task_id,
source=source,
tags=list(tags),
metadata=metadata,
funcObjectId=func_object_id,
functionArgs=[_message.Task.Argument(type=arg.type.value, data=arg.data) for arg in function_args],
Expand Down Expand Up @@ -232,6 +242,10 @@ class WorkerHeartbeat(Message):
def __init__(self, msg):
super().__init__(msg)

@property
def tags(self) -> Set[str]:
return set(self._msg.tags)

@property
def agent(self) -> Resource:
return Resource(self._msg.agent)
Expand All @@ -258,6 +272,7 @@ def processors(self) -> List[ProcessorStatus]:

@staticmethod
def new_msg(
tags: Set[str],
agent: Resource,
rss_free: int,
queued_tasks: int,
Expand All @@ -267,6 +282,7 @@ def new_msg(
) -> "WorkerHeartbeat":
return WorkerHeartbeat(
_message.WorkerHeartbeat(
tags=list(tags),
agent=agent.get_message(),
rssFree=rss_free,
queuedTasks=queued_tasks,
Expand Down
244 changes: 244 additions & 0 deletions scaler/scheduler/allocators/tagged_allocator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import dataclasses
from collections import OrderedDict, defaultdict
from itertools import takewhile
from sortedcontainers import SortedList
from typing import Dict, Iterable, List, Optional, Set

from scaler.protocol.python.message import Task


@dataclasses.dataclass(frozen=True)
class _TaskHolder:
task_id: bytes = dataclasses.field()
tags: Set[str] = dataclasses.field()


@dataclasses.dataclass(frozen=True)
class _WorkerHolder:
worker_id: bytes = dataclasses.field()
tags: Set[str] = dataclasses.field()

# Queued tasks, ordered from oldest to youngest tasks.
task_id_to_task: OrderedDict[bytes, _TaskHolder] = dataclasses.field(default_factory=OrderedDict)

def n_tasks(self) -> int:
return len(self.task_id_to_task)

def copy(self) -> "_WorkerHolder":
return _WorkerHolder(self.worker_id, self.tags, self.task_id_to_task.copy())


class TaggedAllocator: # FIXME: remove async. methods from the TaskAllocator mixin to make it a derivative.
def __init__(self, max_tasks_per_worker: int):
self._max_tasks_per_worker = max_tasks_per_worker

self._worker_id_to_worker: Dict[bytes, _WorkerHolder] = {}

self._task_id_to_worker_id: Dict[bytes, bytes] = {}
self._tag_to_worker_ids: Dict[str, Set[bytes]] = {}

def add_worker(self, worker_id: bytes, tags: Set[str]) -> bool:
if worker_id in self._worker_id_to_worker:
return False

worker = _WorkerHolder(worker_id=worker_id, tags=tags)
self._worker_id_to_worker[worker_id] = worker

for tag in tags:
if tag not in self._tag_to_worker_ids:
self._tag_to_worker_ids[tag] = set()

self._tag_to_worker_ids[tag].add(worker.worker_id)

return True

def remove_worker(self, worker_id: bytes) -> List[bytes]:
worker = self._worker_id_to_worker.pop(worker_id, None)

if worker is None:
return []

for tag in worker.tags:
self._tag_to_worker_ids[tag].discard(worker.worker_id)
if len(self._tag_to_worker_ids[tag]) == 0:
self._tag_to_worker_ids.pop(tag)

task_ids = list(worker.task_id_to_task.keys())
for task_id in task_ids:
self._task_id_to_worker_id.pop(task_id)

return task_ids

def get_worker_ids(self) -> Set[bytes]:
return set(self._worker_id_to_worker.keys())

def get_worker_by_task_id(self, task_id: bytes) -> bytes:
return self._task_id_to_worker_id.get(task_id, b"")

def assign_task(self, task: Task) -> Optional[bytes]:
available_workers = self.__get_available_workers_for_tags(task.tags)

if len(available_workers) <= 0:
return None

min_loaded_worker = min(available_workers, key=lambda worker: worker.n_tasks())
min_loaded_worker.task_id_to_task[task.task_id] = _TaskHolder(task.task_id, task.tags)

self._task_id_to_worker_id[task.task_id] = min_loaded_worker.worker_id

return min_loaded_worker.worker_id

def remove_task(self, task_id: bytes) -> Optional[bytes]:
worker_id = self._task_id_to_worker_id.pop(task_id, None)

if worker_id is None:
return None

worker = self._worker_id_to_worker[worker_id]
worker.task_id_to_task.pop(task_id)

return worker_id

def get_assigned_worker(self, task_id: bytes) -> Optional[bytes]:
if task_id not in self._task_id_to_worker_id:
return None

return self._task_id_to_worker_id[task_id]

def has_available_worker(self, tags: Optional[Set[str]] = None) -> bool:
if tags is None:
tags = set()

return len(self.__get_available_workers_for_tags(tags)) > 0

def balance(self) -> Dict[bytes, List[bytes]]:
"""Returns, for every worker id, the list of task ids to balance out."""

has_idle_workers = any(worker.n_tasks() == 0 for worker in self._worker_id_to_worker.values())

if not has_idle_workers:
return {}

# The balancing algorithm works by trying to move tasks from workers that have more queued tasks than the
# average (high-load workers) to workers that have less tasks than the average (low-load workers).
#
# Because of the tag constraints, this might result in less than optimal balancing. However, it will greatly
# limit the number of messages transmitted to workers, and reduce the algorithmic worst-case of the balancing
# process.
#
# The overall worst-case time complexity of the balancing algorithm is:
#
# O(n_workers * log(n_workers) + n_tasks * n_workers * n_tags)
#
# However, if the cluster does not use any tag, time complexity is always:
#
# O(n_workers * log(n_workers) + n_tasks * log(n_workers))
#
# See <https://github.com/Citi/scaler/issues/32#issuecomment-2541897645> for more details.

n_tasks = sum(worker.n_tasks() for worker in self._worker_id_to_worker.values())
avg_tasks_per_worker = n_tasks / len(self._worker_id_to_worker)

def is_balanced(worker: _WorkerHolder) -> bool:
return abs(worker.n_tasks() - avg_tasks_per_worker) <= 1

# First, we create a copy of the current workers objects so that we can modify their respective task queues.
# We also filter out workers that are already balanced as we will not touch these.
#
# Time complexity is O(n_workers)

workers = [worker.copy() for worker in self._worker_id_to_worker.values() if not is_balanced(worker)]

# Then, we sort the remaining workers by the number of queued tasks.
#
# Time complexity is O(n_workers * log(n_workers))

sorted_workers: SortedList[_WorkerHolder] = SortedList(workers, key=lambda worker: worker.n_tasks())

# Finally, we repeatedly remove one task from the most loaded worker until either:
#
# - all workers are balanced;
# - we cannot find a low-load worker than can accept tasks from a high-load worker.
#
# Worst-case time complexity is O(n_tasks * n_workers * n_tags). If no tag is used in the cluster, complexity is
# always O(n_tasks * log(n_workers))

balancing_advice: Dict[bytes, List[bytes]] = defaultdict(list)
unbalanceable_tasks: Set[bytes] = set()

while len(sorted_workers) >= 2:
most_loaded_worker: _WorkerHolder = sorted_workers.pop(-1)

if most_loaded_worker.n_tasks() - avg_tasks_per_worker <= 1:
# Most loaded worker is not high-load, stop
break

# Go through all of the most loaded worker's tasks, trying to find a low-load worker that can accept it.

receiving_worker: Optional[_WorkerHolder] = None
moved_task: Optional[_TaskHolder] = None

for task in reversed(most_loaded_worker.task_id_to_task.values()): # Try to balance youngest tasks first.
if task.task_id in unbalanceable_tasks:
continue

worker_candidates = takewhile(lambda worker: worker.n_tasks() < avg_tasks_per_worker, sorted_workers)
receiving_worker_index = self.__balance_try_reassign_task(task, worker_candidates)

if receiving_worker_index is not None:
receiving_worker = sorted_workers.pop(receiving_worker_index)
moved_task = task
break
else:
# We could not find a receiving worker for this task, remember the task as unbalanceable in case the
# worker pops-up again. This greatly reduces the worst-case big-O complexity of the algorithm.
unbalanceable_tasks.add(task.task_id)

# Re-inserts the workers in the sorted list if these can be balanced more.

if moved_task is not None:
assert receiving_worker is not None

balancing_advice[most_loaded_worker.worker_id].append(moved_task.task_id)

most_loaded_worker.task_id_to_task.pop(moved_task.task_id)
receiving_worker.task_id_to_task[moved_task.task_id] = moved_task

if not is_balanced(most_loaded_worker):
sorted_workers.add(most_loaded_worker)

if not is_balanced(receiving_worker):
sorted_workers.add(receiving_worker)

return balancing_advice

@staticmethod
def __balance_try_reassign_task(task: _TaskHolder, worker_candidates: Iterable[_WorkerHolder]) -> Optional[int]:
"""Returns the index of the first worker that can accept the task."""

# Time complexity is O(n_worker * n_tags)

for worker_index, worker in enumerate(worker_candidates):
if task.tags.issubset(worker.tags):
return worker_index

return None

def statistics(self) -> Dict:
return {
worker.worker_id: {"free": self._max_tasks_per_worker - worker.n_tasks(), "sent": worker.n_tasks()}
for worker in self._worker_id_to_worker.values()
}

def __get_available_workers_for_tags(self, tags: Set[str]) -> List[_WorkerHolder]:
if any(tag not in self._tag_to_worker_ids for tag in tags):
return []

matching_worker_ids = set(self._worker_id_to_worker.keys())

for tag in tags:
matching_worker_ids.intersection_update(self._tag_to_worker_ids[tag])

matching_workers = [self._worker_id_to_worker[worker_id] for worker_id in matching_worker_ids]

return [worker for worker in matching_workers if worker.n_tasks() < self._max_tasks_per_worker]
1 change: 1 addition & 0 deletions scaler/scheduler/graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ async def __check_one_graph(self, graph_task_id: bytes):
task = Task.new_msg(
task_id=task_info.task.task_id,
source=task_info.task.source,
tags=set(),
metadata=task_info.task.metadata,
func_object_id=task_info.task.func_object_id,
function_args=[self.__get_argument(graph_task_id, arg) for arg in task_info.task.function_args],
Expand Down
1 change: 1 addition & 0 deletions scaler/worker/agent/heartbeat_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ async def routine(self):

await self._connector_external.send(
WorkerHeartbeat.new_msg(
set(),
Resource.new_msg(int(self._agent_process.cpu_percent() * 10), self._agent_process.memory_info().rss),
psutil.virtual_memory().available,
self._worker_task_manager.get_queued_size() - num_suspended_processors,
Expand Down
Loading

0 comments on commit 1bc05f0

Please sign in to comment.