-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Formatting consolidation main (#216)
* Add prepreprocessing utils for pretokenized datasets Signed-off-by: Alex-Brooks <[email protected]> * Add twitter input/output data Signed-off-by: Alex-Brooks <[email protected]> * Add tests for data preprocessing utilities Signed-off-by: Alex-Brooks <[email protected]> * Formatting, add hack for sidestepping validation Signed-off-by: Alex-Brooks <[email protected]> * Fix linting errors in data gen Signed-off-by: Alex-Brooks <[email protected]> * Add end to end pretokenized tests, formatting Signed-off-by: Alex-Brooks <[email protected]> * Add docstrings for preprocessor utils Signed-off-by: Alex-Brooks <[email protected]> * Rebase tests to new structure Signed-off-by: Alex-Brooks <[email protected]> * fix formatting Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix linting Signed-off-by: Sukriti-Sharma4 <[email protected]> --------- Signed-off-by: Alex-Brooks <[email protected]> Signed-off-by: Sukriti-Sharma4 <[email protected]> Co-authored-by: Alex-Brooks <[email protected]>
- Loading branch information
1 parent
3f05c67
commit 4334d6c
Showing
4 changed files
with
588 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
{"ID": 0, "Label": 2, "input": "@HMRCcustomers No this is my first job", "output": "no complaint"} | ||
{"ID": 1, "Label": 2, "input": "@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.", "output": "no complaint"} | ||
{"ID": 2, "Label": 1, "input": "If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService", "output": "complaint"} | ||
{"ID": 3, "Label": 1, "input": "@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.", "output": "complaint"} | ||
{"ID": 4, "Label": 2, "input": "Couples wallpaper, so cute. :) #BrothersAtHome", "output": "no complaint"} | ||
{"ID": 5, "Label": 2, "input": "@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https://t.co/WRtNsokblG", "output": "no complaint"} | ||
{"ID": 6, "Label": 2, "input": "@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?", "output": "no complaint"} | ||
{"ID": 7, "Label": 1, "input": "@nationalgridus I have no water and the bill is current and paid. Can you do something about this?", "output": "complaint"} | ||
{"ID": 8, "Label": 1, "input": "Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude/condescending. I'll take my $$ to @Sephora", "output": "complaint"} | ||
{"ID": 9, "Label": 2, "input": "@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd", "output": "no complaint"} | ||
{"ID": 10, "Label": 2, "input": "@NortonSupport Thanks much.", "output": "no complaint"} | ||
{"ID": 11, "Label": 2, "input": "@VerizonSupport all of a sudden I can't connect to my primary wireless network but guest one works", "output": "no complaint"} | ||
{"ID": 12, "Label": 2, "input": "Aaaahhhhh!!!! My @Razer @PlayOverwatch d.va meka headset came in!!! I didn't even know it had shipped!!! So excited\u2026 https://t.co/4gXy9xED8d", "output": "no complaint"} | ||
{"ID": 13, "Label": 2, "input": "@Lin_Manuel @jmessinaphoto @VAMNit Omg a little squish!!!!! Enjoy and congrats!!!! I miss mine being so young! \ufffd\ufffd\ufffd\ufffd\ufffd\ufffd", "output": "no complaint"} | ||
{"ID": 14, "Label": 2, "input": "@IanJamesPoulter What's your secret to poaching eggs? Mine NEVER look that good.", "output": "no complaint"} | ||
{"ID": 15, "Label": 2, "input": "@AWSSupport When will be able Kinesis Firehose compatible with Elasticsearch 6.0? Thank you!", "output": "no complaint"} | ||
{"ID": 16, "Label": 2, "input": "@NCIS_CBS https://t.co/eeVL9Eu3bE", "output": "no complaint"} | ||
{"ID": 17, "Label": 2, "input": "@msetchell Via the settings? That\u2019s how I do it on master T\u2019s", "output": "no complaint"} | ||
{"ID": 18, "Label": 2, "input": "Today at work there was a low flying duck heading toward a crowd of people, and I yelled \"watch out! and I'm very disappointed with myself.", "output": "no complaint"} | ||
{"ID": 19, "Label": 1, "input": "@NortonSupport @NortonOnline What the hell is a dm 5-10 days to get money back bank account now overdrawn thanks guys", "output": "complaint"} | ||
{"ID": 20, "Label": 1, "input": "@united not happy with this delay from Newark to Manchester tonight :( only 30 mins free Wi-fi sucks ...", "output": "complaint"} | ||
{"ID": 21, "Label": 1, "input": "@ZARA_Care I've been waiting on a reply to my tweets and DMs for days now?", "output": "complaint"} | ||
{"ID": 22, "Label": 2, "input": "New Listing! Large 2 Family Home for Sale in #Passaic Park, #NJ #realestate #homesforsale Great Location!\u2026 https://t.co/IV4OrLXkMk", "output": "no complaint"} | ||
{"ID": 23, "Label": 1, "input": "@SouthwestAir I love you but when sending me flight changes please don't use military time #ignoranceisbliss", "output": "complaint"} | ||
{"ID": 24, "Label": 2, "input": "@JetBlue Completely understand but would prefer being on time to filling out forms....", "output": "no complaint"} | ||
{"ID": 25, "Label": 2, "input": "@nvidiacc I own two gtx 460 in sli. I want to try windows 8 dev preview. Which driver should I use. Can I use the windows 7 one.", "output": "no complaint"} | ||
{"ID": 26, "Label": 2, "input": "Just posted a photo https://t.co/RShFwCjPHu", "output": "no complaint"} | ||
{"ID": 27, "Label": 2, "input": "Love crescent rolls? Try adding pesto @PerdueChicken to them and you\u2019re going to love it! #Promotion #PerdueCrew -\u2026 https://t.co/KBHOfqCukH", "output": "no complaint"} | ||
{"ID": 28, "Label": 1, "input": "@TopmanAskUs please just give me my money back.", "output": "complaint"} | ||
{"ID": 29, "Label": 2, "input": "I just gave 5 stars to Tracee at @neimanmarcus for the great service I received!", "output": "no complaint"} | ||
{"ID": 30, "Label": 2, "input": "@FitbitSupport when are you launching new clock faces for Indian market", "output": "no complaint"} | ||
{"ID": 31, "Label": 1, "input": "@HPSupport my printer will not allow me to choose color instead it only prints monochrome #hppsdr #ijkhelp", "output": "complaint"} | ||
{"ID": 32, "Label": 1, "input": "@DIRECTV can I get a monthly charge double refund when it sprinkles outside and we lose reception? #IamEmbarrasedForYou", "output": "complaint"} | ||
{"ID": 33, "Label": 1, "input": "@AlfaRomeoCares Hi thanks for replying, could be my internet but link doesn't seem to be working", "output": "complaint"} | ||
{"ID": 34, "Label": 2, "input": "Looks tasty! Going to share with everyone I know #FebrezeONE #sponsored https://t.co/4AQI53npei", "output": "no complaint"} | ||
{"ID": 35, "Label": 2, "input": "@OnePlus_IN can OnePlus 5T do front camera portrait?", "output": "no complaint"} | ||
{"ID": 36, "Label": 1, "input": "@sho_help @showtime your arrive is terrible streaming is stop and start every couple mins. Get it together it's xmas", "output": "complaint"} | ||
{"ID": 37, "Label": 2, "input": "@KandraKPTV I just witnessed a huge building fire in Santa Monica California", "output": "no complaint"} | ||
{"ID": 38, "Label": 2, "input": "@fernrocks most definitely the latter for me", "output": "no complaint"} | ||
{"ID": 39, "Label": 1, "input": "@greateranglia Could I ask why the Area in front of BIC Station was not gritted withh all the snow.", "output": "complaint"} | ||
{"ID": 40, "Label": 2, "input": "I'm earning points with #CricketRewards https://t.co/GfpGhqqnhE", "output": "no complaint"} | ||
{"ID": 41, "Label": 2, "input": "@Schrapnel @comcast RIP me", "output": "no complaint"} | ||
{"ID": 42, "Label": 2, "input": "The wait is finally over, just joined @SquareUK, hope to get started real soon!", "output": "no complaint"} | ||
{"ID": 43, "Label": 2, "input": "@WholeFoods what's the best way to give feedback on a particular store to the regional/national office?", "output": "no complaint"} | ||
{"ID": 44, "Label": 2, "input": "@DanielNewman I honestly would believe anything. People are...too much sometimes.", "output": "no complaint"} | ||
{"ID": 45, "Label": 2, "input": "@asblough Yep! It should send you a notification with your driver\u2019s name and what time they\u2019ll be showing up!", "output": "no complaint"} | ||
{"ID": 46, "Label": 2, "input": "@Wavy2Timez for real", "output": "no complaint"} | ||
{"ID": 47, "Label": 1, "input": "@KenyaPower_Care no power in south b area... is it scheduled.", "output": "complaint"} | ||
{"ID": 48, "Label": 1, "input": "Honda won't do anything about water leaking in brand new car. Frustrated! @HondaCustSvc @AmericanHonda", "output": "complaint"} | ||
{"ID": 49, "Label": 1, "input": "@CBSNews @Dodge @ChryslerCares My driver side air bag has been recalled and replaced, but what about the passenger side?", "output": "complaint"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
# Third Party | ||
from datasets import Dataset | ||
from datasets.exceptions import DatasetGenerationError | ||
from transformers import AutoTokenizer, DataCollatorForSeq2Seq | ||
from trl import DataCollatorForCompletionOnlyLM | ||
import pytest | ||
|
||
# First Party | ||
from tests.data import ( | ||
MALFORMATTED_DATA, | ||
TWITTER_COMPLAINTS_DATA, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT, | ||
) | ||
|
||
# Local | ||
from tuning.utils.preprocessing_utils import ( | ||
combine_sequence, | ||
get_data_trainer_kwargs, | ||
get_preprocessed_dataset, | ||
load_hf_dataset_from_jsonl_file, | ||
validate_data_args, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_element,output_element,expected_res", | ||
[ | ||
("foo ", "bar", "foo bar"), | ||
("foo\n", "bar", "foo\nbar"), | ||
("foo\t", "bar", "foo\tbar"), | ||
("foo", "bar", "foo bar"), | ||
], | ||
) | ||
def test_combine_sequence(input_element, output_element, expected_res): | ||
"""Ensure that input / output elements are combined with correct whitespace handling.""" | ||
comb_seq = combine_sequence(input_element, output_element) | ||
assert isinstance(comb_seq, str) | ||
assert comb_seq == expected_res | ||
|
||
|
||
# Tests for loading the dataset from disk | ||
def test_load_hf_dataset_from_jsonl_file(): | ||
input_field_name = "Tweet text" | ||
output_field_name = "text_label" | ||
data = load_hf_dataset_from_jsonl_file( | ||
TWITTER_COMPLAINTS_DATA, | ||
input_field_name=input_field_name, | ||
output_field_name=output_field_name, | ||
) | ||
# Our dataset should contain dicts that contain the input / output field name types | ||
next_data = next(iter(data)) | ||
assert input_field_name in next_data | ||
assert output_field_name in next_data | ||
|
||
|
||
def test_load_hf_dataset_from_jsonl_file_wrong_keys(): | ||
"""Ensure that we explode if the keys are not in the jsonl file.""" | ||
with pytest.raises(DatasetGenerationError): | ||
load_hf_dataset_from_jsonl_file( | ||
TWITTER_COMPLAINTS_DATA, input_field_name="foo", output_field_name="bar" | ||
) | ||
|
||
|
||
def test_load_hf_dataset_from_malformatted_data(): | ||
"""Ensure that we explode if the data is not properly formatted.""" | ||
# NOTE: The actual keys don't matter here | ||
with pytest.raises(DatasetGenerationError): | ||
load_hf_dataset_from_jsonl_file( | ||
MALFORMATTED_DATA, input_field_name="foo", output_field_name="bar" | ||
) | ||
|
||
|
||
def test_load_hf_dataset_from_jsonl_file_duplicate_keys(): | ||
"""Ensure we cannot have the same key for input / output.""" | ||
with pytest.raises(ValueError): | ||
load_hf_dataset_from_jsonl_file( | ||
TWITTER_COMPLAINTS_DATA, | ||
input_field_name="Tweet text", | ||
output_field_name="Tweet text", | ||
) | ||
|
||
|
||
# Tests for custom masking / preprocessing logic | ||
@pytest.mark.parametrize("max_sequence_length", [1, 10, 100, 1000]) | ||
def test_get_preprocessed_dataset(max_sequence_length): | ||
tokenizer = AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0") | ||
preprocessed_data = get_preprocessed_dataset( | ||
data_path=TWITTER_COMPLAINTS_DATA, | ||
tokenizer=tokenizer, | ||
max_sequence_length=max_sequence_length, | ||
input_field_name="Tweet text", | ||
output_field_name="text_label", | ||
) | ||
for tok_res in preprocessed_data: | ||
# Since the padding is left to the collator, there should be no 0s in the attention mask yet | ||
assert sum(tok_res["attention_mask"]) == len(tok_res["attention_mask"]) | ||
# If the source text isn't empty, we start with masked inputs | ||
assert tok_res["labels"][0] == -100 | ||
# All keys in the produced record must be the same length | ||
key_lengths = {len(tok_res[k]) for k in tok_res.keys()} | ||
assert len(key_lengths) == 1 | ||
# And also that length should be less than or equal to the max length depending on if we | ||
# are going up to / over the max size and truncating - padding is handled separately | ||
assert key_lengths.pop() <= max_sequence_length | ||
|
||
|
||
# Tests for fetching train args | ||
@pytest.mark.parametrize( | ||
"use_validation_data, collator_type, packing", | ||
[ | ||
(True, None, True), | ||
(False, None, True), | ||
(True, DataCollatorForCompletionOnlyLM, False), | ||
(False, DataCollatorForCompletionOnlyLM, False), | ||
], | ||
) | ||
def test_get_trainer_kwargs_with_response_template_and_text_field( | ||
use_validation_data, collator_type, packing | ||
): | ||
training_data_path = TWITTER_COMPLAINTS_DATA | ||
validation_data_path = training_data_path if use_validation_data else None | ||
# Expected columns in the raw loaded dataset for the twitter data | ||
column_names = set(["Tweet text", "ID", "Label", "text_label", "output"]) | ||
trainer_kwargs = get_data_trainer_kwargs( | ||
training_data_path=training_data_path, | ||
validation_data_path=validation_data_path, | ||
packing=packing, | ||
response_template="\n### Label:", | ||
max_sequence_length=100, | ||
tokenizer=AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0"), | ||
dataset_text_field="output", | ||
) | ||
assert len(trainer_kwargs) == 3 | ||
# If we are packing, we should not have a data collator | ||
if collator_type is None: | ||
assert trainer_kwargs["data_collator"] is None | ||
else: | ||
assert isinstance(trainer_kwargs["data_collator"], collator_type) | ||
|
||
# We should only have a validation dataset if one is present | ||
if validation_data_path is None: | ||
assert trainer_kwargs["eval_dataset"] is None | ||
else: | ||
assert isinstance(trainer_kwargs["eval_dataset"], Dataset) | ||
assert set(trainer_kwargs["eval_dataset"].column_names) == column_names | ||
|
||
assert isinstance(trainer_kwargs["train_dataset"], Dataset) | ||
assert set(trainer_kwargs["train_dataset"].column_names) == column_names | ||
|
||
|
||
@pytest.mark.parametrize("use_validation_data", [True, False]) | ||
def test_get_trainer_kwargs_with_custom_masking(use_validation_data): | ||
training_data_path = TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT | ||
validation_data_path = training_data_path if use_validation_data else None | ||
# Expected columns in the raw loaded dataset for the twitter data | ||
column_names = set(["input_ids", "attention_mask", "labels"]) | ||
trainer_kwargs = get_data_trainer_kwargs( | ||
training_data_path=training_data_path, | ||
validation_data_path=validation_data_path, | ||
packing=False, | ||
response_template=None, | ||
max_sequence_length=100, | ||
tokenizer=AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0"), | ||
dataset_text_field=None, | ||
) | ||
assert len(trainer_kwargs) == 4 | ||
# If we are packing, we should not have a data collator | ||
assert isinstance(trainer_kwargs["data_collator"], DataCollatorForSeq2Seq) | ||
|
||
# We should only have a validation dataset if one is present | ||
if validation_data_path is None: | ||
assert trainer_kwargs["eval_dataset"] is None | ||
else: | ||
assert isinstance(trainer_kwargs["eval_dataset"], Dataset) | ||
assert set(trainer_kwargs["eval_dataset"].column_names) == column_names | ||
|
||
assert isinstance(trainer_kwargs["train_dataset"], Dataset) | ||
assert set(trainer_kwargs["train_dataset"].column_names) == column_names | ||
# Needed to sidestep TRL validation | ||
assert trainer_kwargs["formatting_func"] is not None | ||
|
||
|
||
# Tests for fetching train args | ||
@pytest.mark.parametrize( | ||
"dataset_text_field, response_template", | ||
[ | ||
("input", None), | ||
(None, "output"), | ||
], | ||
) | ||
def test_validate_args(dataset_text_field, response_template): | ||
with pytest.raises(ValueError): | ||
validate_data_args(dataset_text_field, response_template) |
Oops, something went wrong.