diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index fa7d0875c..ab4af54fa 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -29,8 +29,6 @@ from transformers import ( AutoModelForCausalLM, AutoTokenizer, - GPT2Tokenizer, - GPTNeoXTokenizerFast, LlamaTokenizer, LlamaTokenizerFast, TrainerCallback, @@ -136,6 +134,7 @@ def train( ): raise ValueError("gradient_accumulation_steps has to be an integer >= 1") + padding_free = False if ( attention_and_distributed_packing_config is not None and attention_and_distributed_packing_config.padding_free is not None @@ -153,6 +152,7 @@ def train( "`--padding_free` argument was called with `packing=True`, " "Trainer should not perform packing when using `--padding_free`" ) + padding_free = True task_type = "CAUSAL_LM" additional_metrics = {} @@ -253,9 +253,8 @@ def train( special_tokens_dict["bos_token"] = "" special_tokens_dict["eos_token"] = "" special_tokens_dict["unk_token"] = "" - special_tokens_dict["pad_token"] = "" - elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)): - special_tokens_dict["pad_token"] = "" + if not padding_free: + special_tokens_dict["pad_token"] = "" max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) logger.info("Max sequence length is %s", max_seq_length) @@ -271,7 +270,7 @@ def train( # add special tokens only when a custom tokenizer is not passed if not model_args.tokenizer_name_or_path: # TODO: we need to change this, perhaps follow what open instruct does? - if tokenizer.pad_token is None: + if tokenizer.pad_token is None and not padding_free: logger.warning("PAD token set to default, missing in tokenizer") special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN if tokenizer.eos_token is None: @@ -288,7 +287,9 @@ def train( "PAD token set to default, to make it different from eos token" ) if tokenizer.eos_token != configs.DEFAULT_PAD_TOKEN: - tokenizer.pad_token = configs.DEFAULT_PAD_TOKEN + tokenizer.pad_token = ( + configs.DEFAULT_PAD_TOKEN if not padding_free else None + ) else: tokenizer.eos_token = configs.DEFAULT_EOS_TOKEN