-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #149 from vanna-ai/flask
Experimental integrated Flask app
- Loading branch information
Showing
3 changed files
with
333 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" | |
|
||
[project] | ||
name = "vanna" | ||
version = "0.0.31" | ||
version = "0.0.32" | ||
authors = [ | ||
{ name="Zain Hoda", email="[email protected]" }, | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,331 @@ | ||
import flask | ||
from flask import Flask, Response, jsonify, request | ||
import logging | ||
import requests | ||
from functools import wraps | ||
|
||
from abc import ABC, abstractmethod | ||
import uuid | ||
|
||
class Cache(ABC): | ||
@abstractmethod | ||
def generate_id(self, *args, **kwargs): | ||
pass | ||
|
||
@abstractmethod | ||
def get(self, id, field): | ||
pass | ||
|
||
@abstractmethod | ||
def get_all(self, field_list) -> list: | ||
pass | ||
|
||
@abstractmethod | ||
def set(self, id, field, value): | ||
pass | ||
|
||
@abstractmethod | ||
def delete(self, id): | ||
pass | ||
|
||
|
||
class MemoryCache(Cache): | ||
def __init__(self): | ||
self.cache = {} | ||
|
||
def generate_id(self, *args, **kwargs): | ||
return str(uuid.uuid4()) | ||
|
||
def set(self, id, field, value): | ||
if id not in self.cache: | ||
self.cache[id] = {} | ||
|
||
self.cache[id][field] = value | ||
|
||
def get(self, id, field): | ||
if id not in self.cache: | ||
return None | ||
|
||
if field not in self.cache[id]: | ||
return None | ||
|
||
return self.cache[id][field] | ||
|
||
def get_all(self, field_list) -> list: | ||
return [ | ||
{ | ||
"id": id, | ||
**{ | ||
field: self.get(id=id, field=field) | ||
for field in field_list | ||
} | ||
} | ||
for id in self.cache | ||
] | ||
|
||
def delete(self, id): | ||
if id in self.cache: | ||
del self.cache[id] | ||
|
||
class VannaFlaskApp: | ||
flask_app = None | ||
|
||
def requires_cache(self, fields): | ||
def decorator(f): | ||
@wraps(f) | ||
def decorated(*args, **kwargs): | ||
id = request.args.get('id') | ||
|
||
if id is None: | ||
return jsonify({"type": "error", "error": "No id provided"}) | ||
|
||
for field in fields: | ||
if self.cache.get(id=id, field=field) is None: | ||
return jsonify({"type": "error", "error": f"No {field} found"}) | ||
|
||
field_values = {field: self.cache.get(id=id, field=field) for field in fields} | ||
|
||
# Add the id to the field_values | ||
field_values['id'] = id | ||
|
||
return f(*args, **field_values, **kwargs) | ||
return decorated | ||
return decorator | ||
|
||
def __init__(self, vn, cache: Cache = MemoryCache()): | ||
self.flask_app = Flask(__name__) | ||
self.vn = vn | ||
self.cache = cache | ||
|
||
log = logging.getLogger('werkzeug') | ||
log.setLevel(logging.ERROR) | ||
|
||
@self.flask_app.route('/api/v0/generate_questions', methods=['GET']) | ||
def generate_questions(): | ||
# If self has an _model attribute and model=='chinook' | ||
if hasattr(self.vn, '_model') and self.vn._model == 'chinook': | ||
return jsonify({ | ||
"type": "question_list", | ||
"questions": ['What are the top 10 artists by sales?', 'What are the total sales per year by country?', 'Who is the top selling artist in each genre? Show the sales numbers.', 'How do the employees rank in terms of sales performance?', 'Which 5 cities have the most customers?'], | ||
"header": "Here are some questions you can ask:" | ||
}) | ||
|
||
@self.flask_app.route('/api/v0/generate_sql', methods=['GET']) | ||
def generate_sql(): | ||
question = flask.request.args.get('question') | ||
|
||
if question is None: | ||
return jsonify({"type": "error", "error": "No question provided"}) | ||
|
||
id = self.cache.generate_id(question=question) | ||
sql = vn.generate_sql(question=question) | ||
|
||
self.cache.set(id=id, field='question', value=question) | ||
self.cache.set(id=id, field='sql', value=sql) | ||
|
||
return jsonify( | ||
{ | ||
"type": "sql", | ||
"id": id, | ||
"text": sql, | ||
}) | ||
|
||
@self.flask_app.route('/api/v0/run_sql', methods=['GET']) | ||
@self.requires_cache(['sql']) | ||
def run_sql(id: str, sql: str): | ||
try: | ||
if not vn.run_sql_is_set: | ||
return jsonify({"type": "error", "error": "Please connect to a database using vn.connect_to_... in order to run SQL queries."}) | ||
|
||
df = vn.run_sql(sql=sql) | ||
|
||
cache.set(id=id, field='df', value=df) | ||
|
||
return jsonify( | ||
{ | ||
"type": "df", | ||
"id": id, | ||
"df": df.head(10).to_json(orient='records'), | ||
}) | ||
|
||
except Exception as e: | ||
return jsonify({"type": "error", "error": str(e)}) | ||
|
||
@self.flask_app.route('/api/v0/download_csv', methods=['GET']) | ||
@self.requires_cache(['df']) | ||
def download_csv(id: str, df): | ||
csv = df.to_csv() | ||
|
||
return Response( | ||
csv, | ||
mimetype="text/csv", | ||
headers={"Content-disposition": | ||
f"attachment; filename={id}.csv"}) | ||
|
||
@self.flask_app.route('/api/v0/generate_plotly_figure', methods=['GET']) | ||
@self.requires_cache(['df', 'question', 'sql']) | ||
def generate_plotly_figure(id: str, df, question, sql): | ||
try: | ||
code = vn.generate_plotly_code(question=question, sql=sql, df_metadata=f"Running df.dtypes gives:\n {df.dtypes}") | ||
fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False) | ||
fig_json = fig.to_json() | ||
|
||
cache.set(id=id, field='fig_json', value=fig_json) | ||
|
||
return jsonify( | ||
{ | ||
"type": "plotly_figure", | ||
"id": id, | ||
"fig": fig_json, | ||
}) | ||
except Exception as e: | ||
# Print the stack trace | ||
import traceback | ||
traceback.print_exc() | ||
|
||
return jsonify({"type": "error", "error": str(e)}) | ||
|
||
@self.flask_app.route('/api/v0/get_training_data', methods=['GET']) | ||
def get_training_data(): | ||
df = vn.get_training_data() | ||
|
||
return jsonify( | ||
{ | ||
"type": "df", | ||
"id": "training_data", | ||
"df": df.tail(25).to_json(orient='records'), | ||
}) | ||
|
||
@self.flask_app.route('/api/v0/remove_training_data', methods=['POST']) | ||
def remove_training_data(): | ||
# Get id from the JSON body | ||
id = flask.request.json.get('id') | ||
|
||
if id is None: | ||
return jsonify({"type": "error", "error": "No id provided"}) | ||
|
||
if vn.remove_training_data(id=id): | ||
return jsonify({"success": True}) | ||
else: | ||
return jsonify({"type": "error", "error": "Couldn't remove training data"}) | ||
|
||
@self.flask_app.route('/api/v0/train', methods=['POST']) | ||
def add_training_data(): | ||
question = flask.request.json.get('question') | ||
sql = flask.request.json.get('sql') | ||
ddl = flask.request.json.get('ddl') | ||
documentation = flask.request.json.get('documentation') | ||
|
||
try: | ||
id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation) | ||
|
||
return jsonify({"id": id}) | ||
except Exception as e: | ||
print("TRAINING ERROR", e) | ||
return jsonify({"type": "error", "error": str(e)}) | ||
|
||
@self.flask_app.route('/api/v0/generate_followup_questions', methods=['GET']) | ||
@self.requires_cache(['df', 'question']) | ||
def generate_followup_questions(id: str, df, question): | ||
followup_questions = [] | ||
# followup_questions = vn.generate_followup_questions(question=question, df=df) | ||
# if followup_questions is not None and len(followup_questions) > 5: | ||
# followup_questions = followup_questions[:5] | ||
|
||
cache.set(id=id, field='followup_questions', value=followup_questions) | ||
|
||
return jsonify( | ||
{ | ||
"type": "question_list", | ||
"id": id, | ||
"questions": followup_questions, | ||
"header": "Followup Questions can be enabled in a future version if you allow the LLM to 'see' your query results." | ||
}) | ||
|
||
@self.flask_app.route('/api/v0/load_question', methods=['GET']) | ||
@self.requires_cache(['question', 'sql', 'df', 'fig_json', 'followup_questions']) | ||
def load_question(id: str, question, sql, df, fig_json, followup_questions): | ||
try: | ||
return jsonify( | ||
{ | ||
"type": "question_cache", | ||
"id": id, | ||
"question": question, | ||
"sql": sql, | ||
"df": df.head(10).to_json(orient='records'), | ||
"fig": fig_json, | ||
"followup_questions": followup_questions, | ||
}) | ||
|
||
except Exception as e: | ||
return jsonify({"type": "error", "error": str(e)}) | ||
|
||
@self.flask_app.route('/api/v0/get_question_history', methods=['GET']) | ||
def get_question_history(): | ||
return jsonify({"type": "question_history", "questions": cache.get_all(field_list=['question']) }) | ||
|
||
|
||
@self.flask_app.route('/api/v0/<path:catch_all>', methods=['GET', 'POST']) | ||
def catch_all(catch_all): | ||
return jsonify({"type": "error", "error": "The rest of the API is not ported yet."}) | ||
|
||
@self.flask_app.route('/assets/<path:filename>') | ||
def proxy_assets(filename): | ||
remote_url = f'https://vanna.ai/assets/{filename}' | ||
response = requests.get(remote_url, stream=True) | ||
|
||
# Check if the request to the remote URL was successful | ||
if response.status_code == 200: | ||
excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection'] | ||
headers = [(name, value) for (name, value) in response.raw.headers.items() if name.lower() not in excluded_headers] | ||
return Response(response.content, response.status_code, headers) | ||
else: | ||
return 'Error fetching file from remote server', response.status_code | ||
|
||
# Proxy the /vanna.svg file to the remote server | ||
@self.flask_app.route('/vanna.svg') | ||
def proxy_vanna_svg(): | ||
remote_url = f'https://vanna.ai/img/vanna.svg' | ||
response = requests.get(remote_url, stream=True) | ||
|
||
# Check if the request to the remote URL was successful | ||
if response.status_code == 200: | ||
excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection'] | ||
headers = [(name, value) for (name, value) in response.raw.headers.items() if name.lower() not in excluded_headers] | ||
return Response(response.content, response.status_code, headers) | ||
else: | ||
return 'Error fetching file from remote server', response.status_code | ||
|
||
@self.flask_app.route('/', defaults={'path': ''}) | ||
@self.flask_app.route('/<path:path>') | ||
def hello(path: str): | ||
return """ | ||
<!doctype html> | ||
<html lang="en"> | ||
<head> | ||
<meta charset="UTF-8" /> | ||
<link rel="icon" type="image/svg+xml" href="/vanna.svg" /> | ||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> | ||
<link href="https://fonts.googleapis.com/css2?family=Roboto+Slab:wght@350&display=swap" rel="stylesheet"> | ||
<script src="https://cdn.plot.ly/plotly-latest.min.js" type="text/javascript"></script> | ||
<title>Vanna.AI</title> | ||
<script type="module" crossorigin src="/assets/index-d29524f4.js"></script> | ||
<link rel="stylesheet" href="/assets/index-b1a5a2f1.css"> | ||
</head> | ||
<body class="bg-white dark:bg-slate-900"> | ||
<div id="app"></div> | ||
</body> | ||
</html> | ||
""" | ||
|
||
def run(self): | ||
try: | ||
from google.colab import output | ||
output.serve_kernel_port_as_window(8084) | ||
from google.colab.output import eval_js | ||
print("Your app is running at:") | ||
print(eval_js("google.colab.kernel.proxyPort(8084)")) | ||
except: | ||
print("Your app is running at:") | ||
print("http://localhost:8084") | ||
self.flask_app.run(host='0.0.0.0', port=8084, debug=False) |