From 54b2624f8a43cc355410d52a1656c58e7464ac3b Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Thu, 16 May 2024 22:44:24 -0600 Subject: [PATCH 01/13] utility functions to format datasets using template Signed-off-by: Sukriti-Sharma4 --- tests/utils/test_data_utils.py | 43 ++++++++++++++++++++++++++++++++++ tuning/config/configs.py | 9 ++++++- tuning/sft_trainer.py | 11 +++++++-- tuning/utils/data_utils.py | 34 +++++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 3 deletions(-) create mode 100644 tests/utils/test_data_utils.py create mode 100644 tuning/utils/data_utils.py diff --git a/tests/utils/test_data_utils.py b/tests/utils/test_data_utils.py new file mode 100644 index 000000000..7e98053c7 --- /dev/null +++ b/tests/utils/test_data_utils.py @@ -0,0 +1,43 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# First Party +from tests.data import TWITTER_COMPLAINTS_DATA +from tuning.utils import data_utils + +# Third Party +import datasets + +def test_formatting_function(): + json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + # First response from the data file that is read. + expected_response = "### Input: @HMRCcustomers No this is my first job \n\n ### Response: no complaint" + formatted_dataset, dataset_text_field = data_utils.formatting_function(json_dataset, template) + # a new dataset_text_field is created in Dataset + assert dataset_text_field in formatted_dataset['train'][0] + assert formatted_dataset['train'][0][dataset_text_field] == expected_response + +def test_formatting_function_adds_eos_token(): + json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + # First response from the data file that is read. + expected_response = "### Input: @HMRCcustomers No this is my first job \n\n ### Response: no complaintEOS" + formatted_dataset, dataset_text_field = data_utils.formatting_function(json_dataset, template, 'EOS') + # a new dataset_text_field is created in Dataset + assert dataset_text_field in formatted_dataset['train'][0] + assert formatted_dataset['train'][0][dataset_text_field] == expected_response \ No newline at end of file diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 247652b7c..13e57ebd0 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -50,11 +50,18 @@ class DataArguments: metadata={"help": "Response template, separator to train on completions only"}, ) dataset_text_field: str = field( - default=None, metadata={"help": "Training dataset text field"} + default=None, metadata={"help": "Training dataset text field containing single sequence. \ + Either the dataset_text_field or data_formatter_template need to be supplied."} ) validation_data_path: str = field( default=None, metadata={"help": "Path to the validation data in JSONL format."} ) + data_formatter_template: str = field( + default=None, metadata={"help": "formatter template to format a single sequence from each instance in JSONL files. \ + Keys of JSON can be referred to as {{key}} in template. Either the dataset_text_field \ + or data_formatter_template needs to be supplied." + } + ) @dataclass diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index b307505c0..a802cc831 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -218,8 +218,6 @@ def train( # TODO: Fix this, currently unreachable due to crashing in batch encoding tokenization # We should do this validation up front, then do the encoding, then handle the collator raise ValueError("Response template is None, needs to be set for training") - if data_args.dataset_text_field is None: - raise ValueError("Dataset_text_field is None, needs to be set for training") data_collator = DataCollatorForCompletionOnlyLM( response_template_ids, tokenizer=tokenizer, @@ -227,6 +225,15 @@ def train( ) packing = False + # Currently we support formatted datasets with single sequence instances. + if (data_args.dataset_text_field is None) and (data_args.data_formatter_template is None): + raise ValueError("Dataset_text_field and data_formatter_template are None. \ + One of them needs to be set for training") + # Only one of dataset_text_field or data_formatter_template should be set. + if data_args.dataset_text_field and data_args.data_formatter_template: + raise ValueError("Dataset_text_field and data_formatter_template are set. \ + Only one of them needs to be set for training") + # load the data by parsing JSON data_files = {"train": data_args.training_data_path} if data_args.validation_data_path: diff --git a/tuning/utils/data_utils.py b/tuning/utils/data_utils.py new file mode 100644 index 000000000..0546ef8e1 --- /dev/null +++ b/tuning/utils/data_utils.py @@ -0,0 +1,34 @@ +import re +from datasets import Dataset + +def formatting_function(dataset, template, eos_token=""): + """Function to format datasets with Alpaca style / other templates. + Args: + dataset: the HF Dataset element loaded from a JSON or DatasetDict object. + template: Template to format data with. Features of Dataset + should be referred to by {{key}} + Returns: + Formatted HF Dataset, dataset_field name that contains formatted data. + """ + + formatted_dataset_field = "formatted_data_field" + template += eos_token + def formatter(element): + nonlocal template + + def replace_text(match_obj): + captured_groups = match_obj.groups() + if len(captured_groups) != 1: + raise ValueError( + "Unexpectedly captured multiple groups in verbalizer rendering" + ) + + index_object = captured_groups[0] + if index_object not in element: + raise KeyError("Requested template string is not a valid key in dict") + + return element[index_object] + + return {formatted_dataset_field : re.sub(r"{{([_a-z\sA-Z0-9]+)}}", replace_text, template)} + + return dataset.map(formatter), formatted_dataset_field \ No newline at end of file From d1cc787166e1d4537724d90485ad56e58f72eb58 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Fri, 17 May 2024 16:25:04 -0600 Subject: [PATCH 02/13] add tests and formatter as arg Signed-off-by: Sukriti-Sharma4 --- tests/test_sft_trainer.py | 61 +++++++++++++++++++++++++++++++++- tests/utils/test_data_utils.py | 10 +++++- tuning/sft_trainer.py | 15 +++++++-- 3 files changed, 81 insertions(+), 5 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index c23c7e2c5..fc226c5af 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -63,7 +63,6 @@ "save_strategy": "epoch", "output_dir": "tmp", } - BASE_LORA_KWARGS = copy.deepcopy(BASE_PEFT_KWARGS) BASE_LORA_KWARGS["peft_method"] = "lora" @@ -148,6 +147,40 @@ def test_run_causallm_pt_and_inference(): assert len(output_inference) > 0 assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference +def test_run_causallm_pt_and_inference_with_formatting_data(): + """Check if we can bootstrap and peft tune causallm models + This test needs the trainer to format data to a single sequence internally. + """ + with tempfile.TemporaryDirectory() as tempdir: + data_formatting_args = copy.deepcopy(BASE_PEFT_KWARGS) + del data_formatting_args["dataset_text_field"] + data_formatting_args["data_formatter_template"] = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + + TRAIN_KWARGS = {**data_formatting_args, **{"output_dir": tempdir}} + + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + # Just double checking that formnatter is set + assert data_args.data_formatter_template is not None + + sft_trainer.train(model_args, data_args, training_args, tune_config) + + # validate peft tuning configs + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + adapter_config = _get_adapter_config(checkpoint_path) + _validate_adapter_config(adapter_config, "PROMPT_TUNING", data_formatting_args) + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path) + + # Run inference on the text + output_inference = loaded_model.run( + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference def test_run_causallm_pt_init_text(): """Check if we can bootstrap and peft tune causallm models with init text as 'TEXT'""" @@ -381,6 +414,32 @@ def test_invalid_dataset_text_field(): with pytest.raises(KeyError): sft_trainer.train(model_args, data_args, training_args, tune_config) +### Tests that giving dataset_text_field as well as formatter template gives error +def test_invalid_dataset_text_field_and_formatter_template(): + """Only one of dataset_text_field or formatter can be supplied""" + TRAIN_KWARGS = { + **BASE_PEFT_KWARGS, + **{"data_formatter_template": "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + with pytest.raises(ValueError): + sft_trainer.train(model_args, data_args, training_args, tune_config) + +### Tests passing formatter with invalid keys gives error +def test_invalid_formatter_template(): + data_formatting_args = copy.deepcopy(BASE_PEFT_KWARGS) + del data_formatting_args["dataset_text_field"] + TRAIN_KWARGS = { + **data_formatting_args, + **{"data_formatter_template": "### Input: {{not found}} \n\n ### Response: {{text_label}}"}, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + with pytest.raises(KeyError): + sft_trainer.train(model_args, data_args, training_args, tune_config) ### Tests for bad training data (i.e., data_path is an unhappy value or points to an unhappy thing) def test_malformatted_data(): diff --git a/tests/utils/test_data_utils.py b/tests/utils/test_data_utils.py index 7e98053c7..99322e4b0 100644 --- a/tests/utils/test_data_utils.py +++ b/tests/utils/test_data_utils.py @@ -21,6 +21,7 @@ # Third Party import datasets +import pytest def test_formatting_function(): json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) @@ -40,4 +41,11 @@ def test_formatting_function_adds_eos_token(): formatted_dataset, dataset_text_field = data_utils.formatting_function(json_dataset, template, 'EOS') # a new dataset_text_field is created in Dataset assert dataset_text_field in formatted_dataset['train'][0] - assert formatted_dataset['train'][0][dataset_text_field] == expected_response \ No newline at end of file + assert formatted_dataset['train'][0][dataset_text_field] == expected_response + +def test_formatting_function_gives_error_with_wrong_keys(): + """Tests that the formatting function will throw error if wrong keys are passed to template""" + json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) + template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" + with pytest.raises(KeyError): + data_utils.formatting_function(json_dataset, template, 'EOS') \ No newline at end of file diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index a802cc831..d40d63d80 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -46,6 +46,7 @@ from tuning.trainercontroller import TrainerControllerCallback from tuning.utils.config_utils import get_hf_peft_config from tuning.utils.data_type_utils import get_torch_dtype +from tuning.utils.data_utils import formatting_function def train( @@ -233,7 +234,7 @@ def train( if data_args.dataset_text_field and data_args.data_formatter_template: raise ValueError("Dataset_text_field and data_formatter_template are set. \ Only one of them needs to be set for training") - + # load the data by parsing JSON data_files = {"train": data_args.training_data_path} if data_args.validation_data_path: @@ -245,12 +246,20 @@ def train( } json_dataset = datasets.load_dataset("json", data_files=data_files) - formatted_train_dataset = json_dataset["train"].map(format_dataset) + if data_args.data_formatter_template: + formatted_train_dataset, data_args.dataset_text_field = \ + formatting_function(json_dataset["train"], data_args.data_formatter_template, tokenizer.eos_token) + else: + formatted_train_dataset = json_dataset["train"].map(format_dataset) logger.info("Training dataset length is %s", len(formatted_train_dataset)) formatted_validation_dataset = None if data_args.validation_data_path: - formatted_validation_dataset = json_dataset["validation"].map(format_dataset) + if data_args.data_formatter_template: + formatted_validation_dataset, data_args.dataset_text_field = \ + formatting_function(json_dataset["validation"], data_args.data_formatter_template, tokenizer.eos_token) + else: + formatted_validation_dataset = json_dataset["validation"].map(format_dataset) logger.info( "Validation dataset length is %s", len(formatted_validation_dataset) ) From 6d5593402f149e99473d6c0132232c7b3a5679e8 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Thu, 23 May 2024 15:52:27 -0600 Subject: [PATCH 03/13] update tests to use template to avoid warnings Signed-off-by: Sukriti-Sharma4 --- tests/test_sft_trainer.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 85c134209..3664cc0cf 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -120,7 +120,7 @@ def test_run_causallm_pt_and_inference_with_formatting_data(): with tempfile.TemporaryDirectory() as tempdir: data_formatting_args = copy.deepcopy(DATA_ARGS) data_formatting_args.dataset_text_field = None - data_formatting_args.data_formatter_template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + data_formatting_args.data_formatter_template = "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" train_args = copy.deepcopy(TRAIN_ARGS) train_args.output_dir = tempdir @@ -202,6 +202,19 @@ def test_run_causallm_pt_with_validation(): sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) _validate_training(tempdir, check_eval=True) +def test_run_causallm_pt_with_validation_data_formatting(): + """Check if we can bootstrap and peft tune causallm models with validation dataset""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.eval_strategy = "epoch" + data_args = copy.deepcopy(DATA_ARGS) + data_args.validation_data_path = TWITTER_COMPLAINTS_DATA + data_args.dataset_text_field = None + data_args.data_formatter_template = "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" + + sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) + _validate_training(tempdir, check_eval=True) ############################# Lora Tests ############################# @@ -367,7 +380,7 @@ def test_invalid_dataset_text_field(): def test_invalid_dataset_text_field_and_formatter_template(): """Only one of dataset_text_field or formatter can be supplied""" data_args = copy.deepcopy(DATA_ARGS) - data_args.data_formatter_template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + data_args.data_formatter_template = "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" with pytest.raises(ValueError): sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) @@ -376,7 +389,7 @@ def test_invalid_dataset_text_field_and_formatter_template(): def test_invalid_formatter_template(): data_args = copy.deepcopy(DATA_ARGS) data_args.dataset_text_field = None - data_args.data_formatter_template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" + data_args.data_formatter_template = "### Text: {{not found}} \n\n### Label: {{text_label}}" with pytest.raises(KeyError): sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) @@ -428,13 +441,15 @@ def test_run_causallm_lora_with_invalid_modules(): ### Direct validation tests based on whether or not packing is enabled -def test_no_packing_needs_dataset_text_field(): +def test_no_packing_needs_dataset_text_field_or_data_formatter_template(): """Ensure we need to set the dataset text field if packing is False""" with tempfile.TemporaryDirectory() as tempdir: train_args = copy.deepcopy(TRAIN_ARGS) train_args.output_dir = tempdir data_args = copy.deepcopy(DATA_ARGS) + # One of dataset_text_field or data_formatter_template should be set data_args.dataset_text_field = None + data_args.data_formatter_template = None with pytest.raises(ValueError): sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) From f84a222ae8146de42b9892644fd374189ff9530f Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Thu, 23 May 2024 19:58:06 -0600 Subject: [PATCH 04/13] update README and tests Signed-off-by: Sukriti-Sharma4 --- README.md | 55 ++++++++++++++++++++++++- tests/data/__init__.py | 1 + tests/data/twitter_complaints_json.json | 12 ++++++ tests/test_sft_trainer.py | 31 +++++++++++++- tuning/config/configs.py | 4 +- 5 files changed, 99 insertions(+), 4 deletions(-) create mode 100644 tests/data/twitter_complaints_json.json diff --git a/README.md b/README.md index bb1a95876..5cb448d56 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,9 @@ pip install -e ".[aim]" ``` ## Data format -The data format expectation is a single column text. The trainer is configured to expect a response template as a string. For example, if one wants to prepare the `alpaca` format data to feed into this trainer, it is quite easy and can be done with the following code. +We support two data formats: + +1. Pre-process the JSON/JSONL dataset to contain a single sequence of each data instance containing input + Response. The trainer is configured to expect a response template as a string. For example, if one wants to prepare the `alpaca` format data to feed into this trainer, it is quite easy and can be done with the following code. ```python PROMPT_DICT = { @@ -56,6 +58,21 @@ The `response template` corresponding to the above dataset and the `Llama` token The same way can be applied to any dataset, with more info can be found [here](https://huggingface.co/docs/trl/main/en/sft_trainer#format-your-input-prompts). +Once the JSON is converted using the formatting function, pass the `dataset_text_field` containing the single sequence to the trainer. + +2. Pass a JSON/JSONL and a `data_formatter_template` to use the formatting function on the fly while tuning. The template should specify fields of JSON with `{{field}}`. While tuning, the data will be converted to a single sequence using the template. + +Example: Train.json +`[{ "input" : , + "output" : , + }, + ... +]` + +data_formatter_template: `### Input: {{input}} \n\n##Label: {{output}}` +Formatting will happen pn the fly while tuning. The keys in template should match fields in JSON file. The `response template` corresponding to the above template will need to be supplied. in this case, `response template` = `\n## Label:`. + +In conclusion, either the `data_formatter_template` argument or `dataset_text_field` needs to be supplied to the trainer. ## Supported Models @@ -64,6 +81,9 @@ Current supported and tested models are `Llama2` (7 and 13B configurations have ## Training ### Single GPU + +1. Using pre-processed dataset for training. + ```bash # if you want to use one GPU on multi-gpu machine export CUDA_VISIBLE_DEVICES=0 @@ -94,6 +114,39 @@ python tuning/sft_trainer.py \ ``` +2. Using formatter with JSON/JSONL files + +```bash +# if you want to use one GPU on multi-gpu machine +export CUDA_VISIBLE_DEVICES=0 + +# MODEL_PATH=meta-llama/Llama-2-7b-hf # Huggingface model id or path to a checkpoint +# TRAIN_DATA_PATH=twitter_complaints.json # Path to the dataset + # contains data in form of [{"input": text , "output": text}] +# OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved + +python tuning/sft_trainer.py \ +--model_name_or_path $MODEL_PATH \ +--training_data_path $TRAIN_DATA_PATH \ +--output_dir $OUTPUT_PATH \ +--num_train_epochs 5 \ +--per_device_train_batch_size 4 \ +--per_device_eval_batch_size 4 \ +--gradient_accumulation_steps 4 \ +--eval_strategy "no" \ +--save_strategy "epoch" \ +--learning_rate 1e-5 \ +--weight_decay 0. \ +--warmup_ratio 0.03 \ +--lr_scheduler_type "cosine" \ +--logging_steps 1 \ +--include_tokens_per_second \ +--packing False \ +--response_template "\n## Label:" \ +--data_formatter_template: "### Input: {{input}} \n\n##Label: {{output}}" + +``` + ### Multiple GPUs with FSDP The recommendation is to use [huggingface accelerate](https://huggingface.co/docs/accelerate/en/index) to launch multi-gpu jobs, in particular when using FSDP: diff --git a/tests/data/__init__.py b/tests/data/__init__.py index 6df7802cd..b81ccaff2 100644 --- a/tests/data/__init__.py +++ b/tests/data/__init__.py @@ -20,5 +20,6 @@ ### Constants used for data DATA_DIR = os.path.join(os.path.dirname(__file__)) TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json") +TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join(DATA_DIR, "twitter_complaints_json.json") EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json") MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json") diff --git a/tests/data/twitter_complaints_json.json b/tests/data/twitter_complaints_json.json new file mode 100644 index 000000000..fba22a9fd --- /dev/null +++ b/tests/data/twitter_complaints_json.json @@ -0,0 +1,12 @@ +[ + {"Tweet text":"@HMRCcustomers No this is my first job","ID":0,"Label":2,"text_label":"no complaint","output":"### Text: @HMRCcustomers No this is my first job\n\n### Label: no complaint"}, + {"Tweet text":"@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.","ID":1,"Label":2,"text_label":"no complaint","output":"### Text: @KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.\n\n### Label: no complaint"}, + {"Tweet text":"If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService","ID":2,"Label":1,"text_label":"complaint","output":"### Text: If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService\n\n### Label: complaint"}, + {"Tweet text":"@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.","ID":3,"Label":1,"text_label":"complaint","output":"### Text: @EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.\n\n### Label: complaint"}, + {"Tweet text":"Couples wallpaper, so cute. :) #BrothersAtHome","ID":4,"Label":2,"text_label":"no complaint","output":"### Text: Couples wallpaper, so cute. :) #BrothersAtHome\n\n### Label: no complaint"}, + {"Tweet text":"@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG","ID":5,"Label":2,"text_label":"no complaint","output":"### Text: @mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG\n\n### Label: no complaint"}, + {"Tweet text":"@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?","ID":6,"Label":2,"text_label":"no complaint","output":"### Text: @Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?\n\n### Label: no complaint"}, + {"Tweet text":"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?","ID":7,"Label":1,"text_label":"complaint","output":"### Text: @nationalgridus I have no water and the bill is current and paid. Can you do something about this?\n\n### Label: complaint"}, + {"Tweet text":"Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora","ID":8,"Label":1,"text_label":"complaint","output":"### Text: Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora\n\n### Label: complaint"}, + {"Tweet text":"@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd","ID":9,"Label":2,"text_label":"no complaint","output":"### Text: @JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd\n\n### Label: no complaint"} +] \ No newline at end of file diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 3664cc0cf..e7eb65aa5 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -29,7 +29,7 @@ # First Party from scripts.run_inference import TunedCausalLM -from tests.data import EMPTY_DATA, MALFORMATTED_DATA, TWITTER_COMPLAINTS_DATA +from tests.data import EMPTY_DATA, MALFORMATTED_DATA, TWITTER_COMPLAINTS_DATA, TWITTER_COMPLAINTS_JSON_FORMAT # Local from tuning import sft_trainer @@ -113,6 +113,7 @@ def test_run_causallm_pt_and_inference(): assert len(output_inference) > 0 assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference + def test_run_causallm_pt_and_inference_with_formatting_data(): """Check if we can bootstrap and peft tune causallm models This test needs the trainer to format data to a single sequence internally. @@ -142,6 +143,34 @@ def test_run_causallm_pt_and_inference_with_formatting_data(): ) assert len(output_inference) > 0 assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference + +def test_run_causallm_pt_and_inference_JSON_file_formatter(): + """Check if we can bootstrap and peft tune causallm models with JSON train file format""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT + data_args.dataset_text_field = None + data_args.data_formatter_template = "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" + + sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) + + # validate peft tuning configs + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + adapter_config = _get_adapter_config(checkpoint_path) + _validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS) + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path) + + # Run inference on the text + output_inference = loaded_model.run( + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference def test_run_causallm_pt_init_text(): """Check if we can bootstrap and peft tune causallm models with init text as 'TEXT'""" diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 13e57ebd0..0a4a9dda0 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -43,7 +43,7 @@ class ModelArguments: @dataclass class DataArguments: training_data_path: str = field( - default=None, metadata={"help": "Path to the training data in JSONL format."} + default=None, metadata={"help": "Path to the training data in JSON/JSONL format."} ) response_template: str = field( default=None, @@ -54,7 +54,7 @@ class DataArguments: Either the dataset_text_field or data_formatter_template need to be supplied."} ) validation_data_path: str = field( - default=None, metadata={"help": "Path to the validation data in JSONL format."} + default=None, metadata={"help": "Path to the validation data in JSON/JSONL format."} ) data_formatter_template: str = field( default=None, metadata={"help": "formatter template to format a single sequence from each instance in JSONL files. \ From e95661df51604033aba5e2b1bb53ccb972fbe00e Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Thu, 23 May 2024 20:13:59 -0600 Subject: [PATCH 05/13] fix:formatter Signed-off-by: Sukriti-Sharma4 --- tests/test_sft_trainer.py | 40 ++++++++++++++++++++++++------- tests/utils/test_data_utils.py | 29 +++++++++++++++-------- tuning/config/configs.py | 19 ++++++++++----- tuning/sft_trainer.py | 43 ++++++++++++++++++++++++---------- tuning/utils/data_utils.py | 23 ++++++++++++------ 5 files changed, 110 insertions(+), 44 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index e7eb65aa5..42368879f 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -29,7 +29,12 @@ # First Party from scripts.run_inference import TunedCausalLM -from tests.data import EMPTY_DATA, MALFORMATTED_DATA, TWITTER_COMPLAINTS_DATA, TWITTER_COMPLAINTS_JSON_FORMAT +from tests.data import ( + EMPTY_DATA, + MALFORMATTED_DATA, + TWITTER_COMPLAINTS_DATA, + TWITTER_COMPLAINTS_JSON_FORMAT, +) # Local from tuning import sft_trainer @@ -121,11 +126,13 @@ def test_run_causallm_pt_and_inference_with_formatting_data(): with tempfile.TemporaryDirectory() as tempdir: data_formatting_args = copy.deepcopy(DATA_ARGS) data_formatting_args.dataset_text_field = None - data_formatting_args.data_formatter_template = "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" + data_formatting_args.data_formatter_template = ( + "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" + ) train_args = copy.deepcopy(TRAIN_ARGS) train_args.output_dir = tempdir - + sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args, PEFT_PT_ARGS) # validate peft tuning configs @@ -143,7 +150,8 @@ def test_run_causallm_pt_and_inference_with_formatting_data(): ) assert len(output_inference) > 0 assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference - + + def test_run_causallm_pt_and_inference_JSON_file_formatter(): """Check if we can bootstrap and peft tune causallm models with JSON train file format""" with tempfile.TemporaryDirectory() as tempdir: @@ -152,7 +160,9 @@ def test_run_causallm_pt_and_inference_JSON_file_formatter(): data_args = copy.deepcopy(DATA_ARGS) data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT data_args.dataset_text_field = None - data_args.data_formatter_template = "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" + data_args.data_formatter_template = ( + "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" + ) sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) @@ -172,6 +182,7 @@ def test_run_causallm_pt_and_inference_JSON_file_formatter(): assert len(output_inference) > 0 assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference + def test_run_causallm_pt_init_text(): """Check if we can bootstrap and peft tune causallm models with init text as 'TEXT'""" with tempfile.TemporaryDirectory() as tempdir: @@ -231,6 +242,7 @@ def test_run_causallm_pt_with_validation(): sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) _validate_training(tempdir, check_eval=True) + def test_run_causallm_pt_with_validation_data_formatting(): """Check if we can bootstrap and peft tune causallm models with validation dataset""" with tempfile.TemporaryDirectory() as tempdir: @@ -240,11 +252,14 @@ def test_run_causallm_pt_with_validation_data_formatting(): data_args = copy.deepcopy(DATA_ARGS) data_args.validation_data_path = TWITTER_COMPLAINTS_DATA data_args.dataset_text_field = None - data_args.data_formatter_template = "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" + data_args.data_formatter_template = ( + "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" + ) sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) _validate_training(tempdir, check_eval=True) + ############################# Lora Tests ############################# target_modules_val_map = [ @@ -405,24 +420,31 @@ def test_invalid_dataset_text_field(): with pytest.raises(KeyError): sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) + ### Tests that giving dataset_text_field as well as formatter template gives error def test_invalid_dataset_text_field_and_formatter_template(): """Only one of dataset_text_field or formatter can be supplied""" data_args = copy.deepcopy(DATA_ARGS) - data_args.data_formatter_template = "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" - + data_args.data_formatter_template = ( + "### Text: {{Tweet text}} \n\n### Label: {{text_label}}" + ) + with pytest.raises(ValueError): sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) + ### Tests passing formatter with invalid keys gives error def test_invalid_formatter_template(): data_args = copy.deepcopy(DATA_ARGS) data_args.dataset_text_field = None - data_args.data_formatter_template = "### Text: {{not found}} \n\n### Label: {{text_label}}" + data_args.data_formatter_template = ( + "### Text: {{not found}} \n\n### Label: {{text_label}}" + ) with pytest.raises(KeyError): sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_PT_ARGS) + ### Tests for bad training data (i.e., data_path is an unhappy value or points to an unhappy thing) def test_malformatted_data(): """Ensure that malformatted data explodes due to failure to generate the dataset.""" diff --git a/tests/utils/test_data_utils.py b/tests/utils/test_data_utils.py index 99322e4b0..26b151cb9 100644 --- a/tests/utils/test_data_utils.py +++ b/tests/utils/test_data_utils.py @@ -15,37 +15,46 @@ # SPDX-License-Identifier: Apache-2.0 # https://spdx.dev/learn/handling-license-info/ +# Third Party +import datasets +import pytest + # First Party from tests.data import TWITTER_COMPLAINTS_DATA + +# Local from tuning.utils import data_utils -# Third Party -import datasets -import pytest def test_formatting_function(): json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" # First response from the data file that is read. expected_response = "### Input: @HMRCcustomers No this is my first job \n\n ### Response: no complaint" - formatted_dataset, dataset_text_field = data_utils.formatting_function(json_dataset, template) + formatted_dataset, dataset_text_field = data_utils.formatting_function( + json_dataset, template + ) # a new dataset_text_field is created in Dataset - assert dataset_text_field in formatted_dataset['train'][0] - assert formatted_dataset['train'][0][dataset_text_field] == expected_response + assert dataset_text_field in formatted_dataset["train"][0] + assert formatted_dataset["train"][0][dataset_text_field] == expected_response + def test_formatting_function_adds_eos_token(): json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" # First response from the data file that is read. expected_response = "### Input: @HMRCcustomers No this is my first job \n\n ### Response: no complaintEOS" - formatted_dataset, dataset_text_field = data_utils.formatting_function(json_dataset, template, 'EOS') + formatted_dataset, dataset_text_field = data_utils.formatting_function( + json_dataset, template, "EOS" + ) # a new dataset_text_field is created in Dataset - assert dataset_text_field in formatted_dataset['train'][0] - assert formatted_dataset['train'][0][dataset_text_field] == expected_response + assert dataset_text_field in formatted_dataset["train"][0] + assert formatted_dataset["train"][0][dataset_text_field] == expected_response + def test_formatting_function_gives_error_with_wrong_keys(): """Tests that the formatting function will throw error if wrong keys are passed to template""" json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" with pytest.raises(KeyError): - data_utils.formatting_function(json_dataset, template, 'EOS') \ No newline at end of file + data_utils.formatting_function(json_dataset, template, "EOS") diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 0a4a9dda0..28c9e4ab2 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -43,24 +43,31 @@ class ModelArguments: @dataclass class DataArguments: training_data_path: str = field( - default=None, metadata={"help": "Path to the training data in JSON/JSONL format."} + default=None, + metadata={"help": "Path to the training data in JSON/JSONL format."}, ) response_template: str = field( default=None, metadata={"help": "Response template, separator to train on completions only"}, ) dataset_text_field: str = field( - default=None, metadata={"help": "Training dataset text field containing single sequence. \ - Either the dataset_text_field or data_formatter_template need to be supplied."} + default=None, + metadata={ + "help": "Training dataset text field containing single sequence. \ + Either the dataset_text_field or data_formatter_template need to be supplied." + }, ) validation_data_path: str = field( - default=None, metadata={"help": "Path to the validation data in JSON/JSONL format."} + default=None, + metadata={"help": "Path to the validation data in JSON/JSONL format."}, ) data_formatter_template: str = field( - default=None, metadata={"help": "formatter template to format a single sequence from each instance in JSONL files. \ + default=None, + metadata={ + "help": "formatter template to format a single sequence from each instance in JSONL files. \ Keys of JSON can be referred to as {{key}} in template. Either the dataset_text_field \ or data_formatter_template needs to be supplied." - } + }, ) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index d40d63d80..3e7887df8 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -29,11 +29,13 @@ TrainerCallback, ) from transformers.utils import logging -from trl import DataCollatorForCompletionOnlyLM, SFTTrainer import datasets import fire import transformers +# First Party +from trl import DataCollatorForCompletionOnlyLM, SFTTrainer + # Local from tuning.config import configs, peft_config from tuning.config.tracker_configs import ( @@ -227,14 +229,20 @@ def train( packing = False # Currently we support formatted datasets with single sequence instances. - if (data_args.dataset_text_field is None) and (data_args.data_formatter_template is None): - raise ValueError("Dataset_text_field and data_formatter_template are None. \ - One of them needs to be set for training") + if (data_args.dataset_text_field is None) and ( + data_args.data_formatter_template is None + ): + raise ValueError( + "Dataset_text_field and data_formatter_template are None. \ + One of them needs to be set for training" + ) # Only one of dataset_text_field or data_formatter_template should be set. if data_args.dataset_text_field and data_args.data_formatter_template: - raise ValueError("Dataset_text_field and data_formatter_template are set. \ - Only one of them needs to be set for training") - + raise ValueError( + "Dataset_text_field and data_formatter_template are set. \ + Only one of them needs to be set for training" + ) + # load the data by parsing JSON data_files = {"train": data_args.training_data_path} if data_args.validation_data_path: @@ -247,8 +255,11 @@ def train( json_dataset = datasets.load_dataset("json", data_files=data_files) if data_args.data_formatter_template: - formatted_train_dataset, data_args.dataset_text_field = \ - formatting_function(json_dataset["train"], data_args.data_formatter_template, tokenizer.eos_token) + formatted_train_dataset, data_args.dataset_text_field = formatting_function( + json_dataset["train"], + data_args.data_formatter_template, + tokenizer.eos_token, + ) else: formatted_train_dataset = json_dataset["train"].map(format_dataset) logger.info("Training dataset length is %s", len(formatted_train_dataset)) @@ -256,10 +267,18 @@ def train( formatted_validation_dataset = None if data_args.validation_data_path: if data_args.data_formatter_template: - formatted_validation_dataset, data_args.dataset_text_field = \ - formatting_function(json_dataset["validation"], data_args.data_formatter_template, tokenizer.eos_token) + ( + formatted_validation_dataset, + data_args.dataset_text_field, + ) = formatting_function( + json_dataset["validation"], + data_args.data_formatter_template, + tokenizer.eos_token, + ) else: - formatted_validation_dataset = json_dataset["validation"].map(format_dataset) + formatted_validation_dataset = json_dataset["validation"].map( + format_dataset + ) logger.info( "Validation dataset length is %s", len(formatted_validation_dataset) ) diff --git a/tuning/utils/data_utils.py b/tuning/utils/data_utils.py index 0546ef8e1..11ec3476f 100644 --- a/tuning/utils/data_utils.py +++ b/tuning/utils/data_utils.py @@ -1,11 +1,15 @@ +# Standard import re + +# Third Party from datasets import Dataset + def formatting_function(dataset, template, eos_token=""): """Function to format datasets with Alpaca style / other templates. Args: dataset: the HF Dataset element loaded from a JSON or DatasetDict object. - template: Template to format data with. Features of Dataset + template: Template to format data with. Features of Dataset should be referred to by {{key}} Returns: Formatted HF Dataset, dataset_field name that contains formatted data. @@ -13,6 +17,7 @@ def formatting_function(dataset, template, eos_token=""): formatted_dataset_field = "formatted_data_field" template += eos_token + def formatter(element): nonlocal template @@ -20,15 +25,19 @@ def replace_text(match_obj): captured_groups = match_obj.groups() if len(captured_groups) != 1: raise ValueError( - "Unexpectedly captured multiple groups in verbalizer rendering" - ) + "Unexpectedly captured multiple groups in verbalizer rendering" + ) index_object = captured_groups[0] if index_object not in element: raise KeyError("Requested template string is not a valid key in dict") - + return element[index_object] - return {formatted_dataset_field : re.sub(r"{{([_a-z\sA-Z0-9]+)}}", replace_text, template)} - - return dataset.map(formatter), formatted_dataset_field \ No newline at end of file + return { + formatted_dataset_field: re.sub( + r"{{([_a-z\sA-Z0-9]+)}}", replace_text, template + ) + } + + return dataset.map(formatter), formatted_dataset_field From 5f1b3e15a0b18d031bcec6a97013326eaff4c66c Mon Sep 17 00:00:00 2001 From: Sukriti Sharma Date: Thu, 23 May 2024 21:33:06 -0600 Subject: [PATCH 06/13] Update README.md Signed-off-by: Sukriti Sharma --- README.md | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 5cb448d56..fff11dd7e 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,8 @@ pip install -e ".[aim]" ## Data format We support two data formats: -1. Pre-process the JSON/JSONL dataset to contain a single sequence of each data instance containing input + Response. The trainer is configured to expect a response template as a string. For example, if one wants to prepare the `alpaca` format data to feed into this trainer, it is quite easy and can be done with the following code. +1. #### Pre-process the JSON/JSONL dataset + Pre-process the JSON/JSONL dataset to contain a single sequence of each data instance containing input + Response. The trainer is configured to expect a response template as a string. For example, if one wants to prepare the `alpaca` format data to feed into this trainer, it is quite easy and can be done with the following code. ```python PROMPT_DICT = { @@ -60,19 +61,21 @@ The same way can be applied to any dataset, with more info can be found [here](h Once the JSON is converted using the formatting function, pass the `dataset_text_field` containing the single sequence to the trainer. -2. Pass a JSON/JSONL and a `data_formatter_template` to use the formatting function on the fly while tuning. The template should specify fields of JSON with `{{field}}`. While tuning, the data will be converted to a single sequence using the template. +2. #### Format JSON/JSONL on the fly + Pass a JSON/JSONL and a `data_formatter_template` to use the formatting function on the fly while tuning. The template should specify fields of JSON with `{{field}}`. While tuning, the data will be converted to a single sequence using the template. Example: Train.json `[{ "input" : , "output" : , }, ... -]` +]` +data_formatter_template: `### Input: {{input}} \n\n##Label: {{output}}` -data_formatter_template: `### Input: {{input}} \n\n##Label: {{output}}` -Formatting will happen pn the fly while tuning. The keys in template should match fields in JSON file. The `response template` corresponding to the above template will need to be supplied. in this case, `response template` = `\n## Label:`. +Formatting will happen on the fly while tuning. The keys in template should match fields in JSON file. The `response template` corresponding to the above template will need to be supplied. in this case, `response template` = `\n## Label:`. -In conclusion, either the `data_formatter_template` argument or `dataset_text_field` needs to be supplied to the trainer. + +##### In conclusion, either the `data_formatter_template` argument or `dataset_text_field` needs to be supplied to the trainer. ## Supported Models @@ -90,6 +93,7 @@ export CUDA_VISIBLE_DEVICES=0 # MODEL_PATH=meta-llama/Llama-2-7b-hf # Huggingface model id or path to a checkpoint # TRAIN_DATA_PATH=twitter_complaints.json # Path to the dataset + # contains data in single sequence {"output": "### Input: text \n\n### Response: text"} # OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved python tuning/sft_trainer.py \ From 2f70e34b6d39a4d72a6e90aa27f801aa817df429 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Thu, 23 May 2024 21:41:53 -0600 Subject: [PATCH 07/13] fix imports Signed-off-by: Sukriti-Sharma4 --- tuning/sft_trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 3e7887df8..a9e1aebc1 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -29,13 +29,11 @@ TrainerCallback, ) from transformers.utils import logging +from trl import DataCollatorForCompletionOnlyLM, SFTTrainer import datasets import fire import transformers -# First Party -from trl import DataCollatorForCompletionOnlyLM, SFTTrainer - # Local from tuning.config import configs, peft_config from tuning.config.tracker_configs import ( From 45827ce9c7b8b8aaac97ecb2c73e16620f2ec30f Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Thu, 23 May 2024 21:56:08 -0600 Subject: [PATCH 08/13] fix pylint Signed-off-by: Sukriti-Sharma4 --- tests/utils/test_data_utils.py | 8 ++++++-- tuning/config/configs.py | 11 +++++++---- tuning/utils/data_utils.py | 3 --- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/utils/test_data_utils.py b/tests/utils/test_data_utils.py index 26b151cb9..f20790895 100644 --- a/tests/utils/test_data_utils.py +++ b/tests/utils/test_data_utils.py @@ -30,7 +30,9 @@ def test_formatting_function(): json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" # First response from the data file that is read. - expected_response = "### Input: @HMRCcustomers No this is my first job \n\n ### Response: no complaint" + expected_response = ( + "### Input: No this is my first job \n\n ### Response: no complaint" + ) formatted_dataset, dataset_text_field = data_utils.formatting_function( json_dataset, template ) @@ -43,7 +45,9 @@ def test_formatting_function_adds_eos_token(): json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" # First response from the data file that is read. - expected_response = "### Input: @HMRCcustomers No this is my first job \n\n ### Response: no complaintEOS" + expected_response = ( + "### Input: No this is my first job \n\n ### Response: no complaintEOS" + ) formatted_dataset, dataset_text_field = data_utils.formatting_function( json_dataset, template, "EOS" ) diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 28c9e4ab2..bccf5d15b 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -54,7 +54,8 @@ class DataArguments: default=None, metadata={ "help": "Training dataset text field containing single sequence. \ - Either the dataset_text_field or data_formatter_template need to be supplied." + Either the dataset_text_field \ + or data_formatter_template need to be supplied." }, ) validation_data_path: str = field( @@ -64,9 +65,11 @@ class DataArguments: data_formatter_template: str = field( default=None, metadata={ - "help": "formatter template to format a single sequence from each instance in JSONL files. \ - Keys of JSON can be referred to as {{key}} in template. Either the dataset_text_field \ - or data_formatter_template needs to be supplied." + "help": "formatter template to format a single sequence \ + from each instance in JSONL files. \ + Keys of JSON can be referred to as {{key}} in template. \ + Either the dataset_text_field \ + or data_formatter_template needs to be supplied." }, ) diff --git a/tuning/utils/data_utils.py b/tuning/utils/data_utils.py index 11ec3476f..3f66c07da 100644 --- a/tuning/utils/data_utils.py +++ b/tuning/utils/data_utils.py @@ -1,9 +1,6 @@ # Standard import re -# Third Party -from datasets import Dataset - def formatting_function(dataset, template, eos_token=""): """Function to format datasets with Alpaca style / other templates. From 1f6bb04fc8b0c994e223b23938a979c6e9c5d0b7 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Thu, 23 May 2024 22:02:55 -0600 Subject: [PATCH 09/13] fix tests Signed-off-by: Sukriti-Sharma4 --- tests/utils/test_data_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_data_utils.py b/tests/utils/test_data_utils.py index f20790895..3857fd533 100644 --- a/tests/utils/test_data_utils.py +++ b/tests/utils/test_data_utils.py @@ -31,7 +31,8 @@ def test_formatting_function(): template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" # First response from the data file that is read. expected_response = ( - "### Input: No this is my first job \n\n ### Response: no complaint" + "### Input: @HMRCcustomers No this is my first job" + + " \n\n ### Response: no complaint" ) formatted_dataset, dataset_text_field = data_utils.formatting_function( json_dataset, template @@ -46,7 +47,8 @@ def test_formatting_function_adds_eos_token(): template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" # First response from the data file that is read. expected_response = ( - "### Input: No this is my first job \n\n ### Response: no complaintEOS" + "### Input: @HMRCcustomers No this is my first job" + + " \n\n ### Response: no complaintEOS" ) formatted_dataset, dataset_text_field = data_utils.formatting_function( json_dataset, template, "EOS" From 4579f6f898959772d9bfb85978a95bac33674900 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Fri, 24 May 2024 15:52:24 -0600 Subject: [PATCH 10/13] address review comments- function names Signed-off-by: Sukriti-Sharma4 --- tests/utils/test_data_utils.py | 12 ++++++------ tuning/sft_trainer.py | 15 ++++++++------- tuning/utils/data_utils.py | 9 +++++---- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/tests/utils/test_data_utils.py b/tests/utils/test_data_utils.py index 3857fd533..471f28590 100644 --- a/tests/utils/test_data_utils.py +++ b/tests/utils/test_data_utils.py @@ -26,7 +26,7 @@ from tuning.utils import data_utils -def test_formatting_function(): +def test_apply_custom_formatting_template(): json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" # First response from the data file that is read. @@ -34,7 +34,7 @@ def test_formatting_function(): "### Input: @HMRCcustomers No this is my first job" + " \n\n ### Response: no complaint" ) - formatted_dataset, dataset_text_field = data_utils.formatting_function( + formatted_dataset, dataset_text_field = data_utils.apply_custom_formatting_template( json_dataset, template ) # a new dataset_text_field is created in Dataset @@ -42,7 +42,7 @@ def test_formatting_function(): assert formatted_dataset["train"][0][dataset_text_field] == expected_response -def test_formatting_function_adds_eos_token(): +def test_apply_custom_formatting_template_adds_eos_token(): json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" # First response from the data file that is read. @@ -50,7 +50,7 @@ def test_formatting_function_adds_eos_token(): "### Input: @HMRCcustomers No this is my first job" + " \n\n ### Response: no complaintEOS" ) - formatted_dataset, dataset_text_field = data_utils.formatting_function( + formatted_dataset, dataset_text_field = data_utils.apply_custom_formatting_template( json_dataset, template, "EOS" ) # a new dataset_text_field is created in Dataset @@ -58,9 +58,9 @@ def test_formatting_function_adds_eos_token(): assert formatted_dataset["train"][0][dataset_text_field] == expected_response -def test_formatting_function_gives_error_with_wrong_keys(): +def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): """Tests that the formatting function will throw error if wrong keys are passed to template""" json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" with pytest.raises(KeyError): - data_utils.formatting_function(json_dataset, template, "EOS") + data_utils.apply_custom_formatting_template(json_dataset, template, "EOS") diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index a9e1aebc1..cc5cb548b 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -29,11 +29,13 @@ TrainerCallback, ) from transformers.utils import logging -from trl import DataCollatorForCompletionOnlyLM, SFTTrainer import datasets import fire import transformers +# First Party +from trl import DataCollatorForCompletionOnlyLM, SFTTrainer + # Local from tuning.config import configs, peft_config from tuning.config.tracker_configs import ( @@ -46,7 +48,7 @@ from tuning.trainercontroller import TrainerControllerCallback from tuning.utils.config_utils import get_hf_peft_config from tuning.utils.data_type_utils import get_torch_dtype -from tuning.utils.data_utils import formatting_function +from tuning.utils.data_utils import apply_custom_formatting_template def train( @@ -227,9 +229,8 @@ def train( packing = False # Currently we support formatted datasets with single sequence instances. - if (data_args.dataset_text_field is None) and ( - data_args.data_formatter_template is None - ): + if not (data_args.dataset_text_field or + data_args.data_formatter_template ): raise ValueError( "Dataset_text_field and data_formatter_template are None. \ One of them needs to be set for training" @@ -253,7 +254,7 @@ def train( json_dataset = datasets.load_dataset("json", data_files=data_files) if data_args.data_formatter_template: - formatted_train_dataset, data_args.dataset_text_field = formatting_function( + formatted_train_dataset, data_args.dataset_text_field = apply_custom_formatting_template( json_dataset["train"], data_args.data_formatter_template, tokenizer.eos_token, @@ -268,7 +269,7 @@ def train( ( formatted_validation_dataset, data_args.dataset_text_field, - ) = formatting_function( + ) = apply_custom_formatting_template( json_dataset["validation"], data_args.data_formatter_template, tokenizer.eos_token, diff --git a/tuning/utils/data_utils.py b/tuning/utils/data_utils.py index 3f66c07da..b60ef677b 100644 --- a/tuning/utils/data_utils.py +++ b/tuning/utils/data_utils.py @@ -2,12 +2,14 @@ import re -def formatting_function(dataset, template, eos_token=""): +def apply_custom_formatting_template(dataset, template, eos_token=""): """Function to format datasets with Alpaca style / other templates. Args: dataset: the HF Dataset element loaded from a JSON or DatasetDict object. template: Template to format data with. Features of Dataset should be referred to by {{key}} + eos_token: string EOS token to be appended while formatting data to a single sequence. + Defaults to empty Returns: Formatted HF Dataset, dataset_field name that contains formatted data. """ @@ -16,13 +18,12 @@ def formatting_function(dataset, template, eos_token=""): template += eos_token def formatter(element): - nonlocal template def replace_text(match_obj): captured_groups = match_obj.groups() if len(captured_groups) != 1: raise ValueError( - "Unexpectedly captured multiple groups in verbalizer rendering" + "Unexpectedly captured multiple groups in template formatting" ) index_object = captured_groups[0] @@ -33,7 +34,7 @@ def replace_text(match_obj): return { formatted_dataset_field: re.sub( - r"{{([_a-z\sA-Z0-9]+)}}", replace_text, template + r"{{(.+)}}", replace_text, template ) } From 6cc6d41c8f830ef7e8e9430ab3a3963fb7707af6 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Fri, 24 May 2024 16:07:55 -0600 Subject: [PATCH 11/13] formatting fix Signed-off-by: Sukriti-Sharma4 --- tuning/sft_trainer.py | 12 ++++++------ tuning/utils/data_utils.py | 7 +------ 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index cc5cb548b..1ad4a5463 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -29,13 +29,11 @@ TrainerCallback, ) from transformers.utils import logging +from trl import DataCollatorForCompletionOnlyLM, SFTTrainer import datasets import fire import transformers -# First Party -from trl import DataCollatorForCompletionOnlyLM, SFTTrainer - # Local from tuning.config import configs, peft_config from tuning.config.tracker_configs import ( @@ -229,8 +227,7 @@ def train( packing = False # Currently we support formatted datasets with single sequence instances. - if not (data_args.dataset_text_field or - data_args.data_formatter_template ): + if not (data_args.dataset_text_field or data_args.data_formatter_template): raise ValueError( "Dataset_text_field and data_formatter_template are None. \ One of them needs to be set for training" @@ -254,7 +251,10 @@ def train( json_dataset = datasets.load_dataset("json", data_files=data_files) if data_args.data_formatter_template: - formatted_train_dataset, data_args.dataset_text_field = apply_custom_formatting_template( + ( + formatted_train_dataset, + data_args.dataset_text_field, + ) = apply_custom_formatting_template( json_dataset["train"], data_args.data_formatter_template, tokenizer.eos_token, diff --git a/tuning/utils/data_utils.py b/tuning/utils/data_utils.py index b60ef677b..01bc4d1d2 100644 --- a/tuning/utils/data_utils.py +++ b/tuning/utils/data_utils.py @@ -18,7 +18,6 @@ def apply_custom_formatting_template(dataset, template, eos_token=""): template += eos_token def formatter(element): - def replace_text(match_obj): captured_groups = match_obj.groups() if len(captured_groups) != 1: @@ -32,10 +31,6 @@ def replace_text(match_obj): return element[index_object] - return { - formatted_dataset_field: re.sub( - r"{{(.+)}}", replace_text, template - ) - } + return {formatted_dataset_field: re.sub(r"{{(.+)}}", replace_text, template)} return dataset.map(formatter), formatted_dataset_field From 6c09d687f1ea1a6ff028805b93b8ab9f2e8d03db Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Tue, 28 May 2024 17:11:50 -0600 Subject: [PATCH 12/13] update error message Signed-off-by: Sukriti-Sharma4 --- tuning/sft_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 1ad4a5463..fdf7efc8d 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -229,14 +229,14 @@ def train( # Currently we support formatted datasets with single sequence instances. if not (data_args.dataset_text_field or data_args.data_formatter_template): raise ValueError( - "Dataset_text_field and data_formatter_template are None. \ + "dataset_text_field and data_formatter_template are None. \ One of them needs to be set for training" ) # Only one of dataset_text_field or data_formatter_template should be set. if data_args.dataset_text_field and data_args.data_formatter_template: raise ValueError( - "Dataset_text_field and data_formatter_template are set. \ - Only one of them needs to be set for training" + "dataset_text_field and data_formatter_template are both set,\ + but are mutually exclusive options" ) # load the data by parsing JSON From 3f5cc6ba0fb72c4a24f6c6ec16a34aea18a59498 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Tue, 28 May 2024 18:06:24 -0600 Subject: [PATCH 13/13] restrict JSON fields templates Signed-off-by: Sukriti-Sharma4 --- README.md | 1 + tuning/utils/data_utils.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index fff11dd7e..f552f06fe 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ Once the JSON is converted using the formatting function, pass the `dataset_text 2. #### Format JSON/JSONL on the fly Pass a JSON/JSONL and a `data_formatter_template` to use the formatting function on the fly while tuning. The template should specify fields of JSON with `{{field}}`. While tuning, the data will be converted to a single sequence using the template. + JSON fields can contain alpha-numeric characters, spaces and the following special symbols - "." , "_", "-". Example: Train.json `[{ "input" : , diff --git a/tuning/utils/data_utils.py b/tuning/utils/data_utils.py index 01bc4d1d2..3e67cc56f 100644 --- a/tuning/utils/data_utils.py +++ b/tuning/utils/data_utils.py @@ -31,6 +31,10 @@ def replace_text(match_obj): return element[index_object] - return {formatted_dataset_field: re.sub(r"{{(.+)}}", replace_text, template)} + return { + formatted_dataset_field: re.sub( + r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template + ) + } return dataset.map(formatter), formatted_dataset_field