From 9612ac41248eb7423e21788d998c6001c99e0fa0 Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Thu, 2 Jan 2025 15:16:12 +0530 Subject: [PATCH] add data collator for padding free plugin scenario to be used for extended pretraining Signed-off-by: Dushyant Behl --- tests/data/test_data_preprocessing_utils.py | 62 ++++++++++++++++----- tuning/data/data_preprocessing_utils.py | 13 +++++ tuning/data/setup_dataprocessor.py | 20 ++++--- tuning/sft_trainer.py | 7 ++- 4 files changed, 78 insertions(+), 24 deletions(-) diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 578daffbf..8de5dfc36 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -489,7 +489,7 @@ def test_is_pretokenized_data(data, result): @pytest.mark.parametrize( "packing, response_template, formatted_train_dataset,\ - max_seq_length, instruction_template, expected_collator", + max_seq_length, instruction_template, is_padding_free, expected_collator", [ ( False, @@ -501,6 +501,7 @@ def test_is_pretokenized_data(data, result): ), 1024, None, + False, DataCollatorForCompletionOnlyLM, ), ( @@ -517,6 +518,7 @@ def test_is_pretokenized_data(data, result): ), 1024, None, + False, DataCollatorForSeq2Seq, ), ( @@ -529,6 +531,7 @@ def test_is_pretokenized_data(data, result): ), 1024, "\n### Text:", + False, DataCollatorForCompletionOnlyLM, ), ( @@ -545,6 +548,20 @@ def test_is_pretokenized_data(data, result): ), 1024, "\n### Text:", + False, + DataCollatorForSeq2Seq, + ), + ( + False, + None, + datasets.load_dataset( + "json", + data_files=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + split="train", + ), + 1024, + None, + True, DataCollatorForSeq2Seq, ), ], @@ -555,6 +572,7 @@ def test_get_data_collator( formatted_train_dataset, max_seq_length, instruction_template, + is_padding_free, expected_collator, ): """Ensure that the correct collator type is fetched based on the data args""" @@ -565,6 +583,7 @@ def test_get_data_collator( is_pretokenized_dataset(formatted_train_dataset), max_seq_length, instruction_template, + is_padding_free, ) assert isinstance(collator, expected_collator) @@ -1044,7 +1063,7 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( @pytest.mark.parametrize( - "data_args", + "data_args, is_padding_free", [ # single sequence JSON and response template ( @@ -1053,7 +1072,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( validation_data_path=TWITTER_COMPLAINTS_DATA_JSON, dataset_text_field="output", response_template="\n### Label:", - ) + ), + False, ), # single sequence JSONL and response template ( @@ -1062,7 +1082,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( validation_data_path=TWITTER_COMPLAINTS_DATA_JSONL, dataset_text_field="output", response_template="\n### Label:", - ) + ), + False, ), # single sequence PARQUET and response template ( @@ -1071,7 +1092,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( validation_data_path=TWITTER_COMPLAINTS_DATA_PARQUET, dataset_text_field="output", response_template="\n### Label:", - ) + ), + False, ), # data formatter template with input/output JSON ( @@ -1080,7 +1102,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}", response_template="\n### Label:", - ) + ), + False, ), # data formatter template with input/output JSONL ( @@ -1089,7 +1112,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}", response_template="\n### Label:", - ) + ), + False, ), # data formatter template with input/output PARQUET ( @@ -1098,32 +1122,44 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}", response_template="\n### Label:", - ) + ), + False, ), # input/output JSON with masking on input ( configs.DataArguments( training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, - ) + ), + False, ), # input/output JSONL with masking on input ( configs.DataArguments( training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, - ) + ), + False, ), # input/output PARQUET with masking on input ( configs.DataArguments( training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, - ) + ), + False, + ), + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_DATA_JSON, + validation_data_path=TWITTER_COMPLAINTS_DATA_JSON, + dataset_text_field="output", + ), + True, ), ], ) -def test_process_dataargs(data_args): +def test_process_dataargs(data_args, is_padding_free): """Ensure that the train/eval data are properly formatted based on the data args / text field""" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) TRAIN_ARGS = configs.TrainingArguments( @@ -1132,7 +1168,7 @@ def test_process_dataargs(data_args): output_dir="tmp", # Not needed but positional ) (train_set, eval_set, dataset_text_field, _, _, _) = process_dataargs( - data_args, tokenizer, TRAIN_ARGS + data_args, tokenizer, TRAIN_ARGS, is_padding_free=is_padding_free ) assert isinstance(train_set, Dataset) assert isinstance(eval_set, Dataset) diff --git a/tuning/data/data_preprocessing_utils.py b/tuning/data/data_preprocessing_utils.py index 2c3386e34..b77fdba1d 100644 --- a/tuning/data/data_preprocessing_utils.py +++ b/tuning/data/data_preprocessing_utils.py @@ -29,6 +29,7 @@ def get_data_collator( is_traindata_tokenized: bool, max_seq_length: int, instruction_template: Optional[str], + is_padding_free: bool = False, ) -> Callable: """Create and return the the appropriate collator type based on the configuration for packing, response_template, and dataset_text_field. @@ -46,6 +47,8 @@ def get_data_collator( Max sequence length expected instruction_template: str str representing the human response in a chat template + is_padding_free: bool + if padding free plugin is used or not Returns: Callable @@ -74,6 +77,16 @@ def get_data_collator( tokenizer=tokenizer, ignore_index=configs.IGNORE_INDEX, ) + + if is_padding_free: + # when packing is false but padding_free is used and + # no response template is used then its a pretrained scenario. + # Current plugin in fms-acceleration is compatible with + # `DataCollatorForSeq2Seq` collator hence we use this. + return DataCollatorForSeq2Seq( + tokenizer=tokenizer, padding=False, max_length=max_seq_length + ) + # Note that this automatically pads labels with -100 # TODO check if this is sufficient for preprocessed if is_traindata_tokenized: diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index 037a49630..f8d38a4e7 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -107,14 +107,14 @@ def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized): ### Data format 2 -def _get_dataset_formatting_handlers(data_args, packing, padding_free=None): +def _get_dataset_formatting_handlers(data_args, packing, is_padding_free=False): if data_args.response_template is None: if packing is False: - if padding_free: + if is_padding_free: logger.debug( - "Assuming extended pretraining scenario because, packing is false" - + ", padding_free is used and no response template was provided." + "Assuming extended pretraining scenario because, packing is false," + + " padding_free plugin is used and no response template was provided." ) else: raise ValueError( @@ -215,7 +215,7 @@ def _process_raw_data_args( packing: bool, max_seq_length: int, additional_data_handlers: Dict[str, Callable] = None, - **kwargs, + is_padding_free: bool = False, ): # Create a data processor with default processor config @@ -255,6 +255,7 @@ def _process_raw_data_args( tokenizer_kwargs = {} tokenizer_kwargs["max_length"] = max_seq_length tokenizer_kwargs["truncation"] = True + # Lets not pad in tokenizer...we can handle that in the collator tokenizer_kwargs["padding"] = False handlers = None @@ -273,7 +274,7 @@ def _process_raw_data_args( elif data_args.data_formatter_template or data_args.dataset_text_field: # Data Format 3: Single Sequence Dataset handlers, dataset_text_field = _get_dataset_formatting_handlers( - data_args, packing, **kwargs + data_args, packing, is_padding_free ) else: # Default Data Format: Dataset with Input/Output Fields @@ -307,7 +308,7 @@ def process_dataargs( tokenizer: AutoTokenizer, train_args: TrainingArguments, additional_data_handlers: Dict[str, Callable] = None, - **kwargs, + is_padding_free: bool = False, ): """ Args: @@ -318,6 +319,8 @@ def process_dataargs( Used for packing and max_seq_length additional_data_handlers: A Dict of [str, callable] data handlers which need to be registered with the data preprocessor + is_padding_free: A bool representing if Padding free plugin is enabled. + Defaults to False. Returns: Tuple(Dataset, Dataset, str, DataCollator, int, Dict) tuple containing @@ -353,7 +356,7 @@ def process_dataargs( train_args.packing, max_seq_length, additional_data_handlers, - **kwargs, + is_padding_free, ) # Note: This check should not be removed. @@ -368,6 +371,7 @@ def process_dataargs( is_tokenized_dataset, max_seq_length, data_args.instruction_template, + is_padding_free=is_padding_free, ) dataset_kwargs = {} diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index fee080d68..32b8735cb 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -306,9 +306,10 @@ def train( data_collator = None logger.info("Packing is set to %s ", train_args.packing) - padding_free = None + is_padding_free = False if attention_and_distributed_packing_config is not None: - padding_free = attention_and_distributed_packing_config.padding_free + is_padding_free = attention_and_distributed_packing_config.is_padding_free + data_preprocessing_time = time.time() ( formatted_train_dataset, @@ -322,7 +323,7 @@ def train( tokenizer, train_args, additional_data_handlers, - padding_free=padding_free, + is_padding_free=is_padding_free, ) additional_metrics["data_preprocessing_time"] = ( time.time() - data_preprocessing_time