Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support datasets with input and output and no templates #2

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForSeq2Seq, TrainingArguments

from trl import SFTTrainer
from trl.import_utils import is_peft_available
Expand Down Expand Up @@ -486,6 +486,44 @@ def test_sft_trainer_with_model(self):

assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")

# Tests for no packing, with no formatting func or dataset_text_field
# If no input/output cols exist, we should throw a KeyError
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
)
with pytest.raises(KeyError):
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
dataset_text_field=None,
formatting_func=None,
max_seq_length=16,
packing=False,
)

# if we have input/output cols, then things should work without issue
dataset_with_input_output = self.dummy_dataset.rename_column("question", "input").rename_column("answer", "output")
trainer = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=dataset_with_input_output,
dataset_text_field=None,
formatting_func=None,
max_seq_length=16,
packing=False,
)
assert isinstance(trainer.data_collator, DataCollatorForSeq2Seq)
trainer.train()
assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")

def test_sft_trainer_with_multiple_eval_datasets(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
Expand Down
73 changes: 66 additions & 7 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AutoTokenizer,
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForSeq2Seq,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
Expand Down Expand Up @@ -169,6 +170,12 @@ def __init__(
"You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument."
)

# TODO: think about this error handling and if we want to enforce seq2seq collator
if not packing and formatting_func is None and dataset_text_field is None and data_collator is not None and not isinstance(data_collator, DataCollatorForSeq2Seq):
raise ValueError(
"If no formatting_func / dataset_text_field provided, the data_collator should be a `DataCollatorForSeq2Seq` object"
)

if is_peft_available() and peft_config is not None:
if not isinstance(peft_config, PeftConfig):
raise ValueError(
Expand Down Expand Up @@ -245,14 +252,14 @@ def make_inputs_require_grad(module, input, output):
# if not stays #None
formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer)

requires_input_output_keys = False
if not packing:
if dataset_text_field is None and formatting_func is None:
raise ValueError(
"You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument."
)

requires_input_output_keys = (dataset_text_field is None and formatting_func is None)
if data_collator is None:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Fall back to the appropriate collator type based on the input_output_keys
data_collator = (DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)
if requires_input_output_keys
else DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False))

# Pre-process the datasets only once per node. The remaining processes will use the cache.
with PartialState().local_main_process_first():
Expand All @@ -269,6 +276,7 @@ def make_inputs_require_grad(module, input, output):
num_of_sequences,
chars_per_token,
remove_unused_columns=args.remove_unused_columns if args is not None else True,
requires_input_output_keys=requires_input_output_keys,
**dataset_kwargs,
)
if eval_dataset is not None:
Expand Down Expand Up @@ -365,6 +373,7 @@ def _prepare_dataset(
num_of_sequences,
chars_per_token,
remove_unused_columns=True,
requires_input_output_keys=False,
append_concat_token=True,
add_special_tokens=True,
):
Expand All @@ -384,6 +393,7 @@ def _prepare_dataset(
formatting_func,
add_special_tokens,
remove_unused_columns,
requires_input_output_keys,
)

else:
Expand All @@ -408,10 +418,47 @@ def _prepare_non_packed_dataloader(
formatting_func=None,
add_special_tokens=True,
remove_unused_columns=True,
requires_input_output_keys=False,
):
use_formatting_func = formatting_func is not None and dataset_text_field is None
self._dataset_sanity_checked = False

# TODO : fix how EOS tokens are handled
# Inspired from https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L266
def tokenize_input_output(element):

# It is difficult to add special tokens here, as separator / EOS tokens that may be added while tokenizing
# input texts may differ from concatenated text, making masking on input length incorrect.
# EOS and BOS tokens can be added to input / output texts beforehand by user if needed.
# TODO: we may need to change default of add_special_tokens to False.
if add_special_tokens:
warnings.warn(
"Add special tokens is not supported for this type of data format. Hence flag will be ignored."
)

new_source = []
for (input_element, output_element) in zip(element['input'], element['output']):
if not input_element.endswith((' ', '\n', '\t')) and not output_element.startswith((' ', '\n', '\t')):
new_source.append(input_element + ' ' + output_element)
else:
new_source.append(input_element + output_element)

tokenized_example = tokenizer(new_source, max_length=max_seq_length, truncation=True, padding=False)
input_ids = tokenized_example.input_ids
labels = input_ids

# mask the prompt part for avoiding loss
tokenized_prompt = tokenizer(element['input'], max_length=max_seq_length, truncation=True, padding=False)

new_labels = [([-100] * len(tokenized_instance)) + label_instance[len(tokenized_instance):] for tokenized_instance,label_instance in zip(tokenized_prompt.input_ids, labels) ]
attention_mask = tokenized_example.attention_mask

return {
'input_ids': input_ids,
'labels': new_labels,
'attention_mask': attention_mask,
}

# Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
def tokenize(element):
outputs = tokenizer(
Expand Down Expand Up @@ -444,8 +491,20 @@ def tokenize(element):
f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns."
)

if requires_input_output_keys:
if "input" in dataset.column_names and "output" in dataset.column_names:
# TODO: if we execute this input path, it is expected that we are using a seq2seq
# collator. If that is the case, the tokenizer should had a pad_token; this is set
# to eos automatically if it's unset and no tokenizer is provided, but we should
# properly handle if a tokenizer with no padding token is given.
tokenize_func = tokenize_input_output
else:
raise KeyError("Missing input / output keys")
else:
tokenize_func = tokenize

tokenized_dataset = dataset.map(
tokenize,
tokenize_func,
batched=True,
remove_columns=dataset.column_names if remove_unused_columns else None,
num_proc=self.dataset_num_proc,
Expand Down