Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Gemini models #87

Merged
merged 3 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,19 @@ python -W ignore main.py \
-m mlx-community/defog-sqlcoder-7b-2
```

### Gemini

```bash
python -W ignore main.py \
-db postgres \
-o "results/results.csv" \
-g gemini \
-f "prompts/prompt.md" \
-m gemini-pro \
-p 1 \
-n 5
```

### Mistral
```bash
python -W ignore main.py \
Expand Down
191 changes: 191 additions & 0 deletions eval/gemini_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional

from eval.eval import compare_query_results
import pandas as pd
from utils.pruning import prune_metadata_str
from utils.questions import prepare_questions_df
from utils.creds import db_creds_all
from tqdm import tqdm
from time import time
import requests
from utils.reporting import upload_results

import vertexai
from vertexai.preview.generative_models import GenerativeModel, Part


def multiturn_generate_content(model_name="gemini-pro"):
config = {"max_output_tokens": 600, "temperature": 0, "top_p": 1}
model = GenerativeModel(model_name, generation_config=config)
chat = model.start_chat()
return chat


def generate_prompt(
prompt_file,
question,
db_name,
instructions="",
k_shot_prompt="",
glossary="",
table_metadata_string="",
public_data=True,
):
with open(prompt_file, "r") as f:
prompt = f.read()
question_instructions = question + " " + instructions

if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
question_instructions, db_name, public_data
)
else:
pruned_metadata_str = table_metadata_string

prompt = """Generate a PostgreSQL query to answer the following question: `{user_question}`

The query will run on a database with the following schema:
{table_metadata_string}

Please return only the SQL query in your response, nothing else."""

prompt = prompt.format(
user_question=question,
# instructions=instructions,
table_metadata_string=pruned_metadata_str,
# k_shot_prompt=k_shot_prompt,
# glossary=glossary,
)
return prompt


def process_row(row, model_name):
start_time = time()
chat = multiturn_generate_content(model_name=model_name)
response = chat.send_message(row["prompt"])

end_time = time()
generated_query = response.text.split("```sql")[-1].split("```")[0].strip()
print(generated_query)

row["generated_query"] = generated_query
row["latency_seconds"] = end_time - start_time
golden_query = row["query"]
db_name = row["db_name"]
db_type = row["db_type"]
question = row["question"]
query_category = row["query_category"]
exact_match = correct = 0

try:
exact_match, correct = compare_query_results(
query_gold=golden_query,
query_gen=generated_query,
db_name=db_name,
db_type=db_type,
db_creds=db_creds_all[row["db_type"]],
question=question,
query_category=query_category,
)
row["exact_match"] = int(exact_match)
row["correct"] = int(correct)
row["error_msg"] = ""
except Exception as e:
row["error_db_exec"] = 1
row["error_msg"] = f"QUERY EXECUTION ERROR: {e}"

return row


def run_gemini_eval(args):
# get params from args
questions_file = args.questions_file
prompt_file_list = args.prompt_file
num_questions = args.num_questions
public_data = not args.use_private_data
model_name = args.model
output_file_list = args.output_file
k_shot = args.k_shot
max_workers = args.parallel_threads
db_type = args.db_type

# get questions
print("Preparing questions...")
print(
f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}"
)
df = prepare_questions_df(questions_file, db_type, num_questions, k_shot)

for prompt_file, output_file in zip(prompt_file_list, output_file_list):
# create a prompt for each question
df["prompt"] = df[
[
"question",
"db_name",
"instructions",
"k_shot_prompt",
"glossary",
"table_metadata_string",
]
].apply(
lambda row: generate_prompt(
prompt_file,
row["question"],
row["db_name"],
row["instructions"],
row["k_shot_prompt"],
row["glossary"],
row["table_metadata_string"],
public_data,
),
axis=1,
)

total_tried = 0
total_correct = 0
output_rows = []

with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for row in df.to_dict("records"):
futures.append(executor.submit(process_row, row, model_name))

with tqdm(as_completed(futures), total=len(futures)) as pbar:
for f in pbar:
row = f.result()
output_rows.append(row)
if row["correct"]:
total_correct += 1
total_tried += 1
pbar.update(1)
pbar.set_description(
f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)"
)

output_df = pd.DataFrame(output_rows)
del output_df["prompt"]
print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean())
output_df = output_df.sort_values(by=["db_name", "query_category", "question"])
# get directory of output_file and create if not exist
output_dir = os.path.dirname(output_file)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
try:
output_df.to_csv(output_file, index=False, float_format="%.2f")
except:
output_df.to_pickle(output_file)

results = output_df.to_dict("records")
# upload results
with open(prompt_file, "r") as f:
prompt = f.read()
if args.upload_url is not None:
upload_results(
results=results,
url=args.upload_url,
runner_type="api_runner",
prompt=prompt,
args=args,
)
4 changes: 4 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@
from eval.mlx_runner import run_mlx_eval

run_mlx_eval(args)
elif args.model_type == "gemini":
from eval.gemini_runner import run_gemini_eval

run_gemini_eval(args)
elif args.model_type == "mistral":
from eval.mistral_runner import run_mistral_eval

Expand Down
Loading