Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync functions between local and remote #120

Merged
merged 3 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/vanna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
end

subgraph OpenAI_Chat
get_prompt
get_sql_prompt
submit_prompt
generate_question
generate_plotly_code
Expand Down
57 changes: 53 additions & 4 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import plotly.express as px
import plotly.graph_objects as go
import requests
import re

from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
from ..types import TrainingPlan, TrainingPlanItem
Expand All @@ -21,11 +22,11 @@ def __init__(self, config=None):
self.config = config
self.run_sql_is_set = False

def generate_sql_from_question(self, question: str, **kwargs) -> str:
def generate_sql(self, question: str, **kwargs) -> str:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
prompt = self.get_prompt(
prompt = self.get_sql_prompt(
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
Expand All @@ -35,6 +36,35 @@ def generate_sql_from_question(self, question: str, **kwargs) -> str:
llm_response = self.submit_prompt(prompt, **kwargs)
return llm_response

def generate_followup_questions(self, question: str, **kwargs) -> str:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
prompt = self.get_followup_questions_prompt(
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list,
**kwargs,
)
llm_response = self.submit_prompt(prompt, **kwargs)

numbers_removed = re.sub(r'^\d+\.\s*', '', llm_response, flags=re.MULTILINE)
return numbers_removed.split("\n")

def generate_questions(self, **kwargs) -> list[str]:
"""
**Example:**
```python
vn.generate_questions()
```

Generate a list of questions that you can ask Vanna.AI.
"""
question_sql = self.get_similar_question_sql(question="", **kwargs)

return [q['question'] for q in question_sql]

# ----------------- Use Any Embeddings API ----------------- #
@abstractmethod
def generate_embedding(self, data: str, **kwargs) -> list[float]:
Expand Down Expand Up @@ -65,10 +95,18 @@ def add_ddl(self, ddl: str, **kwargs) -> str:
def add_documentation(self, doc: str, **kwargs) -> str:
pass

@abstractmethod
def get_training_data(self, **kwargs) -> pd.DataFrame:
pass

@abstractmethod
def remove_training_data(id: str, **kwargs) -> bool:
pass

# ----------------- Use Any Language Model API ----------------- #

@abstractmethod
def get_prompt(
def get_sql_prompt(
self,
question: str,
question_sql_list: list,
Expand All @@ -78,6 +116,17 @@ def get_prompt(
):
pass

@abstractmethod
def get_followup_questions_prompt(
self,
question: str,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs
):
pass

@abstractmethod
def submit_prompt(self, prompt, **kwargs) -> str:
pass
Expand Down Expand Up @@ -415,7 +464,7 @@ def ask(
question = input("Enter a question: ")

try:
sql = self.generate_sql_from_question(question=question)
sql = self.generate_sql(question=question)
except Exception as e:
print(e)
return None, None, None
Expand Down
92 changes: 86 additions & 6 deletions src/vanna/chromadb/chromadb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions
import pandas as pd

from ..base import VannaBase

Expand Down Expand Up @@ -39,32 +40,111 @@ def generate_embedding(self, data: str, **kwargs) -> list[float]:
return embedding[0]
return embedding

def add_question_sql(self, question: str, sql: str, **kwargs):
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
question_sql_json = json.dumps(
{
"question": question,
"sql": sql,
}
)
id = str(uuid.uuid4())+"-sql"
self.sql_collection.add(
documents=question_sql_json,
embeddings=self.generate_embedding(question_sql_json),
ids=str(uuid.uuid4()),
ids=id,
)

def add_ddl(self, ddl: str, **kwargs):
return id

def add_ddl(self, ddl: str, **kwargs) -> str:
id = str(uuid.uuid4())+"-ddl"
self.ddl_collection.add(
documents=ddl,
embeddings=self.generate_embedding(ddl),
ids=str(uuid.uuid4()),
ids=id,
)
return id

def add_documentation(self, doc: str, **kwargs):
def add_documentation(self, doc: str, **kwargs) -> str:
id = str(uuid.uuid4())+"-doc"
self.documentation_collection.add(
documents=doc,
embeddings=self.generate_embedding(doc),
ids=str(uuid.uuid4()),
ids=id,
)
return id

def get_training_data(self, **kwargs) -> pd.DataFrame:
sql_data = self.sql_collection.get()

df = pd.DataFrame()

if sql_data is not None:
# Extract the documents and ids
documents = [json.loads(doc) for doc in sql_data['documents']]
ids = sql_data['ids']

# Create a DataFrame
df_sql = pd.DataFrame({
'id': ids,
'question': [doc['question'] for doc in documents],
'content': [doc['sql'] for doc in documents]
})

df_sql["training_data_type"] = "sql"

df = pd.concat([df, df_sql])

ddl_data = self.ddl_collection.get()

if ddl_data is not None:
# Extract the documents and ids
documents = [doc for doc in ddl_data['documents']]
ids = ddl_data['ids']

# Create a DataFrame
df_ddl = pd.DataFrame({
'id': ids,
'question': [None for doc in documents],
'content': [doc for doc in documents]
})

df_ddl["training_data_type"] = "ddl"

df = pd.concat([df, df_ddl])

doc_data = self.documentation_collection.get()

if doc_data is not None:
# Extract the documents and ids
documents = [doc for doc in doc_data['documents']]
ids = doc_data['ids']

# Create a DataFrame
df_doc = pd.DataFrame({
'id': ids,
'question': [None for doc in documents],
'content': [doc for doc in documents]
})

df_doc["training_data_type"] = "documentation"

df = pd.concat([df, df_doc])

return df

def remove_training_data(self, id: str, **kwargs) -> bool:
if id.endswith("-sql"):
self.sql_collection.delete(ids=id)
return True
elif id.endswith("-ddl"):
self.ddl_collection.delete(ids=id)
return True
elif id.endswith("-doc"):
self.documentation_collection.delete(ids=id)
return True
else:
return False

# Static method to extract the documents from the results of a query
@staticmethod
Expand Down
82 changes: 66 additions & 16 deletions src/vanna/openai/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import abstractmethod

import openai
import pandas as pd

from ..base import VannaBase

Expand Down Expand Up @@ -37,29 +38,56 @@ def user_message(message: str) -> dict:
def assistant_message(message: str) -> dict:
return {"role": "assistant", "content": message}

def get_prompt(
@staticmethod
def str_to_approx_token_count(string: str) -> int:
return len(string) / 4

@staticmethod
def add_ddl_to_prompt(initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000) -> str:
if len(ddl_list) > 0:
initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for ddl in ddl_list:
if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(ddl) < max_tokens:
initial_prompt += f"{ddl}\n\n"

return initial_prompt

@staticmethod
def add_documentation_to_prompt(initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000) -> str:
if len(documentation_list) > 0:
initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for documentation in documentation_list:
if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(documentation) < max_tokens:
initial_prompt += f"{documentation}\n\n"

return initial_prompt

@staticmethod
def add_sql_to_prompt(initial_prompt: str, sql_list: list[str], max_tokens: int = 14000) -> str:
if len(sql_list) > 0:
initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for question in sql_list:
if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(question["sql"]) < max_tokens:
initial_prompt += f"{question['question']}\n{question['sql']}\n\n"

return initial_prompt

def get_sql_prompt(
self,
question: str,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs,
) -> str:
):
initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n"

if len(ddl_list) > 0:
initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
initial_prompt = OpenAI_Chat.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000)

for ddl in ddl_list:
if len(initial_prompt) < 50000: # Add DDL if it fits
initial_prompt += f"{ddl}\n\n"

if len(doc_list) > 0:
initial_prompt += f"The following information may or may not be useful in constructing the SQL to answer the question\n"

for doc in doc_list:
if len(initial_prompt) < 60000: # Add Documentation if it fits
initial_prompt += f"{doc}\n\n"
initial_prompt = OpenAI_Chat.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000)

message_log = [OpenAI_Chat.system_message(initial_prompt)]

Expand All @@ -75,6 +103,28 @@ def get_prompt(

return message_log

def get_followup_questions_prompt(
self,
question: str,
df: pd.DataFrame,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs
):
initial_prompt = f"The user initially asked the question: '{question}': \n\n"

initial_prompt = OpenAI_Chat.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000)

initial_prompt = OpenAI_Chat.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000)

initial_prompt = OpenAI_Chat.add_sql_to_prompt(initial_prompt, question_sql_list, max_tokens=14000)

message_log = [OpenAI_Chat.system_message(initial_prompt)]
message_log.append(OpenAI_Chat.user_message("Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions."))

return message_log

def generate_question(self, sql: str, **kwargs) -> str:
response = self.submit_prompt(
[
Expand Down Expand Up @@ -150,7 +200,7 @@ def submit_prompt(self, prompt, **kwargs) -> str:
len(message["content"]) / 4
) # Use 4 as an approximation for the number of characters per token

if "engine" in self.config:
if self.config is not None and "engine" in self.config:
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
Expand All @@ -161,7 +211,7 @@ def submit_prompt(self, prompt, **kwargs) -> str:
stop=None,
temperature=0.7,
)
elif "model" in self.config:
elif self.config is not None and "model" in self.config:
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
Expand Down
Loading