From ec972d4a016010501ffb5f28202711dca753ec13 Mon Sep 17 00:00:00 2001 From: Richard Kuo Date: Tue, 31 Dec 2024 09:38:19 -0800 Subject: [PATCH 1/3] semaphore WIP --- .../background/celery/tasks/indexing/tasks.py | 146 ++++++++++++------ backend/onyx/redis/redis_connector_index.py | 51 +++++- backend/onyx/redis/redis_pool.py | 6 +- 3 files changed, 153 insertions(+), 50 deletions(-) diff --git a/backend/onyx/background/celery/tasks/indexing/tasks.py b/backend/onyx/background/celery/tasks/indexing/tasks.py index a9ce4274c38..a8446332a4e 100644 --- a/backend/onyx/background/celery/tasks/indexing/tasks.py +++ b/backend/onyx/background/celery/tasks/indexing/tasks.py @@ -19,6 +19,7 @@ from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.celery_redis import celery_find_task from onyx.background.celery.celery_redis import celery_get_unacked_task_ids +from onyx.background.indexing.job_client import SimpleJob from onyx.background.indexing.job_client import SimpleJobClient from onyx.background.indexing.run_indexing import run_indexing_entrypoint from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP @@ -728,64 +729,21 @@ def try_creating_indexing_task( return index_attempt_id -@shared_task( - name=OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK, - bind=True, - acks_late=False, - track_started=True, -) -def connector_indexing_proxy_task( +def connector_indexing_wait_for_spawned_task( self: Task, index_attempt_id: int, cc_pair_id: int, search_settings_id: int, tenant_id: str | None, + job: SimpleJob, + redis_connector_index: RedisConnectorIndex, ) -> None: - """celery tasks are forked, but forking is unstable. This proxies work to a spawned task.""" - task_logger.info( - f"Indexing watchdog - starting: attempt={index_attempt_id} " - f"cc_pair={cc_pair_id} " - f"search_settings={search_settings_id}" - ) - - if not self.request.id: - task_logger.error("self.request.id is None!") - - client = SimpleJobClient() - - job = client.submit( - connector_indexing_task_wrapper, - index_attempt_id, - cc_pair_id, - search_settings_id, - tenant_id, - global_version.is_ee_version(), - pure=False, - ) - - if not job: - task_logger.info( - f"Indexing watchdog - spawn failed: attempt={index_attempt_id} " - f"cc_pair={cc_pair_id} " - f"search_settings={search_settings_id}" - ) - return - - task_logger.info( - f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} " - f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} " - f"cc_pair={cc_pair_id} " - f"search_settings={search_settings_id}" - ) - - redis_connector = RedisConnector(tenant_id, cc_pair_id) - redis_connector_index = redis_connector.new_index(search_settings_id) - while True: sleep(5) # renew active signal redis_connector_index.set_active() + redis_connector_index.refresh_semaphore() # if the job is done, clean up and break if job.done(): @@ -884,8 +842,100 @@ def connector_indexing_proxy_task( ) continue + +@shared_task( + name=OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK, + bind=True, + acks_late=False, + track_started=True, + max_retries=128, # an arbitrarily large but finite number of retries +) +def connector_indexing_proxy_task( + self: Task, + index_attempt_id: int, + cc_pair_id: int, + search_settings_id: int, + tenant_id: str | None, +) -> None: + """celery tasks are forked, but forking is unstable. This proxies work to a spawned task.""" + TASK_RETRY_DELAY = 1800 # in seconds + task_logger.info( - f"Indexing watchdog - finished: attempt={index_attempt_id} " + f"Indexing watchdog - starting: attempt={index_attempt_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id}" + ) + + client = SimpleJobClient() + + while True: + redis_connector = RedisConnector(tenant_id, cc_pair_id) + redis_connector_index = redis_connector.new_index(search_settings_id) + + celery_task_id = self.request.id + if not celery_task_id: + task_logger.error( + f"Indexing watchdog - task does not have an id: " + f"attempt={index_attempt_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id}" + ) + break + + if MULTI_TENANT: + if not redis_connector_index.acquire_semaphore(celery_task_id): + task_logger.warning( + f"Indexing watchdog - could not acquire semaphore, delaying execution: " + f"attempt={index_attempt_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id} " + f"semaphore_limit={redis_connector_index.SEMAPHORE_LIMIT}" + f"task_delay={TASK_RETRY_DELAY}" + ) + + # this will delay the task for 30 minutes if the semaphore can't be acquired + raise self.retry(countdown=TASK_RETRY_DELAY) + + job = client.submit( + connector_indexing_task_wrapper, + index_attempt_id, + cc_pair_id, + search_settings_id, + tenant_id, + global_version.is_ee_version(), + pure=False, + ) + + if not job: + task_logger.info( + f"Indexing watchdog - spawn failed: attempt={index_attempt_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id}" + ) + break + + task_logger.info( + f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id}" + ) + + # This polls the spawned task and other conditions until we are finished + connector_indexing_wait_for_spawned_task( + self, + index_attempt_id, + cc_pair_id, + search_settings_id, + tenant_id, + job, + redis_connector_index, + ) + redis_connector_index.release_semaphore() + break + + task_logger.info( + "Indexing watchdog - finished: " + f"attempt={index_attempt_id} " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id}" ) diff --git a/backend/onyx/redis/redis_connector_index.py b/backend/onyx/redis/redis_connector_index.py index 5cf5d449d26..75386418d1a 100644 --- a/backend/onyx/redis/redis_connector_index.py +++ b/backend/onyx/redis/redis_connector_index.py @@ -32,9 +32,13 @@ class RedisConnectorIndex: TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate # used to signal the overall workflow is still active - # it's difficult to prevent ACTIVE_PREFIX = PREFIX + "_active" + # used to limit the number of simulataneous indexing workers at task runtime + SEMAPHORE_KEY = PREFIX + "_semaphore" + SEMAPHORE_LIMIT = 1 + SEMAPHORE_TIMEOUT = 15 * 60 # 15 minutes + def __init__( self, tenant_id: str | None, @@ -59,6 +63,7 @@ def __init__( ) self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}" self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}" + self.semaphore_member: str | None = None @classmethod def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str: @@ -112,6 +117,50 @@ def set_terminate(self, celery_task_id: str) -> None: # 10 minute TTL is good. self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600) + def acquire_semaphore(self, celery_task_id: str) -> bool: + """Used to limit the number of simultaneous indexing workers per tenant. + This semaphore does not need to be strongly consistent and is written as such. + """ + self.semaphore_member = celery_task_id + + redis_time_raw = self.redis.time() + redis_time = cast(tuple[int, int], redis_time_raw) + now = ( + redis_time[0] + redis_time[1] / 1_000_000 + ) # unix timestamp in floating point seconds + + self.redis.zremrangebyscore( + RedisConnectorIndex.SEMAPHORE_KEY, "-inf", now - self.SEMAPHORE_TIMEOUT + ) + self.redis.zadd(RedisConnectorIndex.SEMAPHORE_KEY, {self.semaphore_member: now}) + rank_bytes = self.redis.zrank( + RedisConnectorIndex.SEMAPHORE_KEY, self.semaphore_member + ) + rank = cast(int, rank_bytes) + if rank >= RedisConnectorIndex.SEMAPHORE_LIMIT: + self.redis.zrem(RedisConnectorIndex.SEMAPHORE_KEY, self.semaphore_member) + return False + + return True + + def refresh_semaphore(self) -> None: + if not self.semaphore_member: + return + + redis_time_raw = self.redis.time() + redis_time = cast(tuple[int, int], redis_time_raw) + now = ( + redis_time[0] + redis_time[1] / 1_000_000 + ) # unix timestamp in floating point seconds + + self.redis.zadd(RedisConnectorIndex.SEMAPHORE_KEY, {self.semaphore_member: now}) + + def release_semaphore(self) -> None: + if not self.semaphore_member: + return + + self.redis.zrem(RedisConnectorIndex.SEMAPHORE_KEY, self.semaphore_member) + def set_active(self) -> None: """This sets a signal to keep the indexing flow from getting cleaned up within the expiration time. diff --git a/backend/onyx/redis/redis_pool.py b/backend/onyx/redis/redis_pool.py index f7e372887cf..c6125d98e48 100644 --- a/backend/onyx/redis/redis_pool.py +++ b/backend/onyx/redis/redis_pool.py @@ -107,11 +107,15 @@ def __getattribute__(self, item: str) -> Any: "sadd", "srem", "scard", + "zadd", + "zrem", + "zremrangebyscore", + "zrank", ] # Regular methods that need simple prefixing if item == "scan_iter": return self._prefix_scan_iter(original_attr) - elif item in methods_to_wrap and callable(original_attr): + if item in methods_to_wrap and callable(original_attr): return self._prefix_method(original_attr) return original_attr From b41efbc0255629cbc55328a1a26f1dd54bdd9000 Mon Sep 17 00:00:00 2001 From: Richard Kuo Date: Tue, 31 Dec 2024 16:33:35 -0800 Subject: [PATCH 2/3] working semaphore --- .../background/celery/tasks/indexing/tasks.py | 74 +++++++++++++++++-- backend/onyx/redis/redis_connector_index.py | 27 +++++-- 2 files changed, 87 insertions(+), 14 deletions(-) diff --git a/backend/onyx/background/celery/tasks/indexing/tasks.py b/backend/onyx/background/celery/tasks/indexing/tasks.py index a8446332a4e..937ba65b686 100644 --- a/backend/onyx/background/celery/tasks/indexing/tasks.py +++ b/backend/onyx/background/celery/tasks/indexing/tasks.py @@ -737,13 +737,23 @@ def connector_indexing_wait_for_spawned_task( tenant_id: str | None, job: SimpleJob, redis_connector_index: RedisConnectorIndex, + use_semaphore: bool, ) -> None: while True: sleep(5) # renew active signal redis_connector_index.set_active() - redis_connector_index.refresh_semaphore() + if use_semaphore: + refreshed = redis_connector_index.refresh_semaphore() + if not refreshed: + task_logger.warning( + "Indexing watchdog - refresh semaphore failed: " + f"attempt={index_attempt_id} " + f"tenant={tenant_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id} " + ) # if the job is done, clean up and break if job.done(): @@ -848,7 +858,7 @@ def connector_indexing_wait_for_spawned_task( bind=True, acks_late=False, track_started=True, - max_retries=128, # an arbitrarily large but finite number of retries + max_retries=64, # an arbitrarily large but finite number of retries ) def connector_indexing_proxy_task( self: Task, @@ -860,6 +870,9 @@ def connector_indexing_proxy_task( """celery tasks are forked, but forking is unstable. This proxies work to a spawned task.""" TASK_RETRY_DELAY = 1800 # in seconds + # we only care about limiting concurrency in a multitenant scenario, not self hosted + use_semaphore = True if MULTI_TENANT else False + task_logger.info( f"Indexing watchdog - starting: attempt={index_attempt_id} " f"cc_pair={cc_pair_id} " @@ -882,19 +895,64 @@ def connector_indexing_proxy_task( ) break - if MULTI_TENANT: - if not redis_connector_index.acquire_semaphore(celery_task_id): + if use_semaphore: + rank = redis_connector_index.acquire_semaphore(celery_task_id) + if rank >= redis_connector_index.SEMAPHORE_LIMIT: task_logger.warning( f"Indexing watchdog - could not acquire semaphore, delaying execution: " f"attempt={index_attempt_id} " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id} " - f"semaphore_limit={redis_connector_index.SEMAPHORE_LIMIT}" + f"semaphore_rank={rank} " + f"semaphore_limit={redis_connector_index.SEMAPHORE_LIMIT} " + f"task_retries={self.request.retries} " f"task_delay={TASK_RETRY_DELAY}" ) + # max_retries = None will retry forever, which we don't want + # max_retries = 0 really means no retries + if self.max_retries is None or self.request.retries >= self.max_retries: + task_logger.warning( + f"Indexing watchdog - could not acquire semaphore within max_retries. " + f"Canceling the attempt: " + f"attempt={index_attempt_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id} " + f"task_retries={self.request.retries} " + f"max_retries={self.max_retries} " + ) + try: + with get_session_with_tenant(tenant_id) as db_session: + mark_attempt_canceled( + index_attempt_id, + db_session, + f"Canceled the indexing attempt because an indexing slot could not be " + f"acquired within the retry limit: max_retries={self.max_retries}", + ) + except Exception: + # if the DB exceptions, we'll just get an unfriendly failure message + # in the UI instead of the cancellation message + logger.error( + "Indexing watchdog - transient exception marking index attempt as canceled: " + f"attempt={index_attempt_id} " + f"tenant={tenant_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id}" + ) + break + # this will delay the task for 30 minutes if the semaphore can't be acquired + # the max_retries annotation on the task will hard limit the number of times self.retry can succeed raise self.retry(countdown=TASK_RETRY_DELAY) + else: + task_logger.info( + f"Indexing watchdog - acquired semaphore: " + f"attempt={index_attempt_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id} " + f"semaphore_rank={rank} " + f"semaphore_limit={redis_connector_index.SEMAPHORE_LIMIT}" + ) job = client.submit( connector_indexing_task_wrapper, @@ -929,8 +987,12 @@ def connector_indexing_proxy_task( tenant_id, job, redis_connector_index, + use_semaphore, ) - redis_connector_index.release_semaphore() + + if use_semaphore: + redis_connector_index.release_semaphore() + break task_logger.info( diff --git a/backend/onyx/redis/redis_connector_index.py b/backend/onyx/redis/redis_connector_index.py index 75386418d1a..8c216372794 100644 --- a/backend/onyx/redis/redis_connector_index.py +++ b/backend/onyx/redis/redis_connector_index.py @@ -36,7 +36,7 @@ class RedisConnectorIndex: # used to limit the number of simulataneous indexing workers at task runtime SEMAPHORE_KEY = PREFIX + "_semaphore" - SEMAPHORE_LIMIT = 1 + SEMAPHORE_LIMIT = 6 SEMAPHORE_TIMEOUT = 15 * 60 # 15 minutes def __init__( @@ -117,7 +117,7 @@ def set_terminate(self, celery_task_id: str) -> None: # 10 minute TTL is good. self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600) - def acquire_semaphore(self, celery_task_id: str) -> bool: + def acquire_semaphore(self, celery_task_id: str) -> int: """Used to limit the number of simultaneous indexing workers per tenant. This semaphore does not need to be strongly consistent and is written as such. """ @@ -131,21 +131,24 @@ def acquire_semaphore(self, celery_task_id: str) -> bool: self.redis.zremrangebyscore( RedisConnectorIndex.SEMAPHORE_KEY, "-inf", now - self.SEMAPHORE_TIMEOUT - ) + ) # clean up old semaphore entries + + # add ourselves to the semaphore and check our position/rank self.redis.zadd(RedisConnectorIndex.SEMAPHORE_KEY, {self.semaphore_member: now}) rank_bytes = self.redis.zrank( RedisConnectorIndex.SEMAPHORE_KEY, self.semaphore_member ) rank = cast(int, rank_bytes) if rank >= RedisConnectorIndex.SEMAPHORE_LIMIT: + # we're over the limit, acquiring the semaphore has failed. self.redis.zrem(RedisConnectorIndex.SEMAPHORE_KEY, self.semaphore_member) - return False - return True + # return the rank ... we failed to acquire the semaphore is rank >= SEMAPHORE_LIMIT + return rank - def refresh_semaphore(self) -> None: + def refresh_semaphore(self) -> bool: if not self.semaphore_member: - return + return False redis_time_raw = self.redis.time() redis_time = cast(tuple[int, int], redis_time_raw) @@ -153,7 +156,13 @@ def refresh_semaphore(self) -> None: redis_time[0] + redis_time[1] / 1_000_000 ) # unix timestamp in floating point seconds - self.redis.zadd(RedisConnectorIndex.SEMAPHORE_KEY, {self.semaphore_member: now}) + reply = self.redis.zadd( + RedisConnectorIndex.SEMAPHORE_KEY, {self.semaphore_member: now}, xx=True + ) + if reply is None: + return False + + return True def release_semaphore(self) -> None: if not self.semaphore_member: @@ -221,6 +230,8 @@ def reset(self) -> None: @staticmethod def reset_all(r: redis.Redis) -> None: """Deletes all redis values for all connectors""" + r.delete(RedisConnectorIndex.SEMAPHORE_KEY) + for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"): r.delete(key) From 76bbbd88bd84cad4540601ac11fa3c3993a543c6 Mon Sep 17 00:00:00 2001 From: Richard Kuo Date: Tue, 31 Dec 2024 18:00:12 -0800 Subject: [PATCH 3/3] typo --- backend/onyx/redis/redis_connector_index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/onyx/redis/redis_connector_index.py b/backend/onyx/redis/redis_connector_index.py index 8c216372794..0869806f42e 100644 --- a/backend/onyx/redis/redis_connector_index.py +++ b/backend/onyx/redis/redis_connector_index.py @@ -143,7 +143,7 @@ def acquire_semaphore(self, celery_task_id: str) -> int: # we're over the limit, acquiring the semaphore has failed. self.redis.zrem(RedisConnectorIndex.SEMAPHORE_KEY, self.semaphore_member) - # return the rank ... we failed to acquire the semaphore is rank >= SEMAPHORE_LIMIT + # return the rank ... we failed to acquire the semaphore if rank >= SEMAPHORE_LIMIT return rank def refresh_semaphore(self) -> bool: