Skip to content

Commit

Permalink
add number of columns and shuffle as args
Browse files Browse the repository at this point in the history
update all runners with the args
  • Loading branch information
wongjingping committed Mar 6, 2024
1 parent f4d7090 commit 068547a
Show file tree
Hide file tree
Showing 16 changed files with 158 additions and 78 deletions.
33 changes: 10 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ To test it out with just 10 questions (instead of all 200), parallelized across
```bash
python main.py \
-db postgres \
-o results/my_query_generator.csv \
-o results/openai.csv \
-g oa \
-f prompts/prompt_openai.md \
-m gpt-3.5-turbo-0613 \
Expand All @@ -138,7 +138,7 @@ To test out the full suite of questions for claude-2:
```bash
python main.py \
-db postgres \
-o results/claude-2.csv \
-o results/claude-3.csv \
-g anthropic \
-f prompts/prompt_anthropic.md \
-m claude-3-opus-20240229 \
Expand Down Expand Up @@ -166,14 +166,15 @@ We also have a [vllm](https://blog.vllm.ai/) runner which uses the vLLM engine t
```bash
python -W ignore main.py \
-db postgres \
-o "results/results.csv" \
-o "results/vllm.csv" \
-g vllm \
-f "prompts/prompt.md" \
-m defog/sqlcoder-7b-2
```

Optionally, if you're running evals on a model that is quantized with AWQ, add the `-qz` or `--quantized` parameter. Only applicable for the vllm runner.

#### Running with an API Server
If running with different settings, you can setup an api server to avoid reloading for each test setting and then run the tests subsequently. To setup the api server:
```bash
# to set up a vllm server
Expand All @@ -182,10 +183,10 @@ python -m vllm.entrypoints.api_server \
--tensor-parallel-size 4 \
--dtype float16

# to run sql-eval using the api runner - depending on how much your GPUs can take, can increase p to higher values
# to run sql-eval using the api runner - depending on how much your GPUs can take, can increase p and b to higher values
python main.py \
-db postgres \
-o results/results.csv \
-o results/api.csv \
-g api \
-b 1 \
-f prompts/prompt.md \
Expand All @@ -194,7 +195,7 @@ python main.py \
-n 10
```

### Multiple Prompts
#### Multiple Prompts

If you'd like to test out a few prompts in a single run (to save the few minutes spent loading the model into GPU at the start of each run), you can specify a list of prompt files in `--prompt_file` (e.g. `-f prompts/prompt-1.md prompts/prompt-2.md prompts/prompt-3.md`), as well as a corresponding list of output files in `--output_file` (e.g. `-o results/results-1.csv results/results-2.csv results/results-3.csv`). The number of prompts and output files must be the same. Here's a sample command:
```bash
Expand All @@ -207,20 +208,6 @@ python -W ignore main.py \
```
While you can do the same for the other runners, the time savings are most significant when loading a large model locally, vs calling an always-on API.

### API
To test it out with just 10 questions (instead of all 200), parallelized across 3 calls:
```bash
mkdir results
python main.py \
-db postgres \
-o results/results.csv \
-g api \
-b 5 \
-f prompts/prompt.md \
--api_url YOUR_API_URL \
-p 3 \
-n 10
```

### Llama CPP
To run the eval using Llama CPP, you can use the following code. Before running this, you must install `llama-cpp-python` with the following (on Apple Silicon)
Expand All @@ -232,7 +219,7 @@ Note that llama-cpp-python library does not currently have beam search, and henc
```bash
python -W ignore main.py \
-db postgres \
-o "results/results.csv" \
-o "results/llama_cpp.csv" \
-g llama_cpp \
-f "prompts/prompt.md" \
-m path/to/model.gguf
Expand All @@ -246,7 +233,7 @@ Note that MLX does not currently have beam search, and hence will have lower qua
```bash
python -W ignore main.py \
-db postgres \
-o "results/results.csv" \
-o "results/mlx_sqlcoder-7b-2.csv" \
-g mlx \
-f "prompts/prompt.md" \
-m mlx-community/defog-sqlcoder-7b-2
Expand All @@ -258,7 +245,7 @@ Before running this, you must create an account with [Google AI](https://ai.goog
```bash
python -W ignore main.py \
-db postgres \
-o "results/results.csv" \
-o "results/gemini_pro.csv" \
-g gemini \
-f "prompts/prompt_gemini.md" \
-m gemini-pro \
Expand Down
12 changes: 3 additions & 9 deletions eval/anthropic_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ def run_anthropic_eval(args):
args.questions_file, args.db_type, args.num_questions, args.k_shot
)
for prompt_file, output_file in zip(args.prompt_file, args.output_file):
qg_class = AnthropicQueryGenerator

input_rows = question_query_df.to_dict("records")
output_rows = []
with ThreadPoolExecutor(args.parallel_threads) as executor:
Expand All @@ -34,20 +32,14 @@ def run_anthropic_eval(args):
db_name = row["db_name"]
db_creds = db_creds_all[row["db_type"]]

qg = qg_class(
qg = AnthropicQueryGenerator(
db_creds=copy.deepcopy(db_creds),
db_name=db_name,
model=args.model,
prompt_file=prompt_file,
timeout=args.timeout_gen,
use_public_data=not args.use_private_data,
verbose=args.verbose,
instructions=row["instructions"],
k_shot_prompt=row["k_shot_prompt"],
glossary=row["glossary"],
table_metadata_string=row["table_metadata_string"],
prev_invalid_sql=row["prev_invalid_sql"],
prev_error_msg=row["prev_error_msg"],
)

generated_query_fut = executor.submit(
Expand All @@ -59,6 +51,8 @@ def run_anthropic_eval(args):
table_metadata_string=row["table_metadata_string"],
prev_invalid_sql=row["prev_invalid_sql"],
prev_error_msg=row["prev_error_msg"],
columns_to_keep=args.num_columns,
shuffle=args.shuffle_metadata,
)
futures.append(generated_query_fut)

Expand Down
2 changes: 2 additions & 0 deletions eval/api_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def run_api_eval(args):
row["prev_invalid_sql"],
row["prev_error_msg"],
public_data,
args.num_columns,
args.shuffle_metadata,
),
axis=1,
)
Expand Down
24 changes: 13 additions & 11 deletions eval/gemini_runner.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional
import os
from time import time

from eval.eval import compare_query_results
import pandas as pd
from tqdm import tqdm
from vertexai.preview.generative_models import GenerativeModel

from eval.eval import compare_query_results
from utils.creds import db_creds_all
from utils.pruning import prune_metadata_str
from utils.questions import prepare_questions_df
from utils.creds import db_creds_all
from tqdm import tqdm
from time import time
import requests
from utils.reporting import upload_results

import vertexai
from vertexai.preview.generative_models import GenerativeModel, Part


def multiturn_generate_content(model_name="gemini-pro"):
config = {"max_output_tokens": 600, "temperature": 0, "top_p": 1}
Expand All @@ -34,6 +31,8 @@ def generate_prompt(
prev_invalid_sql="",
prev_error_msg="",
public_data=True,
num_columns_to_keep=20,
shuffle=True,
):
if "gemini" not in prompt_file:
raise ValueError("Invalid prompt file. Please use prompt_gemini.md")
Expand All @@ -44,7 +43,7 @@ def generate_prompt(

if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
question_instructions, db_name, public_data
question_instructions, db_name, public_data, num_columns_to_keep, shuffle
)
else:
pruned_metadata_str = table_metadata_string
Expand Down Expand Up @@ -143,6 +142,8 @@ def run_gemini_eval(args):
row["prev_invalid_sql"],
row["prev_error_msg"],
public_data,
args.num_columns,
args.shuffle_metadata,
),
axis=1,
)
Expand All @@ -151,6 +152,7 @@ def run_gemini_eval(args):
total_correct = 0
output_rows = []

print(f"Running evaluation using {model_name}...")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for row in df.to_dict("records"):
Expand Down
2 changes: 2 additions & 0 deletions eval/hf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def run_hf_eval(args):
row["prev_invalid_sql"],
row["prev_error_msg"],
public_data,
args.num_columns,
args.shuffle_metadata,
),
axis=1,
)
Expand Down
2 changes: 2 additions & 0 deletions eval/llama_cpp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def run_llama_cpp_eval(args):
row["prev_invalid_sql"],
row["prev_error_msg"],
public_data,
args.num_columns,
args.shuffle_metadata,
),
axis=1,
)
Expand Down
37 changes: 23 additions & 14 deletions eval/mistral_runner.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
import os
from time import time
from typing import Optional

from eval.eval import compare_query_results
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
import pandas as pd
from tqdm import tqdm

from eval.eval import compare_query_results
from utils.creds import db_creds_all
from utils.pruning import prune_metadata_str
from utils.questions import prepare_questions_df
from utils.creds import db_creds_all
from tqdm import tqdm
from time import time
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
from utils.reporting import upload_results

api_key = os.environ.get("MISTRAL_API_KEY")
Expand All @@ -29,6 +30,8 @@ def generate_prompt(
prev_invalid_sql="",
prev_error_msg="",
public_data=True,
columns_to_keep=20,
shuffle=True,
):
with open(prompt_file, "r") as f:
prompt = f.read()
Expand All @@ -43,7 +46,7 @@ def generate_prompt(

if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
question_instructions, db_name, public_data
question_instructions, db_name, public_data, columns_to_keep, shuffle
)
else:
pruned_metadata_str = table_metadata_string
Expand Down Expand Up @@ -81,13 +84,17 @@ def process_row(row, model):
end_time = time()
generated_query = chat_response.choices[0].message.content

# replace all backslashes with empty string
generated_query = generated_query.replace("\\", "")
try:
# replace all backslashes with empty string
generated_query = generated_query.replace("\\", "")

generated_query = generated_query.split(";")[0].split("```sql")[-1].strip()
generated_query = [i for i in generated_query.split("```") if i.strip() != ""][
0
] + ";"
generated_query = generated_query.split(";")[0].split("```sql")[-1].strip()
generated_query = [i for i in generated_query.split("```") if i.strip() != ""][
0
] + ";"
except Exception as e:
print(e)
generated_query = chat_response.choices[0].message.content
row["generated_query"] = generated_query
row["latency_seconds"] = end_time - start_time
golden_query = row["query"]
Expand Down Expand Up @@ -161,6 +168,8 @@ def run_mistral_eval(args):
row["prev_invalid_sql"],
row["prev_error_msg"],
public_data,
args.num_columns,
args.shuffle_metadata,
),
axis=1,
)
Expand Down
2 changes: 2 additions & 0 deletions eval/mlx_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def run_mlx_eval(args):
row["prev_invalid_sql"],
row["prev_error_msg"],
public_data,
args.num_columns,
args.shuffle_metadata,
),
axis=1,
)
Expand Down
8 changes: 2 additions & 6 deletions eval/openai_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ def run_openai_eval(args):
timeout=args.timeout_gen,
use_public_data=not args.use_private_data,
verbose=args.verbose,
instructions=row["instructions"],
k_shot_prompt=row["k_shot_prompt"],
glossary=row["glossary"],
table_metadata_string=row["table_metadata_string"],
prev_invalid_sql=row["prev_invalid_sql"],
prev_error_msg=row["prev_error_msg"],
)

generated_query_fut = executor.submit(
Expand All @@ -58,6 +52,8 @@ def run_openai_eval(args):
table_metadata_string=row["table_metadata_string"],
prev_invalid_sql=row["prev_invalid_sql"],
prev_error_msg=row["prev_error_msg"],
columns_to_keep=args.num_columns,
shuffle=args.shuffle_metadata,
)
futures.append(generated_query_fut)

Expand Down
2 changes: 2 additions & 0 deletions eval/vllm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def run_vllm_eval(args):
row["prev_invalid_sql"],
row["prev_error_msg"],
public_data,
args.num_columns,
args.shuffle_metadata,
),
axis=1,
)
Expand Down
11 changes: 8 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@

if __name__ == "__main__":
parser = argparse.ArgumentParser()
# data-related parameters
parser.add_argument("-q", "--questions_file", type=str)
parser.add_argument("-n", "--num_questions", type=int, default=None)
parser.add_argument("-db", "--db_type", type=str, required=True)
parser.add_argument("-d", "--use_private_data", action="store_true")
# model-related parameters
parser.add_argument("-g", "--model_type", type=str, required=True)
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-a", "--adapter", type=str)
parser.add_argument("--api_url", type=str)
parser.add_argument("-b", "--num_beams", type=int, default=4)
# take in a list of prompt files
# inference-technique-related parameters
parser.add_argument("-f", "--prompt_file", nargs="+", type=str, required=True)
parser.add_argument("-d", "--use_private_data", action="store_true")
parser.add_argument("-b", "--num_beams", type=int, default=4)
parser.add_argument("-c", "--num_columns", type=int, default=20)
parser.add_argument("-s", "--shuffle_metadata", action="store_true")
parser.add_argument("-k", "--k_shot", action="store_true")
# execution-related parameters
parser.add_argument("-o", "--output_file", nargs="+", type=str, required=True)
parser.add_argument("-p", "--parallel_threads", type=int, default=5)
parser.add_argument("-t", "--timeout_gen", type=float, default=30.0)
Expand Down
8 changes: 7 additions & 1 deletion query_generators/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def generate_query(
table_metadata_string: str,
prev_invalid_sql: str,
prev_error_msg: str,
columns_to_keep: int,
shuffle: bool,
) -> dict:
start_time = time.time()
self.err = ""
Expand All @@ -97,7 +99,11 @@ def generate_query(
question_instructions = question + " " + instructions
if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
question_instructions, self.db_name, self.use_public_data
question_instructions,
self.db_name,
self.use_public_data,
columns_to_keep,
shuffle,
)
else:
pruned_metadata_str = table_metadata_string
Expand Down
Loading

0 comments on commit 068547a

Please sign in to comment.