diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 59943a750..29c5fd299 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -44,17 +44,6 @@ from tuning.utils.data_type_utils import get_torch_dtype -class PeftSavingCallback(TrainerCallback): - def on_save(self, args, state, control, **kwargs): - checkpoint_path = os.path.join( - args.output_dir, f"checkpoint-{state.global_step}" - ) - kwargs["model"].save_pretrained(checkpoint_path) - - if "pytorch_model.bin" in os.listdir(checkpoint_path): - os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) - - class FileLoggingCallback(TrainerCallback): """Exports metrics, e.g., training loss to a file in the checkpoint directory.""" @@ -118,7 +107,6 @@ def train( None for fine tuning The peft configuration to pass to trainer """ - run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 logger = logging.get_logger("sft_trainer") @@ -132,11 +120,6 @@ def train( ): raise ValueError("gradient_accumulation_steps has to be an integer >= 1") - # make sure to unset FSDP args when running on single gpu - if not run_distributed: - train_args.fsdp = "" - train_args.fsdp_config = {"xla": False} - task_type = "CAUSAL_LM" model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, @@ -147,8 +130,6 @@ def train( peft_config = get_hf_peft_config(task_type, peft_config) - model.gradient_checkpointing_enable() - # TODO: Move these to a config as well tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=train_args.cache_dir, use_fast=True @@ -239,8 +220,7 @@ def train( aim_callback = get_aimstack_callback() file_logger_callback = FileLoggingCallback(logger) - peft_saving_callback = PeftSavingCallback() - callbacks = [aim_callback, peft_saving_callback, file_logger_callback] + callbacks = [aim_callback, file_logger_callback] if train_args.packing: logger.info("Packing is set to True") @@ -281,7 +261,7 @@ def train( peft_config=peft_config, ) - if run_distributed and peft_config is not None: + if trainer.is_fsdp_enabled and peft_config is not None: trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( model )