Skip to content

Commit

Permalink
fix: do not add special tokens for custom tokenizer (#279)
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant authored Aug 5, 2024
1 parent d35a139 commit e0da345
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 54 deletions.
1 change: 0 additions & 1 deletion tests/build/test_launch_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
"prompt_tuning_init": "RANDOM",
"num_virtual_tokens": 8,
"prompt_tuning_init_text": "hello",
"tokenizer_name_or_path": MODEL_NAME,
"save_strategy": "epoch",
"output_dir": "tmp",
},
Expand Down
35 changes: 19 additions & 16 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,9 @@ def test_run_causallm_pt_and_inference():
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
# tokenizer_name_or_path from model arguments is passed
# while preparing the prompt tuning config which
# defaults to model_name_or_path if not explicitly set.

_validate_adapter_config(
adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path
adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path
)

# Load the model
Expand Down Expand Up @@ -214,11 +212,8 @@ def test_run_causallm_pt_and_inference_with_formatting_data():
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
# tokenizer_name_or_path from model arguments is passed
# while preparing the prompt tuning config which
# defaults to model_name_or_path if not explicitly set.
_validate_adapter_config(
adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path
adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path
)

# Load the model
Expand Down Expand Up @@ -250,11 +245,8 @@ def test_run_causallm_pt_and_inference_JSON_file_formatter():
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
# tokenizer_name_or_path from model arguments is passed
# while preparing the prompt tuning config which
# defaults to model_name_or_path if not explicitly set.
_validate_adapter_config(
adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path
adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path
)

# Load the model
Expand Down Expand Up @@ -285,11 +277,8 @@ def test_run_causallm_pt_init_text():
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
# tokenizer_name_or_path from model arguments is passed
# while preparing the prompt tuning config which
# defaults to model_name_or_path if not explicitly set.
_validate_adapter_config(
adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path
adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path
)


Expand Down Expand Up @@ -349,6 +338,20 @@ def test_run_causallm_pt_with_validation_data_formatting():
_validate_training(tempdir, check_eval=True)


def test_run_causallm_pt_with_custom_tokenizer():
"""Check if we fail when custom tokenizer not having pad token is used in prompt tuning"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
model_args = copy.deepcopy(MODEL_ARGS)
model_args.tokenizer_name_or_path = model_args.model_name_or_path
train_args.output_dir = tempdir
train_args.eval_strategy = "epoch"
data_args = copy.deepcopy(DATA_ARGS)
data_args.validation_data_path = TWITTER_COMPLAINTS_DATA
with pytest.raises(ValueError):
sft_trainer.train(model_args, data_args, train_args, PEFT_PT_ARGS)


############################# Lora Tests #############################

target_modules_val_map = [
Expand Down
12 changes: 6 additions & 6 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ class ModelArguments:
tokenizer_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "Path to custom tokenizer.\
If not provided it defaults to model_name_or_path"
"help": "Path to custom tokenizer. \
If not provided it defaults to model_name_or_path \
and special tokens will be added as needed for specific tokenizer classes. \
For prompt tuning, if tokenizer_name_or_path provided, special tokens are not added, \
otherwise, it defaults to model_name_or_path with special tokens for specific \
tokenizer classes."
},
)

def __post_init__(self):
if not self.tokenizer_name_or_path:
self.tokenizer_name_or_path = self.model_name_or_path


@dataclass
class DataArguments:
Expand Down
78 changes: 47 additions & 31 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,32 +190,46 @@ def train(

# TODO: Move these to a config as well
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name_or_path, cache_dir=train_args.cache_dir, use_fast=True
(
model_args.tokenizer_name_or_path
if model_args.tokenizer_name_or_path
else model_args.model_name_or_path
),
cache_dir=train_args.cache_dir,
use_fast=True,
)

# Calculate and save additional metrics to track later.
additional_metrics["model_load_time"] = time.time() - model_load_time

peft_config = get_hf_peft_config(
task_type, peft_config, model_args.tokenizer_name_or_path
task_type,
peft_config,
(
model_args.tokenizer_name_or_path
if model_args.tokenizer_name_or_path
else model_args.model_name_or_path
),
)

# TODO: understand if we need to hardcode these here or just use defaults in model
if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
}
)
elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)):
tokenizer.add_special_tokens(
{
"pad_token": "<pad>",
}
)
# add special tokens only when a custom tokenizer is not passed
if not model_args.tokenizer_name_or_path:
# TODO: understand if we need to hardcode these here or just use defaults in model
if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
}
)
elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)):
tokenizer.add_special_tokens(
{
"pad_token": "<pad>",
}
)

max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length)
logger.info("Max sequence length is %s", max_seq_length)
Expand All @@ -228,20 +242,22 @@ def train(
tokenizer.model_max_length,
)

# TODO: we need to change this, perhaps follow what open instruct does?
# add special tokens only when a custom tokenizer is not passed
special_tokens_dict = {}
if tokenizer.pad_token is None:
logger.warning("PAD token set to default, missing in tokenizer")
special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN
if tokenizer.eos_token is None:
logger.warning("EOS token set to default, missing in tokenizer")
special_tokens_dict["eos_token"] = configs.DEFAULT_EOS_TOKEN
if tokenizer.bos_token is None:
logger.warning("BOS token set to default, missing in tokenizer")
special_tokens_dict["bos_token"] = configs.DEFAULT_BOS_TOKEN
if tokenizer.unk_token is None:
logger.warning("UNK token set to default, missing in tokenizer")
special_tokens_dict["unk_token"] = configs.DEFAULT_UNK_TOKEN
if not model_args.tokenizer_name_or_path:
# TODO: we need to change this, perhaps follow what open instruct does?
if tokenizer.pad_token is None:
logger.warning("PAD token set to default, missing in tokenizer")
special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN
if tokenizer.eos_token is None:
logger.warning("EOS token set to default, missing in tokenizer")
special_tokens_dict["eos_token"] = configs.DEFAULT_EOS_TOKEN
if tokenizer.bos_token is None:
logger.warning("BOS token set to default, missing in tokenizer")
special_tokens_dict["bos_token"] = configs.DEFAULT_BOS_TOKEN
if tokenizer.unk_token is None:
logger.warning("UNK token set to default, missing in tokenizer")
special_tokens_dict["unk_token"] = configs.DEFAULT_UNK_TOKEN

# TODO: lower priority but understand if resizing impacts inference quality and why its needed.
# It makes sense if we manipulate tokenizer that we also save it and provide it to inference.
Expand Down

0 comments on commit e0da345

Please sign in to comment.