From b48249fab3df124d6b85cc8ce59b9e5a66ea6dcb Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 14 Apr 2024 18:09:35 -0600 Subject: [PATCH] rstrip eos in evaluation (#121) Signed-off-by: Alex-Brooks Co-authored-by: Sukriti Sharma --- scripts/run_evaluation.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/scripts/run_evaluation.py b/scripts/run_evaluation.py index ee4566cb6..f5d34a539 100644 --- a/scripts/run_evaluation.py +++ b/scripts/run_evaluation.py @@ -47,6 +47,10 @@ def parse_and_validate_args(): help="Delimiter to be used for multilabel multiclass evaluation", default=None, ) + parser.add_argument( + "--eos_token", + help="EOS token emitted by the model; passing will rstrip() the token if present", + ) parser.add_argument("--purge_results", action=argparse.BooleanOptionalAction) parsed_args = parser.parse_args() @@ -119,6 +123,7 @@ def get_prediction_results( data: datasets.arrow_dataset.Dataset, max_new_tokens: int, delimiter: Optional[str], + eos_token: Optional[str], ) -> tuple[list]: """Runs the model over the alpaca formatted data to get the predictions / references to be used when computing the metrics of interest. @@ -132,7 +137,8 @@ def get_prediction_results( Max number of tokens to be used for generation. delimiter: Optional[str] Delimiter to be used for splitting apart multioutput instances. - + eos_token: Optional[str] + EOS token emitted by the model, which will be rstripped from predictions. Returns: tuple[list] Tuple containing: @@ -153,8 +159,9 @@ def get_prediction_results( ret_gen_text_only=True, ) # Save the raw output / predicted texts - processed_pred = postprocess_output(prediction, delimiter) - processed_ref = postprocess_output(formatted_datum["output"], delimiter) + processed_pred = postprocess_output(prediction, delimiter, eos_token) + # The reference text should not have an EOS to strip + processed_ref = postprocess_output(formatted_datum["output"], delimiter, None) preds.append(processed_pred) refs.append(processed_ref) model_pred_info.append( @@ -167,18 +174,24 @@ def get_prediction_results( return preds, refs, model_pred_info -def postprocess_output(output_text: str, delimiter: Optional[str]) -> list[str]: +def postprocess_output( + output_text: str, delimiter: Optional[str], eos_token: Optional[str] +) -> list[str]: """NOTE: We are returning a list here, since that is what the one hot encoder module expects. Args: output_text: str Raw text to be split into one or more (potentially) delimited instances. delimiter: Optional[str] Delimiter to be used for splitting apart multioutput instances. + delimiter: Optional[str] + Delimiter to be used for splitting apart multioutput instances. Returns list[str] List of one or more labels. """ + if eos_token is not None: + output_text = output_text.rstrip(eos_token) if delimiter is not None: return [text_substr.strip() for text_substr in output_text.split(delimiter)] return [output_text.strip()] @@ -417,7 +430,11 @@ def export_experiment_info( "json", data_files=args.data_path, split=args.split ) predictions, references, model_pred_file_info = get_prediction_results( - tuned_model, eval_data, args.max_new_tokens, args.delimiter + tuned_model, + eval_data, + args.max_new_tokens, + args.delimiter, + args.eos_token, ) (