From ccd39838020f78f5be9fde0cd78652d08e9f6626 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Tue, 31 Dec 2024 16:11:09 -0800 Subject: [PATCH] Linear OAuth Connector (#3570) --- backend/onyx/configs/app_configs.py | 11 ++- .../miscellaneous_utils.py | 8 +++ backend/onyx/connectors/egnyte/connector.py | 68 +++---------------- backend/onyx/connectors/linear/connector.py | 67 +++++++++++++++++- backend/onyx/redis/redis_pool.py | 2 +- backend/onyx/utils/retry_wrapper.py | 47 +++++++++++++ 6 files changed, 141 insertions(+), 62 deletions(-) diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 8baebfdf2fe..1c8be75fc8e 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -58,6 +58,9 @@ os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7 ) # 7 days +# Default request timeout, mostly used by connectors +REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60) + # set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to # restrict access to Onyx to only users with emails from those domains. # E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx @@ -367,12 +370,18 @@ os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true" ) +# Typically set to http://localhost:3000 for OAuth connector development +CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE") + # Egnyte specific configs -EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE") EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN") EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID") EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET") +# Linear specific configs +LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID") +LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET") + DASK_JOB_CLIENT_ENABLED = ( os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true" ) diff --git a/backend/onyx/connectors/cross_connector_utils/miscellaneous_utils.py b/backend/onyx/connectors/cross_connector_utils/miscellaneous_utils.py index 2170ea7ba2c..f504f2e7e97 100644 --- a/backend/onyx/connectors/cross_connector_utils/miscellaneous_utils.py +++ b/backend/onyx/connectors/cross_connector_utils/miscellaneous_utils.py @@ -6,6 +6,7 @@ from dateutil.parser import parse +from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE from onyx.configs.constants import IGNORE_FOR_QA from onyx.connectors.models import BasicExpertInfo from onyx.utils.text_processing import is_valid_email @@ -71,3 +72,10 @@ def process_in_batches( def get_metadata_keys_to_ignore() -> list[str]: return [IGNORE_FOR_QA] + + +def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str: + if CONNECTOR_LOCALHOST_OVERRIDE: + # Used for development + base_domain = CONNECTOR_LOCALHOST_OVERRIDE + return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}" diff --git a/backend/onyx/connectors/egnyte/connector.py b/backend/onyx/connectors/egnyte/connector.py index ba4ece903cd..0fa82cd55e7 100644 --- a/backend/onyx/connectors/egnyte/connector.py +++ b/backend/onyx/connectors/egnyte/connector.py @@ -3,21 +3,18 @@ from collections.abc import Generator from datetime import datetime from datetime import timezone -from logging import Logger from typing import Any -from typing import cast from typing import IO from urllib.parse import quote -import requests -from retry import retry - from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN from onyx.configs.app_configs import EGNYTE_CLIENT_ID from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET -from onyx.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource +from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( + get_oauth_callback_uri, +) from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import OAuthConnector @@ -34,54 +31,13 @@ from onyx.file_processing.extract_file_text import is_valid_file_ext from onyx.file_processing.extract_file_text import read_text_file from onyx.utils.logger import setup_logger +from onyx.utils.retry_wrapper import request_with_retries logger = setup_logger() _EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1" _EGNYTE_APP_BASE = "https://{domain}.egnyte.com" -_TIMEOUT = 60 - - -def _request_with_retries( - method: str, - url: str, - data: dict[str, Any] | None = None, - headers: dict[str, Any] | None = None, - params: dict[str, Any] | None = None, - timeout: int = _TIMEOUT, - stream: bool = False, - tries: int = 8, - delay: float = 1, - backoff: float = 2, -) -> requests.Response: - @retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger)) - def _make_request() -> requests.Response: - response = requests.request( - method, - url, - data=data, - headers=headers, - params=params, - timeout=timeout, - stream=stream, - ) - try: - response.raise_for_status() - except requests.exceptions.HTTPError as e: - if e.response.status_code != 403: - logger.exception( - f"Failed to call Egnyte API.\n" - f"URL: {url}\n" - # NOTE: can't log headers because they contain the access token - # f"Headers: {headers}\n" - f"Data: {data}\n" - f"Params: {params}" - ) - raise e - return response - - return _make_request() def _parse_last_modified(last_modified: str) -> datetime: @@ -189,10 +145,7 @@ def oauth_authorization_url(cls, base_domain: str, state: str) -> str: if not EGNYTE_BASE_DOMAIN: raise ValueError("EGNYTE_DOMAIN environment variable must be set") - if EGNYTE_LOCALHOST_OVERRIDE: - base_domain = EGNYTE_LOCALHOST_OVERRIDE - - callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte" + callback_uri = get_oauth_callback_uri(base_domain, "egnyte") return ( f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token" f"?client_id={EGNYTE_CLIENT_ID}" @@ -213,7 +166,7 @@ def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]: # Exchange code for token url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token" - redirect_uri = f"{EGNYTE_LOCALHOST_OVERRIDE or base_domain}/connector/oauth/callback/egnyte" + redirect_uri = get_oauth_callback_uri(base_domain, "egnyte") data = { "client_id": EGNYTE_CLIENT_ID, "client_secret": EGNYTE_CLIENT_SECRET, @@ -224,7 +177,7 @@ def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]: } headers = {"Content-Type": "application/x-www-form-urlencoded"} - response = _request_with_retries( + response = request_with_retries( method="POST", url=url, data=data, @@ -264,8 +217,8 @@ def _get_files_list( url_encoded_path = quote(path or "", safe="") url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}" - response = _request_with_retries( - method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT + response = request_with_retries( + method="GET", url=url, headers=headers, params=params ) if not response.ok: raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}") @@ -320,11 +273,10 @@ def _process_files( } url_encoded_path = quote(file["path"], safe="") url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}" - response = _request_with_retries( + response = request_with_retries( method="GET", url=url, headers=headers, - timeout=_TIMEOUT, stream=True, ) diff --git a/backend/onyx/connectors/linear/connector.py b/backend/onyx/connectors/linear/connector.py index b356026e0f5..0bd10e91f77 100644 --- a/backend/onyx/connectors/linear/connector.py +++ b/backend/onyx/connectors/linear/connector.py @@ -7,16 +7,23 @@ import requests from onyx.configs.app_configs import INDEX_BATCH_SIZE +from onyx.configs.app_configs import LINEAR_CLIENT_ID +from onyx.configs.app_configs import LINEAR_CLIENT_SECRET from onyx.configs.constants import DocumentSource +from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( + get_oauth_callback_uri, +) from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector +from onyx.connectors.interfaces import OAuthConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import Section from onyx.utils.logger import setup_logger +from onyx.utils.retry_wrapper import request_with_retries logger = setup_logger() @@ -57,7 +64,7 @@ def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response ) -class LinearConnector(LoadConnector, PollConnector): +class LinearConnector(LoadConnector, PollConnector, OAuthConnector): def __init__( self, batch_size: int = INDEX_BATCH_SIZE, @@ -65,8 +72,64 @@ def __init__( self.batch_size = batch_size self.linear_api_key: str | None = None + @classmethod + def oauth_id(cls) -> DocumentSource: + return DocumentSource.LINEAR + + @classmethod + def oauth_authorization_url(cls, base_domain: str, state: str) -> str: + if not LINEAR_CLIENT_ID: + raise ValueError("LINEAR_CLIENT_ID environment variable must be set") + + callback_uri = get_oauth_callback_uri(base_domain, DocumentSource.LINEAR.value) + return ( + f"https://linear.app/oauth/authorize" + f"?client_id={LINEAR_CLIENT_ID}" + f"&redirect_uri={callback_uri}" + f"&response_type=code" + f"&scope=read" + f"&state={state}" + ) + + @classmethod + def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]: + data = { + "code": code, + "redirect_uri": get_oauth_callback_uri( + base_domain, DocumentSource.LINEAR.value + ), + "client_id": LINEAR_CLIENT_ID, + "client_secret": LINEAR_CLIENT_SECRET, + "grant_type": "authorization_code", + } + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + response = request_with_retries( + method="POST", + url="https://api.linear.app/oauth/token", + data=data, + headers=headers, + backoff=0, + delay=0.1, + ) + if not response.ok: + raise RuntimeError(f"Failed to exchange code for token: {response.text}") + + token_data = response.json() + + return { + "access_token": token_data["access_token"], + } + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: - self.linear_api_key = cast(str, credentials["linear_api_key"]) + if "linear_api_key" in credentials: + self.linear_api_key = cast(str, credentials["linear_api_key"]) + elif "access_token" in credentials: + self.linear_api_key = "Bearer " + cast(str, credentials["access_token"]) + else: + # May need to handle case in the future if the OAuth flow expires + raise ConnectorMissingCredentialError("Linear") + return None def _process_issues( diff --git a/backend/onyx/redis/redis_pool.py b/backend/onyx/redis/redis_pool.py index f7e372887cf..c118d410dcd 100644 --- a/backend/onyx/redis/redis_pool.py +++ b/backend/onyx/redis/redis_pool.py @@ -199,7 +199,7 @@ def get_redis_client(*, tenant_id: str | None) -> Redis: # value = redis_client.get('key') # print(value.decode()) # Output: 'value' -_async_redis_connection = None +_async_redis_connection: aioredis.Redis | None = None _async_lock = asyncio.Lock() diff --git a/backend/onyx/utils/retry_wrapper.py b/backend/onyx/utils/retry_wrapper.py index ef710709934..b441d5250ac 100644 --- a/backend/onyx/utils/retry_wrapper.py +++ b/backend/onyx/utils/retry_wrapper.py @@ -4,8 +4,10 @@ from typing import cast from typing import TypeVar +import requests from retry import retry +from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS from onyx.utils.logger import setup_logger logger = setup_logger() @@ -42,3 +44,48 @@ def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any: return cast(F, wrapped_func) return retry_with_default + + +def request_with_retries( + method: str, + url: str, + *, + data: dict[str, Any] | None = None, + headers: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + timeout: int = REQUEST_TIMEOUT_SECONDS, + stream: bool = False, + tries: int = 8, + delay: float = 1, + backoff: float = 2, +) -> requests.Response: + @retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger)) + def _make_request() -> requests.Response: + response = requests.request( + method=method, + url=url, + data=data, + headers=headers, + params=params, + timeout=timeout, + stream=stream, + ) + try: + response.raise_for_status() + except requests.exceptions.HTTPError: + logger.exception( + "Request failed:\n%s", + { + "method": method, + "url": url, + "data": data, + "headers": headers, + "params": params, + "timeout": timeout, + "stream": stream, + }, + ) + raise + return response + + return _make_request()