Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add formatting function alpaca #161

Merged
merged 17 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
"save_strategy": "epoch",
alex-jw-brooks marked this conversation as resolved.
Show resolved Hide resolved
"output_dir": "tmp",
}

BASE_LORA_KWARGS = copy.deepcopy(BASE_PEFT_KWARGS)
BASE_LORA_KWARGS["peft_method"] = "lora"

Expand Down Expand Up @@ -148,6 +147,40 @@ def test_run_causallm_pt_and_inference():
assert len(output_inference) > 0
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference

def test_run_causallm_pt_and_inference_with_formatting_data():
"""Check if we can bootstrap and peft tune causallm models
This test needs the trainer to format data to a single sequence internally.
"""
with tempfile.TemporaryDirectory() as tempdir:
data_formatting_args = copy.deepcopy(BASE_PEFT_KWARGS)
del data_formatting_args["dataset_text_field"]
data_formatting_args["data_formatter_template"] = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"

TRAIN_KWARGS = {**data_formatting_args, **{"output_dir": tempdir}}

model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
# Just double checking that formnatter is set
assert data_args.data_formatter_template is not None

sft_trainer.train(model_args, data_args, training_args, tune_config)

# validate peft tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", data_formatting_args)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)

# Run inference on the text
output_inference = loaded_model.run(
"### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
)
assert len(output_inference) > 0
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference

def test_run_causallm_pt_init_text():
"""Check if we can bootstrap and peft tune causallm models with init text as 'TEXT'"""
Expand Down Expand Up @@ -381,6 +414,32 @@ def test_invalid_dataset_text_field():
with pytest.raises(KeyError):
sft_trainer.train(model_args, data_args, training_args, tune_config)

### Tests that giving dataset_text_field as well as formatter template gives error
def test_invalid_dataset_text_field_and_formatter_template():
"""Only one of dataset_text_field or formatter can be supplied"""
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"data_formatter_template": "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(ValueError):
sft_trainer.train(model_args, data_args, training_args, tune_config)

### Tests passing formatter with invalid keys gives error
def test_invalid_formatter_template():
data_formatting_args = copy.deepcopy(BASE_PEFT_KWARGS)
del data_formatting_args["dataset_text_field"]
TRAIN_KWARGS = {
**data_formatting_args,
**{"data_formatter_template": "### Input: {{not found}} \n\n ### Response: {{text_label}}"},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(KeyError):
sft_trainer.train(model_args, data_args, training_args, tune_config)

### Tests for bad training data (i.e., data_path is an unhappy value or points to an unhappy thing)
def test_malformatted_data():
Expand Down
51 changes: 51 additions & 0 deletions tests/utils/test_data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# First Party
from tests.data import TWITTER_COMPLAINTS_DATA
from tuning.utils import data_utils

# Third Party
import datasets
import pytest

def test_formatting_function():
json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA)
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
alex-jw-brooks marked this conversation as resolved.
Show resolved Hide resolved
# First response from the data file that is read.
expected_response = "### Input: @HMRCcustomers No this is my first job \n\n ### Response: no complaint"
formatted_dataset, dataset_text_field = data_utils.formatting_function(json_dataset, template)
# a new dataset_text_field is created in Dataset
assert dataset_text_field in formatted_dataset['train'][0]
assert formatted_dataset['train'][0][dataset_text_field] == expected_response

def test_formatting_function_adds_eos_token():
json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA)
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
# First response from the data file that is read.
expected_response = "### Input: @HMRCcustomers No this is my first job \n\n ### Response: no complaintEOS"
alex-jw-brooks marked this conversation as resolved.
Show resolved Hide resolved
formatted_dataset, dataset_text_field = data_utils.formatting_function(json_dataset, template, 'EOS')
# a new dataset_text_field is created in Dataset
assert dataset_text_field in formatted_dataset['train'][0]
assert formatted_dataset['train'][0][dataset_text_field] == expected_response

def test_formatting_function_gives_error_with_wrong_keys():
"""Tests that the formatting function will throw error if wrong keys are passed to template"""
json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA)
template = "### Input: {{not found}} \n\n ### Response: {{text_label}}"
with pytest.raises(KeyError):
data_utils.formatting_function(json_dataset, template, 'EOS')
9 changes: 8 additions & 1 deletion tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,18 @@ class DataArguments:
metadata={"help": "Response template, separator to train on completions only"},
)
dataset_text_field: str = field(
default=None, metadata={"help": "Training dataset text field"}
default=None, metadata={"help": "Training dataset text field containing single sequence. \
Either the dataset_text_field or data_formatter_template need to be supplied."}
)
validation_data_path: str = field(
default=None, metadata={"help": "Path to the validation data in JSONL format."}
)
data_formatter_template: str = field(
default=None, metadata={"help": "formatter template to format a single sequence from each instance in JSONL files. \
Keys of JSON can be referred to as {{key}} in template. Either the dataset_text_field \
or data_formatter_template needs to be supplied."
}
)


@dataclass
Expand Down
24 changes: 20 additions & 4 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from tuning.trainercontroller import TrainerControllerCallback
from tuning.utils.config_utils import get_hf_peft_config
from tuning.utils.data_type_utils import get_torch_dtype
from tuning.utils.data_utils import formatting_function


def train(
Expand Down Expand Up @@ -218,15 +219,22 @@ def train(
# TODO: Fix this, currently unreachable due to crashing in batch encoding tokenization
# We should do this validation up front, then do the encoding, then handle the collator
raise ValueError("Response template is None, needs to be set for training")
if data_args.dataset_text_field is None:
raise ValueError("Dataset_text_field is None, needs to be set for training")
data_collator = DataCollatorForCompletionOnlyLM(
response_template_ids,
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)
packing = False

# Currently we support formatted datasets with single sequence instances.
if (data_args.dataset_text_field is None) and (data_args.data_formatter_template is None):
alex-jw-brooks marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Dataset_text_field and data_formatter_template are None. \
One of them needs to be set for training")
# Only one of dataset_text_field or data_formatter_template should be set.
if data_args.dataset_text_field and data_args.data_formatter_template:
raise ValueError("Dataset_text_field and data_formatter_template are set. \
Only one of them needs to be set for training")

# load the data by parsing JSON
data_files = {"train": data_args.training_data_path}
if data_args.validation_data_path:
Expand All @@ -238,12 +246,20 @@ def train(
}

json_dataset = datasets.load_dataset("json", data_files=data_files)
formatted_train_dataset = json_dataset["train"].map(format_dataset)
if data_args.data_formatter_template:
formatted_train_dataset, data_args.dataset_text_field = \
formatting_function(json_dataset["train"], data_args.data_formatter_template, tokenizer.eos_token)
Ssukriti marked this conversation as resolved.
Show resolved Hide resolved
else:
formatted_train_dataset = json_dataset["train"].map(format_dataset)
logger.info("Training dataset length is %s", len(formatted_train_dataset))

formatted_validation_dataset = None
if data_args.validation_data_path:
formatted_validation_dataset = json_dataset["validation"].map(format_dataset)
if data_args.data_formatter_template:
alex-jw-brooks marked this conversation as resolved.
Show resolved Hide resolved
formatted_validation_dataset, data_args.dataset_text_field = \
formatting_function(json_dataset["validation"], data_args.data_formatter_template, tokenizer.eos_token)
else:
formatted_validation_dataset = json_dataset["validation"].map(format_dataset)
logger.info(
"Validation dataset length is %s", len(formatted_validation_dataset)
)
Expand Down
34 changes: 34 additions & 0 deletions tuning/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import re
from datasets import Dataset

def formatting_function(dataset, template, eos_token=""):
alex-jw-brooks marked this conversation as resolved.
Show resolved Hide resolved
"""Function to format datasets with Alpaca style / other templates.
Args:
dataset: the HF Dataset element loaded from a JSON or DatasetDict object.
template: Template to format data with. Features of Dataset
should be referred to by {{key}}
Ssukriti marked this conversation as resolved.
Show resolved Hide resolved
Returns:
Formatted HF Dataset, dataset_field name that contains formatted data.
"""

formatted_dataset_field = "formatted_data_field"
template += eos_token
alex-jw-brooks marked this conversation as resolved.
Show resolved Hide resolved
def formatter(element):
nonlocal template
Ssukriti marked this conversation as resolved.
Show resolved Hide resolved

def replace_text(match_obj):
captured_groups = match_obj.groups()
if len(captured_groups) != 1:
raise ValueError(
"Unexpectedly captured multiple groups in verbalizer rendering"
Ssukriti marked this conversation as resolved.
Show resolved Hide resolved
)

index_object = captured_groups[0]
if index_object not in element:
raise KeyError("Requested template string is not a valid key in dict")

return element[index_object]

return {formatted_dataset_field : re.sub(r"{{([_a-z\sA-Z0-9]+)}}", replace_text, template)}
alex-jw-brooks marked this conversation as resolved.
Show resolved Hide resolved

return dataset.map(formatter), formatted_dataset_field
Loading