Skip to content

Commit

Permalink
Gemini 2
Browse files Browse the repository at this point in the history
  • Loading branch information
wongjingping committed Dec 13, 2024
1 parent 0262d0e commit 00457a5
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 42 deletions.
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,18 +363,17 @@ python -W ignore main.py \

### Gemini

Before running this, you must create an account with [Google AI](https://ai.google.dev/) and set your credentials with `export GOOGLE_APPLICATION_CREDENTIALS=</path/to/service_account.json>`. Then, install these packages with `pip install vertexai google-cloud-aiplatform`.
Before running this, you need to set your credentials with `export GEMINI_API_KEY=<your_api_key>`. Then, install these packages with `pip install google-generative-ai`.

```bash
python -W ignore main.py \
python main.py \
-db postgres \
-q "data/questions_gen_postgres.csv" \
-o "results/gemini_pro.csv" \
-q "data/questions_gen_postgres.csv" "data/instruct_basic_postgres.csv" "data/instruct_advanced_postgres.csv" \
-o "results/gemini_flash_basic.csv" "results/gemini_flash_basic.csv" "results/gemini_flash_advanced.csv" \
-g gemini \
-f "prompts/prompt_gemini.md" \
-m gemini-pro \
-p 1 \
-n 5
-f "prompts/prompt_gemini.md" "prompts/prompt_gemini.md" "prompts/prompt_gemini.md" \
-m gemini-2.0-flash-exp \
-p 10
```

### Mistral
Expand Down
126 changes: 95 additions & 31 deletions eval/gemini_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,46 @@
from time import time

import pandas as pd
import sqlparse
from tqdm import tqdm
from vertexai.preview.generative_models import GenerativeModel

from eval.eval import compare_query_results
from utils.creds import db_creds_all
from utils.dialects import convert_postgres_ddl_to_dialect
from utils.gen_prompt import to_prompt_schema
from utils.pruning import prune_metadata_str
from utils.questions import prepare_questions_df
from utils.reporting import upload_results


def multiturn_generate_content(model_name="gemini-pro"):
config = {"max_output_tokens": 600, "temperature": 0, "top_p": 1}
model = GenerativeModel(model_name, generation_config=config)
chat = model.start_chat()
return chat
def setup_genai(api_key=None):
"""Initialize the Google GenAI client"""
if api_key is None:
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY environment variable must be set")
import google.generativeai as genai
genai.configure(api_key=api_key)
return genai


def get_chat_model(genai, model_name="gemini-pro"):
"""Get a chat model instance with configured parameters"""
generation_config = {
"max_output_tokens": 600,
"temperature": 0,
"top_p": 1.0
}
model = genai.GenerativeModel(model_name=model_name,
generation_config=generation_config)
return model.start_chat()


def generate_prompt(
prompt_file,
question,
db_name,
db_type,
instructions="",
k_shot_prompt="",
glossary="",
Expand All @@ -37,20 +56,62 @@ def generate_prompt(
if "gemini" not in prompt_file:
raise ValueError("Invalid prompt file. Please use prompt_gemini.md")

if public_data:
from defog_data.metadata import dbs
import defog_data.supplementary as sup
else:
# raise Exception("Replace this with your private data import")
from defog_data_private.metadata import dbs
with open(prompt_file, "r") as f:
prompt = f.read()
question_instructions = question + " " + instructions

if table_metadata_string == "":
pruned_metadata_ddl, join_str = prune_metadata_str(
question_instructions, db_name, public_data, num_columns_to_keep, shuffle
)
pruned_metadata_str = pruned_metadata_ddl + join_str
if num_columns_to_keep > 0:
pruned_metadata_ddl, join_str = prune_metadata_str(
question_instructions,
db_name,
public_data,
num_columns_to_keep,
shuffle,
)
pruned_metadata_ddl = convert_postgres_ddl_to_dialect(
postgres_ddl=pruned_metadata_ddl,
to_dialect=db_type,
db_name=db_name,
)
pruned_metadata_str = pruned_metadata_ddl + join_str
elif num_columns_to_keep == 0:
md = dbs[db_name]["table_metadata"]
pruned_metadata_str = to_prompt_schema(md, shuffle)
pruned_metadata_str = convert_postgres_ddl_to_dialect(
postgres_ddl=pruned_metadata_str,
to_dialect=db_type,
db_name=db_name,
)
column_join = sup.columns_join.get(db_name, {})
# get join_str from column_join
join_list = []
for values in column_join.values():
col_1, col_2 = values[0]
# add to join_list
join_str = f"{col_1} can be joined with {col_2}"
if join_str not in join_list:
join_list.append(join_str)
if len(join_list) > 0:
join_str = "\nHere is a list of joinable columns:\n" + "\n".join(
join_list
)
else:
join_str = ""
pruned_metadata_str = pruned_metadata_str + join_str
else:
raise ValueError("columns_to_keep must be >= 0")
else:
pruned_metadata_str = table_metadata_string

prompt = prompt.format(
user_question=question,
db_type=db_type,
instructions=instructions,
table_metadata_string=pruned_metadata_str,
k_shot_prompt=k_shot_prompt,
Expand All @@ -61,14 +122,17 @@ def generate_prompt(
return prompt


def process_row(row, model_name, args):
def process_row(row, genai, model_name, args):
start_time = time()
chat = multiturn_generate_content(model_name=model_name)
chat = get_chat_model(genai, model_name=model_name)
response = chat.send_message(row["prompt"])

end_time = time()
generated_query = response.text.split("```sql")[-1].split("```")[0].strip()

generated_query = response.text.split("```sql", 1)[-1].split("```", 1)[0].strip()
try:
generated_query = sqlparse.format(generated_query, strip_comments=True, strip_whitespace=True, keyword_case="upper")
except:
pass
row["generated_query"] = generated_query
row["latency_seconds"] = end_time - start_time
golden_query = row["query"]
Expand Down Expand Up @@ -100,6 +164,9 @@ def process_row(row, model_name, args):


def run_gemini_eval(args):
# Initialize Google GenAI
genai = setup_genai()

# get params from args
questions_file_list = args.questions_file
prompt_file_list = args.prompt_file
Expand All @@ -116,15 +183,14 @@ def run_gemini_eval(args):
questions_file_list, prompt_file_list, output_file_list
):
print(f"Using prompt file {prompt_file}")
# get questions
print("Preparing questions...")
print(
f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}"
)
df = prepare_questions_df(
questions_file, db_type, num_questions, k_shot, cot_table_alias
)
# create a prompt for each question

df["prompt"] = df.apply(
lambda row: generate_prompt(
prompt_file,
Expand All @@ -137,8 +203,6 @@ def run_gemini_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["cot_instructions"],
row["cot_pregen"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand All @@ -154,7 +218,7 @@ def run_gemini_eval(args):
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for row in df.to_dict("records"):
futures.append(executor.submit(process_row, row, model_name, args))
futures.append(executor.submit(process_row, row, genai, model_name, args))

with tqdm(as_completed(futures), total=len(futures)) as pbar:
for f in pbar:
Expand All @@ -172,7 +236,7 @@ def run_gemini_eval(args):
del output_df["prompt"]
print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean())
output_df = output_df.sort_values(by=["db_name", "query_category", "question"])
# get directory of output_file and create if not exist

output_dir = os.path.dirname(output_file)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
Expand All @@ -182,14 +246,14 @@ def run_gemini_eval(args):
output_df.to_pickle(output_file)

results = output_df.to_dict("records")
# upload results
with open(prompt_file, "r") as f:
prompt = f.read()

if args.upload_url is not None:
upload_results(
results=results,
url=args.upload_url,
runner_type="api_runner",
prompt=prompt,
args=args,
)
with open(prompt_file, "r") as f:
prompt = f.read()
upload_results(
results=results,
url=args.upload_url,
runner_type="api_runner",
prompt=prompt,
args=args,
)
9 changes: 6 additions & 3 deletions prompts/prompt_gemini.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
Generate a PostgreSQL query to answer the following question: `{user_question}`
Your task is to convert a text question to a {db_type} query, given a database schema.

The query will run on a database with the following schema:
Generate a SQL query that answers the question `{user_question}`.
{instructions}
This query will run on a database whose schema is represented in this SQL DDL:
{table_metadata_string}

Please return only the SQL query in your response, nothing else.
Return the SQL query that answers the question `{user_question}`
```sql
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
anthropic
argparse
func_timeout
google-generativeai
mistralai
mysql-connector-python
numpy==2.1.2
Expand Down

0 comments on commit 00457a5

Please sign in to comment.