-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add unit tests for various edge cases #97
Add unit tests for various edge cases #97
Conversation
|
||
### Tests for model dtype edge cases | ||
@pytest.mark.skipif( | ||
not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cuda
check is here needed because the bf16 check throws if no Nvidia drivers are available
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great test cases, thanks Alex! Left a few questions...
tests/test_sft_trainer.py
Outdated
def test_data_path_does_not_exist(): | ||
"""Ensure that we get a FileNotFoundError if the data is missing completely.""" | ||
TRAIN_KWARGS = { | ||
**BASE_PEFT_KWARGS, | ||
**{"training_data_path": "/foo/bar/foobar", "output_dir": "foo/bar/baz"}, | ||
} | ||
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( | ||
TRAIN_KWARGS | ||
) | ||
with pytest.raises(FileNotFoundError): | ||
sft_trainer.train(model_args, data_args, training_args, tune_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test already exists -- https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tests/test_sft_trainer.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch, removed, thanks!
# TODO: Fix this, currently unreachable due to crashing in batch encoding tokenization | ||
# We should do this validation up front, then do the encoding, then handle the collator | ||
raise ValueError("Response template is None, needs to be set for training") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you saying it fails before hitting this ValueError perhaps on line 158 with
response_template_ids = tokenizer.encode(
data_args.response_template, add_special_tokens=False
)[2:]
in which should this validation be moved up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you can't encode a None type with a tokenizer since tokenizers generally expect an input of type Union[TextInputSequence, Tuple[InputSequence, InputSequence]]
. It would be best to do that in a separate PR to keep things atomic even though it's a simple change, since some of the validation logic is a little bit delicate
reason="Only runs if bf16 is unsupported", | ||
) | ||
def test_bf16_still_tunes_if_unsupported(): | ||
"""Ensure that even if bf16 is not supported, tuning still works without problems.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting test case! can you explain why it doesn't fail and tuning still works and why this is the preferred expected behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I understand, in devices where bfloat16 is unsupported, there is usually fallback behavior to a supported data type, which is usually float32 since bfloat16 and float32 have the same exponent size!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting! appreciate knowing the details
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
efa446f
to
5b89073
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than the linter error for a line being too long, the tests look good! Thanks Alex
tests/test_sft_trainer.py:422:0: C0301: Line too long (105/100) (line-too-long)
Signed-off-by: Alex-Brooks <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Alex!
* Add unit tests for various edge cases Signed-off-by: Alex-Brooks <[email protected]> * Fix bf16 check in skipped test Signed-off-by: Alex-Brooks <[email protected]> * Remove redundant test Signed-off-by: Alex-Brooks <[email protected]> * Fix linting Signed-off-by: Alex-Brooks <[email protected]> --------- Signed-off-by: Alex-Brooks <[email protected]> Signed-off-by: aaron.chew1 <[email protected]>
This PR adds some extra unit tests for various edge cases. It also fixes (one of) the edge cases that has validation which is currently unreachable. It also changes a few stray
sys.exit()
calls to just raise errors, since I don't think there's a special reason to have those there at the moment.NOTE: Some of these tests are testing errors that are thrown by HF libraries for things we don't explicitly validate ourselves to ensure the behavior is stable with version bumps.
Related issue number
This is a follow-up on: #74
We made the decision to split the PR into two parts, to get the tests covering the base functionality it earlier.
How to verify the PR
Was the PR tested