Skip to content

Commit

Permalink
Merge pull request #609 from vanna-ai/multi-turn
Browse files Browse the repository at this point in the history
Multi-turn conversations
  • Loading branch information
zainhoda authored Aug 21, 2024
2 parents eaf3f5a + a292acc commit 7cb744a
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 35 deletions.
27 changes: 27 additions & 0 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,33 @@ def should_generate_chart(self, df: pd.DataFrame) -> bool:

return False

def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
"""
**Example:**
```python
rewritten_question = vn.generate_rewritten_question("Who are the top 5 customers by sales?", "Show me their email addresses")
```
Generate a rewritten question by combining the last question and the new question if they are related. If the new question is self-contained and not related to the last question, return the new question.
Args:
last_question (str): The previous question that was asked.
new_question (str): The new question to be combined with the last question.
**kwargs: Additional keyword arguments.
Returns:
str: The combined question if related, otherwise the new question.
"""
if last_question is None:
return new_question

prompt = [
self.system_message("Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."),
self.user_message("First question: " + last_question + "\nSecond question: " + new_question),
]

return self.submit_prompt(prompt=prompt, **kwargs)

def generate_followup_questions(
self, question: str, sql: str, df: pd.DataFrame, n_questions: int = 5, **kwargs
) -> list:
Expand Down
26 changes: 26 additions & 0 deletions src/vanna/flask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import uuid
from abc import ABC, abstractmethod
from functools import wraps
import importlib.metadata

import flask
import requests
Expand Down Expand Up @@ -353,6 +354,30 @@ def generate_sql(user: any):
}
)

@self.flask_app.route("/api/v0/generate_rewritten_question", methods=["GET"])
@self.requires_auth
def generate_rewritten_question(user: any):
"""
Generate a rewritten question
---
parameters:
- name: last_question
in: query
type: string
required: true
- name: new_question
in: query
type: string
required: true
"""

last_question = flask.request.args.get("last_question")
new_question = flask.request.args.get("new_question")

rewritten_question = self.vn.generate_rewritten_question(last_question, new_question)

return jsonify({"type": "rewritten_question", "question": rewritten_question})

@self.flask_app.route("/api/v0/get_function", methods=["GET"])
@self.requires_auth
def get_function(user: any):
Expand Down Expand Up @@ -1212,6 +1237,7 @@ def __init__(
self.config["followup_questions"] = followup_questions
self.config["summarization"] = summarization
self.config["function_generation"] = function_generation and hasattr(vn, "get_function")
self.config["version"] = importlib.metadata.version('vanna')

self.index_html_path = index_html_path
self.assets_folder = assets_folder
Expand Down
70 changes: 35 additions & 35 deletions src/vanna/flask/assets.py

Large diffs are not rendered by default.

0 comments on commit 7cb744a

Please sign in to comment.