-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add rate limiting wrapper + add to Document360
- Loading branch information
Showing
3 changed files
with
131 additions
and
0 deletions.
There are no files selected for viewing
86 changes: 86 additions & 0 deletions
86
backend/danswer/connectors/cross_connector_utils/rate_limit_wrapper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import time | ||
from collections.abc import Callable | ||
from functools import wraps | ||
from typing import Any | ||
from typing import cast | ||
from typing import TypeVar | ||
|
||
from danswer.utils.logger import setup_logger | ||
|
||
logger = setup_logger() | ||
|
||
|
||
F = TypeVar("F", bound=Callable[..., Any]) | ||
|
||
|
||
class RateLimitTriedTooManyTimesError(Exception): | ||
pass | ||
|
||
|
||
class _RateLimitDecorator: | ||
"""Builds a generic wrapper/decorator for calls to external APIs that | ||
prevents making more than `max_calls` requests per `period` | ||
Implementation inspired by the `ratelimit` library: | ||
https://github.com/tomasbasham/ratelimit. | ||
NOTE: is not thread safe. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
max_calls: int, | ||
period: float, # in seconds | ||
sleep_time: float = 2, # in seconds | ||
sleep_backoff: float = 2, # applies exponential backoff | ||
max_num_sleep: int = 0, | ||
): | ||
self.max_calls = max_calls | ||
self.period = period | ||
self.sleep_time = sleep_time | ||
self.sleep_backoff = sleep_backoff | ||
self.max_num_sleep = max_num_sleep | ||
|
||
self.call_history: list[float] = [] | ||
self.curr_calls = 0 | ||
|
||
def __call__(self, func: F) -> F: | ||
@wraps(func) | ||
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any: | ||
# cleanup calls which are no longer relevant | ||
self._cleanup() | ||
|
||
# check if we've exceeded the rate limit | ||
sleep_cnt = 0 | ||
while len(self.call_history) == self.max_calls: | ||
sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt) | ||
logger.info( | ||
f"Rate limit exceeded for function {func.__name__}. " | ||
f"Waiting {sleep_time} seconds before retrying." | ||
) | ||
time.sleep(sleep_time) | ||
sleep_cnt += 1 | ||
if self.max_num_sleep != 0 and sleep_cnt >= self.max_num_sleep: | ||
raise RateLimitTriedTooManyTimesError( | ||
f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'" | ||
) | ||
|
||
self._cleanup() | ||
|
||
# add the current call to the call history | ||
self.call_history.append(time.monotonic()) | ||
return func(*args, **kwargs) | ||
|
||
return cast(F, wrapped_func) | ||
|
||
def _cleanup(self) -> None: | ||
curr_time = time.monotonic() | ||
time_to_expire_before = curr_time - self.period | ||
self.call_history = [ | ||
call_time | ||
for call_time in self.call_history | ||
if call_time > time_to_expire_before | ||
] | ||
|
||
|
||
rate_limit_builder = _RateLimitDecorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
backend/tests/unit/danswer/connectors/cross_connector_utils/test_rate_limit.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import time | ||
import unittest | ||
|
||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( | ||
rate_limit_builder, | ||
) | ||
|
||
|
||
class TestRateLimit(unittest.TestCase): | ||
call_cnt = 0 | ||
|
||
def test_rate_limit_basic(self) -> None: | ||
self.call_cnt = 0 | ||
|
||
@rate_limit_builder(max_calls=2, period=5) | ||
def func() -> None: | ||
self.call_cnt += 1 | ||
|
||
start = time.time() | ||
|
||
# make calls that shouldn't be rate-limited | ||
func() | ||
func() | ||
time_to_finish_non_ratelimited = time.time() - start | ||
|
||
# make a call which SHOULD be rate-limited | ||
func() | ||
time_to_finish_ratelimited = time.time() - start | ||
|
||
self.assertEqual(self.call_cnt, 3) | ||
self.assertLess(time_to_finish_non_ratelimited, 1) | ||
self.assertGreater(time_to_finish_ratelimited, 5) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
08909b4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Successfully deployed to the following URLs:
internal-search – ./
internal-search.vercel.app
internal-search-danswer.vercel.app
internal-search-git-main-danswer.vercel.app
www.danswer.dev
danswer.dev