Skip to content

Commit

Permalink
Add rate limiting wrapper + add to Document360
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves committed Oct 30, 2023
1 parent 64ebaf2 commit 08909b4
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import time
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import cast
from typing import TypeVar

from danswer.utils.logger import setup_logger

logger = setup_logger()


F = TypeVar("F", bound=Callable[..., Any])


class RateLimitTriedTooManyTimesError(Exception):
pass


class _RateLimitDecorator:
"""Builds a generic wrapper/decorator for calls to external APIs that
prevents making more than `max_calls` requests per `period`
Implementation inspired by the `ratelimit` library:
https://github.com/tomasbasham/ratelimit.
NOTE: is not thread safe.
"""

def __init__(
self,
max_calls: int,
period: float, # in seconds
sleep_time: float = 2, # in seconds
sleep_backoff: float = 2, # applies exponential backoff
max_num_sleep: int = 0,
):
self.max_calls = max_calls
self.period = period
self.sleep_time = sleep_time
self.sleep_backoff = sleep_backoff
self.max_num_sleep = max_num_sleep

self.call_history: list[float] = []
self.curr_calls = 0

def __call__(self, func: F) -> F:
@wraps(func)
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
# cleanup calls which are no longer relevant
self._cleanup()

# check if we've exceeded the rate limit
sleep_cnt = 0
while len(self.call_history) == self.max_calls:
sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt)
logger.info(
f"Rate limit exceeded for function {func.__name__}. "
f"Waiting {sleep_time} seconds before retrying."
)
time.sleep(sleep_time)
sleep_cnt += 1
if self.max_num_sleep != 0 and sleep_cnt >= self.max_num_sleep:
raise RateLimitTriedTooManyTimesError(
f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'"
)

self._cleanup()

# add the current call to the call history
self.call_history.append(time.monotonic())
return func(*args, **kwargs)

return cast(F, wrapped_func)

def _cleanup(self) -> None:
curr_time = time.monotonic()
time_to_expire_before = curr_time - self.period
self.call_history = [
call_time
for call_time in self.call_history
if call_time > time_to_expire_before
]


rate_limit_builder = _RateLimitDecorator
9 changes: 9 additions & 0 deletions backend/danswer/connectors/document360/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
Expand Down Expand Up @@ -46,6 +50,11 @@ def load_credentials(self, credentials: dict[str, Any]) -> Optional[dict[str, An
self.portal_id = credentials.get("portal_id")
return None

# rate limiting set based on the enterprise plan: https://apidocs.document360.com/apidocs/rate-limiting
# NOTE: retry will handle cases where user is not on enterprise plan - we will just hit the rate limit
# and then retry after a period
@retry_builder()
@rate_limit_builder(max_calls=100, period=60)
def _make_request(self, endpoint: str, params: Optional[dict] = None) -> Any:
if not self.api_token:
raise ConnectorMissingCredentialError("Document360")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import time
import unittest

from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)


class TestRateLimit(unittest.TestCase):
call_cnt = 0

def test_rate_limit_basic(self) -> None:
self.call_cnt = 0

@rate_limit_builder(max_calls=2, period=5)
def func() -> None:
self.call_cnt += 1

start = time.time()

# make calls that shouldn't be rate-limited
func()
func()
time_to_finish_non_ratelimited = time.time() - start

# make a call which SHOULD be rate-limited
func()
time_to_finish_ratelimited = time.time() - start

self.assertEqual(self.call_cnt, 3)
self.assertLess(time_to_finish_non_ratelimited, 1)
self.assertGreater(time_to_finish_ratelimited, 5)


if __name__ == "__main__":
unittest.main()

1 comment on commit 08909b4

@vercel
Copy link

@vercel vercel bot commented on 08909b4 Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.