Skip to content

Commit

Permalink
Merge pull request #319 from vanna-ai/app-upgrades
Browse files Browse the repository at this point in the history
Built-in Flask App Upgrades
  • Loading branch information
zainhoda authored Mar 27, 2024
2 parents e78cabb + 9dc80b0 commit 5824eff
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
exclude: 'docs|node_modules|migrations|.git|.tox'
exclude: 'docs|node_modules|migrations|.git|.tox|assets.py'
default_stages: [ commit ]
fail_fast: true

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "vanna"
version = "0.2.1"
version = "0.3.0"
authors = [
{ name="Zain Hoda", email="[email protected]" },
]
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
[flake8]
ignore = BLK100,W503,E203,E722,F821,F841
max-line-length = 100
exclude = .tox,.git,docs,venv,jupyter_notebook_config.py,jupyter_lab_config.py
exclude = .tox,.git,docs,venv,jupyter_notebook_config.py,jupyter_lab_config.py,assets.py

[tool:brunette]
verbose = true
single-quotes = false
target-version = py39
exclude = .tox,.git,docs,venv
exclude = .tox,.git,docs,venv,assets.py
176 changes: 158 additions & 18 deletions src/vanna/flask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def decorated(*args, **kwargs):
id = request.args.get("id")

if id is None:
return jsonify({"type": "error", "error": "No id provided"})
id = request.json.get("id")
if id is None:
return jsonify({"type": "error", "error": "No id provided"})

for field in fields:
if self.cache.get(id=id, field=field) is None:
Expand All @@ -94,15 +96,94 @@ def decorated(*args, **kwargs):

return decorator

def __init__(self, vn, cache: Cache = MemoryCache(), allow_llm_to_see_data=False):
def __init__(self, vn, cache: Cache = MemoryCache(),
allow_llm_to_see_data=False,
logo="https://img.vanna.ai/vanna-flask.svg",
title="Welcome to Vanna.AI",
subtitle="Your AI-powered copilot for SQL queries.",
show_training_data=True,
suggested_questions=True,
sql=True,
table=True,
csv_download=True,
chart=True,
redraw_chart=True,
auto_fix_sql=True,
ask_results_correct=True,
followup_questions=True,
summarization=True
):
"""
Expose a Flask app that can be used to interact with a Vanna instance.
Args:
vn: The Vanna instance to interact with.
cache: The cache to use. Defaults to MemoryCache, which uses an in-memory cache. You can also pass in a custom cache that implements the Cache interface.
allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False.
logo: The logo to display in the UI. Defaults to the Vanna logo.
title: The title to display in the UI. Defaults to "Welcome to Vanna.AI".
subtitle: The subtitle to display in the UI. Defaults to "Your AI-powered copilot for SQL queries.".
show_training_data: Whether to show the training data in the UI. Defaults to True.
suggested_questions: Whether to show suggested questions in the UI. Defaults to True.
sql: Whether to show the SQL input in the UI. Defaults to True.
table: Whether to show the table output in the UI. Defaults to True.
csv_download: Whether to allow downloading the table output as a CSV file. Defaults to True.
chart: Whether to show the chart output in the UI. Defaults to True.
redraw_chart: Whether to allow redrawing the chart. Defaults to True.
auto_fix_sql: Whether to allow auto-fixing SQL errors. Defaults to True.
ask_results_correct: Whether to ask the user if the results are correct. Defaults to True.
followup_questions: Whether to show followup questions. Defaults to True.
summarization: Whether to show summarization. Defaults to True.
Returns:
None
"""
self.flask_app = Flask(__name__)
self.vn = vn
self.cache = cache
self.allow_llm_to_see_data = allow_llm_to_see_data
self.logo = logo
self.title = title
self.subtitle = subtitle
self.show_training_data = show_training_data
self.suggested_questions = suggested_questions
self.sql = sql
self.table = table
self.csv_download = csv_download
self.chart = chart
self.redraw_chart = redraw_chart
self.auto_fix_sql = auto_fix_sql
self.ask_results_correct = ask_results_correct
self.followup_questions = followup_questions
self.summarization = summarization

log = logging.getLogger("werkzeug")
log.setLevel(logging.ERROR)

@self.flask_app.route("/api/v0/get_config", methods=["GET"])
def get_config():
return jsonify(
{
"type": "config",
"config": {
"logo": self.logo,
"title": self.title,
"subtitle": self.subtitle,
"show_training_data": self.show_training_data,
"suggested_questions": self.suggested_questions,
"sql": self.sql,
"table": self.table,
"csv_download": self.csv_download,
"chart": self.chart,
"redraw_chart": self.redraw_chart,
"auto_fix_sql": self.auto_fix_sql,
"ask_results_correct": self.ask_results_correct,
"followup_questions": self.followup_questions,
"summarization": self.summarization,
},
}
)

@self.flask_app.route("/api/v0/generate_questions", methods=["GET"])
def generate_questions():
# If self has an _model attribute and model=='chinook'
Expand Down Expand Up @@ -199,12 +280,52 @@ def run_sql(id: str, sql: str):
{
"type": "df",
"id": id,
"df": df.head(10).to_json(orient="records"),
"df": df.head(10).to_json(orient='records', date_format='iso'),
}
)

except Exception as e:
return jsonify({"type": "error", "error": str(e)})
return jsonify({"type": "sql_error", "error": str(e)})

@self.flask_app.route("/api/v0/fix_sql", methods=["POST"])
@self.requires_cache(["question", "sql"])
def fix_sql(id: str, question:str, sql: str):
error = flask.request.json.get("error")

if error is None:
return jsonify({"type": "error", "error": "No error provided"})

question = f"I have an error: {error}\n\nHere is the SQL I tried to run: {sql}\n\nThis is the question I was trying to answer: {question}\n\nCan you rewrite the SQL to fix the error?"

fixed_sql = vn.generate_sql(question=question)

self.cache.set(id=id, field="sql", value=fixed_sql)

return jsonify(
{
"type": "sql",
"id": id,
"text": fixed_sql,
}
)


@self.flask_app.route('/api/v0/update_sql', methods=['POST'])
@self.requires_cache([])
def update_sql(id: str):
sql = flask.request.json.get('sql')

if sql is None:
return jsonify({"type": "error", "error": "No sql provided"})

cache.set(id=id, field='sql', value=sql)

return jsonify(
{
"type": "sql",
"id": id,
"text": sql,
})

@self.flask_app.route("/api/v0/download_csv", methods=["GET"])
@self.requires_cache(["df"])
Expand All @@ -220,6 +341,11 @@ def download_csv(id: str, df):
@self.flask_app.route("/api/v0/generate_plotly_figure", methods=["GET"])
@self.requires_cache(["df", "question", "sql"])
def generate_plotly_figure(id: str, df, question, sql):
chart_instructions = flask.request.args.get('chart_instructions')

if chart_instructions is not None:
question = f"{question}. When generating the chart, use these special instructions: {chart_instructions}"

try:
code = vn.generate_plotly_code(
question=question,
Expand Down Expand Up @@ -352,9 +478,9 @@ def generate_summary(id: str, df, question):

@self.flask_app.route("/api/v0/load_question", methods=["GET"])
@self.requires_cache(
["question", "sql", "df", "fig_json", "followup_questions"]
["question", "sql", "df", "fig_json"]
)
def load_question(id: str, question, sql, df, fig_json, followup_questions):
def load_question(id: str, question, sql, df, fig_json):
try:
return jsonify(
{
Expand All @@ -364,7 +490,6 @@ def load_question(id: str, question, sql, df, fig_json, followup_questions):
"sql": sql,
"df": df.head(10).to_json(orient="records"),
"fig": fig_json,
"followup_questions": followup_questions,
}
)

Expand Down Expand Up @@ -425,16 +550,31 @@ def proxy_vanna_svg():
def hello(path: str):
return html_content

def run(self):
try:
from google.colab import output
def run(self, *args, **kwargs):
"""
Run the Flask app.
Args:
*args: Arguments to pass to Flask's run method.
**kwargs: Keyword arguments to pass to Flask's run method.
Returns:
None
"""
if args or kwargs:
self.flask_app.run(*args, **kwargs)

else:
try:
from google.colab import output

output.serve_kernel_port_as_window(8084)
from google.colab.output import eval_js

output.serve_kernel_port_as_window(8084)
from google.colab.output import eval_js
print("Your app is running at:")
print(eval_js("google.colab.kernel.proxyPort(8084)"))
except:
print("Your app is running at:")
print("http://localhost:8084")

print("Your app is running at:")
print(eval_js("google.colab.kernel.proxyPort(8084)"))
except:
print("Your app is running at:")
print("http://localhost:8084")
self.flask_app.run(host="0.0.0.0", port=8084, debug=False)
self.flask_app.run(host="0.0.0.0", port=8084, debug=False)
42 changes: 22 additions & 20 deletions src/vanna/flask/assets.py

Large diffs are not rendered by default.

0 comments on commit 5824eff

Please sign in to comment.