Skip to content

Commit

Permalink
rstrip eos in evaluation (#121)
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <[email protected]>
Co-authored-by: Sukriti Sharma <[email protected]>
  • Loading branch information
alex-jw-brooks and Ssukriti authored Apr 15, 2024
1 parent b747c5f commit b48249f
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions scripts/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def parse_and_validate_args():
help="Delimiter to be used for multilabel multiclass evaluation",
default=None,
)
parser.add_argument(
"--eos_token",
help="EOS token emitted by the model; passing will rstrip() the token if present",
)
parser.add_argument("--purge_results", action=argparse.BooleanOptionalAction)

parsed_args = parser.parse_args()
Expand Down Expand Up @@ -119,6 +123,7 @@ def get_prediction_results(
data: datasets.arrow_dataset.Dataset,
max_new_tokens: int,
delimiter: Optional[str],
eos_token: Optional[str],
) -> tuple[list]:
"""Runs the model over the alpaca formatted data to get the predictions / references to be used
when computing the metrics of interest.
Expand All @@ -132,7 +137,8 @@ def get_prediction_results(
Max number of tokens to be used for generation.
delimiter: Optional[str]
Delimiter to be used for splitting apart multioutput instances.
eos_token: Optional[str]
EOS token emitted by the model, which will be rstripped from predictions.
Returns:
tuple[list]
Tuple containing:
Expand All @@ -153,8 +159,9 @@ def get_prediction_results(
ret_gen_text_only=True,
)
# Save the raw output / predicted texts
processed_pred = postprocess_output(prediction, delimiter)
processed_ref = postprocess_output(formatted_datum["output"], delimiter)
processed_pred = postprocess_output(prediction, delimiter, eos_token)
# The reference text should not have an EOS to strip
processed_ref = postprocess_output(formatted_datum["output"], delimiter, None)
preds.append(processed_pred)
refs.append(processed_ref)
model_pred_info.append(
Expand All @@ -167,18 +174,24 @@ def get_prediction_results(
return preds, refs, model_pred_info


def postprocess_output(output_text: str, delimiter: Optional[str]) -> list[str]:
def postprocess_output(
output_text: str, delimiter: Optional[str], eos_token: Optional[str]
) -> list[str]:
"""NOTE: We are returning a list here, since that is what the one hot encoder module expects.
Args:
output_text: str
Raw text to be split into one or more (potentially) delimited instances.
delimiter: Optional[str]
Delimiter to be used for splitting apart multioutput instances.
delimiter: Optional[str]
Delimiter to be used for splitting apart multioutput instances.
Returns
list[str]
List of one or more labels.
"""
if eos_token is not None:
output_text = output_text.rstrip(eos_token)
if delimiter is not None:
return [text_substr.strip() for text_substr in output_text.split(delimiter)]
return [output_text.strip()]
Expand Down Expand Up @@ -417,7 +430,11 @@ def export_experiment_info(
"json", data_files=args.data_path, split=args.split
)
predictions, references, model_pred_file_info = get_prediction_results(
tuned_model, eval_data, args.max_new_tokens, args.delimiter
tuned_model,
eval_data,
args.max_new_tokens,
args.delimiter,
args.eos_token,
)

(
Expand Down

0 comments on commit b48249f

Please sign in to comment.