Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add args for follow-on questions #98

Merged
merged 2 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions eval/api_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def run_api_eval(args):
"table_metadata_string",
"prev_invalid_sql",
"prev_error_msg",
"question_0",
"query_0",
"question_1",
"query_1",
]
].apply(
lambda row: generate_prompt(
Expand All @@ -114,6 +118,10 @@ def run_api_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand Down
8 changes: 8 additions & 0 deletions eval/gemini_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def run_gemini_eval(args):
"table_metadata_string",
"prev_invalid_sql",
"prev_error_msg",
"question_0",
"query_0",
"question_1",
"query_1",
]
].apply(
lambda row: generate_prompt(
Expand All @@ -141,6 +145,10 @@ def run_gemini_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand Down
8 changes: 8 additions & 0 deletions eval/hf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def run_hf_eval(args):
"table_metadata_string",
"prev_invalid_sql",
"prev_error_msg",
"question_0",
"query_0",
"question_1",
"query_1",
]
].apply(
lambda row: generate_prompt(
Expand All @@ -145,6 +149,10 @@ def run_hf_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand Down
8 changes: 8 additions & 0 deletions eval/llama_cpp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def run_llama_cpp_eval(args):
"table_metadata_string",
"prev_invalid_sql",
"prev_error_msg",
"question_0",
"query_0",
"question_1",
"query_1",
]
].apply(
lambda row: generate_prompt(
Expand All @@ -103,6 +107,10 @@ def run_llama_cpp_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand Down
8 changes: 8 additions & 0 deletions eval/mistral_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def run_mistral_eval(args):
"table_metadata_string",
"prev_invalid_sql",
"prev_error_msg",
"question_0",
"query_0",
"question_1",
"query_1",
]
].apply(
lambda row: generate_prompt(
Expand All @@ -167,6 +171,10 @@ def run_mistral_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand Down
8 changes: 8 additions & 0 deletions eval/mlx_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def run_mlx_eval(args):
"table_metadata_string",
"prev_invalid_sql",
"prev_error_msg",
"question_0",
"query_0",
"question_1",
"query_1",
]
].apply(
lambda row: generate_prompt(
Expand All @@ -96,6 +100,10 @@ def run_mlx_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand Down
8 changes: 8 additions & 0 deletions eval/vllm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def run_vllm_eval(args):
"table_metadata_string",
"prev_invalid_sql",
"prev_error_msg",
"question_0",
"query_0",
"question_1",
"query_1",
]
].apply(
lambda row: generate_prompt(
Expand All @@ -78,6 +82,10 @@ def run_vllm_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
public_data,
args.num_columns,
args.shuffle_metadata,
Expand Down
21 changes: 15 additions & 6 deletions utils/gen_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,29 @@ def generate_prompt(
table_metadata_string="",
prev_invalid_sql="",
prev_error_msg="",
question_0="",
query_0="",
question_1="",
query_1="",
public_data=True,
columns_to_keep=20,
shuffle=True,
columns_to_keep=40,
shuffle_metadata=False,
):
with open(prompt_file, "r") as f:
prompt = f.read()
question_instructions = question + " " + instructions

if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
question_instructions, db_name, public_data, columns_to_keep, shuffle
question_instructions,
db_name,
public_data,
columns_to_keep,
shuffle_metadata,
)
else:
pruned_metadata_str = table_metadata_string

if instructions != "":
instructions = "\n\n### Instructions\n" + instructions

prompt = prompt.format(
user_question=question,
instructions=instructions,
Expand All @@ -37,5 +42,9 @@ def generate_prompt(
glossary=glossary,
prev_invalid_sql=prev_invalid_sql,
prev_error_msg=prev_error_msg,
question_0=question_0,
query_0=query_0,
question_1=question_1,
query_1=query_1,
)
return prompt
20 changes: 19 additions & 1 deletion utils/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def prepare_questions_df(
lambda x: x.replace(". ", ".\n")
)
question_query_df["instructions"] = question_query_df["instructions"].apply(
lambda x: f"Instructions:\n{x}\n"
lambda x: f"\n### Instructions:\n{x}\n"
)
else:
question_query_df["instructions"] = ""
Expand Down Expand Up @@ -90,4 +90,22 @@ def prepare_questions_df(
else:
question_query_df["prev_error_msg"] = ""

# get question_0, query_0, question_1, query_1 if applicable
if "question_0" in question_query_df.columns:
question_query_df["question_0"] = question_query_df["question_0"].fillna("")
else:
question_query_df["question_0"] = ""
if "query_0" in question_query_df.columns:
question_query_df["query_0"] = question_query_df["query_0"].fillna("")
else:
question_query_df["query_0"] = ""
if "question_1" in question_query_df.columns:
question_query_df["question_1"] = question_query_df["question_1"].fillna("")
else:
question_query_df["question_1"] = ""
if "query_1" in question_query_df.columns:
question_query_df["query_1"] = question_query_df["query_1"].fillna("")
else:
question_query_df["query_1"] = ""

return question_query_df
Loading