Skip to content

Commit

Permalink
fix: allow for padding free + pretraining
Browse files Browse the repository at this point in the history
Signed-off-by: Harikrishnan Balagopal <[email protected]>
  • Loading branch information
HarikrishnanBalagopal committed Dec 24, 2024
1 parent 6f0c61d commit e0ac618
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ class AttentionAndDistributedPackingConfig:
def __post_init__(self):
# ensure nested dataclasses initialized
ensure_nested_dataclasses_initialized(self)

@property
def is_padding_free(self):
return self.padding_free is not None
23 changes: 16 additions & 7 deletions tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,21 @@ def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized):


### Data format 2
def _get_dataset_formatting_handlers(data_args, packing):
def _get_dataset_formatting_handlers(data_args, packing, padding_free=None):

if data_args.response_template is None:
if packing is False:
raise ValueError(
"Since dataset_text_field or data_formatter_template \
is provided and packing is disabled, \
needs a corresponding response template for masking"
)
if padding_free:
logger.debug(
"Assuming extended pretraining scenario because, packing is false"
+ ", padding_free is used and no response template was provided."
)
else:
raise ValueError(
"Since dataset_text_field or data_formatter_template \
is provided and packing is disabled, \
needs a corresponding response template for masking"
)

if data_args.response_template:
# To use Response template, pass datasets with single sequence instances \
Expand Down Expand Up @@ -209,6 +215,7 @@ def _process_raw_data_args(
packing: bool,
max_seq_length: int,
additional_data_handlers: Dict[str, Callable] = None,
**kwargs,
):

# Create a data processor with default processor config
Expand Down Expand Up @@ -266,7 +273,7 @@ def _process_raw_data_args(
elif data_args.data_formatter_template or data_args.dataset_text_field:
# Data Format 3: Single Sequence Dataset
handlers, dataset_text_field = _get_dataset_formatting_handlers(
data_args, packing
data_args, packing, **kwargs
)
else:
# Default Data Format: Dataset with Input/Output Fields
Expand Down Expand Up @@ -300,6 +307,7 @@ def process_dataargs(
tokenizer: AutoTokenizer,
train_args: TrainingArguments,
additional_data_handlers: Dict[str, Callable] = None,
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -345,6 +353,7 @@ def process_dataargs(
train_args.packing,
max_seq_length,
additional_data_handlers,
**kwargs,
)

# Note: This check should not be removed.
Expand Down
11 changes: 10 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ def train(
data_collator = None
logger.info("Packing is set to %s ", train_args.packing)

padding_free = None
if attention_and_distributed_packing_config is not None:
padding_free = attention_and_distributed_packing_config.padding_free
data_preprocessing_time = time.time()
(
formatted_train_dataset,
Expand All @@ -310,7 +313,13 @@ def train(
data_collator,
train_args.max_seq_length,
dataset_kwargs,
) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers)
) = process_dataargs(
data_args,
tokenizer,
train_args,
additional_data_handlers,
padding_free=padding_free,
)
additional_metrics["data_preprocessing_time"] = (
time.time() - data_preprocessing_time
)
Expand Down

0 comments on commit e0ac618

Please sign in to comment.