From e0da345c65347ade0d7a3ea0889ec8cbcfdd1fde Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Mon, 5 Aug 2024 22:42:07 +0530 Subject: [PATCH] fix: do not add special tokens for custom tokenizer (#279) Signed-off-by: Mehant Kammakomati --- tests/build/test_launch_script.py | 1 - tests/test_sft_trainer.py | 35 +++++++------- tuning/config/configs.py | 12 ++--- tuning/sft_trainer.py | 78 +++++++++++++++++++------------ 4 files changed, 72 insertions(+), 54 deletions(-) diff --git a/tests/build/test_launch_script.py b/tests/build/test_launch_script.py index 824b7125c..421849b1f 100644 --- a/tests/build/test_launch_script.py +++ b/tests/build/test_launch_script.py @@ -61,7 +61,6 @@ "prompt_tuning_init": "RANDOM", "num_virtual_tokens": 8, "prompt_tuning_init_text": "hello", - "tokenizer_name_or_path": MODEL_NAME, "save_strategy": "epoch", "output_dir": "tmp", }, diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index eb2ca855a..26067124a 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -176,11 +176,9 @@ def test_run_causallm_pt_and_inference(): _validate_training(tempdir) checkpoint_path = _get_checkpoint_path(tempdir) adapter_config = _get_adapter_config(checkpoint_path) - # tokenizer_name_or_path from model arguments is passed - # while preparing the prompt tuning config which - # defaults to model_name_or_path if not explicitly set. + _validate_adapter_config( - adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path + adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path ) # Load the model @@ -214,11 +212,8 @@ def test_run_causallm_pt_and_inference_with_formatting_data(): _validate_training(tempdir) checkpoint_path = _get_checkpoint_path(tempdir) adapter_config = _get_adapter_config(checkpoint_path) - # tokenizer_name_or_path from model arguments is passed - # while preparing the prompt tuning config which - # defaults to model_name_or_path if not explicitly set. _validate_adapter_config( - adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path + adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path ) # Load the model @@ -250,11 +245,8 @@ def test_run_causallm_pt_and_inference_JSON_file_formatter(): _validate_training(tempdir) checkpoint_path = _get_checkpoint_path(tempdir) adapter_config = _get_adapter_config(checkpoint_path) - # tokenizer_name_or_path from model arguments is passed - # while preparing the prompt tuning config which - # defaults to model_name_or_path if not explicitly set. _validate_adapter_config( - adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path + adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path ) # Load the model @@ -285,11 +277,8 @@ def test_run_causallm_pt_init_text(): _validate_training(tempdir) checkpoint_path = _get_checkpoint_path(tempdir) adapter_config = _get_adapter_config(checkpoint_path) - # tokenizer_name_or_path from model arguments is passed - # while preparing the prompt tuning config which - # defaults to model_name_or_path if not explicitly set. _validate_adapter_config( - adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path + adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path ) @@ -349,6 +338,20 @@ def test_run_causallm_pt_with_validation_data_formatting(): _validate_training(tempdir, check_eval=True) +def test_run_causallm_pt_with_custom_tokenizer(): + """Check if we fail when custom tokenizer not having pad token is used in prompt tuning""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + model_args = copy.deepcopy(MODEL_ARGS) + model_args.tokenizer_name_or_path = model_args.model_name_or_path + train_args.output_dir = tempdir + train_args.eval_strategy = "epoch" + data_args = copy.deepcopy(DATA_ARGS) + data_args.validation_data_path = TWITTER_COMPLAINTS_DATA + with pytest.raises(ValueError): + sft_trainer.train(model_args, data_args, train_args, PEFT_PT_ARGS) + + ############################# Lora Tests ############################# target_modules_val_map = [ diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 92fb4f8f8..c08c90b12 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -51,15 +51,15 @@ class ModelArguments: tokenizer_name_or_path: Optional[str] = field( default=None, metadata={ - "help": "Path to custom tokenizer.\ - If not provided it defaults to model_name_or_path" + "help": "Path to custom tokenizer. \ + If not provided it defaults to model_name_or_path \ + and special tokens will be added as needed for specific tokenizer classes. \ + For prompt tuning, if tokenizer_name_or_path provided, special tokens are not added, \ + otherwise, it defaults to model_name_or_path with special tokens for specific \ + tokenizer classes." }, ) - def __post_init__(self): - if not self.tokenizer_name_or_path: - self.tokenizer_name_or_path = self.model_name_or_path - @dataclass class DataArguments: diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index d889c67e7..30e768ce4 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -190,32 +190,46 @@ def train( # TODO: Move these to a config as well tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name_or_path, cache_dir=train_args.cache_dir, use_fast=True + ( + model_args.tokenizer_name_or_path + if model_args.tokenizer_name_or_path + else model_args.model_name_or_path + ), + cache_dir=train_args.cache_dir, + use_fast=True, ) # Calculate and save additional metrics to track later. additional_metrics["model_load_time"] = time.time() - model_load_time peft_config = get_hf_peft_config( - task_type, peft_config, model_args.tokenizer_name_or_path + task_type, + peft_config, + ( + model_args.tokenizer_name_or_path + if model_args.tokenizer_name_or_path + else model_args.model_name_or_path + ), ) - # TODO: understand if we need to hardcode these here or just use defaults in model - if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)): - tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - "pad_token": "", - } - ) - elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)): - tokenizer.add_special_tokens( - { - "pad_token": "", - } - ) + # add special tokens only when a custom tokenizer is not passed + if not model_args.tokenizer_name_or_path: + # TODO: understand if we need to hardcode these here or just use defaults in model + if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)): + tokenizer.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + "pad_token": "", + } + ) + elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)): + tokenizer.add_special_tokens( + { + "pad_token": "", + } + ) max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) logger.info("Max sequence length is %s", max_seq_length) @@ -228,20 +242,22 @@ def train( tokenizer.model_max_length, ) - # TODO: we need to change this, perhaps follow what open instruct does? + # add special tokens only when a custom tokenizer is not passed special_tokens_dict = {} - if tokenizer.pad_token is None: - logger.warning("PAD token set to default, missing in tokenizer") - special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN - if tokenizer.eos_token is None: - logger.warning("EOS token set to default, missing in tokenizer") - special_tokens_dict["eos_token"] = configs.DEFAULT_EOS_TOKEN - if tokenizer.bos_token is None: - logger.warning("BOS token set to default, missing in tokenizer") - special_tokens_dict["bos_token"] = configs.DEFAULT_BOS_TOKEN - if tokenizer.unk_token is None: - logger.warning("UNK token set to default, missing in tokenizer") - special_tokens_dict["unk_token"] = configs.DEFAULT_UNK_TOKEN + if not model_args.tokenizer_name_or_path: + # TODO: we need to change this, perhaps follow what open instruct does? + if tokenizer.pad_token is None: + logger.warning("PAD token set to default, missing in tokenizer") + special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN + if tokenizer.eos_token is None: + logger.warning("EOS token set to default, missing in tokenizer") + special_tokens_dict["eos_token"] = configs.DEFAULT_EOS_TOKEN + if tokenizer.bos_token is None: + logger.warning("BOS token set to default, missing in tokenizer") + special_tokens_dict["bos_token"] = configs.DEFAULT_BOS_TOKEN + if tokenizer.unk_token is None: + logger.warning("UNK token set to default, missing in tokenizer") + special_tokens_dict["unk_token"] = configs.DEFAULT_UNK_TOKEN # TODO: lower priority but understand if resizing impacts inference quality and why its needed. # It makes sense if we manipulate tokenizer that we also save it and provide it to inference.