-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add support for smoothly resuming training from a saved checkpoint #300
Changes from all commits
7f72468
b25e2a6
b055365
a857d68
fdee550
647b056
156729d
37d5e83
fcad583
8229235
e06b2d9
854aec2
f0ba2d0
d72fc7d
effa188
d2a4f0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,6 +80,214 @@ | |
PEFT_LORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05) | ||
|
||
|
||
def test_resume_training_from_checkpoint(): | ||
""" | ||
Test tuning resumes from the latest checkpoint, creating new checkpoints and the | ||
checkpoints created before resuming tuning is not affected. | ||
""" | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
train_args = copy.deepcopy(TRAIN_ARGS) | ||
train_args.output_dir = tempdir | ||
|
||
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) | ||
_validate_training(tempdir) | ||
|
||
# Get trainer state of latest checkpoint | ||
init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir) | ||
assert init_trainer_state is not None | ||
|
||
# Resume training with higher epoch and same output dir | ||
train_args.num_train_epochs += 5 | ||
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) | ||
_validate_training(tempdir) | ||
|
||
# Get trainer state of latest checkpoint | ||
final_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir) | ||
assert final_trainer_state is not None | ||
|
||
assert final_trainer_state["epoch"] == init_trainer_state["epoch"] + 5 | ||
assert final_trainer_state["global_step"] > init_trainer_state["global_step"] | ||
|
||
# Check if loss of 1st epoch after first tuning is same after | ||
# resuming tuning and not overwritten | ||
assert len(init_trainer_state["log_history"]) > 0 | ||
|
||
init_log_history = init_trainer_state["log_history"][0] | ||
assert init_log_history["epoch"] == 1 | ||
|
||
final_log_history = final_trainer_state["log_history"][0] | ||
assert final_log_history["epoch"] == 1 | ||
|
||
assert init_log_history["loss"] == final_log_history["loss"] | ||
|
||
|
||
def test_resume_training_from_checkpoint_with_flag_true(): | ||
""" | ||
Test tuning resumes from the latest checkpoint when flag is true, | ||
creating new checkpoints and the checkpoints created before resuming | ||
tuning is not affected. | ||
""" | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
train_args = copy.deepcopy(TRAIN_ARGS) | ||
train_args.output_dir = tempdir | ||
train_args.resume_from_checkpoint = "True" | ||
|
||
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) | ||
_validate_training(tempdir) | ||
|
||
# Get trainer state of latest checkpoint | ||
init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir) | ||
assert init_trainer_state is not None | ||
|
||
# Get Training logs | ||
init_training_logs = _get_training_logs_by_epoch(tempdir) | ||
|
||
# Resume training with higher epoch and same output dir | ||
train_args.num_train_epochs += 5 | ||
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) | ||
_validate_training(tempdir) | ||
|
||
# Get trainer state of latest checkpoint | ||
final_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir) | ||
assert final_trainer_state is not None | ||
|
||
assert final_trainer_state["epoch"] == init_trainer_state["epoch"] + 5 | ||
assert final_trainer_state["global_step"] > init_trainer_state["global_step"] | ||
|
||
final_training_logs = _get_training_logs_by_epoch(tempdir) | ||
|
||
assert ( | ||
init_training_logs[0]["data"]["timestamp"] | ||
== final_training_logs[0]["data"]["timestamp"] | ||
) | ||
|
||
|
||
def test_resume_training_from_checkpoint_with_flag_false(): | ||
""" | ||
Test when setting resume_from_checkpoint=False that tuning will start from scratch. | ||
""" | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
train_args = copy.deepcopy(TRAIN_ARGS) | ||
train_args.output_dir = tempdir | ||
train_args.resume_from_checkpoint = "False" | ||
|
||
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) | ||
_validate_training(tempdir) | ||
|
||
# Get trainer state of latest checkpoint | ||
init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir) | ||
assert init_trainer_state is not None | ||
|
||
# Get Training log entry for epoch 1 | ||
init_training_logs = _get_training_logs_by_epoch(tempdir, epoch=1) | ||
assert len(init_training_logs) == 1 | ||
|
||
# Training again with higher epoch and same output dir | ||
train_args.num_train_epochs += 5 | ||
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) | ||
_validate_training(tempdir) | ||
|
||
# Get Training log entry for epoch 1 | ||
final_training_logs = _get_training_logs_by_epoch(tempdir, epoch=1) | ||
assert len(final_training_logs) == 2 | ||
|
||
|
||
def test_resume_training_from_checkpoint_with_flag_checkpoint_path_lora(): | ||
""" | ||
Test resume checkpoint from a specified checkpoint path for LoRA tuning. | ||
""" | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
train_args = copy.deepcopy(TRAIN_ARGS) | ||
lora_config = copy.deepcopy(PEFT_LORA_ARGS) | ||
train_args.output_dir = tempdir | ||
|
||
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, lora_config) | ||
_validate_training(tempdir) | ||
|
||
# Get trainer state and checkpoint_path of second last checkpoint | ||
init_trainer_state, checkpoint_path = _get_latest_checkpoint_trainer_state( | ||
tempdir, checkpoint_index=-2 | ||
) | ||
assert init_trainer_state is not None | ||
|
||
# Resume training with higher epoch and same output dir | ||
train_args.num_train_epochs += 5 | ||
train_args.resume_from_checkpoint = checkpoint_path | ||
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, lora_config) | ||
_validate_training(tempdir) | ||
|
||
# Get total_flos from trainer state of checkpoint_path and check if its same | ||
final_trainer_state = None | ||
trainer_state_file = os.path.join(checkpoint_path, "trainer_state.json") | ||
with open(trainer_state_file, "r", encoding="utf-8") as f: | ||
final_trainer_state = json.load(f) | ||
|
||
assert final_trainer_state["total_flos"] == init_trainer_state["total_flos"] | ||
|
||
|
||
def _get_latest_checkpoint_trainer_state(dir_path: str, checkpoint_index: int = -1): | ||
""" | ||
Get the trainer state from the latest or specified checkpoint directory. | ||
The trainer state is returned along with the path to the checkpoint. | ||
|
||
Args: | ||
dir_path (str): The directory path where checkpoint folders are located. | ||
checkpoint_index (int, optional): The index of the checkpoint to retrieve, | ||
based on the checkpoint number. The default | ||
is -1, which returns the latest checkpoint. | ||
|
||
Returns: | ||
trainer_state: The trainer state loaded from `trainer_state.json` in the | ||
checkpoint directory. | ||
last_checkpoint: The path to the checkpoint directory. | ||
""" | ||
trainer_state = None | ||
last_checkpoint = None | ||
checkpoints = [ | ||
os.path.join(dir_path, d) | ||
for d in os.listdir(dir_path) | ||
if d.startswith("checkpoint") | ||
] | ||
if checkpoints: | ||
last_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))[ | ||
checkpoint_index | ||
] | ||
Comment on lines
+246
to
+254
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we also use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In line 248-250, we don't necessarily take out last checkpoint, instead get checkpoint based on |
||
trainer_state_file = os.path.join(last_checkpoint, "trainer_state.json") | ||
with open(trainer_state_file, "r", encoding="utf-8") as f: | ||
trainer_state = json.load(f) | ||
return trainer_state, last_checkpoint | ||
|
||
|
||
def _get_training_logs_by_epoch(dir_path: str, epoch: int = None): | ||
""" | ||
Load and optionally filter training_logs.jsonl file. | ||
If an epoch number is specified, the function filters the logs | ||
and returns only the entries corresponding to the specified epoch. | ||
|
||
Args: | ||
dir_path (str): The directory path where the `training_logs.jsonl` file is located. | ||
epoch (int, optional): The epoch number to filter logs by. If not specified, | ||
all logs are returned. | ||
|
||
Returns: | ||
list: A list containing the training logs. If `epoch` is specified, | ||
only logs from the specified epoch are returned; otherwise, all logs are returned. | ||
""" | ||
data_list = [] | ||
with open(f"{dir_path}/training_logs.jsonl", "r", encoding="utf-8") as file: | ||
for line in file: | ||
json_data = json.loads(line) | ||
data_list.append(json_data) | ||
Comment on lines
+277
to
+280
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could refactor out getting train logs so that this method and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can be a future improvement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Can look into this one as a future improvement. |
||
|
||
if epoch: | ||
mod_data_list = [] | ||
for value in data_list: | ||
if value["data"]["epoch"] == epoch: | ||
mod_data_list.append(value) | ||
return mod_data_list | ||
return data_list | ||
|
||
|
||
def test_run_train_requires_output_dir(): | ||
"""Check fails when output dir not provided.""" | ||
updated_output_dir_train_args = copy.deepcopy(TRAIN_ARGS) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
LlamaTokenizerFast, | ||
TrainerCallback, | ||
) | ||
from transformers.trainer_utils import get_last_checkpoint | ||
from transformers.utils import is_accelerate_available | ||
from trl import SFTConfig, SFTTrainer | ||
import transformers | ||
|
@@ -215,7 +216,7 @@ def train( | |
), | ||
) | ||
|
||
# add special tokens only when a custom tokenizer is not passed | ||
# Add special tokens only when a custom tokenizer is not passed | ||
if not model_args.tokenizer_name_or_path: | ||
# TODO: understand if we need to hardcode these here or just use defaults in model | ||
if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)): | ||
|
@@ -366,7 +367,24 @@ def train( | |
for x in framework.get_callbacks_and_ready_for_train(model, accelerator): | ||
trainer.add_callback(x) | ||
|
||
trainer.train() | ||
resume_from_checkpoint = None | ||
# Check if resume flag is not passed (None), or if flag is true and | ||
# output_dir has checkpoints then get last checkpoint from output_dir | ||
if ( | ||
training_args.resume_from_checkpoint is None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this saying that if resume_from_checkpoint is not set, then it would enable resume from checkpoint? That seems off...If a user doesn't pass anything in, why would it be enabled? How does the trainer differentiate between training the first time and resuming from training? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to my understanding on the requirements, I have listed down the cases here. 4th case is if user doesn't pass the flag Feel free to give your input on this case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After reading through the tests I understand this better but this should be documented in our README so users dont need to go into tests to understand how it works. Or please point to the transformers docs if they exist. I see that transformers.utils.get_last_checkpoint will return nothing if there are no checkpoints. So the logic is:
This is an important note as if a user reuses an output directory, the training behavior is now different than it was previously. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I can update the Readme. In this readme, is there any specific section recommended, where I can mention above details for users (regarding info of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can add the details to the "Tips on Parameters to Set" section |
||
or training_args.resume_from_checkpoint.lower() == "true" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this does accept a boolean value, shouldn't this just check the bool value? A user isn't passing in a string "True" but the boolean "True" right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ahh thanks for the clarity |
||
): | ||
resume_from_checkpoint = get_last_checkpoint(training_args.output_dir) | ||
else: | ||
# `training_args.resume_from_checkpoint` gives string values | ||
# Check if flag is false OR flag has checkpoint value for resuming tuning | ||
resume_from_checkpoint = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Am I understanding the logic here correctly -- that
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! Mentioning
If the value of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I was looking at the docs here - https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.resume_from_checkpoint |
||
training_args.resume_from_checkpoint | ||
if training_args.resume_from_checkpoint.lower() != "false" | ||
else False | ||
) | ||
Comment on lines
+370
to
+385
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not simplify this logic even further and move it into post_init of training_args. We can check if We can leave every other case untouched, Isn't it? HF will already handle the other cases well, or is it not handling when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we are adding support for string There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wanted your input on exactly that, as mentioned in the code comments. Training arguments, accept the flag There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the behaviour if we set the Training argument There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Training argument There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then lets go ahead with this change. LGTM. We can refine later if required. |
||
|
||
trainer.train(resume_from_checkpoint) | ||
|
||
return trainer | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this method is useful or if it;s doing the same as the one above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this test also checks if tuning is resuming from checkpoint but here I am marking
resume_from_checkpoint
flag as true, as to test the logic in thesft_trainer.py
when flag value is True, though it's behavior would be similar to when value ofresume_from_checkpoint
isNone
.Let me know, if you think we should remove this test. @anhuong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit redundant as resume_from_checkpoint=None creates the same behavior as resume_from_checkpoint=True but fine to leave for now