forked from foundation-model-stack/fms-hf-tuning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: DataPreProcessor v1 - framework and backend (foundation-model-s…
…tack#381) * Add initial implementation of dataloader v1 Signed-off-by: Dushyant Behl <[email protected]> * tests: reformat mock.patch to inside unit tests Signed-off-by: Will Johnson <[email protected]> fmt Signed-off-by: Will Johnson <[email protected]> * Add data config argument to data preprocessor Signed-off-by: Dushyant Behl <[email protected]> * fix: Changes to support current implementation Signed-off-by: Abhishek <[email protected]> * Ensure data handling is done within process dataargs Removes unused dead code after adding the new framework and refactors some test cases and files. Signed-off-by: Dushyant Behl <[email protected]> * Remove accelerator in favor of torch distributed check for multi node data preprocessing Signed-off-by: Dushyant Behl <[email protected]> * Refactor data util tests as data handler tests. Signed-off-by: Dushyant Behl <[email protected]> * fix: add __init__.py to add tuning.data to python package Signed-off-by: Will Johnson <[email protected]> * fix: multi GPU prepare training dataset Signed-off-by: Will Johnson <[email protected]> * fix: lint Signed-off-by: Will Johnson <[email protected]> * fix: Add TODO Signed-off-by: Will Johnson <[email protected]> * test: add test for process_dataset_configs in HFBasedDataPreProcessor Signed-off-by: Will Johnson <[email protected]> * add: test cases for framework Signed-off-by: Abhishek <[email protected]> * fix: update function name get_dataprocessor->get_datapreprocessor Signed-off-by: Will Johnson <[email protected]> * Rename loader to processor Signed-off-by: Dushyant Behl <[email protected]> * data folders should be together Signed-off-by: Dushyant Behl <[email protected]> * Add code comments and make code path clearer. Remove packing check as packing support for pretokenised data is merged to trl. See huggingface/trl#2011 Signed-off-by: Dushyant Behl <[email protected]> --------- Signed-off-by: Dushyant Behl <[email protected]> Signed-off-by: Will Johnson <[email protected]> Signed-off-by: Abhishek <[email protected]> Co-authored-by: Will Johnson <[email protected]> Co-authored-by: Abhishek <[email protected]>
- Loading branch information
1 parent
268ac80
commit 7df3416
Showing
21 changed files
with
1,426 additions
and
824 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright The FMS HF Tuning Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Helpful datasets for configuring individual unit tests. | ||
""" | ||
# Standard | ||
import os | ||
|
||
### Constants used for data | ||
PREDEFINED_DATA_CONFIGS = os.path.join(os.path.dirname(__file__)) | ||
APPLY_CUSTOM_TEMPLATE_YAML = os.path.join( | ||
PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml" | ||
) | ||
PRETOKENIZE_JSON_DATA_YAML = os.path.join( | ||
PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml" | ||
) | ||
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML = os.path.join( | ||
PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking.yaml" | ||
) |
14 changes: 14 additions & 0 deletions
14
tests/artifacts/predefined_data_configs/apply_custom_template.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
dataprocessor: | ||
type: default | ||
datasets: | ||
- name: apply_custom_data_template | ||
data_paths: | ||
- "FILE_PATH" | ||
data_handlers: | ||
- name: apply_custom_data_formatting_template | ||
arguments: | ||
remove_columns: all | ||
batched: false | ||
fn_kwargs: | ||
dataset_text_field: "dataset_text_field" | ||
dataset_template: "dataset_template" |
6 changes: 6 additions & 0 deletions
6
tests/artifacts/predefined_data_configs/pretokenized_json_data.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
dataprocessor: | ||
type: default | ||
datasets: | ||
- name: pretokenized_dataset | ||
data_paths: | ||
- "FILE_PATH" |
14 changes: 14 additions & 0 deletions
14
tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
dataprocessor: | ||
type: default | ||
datasets: | ||
- name: text_dataset_input_output_masking | ||
data_paths: | ||
- "FILE_PATH" | ||
data_handlers: | ||
- name: tokenize_and_apply_input_masking | ||
arguments: | ||
remove_columns: all | ||
batched: false | ||
fn_kwargs: | ||
input_field: "INPUT" | ||
output_field: "OUTPUT" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright The FMS HF Tuning Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
# https://spdx.dev/learn/handling-license-info/ | ||
|
||
# Third Party | ||
from transformers import AutoTokenizer | ||
import datasets | ||
import pytest | ||
|
||
# First Party | ||
from tests.artifacts.testdata import MODEL_NAME, TWITTER_COMPLAINTS_DATA_JSONL | ||
|
||
# Local | ||
from tuning.data.data_handlers import ( | ||
apply_custom_data_formatting_template, | ||
combine_sequence, | ||
) | ||
|
||
|
||
def test_apply_custom_formatting_template(): | ||
json_dataset = datasets.load_dataset( | ||
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL | ||
) | ||
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
formatted_dataset_field = "formatted_data_field" | ||
formatted_dataset = json_dataset.map( | ||
apply_custom_data_formatting_template, | ||
fn_kwargs={ | ||
"tokenizer": tokenizer, | ||
"dataset_text_field": formatted_dataset_field, | ||
"template": template, | ||
}, | ||
) | ||
# First response from the data file that is read. | ||
expected_response = ( | ||
"### Input: @HMRCcustomers No this is my first job" | ||
+ " \n\n ### Response: no complaint" | ||
+ tokenizer.eos_token | ||
) | ||
|
||
# a new dataset_text_field is created in Dataset | ||
assert formatted_dataset_field in formatted_dataset["train"][0] | ||
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response | ||
|
||
|
||
def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): | ||
"""Tests that the formatting function will throw error if wrong keys are passed to template""" | ||
json_dataset = datasets.load_dataset( | ||
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL | ||
) | ||
template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" | ||
formatted_dataset_field = "formatted_data_field" | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
with pytest.raises(KeyError): | ||
json_dataset.map( | ||
apply_custom_data_formatting_template, | ||
fn_kwargs={ | ||
"tokenizer": tokenizer, | ||
"dataset_text_field": formatted_dataset_field, | ||
"template": template, | ||
}, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_element,output_element,expected_res", | ||
[ | ||
("foo ", "bar", "foo bar"), | ||
("foo\n", "bar", "foo\nbar"), | ||
("foo\t", "bar", "foo\tbar"), | ||
("foo", "bar", "foo bar"), | ||
], | ||
) | ||
def test_combine_sequence(input_element, output_element, expected_res): | ||
"""Ensure that input / output elements are combined with correct whitespace handling.""" | ||
comb_seq = combine_sequence(input_element, output_element) | ||
assert isinstance(comb_seq, str) | ||
assert comb_seq == expected_res | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_element,output_element,expected_res", | ||
[ | ||
("foo ", "bar", "foo bar"), | ||
("foo\n", "bar", "foo\nbar"), | ||
("foo\t", "bar", "foo\tbar"), | ||
("foo", "bar", "foo bar"), | ||
], | ||
) | ||
def test_combine_sequence_adds_eos(input_element, output_element, expected_res): | ||
"""Ensure that input / output elements are combined with correct whitespace handling.""" | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token) | ||
expected_res += tokenizer.eos_token | ||
assert isinstance(comb_seq, str) | ||
assert comb_seq == expected_res |
Oops, something went wrong.