-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added an option to let users make a call to any external API with a prompt #47
Merged
Merged
Changes from 4 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
14f9886
actually pass the seed parameters to the API 🤦🏾♂️
rishsriv c917c4b
added an option to evaluate models via an API
rishsriv 100cfcd
linting
rishsriv 2a9c699
added parallelization to api runner
rishsriv 1209b3a
added instructions for using the API mode to README
rishsriv 399ebc3
cleaned up api runner, removing unnecessary comments and function params
rishsriv 8effd9f
Merge branch 'main' into rishabh/api
rishsriv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from concurrent.futures import ThreadPoolExecutor, as_completed | ||
from typing import Optional | ||
from eval.eval import compare_query_results | ||
import pandas as pd | ||
from utils.pruning import prune_metadata_str | ||
from utils.questions import prepare_questions_df | ||
from tqdm import tqdm | ||
from time import time | ||
import requests | ||
|
||
|
||
def generate_prompt(prompt_file, question, db_name, public_data): | ||
with open(prompt_file, "r") as f: | ||
prompt = f.read() | ||
|
||
pruned_metadata_str = prune_metadata_str(question, db_name, public_data) | ||
prompt = prompt.format( | ||
user_question=question, table_metadata_string=pruned_metadata_str | ||
) | ||
return prompt | ||
|
||
|
||
def process_row(row, api_url, num_beams, public_data): | ||
start_time = time() | ||
# we set return_full_text to False so that we don't get the prompt text in the generated text | ||
# this simplifies our postprocessing to deal with just the truncation of the end of the query | ||
r = requests.post( | ||
api_url, | ||
json={ | ||
"prompt": row["prompt"], | ||
"n": 1, | ||
"use_beam_search": True, | ||
"best_of": num_beams, | ||
"temperature": 0, | ||
"stop": [";", "```"], | ||
"max_tokens": 600, | ||
}, | ||
) | ||
end_time = time() | ||
generated_query = ( | ||
r.json()["text"][0].split("```")[-1].split("```")[0].split(";")[0].strip() + ";" | ||
) | ||
|
||
row["generated_query"] = generated_query | ||
row["latency_seconds"] = end_time - start_time | ||
golden_query = row["query"] | ||
db_name = row["db_name"] | ||
question = row["question"] | ||
query_category = row["query_category"] | ||
exact_match = correct = 0 | ||
|
||
db_creds = { | ||
"host": "localhost", | ||
"port": 5432, | ||
"user": "postgres", | ||
"password": "postgres", | ||
"database": db_name, | ||
} | ||
|
||
try: | ||
exact_match, correct = compare_query_results( | ||
query_gold=golden_query, | ||
query_gen=generated_query, | ||
db_name=db_name, | ||
db_creds=db_creds, | ||
question=question, | ||
query_category=query_category, | ||
) | ||
row["exact_match"] = int(exact_match) | ||
row["correct"] = int(correct) | ||
row["error_msg"] = "" | ||
except Exception as e: | ||
row["error_db_exec"] = 1 | ||
row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" | ||
|
||
return row | ||
|
||
|
||
def run_api_eval(args): | ||
# get params from args | ||
questions_file = args.questions_file | ||
prompt_file = args.prompt_file | ||
num_questions = args.num_questions | ||
public_data = not args.use_private_data | ||
api_url = args.url | ||
output_file = args.output_file | ||
num_beams = args.num_beams | ||
max_workers = args.parallel_threads | ||
|
||
print("preparing questions...") | ||
# get questions | ||
print(f"Using {num_questions} questions from {questions_file}") | ||
df = prepare_questions_df(questions_file, num_questions) | ||
|
||
# create a prompt for each question | ||
df["prompt"] = df[["question", "db_name"]].apply( | ||
lambda row: generate_prompt( | ||
prompt_file, row["question"], row["db_name"], public_data | ||
), | ||
axis=1, | ||
) | ||
|
||
print("questions prepared\nnow loading model...") | ||
# initialize tokenizer and model | ||
total_tried = 0 | ||
total_correct = 0 | ||
output_rows = [] | ||
|
||
with ThreadPoolExecutor(max_workers=5) as executor: | ||
futures = [] | ||
for row in df.to_dict("records"): | ||
futures.append( | ||
executor.submit(process_row, row, api_url, num_beams, public_data) | ||
) | ||
|
||
with tqdm(as_completed(futures), total=len(futures)) as pbar: | ||
for f in pbar: | ||
row = f.result() | ||
output_rows.append(row) | ||
if row["correct"]: | ||
total_correct += 1 | ||
total_tried += 1 | ||
pbar.update(1) | ||
pbar.set_description( | ||
f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" | ||
) | ||
|
||
output_df = pd.DataFrame(output_rows) | ||
del output_df["prompt"] | ||
print(output_df.groupby("query_category")[["exact_match", "correct"]].mean()) | ||
output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) | ||
output_df.to_csv(output_file, index=False, float_format="%.2f") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,11 @@ | ||
### Task | ||
# Task | ||
Generate a SQL query to answer the following question: | ||
`{user_question}` | ||
|
||
### Database Schema | ||
# Database Schema | ||
The query will run on a database with the following schema: | ||
{table_metadata_string} | ||
|
||
### SQL | ||
Follow these steps to create the SQL Query: | ||
1. Only use the columns and tables present in the database schema | ||
2. Use table aliases to prevent ambiguity when doing joins. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. | ||
|
||
Given the database schema, here is the SQL query that answers `{user_question}`: | ||
```sql | ||
# SQL | ||
Here is the query to answer the question `{user_question}` | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: are we able to set
return_full_text
in the request params here? if not we can remove the comment above.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah we're not, removed!