Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemini 2 #224

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,18 +363,17 @@ python -W ignore main.py \

### Gemini

Before running this, you must create an account with [Google AI](https://ai.google.dev/) and set your credentials with `export GOOGLE_APPLICATION_CREDENTIALS=</path/to/service_account.json>`. Then, install these packages with `pip install vertexai google-cloud-aiplatform`.
Before running this, you need to set your credentials with `export GEMINI_API_KEY=<your_api_key>`. Then, install these packages with `pip install google-generative-ai`.

```bash
python -W ignore main.py \
python main.py \
-db postgres \
-q "data/questions_gen_postgres.csv" \
-o "results/gemini_pro.csv" \
-q "data/questions_gen_postgres.csv" "data/instruct_basic_postgres.csv" "data/instruct_advanced_postgres.csv" \
-o "results/gemini_flash_basic.csv" "results/gemini_flash_basic.csv" "results/gemini_flash_advanced.csv" \
-g gemini \
-f "prompts/prompt_gemini.md" \
-m gemini-pro \
-p 1 \
-n 5
-f "prompts/prompt_gemini.md" "prompts/prompt_gemini.md" "prompts/prompt_gemini.md" \
-m gemini-2.0-flash-exp \
-p 10
```

### Mistral
Expand Down
131 changes: 100 additions & 31 deletions eval/gemini_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,44 @@
from time import time

import pandas as pd
import sqlparse
from tqdm import tqdm
from vertexai.preview.generative_models import GenerativeModel

from eval.eval import compare_query_results
from utils.creds import db_creds_all
from utils.dialects import convert_postgres_ddl_to_dialect
from utils.gen_prompt import to_prompt_schema
from utils.pruning import prune_metadata_str
from utils.questions import prepare_questions_df
from utils.reporting import upload_results


def multiturn_generate_content(model_name="gemini-pro"):
config = {"max_output_tokens": 600, "temperature": 0, "top_p": 1}
model = GenerativeModel(model_name, generation_config=config)
chat = model.start_chat()
return chat
def setup_genai(api_key=None):
"""Initialize the Google GenAI client"""
if api_key is None:
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY environment variable must be set")
import google.generativeai as genai

genai.configure(api_key=api_key)
return genai


def get_chat_model(genai, model_name="gemini-pro"):
"""Get a chat model instance with configured parameters"""
generation_config = {"max_output_tokens": 600, "temperature": 0, "top_p": 1.0}
model = genai.GenerativeModel(
model_name=model_name, generation_config=generation_config
)
return model.start_chat()


def generate_prompt(
prompt_file,
question,
db_name,
db_type,
instructions="",
k_shot_prompt="",
glossary="",
Expand All @@ -37,20 +54,62 @@ def generate_prompt(
if "gemini" not in prompt_file:
raise ValueError("Invalid prompt file. Please use prompt_gemini.md")

if public_data:
from defog_data.metadata import dbs
import defog_data.supplementary as sup
else:
# raise Exception("Replace this with your private data import")
from defog_data_private.metadata import dbs
with open(prompt_file, "r") as f:
prompt = f.read()
question_instructions = question + " " + instructions

if table_metadata_string == "":
pruned_metadata_ddl, join_str = prune_metadata_str(
question_instructions, db_name, public_data, num_columns_to_keep, shuffle
)
pruned_metadata_str = pruned_metadata_ddl + join_str
if num_columns_to_keep > 0:
pruned_metadata_ddl, join_str = prune_metadata_str(
question_instructions,
db_name,
public_data,
num_columns_to_keep,
shuffle,
)
pruned_metadata_ddl = convert_postgres_ddl_to_dialect(
postgres_ddl=pruned_metadata_ddl,
to_dialect=db_type,
db_name=db_name,
)
pruned_metadata_str = pruned_metadata_ddl + join_str
elif num_columns_to_keep == 0:
md = dbs[db_name]["table_metadata"]
pruned_metadata_str = to_prompt_schema(md, shuffle)
pruned_metadata_str = convert_postgres_ddl_to_dialect(
postgres_ddl=pruned_metadata_str,
to_dialect=db_type,
db_name=db_name,
)
column_join = sup.columns_join.get(db_name, {})
# get join_str from column_join
join_list = []
for values in column_join.values():
col_1, col_2 = values[0]
# add to join_list
join_str = f"{col_1} can be joined with {col_2}"
if join_str not in join_list:
join_list.append(join_str)
if len(join_list) > 0:
join_str = "\nHere is a list of joinable columns:\n" + "\n".join(
join_list
)
else:
join_str = ""
pruned_metadata_str = pruned_metadata_str + join_str
else:
raise ValueError("columns_to_keep must be >= 0")
else:
pruned_metadata_str = table_metadata_string

prompt = prompt.format(
user_question=question,
db_type=db_type,
instructions=instructions,
table_metadata_string=pruned_metadata_str,
k_shot_prompt=k_shot_prompt,
Expand All @@ -61,14 +120,22 @@ def generate_prompt(
return prompt


def process_row(row, model_name, args):
def process_row(row, genai, model_name, args):
start_time = time()
chat = multiturn_generate_content(model_name=model_name)
chat = get_chat_model(genai, model_name=model_name)
response = chat.send_message(row["prompt"])

end_time = time()
generated_query = response.text.split("```sql")[-1].split("```")[0].strip()

generated_query = response.text.split("```sql", 1)[-1].split("```", 1)[0].strip()
try:
generated_query = sqlparse.format(
generated_query,
strip_comments=True,
strip_whitespace=True,
keyword_case="upper",
)
except:
pass
row["generated_query"] = generated_query
row["latency_seconds"] = end_time - start_time
golden_query = row["query"]
Expand Down Expand Up @@ -100,6 +167,9 @@ def process_row(row, model_name, args):


def run_gemini_eval(args):
# Initialize Google GenAI
genai = setup_genai()

# get params from args
questions_file_list = args.questions_file
prompt_file_list = args.prompt_file
Expand All @@ -116,15 +186,14 @@ def run_gemini_eval(args):
questions_file_list, prompt_file_list, output_file_list
):
print(f"Using prompt file {prompt_file}")
# get questions
print("Preparing questions...")
print(
f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}"
)
df = prepare_questions_df(
questions_file, db_type, num_questions, k_shot, cot_table_alias
)
# create a prompt for each question

df["prompt"] = df.apply(
lambda row: generate_prompt(
prompt_file,
Expand All @@ -137,8 +206,6 @@ def run_gemini_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["cot_instructions"],
row["cot_pregen"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand All @@ -154,7 +221,9 @@ def run_gemini_eval(args):
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for row in df.to_dict("records"):
futures.append(executor.submit(process_row, row, model_name, args))
futures.append(
executor.submit(process_row, row, genai, model_name, args)
)

with tqdm(as_completed(futures), total=len(futures)) as pbar:
for f in pbar:
Expand All @@ -172,7 +241,7 @@ def run_gemini_eval(args):
del output_df["prompt"]
print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean())
output_df = output_df.sort_values(by=["db_name", "query_category", "question"])
# get directory of output_file and create if not exist

output_dir = os.path.dirname(output_file)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
Expand All @@ -182,14 +251,14 @@ def run_gemini_eval(args):
output_df.to_pickle(output_file)

results = output_df.to_dict("records")
# upload results
with open(prompt_file, "r") as f:
prompt = f.read()

if args.upload_url is not None:
upload_results(
results=results,
url=args.upload_url,
runner_type="api_runner",
prompt=prompt,
args=args,
)
with open(prompt_file, "r") as f:
prompt = f.read()
upload_results(
results=results,
url=args.upload_url,
runner_type="api_runner",
prompt=prompt,
args=args,
)
9 changes: 6 additions & 3 deletions prompts/prompt_gemini.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
Generate a PostgreSQL query to answer the following question: `{user_question}`
Your task is to convert a text question to a {db_type} query, given a database schema.

The query will run on a database with the following schema:
Generate a SQL query that answers the question `{user_question}`.
{instructions}
This query will run on a database whose schema is represented in this SQL DDL:
{table_metadata_string}

Please return only the SQL query in your response, nothing else.
Return the SQL query that answers the question `{user_question}`
```sql
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
anthropic
argparse
func_timeout
google-generativeai
mistralai
mysql-connector-python
numpy==2.1.2
Expand Down
Loading