Skip to content

Commit

Permalink
Merge pull request #47 from defog-ai/rishabh/api
Browse files Browse the repository at this point in the history
Added an option to let users make a call to any external API with a prompt
  • Loading branch information
rishsriv authored Nov 15, 2023
2 parents 866d3f5 + 8effd9f commit d76237e
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 12 deletions.
21 changes: 18 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Having implemented the query generator, the next piece of abstraction would be t
### OpenAI / Anthropic
Remember to have your API key (`OPENAI_API_KEY` or `ANTHROPIC_API_KEY`) set as an environment variable before running the test if you plan to call the OpenAI or Anthropic/other LLM API's accordingly.

To test it out with just 10 questions (instead of all 175) using the gpt-3.5-turbo model, parallelized across 5 workers:
To test it out with just 10 questions (instead of all 200), parallelized across 5 :

```bash
python main.py \
Expand All @@ -128,7 +128,7 @@ python main.py \
```

### Hugging Face
To test it out with our fine-tuned sql model with just 10 questions (instead of all 175):
To test it out with our fine-tuned sql model with just 10 questions (instead of all 200):

```bash
# use the -W option to ignore warnings about sequential use of transformers pipeline
Expand Down Expand Up @@ -165,15 +165,30 @@ python -W ignore main.py \
```
While you can do the same for the other runners, the time savings are most significant when loading a large model locally, vs calling an always-on API.

### API
To test it out with just 10 questions (instead of all 200), parallelized across 3 calls:
```bash
mkdir results
python main.py \
-q data/questions_gen.csv \
-o results/results.csv \
-g api \
-b 5 \
-f prompts/prompt.md \
--url YOUR_API_URL \
-p 3 \
-n 10
```

### CLI Flags
You can use the following flags in the command line to change the configurations of your evaluation runs.
| CLI Flags | Description |
|-------------|-------|
| -q, --questions_file | CSV file that contains the test questions and true queries. |
| -n, --num_questions | Use this to limit the total number of questions you want to test. |
| -g, --model_type | Model type used. Make sure this matches the model used. Currently defined options in `main.py` are `oa` for OpenAI models and `hf` for Hugging Face models. |
| -g, --model_type | Model type used. Make sure this matches the model used. Currently defined options in `main.py` are `oa` for OpenAI models, `anthropic` for Anthropic models, `hf` for Hugging Face models, and `api` for API endpoints. |
| -m, --model | Model that will be tested and used to generate the queries. Currently defined options for OpenAI models are chat models `gpt-3.5-turbo-0613` and `gpt-4-0613`, and non-chat model `text-davinci-003`. For Hugging Face models, simply use the path of your chosen model (e.g. `defog/sqlcoder`). |
| --url | The URL of the API you want to send the prompt to. Only used when model_type is `api` |
| -f, --prompt_file | Markdown file with the prompt used for query generation. You can pass in a list of prompts to test sequentially without reloading the script. |
| -d, --use_private_data | Use this to read from your own private data library. |
| -o, --output_file | Output CSV file that will store your results. You need to pass the same number of output file paths as the number of prompt files |
Expand Down
128 changes: 128 additions & 0 deletions eval/api_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional
from eval.eval import compare_query_results
import pandas as pd
from utils.pruning import prune_metadata_str
from utils.questions import prepare_questions_df
from tqdm import tqdm
from time import time
import requests


def generate_prompt(prompt_file, question, db_name, public_data):
with open(prompt_file, "r") as f:
prompt = f.read()

pruned_metadata_str = prune_metadata_str(question, db_name, public_data)
prompt = prompt.format(
user_question=question, table_metadata_string=pruned_metadata_str
)
return prompt


def process_row(row, api_url, num_beams):
start_time = time()
r = requests.post(
api_url,
json={
"prompt": row["prompt"],
"n": 1,
"use_beam_search": True,
"best_of": num_beams,
"temperature": 0,
"stop": [";", "```"],
"max_tokens": 600,
},
)
end_time = time()
generated_query = (
r.json()["text"][0].split("```")[-1].split("```")[0].split(";")[0].strip() + ";"
)

row["generated_query"] = generated_query
row["latency_seconds"] = end_time - start_time
golden_query = row["query"]
db_name = row["db_name"]
question = row["question"]
query_category = row["query_category"]
exact_match = correct = 0

db_creds = {
"host": "localhost",
"port": 5432,
"user": "postgres",
"password": "postgres",
"database": db_name,
}

try:
exact_match, correct = compare_query_results(
query_gold=golden_query,
query_gen=generated_query,
db_name=db_name,
db_creds=db_creds,
question=question,
query_category=query_category,
)
row["exact_match"] = int(exact_match)
row["correct"] = int(correct)
row["error_msg"] = ""
except Exception as e:
row["error_db_exec"] = 1
row["error_msg"] = f"QUERY EXECUTION ERROR: {e}"

return row


def run_api_eval(args):
# get params from args
questions_file = args.questions_file
prompt_file = args.prompt_file
num_questions = args.num_questions
public_data = not args.use_private_data
api_url = args.url
output_file = args.output_file
num_beams = args.num_beams
max_workers = args.parallel_threads

print("preparing questions...")
# get questions
print(f"Using {num_questions} questions from {questions_file}")
df = prepare_questions_df(questions_file, num_questions)

# create a prompt for each question
df["prompt"] = df[["question", "db_name"]].apply(
lambda row: generate_prompt(
prompt_file, row["question"], row["db_name"], public_data
),
axis=1,
)

print("questions prepared\nnow loading model...")
# initialize tokenizer and model
total_tried = 0
total_correct = 0
output_rows = []

with ThreadPoolExecutor(max_workers=5) as executor:
futures = []
for row in df.to_dict("records"):
futures.append(executor.submit(process_row, row, api_url, num_beams))

with tqdm(as_completed(futures), total=len(futures)) as pbar:
for f in pbar:
row = f.result()
output_rows.append(row)
if row["correct"]:
total_correct += 1
total_tried += 1
pbar.update(1)
pbar.set_description(
f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)"
)

output_df = pd.DataFrame(output_rows)
del output_df["prompt"]
print(output_df.groupby("query_category")[["exact_match", "correct"]].mean())
output_df = output_df.sort_values(by=["db_name", "query_category", "question"])
output_df.to_csv(output_file, index=False, float_format="%.2f")
5 changes: 5 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
parser.add_argument("-g", "--model_type", type=str, required=True)
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-a", "--adapter", type=str)
parser.add_argument("--url", type=str)
parser.add_argument("-b", "--num_beams", type=int, default=4)
# take in a list of prompt files
parser.add_argument("-f", "--prompt_file", nargs="+", type=str, required=True)
Expand Down Expand Up @@ -45,6 +46,10 @@
from eval.hf_runner import run_hf_eval

run_hf_eval(args)
elif args.model_type == "api":
from eval.api_runner import run_api_eval

run_api_eval(args)
else:
raise ValueError(
f"Invalid model type: {args.model_type}. Model type must be one of: 'oa', 'hf'"
Expand Down
14 changes: 5 additions & 9 deletions prompts/prompt.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
### Task
# Task
Generate a SQL query to answer the following question:
`{user_question}`

### Database Schema
# Database Schema
The query will run on a database with the following schema:
{table_metadata_string}

### SQL
Follow these steps to create the SQL Query:
1. Only use the columns and tables present in the database schema
2. Use table aliases to prevent ambiguity when doing joins. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.

Given the database schema, here is the SQL query that answers `{user_question}`:
```sql
# SQL
Here is the query to answer the question `{user_question}`
```

0 comments on commit d76237e

Please sign in to comment.