From 5568ae3ee708d527adf419805333ef5a7fe8666c Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Wed, 20 Sep 2023 17:22:58 -0400 Subject: [PATCH 1/3] get_training_data --- src/vanna/__init__.py | 2 +- src/vanna/base/base.py | 21 ++++++++- src/vanna/chromadb/chromadb_vector.py | 67 +++++++++++++++++++++++++-- src/vanna/openai/openai_chat.py | 2 +- src/vanna/remote.py | 30 +++++++++++- 5 files changed, 114 insertions(+), 8 deletions(-) diff --git a/src/vanna/__init__.py b/src/vanna/__init__.py index 8ce850ea..c3dec99d 100644 --- a/src/vanna/__init__.py +++ b/src/vanna/__init__.py @@ -109,7 +109,7 @@ end subgraph OpenAI_Chat - get_prompt + get_sql_prompt submit_prompt generate_question generate_plotly_code diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 3ef05521..0d976dc5 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -25,7 +25,7 @@ def generate_sql_from_question(self, question: str, **kwargs) -> str: question_sql_list = self.get_similar_question_sql(question, **kwargs) ddl_list = self.get_related_ddl(question, **kwargs) doc_list = self.get_related_documentation(question, **kwargs) - prompt = self.get_prompt( + prompt = self.get_sql_prompt( question=question, question_sql_list=question_sql_list, ddl_list=ddl_list, @@ -35,6 +35,19 @@ def generate_sql_from_question(self, question: str, **kwargs) -> str: llm_response = self.submit_prompt(prompt, **kwargs) return llm_response + def generate_questions(self, **kwargs) -> list[str]: + """ + **Example:** + ```python + vn.generate_questions() + ``` + + Generate a list of questions that you can ask Vanna.AI. + """ + question_sql = self.get_similar_question_sql(question="", **kwargs) + + return [q['question'] for q in question_sql] + # ----------------- Use Any Embeddings API ----------------- # @abstractmethod def generate_embedding(self, data: str, **kwargs) -> list[float]: @@ -65,10 +78,14 @@ def add_ddl(self, ddl: str, **kwargs) -> str: def add_documentation(self, doc: str, **kwargs) -> str: pass + @abstractmethod + def get_training_data(self) -> pd.DataFrame: + pass + # ----------------- Use Any Language Model API ----------------- # @abstractmethod - def get_prompt( + def get_sql_prompt( self, question: str, question_sql_list: list, diff --git a/src/vanna/chromadb/chromadb_vector.py b/src/vanna/chromadb/chromadb_vector.py index cbdb6336..7115c2ad 100644 --- a/src/vanna/chromadb/chromadb_vector.py +++ b/src/vanna/chromadb/chromadb_vector.py @@ -5,6 +5,7 @@ import chromadb from chromadb.config import Settings from chromadb.utils import embedding_functions +import pandas as pd from ..base import VannaBase @@ -49,23 +50,83 @@ def add_question_sql(self, question: str, sql: str, **kwargs): self.sql_collection.add( documents=question_sql_json, embeddings=self.generate_embedding(question_sql_json), - ids=str(uuid.uuid4()), + ids=str(uuid.uuid4())+"-sql", ) def add_ddl(self, ddl: str, **kwargs): self.ddl_collection.add( documents=ddl, embeddings=self.generate_embedding(ddl), - ids=str(uuid.uuid4()), + ids=str(uuid.uuid4())+"-ddl", ) def add_documentation(self, doc: str, **kwargs): self.documentation_collection.add( documents=doc, embeddings=self.generate_embedding(doc), - ids=str(uuid.uuid4()), + ids=str(uuid.uuid4())+"-doc", ) + def get_training_data(self, **kwargs) -> pd.DataFrame: + sql_data = self.sql_collection.get() + + df = pd.DataFrame() + + if sql_data is not None: + # Extract the documents and ids + documents = [json.loads(doc) for doc in sql_data['documents']] + ids = sql_data['ids'] + + # Create a DataFrame + df_sql = pd.DataFrame({ + 'id': ids, + 'question': [doc['question'] for doc in documents], + 'content': [doc['sql'] for doc in documents] + }) + + df_sql["training_data_type"] = "sql" + + df = pd.concat([df, df_sql]) + + ddl_data = self.ddl_collection.get() + + if ddl_data is not None: + # Extract the documents and ids + documents = [doc for doc in ddl_data['documents']] + ids = ddl_data['ids'] + + # Create a DataFrame + df_ddl = pd.DataFrame({ + 'id': ids, + 'question': [None for doc in documents], + 'content': [doc for doc in documents] + }) + + df_ddl["training_data_type"] = "ddl" + + df = pd.concat([df, df_ddl]) + + doc_data = self.documentation_collection.get() + + if doc_data is not None: + # Extract the documents and ids + documents = [doc for doc in doc_data['documents']] + ids = doc_data['ids'] + + # Create a DataFrame + df_doc = pd.DataFrame({ + 'id': ids, + 'question': [None for doc in documents], + 'content': [doc for doc in documents] + }) + + df_doc["training_data_type"] = "documentation" + + df = pd.concat([df, df_doc]) + + return df + + # Static method to extract the documents from the results of a query @staticmethod def _extract_documents(query_results) -> list: diff --git a/src/vanna/openai/openai_chat.py b/src/vanna/openai/openai_chat.py index 40e19350..1646168e 100644 --- a/src/vanna/openai/openai_chat.py +++ b/src/vanna/openai/openai_chat.py @@ -37,7 +37,7 @@ def user_message(message: str) -> dict: def assistant_message(message: str) -> dict: return {"role": "assistant", "content": message} - def get_prompt( + def get_sql_prompt( self, question: str, question_sql_list: list, diff --git a/src/vanna/remote.py b/src/vanna/remote.py index 519299c4..cf01e50f 100644 --- a/src/vanna/remote.py +++ b/src/vanna/remote.py @@ -3,6 +3,7 @@ from typing import Callable, List, Tuple, Union import requests +import pandas as pd from .base import VannaBase from .types import ( @@ -91,6 +92,33 @@ def _rpc_call(self, method, params): def _dataclass_to_dict(self, obj): return dataclasses.asdict(obj) + def get_training_data(self, **kwargs) -> pd.DataFrame: + """ + Get the training data for the current model + + **Example:** + ```python + training_data = vn.get_training_data() + ``` + + Returns: + pd.DataFrame or None: The training data, or None if an error occurred. + + """ + params = [] + + d = self._rpc_call(method="get_training_data", params=params) + + if "result" not in d: + return None + + # Load the result into a dataclass + training_data = DataFrameJSON(**d["result"]) + + df = pd.read_json(training_data.data) + + return df + def add_ddl(self, ddl: str, **kwargs) -> str: """ Adds a DDL statement to the model's training data @@ -283,7 +311,7 @@ def generate_question(self, sql: str, **kwargs) -> str: return question.question - def get_prompt( + def get_sql_prompt( self, question: str, question_sql_list: list, From fc9430a1a20273c077a04d4f181516a305b66ca8 Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Thu, 21 Sep 2023 13:46:17 -0400 Subject: [PATCH 2/3] more sync --- src/vanna/base/base.py | 10 +++-- src/vanna/chromadb/chromadb_vector.py | 31 ++++++++++++--- src/vanna/remote.py | 54 ++++++++++++++++++++++++++- 3 files changed, 84 insertions(+), 11 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 0d976dc5..a5c0150d 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -21,7 +21,7 @@ def __init__(self, config=None): self.config = config self.run_sql_is_set = False - def generate_sql_from_question(self, question: str, **kwargs) -> str: + def generate_sql(self, question: str, **kwargs) -> str: question_sql_list = self.get_similar_question_sql(question, **kwargs) ddl_list = self.get_related_ddl(question, **kwargs) doc_list = self.get_related_documentation(question, **kwargs) @@ -79,7 +79,11 @@ def add_documentation(self, doc: str, **kwargs) -> str: pass @abstractmethod - def get_training_data(self) -> pd.DataFrame: + def get_training_data(self, **kwargs) -> pd.DataFrame: + pass + + @abstractmethod + def remove_training_data(id: str, **kwargs) -> bool: pass # ----------------- Use Any Language Model API ----------------- # @@ -432,7 +436,7 @@ def ask( question = input("Enter a question: ") try: - sql = self.generate_sql_from_question(question=question) + sql = self.generate_sql(question=question) except Exception as e: print(e) return None, None, None diff --git a/src/vanna/chromadb/chromadb_vector.py b/src/vanna/chromadb/chromadb_vector.py index 7115c2ad..7af60a80 100644 --- a/src/vanna/chromadb/chromadb_vector.py +++ b/src/vanna/chromadb/chromadb_vector.py @@ -40,32 +40,39 @@ def generate_embedding(self, data: str, **kwargs) -> list[float]: return embedding[0] return embedding - def add_question_sql(self, question: str, sql: str, **kwargs): + def add_question_sql(self, question: str, sql: str, **kwargs) -> str: question_sql_json = json.dumps( { "question": question, "sql": sql, } ) + id = str(uuid.uuid4())+"-sql" self.sql_collection.add( documents=question_sql_json, embeddings=self.generate_embedding(question_sql_json), - ids=str(uuid.uuid4())+"-sql", + ids=id, ) - def add_ddl(self, ddl: str, **kwargs): + return id + + def add_ddl(self, ddl: str, **kwargs) -> str: + id = str(uuid.uuid4())+"-ddl" self.ddl_collection.add( documents=ddl, embeddings=self.generate_embedding(ddl), - ids=str(uuid.uuid4())+"-ddl", + ids=id, ) + return id - def add_documentation(self, doc: str, **kwargs): + def add_documentation(self, doc: str, **kwargs) -> str: + id = str(uuid.uuid4())+"-doc" self.documentation_collection.add( documents=doc, embeddings=self.generate_embedding(doc), - ids=str(uuid.uuid4())+"-doc", + ids=id, ) + return id def get_training_data(self, **kwargs) -> pd.DataFrame: sql_data = self.sql_collection.get() @@ -126,6 +133,18 @@ def get_training_data(self, **kwargs) -> pd.DataFrame: return df + def remove_training_data(self, id: str, **kwargs) -> bool: + if id.endswith("-sql"): + self.sql_collection.delete(ids=id) + return True + elif id.endswith("-ddl"): + self.ddl_collection.delete(ids=id) + return True + elif id.endswith("-doc"): + self.documentation_collection.delete(ids=id) + return True + else: + return False # Static method to extract the documents from the results of a query @staticmethod diff --git a/src/vanna/remote.py b/src/vanna/remote.py index cf01e50f..194babd5 100644 --- a/src/vanna/remote.py +++ b/src/vanna/remote.py @@ -4,6 +4,7 @@ import requests import pandas as pd +from io import StringIO from .base import VannaBase from .types import ( @@ -115,10 +116,59 @@ def get_training_data(self, **kwargs) -> pd.DataFrame: # Load the result into a dataclass training_data = DataFrameJSON(**d["result"]) - df = pd.read_json(training_data.data) + df = pd.read_json(StringIO(training_data.data)) return df + def remove_training_data(self, id: str, **kwargs) -> bool: + """ + Remove training data from the model + + **Example:** + ```python + vn.remove_training_data(id="1-ddl") + ``` + + Args: + id (str): The ID of the training data to remove. + """ + params = [StringData(data=id)] + + d = self._rpc_call(method="remove_training_data", params=params) + + if "result" not in d: + raise Exception(f"Error removing training data") + + status = Status(**d["result"]) + + if not status.success: + raise Exception(f"Error removing training data: {status.message}") + + return status.success + + def generate_questions(self) -> list[str]: + """ + **Example:** + ```python + vn.generate_questions() + # ['What is the average salary of employees?', 'What is the total salary of employees?', ...] + ``` + + Generate questions using the Vanna.AI API. + + Returns: + List[str] or None: The questions, or None if an error occurred. + """ + d = self._rpc_call(method="generate_questions", params=[]) + + if "result" not in d: + return None + + # Load the result into a dataclass + question_string_list = QuestionStringList(**d["result"]) + + return question_string_list.questions + def add_ddl(self, ddl: str, **kwargs) -> str: """ Adds a DDL statement to the model's training data @@ -343,7 +393,7 @@ def get_related_documentation(self, question: str, **kwargs) -> list: Not necessary for remote models as related documentation is generated on the server side. """ - def generate_sql_from_question(self, question: str, **kwargs) -> str: + def generate_sql(self, question: str, **kwargs) -> str: """ **Example:** ```python From b4efdf1d00202164ae26c1ce39a19e558f6ff0bc Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Fri, 22 Sep 2023 12:27:29 -0400 Subject: [PATCH 3/3] followup questions --- src/vanna/base/base.py | 28 ++++++++++++ src/vanna/openai/openai_chat.py | 80 ++++++++++++++++++++++++++------- src/vanna/remote.py | 50 +++++++++++++++++++++ 3 files changed, 143 insertions(+), 15 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index a5c0150d..17d364d7 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -10,6 +10,7 @@ import plotly.express as px import plotly.graph_objects as go import requests +import re from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError from ..types import TrainingPlan, TrainingPlanItem @@ -35,6 +36,22 @@ def generate_sql(self, question: str, **kwargs) -> str: llm_response = self.submit_prompt(prompt, **kwargs) return llm_response + def generate_followup_questions(self, question: str, **kwargs) -> str: + question_sql_list = self.get_similar_question_sql(question, **kwargs) + ddl_list = self.get_related_ddl(question, **kwargs) + doc_list = self.get_related_documentation(question, **kwargs) + prompt = self.get_followup_questions_prompt( + question=question, + question_sql_list=question_sql_list, + ddl_list=ddl_list, + doc_list=doc_list, + **kwargs, + ) + llm_response = self.submit_prompt(prompt, **kwargs) + + numbers_removed = re.sub(r'^\d+\.\s*', '', llm_response, flags=re.MULTILINE) + return numbers_removed.split("\n") + def generate_questions(self, **kwargs) -> list[str]: """ **Example:** @@ -99,6 +116,17 @@ def get_sql_prompt( ): pass + @abstractmethod + def get_followup_questions_prompt( + self, + question: str, + question_sql_list: list, + ddl_list: list, + doc_list: list, + **kwargs + ): + pass + @abstractmethod def submit_prompt(self, prompt, **kwargs) -> str: pass diff --git a/src/vanna/openai/openai_chat.py b/src/vanna/openai/openai_chat.py index 1646168e..72ff8a36 100644 --- a/src/vanna/openai/openai_chat.py +++ b/src/vanna/openai/openai_chat.py @@ -2,6 +2,7 @@ from abc import abstractmethod import openai +import pandas as pd from ..base import VannaBase @@ -37,6 +38,43 @@ def user_message(message: str) -> dict: def assistant_message(message: str) -> dict: return {"role": "assistant", "content": message} + @staticmethod + def str_to_approx_token_count(string: str) -> int: + return len(string) / 4 + + @staticmethod + def add_ddl_to_prompt(initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000) -> str: + if len(ddl_list) > 0: + initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" + + for ddl in ddl_list: + if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(ddl) < max_tokens: + initial_prompt += f"{ddl}\n\n" + + return initial_prompt + + @staticmethod + def add_documentation_to_prompt(initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000) -> str: + if len(documentation_list) > 0: + initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" + + for documentation in documentation_list: + if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(documentation) < max_tokens: + initial_prompt += f"{documentation}\n\n" + + return initial_prompt + + @staticmethod + def add_sql_to_prompt(initial_prompt: str, sql_list: list[str], max_tokens: int = 14000) -> str: + if len(sql_list) > 0: + initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" + + for question in sql_list: + if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(question["sql"]) < max_tokens: + initial_prompt += f"{question['question']}\n{question['sql']}\n\n" + + return initial_prompt + def get_sql_prompt( self, question: str, @@ -44,22 +82,12 @@ def get_sql_prompt( ddl_list: list, doc_list: list, **kwargs, - ) -> str: + ): initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n" - if len(ddl_list) > 0: - initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" + initial_prompt = OpenAI_Chat.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000) - for ddl in ddl_list: - if len(initial_prompt) < 50000: # Add DDL if it fits - initial_prompt += f"{ddl}\n\n" - - if len(doc_list) > 0: - initial_prompt += f"The following information may or may not be useful in constructing the SQL to answer the question\n" - - for doc in doc_list: - if len(initial_prompt) < 60000: # Add Documentation if it fits - initial_prompt += f"{doc}\n\n" + initial_prompt = OpenAI_Chat.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000) message_log = [OpenAI_Chat.system_message(initial_prompt)] @@ -75,6 +103,28 @@ def get_sql_prompt( return message_log + def get_followup_questions_prompt( + self, + question: str, + df: pd.DataFrame, + question_sql_list: list, + ddl_list: list, + doc_list: list, + **kwargs + ): + initial_prompt = f"The user initially asked the question: '{question}': \n\n" + + initial_prompt = OpenAI_Chat.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000) + + initial_prompt = OpenAI_Chat.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000) + + initial_prompt = OpenAI_Chat.add_sql_to_prompt(initial_prompt, question_sql_list, max_tokens=14000) + + message_log = [OpenAI_Chat.system_message(initial_prompt)] + message_log.append(OpenAI_Chat.user_message("Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions.")) + + return message_log + def generate_question(self, sql: str, **kwargs) -> str: response = self.submit_prompt( [ @@ -150,7 +200,7 @@ def submit_prompt(self, prompt, **kwargs) -> str: len(message["content"]) / 4 ) # Use 4 as an approximation for the number of characters per token - if "engine" in self.config: + if self.config is not None and "engine" in self.config: print( f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" ) @@ -161,7 +211,7 @@ def submit_prompt(self, prompt, **kwargs) -> str: stop=None, temperature=0.7, ) - elif "model" in self.config: + elif self.config is not None and "model" in self.config: print( f"Using model {self.config['model']} for {num_tokens} tokens (approx)" ) diff --git a/src/vanna/remote.py b/src/vanna/remote.py index 194babd5..c5f843ae 100644 --- a/src/vanna/remote.py +++ b/src/vanna/remote.py @@ -373,6 +373,19 @@ def get_sql_prompt( Not necessary for remote models as prompts are generated on the server side. """ + def get_followup_questions_prompt( + self, + question: str, + df: pd.DataFrame, + question_sql_list: list, + ddl_list: list, + doc_list: list, + **kwargs, + ): + """ + Not necessary for remote models as prompts are generated on the server side. + """ + def submit_prompt(self, prompt, **kwargs) -> str: """ Not necessary for remote models as prompts are handled on the server side. @@ -420,3 +433,40 @@ def generate_sql(self, question: str, **kwargs) -> str: sql_answer = SQLAnswer(**d["result"]) return sql_answer.sql + + def generate_followup_questions(self, question: str, df: pd.DataFrame, **kwargs) -> list[str]: + """ + **Example:** + ```python + vn.generate_followup_questions(question="What is the average salary of employees?", df=df) + # ['What is the average salary of employees in the Sales department?', 'What is the average salary of employees in the Engineering department?', ...] + ``` + + Generate follow-up questions using the Vanna.AI API. + + Args: + question (str): The question to generate follow-up questions for. + df (pd.DataFrame): The DataFrame to generate follow-up questions for. + + Returns: + List[str] or None: The follow-up questions, or None if an error occurred. + """ + params = [ + DataResult( + question=question, + sql=None, + table_markdown="", + error=None, + correction_attempts=0, + ) + ] + + d = self._rpc_call(method="generate_followup_questions", params=params) + + if "result" not in d: + return None + + # Load the result into a dataclass + question_string_list = QuestionStringList(**d["result"]) + + return question_string_list.questions \ No newline at end of file