-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TRL SFTTrainer Examples #2211
TRL SFTTrainer Examples #2211
Conversation
Thanks Sara - this looks really nice Are there any other features we should flex? I am thinking we might want to look at:
|
Sure I'll test both of these scenarios, but if it ends up being more than tweaking to get FSDP working I'm going to leave that for another ticket :) Edit: both worked with some minor tweaks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great overall - should probably move the examples out of src and add a brief readme to go along with it. Debatable whether or not we'd want SFTTrainer out of src as well
training_args = TrainingArguments( | ||
output_dir=output_dir, | ||
num_train_epochs=0.6, | ||
logging_steps=50, | ||
gradient_checkpointing=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it important at all that the TrainingArguments comes from SparseML?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few things won't work if the native transformers TrainingArguments is used: no support for recipe overrides, no compressed save, no multistage training runs. The mix-in uses these params, so if we wanted to support the non-sparseml TrainingArguments we would have to check each time we reference them. I don't think its worth the extra lines personally
Trainer
to be a barebones class definition, everything is handled bySessionManagerMixIn
. Removed a bunch of old code from loading recipes, as this is handled bySparseAutoModelForCausalLM
SFTTrainer
class which adds our mix-in to trl'sSFTTrainer
. The only added code here is to add support for passing in a tokenized dataset toSFTTrainer
SFTTrainer
for sparse finetuning, both with out dataset preprocessing and TRL's dataset preprocessingAsana ticket: https://app.asana.com/0/1201735099598270/1206486351032763/f
Testing
See examples in
integrations/huggingface-transformers/tutorials/text-generation/trl_mixin