Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for various edge cases #97

Merged

Conversation

alex-jw-brooks
Copy link
Collaborator

@alex-jw-brooks alex-jw-brooks commented Mar 21, 2024

This PR adds some extra unit tests for various edge cases. It also fixes (one of) the edge cases that has validation which is currently unreachable. It also changes a few stray sys.exit() calls to just raise errors, since I don't think there's a special reason to have those there at the moment.

NOTE: Some of these tests are testing errors that are thrown by HF libraries for things we don't explicitly validate ourselves to ensure the behavior is stable with version bumps.

Related issue number

This is a follow-up on: #74

We made the decision to split the PR into two parts, to get the tests covering the base functionality it earlier.

How to verify the PR

Was the PR tested

  • I have added >=1 unit test(s) for every new method I have added.
  • I have ensured all unit tests pass


### Tests for model dtype edge cases
@pytest.mark.skipif(
not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda check is here needed because the bf16 check throws if no Nvidia drivers are available

Copy link
Collaborator

@anhuong anhuong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great test cases, thanks Alex! Left a few questions...

Comment on lines 408 to 418
def test_data_path_does_not_exist():
"""Ensure that we get a FileNotFoundError if the data is missing completely."""
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"training_data_path": "/foo/bar/foobar", "output_dir": "foo/bar/baz"},
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
with pytest.raises(FileNotFoundError):
sft_trainer.train(model_args, data_args, training_args, tune_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch, removed, thanks!

Comment on lines +204 to +214
# TODO: Fix this, currently unreachable due to crashing in batch encoding tokenization
# We should do this validation up front, then do the encoding, then handle the collator
raise ValueError("Response template is None, needs to be set for training")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you saying it fails before hitting this ValueError perhaps on line 158 with

response_template_ids = tokenizer.encode(
        data_args.response_template, add_special_tokens=False
    )[2:]

in which should this validation be moved up?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you can't encode a None type with a tokenizer since tokenizers generally expect an input of type Union[TextInputSequence, Tuple[InputSequence, InputSequence]]. It would be best to do that in a separate PR to keep things atomic even though it's a simple change, since some of the validation logic is a little bit delicate

reason="Only runs if bf16 is unsupported",
)
def test_bf16_still_tunes_if_unsupported():
"""Ensure that even if bf16 is not supported, tuning still works without problems."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting test case! can you explain why it doesn't fail and tuning still works and why this is the preferred expected behavior?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, in devices where bfloat16 is unsupported, there is usually fallback behavior to a supported data type, which is usually float32 since bfloat16 and float32 have the same exponent size!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting! appreciate knowing the details

anhuong
anhuong previously approved these changes Apr 18, 2024
Copy link
Collaborator

@anhuong anhuong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the linter error for a line being too long, the tests look good! Thanks Alex

tests/test_sft_trainer.py:422:0: C0301: Line too long (105/100) (line-too-long)

Signed-off-by: Alex-Brooks <[email protected]>
Copy link
Collaborator

@anhuong anhuong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Alex!

@anhuong anhuong merged commit 8548a6d into foundation-model-stack:main Apr 24, 2024
5 checks passed
achew010 pushed a commit to achew010/fms-hf-tuning that referenced this pull request May 6, 2024
* Add unit tests for various edge cases

Signed-off-by: Alex-Brooks <[email protected]>

* Fix bf16 check in skipped test

Signed-off-by: Alex-Brooks <[email protected]>

* Remove redundant test

Signed-off-by: Alex-Brooks <[email protected]>

* Fix linting

Signed-off-by: Alex-Brooks <[email protected]>

---------

Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: aaron.chew1 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants