diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 0d32ad071..9d063ce96 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -95,6 +95,17 @@ def train( logger = logging.get_logger("sft_trainer") + # Validate parameters + if (not isinstance(train_args.num_train_epochs, float)) or ( + train_args.num_train_epochs <= 0 + ): + raise ValueError("num_train_epochs has to be an integer/float >= 1") + if (not isinstance(train_args.gradient_accumulation_steps, int)) or ( + train_args.gradient_accumulation_steps <= 0 + ): + raise ValueError("gradient_accumulation_steps has to be an integer >= 1") + + task_type = "CAUSAL_LM" model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path,