-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add support for smoothly resuming training from a saved checkpoint #300
feat: Add support for smoothly resuming training from a saved checkpoint #300
Conversation
Signed-off-by: Abhishek <[email protected]>
Signed-off-by: Abhishek <[email protected]>
tuning/utils/config_utils.py
Outdated
@@ -135,3 +135,32 @@ def txt_to_obj(txt): | |||
except UnicodeDecodeError: | |||
# Otherwise the bytes are a pickled python dictionary | |||
return pickle.loads(message_bytes) | |||
|
|||
|
|||
def get_last_checkpoint(train_args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this logic needed? since it seems to be duplicated from that inside HF trainer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also doing similar work to what exists in build/utils so that or this method could be refactored to be used in both places
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct. I missed that one. Now importing it from transformers.trainer_utils import get_last_checkpoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to reuse this flag to pass to train() function as is the way user provides, this would save us from too much of custom handling which limits other potential usecases like starting from a particular checkpoint etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to get clarity @kmehant, does user passes the flag resume_from_checkpoint
if users wants to resume from particular checkpoint ?
And in particular case where user doesn't pass the flag, do we assume that user doesn't intend on resuming the tuning and wants to start a new one, OR we still do need to check for latest checkpoint path and use it to resume tuning.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to get clarity @kmehant, does user passes the flag resume_from_checkpoint if users wants to resume from particular checkpoint ?
Yes, --resume_from_checkpoint=true
if to start from last checkpoint or to a path if it has to be some arbitrary checkpoint. And then we pass this as is to the train()
And in particular case where user doesn't pass the flag, do we assume that user doesn't intend on resuming the tuning and wants to start a new one, OR we still do need to check for latest checkpoint path and use it to resume tuning.
I have no particular suggestion on the default flow when user does not pass the flag 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, --resume_from_checkpoint=true if to start from last checkpoint or to a path if it has to be some arbitrary checkpoint. And then we pass this as is to the train()
Okay
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this different from simply passing the flag https://github.com/huggingface/transformers/blob/8820fe8b8c4b9da94cf1e4761876f85c562e0efe/src/transformers/training_args.py#L663?
Why not just set the |
Signed-off-by: Abhishek <[email protected]>
tuning/sft_trainer.py
Outdated
is_checkpoint_available = get_last_checkpoint(training_args.output_dir) | ||
if is_checkpoint_available: | ||
logger.info("Tuning resumes from the last checkpoint") | ||
trainer.train(resume_from_checkpoint=is_checkpoint_available) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resume_from_checkpoint=True
gives same effect isn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resume_from_checkpoint=True
, Yes I tested it. It gives the same effect if the output_dir
has checkpoints available or else throws error. Hence I had put a check on line 365 and then is_checkpoint_available
could be passed as boolean or string.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ashokponkumar It would also be helpful for me to know your requirements on default flow when user does not pass the resume_from_checkpoint
flag. Do we still need to consider resuming from checkpoint if output_dir
has checkpoints ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ashokponkumar Any comment/input on above ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What error do we get when resume_from_checkpoint
is set to True and there is no checkpoint? We can go with this fix for now, but it would be good to check in the HF Trainer, if they are okay with gracefully failing. Even if resume_from_checkpoint
is set, if there is no checkpoint in the output dir, it should start like a normal run from scratch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, if user actually provides a false explicitly we should honor it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What error do we get when resume_from_checkpoint is set to True and there is no checkpoint
Error output with failing of code and exiting, if resume_from_checkpoint is set to True and there is no checkpoint in output_dir
:
ValueError: No valid checkpoint found in output directory (outputs/lora-tuning)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if resume_from_checkpoint is set, if there is no checkpoint in the output dir, it should start like a normal run from scratch.
Also, if user actually provides a false explicitly we should honor it.
Okay.
Based on above comments, I have pushed few changes. Cases covered and tested: 1- If user passes the flag 2- If user passes the flag 3- If user passes the flag 4- If user doesn't pass flag |
Thanks @Abhishek-TAMU |
Sounds good.
|
@Abhishek-TAMU can we add simple unit tests for each of the cases you mentioned above? Thanks |
Signed-off-by: Abhishek <[email protected]>
Sure! Have added the test cases and pushed the changes. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one query..
resume_from_checkpoint = None | ||
# Check if resume flag is not passed (None), or if flag is true and | ||
# output_dir has checkpoints then get last checkpoint from output_dir | ||
if ( | ||
training_args.resume_from_checkpoint is None | ||
or training_args.resume_from_checkpoint.lower() == "true" | ||
): | ||
resume_from_checkpoint = get_last_checkpoint(training_args.output_dir) | ||
else: | ||
# `training_args.resume_from_checkpoint` gives string values | ||
# Check if flag is false OR flag has checkpoint value for resuming tuning | ||
resume_from_checkpoint = ( | ||
training_args.resume_from_checkpoint | ||
if training_args.resume_from_checkpoint.lower() != "false" | ||
else False | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not simplify this logic even further and move it into post_init of training_args.
We can check if resume_from_checkpoint
is None
or True
, and if so, we check if get_last_checkpoint(training_args.output_dir)
returns a non empty string. If it returns a non empty string, then we set resume_from_checkpoint
to True
, else we set resume_from_checkpoint
to False
. In the current condition added, if get_last_checkpoint
does not return a checkpoint, doesn't it get set to empty string, which is not what we want, isn't it?
We can leave every other case untouched, Isn't it? HF will already handle the other cases well, or is it not handling when resume_from_checkpoint
is a string false
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if
get_last_checkpoint
does not return a checkpoint, doesn't it get set to empty string, which is not what we want, isn't it?
get_last_checkpoint
either returns a checkpoint string or None (Not empty string).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it not handling when
resume_from_checkpoint
is a string false ?
When resume_from_checkpoint
is a string false, it has to be converted to boolean, as trainer()
needs argument resume_from_checkpoint
value as boolean.
trainer()
accepts argument resume_from_checkpoint
value as boolean True/False, None (none_type), or string with path of checkpoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we are adding support for string true
and false
additionally, even though HF Trainer does not support it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wanted your input on exactly that, as mentioned in the code comments. Training arguments, accept the flag resume_from_checkpoint
as a string. Hence the value passed in this flag would be a string value, unless we change it manually or change the flag name ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the behaviour if we set the Training argument resume_from_checkpoint
to be True
or False
? Isn't it handling it as a boolean True
or False
and behave accordingly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Training argument resume_from_checkpoint
takes in the path to a folder with a valid checkpoint for the model, as mentioned here, and doesn't work with boolean/String
value True/False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then lets go ahead with this change. LGTM. We can refine later if required.
# Check if resume flag is not passed (None), or if flag is true and | ||
# output_dir has checkpoints then get last checkpoint from output_dir | ||
if ( | ||
training_args.resume_from_checkpoint is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this saying that if resume_from_checkpoint is not set, then it would enable resume from checkpoint? That seems off...If a user doesn't pass anything in, why would it be enabled? How does the trainer differentiate between training the first time and resuming from training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to my understanding on the requirements, I have listed down the cases here. 4th case is if user doesn't pass the flag resume_from_checkpoint
then tuning will start from scratch if output_dir doesn't have checkpoints else tuning will resume from last checkpoint.
Feel free to give your input on this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reading through the tests I understand this better but this should be documented in our README so users dont need to go into tests to understand how it works. Or please point to the transformers docs if they exist.
I see that transformers.utils.get_last_checkpoint will return nothing if there are no checkpoints.
So the logic is:
- By default will check if the output_dir the user specifies has checkpoints. If so, resume from checkpoint
- If user sets False, will not resume from checkpoint
This is an important note as if a user reuses an output directory, the training behavior is now different than it was previously.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I can update the Readme. In this readme, is there any specific section recommended, where I can mention above details for users (regarding info of resume_from_checkpoint
flag) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can add the details to the "Tips on Parameters to Set" section
else: | ||
# `training_args.resume_from_checkpoint` gives string values | ||
# Check if flag is false OR flag has checkpoint value for resuming tuning | ||
resume_from_checkpoint = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I understanding the logic here correctly -- that resume_from_training
can be passed in as a boolean value or as a path to a directory to resume training from? Would be good to provide docs on this if the behavior is different from how HF transformers handles the argument as I see HF transformers only accepts path to checkpoint
resume_from_checkpoint (str, optional) — The path to a folder with a valid checkpoint for your model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure!
Mentioning trainer.train()
arguments.
resume_from_checkpoint (`str` or `bool`, *optional*):
If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
`bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
If the value of resume_from_checkpoint
is passed None
or False
to trainer.train()
, then it would behave similiar to not passing this flag to trainer.train()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I was looking at the docs here - https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.resume_from_checkpoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great changes and much appreciate the tests Abhishek! Have a few refactors that can be made to the tests to create cleaner code. Do you know how much time this adds to the tests by running tuning an additional 7 times?
tests/test_sft_trainer.py
Outdated
json_data = json.loads(line) | ||
data_list.append(json_data) | ||
|
||
if epoch is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small refactor:
if epoch is not None: | |
if epoch: |
This checks if epoch is a value and not None already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for reviewing test cases. This is done!
with open(f"{dir_path}/training_logs.jsonl", "r", encoding="utf-8") as file: | ||
for line in file: | ||
json_data = json.loads(line) | ||
data_list.append(json_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could refactor out getting train logs so that this method and _validate_logfile()
https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tests/test_sft_trainer.py#L567-L570 both use the same method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be a future improvement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Can look into this one as a future improvement.
tests/test_sft_trainer.py
Outdated
if epoch is not None: | ||
mod_data_list = [] | ||
for value in data_list: | ||
if int(value["data"]["epoch"]) == int(epoch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't both of these values already ints? Why do you have to convert them? You can also make this method more clear by providing type hints
def _get_training_logs(dir_path: str, epoch: int=None):
Also if an epoch is provided, do you return only the loss from this epoch?
Please add a description on this function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this change is helpful. Thanks!
Also I guess the apt function name could be _get_training_logs_by_epoch
. Do you agree with this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i could go either way on this, its only getting by epoch if the user provides an epoch, otherwise it returns the full training logs. so its not necessarily by epoch but is fine either way to me
tests/test_sft_trainer.py
Outdated
return trainer_state, last_checkpoint | ||
|
||
|
||
def _get_training_logs(dir_path, epoch=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add description of what this method is doing.
tests/test_sft_trainer.py
Outdated
assert final_trainer_state["total_flos"] == init_trainer_state["total_flos"] | ||
|
||
|
||
def _get_latest_checkpoint_trainer_state(dir_path, checkpoint_index=-1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add description of what this method is doing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the description!
tests/test_sft_trainer.py
Outdated
Test feature of resume training from checkpoint using resume_from_checkpoint flag | ||
to ensure that when the value of flag is True and output_dir has checkpoints, the | ||
tuning resumes from the latest checkpoint creating new checkpoints and the checkpoints | ||
created before resuming tuning is not affected. When the value of flag is True and | ||
output_dir does not have checkpoints, then the tuning will start from scratch. | ||
This test also checks if timestamp of 1st epoch after first tuning is same after | ||
resuming tuning and hence confirming that earlier checkpoints are not overwritten. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is the exact same test as the one above, not sure if its needed...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! Have removed the description.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not the description, I was referring to the test. All tests should have descriptions but this test looks to be doing the same work as the one above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay got it. Yes this test also checks if tuning is resuming from checkpoint but here I am marking resume_from_checkpoint
flag as true, as to test the logic in the sft_trainer
when flag value is True.
Let me know, if you think we should remove this test.
# Check if resume flag is not passed (None), or if flag is true and | ||
# output_dir has checkpoints then get last checkpoint from output_dir | ||
if ( | ||
training_args.resume_from_checkpoint is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reading through the tests I understand this better but this should be documented in our README so users dont need to go into tests to understand how it works. Or please point to the transformers docs if they exist.
I see that transformers.utils.get_last_checkpoint will return nothing if there are no checkpoints.
So the logic is:
- By default will check if the output_dir the user specifies has checkpoints. If so, resume from checkpoint
- If user sets False, will not resume from checkpoint
This is an important note as if a user reuses an output directory, the training behavior is now different than it was previously.
tests/test_sft_trainer.py
Outdated
Test feature of resume training from checkpoint using resume_from_checkpoint flag | ||
to ensure that when the value of flag is False the tuning will start from scratch. | ||
This test also checks multiple entry of 1st epoch in training logs after resuming | ||
tuning and hence confirming that the tuning has started from scratch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More succinct:
Test feature of resume training from checkpoint using resume_from_checkpoint flag | |
to ensure that when the value of flag is False the tuning will start from scratch. | |
This test also checks multiple entry of 1st epoch in training logs after resuming | |
tuning and hence confirming that the tuning has started from scratch. | |
Test when setting resume_from_checkpoint=False that tuning will start from scratch. |
You don't have to explain the checks you make like "checks multiple entry of 1st epoch in training logs" since users can read that in the test/code. You just need a summary of what the test is accomplishing and how it differs from other tests. You also don't need to say the method your are checking because the test nicely shows it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay got it. Thanks!
tests/test_sft_trainer.py
Outdated
init_training_logs = _get_training_logs(tempdir, epoch=1) | ||
assert len(init_training_logs) == 1 | ||
|
||
# Resume training with higher epoch and same output dir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not resuming training correct? It will overwrite the existing checkpoints and retune the model but since the training_logs are appended, thats why there are two epoch 1 values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes correct, that's the logic. Also comment seems misleading, hence changed it to:
# Training again with higher epoch and same output dir
tests/test_sft_trainer.py
Outdated
Test feature of resume training from checkpoint using resume_from_checkpoint flag | ||
to ensure that when the value of flag is a checkpoint-x path, the tuning will resume | ||
from that checkpoint. This test checks if total_flos of checkpoint-x has not changed | ||
after resuming tuning, hence confirming that the tuning has started from checkpoint-x. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to make this one more succinct, Also what is total_flos
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this a LoRA test to test resuming training from diff tuning techniques
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
total_flos
is basically Total Floating Point Operations (FLOPs).
Formal def:
total_flos
is a measure of the computational workload performed by the model during training. It is computed as the cumulative sum of FLOPs (Floating Point Operations) for each training step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this a LoRA test to test resuming training from diff tuning techniques
Sure! Made this test as LoRA test. Thanks!
# output_dir has checkpoints then get last checkpoint from output_dir | ||
if ( | ||
training_args.resume_from_checkpoint is None | ||
or training_args.resume_from_checkpoint.lower() == "true" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this does accept a boolean value, shouldn't this just check the bool value? A user isn't passing in a string "True" but the boolean "True" right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, resume_from_checkpoint
does accepts bool in trainer.train()
. Its just that, when user passes resume_from_checkpoint
flag as bool, it gets taken as string as it gets treated as flag in training arguments which just accepts string.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh thanks for the clarity
Thank you @anhuong for the review on test cases. I have made changes mentioned and pushed it. Also I calculated the time these 4 test cases adds to the tests. Difference in running with and without the 4 test case is 15 seconds. Thanks! Also wanted to ask regarding |
tests/test_sft_trainer.py
Outdated
Load and optionally filter training logs from a training_logs JSON Lines file. | ||
This function reads a JSON Lines (`.jsonl`) file containing training logs and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: better clarity and shorter
Load and optionally filter training logs from a training_logs JSON Lines file. | |
This function reads a JSON Lines (`.jsonl`) file containing training logs and | |
Load and optionally filter training_logs.jsonl file. | |
If an epoch number is specified, the function filters the logs | |
and returns only the entries corresponding to the specified epoch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
tests/test_sft_trainer.py
Outdated
Get the trainer state from the specified checkpoint directory. | ||
This function gets the latest or specific checkpoint based on the | ||
provided checkpoint_index from the checkpoint directory, and loads | ||
the `trainer_state.json` file from that checkpoint. The trainer | ||
state is returned along with the path to the checkpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: kinda repetitive so makes more succinct
Get the trainer state from the specified checkpoint directory. | |
This function gets the latest or specific checkpoint based on the | |
provided checkpoint_index from the checkpoint directory, and loads | |
the `trainer_state.json` file from that checkpoint. The trainer | |
state is returned along with the path to the checkpoint. | |
Get the trainer state from the latest or specified checkpoint directory. | |
The trainer state is returned along with the path to the checkpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
tests/test_sft_trainer.py
Outdated
for entry in final_trainer_state["log_history"] | ||
if entry["epoch"] == 1.0 | ||
][0] | ||
assert init_loss_epoch == final_loss_epoch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true
tests/test_sft_trainer.py
Outdated
Test feature of resume training from checkpoint using resume_from_checkpoint flag | ||
to ensure that when the value of flag is True and output_dir has checkpoints, the | ||
tuning resumes from the latest checkpoint creating new checkpoints and the checkpoints | ||
created before resuming tuning is not affected. When the value of flag is True and | ||
output_dir does not have checkpoints, then the tuning will start from scratch. | ||
This test also checks if timestamp of 1st epoch after first tuning is same after | ||
resuming tuning and hence confirming that earlier checkpoints are not overwritten. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not the description, I was referring to the test. All tests should have descriptions but this test looks to be doing the same work as the one above
tests/test_sft_trainer.py
Outdated
Test when setting resume_from_checkpoint=path/to/checkpoint-x | ||
that the tuning will resume from the checkpoint-x. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Could be a little clearer and should update that this is for LoRA
Test when setting resume_from_checkpoint=path/to/checkpoint-x | |
that the tuning will resume from the checkpoint-x. | |
Test resume checkpoint from a specified checkpoint path for LoRA tuning. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
tests/test_sft_trainer.py
Outdated
train_args.num_train_epochs += 5 | ||
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) | ||
_validate_training(tempdir) | ||
|
||
# Get Training log entry for epoch 1 | ||
final_training_logs = _get_training_logs(tempdir, epoch=1) | ||
final_training_logs = _get_training_logs_by_epoch(tempdir, epoch=1) | ||
assert len(final_training_logs) == 2 | ||
|
||
|
||
def test_resume_training_from_checkpoint_with_flag_checkpoint_path(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SHould include LoRA here
def test_resume_training_from_checkpoint_with_flag_checkpoint_path(): | |
def test_resume_training_from_checkpoint_with_flag_checkpoint_path_lora(): |
# output_dir has checkpoints then get last checkpoint from output_dir | ||
if ( | ||
training_args.resume_from_checkpoint is None | ||
or training_args.resume_from_checkpoint.lower() == "true" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh thanks for the clarity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why all of my comments got split up in different reviews, sorry about that. There are a few nits on the descriptions and small changes that can be addressed as well as a test case that I'm not sure if its useful. In addition, a quick note should be added to docs that explains how this works but point at the HF.TrainingArguments / Trainer docs
In terms of the linting and too many lines in test_sft_trainer.py, I think its okay to increase this value in pylint. We can create an issue to look into the tests and see how we can split them up into separate files. For now, lets increase to 1200 |
Signed-off-by: Abhishek <[email protected]>
Signed-off-by: Abhishek <[email protected]>
Have pushed changes related to function description, adding ReadMe and increase in line value in pylint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than question on test, change looks good. Please answer question but i'll merge this as-is
@@ -122,6 +122,11 @@ def test_resume_training_from_checkpoint(): | |||
|
|||
|
|||
def test_resume_training_from_checkpoint_with_flag_true(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this method is useful or if it;s doing the same as the one above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this test also checks if tuning is resuming from checkpoint but here I am marking resume_from_checkpoint
flag as true, as to test the logic in the sft_trainer.py
when flag value is True, though it's behavior would be similar to when value of resume_from_checkpoint
is None
.
Let me know, if you think we should remove this test. @anhuong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit redundant as resume_from_checkpoint=None creates the same behavior as resume_from_checkpoint=True but fine to leave for now
@kmehant Since you reviewed this PR and requested changes, the PR is waiting for your approval before it can be merged |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Abhishek-TAMU lets rebase thank you.
Description of the change
In this PR, have added feature of resuming the tuning from the last saved checkpoint of last tuning using the flag
--output_dir
which stores all the checkpoints.Related issue number
#1007
How to verify the PR
Resuming the tuning of model after initial tuning and verify the output checkpoints.
Was the PR tested