Skip to content

Commit

Permalink
Add unit tests for various edge cases (#97)
Browse files Browse the repository at this point in the history
* Add unit tests for various edge cases

Signed-off-by: Alex-Brooks <[email protected]>

* Fix bf16 check in skipped test

Signed-off-by: Alex-Brooks <[email protected]>

* Remove redundant test

Signed-off-by: Alex-Brooks <[email protected]>

* Fix linting

Signed-off-by: Alex-Brooks <[email protected]>

---------

Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks authored Apr 24, 2024
1 parent db99b28 commit 8548a6d
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 34 deletions.
2 changes: 2 additions & 0 deletions tests/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@
### Constants used for data
DATA_DIR = os.path.join(os.path.dirname(__file__))
TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json")
EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json")
MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json")
Empty file added tests/data/empty_data.json
Empty file.
1 change: 1 addition & 0 deletions tests/data/malformatted_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This data is bad! We can't use it to tune.
197 changes: 189 additions & 8 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
import tempfile

# Third Party
from datasets.exceptions import DatasetGenerationError
import pytest
import torch
import transformers

# First Party
from scripts.run_inference import TunedCausalLM
from tests.data import TWITTER_COMPLAINTS_DATA
from tests.data import EMPTY_DATA, MALFORMATTED_DATA, TWITTER_COMPLAINTS_DATA
from tests.helpers import causal_lm_train_kwargs

# Local
Expand Down Expand Up @@ -122,9 +125,10 @@ def test_run_train_fails_training_data_path_not_exist():
def test_run_causallm_pt_and_inference():
"""Check if we can bootstrap and peft tune causallm models"""
with tempfile.TemporaryDirectory() as tempdir:
BASE_PEFT_KWARGS["output_dir"] = tempdir
TRAIN_KWARGS = {**BASE_PEFT_KWARGS, **{"output_dir": tempdir}}

model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
BASE_PEFT_KWARGS
TRAIN_KWARGS
)
sft_trainer.train(model_args, data_args, training_args, tune_config)

Expand All @@ -148,19 +152,20 @@ def test_run_causallm_pt_and_inference():
def test_run_causallm_pt_init_text():
"""Check if we can bootstrap and peft tune causallm models with init text as 'TEXT'"""
with tempfile.TemporaryDirectory() as tempdir:
pt_init_text = copy.deepcopy(BASE_PEFT_KWARGS)
pt_init_text["output_dir"] = tempdir
pt_init_text["prompt_tuning_init"] = "TEXT"
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"output_dir": tempdir, "prompt_tuning_init": "TEXT"},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
pt_init_text
TRAIN_KWARGS
)
sft_trainer.train(model_args, data_args, training_args, tune_config)

# validate peft tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", pt_init_text)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", TRAIN_KWARGS)


invalid_params_map = [
Expand Down Expand Up @@ -326,3 +331,179 @@ def _validate_adapter_config(adapter_config, peft_type, base_kwargs):
if peft_type == "PROMPT_TUNING"
else True
)


### Tests for a variety of edge cases and potentially problematic cases;
# some of these test directly test validation within external dependencies
# and validate errors that we expect to get from them which might be unintuitive.
# In such cases, it would probably be best for us to handle these things directly
# for better error messages, etc.

### Tests related to tokenizer configuration
def test_tokenizer_has_no_eos_token():
"""Ensure that if the model has no EOS token, it sets the default before formatting."""
# This is a bit roundabout, but patch the tokenizer and export it and the model to a tempdir
# that we can then reload out of for the train call, and clean up afterwards.
tokenizer = transformers.AutoTokenizer.from_pretrained(
BASE_PEFT_KWARGS["model_name_or_path"]
)
model = transformers.AutoModelForCausalLM.from_pretrained(
BASE_PEFT_KWARGS["model_name_or_path"]
)
tokenizer.eos_token = None
with tempfile.TemporaryDirectory() as tempdir:
tokenizer.save_pretrained(tempdir)
model.save_pretrained(tempdir)
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"model_name_or_path": tempdir, "output_dir": tempdir},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
# If we handled this badly, we would probably get something like a
# TypeError: can only concatenate str (not "NoneType") to str error
# when we go to apply the data formatter.
sft_trainer.train(model_args, data_args, training_args, tune_config)
_validate_training(tempdir)


### Tests for Bad dataset specification, i.e., data is valid, but the field we point it at isn't
def test_invalid_dataset_text_field():
"""Ensure that if we specify a dataset_text_field that doesn't exist, we get a KeyError."""
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"dataset_text_field": "not found", "output_dir": "foo/bar/baz"},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(KeyError):
sft_trainer.train(model_args, data_args, training_args, tune_config)


### Tests for bad training data (i.e., data_path is an unhappy value or points to an unhappy thing)
def test_malformatted_data():
"""Ensure that malformatted data explodes due to failure to generate the dataset."""
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"training_data_path": MALFORMATTED_DATA, "output_dir": "foo/bar/baz"},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(DatasetGenerationError):
sft_trainer.train(model_args, data_args, training_args, tune_config)


def test_empty_data():
"""Ensure that malformatted data explodes due to failure to generate the dataset."""
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"training_data_path": EMPTY_DATA, "output_dir": "foo/bar/baz"},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(DatasetGenerationError):
sft_trainer.train(model_args, data_args, training_args, tune_config)


def test_data_path_is_a_directory():
"""Ensure that we get FileNotFoundError if we point the data path at a dir, not a file."""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"training_data_path": tempdir, "output_dir": tempdir},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
# Confusingly, if we pass a directory for our data path, it will throw a
# FileNotFoundError saying "unable to find '<data_path>'", since it can't
# find a matchable file in the path.
with pytest.raises(FileNotFoundError):
sft_trainer.train(model_args, data_args, training_args, tune_config)


### Tests for bad tuning module configurations
def test_run_causallm_lora_with_invalid_modules():
"""Check that we throw a value error if the target modules for lora don't exist."""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"peft_method": "lora", "output_dir": tempdir},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
# Defaults are q_proj / v_proj; this will fail lora as the torch module doesn't have them
tune_config.target_modules = ["foo", "bar"]
# Peft should throw a value error about modules not matching the base module
with pytest.raises(ValueError):
sft_trainer.train(model_args, data_args, training_args, tune_config)


### Direct validation tests based on whether or not packing is enabled
def test_no_packing_needs_dataset_text_field():
"""Ensure we need to set the dataset text field if packing is False"""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"dataset_text_field": None, "output_dir": tempdir},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(ValueError):
sft_trainer.train(model_args, data_args, training_args, tune_config)


# TODO: Fix this case
@pytest.mark.skip(reason="currently crashes before validation is done")
def test_no_packing_needs_reponse_template():
"""Ensure we need to set the response template if packing is False"""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"response_template": None, "output_dir": tempdir},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(ValueError):
sft_trainer.train(model_args, data_args, training_args, tune_config)


### Tests for model dtype edge cases
@pytest.mark.skipif(
not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()),
reason="Only runs if bf16 is unsupported",
)
def test_bf16_still_tunes_if_unsupported():
"""Ensure that even if bf16 is not supported, tuning still works without problems."""
assert not torch.cuda.is_bf16_supported()
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"torch_dtype": "bfloat16", "output_dir": tempdir},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
sft_trainer.train(model_args, data_args, training_args, tune_config)
_validate_training(tempdir)


def test_bad_torch_dtype():
"""Ensure that specifying an invalid torch dtype yields a ValueError."""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"torch_dtype": "not a type", "output_dir": tempdir},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(ValueError):
sft_trainer.train(model_args, data_args, training_args, tune_config)
46 changes: 20 additions & 26 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Optional, Union
import json
import os
import sys

# Third Party
from peft.utils.other import fsdp_auto_wrap_policy
Expand Down Expand Up @@ -202,6 +201,26 @@ def train(
model=model,
)

# Configure the collator and validate args related to packing prior to formatting the dataset
if train_args.packing:
logger.info("Packing is set to True")
data_collator = None
packing = True
else:
logger.info("Packing is set to False")
if data_args.response_template is None:
# TODO: Fix this, currently unreachable due to crashing in batch encoding tokenization
# We should do this validation up front, then do the encoding, then handle the collator
raise ValueError("Response template is None, needs to be set for training")
if data_args.dataset_text_field is None:
raise ValueError("Dataset_text_field is None, needs to be set for training")
data_collator = DataCollatorForCompletionOnlyLM(
response_template_ids,
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)
packing = False

# load the data by parsing JSON
data_files = {"train": data_args.training_data_path}
if data_args.validation_data_path:
Expand Down Expand Up @@ -235,31 +254,6 @@ def train(
)
callbacks.append(tc_callback)

if train_args.packing:
logger.info("Packing is set to True")
data_collator = None
packing = True
else:
logger.info("Packing is set to False")
if data_args.response_template is None:
logger.error(
"Error, response template is None, needs to be set for training"
)
sys.exit(-1)

if data_args.dataset_text_field is None:
logger.error(
"Error, dataset_text_field is None, needs to be set for training"
)
sys.exit(-1)

data_collator = DataCollatorForCompletionOnlyLM(
response_template_ids,
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)
packing = False

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
Expand Down

0 comments on commit 8548a6d

Please sign in to comment.