Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/indexing limiters #3572

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 160 additions & 47 deletions backend/onyx/background/celery/tasks/indexing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.indexing.job_client import SimpleJob
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
Expand Down Expand Up @@ -754,63 +755,31 @@ def try_creating_indexing_task(
return index_attempt_id


@shared_task(
name=OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
bind=True,
acks_late=False,
track_started=True,
)
def connector_indexing_proxy_task(
def connector_indexing_wait_for_spawned_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
job: SimpleJob,
redis_connector_index: RedisConnectorIndex,
use_semaphore: bool,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)

if not self.request.id:
task_logger.error("self.request.id is None!")

client = SimpleJobClient()

job = client.submit(
connector_indexing_task_wrapper,
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
global_version.is_ee_version(),
pure=False,
)

if not job:
task_logger.info(
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return

task_logger.info(
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)

redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)

while True:
sleep(5)

# renew active signal
redis_connector_index.set_active()
if use_semaphore:
refreshed = redis_connector_index.refresh_semaphore()
if not refreshed:
task_logger.warning(
"Indexing watchdog - refresh semaphore failed: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
)

# if the job is done, clean up and break
if job.done():
Expand Down Expand Up @@ -915,8 +884,152 @@ def connector_indexing_proxy_task(
)
continue


@shared_task(
name=OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
bind=True,
acks_late=False,
track_started=True,
max_retries=64, # an arbitrarily large but finite number of retries
)
def connector_indexing_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
TASK_RETRY_DELAY = 1800 # in seconds

# we only care about limiting concurrency in a multitenant scenario, not self hosted
use_semaphore = True if MULTI_TENANT else False

task_logger.info(
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)

client = SimpleJobClient()

while True:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)

celery_task_id = self.request.id
if not celery_task_id:
task_logger.error(
f"Indexing watchdog - task does not have an id: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
break

if use_semaphore:
rank = redis_connector_index.acquire_semaphore(celery_task_id)
if rank >= redis_connector_index.SEMAPHORE_LIMIT:
task_logger.warning(
f"Indexing watchdog - could not acquire semaphore, delaying execution: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"semaphore_rank={rank} "
f"semaphore_limit={redis_connector_index.SEMAPHORE_LIMIT} "
f"task_retries={self.request.retries} "
f"task_delay={TASK_RETRY_DELAY}"
)

# max_retries = None will retry forever, which we don't want
# max_retries = 0 really means no retries
if self.max_retries is None or self.request.retries >= self.max_retries:
task_logger.warning(
f"Indexing watchdog - could not acquire semaphore within max_retries. "
f"Canceling the attempt: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"task_retries={self.request.retries} "
f"max_retries={self.max_retries} "
)
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
f"Canceled the indexing attempt because an indexing slot could not be "
f"acquired within the retry limit: max_retries={self.max_retries}",
)
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.error(
"Indexing watchdog - transient exception marking index attempt as canceled: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
break

# this will delay the task for 30 minutes if the semaphore can't be acquired
# the max_retries annotation on the task will hard limit the number of times self.retry can succeed
raise self.retry(countdown=TASK_RETRY_DELAY)
else:
task_logger.info(
f"Indexing watchdog - acquired semaphore: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"semaphore_rank={rank} "
f"semaphore_limit={redis_connector_index.SEMAPHORE_LIMIT}"
)

job = client.submit(
connector_indexing_task_wrapper,
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
global_version.is_ee_version(),
pure=False,
)

if not job:
task_logger.info(
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
break

task_logger.info(
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)

# This polls the spawned task and other conditions until we are finished
connector_indexing_wait_for_spawned_task(
self,
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
job,
redis_connector_index,
use_semaphore,
)

if use_semaphore:
redis_connector_index.release_semaphore()

break

task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
"Indexing watchdog - finished: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
Expand Down
62 changes: 61 additions & 1 deletion backend/onyx/redis/redis_connector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ class RedisConnectorIndex:
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate

# used to signal the overall workflow is still active
# it's difficult to prevent
ACTIVE_PREFIX = PREFIX + "_active"

# used to limit the number of simulataneous indexing workers at task runtime
SEMAPHORE_KEY = PREFIX + "_semaphore"
SEMAPHORE_LIMIT = 6
SEMAPHORE_TIMEOUT = 15 * 60 # 15 minutes

def __init__(
self,
tenant_id: str | None,
Expand All @@ -59,6 +63,7 @@ def __init__(
)
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
self.semaphore_member: str | None = None

@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
Expand Down Expand Up @@ -112,6 +117,59 @@ def set_terminate(self, celery_task_id: str) -> None:
# 10 minute TTL is good.
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)

def acquire_semaphore(self, celery_task_id: str) -> int:
"""Used to limit the number of simultaneous indexing workers per tenant.
This semaphore does not need to be strongly consistent and is written as such.
"""
self.semaphore_member = celery_task_id

redis_time_raw = self.redis.time()
redis_time = cast(tuple[int, int], redis_time_raw)
now = (
redis_time[0] + redis_time[1] / 1_000_000
) # unix timestamp in floating point seconds

self.redis.zremrangebyscore(
RedisConnectorIndex.SEMAPHORE_KEY, "-inf", now - self.SEMAPHORE_TIMEOUT
) # clean up old semaphore entries

# add ourselves to the semaphore and check our position/rank
self.redis.zadd(RedisConnectorIndex.SEMAPHORE_KEY, {self.semaphore_member: now})
rank_bytes = self.redis.zrank(
RedisConnectorIndex.SEMAPHORE_KEY, self.semaphore_member
)
rank = cast(int, rank_bytes)
if rank >= RedisConnectorIndex.SEMAPHORE_LIMIT:
# we're over the limit, acquiring the semaphore has failed.
self.redis.zrem(RedisConnectorIndex.SEMAPHORE_KEY, self.semaphore_member)

# return the rank ... we failed to acquire the semaphore if rank >= SEMAPHORE_LIMIT
return rank

def refresh_semaphore(self) -> bool:
if not self.semaphore_member:
return False

redis_time_raw = self.redis.time()
redis_time = cast(tuple[int, int], redis_time_raw)
now = (
redis_time[0] + redis_time[1] / 1_000_000
) # unix timestamp in floating point seconds

reply = self.redis.zadd(
RedisConnectorIndex.SEMAPHORE_KEY, {self.semaphore_member: now}, xx=True
)
if reply is None:
return False

return True

def release_semaphore(self) -> None:
if not self.semaphore_member:
return

self.redis.zrem(RedisConnectorIndex.SEMAPHORE_KEY, self.semaphore_member)

def set_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
the expiration time.
Expand Down Expand Up @@ -172,6 +230,8 @@ def reset(self) -> None:
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
r.delete(RedisConnectorIndex.SEMAPHORE_KEY)

for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"):
r.delete(key)

Expand Down
6 changes: 5 additions & 1 deletion backend/onyx/redis/redis_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,15 @@ def __getattribute__(self, item: str) -> Any:
"sadd",
"srem",
"scard",
"zadd",
"zrem",
"zremrangebyscore",
"zrank",
] # Regular methods that need simple prefixing

if item == "scan_iter":
return self._prefix_scan_iter(original_attr)
elif item in methods_to_wrap and callable(original_attr):
if item in methods_to_wrap and callable(original_attr):
return self._prefix_method(original_attr)
return original_attr

Expand Down
Loading