From 57cadc3c3b5379b6352697fbfee8188369a272b9 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 18 Sep 2024 18:06:05 -0400 Subject: [PATCH] fix: Add post processing flag so post processing is only done for vLLM Signed-off-by: Will Johnson --- README.md | 2 ++ tests/test_sft_trainer.py | 7 ++++--- tuning/sft_trainer.py | 29 ++++++++++++++++------------- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 7fd8fd5d7..f93b00a6c 100644 --- a/README.md +++ b/README.md @@ -654,6 +654,8 @@ The `fms_acceleration.cli` can do more to search for all available configs, plug ## Inference Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time. +If you are trying to run LoRA inference on vLLM, set the `--post_process_vllm` flag to `True`. + ### Running a single example If you want to run a single example through a model, you can pass it with the `--text` flag. diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 2d55b7de4..b20547002 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -334,6 +334,7 @@ def test_parse_arguments(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -347,7 +348,7 @@ def test_parse_arguments_defaults(job_config): assert "torch_dtype" not in job_config_defaults assert job_config_defaults["use_flash_attn"] is False assert "save_strategy" not in job_config_defaults - model_args, _, training_args, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( + model_args, _, training_args, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_defaults ) assert str(model_args.torch_dtype) == "torch.bfloat16" @@ -359,14 +360,14 @@ def test_parse_arguments_peft_method(job_config): parser = sft_trainer.get_parser() job_config_pt = copy.deepcopy(job_config) job_config_pt["peft_method"] = "pt" - _, _, _, _, tune_config, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_pt ) assert isinstance(tune_config, peft_config.PromptTuningConfig) job_config_lora = copy.deepcopy(job_config) job_config_lora["peft_method"] = "lora" - _, _, _, _, tune_config, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_lora ) assert isinstance(tune_config, peft_config.LoraConfig) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 1c50b9610..2463ccadc 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -40,9 +40,6 @@ from trl import SFTConfig, SFTTrainer import transformers -# First Party -from build.utils import get_highest_checkpoint - # Local from tuning.config import configs, peft_config from tuning.config.acceleration_configs import ( @@ -440,6 +437,13 @@ def get_parser(): choices=["pt", "lora", None, "none"], default="none", ) + parser.add_argument( + "--post_process_vllm", + type=bool, + default=False, + help="Bool to indicate if post processing of LoRA adapters for vLLM \ + is required.", + ) parser.add_argument( "--exp_metadata", type=str, @@ -496,6 +500,7 @@ def parse_arguments(parser, json_config=None): ) = parser.parse_dict(json_config, allow_extra_keys=True) peft_method = json_config.get("peft_method") exp_metadata = json_config.get("exp_metadata") + post_process_vllm = json_config.get("post_process_vllm") else: ( model_args, @@ -514,6 +519,7 @@ def parse_arguments(parser, json_config=None): peft_method = additional.peft_method exp_metadata = additional.exp_metadata + post_process_vllm = additional.post_process_vllm if peft_method == "lora": tune_config = lora_config @@ -533,6 +539,7 @@ def parse_arguments(parser, json_config=None): quantized_lora_config, fusedops_kernels_config, exp_metadata, + post_process_vllm, ) @@ -553,6 +560,7 @@ def main(): quantized_lora_config, fusedops_kernels_config, exp_metadata, + post_process_vllm, ) = parse_arguments(parser, job_config) # Function to set log level for python native logger and transformers training logger @@ -656,17 +664,12 @@ def main(): sys.exit(INTERNAL_ERROR_EXIT_CODE) # post process lora - if isinstance(tune_config, peft_config.LoraConfig): + if post_process_vllm and isinstance(tune_config, peft_config.LoraConfig): try: - checkpoint_dir = job_config.get("save_model_dir") - if not checkpoint_dir: - checkpoint_dir = os.path.join( - training_args.output_dir, - get_highest_checkpoint(training_args.output_dir), - ) - print(training_args) - print(f"Post processing LoRA adapters in {checkpoint_dir}") - post_process_vLLM_adapters_new_tokens(path_to_checkpoint=checkpoint_dir) + checkpoint_dir = training_args.save_model_dir + if checkpoint_dir: + print(f"Post processing LoRA adapters in {checkpoint_dir}") + post_process_vLLM_adapters_new_tokens(path_to_checkpoint=checkpoint_dir) except Exception as e: # pylint: disable=broad-except logging.error(traceback.format_exc()) write_termination_log(