Skip to content

Commit

Permalink
Add args for follow-on questions (#98)
Browse files Browse the repository at this point in the history
* add args for prev questions

* linted
  • Loading branch information
wendy-aw authored Mar 25, 2024
1 parent 405542f commit 1942947
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 7 deletions.
8 changes: 8 additions & 0 deletions eval/api_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def run_api_eval(args):
"table_metadata_string",
"prev_invalid_sql",
"prev_error_msg",
"question_0",
"query_0",
"question_1",
"query_1",
]
].apply(
lambda row: generate_prompt(
Expand All @@ -114,6 +118,10 @@ def run_api_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand Down
8 changes: 8 additions & 0 deletions eval/gemini_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def run_gemini_eval(args):
"table_metadata_string",
"prev_invalid_sql",
"prev_error_msg",
"question_0",
"query_0",
"question_1",
"query_1",
]
].apply(
lambda row: generate_prompt(
Expand All @@ -141,6 +145,10 @@ def run_gemini_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand Down
8 changes: 8 additions & 0 deletions eval/hf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def run_hf_eval(args):
"table_metadata_string",
"prev_invalid_sql",
"prev_error_msg",
"question_0",
"query_0",
"question_1",
"query_1",
]
].apply(
lambda row: generate_prompt(
Expand All @@ -145,6 +149,10 @@ def run_hf_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand Down
8 changes: 8 additions & 0 deletions eval/llama_cpp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def run_llama_cpp_eval(args):
"table_metadata_string",
"prev_invalid_sql",
"prev_error_msg",
"question_0",
"query_0",
"question_1",
"query_1",
]
].apply(
lambda row: generate_prompt(
Expand All @@ -103,6 +107,10 @@ def run_llama_cpp_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand Down
8 changes: 8 additions & 0 deletions eval/mistral_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def run_mistral_eval(args):
"table_metadata_string",
"prev_invalid_sql",
"prev_error_msg",
"question_0",
"query_0",
"question_1",
"query_1",
]
].apply(
lambda row: generate_prompt(
Expand All @@ -167,6 +171,10 @@ def run_mistral_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand Down
8 changes: 8 additions & 0 deletions eval/mlx_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def run_mlx_eval(args):
"table_metadata_string",
"prev_invalid_sql",
"prev_error_msg",
"question_0",
"query_0",
"question_1",
"query_1",
]
].apply(
lambda row: generate_prompt(
Expand All @@ -96,6 +100,10 @@ def run_mlx_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand Down
8 changes: 8 additions & 0 deletions eval/vllm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def run_vllm_eval(args):
"table_metadata_string",
"prev_invalid_sql",
"prev_error_msg",
"question_0",
"query_0",
"question_1",
"query_1",
]
].apply(
lambda row: generate_prompt(
Expand All @@ -78,6 +82,10 @@ def run_vllm_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand Down
21 changes: 15 additions & 6 deletions utils/gen_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,29 @@ def generate_prompt(
table_metadata_string="",
prev_invalid_sql="",
prev_error_msg="",
question_0="",
query_0="",
question_1="",
query_1="",
public_data=True,
columns_to_keep=20,
shuffle=True,
columns_to_keep=40,
shuffle_metadata=False,
):
with open(prompt_file, "r") as f:
prompt = f.read()
question_instructions = question + " " + instructions

if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
question_instructions, db_name, public_data, columns_to_keep, shuffle
question_instructions,
db_name,
public_data,
columns_to_keep,
shuffle_metadata,
)
else:
pruned_metadata_str = table_metadata_string

if instructions != "":
instructions = "\n\n### Instructions\n" + instructions

prompt = prompt.format(
user_question=question,
instructions=instructions,
Expand All @@ -37,5 +42,9 @@ def generate_prompt(
glossary=glossary,
prev_invalid_sql=prev_invalid_sql,
prev_error_msg=prev_error_msg,
question_0=question_0,
query_0=query_0,
question_1=question_1,
query_1=query_1,
)
return prompt
20 changes: 19 additions & 1 deletion utils/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def prepare_questions_df(
lambda x: x.replace(". ", ".\n")
)
question_query_df["instructions"] = question_query_df["instructions"].apply(
lambda x: f"Instructions:\n{x}\n"
lambda x: f"\n### Instructions:\n{x}\n"
)
else:
question_query_df["instructions"] = ""
Expand Down Expand Up @@ -90,4 +90,22 @@ def prepare_questions_df(
else:
question_query_df["prev_error_msg"] = ""

# get question_0, query_0, question_1, query_1 if applicable
if "question_0" in question_query_df.columns:
question_query_df["question_0"] = question_query_df["question_0"].fillna("")
else:
question_query_df["question_0"] = ""
if "query_0" in question_query_df.columns:
question_query_df["query_0"] = question_query_df["query_0"].fillna("")
else:
question_query_df["query_0"] = ""
if "question_1" in question_query_df.columns:
question_query_df["question_1"] = question_query_df["question_1"].fillna("")
else:
question_query_df["question_1"] = ""
if "query_1" in question_query_df.columns:
question_query_df["query_1"] = question_query_df["query_1"].fillna("")
else:
question_query_df["query_1"] = ""

return question_query_df

0 comments on commit 1942947

Please sign in to comment.