Skip to content

Commit

Permalink
Task priorities are now positive numbers.
Browse files Browse the repository at this point in the history
Signed-off-by: rafa-be <[email protected]>
  • Loading branch information
rafa-be committed Oct 29, 2024
1 parent 7192a7d commit de74cd9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 16 deletions.
2 changes: 1 addition & 1 deletion scaler/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def __get_task_flags(self) -> TaskFlags:
parent_task_priority = self.__get_parent_task_priority()

if parent_task_priority is not None:
task_priority = parent_task_priority - 1
task_priority = parent_task_priority + 1
else:
task_priority = 0

Expand Down
39 changes: 25 additions & 14 deletions scaler/worker/agent/task_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Set, Tuple
from typing import Dict, Optional, Set

from scaler.io.async_connector import AsyncConnector
from scaler.protocol.python.common import TaskStatus
Expand Down Expand Up @@ -32,12 +32,9 @@ def register(self, connector: AsyncConnector, processor_manager: ProcessorManage
self._processor_manager = processor_manager

async def on_task_new(self, task: Task):
task_priority = self.__get_task_priority(task)

self._queued_task_id_to_task[task.task_id] = task
self._queued_task_ids.put_nowait(((task_priority, _QUEUED_TASKS_PRIORITY), task.task_id))
self.__enqueue_task(task, is_suspended=False)

await self.__suspend_if_priority_is_lower(task_priority)
await self.__suspend_if_priority_is_higher(task)

async def on_cancel_task(self, task_cancel: TaskCancel):
task_id = task_cancel.task_id
Expand Down Expand Up @@ -94,26 +91,40 @@ async def __processing_task(self):
else:
self._processor_manager.on_resume_task(task_id)

async def __suspend_if_priority_is_lower(self, new_task_priority: int):
async def __suspend_if_priority_is_higher(self, new_task: Task):
current_task = self._processor_manager.current_task()

if current_task is None:
return

new_task_priority = self.__get_task_priority(new_task)
current_task_priority = self.__get_task_priority(current_task)

if new_task_priority >= current_task_priority:
if new_task_priority <= current_task_priority:
return

self._queued_task_ids.put_nowait(((current_task_priority, _SUSPENDED_TASKS_PRIORITY), current_task.task_id))
self._queued_task_id_to_task[current_task.task_id] = current_task
self.__enqueue_task(current_task, is_suspended=True)

await self._processor_manager.on_suspend_task(current_task.task_id)

def __enqueue_task(self, task: Task, is_suspended: bool):
task_priority = self.__get_task_priority(task)

# Higher-priority tasks have an higher priority value. But as the queue is sorted by increasing order, we negate
# the inserted value so they will be at the head of the queue.
if is_suspended:
queue_priority = (-task_priority, _SUSPENDED_TASKS_PRIORITY)
else:
queue_priority = (-task_priority, _QUEUED_TASKS_PRIORITY)

self._queued_task_ids.put_nowait((queue_priority, task.task_id))
self._queued_task_id_to_task[task.task_id] = task

@staticmethod
def __get_task_priority(task: Task) -> int:
return retrieve_task_flags_from_task(task).priority
priority = retrieve_task_flags_from_task(task).priority

@staticmethod
def __is_suspended_task(priority: Tuple[int, int]) -> bool:
return priority[1] == _SUSPENDED_TASKS_PRIORITY
if priority < 0:
raise ValueError(f"invalid task priority, must be positive or zero, got {priority}")

return priority
4 changes: 3 additions & 1 deletion tests/test_async_sorted_priority_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ async def async_test():
await queue.put([1, 1])
await queue.put([3, 6])
await queue.put([2, 4])
await queue.put([-3, 0]) # supports negative priorities
await queue.put([1, 2])

queue.remove(4)
self.assertEqual(queue.qsize(), 5)
self.assertEqual(queue.qsize(), 6)

self.assertEqual(await queue.get(), [-3, 0])
self.assertEqual(await queue.get(), [1, 1])
self.assertEqual(await queue.get(), [1, 2])
self.assertEqual(await queue.get(), [2, 3])
Expand Down

0 comments on commit de74cd9

Please sign in to comment.