Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add formatting function alpaca #161

Merged
merged 17 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ pip install -e ".[aim]"
```

## Data format
The data format expectation is a single column text. The trainer is configured to expect a response template as a string. For example, if one wants to prepare the `alpaca` format data to feed into this trainer, it is quite easy and can be done with the following code.
We support two data formats:
Ssukriti marked this conversation as resolved.
Show resolved Hide resolved

1. #### Pre-process the JSON/JSONL dataset
Pre-process the JSON/JSONL dataset to contain a single sequence of each data instance containing input + Response. The trainer is configured to expect a response template as a string. For example, if one wants to prepare the `alpaca` format data to feed into this trainer, it is quite easy and can be done with the following code.

```python
PROMPT_DICT = {
Expand Down Expand Up @@ -56,6 +59,23 @@ The `response template` corresponding to the above dataset and the `Llama` token

The same way can be applied to any dataset, with more info can be found [here](https://huggingface.co/docs/trl/main/en/sft_trainer#format-your-input-prompts).

Once the JSON is converted using the formatting function, pass the `dataset_text_field` containing the single sequence to the trainer.

2. #### Format JSON/JSONL on the fly
Pass a JSON/JSONL and a `data_formatter_template` to use the formatting function on the fly while tuning. The template should specify fields of JSON with `{{field}}`. While tuning, the data will be converted to a single sequence using the template.

Example: Train.json
`[{ "input" : <text>,
"output" : <text>,
},
...
]`
data_formatter_template: `### Input: {{input}} \n\n##Label: {{output}}`

Formatting will happen on the fly while tuning. The keys in template should match fields in JSON file. The `response template` corresponding to the above template will need to be supplied. in this case, `response template` = `\n## Label:`.


##### In conclusion, either the `data_formatter_template` argument or `dataset_text_field` needs to be supplied to the trainer.

## Supported Models

Expand All @@ -64,12 +84,16 @@ Current supported and tested models are `Llama2` (7 and 13B configurations have
## Training

### Single GPU

1. Using pre-processed dataset for training.

```bash
# if you want to use one GPU on multi-gpu machine
export CUDA_VISIBLE_DEVICES=0

# MODEL_PATH=meta-llama/Llama-2-7b-hf # Huggingface model id or path to a checkpoint
# TRAIN_DATA_PATH=twitter_complaints.json # Path to the dataset
# contains data in single sequence {"output": "### Input: text \n\n### Response: text"}
# OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved

python tuning/sft_trainer.py \
Expand All @@ -94,6 +118,39 @@ python tuning/sft_trainer.py \

```

2. Using formatter with JSON/JSONL files

```bash
# if you want to use one GPU on multi-gpu machine
export CUDA_VISIBLE_DEVICES=0

# MODEL_PATH=meta-llama/Llama-2-7b-hf # Huggingface model id or path to a checkpoint
# TRAIN_DATA_PATH=twitter_complaints.json # Path to the dataset
# contains data in form of [{"input": text , "output": text}]
# OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved

python tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--training_data_path $TRAIN_DATA_PATH \
--output_dir $OUTPUT_PATH \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--eval_strategy "no" \
--save_strategy "epoch" \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--include_tokens_per_second \
--packing False \
--response_template "\n## Label:" \
--data_formatter_template: "### Input: {{input}} \n\n##Label: {{output}}"

```

### Multiple GPUs with FSDP

The recommendation is to use [huggingface accelerate](https://huggingface.co/docs/accelerate/en/index) to launch multi-gpu jobs, in particular when using FSDP:
Expand Down
1 change: 1 addition & 0 deletions tests/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
### Constants used for data
DATA_DIR = os.path.join(os.path.dirname(__file__))
TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json")
TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join(DATA_DIR, "twitter_complaints_json.json")
EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json")
MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json")
12 changes: 12 additions & 0 deletions tests/data/twitter_complaints_json.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[
{"Tweet text":"@HMRCcustomers No this is my first job","ID":0,"Label":2,"text_label":"no complaint","output":"### Text: @HMRCcustomers No this is my first job\n\n### Label: no complaint"},
{"Tweet text":"@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.","ID":1,"Label":2,"text_label":"no complaint","output":"### Text: @KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.\n\n### Label: no complaint"},
{"Tweet text":"If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService","ID":2,"Label":1,"text_label":"complaint","output":"### Text: If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService\n\n### Label: complaint"},
{"Tweet text":"@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.","ID":3,"Label":1,"text_label":"complaint","output":"### Text: @EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.\n\n### Label: complaint"},
{"Tweet text":"Couples wallpaper, so cute. :) #BrothersAtHome","ID":4,"Label":2,"text_label":"no complaint","output":"### Text: Couples wallpaper, so cute. :) #BrothersAtHome\n\n### Label: no complaint"},
{"Tweet text":"@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG","ID":5,"Label":2,"text_label":"no complaint","output":"### Text: @mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG\n\n### Label: no complaint"},
{"Tweet text":"@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?","ID":6,"Label":2,"text_label":"no complaint","output":"### Text: @Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?\n\n### Label: no complaint"},
{"Tweet text":"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?","ID":7,"Label":1,"text_label":"complaint","output":"### Text: @nationalgridus I have no water and the bill is current and paid. Can you do something about this?\n\n### Label: complaint"},
{"Tweet text":"Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora","ID":8,"Label":1,"text_label":"complaint","output":"### Text: Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora\n\n### Label: complaint"},
{"Tweet text":"@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd","ID":9,"Label":2,"text_label":"no complaint","output":"### Text: @JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd\n\n### Label: no complaint"}
]
116 changes: 114 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@

# First Party
from scripts.run_inference import TunedCausalLM
from tests.data import EMPTY_DATA, MALFORMATTED_DATA, TWITTER_COMPLAINTS_DATA
from tests.data import (
EMPTY_DATA,
MALFORMATTED_DATA,
TWITTER_COMPLAINTS_DATA,
TWITTER_COMPLAINTS_JSON_FORMAT,
)

# Local
from tuning import sft_trainer
Expand Down Expand Up @@ -114,6 +119,70 @@ def test_run_causallm_pt_and_inference():
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference


def test_run_causallm_pt_and_inference_with_formatting_data():
"""Check if we can bootstrap and peft tune causallm models
This test needs the trainer to format data to a single sequence internally.
"""
with tempfile.TemporaryDirectory() as tempdir:
data_formatting_args = copy.deepcopy(DATA_ARGS)
data_formatting_args.dataset_text_field = None
data_formatting_args.data_formatter_template = (
"### Text: {{Tweet text}} \n\n### Label: {{text_label}}"
)

train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args, PEFT_PT_ARGS)

# validate peft tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)

# Run inference on the text
output_inference = loaded_model.run(
"### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
)
assert len(output_inference) > 0
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference


def test_run_causallm_pt_and_inference_JSON_file_formatter():
"""Check if we can bootstrap and peft tune causallm models with JSON train file format"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT
data_args.dataset_text_field = None
data_args.data_formatter_template = (
"### Text: {{Tweet text}} \n\n### Label: {{text_label}}"
)

sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS)

# validate peft tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)

# Run inference on the text
output_inference = loaded_model.run(
"### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
)
assert len(output_inference) > 0
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference


def test_run_causallm_pt_init_text():
"""Check if we can bootstrap and peft tune causallm models with init text as 'TEXT'"""
with tempfile.TemporaryDirectory() as tempdir:
Expand Down Expand Up @@ -174,6 +243,23 @@ def test_run_causallm_pt_with_validation():
_validate_training(tempdir, check_eval=True)


def test_run_causallm_pt_with_validation_data_formatting():
"""Check if we can bootstrap and peft tune causallm models with validation dataset"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.eval_strategy = "epoch"
data_args = copy.deepcopy(DATA_ARGS)
data_args.validation_data_path = TWITTER_COMPLAINTS_DATA
data_args.dataset_text_field = None
data_args.data_formatter_template = (
"### Text: {{Tweet text}} \n\n### Label: {{text_label}}"
)

sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS)
_validate_training(tempdir, check_eval=True)


############################# Lora Tests #############################

target_modules_val_map = [
Expand Down Expand Up @@ -335,6 +421,30 @@ def test_invalid_dataset_text_field():
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS)


### Tests that giving dataset_text_field as well as formatter template gives error
def test_invalid_dataset_text_field_and_formatter_template():
"""Only one of dataset_text_field or formatter can be supplied"""
data_args = copy.deepcopy(DATA_ARGS)
data_args.data_formatter_template = (
"### Text: {{Tweet text}} \n\n### Label: {{text_label}}"
)

with pytest.raises(ValueError):
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS)


### Tests passing formatter with invalid keys gives error
def test_invalid_formatter_template():
data_args = copy.deepcopy(DATA_ARGS)
data_args.dataset_text_field = None
data_args.data_formatter_template = (
"### Text: {{not found}} \n\n### Label: {{text_label}}"
)

with pytest.raises(KeyError):
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS)


### Tests for bad training data (i.e., data_path is an unhappy value or points to an unhappy thing)
def test_malformatted_data():
"""Ensure that malformatted data explodes due to failure to generate the dataset."""
Expand Down Expand Up @@ -382,13 +492,15 @@ def test_run_causallm_lora_with_invalid_modules():


### Direct validation tests based on whether or not packing is enabled
def test_no_packing_needs_dataset_text_field():
def test_no_packing_needs_dataset_text_field_or_data_formatter_template():
"""Ensure we need to set the dataset text field if packing is False"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
data_args = copy.deepcopy(DATA_ARGS)
# One of dataset_text_field or data_formatter_template should be set
data_args.dataset_text_field = None
data_args.data_formatter_template = None

with pytest.raises(ValueError):
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS)
Expand Down
66 changes: 66 additions & 0 deletions tests/utils/test_data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Third Party
import datasets
import pytest

# First Party
from tests.data import TWITTER_COMPLAINTS_DATA

# Local
from tuning.utils import data_utils


def test_apply_custom_formatting_template():
json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA)
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
alex-jw-brooks marked this conversation as resolved.
Show resolved Hide resolved
# First response from the data file that is read.
expected_response = (
"### Input: @HMRCcustomers No this is my first job"
+ " \n\n ### Response: no complaint"
)
formatted_dataset, dataset_text_field = data_utils.apply_custom_formatting_template(
json_dataset, template
)
# a new dataset_text_field is created in Dataset
assert dataset_text_field in formatted_dataset["train"][0]
assert formatted_dataset["train"][0][dataset_text_field] == expected_response


def test_apply_custom_formatting_template_adds_eos_token():
json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA)
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
# First response from the data file that is read.
expected_response = (
"### Input: @HMRCcustomers No this is my first job"
+ " \n\n ### Response: no complaintEOS"
)
formatted_dataset, dataset_text_field = data_utils.apply_custom_formatting_template(
json_dataset, template, "EOS"
)
# a new dataset_text_field is created in Dataset
assert dataset_text_field in formatted_dataset["train"][0]
assert formatted_dataset["train"][0][dataset_text_field] == expected_response


def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
"""Tests that the formatting function will throw error if wrong keys are passed to template"""
json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA)
template = "### Input: {{not found}} \n\n ### Response: {{text_label}}"
with pytest.raises(KeyError):
data_utils.apply_custom_formatting_template(json_dataset, template, "EOS")
23 changes: 20 additions & 3 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,34 @@ class ModelArguments:
@dataclass
class DataArguments:
training_data_path: str = field(
default=None, metadata={"help": "Path to the training data in JSONL format."}
default=None,
metadata={"help": "Path to the training data in JSON/JSONL format."},
)
response_template: str = field(
default=None,
metadata={"help": "Response template, separator to train on completions only"},
)
dataset_text_field: str = field(
default=None, metadata={"help": "Training dataset text field"}
default=None,
metadata={
"help": "Training dataset text field containing single sequence. \
Either the dataset_text_field \
or data_formatter_template need to be supplied."
},
)
validation_data_path: str = field(
default=None, metadata={"help": "Path to the validation data in JSONL format."}
default=None,
metadata={"help": "Path to the validation data in JSON/JSONL format."},
)
data_formatter_template: str = field(
default=None,
metadata={
"help": "formatter template to format a single sequence \
from each instance in JSONL files. \
Keys of JSON can be referred to as {{key}} in template. \
Either the dataset_text_field \
or data_formatter_template needs to be supplied."
},
)


Expand Down
Loading
Loading