Skip to content

Commit

Permalink
Parse out sql from code block
Browse files Browse the repository at this point in the history
  • Loading branch information
zainhoda committed Jan 17, 2024
1 parent 0684b5a commit 2ed7a19
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "vanna"
version = "0.0.33"
version = "0.0.34"
authors = [
{ name="Zain Hoda", email="[email protected]" },
]
Expand Down
33 changes: 33 additions & 0 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def __init__(self, config=None):
self.config = config
self.run_sql_is_set = False

def log(self, message: str):
print(message)

def generate_sql(self, question: str, **kwargs) -> str:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
Expand All @@ -35,8 +38,31 @@ def generate_sql(self, question: str, **kwargs) -> str:
**kwargs,
)
llm_response = self.submit_prompt(prompt, **kwargs)
return self.extract_sql(llm_response)

def extract_sql(self, llm_response: str) -> str:
# If the llm_response contains a markdown code block, with or without the sql tag, extract the sql from it
sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL)
if sql:
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
return sql.group(1)

sql = re.search(r"```(.*)```", llm_response, re.DOTALL)
if sql:
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
return sql.group(1)

return llm_response

def is_sql_valid(self, sql: str) -> bool:
# This is a check to see the SQL is valid and should be run
# This simple function just checks if the SQL contains a SELECT statement

if "SELECT" in sql.upper():
return True
else:
return False

def generate_followup_questions(self, question: str, **kwargs) -> str:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
Expand Down Expand Up @@ -489,6 +515,13 @@ def ask(
return sql, None, None

try:
if self.is_sql_valid(sql) is False:
print("SQL is not valid, please try again.")
if print_results:
return None
else:
return sql, None, None

df = self.run_sql(sql)

if print_results:
Expand Down

0 comments on commit 2ed7a19

Please sign in to comment.