From e799f7db788f17d0e0b7fed2472a6486bd252e29 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Fri, 20 Sep 2024 15:52:18 -0600 Subject: [PATCH] refactor saving tokens metadata Signed-off-by: Sukriti-Sharma4 --- tuning/data/tokenizer_data_utils.py | 2 +- tuning/sft_trainer.py | 36 +++++++++++++---------------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/tuning/data/tokenizer_data_utils.py b/tuning/data/tokenizer_data_utils.py index 36d08f8d3..a5e0cab08 100644 --- a/tuning/data/tokenizer_data_utils.py +++ b/tuning/data/tokenizer_data_utils.py @@ -44,4 +44,4 @@ def tokenizer_and_embedding_resize( input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg - return num_new_tokens + return {"num_new_tokens": num_new_tokens, "new_embedding_size": embedding_size} diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 9b808547e..900e2de50 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -291,7 +291,7 @@ def train( # TODO: lower priority but understand if resizing impacts inference quality and why its needed. # It makes sense if we manipulate tokenizer that we also save it and provide it to inference. - num_added_tokens = tokenizer_data_utils.tokenizer_and_embedding_resize( + added_tokens_dict = tokenizer_data_utils.tokenizer_and_embedding_resize( special_tokens_dict=special_tokens_dict, tokenizer=tokenizer, model=model, @@ -411,8 +411,9 @@ def train( ) trainer.train(resume_from_checkpoint) - - return trainer, num_added_tokens + additional_metadata = {} + additional_metadata["added_tokens_info"] = added_tokens_dict + return trainer, additional_metadata def save(path: str, trainer: SFTTrainer, log_level="WARNING"): @@ -463,13 +464,7 @@ def get_parser(): choices=["pt", "lora", None, "none"], default="none", ) - parser.add_argument( - "--post_process_vllm", - type=bool, - default=False, - help="Bool to indicate if post processing of LoRA adapters for vLLM \ - is required.", - ) + parser.add_argument( "--exp_metadata", type=str, @@ -644,7 +639,7 @@ def main(): combined_tracker_configs.aim_config = aim_config try: - trainer, num_added_tokens = train( + trainer, additional_train_info = train( model_args=model_args, data_args=data_args, train_args=training_args, @@ -697,19 +692,20 @@ def main(): ) sys.exit(INTERNAL_ERROR_EXIT_CODE) - # post process lora - if post_process_vllm and isinstance(tune_config, peft_config.LoraConfig): + if isinstance(tune_config, peft_config.LoraConfig): try: - checkpoint_dir = training_args.save_model_dir - if checkpoint_dir: - print(f"Post processing LoRA adapters in {checkpoint_dir}") - post_process_vLLM_adapters_new_tokens( - path_to_checkpoint=checkpoint_dir, num_added_tokens=num_added_tokens - ) + if training_args.save_model_dir: + # Write number of added tokens to artifacts + with open(os.path.join(training_args.save_model_dir,'added_tokens_info.json'), 'w') as f: + json.dump(additional_train_info["added_tokens_info"], f) + if training_args.output_dir: + # Write number of added tokens to artifacts + with open(os.path.join(training_args.output_dir,'added_tokens_info.json'), 'w') as f: + json.dump(additional_train_info["added_tokens_info"], f) except Exception as e: # pylint: disable=broad-except logging.error(traceback.format_exc()) write_termination_log( - f"Exception encountered while lora post-processing model: {e}" + f"Exception encountered when saving metadata with model artifacts: {e}" ) sys.exit(INTERNAL_ERROR_EXIT_CODE)