From 1942947924962d1274cfd5abeb0aba565fb75465 Mon Sep 17 00:00:00 2001 From: wendy-aw <96569828+wendy-aw@users.noreply.github.com> Date: Mon, 25 Mar 2024 09:59:09 +0800 Subject: [PATCH] Add args for follow-on questions (#98) * add args for prev questions * linted --- eval/api_runner.py | 8 ++++++++ eval/gemini_runner.py | 8 ++++++++ eval/hf_runner.py | 8 ++++++++ eval/llama_cpp_runner.py | 8 ++++++++ eval/mistral_runner.py | 8 ++++++++ eval/mlx_runner.py | 8 ++++++++ eval/vllm_runner.py | 8 ++++++++ utils/gen_prompt.py | 21 +++++++++++++++------ utils/questions.py | 20 +++++++++++++++++++- 9 files changed, 90 insertions(+), 7 deletions(-) diff --git a/eval/api_runner.py b/eval/api_runner.py index c74e105..008916c 100644 --- a/eval/api_runner.py +++ b/eval/api_runner.py @@ -102,6 +102,10 @@ def run_api_eval(args): "table_metadata_string", "prev_invalid_sql", "prev_error_msg", + "question_0", + "query_0", + "question_1", + "query_1", ] ].apply( lambda row: generate_prompt( @@ -114,6 +118,10 @@ def run_api_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], + row["question_0"], + row["query_0"], + row["question_1"], + row["query_1"], public_data, args.num_columns, args.shuffle_metadata, diff --git a/eval/gemini_runner.py b/eval/gemini_runner.py index 1c4ca33..f7ce6c6 100644 --- a/eval/gemini_runner.py +++ b/eval/gemini_runner.py @@ -129,6 +129,10 @@ def run_gemini_eval(args): "table_metadata_string", "prev_invalid_sql", "prev_error_msg", + "question_0", + "query_0", + "question_1", + "query_1", ] ].apply( lambda row: generate_prompt( @@ -141,6 +145,10 @@ def run_gemini_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], + row["question_0"], + row["query_0"], + row["question_1"], + row["query_1"], public_data, args.num_columns, args.shuffle_metadata, diff --git a/eval/hf_runner.py b/eval/hf_runner.py index 495c7b6..25891ec 100644 --- a/eval/hf_runner.py +++ b/eval/hf_runner.py @@ -133,6 +133,10 @@ def run_hf_eval(args): "table_metadata_string", "prev_invalid_sql", "prev_error_msg", + "question_0", + "query_0", + "question_1", + "query_1", ] ].apply( lambda row: generate_prompt( @@ -145,6 +149,10 @@ def run_hf_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], + row["question_0"], + row["query_0"], + row["question_1"], + row["query_1"], public_data, args.num_columns, args.shuffle_metadata, diff --git a/eval/llama_cpp_runner.py b/eval/llama_cpp_runner.py index 4995a05..56e1fb6 100644 --- a/eval/llama_cpp_runner.py +++ b/eval/llama_cpp_runner.py @@ -91,6 +91,10 @@ def run_llama_cpp_eval(args): "table_metadata_string", "prev_invalid_sql", "prev_error_msg", + "question_0", + "query_0", + "question_1", + "query_1", ] ].apply( lambda row: generate_prompt( @@ -103,6 +107,10 @@ def run_llama_cpp_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], + row["question_0"], + row["query_0"], + row["question_1"], + row["query_1"], public_data, args.num_columns, args.shuffle_metadata, diff --git a/eval/mistral_runner.py b/eval/mistral_runner.py index bdbab9b..7eb9bf2 100644 --- a/eval/mistral_runner.py +++ b/eval/mistral_runner.py @@ -155,6 +155,10 @@ def run_mistral_eval(args): "table_metadata_string", "prev_invalid_sql", "prev_error_msg", + "question_0", + "query_0", + "question_1", + "query_1", ] ].apply( lambda row: generate_prompt( @@ -167,6 +171,10 @@ def run_mistral_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], + row["question_0"], + row["query_0"], + row["question_1"], + row["query_1"], public_data, args.num_columns, args.shuffle_metadata, diff --git a/eval/mlx_runner.py b/eval/mlx_runner.py index 135a269..9b3a5af 100644 --- a/eval/mlx_runner.py +++ b/eval/mlx_runner.py @@ -84,6 +84,10 @@ def run_mlx_eval(args): "table_metadata_string", "prev_invalid_sql", "prev_error_msg", + "question_0", + "query_0", + "question_1", + "query_1", ] ].apply( lambda row: generate_prompt( @@ -96,6 +100,10 @@ def run_mlx_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], + row["question_0"], + row["query_0"], + row["question_1"], + row["query_1"], public_data, args.num_columns, args.shuffle_metadata, diff --git a/eval/vllm_runner.py b/eval/vllm_runner.py index d5555ac..1425f6c 100644 --- a/eval/vllm_runner.py +++ b/eval/vllm_runner.py @@ -66,6 +66,10 @@ def run_vllm_eval(args): "table_metadata_string", "prev_invalid_sql", "prev_error_msg", + "question_0", + "query_0", + "question_1", + "query_1", ] ].apply( lambda row: generate_prompt( @@ -78,6 +82,10 @@ def run_vllm_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], + row["question_0"], + row["query_0"], + row["question_1"], + row["query_1"], public_data, args.num_columns, args.shuffle_metadata, diff --git a/utils/gen_prompt.py b/utils/gen_prompt.py index 637788f..94aceda 100644 --- a/utils/gen_prompt.py +++ b/utils/gen_prompt.py @@ -11,9 +11,13 @@ def generate_prompt( table_metadata_string="", prev_invalid_sql="", prev_error_msg="", + question_0="", + query_0="", + question_1="", + query_1="", public_data=True, - columns_to_keep=20, - shuffle=True, + columns_to_keep=40, + shuffle_metadata=False, ): with open(prompt_file, "r") as f: prompt = f.read() @@ -21,14 +25,15 @@ def generate_prompt( if table_metadata_string == "": pruned_metadata_str = prune_metadata_str( - question_instructions, db_name, public_data, columns_to_keep, shuffle + question_instructions, + db_name, + public_data, + columns_to_keep, + shuffle_metadata, ) else: pruned_metadata_str = table_metadata_string - if instructions != "": - instructions = "\n\n### Instructions\n" + instructions - prompt = prompt.format( user_question=question, instructions=instructions, @@ -37,5 +42,9 @@ def generate_prompt( glossary=glossary, prev_invalid_sql=prev_invalid_sql, prev_error_msg=prev_error_msg, + question_0=question_0, + query_0=query_0, + question_1=question_1, + query_1=query_1, ) return prompt diff --git a/utils/questions.py b/utils/questions.py index 34e9f0e..1bdd532 100644 --- a/utils/questions.py +++ b/utils/questions.py @@ -29,7 +29,7 @@ def prepare_questions_df( lambda x: x.replace(". ", ".\n") ) question_query_df["instructions"] = question_query_df["instructions"].apply( - lambda x: f"Instructions:\n{x}\n" + lambda x: f"\n### Instructions:\n{x}\n" ) else: question_query_df["instructions"] = "" @@ -90,4 +90,22 @@ def prepare_questions_df( else: question_query_df["prev_error_msg"] = "" + # get question_0, query_0, question_1, query_1 if applicable + if "question_0" in question_query_df.columns: + question_query_df["question_0"] = question_query_df["question_0"].fillna("") + else: + question_query_df["question_0"] = "" + if "query_0" in question_query_df.columns: + question_query_df["query_0"] = question_query_df["query_0"].fillna("") + else: + question_query_df["query_0"] = "" + if "question_1" in question_query_df.columns: + question_query_df["question_1"] = question_query_df["question_1"].fillna("") + else: + question_query_df["question_1"] = "" + if "query_1" in question_query_df.columns: + question_query_df["query_1"] = question_query_df["query_1"].fillna("") + else: + question_query_df["query_1"] = "" + return question_query_df