Skip to content

Commit

Permalink
add data collator for padding free plugin scenario to be used for ext…
Browse files Browse the repository at this point in the history
…ended pretraining

Signed-off-by: Dushyant Behl <[email protected]>
  • Loading branch information
dushyantbehl authored and kmehant committed Jan 9, 2025
1 parent e1b1ff0 commit 9612ac4
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 24 deletions.
62 changes: 49 additions & 13 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def test_is_pretokenized_data(data, result):

@pytest.mark.parametrize(
"packing, response_template, formatted_train_dataset,\
max_seq_length, instruction_template, expected_collator",
max_seq_length, instruction_template, is_padding_free, expected_collator",
[
(
False,
Expand All @@ -501,6 +501,7 @@ def test_is_pretokenized_data(data, result):
),
1024,
None,
False,
DataCollatorForCompletionOnlyLM,
),
(
Expand All @@ -517,6 +518,7 @@ def test_is_pretokenized_data(data, result):
),
1024,
None,
False,
DataCollatorForSeq2Seq,
),
(
Expand All @@ -529,6 +531,7 @@ def test_is_pretokenized_data(data, result):
),
1024,
"\n### Text:",
False,
DataCollatorForCompletionOnlyLM,
),
(
Expand All @@ -545,6 +548,20 @@ def test_is_pretokenized_data(data, result):
),
1024,
"\n### Text:",
False,
DataCollatorForSeq2Seq,
),
(
False,
None,
datasets.load_dataset(
"json",
data_files=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
split="train",
),
1024,
None,
True,
DataCollatorForSeq2Seq,
),
],
Expand All @@ -555,6 +572,7 @@ def test_get_data_collator(
formatted_train_dataset,
max_seq_length,
instruction_template,
is_padding_free,
expected_collator,
):
"""Ensure that the correct collator type is fetched based on the data args"""
Expand All @@ -565,6 +583,7 @@ def test_get_data_collator(
is_pretokenized_dataset(formatted_train_dataset),
max_seq_length,
instruction_template,
is_padding_free,
)
assert isinstance(collator, expected_collator)

Expand Down Expand Up @@ -1044,7 +1063,7 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(


@pytest.mark.parametrize(
"data_args",
"data_args, is_padding_free",
[
# single sequence JSON and response template
(
Expand All @@ -1053,7 +1072,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_JSON,
dataset_text_field="output",
response_template="\n### Label:",
)
),
False,
),
# single sequence JSONL and response template
(
Expand All @@ -1062,7 +1082,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_JSONL,
dataset_text_field="output",
response_template="\n### Label:",
)
),
False,
),
# single sequence PARQUET and response template
(
Expand All @@ -1071,7 +1092,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_PARQUET,
dataset_text_field="output",
response_template="\n### Label:",
)
),
False,
),
# data formatter template with input/output JSON
(
Expand All @@ -1080,7 +1102,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
response_template="\n### Label:",
)
),
False,
),
# data formatter template with input/output JSONL
(
Expand All @@ -1089,7 +1112,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
response_template="\n### Label:",
)
),
False,
),
# data formatter template with input/output PARQUET
(
Expand All @@ -1098,32 +1122,44 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
response_template="\n### Label:",
)
),
False,
),
# input/output JSON with masking on input
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
)
),
False,
),
# input/output JSONL with masking on input
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
)
),
False,
),
# input/output PARQUET with masking on input
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
)
),
False,
),
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_JSON,
validation_data_path=TWITTER_COMPLAINTS_DATA_JSON,
dataset_text_field="output",
),
True,
),
],
)
def test_process_dataargs(data_args):
def test_process_dataargs(data_args, is_padding_free):
"""Ensure that the train/eval data are properly formatted based on the data args / text field"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
TRAIN_ARGS = configs.TrainingArguments(
Expand All @@ -1132,7 +1168,7 @@ def test_process_dataargs(data_args):
output_dir="tmp", # Not needed but positional
)
(train_set, eval_set, dataset_text_field, _, _, _) = process_dataargs(
data_args, tokenizer, TRAIN_ARGS
data_args, tokenizer, TRAIN_ARGS, is_padding_free=is_padding_free
)
assert isinstance(train_set, Dataset)
assert isinstance(eval_set, Dataset)
Expand Down
13 changes: 13 additions & 0 deletions tuning/data/data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_data_collator(
is_traindata_tokenized: bool,
max_seq_length: int,
instruction_template: Optional[str],
is_padding_free: bool = False,
) -> Callable:
"""Create and return the the appropriate collator type based on the configuration for packing,
response_template, and dataset_text_field.
Expand All @@ -46,6 +47,8 @@ def get_data_collator(
Max sequence length expected
instruction_template: str
str representing the human response in a chat template
is_padding_free: bool
if padding free plugin is used or not
Returns:
Callable
Expand Down Expand Up @@ -74,6 +77,16 @@ def get_data_collator(
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)

if is_padding_free:
# when packing is false but padding_free is used and
# no response template is used then its a pretrained scenario.
# Current plugin in fms-acceleration is compatible with
# `DataCollatorForSeq2Seq` collator hence we use this.
return DataCollatorForSeq2Seq(
tokenizer=tokenizer, padding=False, max_length=max_seq_length
)

# Note that this automatically pads labels with -100
# TODO check if this is sufficient for preprocessed
if is_traindata_tokenized:
Expand Down
20 changes: 12 additions & 8 deletions tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized):


### Data format 2
def _get_dataset_formatting_handlers(data_args, packing, padding_free=None):
def _get_dataset_formatting_handlers(data_args, packing, is_padding_free=False):

if data_args.response_template is None:
if packing is False:
if padding_free:
if is_padding_free:
logger.debug(
"Assuming extended pretraining scenario because, packing is false"
+ ", padding_free is used and no response template was provided."
"Assuming extended pretraining scenario because, packing is false,"
+ " padding_free plugin is used and no response template was provided."
)
else:
raise ValueError(
Expand Down Expand Up @@ -215,7 +215,7 @@ def _process_raw_data_args(
packing: bool,
max_seq_length: int,
additional_data_handlers: Dict[str, Callable] = None,
**kwargs,
is_padding_free: bool = False,
):

# Create a data processor with default processor config
Expand Down Expand Up @@ -255,6 +255,7 @@ def _process_raw_data_args(
tokenizer_kwargs = {}
tokenizer_kwargs["max_length"] = max_seq_length
tokenizer_kwargs["truncation"] = True
# Lets not pad in tokenizer...we can handle that in the collator
tokenizer_kwargs["padding"] = False

handlers = None
Expand All @@ -273,7 +274,7 @@ def _process_raw_data_args(
elif data_args.data_formatter_template or data_args.dataset_text_field:
# Data Format 3: Single Sequence Dataset
handlers, dataset_text_field = _get_dataset_formatting_handlers(
data_args, packing, **kwargs
data_args, packing, is_padding_free
)
else:
# Default Data Format: Dataset with Input/Output Fields
Expand Down Expand Up @@ -307,7 +308,7 @@ def process_dataargs(
tokenizer: AutoTokenizer,
train_args: TrainingArguments,
additional_data_handlers: Dict[str, Callable] = None,
**kwargs,
is_padding_free: bool = False,
):
"""
Args:
Expand All @@ -318,6 +319,8 @@ def process_dataargs(
Used for packing and max_seq_length
additional_data_handlers: A Dict of [str, callable] data handlers
which need to be registered with the data preprocessor
is_padding_free: A bool representing if Padding free plugin is enabled.
Defaults to False.
Returns:
Tuple(Dataset, Dataset, str, DataCollator, int, Dict)
tuple containing
Expand Down Expand Up @@ -353,7 +356,7 @@ def process_dataargs(
train_args.packing,
max_seq_length,
additional_data_handlers,
**kwargs,
is_padding_free,
)

# Note: This check should not be removed.
Expand All @@ -368,6 +371,7 @@ def process_dataargs(
is_tokenized_dataset,
max_seq_length,
data_args.instruction_template,
is_padding_free=is_padding_free,
)

dataset_kwargs = {}
Expand Down
7 changes: 4 additions & 3 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,10 @@ def train(
data_collator = None
logger.info("Packing is set to %s ", train_args.packing)

padding_free = None
is_padding_free = False
if attention_and_distributed_packing_config is not None:
padding_free = attention_and_distributed_packing_config.padding_free
is_padding_free = attention_and_distributed_packing_config.is_padding_free

data_preprocessing_time = time.time()
(
formatted_train_dataset,
Expand All @@ -322,7 +323,7 @@ def train(
tokenizer,
train_args,
additional_data_handlers,
padding_free=padding_free,
is_padding_free=is_padding_free,
)
additional_metrics["data_preprocessing_time"] = (
time.time() - data_preprocessing_time
Expand Down

0 comments on commit 9612ac4

Please sign in to comment.