From 6cc43d83113c877b96f14070b9c66ff7fc890ad0 Mon Sep 17 00:00:00 2001 From: Rishabh Srivastava Date: Fri, 16 Feb 2024 00:39:32 +0530 Subject: [PATCH 1/2] Add gemini runner --- README.md | 12 +++ eval/gemini_runner.py | 193 ++++++++++++++++++++++++++++++++++++++++++ main.py | 4 + 3 files changed, 209 insertions(+) create mode 100644 eval/gemini_runner.py diff --git a/README.md b/README.md index 88ffc27..bd3cb11 100644 --- a/README.md +++ b/README.md @@ -251,6 +251,18 @@ python -W ignore main.py \ -m mlx-community/defog-sqlcoder-7b-2 ``` +### Gemini +```bash +python -W ignore main.py \ + -db postgres \ + -o "results/results.csv" \ + -g gemini \ + -f "prompts/prompt.md" \ + -m gemini-pro \ + -p 1 \ + -n 5 +``` + ### CLI Flags You can use the following flags in the command line to change the configurations of your evaluation runs. | CLI Flags | Description | diff --git a/eval/gemini_runner.py b/eval/gemini_runner.py new file mode 100644 index 0000000..1f1672e --- /dev/null +++ b/eval/gemini_runner.py @@ -0,0 +1,193 @@ +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Optional + +from eval.eval import compare_query_results +import pandas as pd +from utils.pruning import prune_metadata_str +from utils.questions import prepare_questions_df +from utils.creds import db_creds_all +from tqdm import tqdm +from time import time +import requests +from utils.reporting import upload_results + +import vertexai +from vertexai.preview.generative_models import GenerativeModel, Part + +def multiturn_generate_content(model_name="gemini-pro"): + config = { + "max_output_tokens": 600, + "temperature": 0, + "top_p": 1 + } + model = GenerativeModel(model_name, generation_config=config) + chat = model.start_chat() + return chat + +def generate_prompt( + prompt_file, + question, + db_name, + instructions="", + k_shot_prompt="", + glossary="", + table_metadata_string="", + public_data=True, +): + with open(prompt_file, "r") as f: + prompt = f.read() + question_instructions = question + " " + instructions + + if table_metadata_string == "": + pruned_metadata_str = prune_metadata_str( + question_instructions, db_name, public_data + ) + else: + pruned_metadata_str = table_metadata_string + + prompt = """Generate a PostgreSQL query to answer the following question: `{user_question}` + +The query will run on a database with the following schema: +{table_metadata_string} + +Please return only the SQL query in your response, nothing else.""" + + prompt = prompt.format( + user_question=question, + # instructions=instructions, + table_metadata_string=pruned_metadata_str, + # k_shot_prompt=k_shot_prompt, + # glossary=glossary, + ) + return prompt + + +def process_row(row, model_name): + start_time = time() + chat = multiturn_generate_content(model_name=model_name) + response = chat.send_message(row["prompt"]) + + end_time = time() + generated_query = response.text.split("```sql")[-1].split("```")[0].strip() + print(generated_query) + + row["generated_query"] = generated_query + row["latency_seconds"] = end_time - start_time + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + exact_match = correct = 0 + + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[row["db_type"]], + question=question, + query_category=query_category, + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row + + +def run_gemini_eval(args): + # get params from args + questions_file = args.questions_file + prompt_file_list = args.prompt_file + num_questions = args.num_questions + public_data = not args.use_private_data + model_name = args.model + output_file_list = args.output_file + k_shot = args.k_shot + max_workers = args.parallel_threads + db_type = args.db_type + + # get questions + print("Preparing questions...") + print( + f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" + ) + df = prepare_questions_df(questions_file, db_type, num_questions, k_shot) + + for prompt_file, output_file in zip(prompt_file_list, output_file_list): + # create a prompt for each question + df["prompt"] = df[ + [ + "question", + "db_name", + "instructions", + "k_shot_prompt", + "glossary", + "table_metadata_string", + ] + ].apply( + lambda row: generate_prompt( + prompt_file, + row["question"], + row["db_name"], + row["instructions"], + row["k_shot_prompt"], + row["glossary"], + row["table_metadata_string"], + public_data, + ), + axis=1, + ) + + total_tried = 0 + total_correct = 0 + output_rows = [] + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for row in df.to_dict("records"): + futures.append(executor.submit(process_row, row, model_name)) + + with tqdm(as_completed(futures), total=len(futures)) as pbar: + for f in pbar: + row = f.result() + output_rows.append(row) + if row["correct"]: + total_correct += 1 + total_tried += 1 + pbar.update(1) + pbar.set_description( + f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" + ) + + output_df = pd.DataFrame(output_rows) + del output_df["prompt"] + print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) + output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + # get directory of output_file and create if not exist + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + try: + output_df.to_csv(output_file, index=False, float_format="%.2f") + except: + output_df.to_pickle(output_file) + + results = output_df.to_dict("records") + # upload results + with open(prompt_file, "r") as f: + prompt = f.read() + if args.upload_url is not None: + upload_results( + results=results, + url=args.upload_url, + runner_type="api_runner", + prompt=prompt, + args=args, + ) diff --git a/main.py b/main.py index a9c9504..deeedfb 100644 --- a/main.py +++ b/main.py @@ -84,6 +84,10 @@ from eval.mlx_runner import run_mlx_eval run_mlx_eval(args) + elif args.model_type == "gemini": + from eval.gemini_runner import run_gemini_eval + + run_gemini_eval(args) else: raise ValueError( f"Invalid model type: {args.model_type}. Model type must be one of: 'oa', 'hf', 'anthropic', 'vllm', 'api'" From 84e8d554a4a154cc7ce115d8391b1ddf980bedd0 Mon Sep 17 00:00:00 2001 From: Rishabh Srivastava Date: Fri, 16 Feb 2024 00:43:42 +0530 Subject: [PATCH 2/2] linting --- eval/gemini_runner.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/eval/gemini_runner.py b/eval/gemini_runner.py index 1f1672e..232363c 100644 --- a/eval/gemini_runner.py +++ b/eval/gemini_runner.py @@ -15,16 +15,14 @@ import vertexai from vertexai.preview.generative_models import GenerativeModel, Part + def multiturn_generate_content(model_name="gemini-pro"): - config = { - "max_output_tokens": 600, - "temperature": 0, - "top_p": 1 - } + config = {"max_output_tokens": 600, "temperature": 0, "top_p": 1} model = GenerativeModel(model_name, generation_config=config) chat = model.start_chat() return chat + def generate_prompt( prompt_file, question, @@ -45,7 +43,7 @@ def generate_prompt( ) else: pruned_metadata_str = table_metadata_string - + prompt = """Generate a PostgreSQL query to answer the following question: `{user_question}` The query will run on a database with the following schema: @@ -67,7 +65,7 @@ def process_row(row, model_name): start_time = time() chat = multiturn_generate_content(model_name=model_name) response = chat.send_message(row["prompt"]) - + end_time = time() generated_query = response.text.split("```sql")[-1].split("```")[0].strip() print(generated_query)