From 5b640b76db46655cc87bbfc985d46b9334b6917b Mon Sep 17 00:00:00 2001 From: jp Date: Fri, 13 Dec 2024 14:05:45 +0800 Subject: [PATCH] Gemini 2 --- README.md | 15 +++-- eval/gemini_runner.py | 131 ++++++++++++++++++++++++++++++--------- prompts/prompt_gemini.md | 9 ++- requirements.txt | 1 + 4 files changed, 114 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index cde70d7..9b319d3 100644 --- a/README.md +++ b/README.md @@ -363,18 +363,17 @@ python -W ignore main.py \ ### Gemini -Before running this, you must create an account with [Google AI](https://ai.google.dev/) and set your credentials with `export GOOGLE_APPLICATION_CREDENTIALS=`. Then, install these packages with `pip install vertexai google-cloud-aiplatform`. +Before running this, you need to set your credentials with `export GEMINI_API_KEY=`. Then, install these packages with `pip install google-generative-ai`. ```bash -python -W ignore main.py \ +python main.py \ -db postgres \ - -q "data/questions_gen_postgres.csv" \ - -o "results/gemini_pro.csv" \ + -q "data/questions_gen_postgres.csv" "data/instruct_basic_postgres.csv" "data/instruct_advanced_postgres.csv" \ + -o "results/gemini_flash_basic.csv" "results/gemini_flash_basic.csv" "results/gemini_flash_advanced.csv" \ -g gemini \ - -f "prompts/prompt_gemini.md" \ - -m gemini-pro \ - -p 1 \ - -n 5 + -f "prompts/prompt_gemini.md" "prompts/prompt_gemini.md" "prompts/prompt_gemini.md" \ + -m gemini-2.0-flash-exp \ + -p 10 ``` ### Mistral diff --git a/eval/gemini_runner.py b/eval/gemini_runner.py index 031e73a..af46fb4 100644 --- a/eval/gemini_runner.py +++ b/eval/gemini_runner.py @@ -3,27 +3,44 @@ from time import time import pandas as pd +import sqlparse from tqdm import tqdm -from vertexai.preview.generative_models import GenerativeModel from eval.eval import compare_query_results from utils.creds import db_creds_all +from utils.dialects import convert_postgres_ddl_to_dialect +from utils.gen_prompt import to_prompt_schema from utils.pruning import prune_metadata_str from utils.questions import prepare_questions_df from utils.reporting import upload_results -def multiturn_generate_content(model_name="gemini-pro"): - config = {"max_output_tokens": 600, "temperature": 0, "top_p": 1} - model = GenerativeModel(model_name, generation_config=config) - chat = model.start_chat() - return chat +def setup_genai(api_key=None): + """Initialize the Google GenAI client""" + if api_key is None: + api_key = os.getenv("GEMINI_API_KEY") + if not api_key: + raise ValueError("GEMINI_API_KEY environment variable must be set") + import google.generativeai as genai + + genai.configure(api_key=api_key) + return genai + + +def get_chat_model(genai, model_name="gemini-pro"): + """Get a chat model instance with configured parameters""" + generation_config = {"max_output_tokens": 600, "temperature": 0, "top_p": 1.0} + model = genai.GenerativeModel( + model_name=model_name, generation_config=generation_config + ) + return model.start_chat() def generate_prompt( prompt_file, question, db_name, + db_type, instructions="", k_shot_prompt="", glossary="", @@ -37,20 +54,62 @@ def generate_prompt( if "gemini" not in prompt_file: raise ValueError("Invalid prompt file. Please use prompt_gemini.md") + if public_data: + from defog_data.metadata import dbs + import defog_data.supplementary as sup + else: + # raise Exception("Replace this with your private data import") + from defog_data_private.metadata import dbs with open(prompt_file, "r") as f: prompt = f.read() question_instructions = question + " " + instructions if table_metadata_string == "": - pruned_metadata_ddl, join_str = prune_metadata_str( - question_instructions, db_name, public_data, num_columns_to_keep, shuffle - ) - pruned_metadata_str = pruned_metadata_ddl + join_str + if num_columns_to_keep > 0: + pruned_metadata_ddl, join_str = prune_metadata_str( + question_instructions, + db_name, + public_data, + num_columns_to_keep, + shuffle, + ) + pruned_metadata_ddl = convert_postgres_ddl_to_dialect( + postgres_ddl=pruned_metadata_ddl, + to_dialect=db_type, + db_name=db_name, + ) + pruned_metadata_str = pruned_metadata_ddl + join_str + elif num_columns_to_keep == 0: + md = dbs[db_name]["table_metadata"] + pruned_metadata_str = to_prompt_schema(md, shuffle) + pruned_metadata_str = convert_postgres_ddl_to_dialect( + postgres_ddl=pruned_metadata_str, + to_dialect=db_type, + db_name=db_name, + ) + column_join = sup.columns_join.get(db_name, {}) + # get join_str from column_join + join_list = [] + for values in column_join.values(): + col_1, col_2 = values[0] + # add to join_list + join_str = f"{col_1} can be joined with {col_2}" + if join_str not in join_list: + join_list.append(join_str) + if len(join_list) > 0: + join_str = "\nHere is a list of joinable columns:\n" + "\n".join( + join_list + ) + else: + join_str = "" + pruned_metadata_str = pruned_metadata_str + join_str + else: + raise ValueError("columns_to_keep must be >= 0") else: pruned_metadata_str = table_metadata_string - prompt = prompt.format( user_question=question, + db_type=db_type, instructions=instructions, table_metadata_string=pruned_metadata_str, k_shot_prompt=k_shot_prompt, @@ -61,14 +120,22 @@ def generate_prompt( return prompt -def process_row(row, model_name, args): +def process_row(row, genai, model_name, args): start_time = time() - chat = multiturn_generate_content(model_name=model_name) + chat = get_chat_model(genai, model_name=model_name) response = chat.send_message(row["prompt"]) end_time = time() - generated_query = response.text.split("```sql")[-1].split("```")[0].strip() - + generated_query = response.text.split("```sql", 1)[-1].split("```", 1)[0].strip() + try: + generated_query = sqlparse.format( + generated_query, + strip_comments=True, + strip_whitespace=True, + keyword_case="upper", + ) + except: + pass row["generated_query"] = generated_query row["latency_seconds"] = end_time - start_time golden_query = row["query"] @@ -100,6 +167,9 @@ def process_row(row, model_name, args): def run_gemini_eval(args): + # Initialize Google GenAI + genai = setup_genai() + # get params from args questions_file_list = args.questions_file prompt_file_list = args.prompt_file @@ -116,7 +186,6 @@ def run_gemini_eval(args): questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -124,7 +193,7 @@ def run_gemini_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -137,8 +206,6 @@ def run_gemini_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["cot_instructions"], - row["cot_pregen"], public_data, args.num_columns, args.shuffle_metadata, @@ -154,7 +221,9 @@ def run_gemini_eval(args): with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] for row in df.to_dict("records"): - futures.append(executor.submit(process_row, row, model_name, args)) + futures.append( + executor.submit(process_row, row, genai, model_name, args) + ) with tqdm(as_completed(futures), total=len(futures)) as pbar: for f in pbar: @@ -172,7 +241,7 @@ def run_gemini_eval(args): del output_df["prompt"] print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + output_dir = os.path.dirname(output_file) if not os.path.exists(output_dir): os.makedirs(output_dir) @@ -182,14 +251,14 @@ def run_gemini_eval(args): output_df.to_pickle(output_file) results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() + if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="api_runner", - prompt=prompt, - args=args, - ) + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=results, + url=args.upload_url, + runner_type="api_runner", + prompt=prompt, + args=args, + ) diff --git a/prompts/prompt_gemini.md b/prompts/prompt_gemini.md index ade3cce..bd49df9 100644 --- a/prompts/prompt_gemini.md +++ b/prompts/prompt_gemini.md @@ -1,6 +1,9 @@ -Generate a PostgreSQL query to answer the following question: `{user_question}` +Your task is to convert a text question to a {db_type} query, given a database schema. -The query will run on a database with the following schema: +Generate a SQL query that answers the question `{user_question}`. +{instructions} +This query will run on a database whose schema is represented in this SQL DDL: {table_metadata_string} -Please return only the SQL query in your response, nothing else. +Return the SQL query that answers the question `{user_question}` +```sql \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 31360ca..8e9afc5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ anthropic argparse func_timeout +google-generativeai mistralai mysql-connector-python numpy==2.1.2