diff --git a/tuning/config/acceleration_configs/attention_and_distributed_packing.py b/tuning/config/acceleration_configs/attention_and_distributed_packing.py index e1ed83a58..d522d9826 100644 --- a/tuning/config/acceleration_configs/attention_and_distributed_packing.py +++ b/tuning/config/acceleration_configs/attention_and_distributed_packing.py @@ -33,3 +33,7 @@ class AttentionAndDistributedPackingConfig: def __post_init__(self): # ensure nested dataclasses initialized ensure_nested_dataclasses_initialized(self) + + @property + def is_padding_free(self): + return self.padding_free is not None diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index b6f09c323..037a49630 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -107,15 +107,21 @@ def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized): ### Data format 2 -def _get_dataset_formatting_handlers(data_args, packing): +def _get_dataset_formatting_handlers(data_args, packing, padding_free=None): if data_args.response_template is None: if packing is False: - raise ValueError( - "Since dataset_text_field or data_formatter_template \ - is provided and packing is disabled, \ - needs a corresponding response template for masking" - ) + if padding_free: + logger.debug( + "Assuming extended pretraining scenario because, packing is false" + + ", padding_free is used and no response template was provided." + ) + else: + raise ValueError( + "Since dataset_text_field or data_formatter_template \ + is provided and packing is disabled, \ + needs a corresponding response template for masking" + ) if data_args.response_template: # To use Response template, pass datasets with single sequence instances \ @@ -209,6 +215,7 @@ def _process_raw_data_args( packing: bool, max_seq_length: int, additional_data_handlers: Dict[str, Callable] = None, + **kwargs, ): # Create a data processor with default processor config @@ -266,7 +273,7 @@ def _process_raw_data_args( elif data_args.data_formatter_template or data_args.dataset_text_field: # Data Format 3: Single Sequence Dataset handlers, dataset_text_field = _get_dataset_formatting_handlers( - data_args, packing + data_args, packing, **kwargs ) else: # Default Data Format: Dataset with Input/Output Fields @@ -300,6 +307,7 @@ def process_dataargs( tokenizer: AutoTokenizer, train_args: TrainingArguments, additional_data_handlers: Dict[str, Callable] = None, + **kwargs, ): """ Args: @@ -345,6 +353,7 @@ def process_dataargs( train_args.packing, max_seq_length, additional_data_handlers, + **kwargs, ) # Note: This check should not be removed. diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 2afdd2dac..b116e08d8 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -302,6 +302,9 @@ def train( data_collator = None logger.info("Packing is set to %s ", train_args.packing) + padding_free = None + if attention_and_distributed_packing_config is not None: + padding_free = attention_and_distributed_packing_config.padding_free data_preprocessing_time = time.time() ( formatted_train_dataset, @@ -310,7 +313,13 @@ def train( data_collator, train_args.max_seq_length, dataset_kwargs, - ) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers) + ) = process_dataargs( + data_args, + tokenizer, + train_args, + additional_data_handlers, + padding_free=padding_free, + ) additional_metrics["data_preprocessing_time"] = ( time.time() - data_preprocessing_time )