From 21512255240ab631ecfcc497a4e5fcd30130f699 Mon Sep 17 00:00:00 2001 From: Sukriti Sharma Date: Sun, 22 Sep 2024 21:58:53 -0600 Subject: [PATCH] feat: Refactor post-processing of adapters (#345) * refactor saving tokens metadata Signed-off-by: Sukriti-Sharma4 * remove extra check Signed-off-by: Sukriti-Sharma4 * post processing script Signed-off-by: Sukriti-Sharma4 * post processing script Signed-off-by: Sukriti-Sharma4 * fix: unit test args Signed-off-by: Sukriti-Sharma4 * undo post_process_vLLm flag Signed-off-by: Sukriti-Sharma4 --------- Signed-off-by: Sukriti-Sharma4 --- README.md | 2 - scripts/post_process_adapters_vLLM.py | 53 +++++++++++++++++++++++++++ tests/build/test_launch_script.py | 2 +- tests/test_sft_trainer.py | 6 +-- tuning/data/tokenizer_data_utils.py | 14 +++++-- tuning/sft_trainer.py | 51 +++++++++++++------------- tuning/utils/merge_model_utils.py | 14 ++++++- 7 files changed, 106 insertions(+), 36 deletions(-) create mode 100644 scripts/post_process_adapters_vLLM.py diff --git a/README.md b/README.md index 214da6d8c..40e78a838 100644 --- a/README.md +++ b/README.md @@ -665,8 +665,6 @@ The `fms_acceleration.cli` can do more to search for all available configs, plug ## Inference Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time. -If you are trying to run LoRA inference on vLLM, set the `--post_process_vllm` flag to `True`. - ### Running a single example If you want to run a single example through a model, you can pass it with the `--text` flag. diff --git a/scripts/post_process_adapters_vLLM.py b/scripts/post_process_adapters_vLLM.py new file mode 100644 index 000000000..577e0384c --- /dev/null +++ b/scripts/post_process_adapters_vLLM.py @@ -0,0 +1,53 @@ +# Standard +import argparse +import json +import os + +# Local +from tuning.utils.merge_model_utils import post_process_vLLM_adapters_new_tokens + + +### Main & arg parsing +def main(): + parser = argparse.ArgumentParser( + description="Post processes adapters due to addition of new tokens, as needed by vLLM" + ) + parser.add_argument( + "--model_path", + help="Path to tuned model containing either one or multiple checkpoints \ + Path should have file added_tokens_info.json produced by tuning \ + Hint: This will be either output_dir or save_model_dir arguments while tuning \ + If multiple checkpoints are present, each checkpoint folder name \ + should begin with 'checkpoint-'", + required=True, + ) + parser.add_argument( + "--output_model_path", + help="Output directory where post-processed artifacts will be stored. \ + If not provided, artifacts will be modified in place", + default=None, + ) + args = parser.parse_args() + + if os.path.exists(os.path.join(args.model_path, "added_tokens_info.json")): + with open( + os.path.join(args.model_path, "added_tokens_info.json"), encoding="utf-8" + ) as json_data: + added_tokens_info = json.loads(json_data) + num_added_tokens = added_tokens_info["num_added_tokens"] + else: + print("file added_tokens_info.json not in model_path. Cannot post-processes") + + if os.path.exists(os.path.join(args.model_path, "adapter_model.safetensors")): + post_process_vLLM_adapters_new_tokens( + args.model_path, args.output_model_path, num_added_tokens + ) + # if multiple checkpoints in directory, process each checkpoint + for _, dirs, _ in os.walk(args.model_path, topdown=False): + for name in dirs: + if "checkpoint-" in name.lower(): + post_process_vLLM_adapters_new_tokens( + os.path.join(args.model_path, name), + os.path.join(args.output_model_path, name), + num_added_tokens, + ) diff --git a/tests/build/test_launch_script.py b/tests/build/test_launch_script.py index 0028c59ab..030c92965 100644 --- a/tests/build/test_launch_script.py +++ b/tests/build/test_launch_script.py @@ -155,7 +155,7 @@ def test_lora_save_model_dir_separate_dirs(): _validate_termination_files_when_tuning_succeeds(output_dir) _validate_training_output(save_model_dir, "lora") - assert len(os.listdir(output_dir)) == 3 + # purpose here is to see if only one checkpoint is saved checkpoints = glob.glob(os.path.join(output_dir, "checkpoint-*")) assert len(checkpoints) == 1 diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 4049690e5..8deca782a 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -335,7 +335,6 @@ def test_parse_arguments(job_config): _, _, _, - _, ) = sft_trainer.parse_arguments(parser, job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -361,7 +360,6 @@ def test_parse_arguments_defaults(job_config): _, _, _, - _, ) = sft_trainer.parse_arguments(parser, job_config_defaults) assert str(model_args.torch_dtype) == "torch.bfloat16" assert model_args.use_flash_attn is False @@ -372,14 +370,14 @@ def test_parse_arguments_peft_method(job_config): parser = sft_trainer.get_parser() job_config_pt = copy.deepcopy(job_config) job_config_pt["peft_method"] = "pt" - _, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_pt ) assert isinstance(tune_config, peft_config.PromptTuningConfig) job_config_lora = copy.deepcopy(job_config) job_config_lora["peft_method"] = "lora" - _, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_lora ) assert isinstance(tune_config, peft_config.LoraConfig) diff --git a/tuning/data/tokenizer_data_utils.py b/tuning/data/tokenizer_data_utils.py index 36d08f8d3..ef0662d59 100644 --- a/tuning/data/tokenizer_data_utils.py +++ b/tuning/data/tokenizer_data_utils.py @@ -25,8 +25,16 @@ def tokenizer_and_embedding_resize( tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, multiple_of: int = 1, -): - """Resize tokenizer and embedding.""" +) -> dict: + """Resize tokenizer and embedding. + Args: + special_tokens_dict: Dict containing special tokens to be added. + tokenizer: transformers.PreTrainedTokenizer. + model: transformers.PreTrainedModel + multiple_of: int , embeddings are resized to multiple of this. + Return: + dict: Metadata on number of added tokens + """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) embedding_size = int(multiple_of * math.ceil(len(tokenizer) / multiple_of)) num_new_tokens = num_new_tokens + embedding_size - len(tokenizer) @@ -44,4 +52,4 @@ def tokenizer_and_embedding_resize( input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg - return num_new_tokens + return {"num_new_tokens": num_new_tokens, "new_embedding_size": embedding_size} diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 9b808547e..bd6a4db48 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -64,7 +64,6 @@ write_termination_log, ) from tuning.utils.logging import set_log_level -from tuning.utils.merge_model_utils import post_process_vLLM_adapters_new_tokens from tuning.utils.preprocessing_utils import ( format_dataset, get_data_collator, @@ -291,7 +290,7 @@ def train( # TODO: lower priority but understand if resizing impacts inference quality and why its needed. # It makes sense if we manipulate tokenizer that we also save it and provide it to inference. - num_added_tokens = tokenizer_data_utils.tokenizer_and_embedding_resize( + added_tokens_dict = tokenizer_data_utils.tokenizer_and_embedding_resize( special_tokens_dict=special_tokens_dict, tokenizer=tokenizer, model=model, @@ -411,8 +410,9 @@ def train( ) trainer.train(resume_from_checkpoint) - - return trainer, num_added_tokens + additional_metadata = {} + additional_metadata["added_tokens_info"] = added_tokens_dict + return trainer, additional_metadata def save(path: str, trainer: SFTTrainer, log_level="WARNING"): @@ -463,13 +463,7 @@ def get_parser(): choices=["pt", "lora", None, "none"], default="none", ) - parser.add_argument( - "--post_process_vllm", - type=bool, - default=False, - help="Bool to indicate if post processing of LoRA adapters for vLLM \ - is required.", - ) + parser.add_argument( "--exp_metadata", type=str, @@ -529,7 +523,6 @@ def parse_arguments(parser, json_config=None): ) = parser.parse_dict(json_config, allow_extra_keys=True) peft_method = json_config.get("peft_method") exp_metadata = json_config.get("exp_metadata") - post_process_vllm = json_config.get("post_process_vllm") else: ( model_args, @@ -549,7 +542,6 @@ def parse_arguments(parser, json_config=None): peft_method = additional.peft_method exp_metadata = additional.exp_metadata - post_process_vllm = additional.post_process_vllm if peft_method == "lora": tune_config = lora_config @@ -570,7 +562,6 @@ def parse_arguments(parser, json_config=None): fusedops_kernels_config, attention_and_distributed_packing_config, exp_metadata, - post_process_vllm, ) @@ -592,7 +583,6 @@ def main(): fusedops_kernels_config, attention_and_distributed_packing_config, exp_metadata, - post_process_vllm, ) = parse_arguments(parser, job_config) # Function to set log level for python native logger and transformers training logger @@ -644,7 +634,7 @@ def main(): combined_tracker_configs.aim_config = aim_config try: - trainer, num_added_tokens = train( + trainer, additional_train_info = train( model_args=model_args, data_args=data_args, train_args=training_args, @@ -697,19 +687,30 @@ def main(): ) sys.exit(INTERNAL_ERROR_EXIT_CODE) - # post process lora - if post_process_vllm and isinstance(tune_config, peft_config.LoraConfig): + if isinstance(tune_config, peft_config.LoraConfig): try: - checkpoint_dir = training_args.save_model_dir - if checkpoint_dir: - print(f"Post processing LoRA adapters in {checkpoint_dir}") - post_process_vLLM_adapters_new_tokens( - path_to_checkpoint=checkpoint_dir, num_added_tokens=num_added_tokens - ) + if training_args.save_model_dir: + # Write number of added tokens to artifacts + with open( + os.path.join( + training_args.save_model_dir, "added_tokens_info.json" + ), + "w", + encoding="utf-8", + ) as f: + json.dump(additional_train_info["added_tokens_info"], f) + if training_args.output_dir: + # Write number of added tokens to artifacts + with open( + os.path.join(training_args.output_dir, "added_tokens_info.json"), + "w", + encoding="utf-8", + ) as f: + json.dump(additional_train_info["added_tokens_info"], f) except Exception as e: # pylint: disable=broad-except logging.error(traceback.format_exc()) write_termination_log( - f"Exception encountered while lora post-processing model: {e}" + f"Exception encountered when saving metadata with model artifacts: {e}" ) sys.exit(INTERNAL_ERROR_EXIT_CODE) diff --git a/tuning/utils/merge_model_utils.py b/tuning/utils/merge_model_utils.py index 233ba9df1..72738e155 100644 --- a/tuning/utils/merge_model_utils.py +++ b/tuning/utils/merge_model_utils.py @@ -124,7 +124,19 @@ def post_process_vLLM_adapters_new_tokens( modified_checkpoint_path: str = None, num_added_tokens: int = 0, ): - # if not set, original checkpoint will be modified + """Post process adapters to allow inferencing on vLLM. + vLLM needs new token embedding weights added during tuning to be moved \ + to a new file new_embeddings.safetensors . \ + This function copies the embeddings weights for the added tokens from \ + adapters.safetnsors to new_embeddings.safetensors. + Args: + path_to_checkpoint: Path to folder containing adapters.safetensors. + modified_checkpoint_path: Output path where to save modified artifacts \ + after post-processing. If not provided, artifacts will be processed \ + in place in same folder. + num_added_tokens: int. Number of tokens that were added during tuning. + """ + # if not set, original checkpoint will be modified in place if not modified_checkpoint_path: modified_checkpoint_path = path_to_checkpoint