From 719647999cef455025165a9002bdec3575604c27 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 23 Feb 2024 07:37:18 +0000 Subject: [PATCH] remove run_distribtued flag and peft_saving callback --- tuning/sft_trainer.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 0cf98c31a..0d32ad071 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -29,17 +29,6 @@ from tuning.utils.data_type_utils import get_torch_dtype -class PeftSavingCallback(TrainerCallback): - def on_save(self, args, state, control, **kwargs): - checkpoint_path = os.path.join( - args.output_dir, f"checkpoint-{state.global_step}" - ) - kwargs["model"].save_pretrained(checkpoint_path) - - if "pytorch_model.bin" in os.listdir(checkpoint_path): - os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) - - class FileLoggingCallback(TrainerCallback): """Exports metrics, e.g., training loss to a file in the checkpoint directory.""" @@ -103,7 +92,6 @@ def train( None for fine tuning The peft configuration to pass to trainer """ - run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 logger = logging.get_logger("sft_trainer") @@ -204,8 +192,7 @@ def train( aim_callback = get_aimstack_callback() file_logger_callback = FileLoggingCallback(logger) - peft_saving_callback = PeftSavingCallback() - callbacks = [aim_callback, peft_saving_callback, file_logger_callback] + callbacks = [aim_callback, file_logger_callback] if train_args.packing: logger.info("Packing is set to True") @@ -246,7 +233,7 @@ def train( peft_config=peft_config, ) - if run_distributed and peft_config is not None: + if trainer.is_fsdp_enabled and peft_config is not None: trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( model )