diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py
index fa7d0875c..ab4af54fa 100644
--- a/tuning/sft_trainer.py
+++ b/tuning/sft_trainer.py
@@ -29,8 +29,6 @@
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
- GPT2Tokenizer,
- GPTNeoXTokenizerFast,
LlamaTokenizer,
LlamaTokenizerFast,
TrainerCallback,
@@ -136,6 +134,7 @@ def train(
):
raise ValueError("gradient_accumulation_steps has to be an integer >= 1")
+ padding_free = False
if (
attention_and_distributed_packing_config is not None
and attention_and_distributed_packing_config.padding_free is not None
@@ -153,6 +152,7 @@ def train(
"`--padding_free` argument was called with `packing=True`, "
"Trainer should not perform packing when using `--padding_free`"
)
+ padding_free = True
task_type = "CAUSAL_LM"
additional_metrics = {}
@@ -253,9 +253,8 @@ def train(
special_tokens_dict["bos_token"] = ""
special_tokens_dict["eos_token"] = ""
special_tokens_dict["unk_token"] = ""
- special_tokens_dict["pad_token"] = ""
- elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)):
- special_tokens_dict["pad_token"] = ""
+ if not padding_free:
+ special_tokens_dict["pad_token"] = ""
max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length)
logger.info("Max sequence length is %s", max_seq_length)
@@ -271,7 +270,7 @@ def train(
# add special tokens only when a custom tokenizer is not passed
if not model_args.tokenizer_name_or_path:
# TODO: we need to change this, perhaps follow what open instruct does?
- if tokenizer.pad_token is None:
+ if tokenizer.pad_token is None and not padding_free:
logger.warning("PAD token set to default, missing in tokenizer")
special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN
if tokenizer.eos_token is None:
@@ -288,7 +287,9 @@ def train(
"PAD token set to default, to make it different from eos token"
)
if tokenizer.eos_token != configs.DEFAULT_PAD_TOKEN:
- tokenizer.pad_token = configs.DEFAULT_PAD_TOKEN
+ tokenizer.pad_token = (
+ configs.DEFAULT_PAD_TOKEN if not padding_free else None
+ )
else:
tokenizer.eos_token = configs.DEFAULT_EOS_TOKEN