Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added an option to let users make a call to any external API with a prompt #47

Merged
merged 7 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions eval/api_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional
from eval.eval import compare_query_results
import pandas as pd
from utils.pruning import prune_metadata_str
from utils.questions import prepare_questions_df
from tqdm import tqdm
from time import time
import requests


def generate_prompt(prompt_file, question, db_name, public_data):
with open(prompt_file, "r") as f:
prompt = f.read()

pruned_metadata_str = prune_metadata_str(question, db_name, public_data)
prompt = prompt.format(
user_question=question, table_metadata_string=pruned_metadata_str
)
return prompt


def process_row(row, api_url, num_beams, public_data):
start_time = time()
# we set return_full_text to False so that we don't get the prompt text in the generated text
# this simplifies our postprocessing to deal with just the truncation of the end of the query
r = requests.post(
api_url,
json={
"prompt": row["prompt"],
"n": 1,
"use_beam_search": True,
"best_of": num_beams,
"temperature": 0,
"stop": [";", "```"],
"max_tokens": 600,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: are we able to set return_full_text in the request params here? if not we can remove the comment above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah we're not, removed!

},
)
end_time = time()
generated_query = (
r.json()["text"][0].split("```")[-1].split("```")[0].split(";")[0].strip() + ";"
)

row["generated_query"] = generated_query
row["latency_seconds"] = end_time - start_time
golden_query = row["query"]
db_name = row["db_name"]
question = row["question"]
query_category = row["query_category"]
exact_match = correct = 0

db_creds = {
"host": "localhost",
"port": 5432,
"user": "postgres",
"password": "postgres",
"database": db_name,
}

try:
exact_match, correct = compare_query_results(
query_gold=golden_query,
query_gen=generated_query,
db_name=db_name,
db_creds=db_creds,
question=question,
query_category=query_category,
)
row["exact_match"] = int(exact_match)
row["correct"] = int(correct)
row["error_msg"] = ""
except Exception as e:
row["error_db_exec"] = 1
row["error_msg"] = f"QUERY EXECUTION ERROR: {e}"

return row


def run_api_eval(args):
# get params from args
questions_file = args.questions_file
prompt_file = args.prompt_file
num_questions = args.num_questions
public_data = not args.use_private_data
api_url = args.url
output_file = args.output_file
num_beams = args.num_beams
max_workers = args.parallel_threads

print("preparing questions...")
# get questions
print(f"Using {num_questions} questions from {questions_file}")
df = prepare_questions_df(questions_file, num_questions)

# create a prompt for each question
df["prompt"] = df[["question", "db_name"]].apply(
lambda row: generate_prompt(
prompt_file, row["question"], row["db_name"], public_data
),
axis=1,
)

print("questions prepared\nnow loading model...")
# initialize tokenizer and model
total_tried = 0
total_correct = 0
output_rows = []

with ThreadPoolExecutor(max_workers=5) as executor:
futures = []
for row in df.to_dict("records"):
futures.append(
executor.submit(process_row, row, api_url, num_beams, public_data)
)

with tqdm(as_completed(futures), total=len(futures)) as pbar:
for f in pbar:
row = f.result()
output_rows.append(row)
if row["correct"]:
total_correct += 1
total_tried += 1
pbar.update(1)
pbar.set_description(
f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)"
)

output_df = pd.DataFrame(output_rows)
del output_df["prompt"]
print(output_df.groupby("query_category")[["exact_match", "correct"]].mean())
output_df = output_df.sort_values(by=["db_name", "query_category", "question"])
output_df.to_csv(output_file, index=False, float_format="%.2f")
5 changes: 5 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
parser.add_argument("-g", "--model_type", type=str, required=True)
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-a", "--adapter", type=str)
parser.add_argument("--url", type=str)
parser.add_argument("-b", "--num_beams", type=int, default=4)
parser.add_argument("-f", "--prompt_file", type=str, required=True)
parser.add_argument("-d", "--use_private_data", action="store_true")
Expand Down Expand Up @@ -38,6 +39,10 @@
from eval.hf_runner import run_hf_eval

run_hf_eval(args)
elif args.model_type == "api":
from eval.api_runner import run_api_eval

run_api_eval(args)
else:
raise ValueError(
f"Invalid model type: {args.model_type}. Model type must be one of: 'oa', 'hf'"
Expand Down
14 changes: 5 additions & 9 deletions prompts/prompt.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
### Task
# Task
Generate a SQL query to answer the following question:
`{user_question}`

### Database Schema
# Database Schema
The query will run on a database with the following schema:
{table_metadata_string}

### SQL
Follow these steps to create the SQL Query:
1. Only use the columns and tables present in the database schema
2. Use table aliases to prevent ambiguity when doing joins. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.

Given the database schema, here is the SQL query that answers `{user_question}`:
```sql
# SQL
Here is the query to answer the question `{user_question}`
```
3 changes: 2 additions & 1 deletion query_generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_chat_completion(
temperature=0,
stop=[],
logit_bias={},
seed=123,
seed=100,
):
"""Get OpenAI chat completion for a given prompt and model"""
generated_text = ""
Expand All @@ -55,6 +55,7 @@ def get_chat_completion(
temperature=temperature,
stop=stop,
logit_bias=logit_bias,
seed=seed,
)
generated_text = completion.choices[0].message.content
except Exception as e:
Expand Down