Skip to content

Commit

Permalink
Merge pull request #350 from Aymane11/extract-sql-non-markdown-response
Browse files Browse the repository at this point in the history
Add sql extraction in case of non-markdown response or CTE
  • Loading branch information
zainhoda authored Apr 12, 2024
2 parents fe2d439 + f601a51 commit e167f2f
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/vanna/ZhipuAI/ZhipuAI_Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def add_ddl_to_prompt(
initial_prompt: str, ddl_list: List[str], max_tokens: int = 14000
) -> str:
if len(ddl_list) > 0:
initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for ddl in ddl_list:
if (
Expand All @@ -57,7 +57,7 @@ def add_documentation_to_prompt(
initial_prompt: str, documentation_List: List[str], max_tokens: int = 14000
) -> str:
if len(documentation_List) > 0:
initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for documentation in documentation_List:
if (
Expand All @@ -74,7 +74,7 @@ def add_sql_to_prompt(
initial_prompt: str, sql_List: List[str], max_tokens: int = 14000
) -> str:
if len(sql_List) > 0:
initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for question in sql_List:
if (
Expand Down
20 changes: 16 additions & 4 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ def generate_sql(self, question: str, **kwargs) -> str:
return self.extract_sql(llm_response)

def extract_sql(self, llm_response: str) -> str:
# If the llm_response is not markdown formatted, extract sql by finding select and ; in the response
sql = re.search(r"SELECT.*?;", llm_response, re.DOTALL)
if sql:
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}"
)
return sql.group(0)

# If the llm_response contains a CTE (with clause), extract the sql bewteen WITH and ;
sql = re.search(r"WITH.*?;", llm_response, re.DOTALL)
if sql:
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}")
return sql.group(0)
# If the llm_response contains a markdown code block, with or without the sql tag, extract the sql from it
sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL)
if sql:
Expand Down Expand Up @@ -363,7 +375,7 @@ def add_ddl_to_prompt(
self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000
) -> str:
if len(ddl_list) > 0:
initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for ddl in ddl_list:
if (
Expand All @@ -382,7 +394,7 @@ def add_documentation_to_prompt(
max_tokens: int = 14000,
) -> str:
if len(documentation_list) > 0:
initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for documentation in documentation_list:
if (
Expand All @@ -398,7 +410,7 @@ def add_sql_to_prompt(
self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000
) -> str:
if len(sql_list) > 0:
initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for question in sql_list:
if (
Expand Down Expand Up @@ -1238,7 +1250,7 @@ def train(
"""

if question and not sql:
raise ValidationError(f"Please also provide a SQL query")
raise ValidationError("Please also provide a SQL query")

if documentation:
print("Adding documentation....")
Expand Down
2 changes: 1 addition & 1 deletion src/vanna/flask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def proxy_assets(filename):
# Proxy the /vanna.svg file to the remote server
@self.flask_app.route("/vanna.svg")
def proxy_vanna_svg():
remote_url = f"https://vanna.ai/img/vanna.svg"
remote_url = "https://vanna.ai/img/vanna.svg"
response = requests.get(remote_url, stream=True)

# Check if the request to the remote URL was successful
Expand Down
2 changes: 1 addition & 1 deletion src/vanna/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def remove_training_data(self, id: str, **kwargs) -> bool:
d = self._rpc_call(method="remove_training_data", params=params)

if "result" not in d:
raise Exception(f"Error removing training data")
raise Exception("Error removing training data")

status = Status(**d["result"])

Expand Down
16 changes: 8 additions & 8 deletions src/vanna/vannadb/vannadb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

from ..base import VannaBase
from ..types import (
DataFrameJSON,
Question,
QuestionSQLPair,
Status,
StatusWithId,
StringData,
TrainingData,
DataFrameJSON,
Question,
QuestionSQLPair,
Status,
StatusWithId,
StringData,
TrainingData,
)


Expand Down Expand Up @@ -141,7 +141,7 @@ def remove_training_data(self, id: str, **kwargs) -> bool:
d = self._rpc_call(method="remove_training_data", params=params)

if "result" not in d:
raise Exception(f"Error removing training data")
raise Exception("Error removing training data")

status = Status(**d["result"])

Expand Down

0 comments on commit e167f2f

Please sign in to comment.