Skip to content

Commit

Permalink
fix: allow for padding free + pretraining
Browse files Browse the repository at this point in the history
Signed-off-by: Harikrishnan Balagopal <[email protected]>
  • Loading branch information
HarikrishnanBalagopal committed Dec 20, 2024
1 parent d7f06f5 commit af1961c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
27 changes: 20 additions & 7 deletions tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,21 @@ def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized):


### Data format 2
def _get_dataset_formatting_handlers(data_args, packing):
def _get_dataset_formatting_handlers(data_args, packing, padding_free: str = ""):

if data_args.response_template is None:
if packing is False:
raise ValueError(
"Since dataset_text_field or data_formatter_template \
is provided and packing is disabled, \
needs a corresponding response template for masking"
)
if padding_free:
logger.debug(
"when packing is false but padding_free is used and"
+ " no response template is used then its a pretrained scenario."
)
else:
raise ValueError(
"Since dataset_text_field or data_formatter_template \
is provided and packing is disabled, \
needs a corresponding response template for masking"
)

if data_args.response_template:
# To use Response template, pass datasets with single sequence instances \
Expand Down Expand Up @@ -209,6 +215,7 @@ def _process_raw_data_args(
packing: bool,
max_seq_length: int,
additional_data_handlers: Dict[str, Callable] = None,
padding_free: str = "",
):

# Create a data processor with default processor config
Expand Down Expand Up @@ -266,7 +273,9 @@ def _process_raw_data_args(
elif data_args.data_formatter_template or data_args.dataset_text_field:
# Data Format 3: Single Sequence Dataset
handlers, dataset_text_field = _get_dataset_formatting_handlers(
data_args, packing
data_args,
packing,
padding_free=padding_free,
)
else:
# Default Data Format: Dataset with Input/Output Fields
Expand Down Expand Up @@ -300,6 +309,7 @@ def process_dataargs(
tokenizer: AutoTokenizer,
train_args: TrainingArguments,
additional_data_handlers: Dict[str, Callable] = None,
padding_free: str = "",
):
"""
Args:
Expand All @@ -310,6 +320,8 @@ def process_dataargs(
Used for packing and max_seq_length
additional_data_handlers: A Dict of [str, callable] data handlers
which need to be registered with the data preprocessor
padding_free: str
padding free method
Returns:
Tuple(Dataset, Dataset, str, DataCollator, int, Dict)
tuple containing
Expand Down Expand Up @@ -345,6 +357,7 @@ def process_dataargs(
train_args.packing,
max_seq_length,
additional_data_handlers,
padding_free=padding_free,
)

# Note: This check should not be removed.
Expand Down
12 changes: 11 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,24 @@ def train(
logger.info("Packing is set to %s ", train_args.packing)

data_preprocessing_time = time.time()
padding_free = ""
if attention_and_distributed_packing_config:
if attention_and_distributed_packing_config.padding_free:
padding_free = attention_and_distributed_packing_config.padding_free
(
formatted_train_dataset,
formatted_validation_dataset,
data_args.dataset_text_field,
data_collator,
train_args.max_seq_length,
dataset_kwargs,
) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers)
) = process_dataargs(
data_args,
tokenizer,
train_args,
additional_data_handlers,
padding_free=padding_free,
)
additional_metrics["data_preprocessing_time"] = (
time.time() - data_preprocessing_time
)
Expand Down

0 comments on commit af1961c

Please sign in to comment.