diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 2ab8f7de0..bcc42aeaf 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -40,6 +40,7 @@ import transformers # Local +from build.utils import get_highest_checkpoint from tuning.config import configs, peft_config from tuning.config.acceleration_configs import ( AccelerationFrameworkConfig, @@ -68,7 +69,9 @@ is_pretokenized_dataset, validate_data_args, ) - +from tuning.utils.merge_model_utils import( + post_process_vLLM_adapters_new_tokens, +) def train( model_args: configs.ModelArguments, @@ -633,6 +636,22 @@ def main(): ) sys.exit(INTERNAL_ERROR_EXIT_CODE) + + # post process lora + if isinstance(tune_config, peft_config.LoraConfig): + try: + checkpoint_dir = job_config.get("save_model_dir") + if not checkpoint_dir: + checkpoint_dir = os.path.join( + training_args.output_dir, get_highest_checkpoint(training_args.output_dir) + ) + print(training_args) + print(f"Post processing LoRA adapters in {checkpoint_dir}") + post_process_vLLM_adapters_new_tokens(path_to_checkpoint=checkpoint_dir) + except Exception as e: # pylint: disable=broad-except + logging.error(traceback.format_exc()) + write_termination_log(f"Exception encountered while lora post-processing model: {e}") + sys.exit(INTERNAL_ERROR_EXIT_CODE) if __name__ == "__main__": main()