Skip to content

Commit

Permalink
Formatting consolidation main (#216)
Browse files Browse the repository at this point in the history
* Add prepreprocessing utils for pretokenized datasets

Signed-off-by: Alex-Brooks <[email protected]>

* Add twitter input/output data

Signed-off-by: Alex-Brooks <[email protected]>

* Add tests for data preprocessing utilities

Signed-off-by: Alex-Brooks <[email protected]>

* Formatting, add hack for sidestepping validation

Signed-off-by: Alex-Brooks <[email protected]>

* Fix linting errors in data gen

Signed-off-by: Alex-Brooks <[email protected]>

* Add end to end pretokenized tests, formatting

Signed-off-by: Alex-Brooks <[email protected]>

* Add docstrings for preprocessor utils

Signed-off-by: Alex-Brooks <[email protected]>

* Rebase tests to new structure

Signed-off-by: Alex-Brooks <[email protected]>

* fix formatting

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix linting

Signed-off-by: Sukriti-Sharma4 <[email protected]>

---------

Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Co-authored-by: Alex-Brooks <[email protected]>
  • Loading branch information
Ssukriti and alex-jw-brooks authored Jun 27, 2024
1 parent 3f05c67 commit 4334d6c
Show file tree
Hide file tree
Showing 4 changed files with 588 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
### Constants used for data
DATA_DIR = os.path.join(os.path.dirname(__file__))
TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json")
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT = os.path.join(
DATA_DIR, "twitter_complaints_input_output.json"
)
TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join(DATA_DIR, "twitter_complaints_json.json")
EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json")
MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json")
50 changes: 50 additions & 0 deletions tests/data/twitter_complaints_input_output.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{"ID": 0, "Label": 2, "input": "@HMRCcustomers No this is my first job", "output": "no complaint"}
{"ID": 1, "Label": 2, "input": "@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.", "output": "no complaint"}
{"ID": 2, "Label": 1, "input": "If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService", "output": "complaint"}
{"ID": 3, "Label": 1, "input": "@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.", "output": "complaint"}
{"ID": 4, "Label": 2, "input": "Couples wallpaper, so cute. :) #BrothersAtHome", "output": "no complaint"}
{"ID": 5, "Label": 2, "input": "@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https://t.co/WRtNsokblG", "output": "no complaint"}
{"ID": 6, "Label": 2, "input": "@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?", "output": "no complaint"}
{"ID": 7, "Label": 1, "input": "@nationalgridus I have no water and the bill is current and paid. Can you do something about this?", "output": "complaint"}
{"ID": 8, "Label": 1, "input": "Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude/condescending. I'll take my $$ to @Sephora", "output": "complaint"}
{"ID": 9, "Label": 2, "input": "@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd", "output": "no complaint"}
{"ID": 10, "Label": 2, "input": "@NortonSupport Thanks much.", "output": "no complaint"}
{"ID": 11, "Label": 2, "input": "@VerizonSupport all of a sudden I can't connect to my primary wireless network but guest one works", "output": "no complaint"}
{"ID": 12, "Label": 2, "input": "Aaaahhhhh!!!! My @Razer @PlayOverwatch d.va meka headset came in!!! I didn't even know it had shipped!!! So excited\u2026 https://t.co/4gXy9xED8d", "output": "no complaint"}
{"ID": 13, "Label": 2, "input": "@Lin_Manuel @jmessinaphoto @VAMNit Omg a little squish!!!!! Enjoy and congrats!!!! I miss mine being so young! \ufffd\ufffd\ufffd\ufffd\ufffd\ufffd", "output": "no complaint"}
{"ID": 14, "Label": 2, "input": "@IanJamesPoulter What's your secret to poaching eggs? Mine NEVER look that good.", "output": "no complaint"}
{"ID": 15, "Label": 2, "input": "@AWSSupport When will be able Kinesis Firehose compatible with Elasticsearch 6.0? Thank you!", "output": "no complaint"}
{"ID": 16, "Label": 2, "input": "@NCIS_CBS https://t.co/eeVL9Eu3bE", "output": "no complaint"}
{"ID": 17, "Label": 2, "input": "@msetchell Via the settings? That\u2019s how I do it on master T\u2019s", "output": "no complaint"}
{"ID": 18, "Label": 2, "input": "Today at work there was a low flying duck heading toward a crowd of people, and I yelled \"watch out! and I'm very disappointed with myself.", "output": "no complaint"}
{"ID": 19, "Label": 1, "input": "@NortonSupport @NortonOnline What the hell is a dm 5-10 days to get money back bank account now overdrawn thanks guys", "output": "complaint"}
{"ID": 20, "Label": 1, "input": "@united not happy with this delay from Newark to Manchester tonight :( only 30 mins free Wi-fi sucks ...", "output": "complaint"}
{"ID": 21, "Label": 1, "input": "@ZARA_Care I've been waiting on a reply to my tweets and DMs for days now?", "output": "complaint"}
{"ID": 22, "Label": 2, "input": "New Listing! Large 2 Family Home for Sale in #Passaic Park, #NJ #realestate #homesforsale Great Location!\u2026 https://t.co/IV4OrLXkMk", "output": "no complaint"}
{"ID": 23, "Label": 1, "input": "@SouthwestAir I love you but when sending me flight changes please don't use military time #ignoranceisbliss", "output": "complaint"}
{"ID": 24, "Label": 2, "input": "@JetBlue Completely understand but would prefer being on time to filling out forms....", "output": "no complaint"}
{"ID": 25, "Label": 2, "input": "@nvidiacc I own two gtx 460 in sli. I want to try windows 8 dev preview. Which driver should I use. Can I use the windows 7 one.", "output": "no complaint"}
{"ID": 26, "Label": 2, "input": "Just posted a photo https://t.co/RShFwCjPHu", "output": "no complaint"}
{"ID": 27, "Label": 2, "input": "Love crescent rolls? Try adding pesto @PerdueChicken to them and you\u2019re going to love it! #Promotion #PerdueCrew -\u2026 https://t.co/KBHOfqCukH", "output": "no complaint"}
{"ID": 28, "Label": 1, "input": "@TopmanAskUs please just give me my money back.", "output": "complaint"}
{"ID": 29, "Label": 2, "input": "I just gave 5 stars to Tracee at @neimanmarcus for the great service I received!", "output": "no complaint"}
{"ID": 30, "Label": 2, "input": "@FitbitSupport when are you launching new clock faces for Indian market", "output": "no complaint"}
{"ID": 31, "Label": 1, "input": "@HPSupport my printer will not allow me to choose color instead it only prints monochrome #hppsdr #ijkhelp", "output": "complaint"}
{"ID": 32, "Label": 1, "input": "@DIRECTV can I get a monthly charge double refund when it sprinkles outside and we lose reception? #IamEmbarrasedForYou", "output": "complaint"}
{"ID": 33, "Label": 1, "input": "@AlfaRomeoCares Hi thanks for replying, could be my internet but link doesn't seem to be working", "output": "complaint"}
{"ID": 34, "Label": 2, "input": "Looks tasty! Going to share with everyone I know #FebrezeONE #sponsored https://t.co/4AQI53npei", "output": "no complaint"}
{"ID": 35, "Label": 2, "input": "@OnePlus_IN can OnePlus 5T do front camera portrait?", "output": "no complaint"}
{"ID": 36, "Label": 1, "input": "@sho_help @showtime your arrive is terrible streaming is stop and start every couple mins. Get it together it's xmas", "output": "complaint"}
{"ID": 37, "Label": 2, "input": "@KandraKPTV I just witnessed a huge building fire in Santa Monica California", "output": "no complaint"}
{"ID": 38, "Label": 2, "input": "@fernrocks most definitely the latter for me", "output": "no complaint"}
{"ID": 39, "Label": 1, "input": "@greateranglia Could I ask why the Area in front of BIC Station was not gritted withh all the snow.", "output": "complaint"}
{"ID": 40, "Label": 2, "input": "I'm earning points with #CricketRewards https://t.co/GfpGhqqnhE", "output": "no complaint"}
{"ID": 41, "Label": 2, "input": "@Schrapnel @comcast RIP me", "output": "no complaint"}
{"ID": 42, "Label": 2, "input": "The wait is finally over, just joined @SquareUK, hope to get started real soon!", "output": "no complaint"}
{"ID": 43, "Label": 2, "input": "@WholeFoods what's the best way to give feedback on a particular store to the regional/national office?", "output": "no complaint"}
{"ID": 44, "Label": 2, "input": "@DanielNewman I honestly would believe anything. People are...too much sometimes.", "output": "no complaint"}
{"ID": 45, "Label": 2, "input": "@asblough Yep! It should send you a notification with your driver\u2019s name and what time they\u2019ll be showing up!", "output": "no complaint"}
{"ID": 46, "Label": 2, "input": "@Wavy2Timez for real", "output": "no complaint"}
{"ID": 47, "Label": 1, "input": "@KenyaPower_Care no power in south b area... is it scheduled.", "output": "complaint"}
{"ID": 48, "Label": 1, "input": "Honda won't do anything about water leaking in brand new car. Frustrated! @HondaCustSvc @AmericanHonda", "output": "complaint"}
{"ID": 49, "Label": 1, "input": "@CBSNews @Dodge @ChryslerCares My driver side air bag has been recalled and replaced, but what about the passenger side?", "output": "complaint"}
193 changes: 193 additions & 0 deletions tests/utils/test_preprocessing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Third Party
from datasets import Dataset
from datasets.exceptions import DatasetGenerationError
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
from trl import DataCollatorForCompletionOnlyLM
import pytest

# First Party
from tests.data import (
MALFORMATTED_DATA,
TWITTER_COMPLAINTS_DATA,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT,
)

# Local
from tuning.utils.preprocessing_utils import (
combine_sequence,
get_data_trainer_kwargs,
get_preprocessed_dataset,
load_hf_dataset_from_jsonl_file,
validate_data_args,
)


@pytest.mark.parametrize(
"input_element,output_element,expected_res",
[
("foo ", "bar", "foo bar"),
("foo\n", "bar", "foo\nbar"),
("foo\t", "bar", "foo\tbar"),
("foo", "bar", "foo bar"),
],
)
def test_combine_sequence(input_element, output_element, expected_res):
"""Ensure that input / output elements are combined with correct whitespace handling."""
comb_seq = combine_sequence(input_element, output_element)
assert isinstance(comb_seq, str)
assert comb_seq == expected_res


# Tests for loading the dataset from disk
def test_load_hf_dataset_from_jsonl_file():
input_field_name = "Tweet text"
output_field_name = "text_label"
data = load_hf_dataset_from_jsonl_file(
TWITTER_COMPLAINTS_DATA,
input_field_name=input_field_name,
output_field_name=output_field_name,
)
# Our dataset should contain dicts that contain the input / output field name types
next_data = next(iter(data))
assert input_field_name in next_data
assert output_field_name in next_data


def test_load_hf_dataset_from_jsonl_file_wrong_keys():
"""Ensure that we explode if the keys are not in the jsonl file."""
with pytest.raises(DatasetGenerationError):
load_hf_dataset_from_jsonl_file(
TWITTER_COMPLAINTS_DATA, input_field_name="foo", output_field_name="bar"
)


def test_load_hf_dataset_from_malformatted_data():
"""Ensure that we explode if the data is not properly formatted."""
# NOTE: The actual keys don't matter here
with pytest.raises(DatasetGenerationError):
load_hf_dataset_from_jsonl_file(
MALFORMATTED_DATA, input_field_name="foo", output_field_name="bar"
)


def test_load_hf_dataset_from_jsonl_file_duplicate_keys():
"""Ensure we cannot have the same key for input / output."""
with pytest.raises(ValueError):
load_hf_dataset_from_jsonl_file(
TWITTER_COMPLAINTS_DATA,
input_field_name="Tweet text",
output_field_name="Tweet text",
)


# Tests for custom masking / preprocessing logic
@pytest.mark.parametrize("max_sequence_length", [1, 10, 100, 1000])
def test_get_preprocessed_dataset(max_sequence_length):
tokenizer = AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0")
preprocessed_data = get_preprocessed_dataset(
data_path=TWITTER_COMPLAINTS_DATA,
tokenizer=tokenizer,
max_sequence_length=max_sequence_length,
input_field_name="Tweet text",
output_field_name="text_label",
)
for tok_res in preprocessed_data:
# Since the padding is left to the collator, there should be no 0s in the attention mask yet
assert sum(tok_res["attention_mask"]) == len(tok_res["attention_mask"])
# If the source text isn't empty, we start with masked inputs
assert tok_res["labels"][0] == -100
# All keys in the produced record must be the same length
key_lengths = {len(tok_res[k]) for k in tok_res.keys()}
assert len(key_lengths) == 1
# And also that length should be less than or equal to the max length depending on if we
# are going up to / over the max size and truncating - padding is handled separately
assert key_lengths.pop() <= max_sequence_length


# Tests for fetching train args
@pytest.mark.parametrize(
"use_validation_data, collator_type, packing",
[
(True, None, True),
(False, None, True),
(True, DataCollatorForCompletionOnlyLM, False),
(False, DataCollatorForCompletionOnlyLM, False),
],
)
def test_get_trainer_kwargs_with_response_template_and_text_field(
use_validation_data, collator_type, packing
):
training_data_path = TWITTER_COMPLAINTS_DATA
validation_data_path = training_data_path if use_validation_data else None
# Expected columns in the raw loaded dataset for the twitter data
column_names = set(["Tweet text", "ID", "Label", "text_label", "output"])
trainer_kwargs = get_data_trainer_kwargs(
training_data_path=training_data_path,
validation_data_path=validation_data_path,
packing=packing,
response_template="\n### Label:",
max_sequence_length=100,
tokenizer=AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0"),
dataset_text_field="output",
)
assert len(trainer_kwargs) == 3
# If we are packing, we should not have a data collator
if collator_type is None:
assert trainer_kwargs["data_collator"] is None
else:
assert isinstance(trainer_kwargs["data_collator"], collator_type)

# We should only have a validation dataset if one is present
if validation_data_path is None:
assert trainer_kwargs["eval_dataset"] is None
else:
assert isinstance(trainer_kwargs["eval_dataset"], Dataset)
assert set(trainer_kwargs["eval_dataset"].column_names) == column_names

assert isinstance(trainer_kwargs["train_dataset"], Dataset)
assert set(trainer_kwargs["train_dataset"].column_names) == column_names


@pytest.mark.parametrize("use_validation_data", [True, False])
def test_get_trainer_kwargs_with_custom_masking(use_validation_data):
training_data_path = TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT
validation_data_path = training_data_path if use_validation_data else None
# Expected columns in the raw loaded dataset for the twitter data
column_names = set(["input_ids", "attention_mask", "labels"])
trainer_kwargs = get_data_trainer_kwargs(
training_data_path=training_data_path,
validation_data_path=validation_data_path,
packing=False,
response_template=None,
max_sequence_length=100,
tokenizer=AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0"),
dataset_text_field=None,
)
assert len(trainer_kwargs) == 4
# If we are packing, we should not have a data collator
assert isinstance(trainer_kwargs["data_collator"], DataCollatorForSeq2Seq)

# We should only have a validation dataset if one is present
if validation_data_path is None:
assert trainer_kwargs["eval_dataset"] is None
else:
assert isinstance(trainer_kwargs["eval_dataset"], Dataset)
assert set(trainer_kwargs["eval_dataset"].column_names) == column_names

assert isinstance(trainer_kwargs["train_dataset"], Dataset)
assert set(trainer_kwargs["train_dataset"].column_names) == column_names
# Needed to sidestep TRL validation
assert trainer_kwargs["formatting_func"] is not None


# Tests for fetching train args
@pytest.mark.parametrize(
"dataset_text_field, response_template",
[
("input", None),
(None, "output"),
],
)
def test_validate_args(dataset_text_field, response_template):
with pytest.raises(ValueError):
validate_data_args(dataset_text_field, response_template)
Loading

0 comments on commit 4334d6c

Please sign in to comment.