From e799f7db788f17d0e0b7fed2472a6486bd252e29 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Fri, 20 Sep 2024 15:52:18 -0600 Subject: [PATCH 1/6] refactor saving tokens metadata Signed-off-by: Sukriti-Sharma4 --- tuning/data/tokenizer_data_utils.py | 2 +- tuning/sft_trainer.py | 36 +++++++++++++---------------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/tuning/data/tokenizer_data_utils.py b/tuning/data/tokenizer_data_utils.py index 36d08f8d3..a5e0cab08 100644 --- a/tuning/data/tokenizer_data_utils.py +++ b/tuning/data/tokenizer_data_utils.py @@ -44,4 +44,4 @@ def tokenizer_and_embedding_resize( input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg - return num_new_tokens + return {"num_new_tokens": num_new_tokens, "new_embedding_size": embedding_size} diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 9b808547e..900e2de50 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -291,7 +291,7 @@ def train( # TODO: lower priority but understand if resizing impacts inference quality and why its needed. # It makes sense if we manipulate tokenizer that we also save it and provide it to inference. - num_added_tokens = tokenizer_data_utils.tokenizer_and_embedding_resize( + added_tokens_dict = tokenizer_data_utils.tokenizer_and_embedding_resize( special_tokens_dict=special_tokens_dict, tokenizer=tokenizer, model=model, @@ -411,8 +411,9 @@ def train( ) trainer.train(resume_from_checkpoint) - - return trainer, num_added_tokens + additional_metadata = {} + additional_metadata["added_tokens_info"] = added_tokens_dict + return trainer, additional_metadata def save(path: str, trainer: SFTTrainer, log_level="WARNING"): @@ -463,13 +464,7 @@ def get_parser(): choices=["pt", "lora", None, "none"], default="none", ) - parser.add_argument( - "--post_process_vllm", - type=bool, - default=False, - help="Bool to indicate if post processing of LoRA adapters for vLLM \ - is required.", - ) + parser.add_argument( "--exp_metadata", type=str, @@ -644,7 +639,7 @@ def main(): combined_tracker_configs.aim_config = aim_config try: - trainer, num_added_tokens = train( + trainer, additional_train_info = train( model_args=model_args, data_args=data_args, train_args=training_args, @@ -697,19 +692,20 @@ def main(): ) sys.exit(INTERNAL_ERROR_EXIT_CODE) - # post process lora - if post_process_vllm and isinstance(tune_config, peft_config.LoraConfig): + if isinstance(tune_config, peft_config.LoraConfig): try: - checkpoint_dir = training_args.save_model_dir - if checkpoint_dir: - print(f"Post processing LoRA adapters in {checkpoint_dir}") - post_process_vLLM_adapters_new_tokens( - path_to_checkpoint=checkpoint_dir, num_added_tokens=num_added_tokens - ) + if training_args.save_model_dir: + # Write number of added tokens to artifacts + with open(os.path.join(training_args.save_model_dir,'added_tokens_info.json'), 'w') as f: + json.dump(additional_train_info["added_tokens_info"], f) + if training_args.output_dir: + # Write number of added tokens to artifacts + with open(os.path.join(training_args.output_dir,'added_tokens_info.json'), 'w') as f: + json.dump(additional_train_info["added_tokens_info"], f) except Exception as e: # pylint: disable=broad-except logging.error(traceback.format_exc()) write_termination_log( - f"Exception encountered while lora post-processing model: {e}" + f"Exception encountered when saving metadata with model artifacts: {e}" ) sys.exit(INTERNAL_ERROR_EXIT_CODE) From 2c0c20631453e57300084a28f108f552bffca7b2 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Fri, 20 Sep 2024 16:04:16 -0600 Subject: [PATCH 2/6] remove extra check Signed-off-by: Sukriti-Sharma4 --- tests/build/test_launch_script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/build/test_launch_script.py b/tests/build/test_launch_script.py index 0028c59ab..030c92965 100644 --- a/tests/build/test_launch_script.py +++ b/tests/build/test_launch_script.py @@ -155,7 +155,7 @@ def test_lora_save_model_dir_separate_dirs(): _validate_termination_files_when_tuning_succeeds(output_dir) _validate_training_output(save_model_dir, "lora") - assert len(os.listdir(output_dir)) == 3 + # purpose here is to see if only one checkpoint is saved checkpoints = glob.glob(os.path.join(output_dir, "checkpoint-*")) assert len(checkpoints) == 1 From 5e4b79620e74e31fd56f58f94b63871bd5018867 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Fri, 20 Sep 2024 17:25:12 -0600 Subject: [PATCH 3/6] post processing script Signed-off-by: Sukriti-Sharma4 --- scripts/post_process_adapters_vLLM.py | 32 +++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 scripts/post_process_adapters_vLLM.py diff --git a/scripts/post_process_adapters_vLLM.py b/scripts/post_process_adapters_vLLM.py new file mode 100644 index 000000000..ae9885591 --- /dev/null +++ b/scripts/post_process_adapters_vLLM.py @@ -0,0 +1,32 @@ +# Standard +import argparse +import json +import os + + +# Local +from tuning.utils.merge_model_utils import post_process_vLLM_adapters_new_tokens + +### Main & arg parsing +def main(): + parser = argparse.ArgumentParser( + description="Post processes adapters due to addition of new tokens, as needed by vLLM" + ) + parser.add_argument( + "--model_path", help="Path to tuned model containing either one or multiple checkpoints \ + Path should have file added_tokens_info.json produced by tuning \ + Hint: This will be either output_dir specified while tuning or save_model_dir", required=True + ) + parser.add_argument( + "--output_model_path", help="Output directory where post-processed artifacts will be stored. \ + If not provided, artifacts will be modified in place", default=None + ) + args = parser.parse_args() + + if os.path.exists(os.path.join(args.model_path, 'added_tokens_info.json')): + with open(os.path.join(args.model_path, 'added_tokens_info.json')) as json_data: + added_tokens_info = json.loads(json_data) + num_added_tokens = added_tokens_info["num_added_tokens"] + else: + print("file added_tokens_info.json not in model_path. Cannot post-processes") + post_process_vLLM_adapters_new_tokens(args.model_path, args.output_model_path, num_added_tokens) From 3b903f34b81003d3a3a8017bd7eff0f0198ac84a Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Sun, 22 Sep 2024 21:17:09 -0600 Subject: [PATCH 4/6] post processing script Signed-off-by: Sukriti-Sharma4 --- scripts/post_process_adapters_vLLM.py | 43 ++++++++++++++++++++------- tuning/data/tokenizer_data_utils.py | 12 ++++++-- tuning/sft_trainer.py | 18 +++++++---- tuning/utils/merge_model_utils.py | 14 ++++++++- 4 files changed, 68 insertions(+), 19 deletions(-) diff --git a/scripts/post_process_adapters_vLLM.py b/scripts/post_process_adapters_vLLM.py index ae9885591..577e0384c 100644 --- a/scripts/post_process_adapters_vLLM.py +++ b/scripts/post_process_adapters_vLLM.py @@ -3,30 +3,51 @@ import json import os - # Local from tuning.utils.merge_model_utils import post_process_vLLM_adapters_new_tokens + ### Main & arg parsing def main(): parser = argparse.ArgumentParser( description="Post processes adapters due to addition of new tokens, as needed by vLLM" ) parser.add_argument( - "--model_path", help="Path to tuned model containing either one or multiple checkpoints \ + "--model_path", + help="Path to tuned model containing either one or multiple checkpoints \ Path should have file added_tokens_info.json produced by tuning \ - Hint: This will be either output_dir specified while tuning or save_model_dir", required=True + Hint: This will be either output_dir or save_model_dir arguments while tuning \ + If multiple checkpoints are present, each checkpoint folder name \ + should begin with 'checkpoint-'", + required=True, ) parser.add_argument( - "--output_model_path", help="Output directory where post-processed artifacts will be stored. \ - If not provided, artifacts will be modified in place", default=None + "--output_model_path", + help="Output directory where post-processed artifacts will be stored. \ + If not provided, artifacts will be modified in place", + default=None, ) args = parser.parse_args() - - if os.path.exists(os.path.join(args.model_path, 'added_tokens_info.json')): - with open(os.path.join(args.model_path, 'added_tokens_info.json')) as json_data: - added_tokens_info = json.loads(json_data) - num_added_tokens = added_tokens_info["num_added_tokens"] + + if os.path.exists(os.path.join(args.model_path, "added_tokens_info.json")): + with open( + os.path.join(args.model_path, "added_tokens_info.json"), encoding="utf-8" + ) as json_data: + added_tokens_info = json.loads(json_data) + num_added_tokens = added_tokens_info["num_added_tokens"] else: print("file added_tokens_info.json not in model_path. Cannot post-processes") - post_process_vLLM_adapters_new_tokens(args.model_path, args.output_model_path, num_added_tokens) + + if os.path.exists(os.path.join(args.model_path, "adapter_model.safetensors")): + post_process_vLLM_adapters_new_tokens( + args.model_path, args.output_model_path, num_added_tokens + ) + # if multiple checkpoints in directory, process each checkpoint + for _, dirs, _ in os.walk(args.model_path, topdown=False): + for name in dirs: + if "checkpoint-" in name.lower(): + post_process_vLLM_adapters_new_tokens( + os.path.join(args.model_path, name), + os.path.join(args.output_model_path, name), + num_added_tokens, + ) diff --git a/tuning/data/tokenizer_data_utils.py b/tuning/data/tokenizer_data_utils.py index a5e0cab08..ef0662d59 100644 --- a/tuning/data/tokenizer_data_utils.py +++ b/tuning/data/tokenizer_data_utils.py @@ -25,8 +25,16 @@ def tokenizer_and_embedding_resize( tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, multiple_of: int = 1, -): - """Resize tokenizer and embedding.""" +) -> dict: + """Resize tokenizer and embedding. + Args: + special_tokens_dict: Dict containing special tokens to be added. + tokenizer: transformers.PreTrainedTokenizer. + model: transformers.PreTrainedModel + multiple_of: int , embeddings are resized to multiple of this. + Return: + dict: Metadata on number of added tokens + """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) embedding_size = int(multiple_of * math.ceil(len(tokenizer) / multiple_of)) num_new_tokens = num_new_tokens + embedding_size - len(tokenizer) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 900e2de50..67df5eaa0 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -64,7 +64,6 @@ write_termination_log, ) from tuning.utils.logging import set_log_level -from tuning.utils.merge_model_utils import post_process_vLLM_adapters_new_tokens from tuning.utils.preprocessing_utils import ( format_dataset, get_data_collator, @@ -587,7 +586,6 @@ def main(): fusedops_kernels_config, attention_and_distributed_packing_config, exp_metadata, - post_process_vllm, ) = parse_arguments(parser, job_config) # Function to set log level for python native logger and transformers training logger @@ -696,11 +694,21 @@ def main(): try: if training_args.save_model_dir: # Write number of added tokens to artifacts - with open(os.path.join(training_args.save_model_dir,'added_tokens_info.json'), 'w') as f: + with open( + os.path.join( + training_args.save_model_dir, "added_tokens_info.json" + ), + "w", + encoding="utf-8", + ) as f: json.dump(additional_train_info["added_tokens_info"], f) - if training_args.output_dir: + if training_args.output_dir: # Write number of added tokens to artifacts - with open(os.path.join(training_args.output_dir,'added_tokens_info.json'), 'w') as f: + with open( + os.path.join(training_args.output_dir, "added_tokens_info.json"), + "w", + encoding="utf-8", + ) as f: json.dump(additional_train_info["added_tokens_info"], f) except Exception as e: # pylint: disable=broad-except logging.error(traceback.format_exc()) diff --git a/tuning/utils/merge_model_utils.py b/tuning/utils/merge_model_utils.py index 233ba9df1..72738e155 100644 --- a/tuning/utils/merge_model_utils.py +++ b/tuning/utils/merge_model_utils.py @@ -124,7 +124,19 @@ def post_process_vLLM_adapters_new_tokens( modified_checkpoint_path: str = None, num_added_tokens: int = 0, ): - # if not set, original checkpoint will be modified + """Post process adapters to allow inferencing on vLLM. + vLLM needs new token embedding weights added during tuning to be moved \ + to a new file new_embeddings.safetensors . \ + This function copies the embeddings weights for the added tokens from \ + adapters.safetnsors to new_embeddings.safetensors. + Args: + path_to_checkpoint: Path to folder containing adapters.safetensors. + modified_checkpoint_path: Output path where to save modified artifacts \ + after post-processing. If not provided, artifacts will be processed \ + in place in same folder. + num_added_tokens: int. Number of tokens that were added during tuning. + """ + # if not set, original checkpoint will be modified in place if not modified_checkpoint_path: modified_checkpoint_path = path_to_checkpoint From 738239e7718e7f63eccc9118b1ef5df40766c194 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Sun, 22 Sep 2024 21:26:29 -0600 Subject: [PATCH 5/6] fix: unit test args Signed-off-by: Sukriti-Sharma4 --- tests/test_sft_trainer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 4049690e5..8deca782a 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -335,7 +335,6 @@ def test_parse_arguments(job_config): _, _, _, - _, ) = sft_trainer.parse_arguments(parser, job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -361,7 +360,6 @@ def test_parse_arguments_defaults(job_config): _, _, _, - _, ) = sft_trainer.parse_arguments(parser, job_config_defaults) assert str(model_args.torch_dtype) == "torch.bfloat16" assert model_args.use_flash_attn is False @@ -372,14 +370,14 @@ def test_parse_arguments_peft_method(job_config): parser = sft_trainer.get_parser() job_config_pt = copy.deepcopy(job_config) job_config_pt["peft_method"] = "pt" - _, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_pt ) assert isinstance(tune_config, peft_config.PromptTuningConfig) job_config_lora = copy.deepcopy(job_config) job_config_lora["peft_method"] = "lora" - _, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_lora ) assert isinstance(tune_config, peft_config.LoraConfig) From 6afdbfefb0cc934a8a3535c9ddcb62cfb6b37f0f Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Sun, 22 Sep 2024 21:50:48 -0600 Subject: [PATCH 6/6] undo post_process_vLLm flag Signed-off-by: Sukriti-Sharma4 --- README.md | 2 -- tuning/sft_trainer.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/README.md b/README.md index 214da6d8c..40e78a838 100644 --- a/README.md +++ b/README.md @@ -665,8 +665,6 @@ The `fms_acceleration.cli` can do more to search for all available configs, plug ## Inference Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time. -If you are trying to run LoRA inference on vLLM, set the `--post_process_vllm` flag to `True`. - ### Running a single example If you want to run a single example through a model, you can pass it with the `--text` flag. diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 67df5eaa0..bd6a4db48 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -523,7 +523,6 @@ def parse_arguments(parser, json_config=None): ) = parser.parse_dict(json_config, allow_extra_keys=True) peft_method = json_config.get("peft_method") exp_metadata = json_config.get("exp_metadata") - post_process_vllm = json_config.get("post_process_vllm") else: ( model_args, @@ -543,7 +542,6 @@ def parse_arguments(parser, json_config=None): peft_method = additional.peft_method exp_metadata = additional.exp_metadata - post_process_vllm = additional.post_process_vllm if peft_method == "lora": tune_config = lora_config @@ -564,7 +562,6 @@ def parse_arguments(parser, json_config=None): fusedops_kernels_config, attention_and_distributed_packing_config, exp_metadata, - post_process_vllm, )