Skip to content

Commit

Permalink
Merge pull request #149 from vanna-ai/flask
Browse files Browse the repository at this point in the history
Experimental integrated Flask app
  • Loading branch information
zainhoda authored Jan 16, 2024
2 parents a4cdf75 + 065d3a8 commit 33db008
Show file tree
Hide file tree
Showing 3 changed files with 333 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "vanna"
version = "0.0.31"
version = "0.0.32"
authors = [
{ name="Zain Hoda", email="[email protected]" },
]
Expand Down
2 changes: 1 addition & 1 deletion src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def connect_to_sqlite(self, url: str):
url = path

# Connect to the database
conn = sqlite3.connect(url)
conn = sqlite3.connect(url, check_same_thread=False)

def run_sql_sqlite(sql: str):
return pd.read_sql_query(sql, conn)
Expand Down
331 changes: 331 additions & 0 deletions src/vanna/flask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
import flask
from flask import Flask, Response, jsonify, request
import logging
import requests
from functools import wraps

from abc import ABC, abstractmethod
import uuid

class Cache(ABC):
@abstractmethod
def generate_id(self, *args, **kwargs):
pass

@abstractmethod
def get(self, id, field):
pass

@abstractmethod
def get_all(self, field_list) -> list:
pass

@abstractmethod
def set(self, id, field, value):
pass

@abstractmethod
def delete(self, id):
pass


class MemoryCache(Cache):
def __init__(self):
self.cache = {}

def generate_id(self, *args, **kwargs):
return str(uuid.uuid4())

def set(self, id, field, value):
if id not in self.cache:
self.cache[id] = {}

self.cache[id][field] = value

def get(self, id, field):
if id not in self.cache:
return None

if field not in self.cache[id]:
return None

return self.cache[id][field]

def get_all(self, field_list) -> list:
return [
{
"id": id,
**{
field: self.get(id=id, field=field)
for field in field_list
}
}
for id in self.cache
]

def delete(self, id):
if id in self.cache:
del self.cache[id]

class VannaFlaskApp:
flask_app = None

def requires_cache(self, fields):
def decorator(f):
@wraps(f)
def decorated(*args, **kwargs):
id = request.args.get('id')

if id is None:
return jsonify({"type": "error", "error": "No id provided"})

for field in fields:
if self.cache.get(id=id, field=field) is None:
return jsonify({"type": "error", "error": f"No {field} found"})

field_values = {field: self.cache.get(id=id, field=field) for field in fields}

# Add the id to the field_values
field_values['id'] = id

return f(*args, **field_values, **kwargs)
return decorated
return decorator

def __init__(self, vn, cache: Cache = MemoryCache()):
self.flask_app = Flask(__name__)
self.vn = vn
self.cache = cache

log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)

@self.flask_app.route('/api/v0/generate_questions', methods=['GET'])
def generate_questions():
# If self has an _model attribute and model=='chinook'
if hasattr(self.vn, '_model') and self.vn._model == 'chinook':
return jsonify({
"type": "question_list",
"questions": ['What are the top 10 artists by sales?', 'What are the total sales per year by country?', 'Who is the top selling artist in each genre? Show the sales numbers.', 'How do the employees rank in terms of sales performance?', 'Which 5 cities have the most customers?'],
"header": "Here are some questions you can ask:"
})

@self.flask_app.route('/api/v0/generate_sql', methods=['GET'])
def generate_sql():
question = flask.request.args.get('question')

if question is None:
return jsonify({"type": "error", "error": "No question provided"})

id = self.cache.generate_id(question=question)
sql = vn.generate_sql(question=question)

self.cache.set(id=id, field='question', value=question)
self.cache.set(id=id, field='sql', value=sql)

return jsonify(
{
"type": "sql",
"id": id,
"text": sql,
})

@self.flask_app.route('/api/v0/run_sql', methods=['GET'])
@self.requires_cache(['sql'])
def run_sql(id: str, sql: str):
try:
if not vn.run_sql_is_set:
return jsonify({"type": "error", "error": "Please connect to a database using vn.connect_to_... in order to run SQL queries."})

df = vn.run_sql(sql=sql)

cache.set(id=id, field='df', value=df)

return jsonify(
{
"type": "df",
"id": id,
"df": df.head(10).to_json(orient='records'),
})

except Exception as e:
return jsonify({"type": "error", "error": str(e)})

@self.flask_app.route('/api/v0/download_csv', methods=['GET'])
@self.requires_cache(['df'])
def download_csv(id: str, df):
csv = df.to_csv()

return Response(
csv,
mimetype="text/csv",
headers={"Content-disposition":
f"attachment; filename={id}.csv"})

@self.flask_app.route('/api/v0/generate_plotly_figure', methods=['GET'])
@self.requires_cache(['df', 'question', 'sql'])
def generate_plotly_figure(id: str, df, question, sql):
try:
code = vn.generate_plotly_code(question=question, sql=sql, df_metadata=f"Running df.dtypes gives:\n {df.dtypes}")
fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
fig_json = fig.to_json()

cache.set(id=id, field='fig_json', value=fig_json)

return jsonify(
{
"type": "plotly_figure",
"id": id,
"fig": fig_json,
})
except Exception as e:
# Print the stack trace
import traceback
traceback.print_exc()

return jsonify({"type": "error", "error": str(e)})

@self.flask_app.route('/api/v0/get_training_data', methods=['GET'])
def get_training_data():
df = vn.get_training_data()

return jsonify(
{
"type": "df",
"id": "training_data",
"df": df.tail(25).to_json(orient='records'),
})

@self.flask_app.route('/api/v0/remove_training_data', methods=['POST'])
def remove_training_data():
# Get id from the JSON body
id = flask.request.json.get('id')

if id is None:
return jsonify({"type": "error", "error": "No id provided"})

if vn.remove_training_data(id=id):
return jsonify({"success": True})
else:
return jsonify({"type": "error", "error": "Couldn't remove training data"})

@self.flask_app.route('/api/v0/train', methods=['POST'])
def add_training_data():
question = flask.request.json.get('question')
sql = flask.request.json.get('sql')
ddl = flask.request.json.get('ddl')
documentation = flask.request.json.get('documentation')

try:
id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)

return jsonify({"id": id})
except Exception as e:
print("TRAINING ERROR", e)
return jsonify({"type": "error", "error": str(e)})

@self.flask_app.route('/api/v0/generate_followup_questions', methods=['GET'])
@self.requires_cache(['df', 'question'])
def generate_followup_questions(id: str, df, question):
followup_questions = []
# followup_questions = vn.generate_followup_questions(question=question, df=df)
# if followup_questions is not None and len(followup_questions) > 5:
# followup_questions = followup_questions[:5]

cache.set(id=id, field='followup_questions', value=followup_questions)

return jsonify(
{
"type": "question_list",
"id": id,
"questions": followup_questions,
"header": "Followup Questions can be enabled in a future version if you allow the LLM to 'see' your query results."
})

@self.flask_app.route('/api/v0/load_question', methods=['GET'])
@self.requires_cache(['question', 'sql', 'df', 'fig_json', 'followup_questions'])
def load_question(id: str, question, sql, df, fig_json, followup_questions):
try:
return jsonify(
{
"type": "question_cache",
"id": id,
"question": question,
"sql": sql,
"df": df.head(10).to_json(orient='records'),
"fig": fig_json,
"followup_questions": followup_questions,
})

except Exception as e:
return jsonify({"type": "error", "error": str(e)})

@self.flask_app.route('/api/v0/get_question_history', methods=['GET'])
def get_question_history():
return jsonify({"type": "question_history", "questions": cache.get_all(field_list=['question']) })


@self.flask_app.route('/api/v0/<path:catch_all>', methods=['GET', 'POST'])
def catch_all(catch_all):
return jsonify({"type": "error", "error": "The rest of the API is not ported yet."})

@self.flask_app.route('/assets/<path:filename>')
def proxy_assets(filename):
remote_url = f'https://vanna.ai/assets/{filename}'
response = requests.get(remote_url, stream=True)

# Check if the request to the remote URL was successful
if response.status_code == 200:
excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection']
headers = [(name, value) for (name, value) in response.raw.headers.items() if name.lower() not in excluded_headers]
return Response(response.content, response.status_code, headers)
else:
return 'Error fetching file from remote server', response.status_code

# Proxy the /vanna.svg file to the remote server
@self.flask_app.route('/vanna.svg')
def proxy_vanna_svg():
remote_url = f'https://vanna.ai/img/vanna.svg'
response = requests.get(remote_url, stream=True)

# Check if the request to the remote URL was successful
if response.status_code == 200:
excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection']
headers = [(name, value) for (name, value) in response.raw.headers.items() if name.lower() not in excluded_headers]
return Response(response.content, response.status_code, headers)
else:
return 'Error fetching file from remote server', response.status_code

@self.flask_app.route('/', defaults={'path': ''})
@self.flask_app.route('/<path:path>')
def hello(path: str):
return """
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vanna.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<link href="https://fonts.googleapis.com/css2?family=Roboto+Slab:wght@350&display=swap" rel="stylesheet">
<script src="https://cdn.plot.ly/plotly-latest.min.js" type="text/javascript"></script>
<title>Vanna.AI</title>
<script type="module" crossorigin src="/assets/index-d29524f4.js"></script>
<link rel="stylesheet" href="/assets/index-b1a5a2f1.css">
</head>
<body class="bg-white dark:bg-slate-900">
<div id="app"></div>
</body>
</html>
"""

def run(self):
try:
from google.colab import output
output.serve_kernel_port_as_window(8084)
from google.colab.output import eval_js
print("Your app is running at:")
print(eval_js("google.colab.kernel.proxyPort(8084)"))
except:
print("Your app is running at:")
print("http://localhost:8084")
self.flask_app.run(host='0.0.0.0', port=8084, debug=False)

0 comments on commit 33db008

Please sign in to comment.