Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for various edge cases #97

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@
### Constants used for data
DATA_DIR = os.path.join(os.path.dirname(__file__))
TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json")
EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json")
MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json")
Empty file added tests/data/empty_data.json
Empty file.
1 change: 1 addition & 0 deletions tests/data/malformatted_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This data is bad! We can't use it to tune.
197 changes: 189 additions & 8 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
import tempfile

# Third Party
from datasets.exceptions import DatasetGenerationError
import pytest
import torch
import transformers

# First Party
from scripts.run_inference import TunedCausalLM
from tests.data import TWITTER_COMPLAINTS_DATA
from tests.data import EMPTY_DATA, MALFORMATTED_DATA, TWITTER_COMPLAINTS_DATA
from tests.helpers import causal_lm_train_kwargs

# Local
Expand Down Expand Up @@ -122,9 +125,10 @@ def test_run_train_fails_training_data_path_not_exist():
def test_run_causallm_pt_and_inference():
"""Check if we can bootstrap and peft tune causallm models"""
with tempfile.TemporaryDirectory() as tempdir:
BASE_PEFT_KWARGS["output_dir"] = tempdir
TRAIN_KWARGS = {**BASE_PEFT_KWARGS, **{"output_dir": tempdir}}

model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
BASE_PEFT_KWARGS
TRAIN_KWARGS
)
sft_trainer.train(model_args, data_args, training_args, tune_config)

Expand All @@ -148,19 +152,20 @@ def test_run_causallm_pt_and_inference():
def test_run_causallm_pt_init_text():
"""Check if we can bootstrap and peft tune causallm models with init text as 'TEXT'"""
with tempfile.TemporaryDirectory() as tempdir:
pt_init_text = copy.deepcopy(BASE_PEFT_KWARGS)
pt_init_text["output_dir"] = tempdir
pt_init_text["prompt_tuning_init"] = "TEXT"
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"output_dir": tempdir, "prompt_tuning_init": "TEXT"},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
pt_init_text
TRAIN_KWARGS
)
sft_trainer.train(model_args, data_args, training_args, tune_config)

# validate peft tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", pt_init_text)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", TRAIN_KWARGS)


invalid_params_map = [
Expand Down Expand Up @@ -326,3 +331,179 @@ def _validate_adapter_config(adapter_config, peft_type, base_kwargs):
if peft_type == "PROMPT_TUNING"
else True
)


### Tests for a variety of edge cases and potentially problematic cases;
# some of these test directly test validation within external dependencies
# and validate errors that we expect to get from them which might be unintuitive.
# In such cases, it would probably be best for us to handle these things directly
# for better error messages, etc.

### Tests related to tokenizer configuration
def test_tokenizer_has_no_eos_token():
"""Ensure that if the model has no EOS token, it sets the default before formatting."""
# This is a bit roundabout, but patch the tokenizer and export it and the model to a tempdir
# that we can then reload out of for the train call, and clean up afterwards.
tokenizer = transformers.AutoTokenizer.from_pretrained(
BASE_PEFT_KWARGS["model_name_or_path"]
)
model = transformers.AutoModelForCausalLM.from_pretrained(
BASE_PEFT_KWARGS["model_name_or_path"]
)
tokenizer.eos_token = None
with tempfile.TemporaryDirectory() as tempdir:
tokenizer.save_pretrained(tempdir)
model.save_pretrained(tempdir)
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"model_name_or_path": tempdir, "output_dir": tempdir},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
# If we handled this badly, we would probably get something like a
# TypeError: can only concatenate str (not "NoneType") to str error
# when we go to apply the data formatter.
sft_trainer.train(model_args, data_args, training_args, tune_config)
_validate_training(tempdir)


### Tests for Bad dataset specification, i.e., data is valid, but the field we point it at isn't
def test_invalid_dataset_text_field():
"""Ensure that if we specify a dataset_text_field that doesn't exist, we get a KeyError."""
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"dataset_text_field": "not found", "output_dir": "foo/bar/baz"},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(KeyError):
sft_trainer.train(model_args, data_args, training_args, tune_config)


### Tests for bad training data (i.e., data_path is an unhappy value or points to an unhappy thing)
def test_malformatted_data():
"""Ensure that malformatted data explodes due to failure to generate the dataset."""
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"training_data_path": MALFORMATTED_DATA, "output_dir": "foo/bar/baz"},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(DatasetGenerationError):
sft_trainer.train(model_args, data_args, training_args, tune_config)


def test_empty_data():
"""Ensure that malformatted data explodes due to failure to generate the dataset."""
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"training_data_path": EMPTY_DATA, "output_dir": "foo/bar/baz"},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(DatasetGenerationError):
sft_trainer.train(model_args, data_args, training_args, tune_config)


def test_data_path_is_a_directory():
"""Ensure that we get FileNotFoundError if we point the data path at a dir, not a file."""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"training_data_path": tempdir, "output_dir": tempdir},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
# Confusingly, if we pass a directory for our data path, it will throw a
# FileNotFoundError saying "unable to find '<data_path>'", since it can't
# find a matchable file in the path.
with pytest.raises(FileNotFoundError):
sft_trainer.train(model_args, data_args, training_args, tune_config)


### Tests for bad tuning module configurations
def test_run_causallm_lora_with_invalid_modules():
"""Check that we throw a value error if the target modules for lora don't exist."""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"peft_method": "lora", "output_dir": tempdir},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
# Defaults are q_proj / v_proj; this will fail lora as the torch module doesn't have them
tune_config.target_modules = ["foo", "bar"]
# Peft should throw a value error about modules not matching the base module
with pytest.raises(ValueError):
sft_trainer.train(model_args, data_args, training_args, tune_config)


### Direct validation tests based on whether or not packing is enabled
def test_no_packing_needs_dataset_text_field():
"""Ensure we need to set the dataset text field if packing is False"""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"dataset_text_field": None, "output_dir": tempdir},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(ValueError):
sft_trainer.train(model_args, data_args, training_args, tune_config)


# TODO: Fix this case
@pytest.mark.skip(reason="currently crashes before validation is done")
def test_no_packing_needs_reponse_template():
"""Ensure we need to set the response template if packing is False"""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"response_template": None, "output_dir": tempdir},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(ValueError):
sft_trainer.train(model_args, data_args, training_args, tune_config)


### Tests for model dtype edge cases
@pytest.mark.skipif(
not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda check is here needed because the bf16 check throws if no Nvidia drivers are available

reason="Only runs if bf16 is unsupported",
)
def test_bf16_still_tunes_if_unsupported():
"""Ensure that even if bf16 is not supported, tuning still works without problems."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting test case! can you explain why it doesn't fail and tuning still works and why this is the preferred expected behavior?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, in devices where bfloat16 is unsupported, there is usually fallback behavior to a supported data type, which is usually float32 since bfloat16 and float32 have the same exponent size!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting! appreciate knowing the details

assert not torch.cuda.is_bf16_supported()
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"torch_dtype": "bfloat16", "output_dir": tempdir},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
sft_trainer.train(model_args, data_args, training_args, tune_config)
_validate_training(tempdir)


def test_bad_torch_dtype():
"""Ensure that specifying an invalid torch dtype yields a ValueError."""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"torch_dtype": "not a type", "output_dir": tempdir},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(ValueError):
sft_trainer.train(model_args, data_args, training_args, tune_config)
46 changes: 20 additions & 26 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Optional, Union
import json
import os
import sys

# Third Party
from peft.utils.other import fsdp_auto_wrap_policy
Expand Down Expand Up @@ -202,6 +201,26 @@ def train(
model=model,
)

# Configure the collator and validate args related to packing prior to formatting the dataset
if train_args.packing:
logger.info("Packing is set to True")
data_collator = None
packing = True
else:
logger.info("Packing is set to False")
if data_args.response_template is None:
# TODO: Fix this, currently unreachable due to crashing in batch encoding tokenization
# We should do this validation up front, then do the encoding, then handle the collator
raise ValueError("Response template is None, needs to be set for training")
Comment on lines +212 to +214
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you saying it fails before hitting this ValueError perhaps on line 158 with

response_template_ids = tokenizer.encode(
        data_args.response_template, add_special_tokens=False
    )[2:]

in which should this validation be moved up?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you can't encode a None type with a tokenizer since tokenizers generally expect an input of type Union[TextInputSequence, Tuple[InputSequence, InputSequence]]. It would be best to do that in a separate PR to keep things atomic even though it's a simple change, since some of the validation logic is a little bit delicate

if data_args.dataset_text_field is None:
raise ValueError("Dataset_text_field is None, needs to be set for training")
data_collator = DataCollatorForCompletionOnlyLM(
response_template_ids,
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)
packing = False

# load the data by parsing JSON
data_files = {"train": data_args.training_data_path}
if data_args.validation_data_path:
Expand Down Expand Up @@ -235,31 +254,6 @@ def train(
)
callbacks.append(tc_callback)

if train_args.packing:
logger.info("Packing is set to True")
data_collator = None
packing = True
else:
logger.info("Packing is set to False")
if data_args.response_template is None:
logger.error(
"Error, response template is None, needs to be set for training"
)
sys.exit(-1)

if data_args.dataset_text_field is None:
logger.error(
"Error, dataset_text_field is None, needs to be set for training"
)
sys.exit(-1)

data_collator = DataCollatorForCompletionOnlyLM(
response_template_ids,
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)
packing = False

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
Expand Down
Loading