From 2ed7a197f7e85ccbdb4b13009bffd7313e030d3a Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Wed, 17 Jan 2024 13:22:55 -0500 Subject: [PATCH] Parse out sql from code block --- pyproject.toml | 2 +- src/vanna/base/base.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8da49c9c..cf21a3ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "vanna" -version = "0.0.33" +version = "0.0.34" authors = [ { name="Zain Hoda", email="zain@vanna.ai" }, ] diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 20adcfc2..292b2839 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -23,6 +23,9 @@ def __init__(self, config=None): self.config = config self.run_sql_is_set = False + def log(self, message: str): + print(message) + def generate_sql(self, question: str, **kwargs) -> str: question_sql_list = self.get_similar_question_sql(question, **kwargs) ddl_list = self.get_related_ddl(question, **kwargs) @@ -35,8 +38,31 @@ def generate_sql(self, question: str, **kwargs) -> str: **kwargs, ) llm_response = self.submit_prompt(prompt, **kwargs) + return self.extract_sql(llm_response) + + def extract_sql(self, llm_response: str) -> str: + # If the llm_response contains a markdown code block, with or without the sql tag, extract the sql from it + sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL) + if sql: + self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}") + return sql.group(1) + + sql = re.search(r"```(.*)```", llm_response, re.DOTALL) + if sql: + self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}") + return sql.group(1) + return llm_response + def is_sql_valid(self, sql: str) -> bool: + # This is a check to see the SQL is valid and should be run + # This simple function just checks if the SQL contains a SELECT statement + + if "SELECT" in sql.upper(): + return True + else: + return False + def generate_followup_questions(self, question: str, **kwargs) -> str: question_sql_list = self.get_similar_question_sql(question, **kwargs) ddl_list = self.get_related_ddl(question, **kwargs) @@ -489,6 +515,13 @@ def ask( return sql, None, None try: + if self.is_sql_valid(sql) is False: + print("SQL is not valid, please try again.") + if print_results: + return None + else: + return sql, None, None + df = self.run_sql(sql) if print_results: