From 6c7ac3bda42ca0a9f8a6c85e54e675e85ffa7bc4 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 7 Mar 2024 15:37:21 +0000 Subject: [PATCH] revert deletion of validation checks on some train args --- tuning/sft_trainer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 0d32ad071..9d063ce96 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -95,6 +95,17 @@ def train( logger = logging.get_logger("sft_trainer") + # Validate parameters + if (not isinstance(train_args.num_train_epochs, float)) or ( + train_args.num_train_epochs <= 0 + ): + raise ValueError("num_train_epochs has to be an integer/float >= 1") + if (not isinstance(train_args.gradient_accumulation_steps, int)) or ( + train_args.gradient_accumulation_steps <= 0 + ): + raise ValueError("gradient_accumulation_steps has to be an integer >= 1") + + task_type = "CAUSAL_LM" model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path,