-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add unit tests for various edge cases #97
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
This data is bad! We can't use it to tune. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,11 +22,14 @@ | |
import tempfile | ||
|
||
# Third Party | ||
from datasets.exceptions import DatasetGenerationError | ||
import pytest | ||
import torch | ||
import transformers | ||
|
||
# First Party | ||
from scripts.run_inference import TunedCausalLM | ||
from tests.data import TWITTER_COMPLAINTS_DATA | ||
from tests.data import EMPTY_DATA, MALFORMATTED_DATA, TWITTER_COMPLAINTS_DATA | ||
from tests.helpers import causal_lm_train_kwargs | ||
|
||
# Local | ||
|
@@ -122,9 +125,10 @@ def test_run_train_fails_training_data_path_not_exist(): | |
def test_run_causallm_pt_and_inference(): | ||
"""Check if we can bootstrap and peft tune causallm models""" | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
BASE_PEFT_KWARGS["output_dir"] = tempdir | ||
TRAIN_KWARGS = {**BASE_PEFT_KWARGS, **{"output_dir": tempdir}} | ||
|
||
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( | ||
BASE_PEFT_KWARGS | ||
TRAIN_KWARGS | ||
) | ||
sft_trainer.train(model_args, data_args, training_args, tune_config) | ||
|
||
|
@@ -148,19 +152,20 @@ def test_run_causallm_pt_and_inference(): | |
def test_run_causallm_pt_init_text(): | ||
"""Check if we can bootstrap and peft tune causallm models with init text as 'TEXT'""" | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
pt_init_text = copy.deepcopy(BASE_PEFT_KWARGS) | ||
pt_init_text["output_dir"] = tempdir | ||
pt_init_text["prompt_tuning_init"] = "TEXT" | ||
TRAIN_KWARGS = { | ||
**BASE_PEFT_KWARGS, | ||
**{"output_dir": tempdir, "prompt_tuning_init": "TEXT"}, | ||
} | ||
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( | ||
pt_init_text | ||
TRAIN_KWARGS | ||
) | ||
sft_trainer.train(model_args, data_args, training_args, tune_config) | ||
|
||
# validate peft tuning configs | ||
_validate_training(tempdir) | ||
checkpoint_path = _get_checkpoint_path(tempdir) | ||
adapter_config = _get_adapter_config(checkpoint_path) | ||
_validate_adapter_config(adapter_config, "PROMPT_TUNING", pt_init_text) | ||
_validate_adapter_config(adapter_config, "PROMPT_TUNING", TRAIN_KWARGS) | ||
|
||
|
||
invalid_params_map = [ | ||
|
@@ -326,3 +331,179 @@ def _validate_adapter_config(adapter_config, peft_type, base_kwargs): | |
if peft_type == "PROMPT_TUNING" | ||
else True | ||
) | ||
|
||
|
||
### Tests for a variety of edge cases and potentially problematic cases; | ||
# some of these test directly test validation within external dependencies | ||
# and validate errors that we expect to get from them which might be unintuitive. | ||
# In such cases, it would probably be best for us to handle these things directly | ||
# for better error messages, etc. | ||
|
||
### Tests related to tokenizer configuration | ||
def test_tokenizer_has_no_eos_token(): | ||
"""Ensure that if the model has no EOS token, it sets the default before formatting.""" | ||
# This is a bit roundabout, but patch the tokenizer and export it and the model to a tempdir | ||
# that we can then reload out of for the train call, and clean up afterwards. | ||
tokenizer = transformers.AutoTokenizer.from_pretrained( | ||
BASE_PEFT_KWARGS["model_name_or_path"] | ||
) | ||
model = transformers.AutoModelForCausalLM.from_pretrained( | ||
BASE_PEFT_KWARGS["model_name_or_path"] | ||
) | ||
tokenizer.eos_token = None | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
tokenizer.save_pretrained(tempdir) | ||
model.save_pretrained(tempdir) | ||
TRAIN_KWARGS = { | ||
**BASE_PEFT_KWARGS, | ||
**{"model_name_or_path": tempdir, "output_dir": tempdir}, | ||
} | ||
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( | ||
TRAIN_KWARGS | ||
) | ||
# If we handled this badly, we would probably get something like a | ||
# TypeError: can only concatenate str (not "NoneType") to str error | ||
# when we go to apply the data formatter. | ||
sft_trainer.train(model_args, data_args, training_args, tune_config) | ||
_validate_training(tempdir) | ||
|
||
|
||
### Tests for Bad dataset specification, i.e., data is valid, but the field we point it at isn't | ||
def test_invalid_dataset_text_field(): | ||
"""Ensure that if we specify a dataset_text_field that doesn't exist, we get a KeyError.""" | ||
TRAIN_KWARGS = { | ||
**BASE_PEFT_KWARGS, | ||
**{"dataset_text_field": "not found", "output_dir": "foo/bar/baz"}, | ||
} | ||
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( | ||
TRAIN_KWARGS | ||
) | ||
with pytest.raises(KeyError): | ||
sft_trainer.train(model_args, data_args, training_args, tune_config) | ||
|
||
|
||
### Tests for bad training data (i.e., data_path is an unhappy value or points to an unhappy thing) | ||
def test_malformatted_data(): | ||
"""Ensure that malformatted data explodes due to failure to generate the dataset.""" | ||
TRAIN_KWARGS = { | ||
**BASE_PEFT_KWARGS, | ||
**{"training_data_path": MALFORMATTED_DATA, "output_dir": "foo/bar/baz"}, | ||
} | ||
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( | ||
TRAIN_KWARGS | ||
) | ||
with pytest.raises(DatasetGenerationError): | ||
sft_trainer.train(model_args, data_args, training_args, tune_config) | ||
|
||
|
||
def test_empty_data(): | ||
"""Ensure that malformatted data explodes due to failure to generate the dataset.""" | ||
TRAIN_KWARGS = { | ||
**BASE_PEFT_KWARGS, | ||
**{"training_data_path": EMPTY_DATA, "output_dir": "foo/bar/baz"}, | ||
} | ||
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( | ||
TRAIN_KWARGS | ||
) | ||
with pytest.raises(DatasetGenerationError): | ||
sft_trainer.train(model_args, data_args, training_args, tune_config) | ||
|
||
|
||
def test_data_path_is_a_directory(): | ||
"""Ensure that we get FileNotFoundError if we point the data path at a dir, not a file.""" | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
TRAIN_KWARGS = { | ||
**BASE_PEFT_KWARGS, | ||
**{"training_data_path": tempdir, "output_dir": tempdir}, | ||
} | ||
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( | ||
TRAIN_KWARGS | ||
) | ||
# Confusingly, if we pass a directory for our data path, it will throw a | ||
# FileNotFoundError saying "unable to find '<data_path>'", since it can't | ||
# find a matchable file in the path. | ||
with pytest.raises(FileNotFoundError): | ||
sft_trainer.train(model_args, data_args, training_args, tune_config) | ||
|
||
|
||
### Tests for bad tuning module configurations | ||
def test_run_causallm_lora_with_invalid_modules(): | ||
"""Check that we throw a value error if the target modules for lora don't exist.""" | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
TRAIN_KWARGS = { | ||
**BASE_PEFT_KWARGS, | ||
**{"peft_method": "lora", "output_dir": tempdir}, | ||
} | ||
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( | ||
TRAIN_KWARGS | ||
) | ||
# Defaults are q_proj / v_proj; this will fail lora as the torch module doesn't have them | ||
tune_config.target_modules = ["foo", "bar"] | ||
# Peft should throw a value error about modules not matching the base module | ||
with pytest.raises(ValueError): | ||
sft_trainer.train(model_args, data_args, training_args, tune_config) | ||
|
||
|
||
### Direct validation tests based on whether or not packing is enabled | ||
def test_no_packing_needs_dataset_text_field(): | ||
"""Ensure we need to set the dataset text field if packing is False""" | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
TRAIN_KWARGS = { | ||
**BASE_PEFT_KWARGS, | ||
**{"dataset_text_field": None, "output_dir": tempdir}, | ||
} | ||
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( | ||
TRAIN_KWARGS | ||
) | ||
with pytest.raises(ValueError): | ||
sft_trainer.train(model_args, data_args, training_args, tune_config) | ||
|
||
|
||
# TODO: Fix this case | ||
@pytest.mark.skip(reason="currently crashes before validation is done") | ||
def test_no_packing_needs_reponse_template(): | ||
"""Ensure we need to set the response template if packing is False""" | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
TRAIN_KWARGS = { | ||
**BASE_PEFT_KWARGS, | ||
**{"response_template": None, "output_dir": tempdir}, | ||
} | ||
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( | ||
TRAIN_KWARGS | ||
) | ||
with pytest.raises(ValueError): | ||
sft_trainer.train(model_args, data_args, training_args, tune_config) | ||
|
||
|
||
### Tests for model dtype edge cases | ||
@pytest.mark.skipif( | ||
not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()), | ||
reason="Only runs if bf16 is unsupported", | ||
) | ||
def test_bf16_still_tunes_if_unsupported(): | ||
"""Ensure that even if bf16 is not supported, tuning still works without problems.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. interesting test case! can you explain why it doesn't fail and tuning still works and why this is the preferred expected behavior? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I understand, in devices where bfloat16 is unsupported, there is usually fallback behavior to a supported data type, which is usually float32 since bfloat16 and float32 have the same exponent size! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. interesting! appreciate knowing the details |
||
assert not torch.cuda.is_bf16_supported() | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
TRAIN_KWARGS = { | ||
**BASE_PEFT_KWARGS, | ||
**{"torch_dtype": "bfloat16", "output_dir": tempdir}, | ||
} | ||
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( | ||
TRAIN_KWARGS | ||
) | ||
sft_trainer.train(model_args, data_args, training_args, tune_config) | ||
_validate_training(tempdir) | ||
|
||
|
||
def test_bad_torch_dtype(): | ||
"""Ensure that specifying an invalid torch dtype yields a ValueError.""" | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
TRAIN_KWARGS = { | ||
**BASE_PEFT_KWARGS, | ||
**{"torch_dtype": "not a type", "output_dir": tempdir}, | ||
} | ||
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( | ||
TRAIN_KWARGS | ||
) | ||
with pytest.raises(ValueError): | ||
sft_trainer.train(model_args, data_args, training_args, tune_config) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,6 @@ | |
from typing import Optional, Union | ||
import json | ||
import os | ||
import sys | ||
|
||
# Third Party | ||
from peft.utils.other import fsdp_auto_wrap_policy | ||
|
@@ -202,6 +201,26 @@ def train( | |
model=model, | ||
) | ||
|
||
# Configure the collator and validate args related to packing prior to formatting the dataset | ||
if train_args.packing: | ||
logger.info("Packing is set to True") | ||
data_collator = None | ||
packing = True | ||
else: | ||
logger.info("Packing is set to False") | ||
if data_args.response_template is None: | ||
# TODO: Fix this, currently unreachable due to crashing in batch encoding tokenization | ||
# We should do this validation up front, then do the encoding, then handle the collator | ||
raise ValueError("Response template is None, needs to be set for training") | ||
Comment on lines
+212
to
+214
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are you saying it fails before hitting this ValueError perhaps on line 158 with response_template_ids = tokenizer.encode(
data_args.response_template, add_special_tokens=False
)[2:] in which should this validation be moved up? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, you can't encode a None type with a tokenizer since tokenizers generally expect an input of type |
||
if data_args.dataset_text_field is None: | ||
raise ValueError("Dataset_text_field is None, needs to be set for training") | ||
data_collator = DataCollatorForCompletionOnlyLM( | ||
response_template_ids, | ||
tokenizer=tokenizer, | ||
ignore_index=configs.IGNORE_INDEX, | ||
) | ||
packing = False | ||
|
||
# load the data by parsing JSON | ||
data_files = {"train": data_args.training_data_path} | ||
if data_args.validation_data_path: | ||
|
@@ -235,31 +254,6 @@ def train( | |
) | ||
callbacks.append(tc_callback) | ||
|
||
if train_args.packing: | ||
logger.info("Packing is set to True") | ||
data_collator = None | ||
packing = True | ||
else: | ||
logger.info("Packing is set to False") | ||
if data_args.response_template is None: | ||
logger.error( | ||
"Error, response template is None, needs to be set for training" | ||
) | ||
sys.exit(-1) | ||
|
||
if data_args.dataset_text_field is None: | ||
logger.error( | ||
"Error, dataset_text_field is None, needs to be set for training" | ||
) | ||
sys.exit(-1) | ||
|
||
data_collator = DataCollatorForCompletionOnlyLM( | ||
response_template_ids, | ||
tokenizer=tokenizer, | ||
ignore_index=configs.IGNORE_INDEX, | ||
) | ||
packing = False | ||
|
||
trainer = SFTTrainer( | ||
model=model, | ||
tokenizer=tokenizer, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cuda
check is here needed because the bf16 check throws if no Nvidia drivers are available