Skip to content

Commit

Permalink
Adjust the way LLM class is instantiated + fix issue where .env file …
Browse files Browse the repository at this point in the history
…GEN_AI_API_KEY wasn't being used (#630)
  • Loading branch information
Weves authored Oct 26, 2023
1 parent 604e511 commit 76275b2
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 66 deletions.
5 changes: 1 addition & 4 deletions backend/danswer/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,7 @@
)

# If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = (
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY"))
or "dummy_llm_key"
)
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY"))

# If using GPT4All, HuggingFace Inference API, or OpenAI - specify the model version
GEN_AI_MODEL_VERSION = os.environ.get(
Expand Down
7 changes: 4 additions & 3 deletions backend/danswer/direct_qa/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@


def _ensure_openai_api_key(api_key: str | None) -> str:
try:
return api_key or get_gen_ai_api_key()
except ConfigNotFoundError:
final_api_key = api_key or get_gen_ai_api_key()
if final_api_key is None:
raise OpenAIKeyMissing()

return final_api_key


def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
"""
Expand Down
15 changes: 10 additions & 5 deletions backend/danswer/direct_qa/qa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.utils import check_number_of_tokens
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote
Expand All @@ -31,11 +32,15 @@
logger = setup_logger()


def get_gen_ai_api_key() -> str:
return (
cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
or GEN_AI_API_KEY
)
def get_gen_ai_api_key() -> str | None:
# first check if the key has been provided by the UI
try:
return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
except ConfigNotFoundError:
pass

# if not provided by the UI, fallback to the env variable
return GEN_AI_API_KEY


def extract_answer_quotes_freeform(
Expand Down
6 changes: 0 additions & 6 deletions backend/danswer/llm/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from danswer.configs.model_configs import API_BASE_OPENAI
from danswer.configs.model_configs import API_VERSION_OPENAI
from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.llm.llm import LangChainChatLLM
from danswer.llm.utils import should_be_verbose

Expand All @@ -23,11 +22,6 @@ def __init__(
*args: list[Any],
**kwargs: dict[str, Any]
):
# set a dummy API key if not specified so that LangChain doesn't throw an
# exception when trying to initialize the LLM which would prevent the API
# server from starting up
if not api_key:
api_key = GEN_AI_API_KEY
self._llm = AzureChatOpenAI(
model=model_version,
openai_api_type="azure",
Expand Down
58 changes: 24 additions & 34 deletions backend/danswer/llm/build.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any

from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType
Expand All @@ -11,48 +9,40 @@
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.azure import AzureGPT
from danswer.llm.google_colab_demo import GoogleColabDemo
from danswer.llm.llm import LLM
from danswer.llm.openai import OpenAIGPT


def get_llm_from_model(model: str, **kwargs: Any) -> LLM:
if model == DanswerGenAIModel.OPENAI_CHAT.value:
if API_TYPE_OPENAI == "azure":
return AzureGPT(**kwargs)
return OpenAIGPT(**kwargs)
if (
model == DanswerGenAIModel.REQUEST.value
and kwargs.get("model_host_type") == ModelHostType.COLAB_DEMO
):
return GoogleColabDemo(**kwargs)

raise ValueError(f"Unknown LLM model: {model}")


def get_default_llm(
api_key: str | None = None, timeout: int = QA_TIMEOUT, **kwargs: Any
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
) -> LLM:
"""NOTE: api_key/timeout must be a special args since we may want to check
if an API key is valid for the default model setup OR we may want to use the
default model with a different timeout specified."""
if api_key is None:
try:
api_key = get_gen_ai_api_key()
except ConfigNotFoundError:
# if no API key is found, assume this model doesn't need one
pass
api_key = get_gen_ai_api_key()

model_args = {
# provide a dummy key since LangChain will throw an exception if not
# given, which would prevent server startup
"api_key": api_key or "dummy_api_key",
"timeout": timeout,
"model_version": GEN_AI_MODEL_VERSION,
"endpoint": GEN_AI_ENDPOINT,
"max_output_tokens": GEN_AI_MAX_OUTPUT_TOKENS,
"temperature": GEN_AI_TEMPERATURE,
}
if INTERNAL_MODEL_VERSION == DanswerGenAIModel.OPENAI_CHAT.value:
if API_TYPE_OPENAI == "azure":
return AzureGPT(**model_args) # type: ignore
return OpenAIGPT(**model_args) # type: ignore
if (
INTERNAL_MODEL_VERSION == DanswerGenAIModel.REQUEST.value
and GEN_AI_HOST_TYPE == ModelHostType.COLAB_DEMO
):
return GoogleColabDemo(**model_args) # type: ignore

return get_llm_from_model(
model=INTERNAL_MODEL_VERSION,
api_key=api_key,
timeout=timeout,
model_version=GEN_AI_MODEL_VERSION,
endpoint=GEN_AI_ENDPOINT,
model_host_type=GEN_AI_HOST_TYPE,
max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS,
temperature=GEN_AI_TEMPERATURE,
**kwargs,
)
raise ValueError(f"Unknown LLM model: {INTERNAL_MODEL_VERSION}")
7 changes: 0 additions & 7 deletions backend/danswer/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from langchain.chat_models.openai import ChatOpenAI

from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.llm import LangChainChatLLM
from danswer.llm.utils import should_be_verbose
Expand All @@ -20,12 +19,6 @@ def __init__(
*args: list[Any],
**kwargs: dict[str, Any]
):
# set a dummy API key if not specified so that LangChain doesn't throw an
# exception when trying to initialize the LLM which would prevent the API
# server from starting up
if not api_key:
api_key = GEN_AI_API_KEY

self._llm = ChatOpenAI(
model=model_version,
openai_api_key=api_key,
Expand Down
12 changes: 5 additions & 7 deletions backend/danswer/server/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,14 +502,9 @@ def validate_existing_genai_api_key(
# First time checking the key, nothing unusual
pass

try:
genai_api_key = get_gen_ai_api_key()
except ConfigNotFoundError:
genai_api_key = get_gen_ai_api_key()
if genai_api_key is None:
raise HTTPException(status_code=404, detail="Key not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

get_dynamic_config_store().store(check_key_time, curr_time.timestamp())

try:
is_valid = check_model_api_key_is_valid(genai_api_key)
Expand All @@ -520,6 +515,9 @@ def validate_existing_genai_api_key(
if not is_valid:
raise HTTPException(status_code=400, detail="Invalid API key provided")

# mark check as successful
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())


@router.get("/admin/genai-api-key", response_model=ApiKey)
def get_gen_ai_api_key_from_dynamic_config_store(
Expand Down

1 comment on commit 76275b2

@vercel
Copy link

@vercel vercel bot commented on 76275b2 Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.