diff --git a/src/vanna/openai/openai_chat.py b/src/vanna/openai/openai_chat.py index 16f74793..361ed441 100644 --- a/src/vanna/openai/openai_chat.py +++ b/src/vanna/openai/openai_chat.py @@ -9,53 +9,55 @@ class OpenAI_Chat(VannaBase): def __init__(self, client=None, config=None): VannaBase.__init__(self, config=config) - # default parameters - can be overrided using config - self.temperature = 0.7 + # Ensure config is a dictionary + config = config or {} - if "temperature" in config: - self.temperature = config["temperature"] + # Default parameters - can be overridden using config + self.temperature = config.get("temperature", 0.7) - if "api_type" in config: - raise Exception( - "Passing api_type is now deprecated. Please pass an OpenAI client instead." - ) - - if "api_base" in config: - raise Exception( - "Passing api_base is now deprecated. Please pass an OpenAI client instead." - ) - - if "api_version" in config: - raise Exception( - "Passing api_version is now deprecated. Please pass an OpenAI client instead." - ) + # Raise exceptions for deprecated parameters + for deprecated_param in ["api_type", "api_base", "api_version"]: + if deprecated_param in config: + raise ValueError( + f"Passing {deprecated_param} is now deprecated. Please pass an OpenAI client instead." + ) if client is not None: self.client = client return - if config is None and client is None: - self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) - return - - if "api_key" in config: - self.client = OpenAI(api_key=config["api_key"]) + # Initialize the OpenAI client with optional overrides from config + self.client = OpenAI( + api_key=config.get("api_key"), + base_url=config.get("base_url") + ) - def system_message(self, message: str) -> any: + def system_message(self, message: str) -> dict: return {"role": "system", "content": message} - def user_message(self, message: str) -> any: + def user_message(self, message: str) -> dict: return {"role": "user", "content": message} - def assistant_message(self, message: str) -> any: + def assistant_message(self, message: str) -> dict: return {"role": "assistant", "content": message} + def generate_response(self, prompt, num_tokens): + model = self.config.get("model", "gpt-4o-mini") + print(f"Using model {model} for {num_tokens} tokens (approx)") + response = self.client.chat.completions.create( + model=model, + messages=prompt, + stop=None, + temperature=self.temperature, + ) + return response + def submit_prompt(self, prompt, **kwargs) -> str: if prompt is None: - raise Exception("Prompt is None") + raise ValueError("Prompt is None") if len(prompt) == 0: - raise Exception("Prompt is empty") + raise ValueError("Prompt is empty") # Count the number of tokens in the message log # Use 4 as an approximation for the number of characters per token @@ -63,66 +65,14 @@ def submit_prompt(self, prompt, **kwargs) -> str: for message in prompt: num_tokens += len(message["content"]) / 4 - if kwargs.get("model", None) is not None: - model = kwargs.get("model", None) - print( - f"Using model {model} for {num_tokens} tokens (approx)" - ) - response = self.client.chat.completions.create( - model=model, - messages=prompt, - stop=None, - temperature=self.temperature, - ) - elif kwargs.get("engine", None) is not None: - engine = kwargs.get("engine", None) - print( - f"Using model {engine} for {num_tokens} tokens (approx)" - ) - response = self.client.chat.completions.create( - engine=engine, - messages=prompt, - stop=None, - temperature=self.temperature, - ) - elif self.config is not None and "engine" in self.config: - print( - f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" - ) - response = self.client.chat.completions.create( - engine=self.config["engine"], - messages=prompt, - stop=None, - temperature=self.temperature, - ) - elif self.config is not None and "model" in self.config: - print( - f"Using model {self.config['model']} for {num_tokens} tokens (approx)" - ) - response = self.client.chat.completions.create( - model=self.config["model"], - messages=prompt, - stop=None, - temperature=self.temperature, - ) - else: - if num_tokens > 3500: - model = "gpt-3.5-turbo-16k" - else: - model = "gpt-3.5-turbo" - - print(f"Using model {model} for {num_tokens} tokens (approx)") - response = self.client.chat.completions.create( - model=model, - messages=prompt, - stop=None, - temperature=self.temperature, - ) - - # Find the first response from the chatbot that has text in it (some responses may not have text) + # Use the generate_response method to get the response + response = self.generate_response(prompt, num_tokens) + + # Find the first response from the chatbot that has text in it + # (some responses may not have text) for choice in response.choices: if "text" in choice: return choice.text - # If no response with text is found, return the first response's content (which may be empty) - return response.choices[0].message.content + # If no response with text is found, return the first response's content + return response.choices[0].message.content \ No newline at end of file