Skip to content

Commit

Permalink
k
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 committed Dec 31, 2024
1 parent 240f3e4 commit accd743
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 62 deletions.
11 changes: 10 additions & 1 deletion backend/onyx/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days

# Default request timeout, mostly used by connectors
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)

# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
# restrict access to Onyx to only users with emails from those domains.
# E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx
Expand Down Expand Up @@ -367,12 +370,18 @@
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
)

# Typically set to http://localhost:3000 for OAuth connector development
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")

# Egnyte specific configs
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")

# Linear specific configs
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")

DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dateutil.parser import parse

from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
from onyx.configs.constants import IGNORE_FOR_QA
from onyx.connectors.models import BasicExpertInfo
from onyx.utils.text_processing import is_valid_email
Expand Down Expand Up @@ -71,3 +72,10 @@ def process_in_batches(

def get_metadata_keys_to_ignore() -> list[str]:
return [IGNORE_FOR_QA]


def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
if CONNECTOR_LOCALHOST_OVERRIDE:
# Used for development
base_domain = CONNECTOR_LOCALHOST_OVERRIDE
return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}"
68 changes: 10 additions & 58 deletions backend/onyx/connectors/egnyte/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,18 @@
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from logging import Logger
from typing import Any
from typing import cast
from typing import IO
from urllib.parse import quote

import requests
from retry import retry

from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
from onyx.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_oauth_callback_uri,
)
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import OAuthConnector
Expand All @@ -34,54 +31,13 @@
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import read_text_file
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries


logger = setup_logger()

_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
_TIMEOUT = 60


def _request_with_retries(
method: str,
url: str,
data: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
timeout: int = _TIMEOUT,
stream: bool = False,
tries: int = 8,
delay: float = 1,
backoff: float = 2,
) -> requests.Response:
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
def _make_request() -> requests.Response:
response = requests.request(
method,
url,
data=data,
headers=headers,
params=params,
timeout=timeout,
stream=stream,
)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
if e.response.status_code != 403:
logger.exception(
f"Failed to call Egnyte API.\n"
f"URL: {url}\n"
# NOTE: can't log headers because they contain the access token
# f"Headers: {headers}\n"
f"Data: {data}\n"
f"Params: {params}"
)
raise e
return response

return _make_request()


def _parse_last_modified(last_modified: str) -> datetime:
Expand Down Expand Up @@ -189,10 +145,7 @@ def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")

if EGNYTE_LOCALHOST_OVERRIDE:
base_domain = EGNYTE_LOCALHOST_OVERRIDE

callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
return (
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
f"?client_id={EGNYTE_CLIENT_ID}"
Expand All @@ -213,7 +166,7 @@ def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:

# Exchange code for token
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
redirect_uri = f"{EGNYTE_LOCALHOST_OVERRIDE or base_domain}/connector/oauth/callback/egnyte"
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
data = {
"client_id": EGNYTE_CLIENT_ID,
"client_secret": EGNYTE_CLIENT_SECRET,
Expand All @@ -224,7 +177,7 @@ def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}

response = _request_with_retries(
response = request_with_retries(
method="POST",
url=url,
data=data,
Expand Down Expand Up @@ -264,8 +217,8 @@ def _get_files_list(

url_encoded_path = quote(path or "", safe="")
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
response = _request_with_retries(
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
response = request_with_retries(
method="GET", url=url, headers=headers, params=params
)
if not response.ok:
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
Expand Down Expand Up @@ -320,11 +273,10 @@ def _process_files(
}
url_encoded_path = quote(file["path"], safe="")
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
response = _request_with_retries(
response = request_with_retries(
method="GET",
url=url,
headers=headers,
timeout=_TIMEOUT,
stream=True,
)

Expand Down
66 changes: 64 additions & 2 deletions backend/onyx/connectors/linear/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,23 @@
import requests

from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import LINEAR_CLIENT_ID
from onyx.configs.app_configs import LINEAR_CLIENT_SECRET
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_oauth_callback_uri,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import OAuthConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries


logger = setup_logger()
Expand Down Expand Up @@ -57,16 +64,71 @@ def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response
)


class LinearConnector(LoadConnector, PollConnector):
class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.batch_size = batch_size
self.linear_api_key: str | None = None

@classmethod
def oauth_id(cls) -> DocumentSource:
return DocumentSource.LINEAR

@classmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
if not LINEAR_CLIENT_ID:
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")

callback_uri = get_oauth_callback_uri(base_domain, "linear")
return (
f"https://linear.app/oauth/authorize"
f"?client_id={LINEAR_CLIENT_ID}"
f"&redirect_uri={callback_uri}"
f"&response_type=code"
f"&scope=read"
f"&state={state}"
)

@classmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
data = {
"code": code,
"redirect_uri": get_oauth_callback_uri(base_domain, "linear"),
"client_id": LINEAR_CLIENT_ID,
"client_secret": LINEAR_CLIENT_SECRET,
"grant_type": "authorization_code",
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}

response = request_with_retries(
method="POST",
url="https://api.linear.app/oauth/token",
data=data,
headers=headers,
backoff=0,
delay=0.1,
)
if not response.ok:
raise RuntimeError(f"Failed to exchange code for token: {response.text}")

token_data = response.json()

return {
"access_token": token_data["access_token"],
}

def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.linear_api_key = cast(str, credentials["linear_api_key"])
if "linear_api_key" in credentials:
self.linear_api_key = cast(str, credentials["linear_api_key"])
elif "access_token" in credentials:
self.linear_api_key = "Bearer " + cast(str, credentials["access_token"])
else:
# The error is slightly different for Linear cause the failure likely happened upstream
# Could be that the refresh token expired
raise ConnectorMissingCredentialError("Linear")

return None

def _process_issues(
Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/redis/redis_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
# value = redis_client.get('key')
# print(value.decode()) # Output: 'value'

_async_redis_connection = None
_async_redis_connection: aioredis.Redis | None = None
_async_lock = asyncio.Lock()


Expand Down
47 changes: 47 additions & 0 deletions backend/onyx/utils/retry_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from typing import cast
from typing import TypeVar

import requests
from retry import retry

from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
from onyx.utils.logger import setup_logger

logger = setup_logger()
Expand Down Expand Up @@ -42,3 +44,48 @@ def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
return cast(F, wrapped_func)

return retry_with_default


def request_with_retries(
method: str,
url: str,
*,
data: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
timeout: int = REQUEST_TIMEOUT_SECONDS,
stream: bool = False,
tries: int = 8,
delay: float = 1,
backoff: float = 2,
) -> requests.Response:
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
def _make_request() -> requests.Response:
response = requests.request(
method=method,
url=url,
data=data,
headers=headers,
params=params,
timeout=timeout,
stream=stream,
)
try:
response.raise_for_status()
except requests.exceptions.HTTPError:
logger.exception(
"Request failed:\n%s",
{
"method": method,
"url": url,
"data": data,
"headers": headers,
"params": params,
"timeout": timeout,
"stream": stream,
},
)
raise
return response

return _make_request()

0 comments on commit accd743

Please sign in to comment.