Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for smoothly resuming training from a saved checkpoint #300

Merged
merged 16 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ indent-string=' '
max-line-length=100

# Maximum number of lines in a module.
max-module-lines=1100
max-module-lines=1200

# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,11 @@ You can set `output_dir` to a local directory and set `save_model_dir` to COS to

In order to achieve the fastest train time, set `save_strategy="no"`, as saving no checkpoints except for the final model will remove intermediate write operations all together.

#### Resuming tuning from checkpoints
If the output directory already contains checkpoints, tuning will automatically resume from the latest checkpoint in the directory specified by the `output_dir` flag. To start tuning from scratch and ignore existing checkpoints, set the `resume_from_checkpoint` flag to False.

You can also use the resume_from_checkpoint flag to resume tuning from a specific checkpoint by providing the full path to the desired checkpoint as a string. This flag is passed as an argument to the [trainer.train()](https://github.com/huggingface/transformers/blob/db70426854fe7850f2c5834d633aff637f14772e/src/transformers/trainer.py#L1901) function of the SFTTrainer.

## Tuning Techniques:

### LoRA Tuning Example
Expand Down
208 changes: 208 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,214 @@
PEFT_LORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05)


def test_resume_training_from_checkpoint():
"""
Test tuning resumes from the latest checkpoint, creating new checkpoints and the
checkpoints created before resuming tuning is not affected.
"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert init_trainer_state is not None

# Resume training with higher epoch and same output dir
train_args.num_train_epochs += 5
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
final_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert final_trainer_state is not None

assert final_trainer_state["epoch"] == init_trainer_state["epoch"] + 5
assert final_trainer_state["global_step"] > init_trainer_state["global_step"]

# Check if loss of 1st epoch after first tuning is same after
# resuming tuning and not overwritten
assert len(init_trainer_state["log_history"]) > 0

init_log_history = init_trainer_state["log_history"][0]
assert init_log_history["epoch"] == 1

final_log_history = final_trainer_state["log_history"][0]
assert final_log_history["epoch"] == 1

assert init_log_history["loss"] == final_log_history["loss"]


def test_resume_training_from_checkpoint_with_flag_true():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this method is useful or if it;s doing the same as the one above.

Copy link
Collaborator Author

@Abhishek-TAMU Abhishek-TAMU Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this test also checks if tuning is resuming from checkpoint but here I am marking resume_from_checkpoint flag as true, as to test the logic in the sft_trainer.py when flag value is True, though it's behavior would be similar to when value of resume_from_checkpoint is None.

Let me know, if you think we should remove this test. @anhuong

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit redundant as resume_from_checkpoint=None creates the same behavior as resume_from_checkpoint=True but fine to leave for now

"""
Test tuning resumes from the latest checkpoint when flag is true,
creating new checkpoints and the checkpoints created before resuming
tuning is not affected.
"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.resume_from_checkpoint = "True"

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert init_trainer_state is not None

# Get Training logs
init_training_logs = _get_training_logs_by_epoch(tempdir)

# Resume training with higher epoch and same output dir
train_args.num_train_epochs += 5
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
final_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert final_trainer_state is not None

assert final_trainer_state["epoch"] == init_trainer_state["epoch"] + 5
assert final_trainer_state["global_step"] > init_trainer_state["global_step"]

final_training_logs = _get_training_logs_by_epoch(tempdir)

assert (
init_training_logs[0]["data"]["timestamp"]
== final_training_logs[0]["data"]["timestamp"]
)


def test_resume_training_from_checkpoint_with_flag_false():
"""
Test when setting resume_from_checkpoint=False that tuning will start from scratch.
"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.resume_from_checkpoint = "False"

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert init_trainer_state is not None

# Get Training log entry for epoch 1
init_training_logs = _get_training_logs_by_epoch(tempdir, epoch=1)
assert len(init_training_logs) == 1

# Training again with higher epoch and same output dir
train_args.num_train_epochs += 5
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get Training log entry for epoch 1
final_training_logs = _get_training_logs_by_epoch(tempdir, epoch=1)
assert len(final_training_logs) == 2


def test_resume_training_from_checkpoint_with_flag_checkpoint_path_lora():
"""
Test resume checkpoint from a specified checkpoint path for LoRA tuning.
"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
lora_config = copy.deepcopy(PEFT_LORA_ARGS)
train_args.output_dir = tempdir

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, lora_config)
_validate_training(tempdir)

# Get trainer state and checkpoint_path of second last checkpoint
init_trainer_state, checkpoint_path = _get_latest_checkpoint_trainer_state(
tempdir, checkpoint_index=-2
)
assert init_trainer_state is not None

# Resume training with higher epoch and same output dir
train_args.num_train_epochs += 5
train_args.resume_from_checkpoint = checkpoint_path
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, lora_config)
_validate_training(tempdir)

# Get total_flos from trainer state of checkpoint_path and check if its same
final_trainer_state = None
trainer_state_file = os.path.join(checkpoint_path, "trainer_state.json")
with open(trainer_state_file, "r", encoding="utf-8") as f:
final_trainer_state = json.load(f)

assert final_trainer_state["total_flos"] == init_trainer_state["total_flos"]


def _get_latest_checkpoint_trainer_state(dir_path: str, checkpoint_index: int = -1):
"""
Get the trainer state from the latest or specified checkpoint directory.
The trainer state is returned along with the path to the checkpoint.

Args:
dir_path (str): The directory path where checkpoint folders are located.
checkpoint_index (int, optional): The index of the checkpoint to retrieve,
based on the checkpoint number. The default
is -1, which returns the latest checkpoint.

Returns:
trainer_state: The trainer state loaded from `trainer_state.json` in the
checkpoint directory.
last_checkpoint: The path to the checkpoint directory.
"""
trainer_state = None
last_checkpoint = None
checkpoints = [
os.path.join(dir_path, d)
for d in os.listdir(dir_path)
if d.startswith("checkpoint")
]
if checkpoints:
last_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))[
checkpoint_index
]
Comment on lines +246 to +254
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also use the transformers.utils.get_last_checkpoint() here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In line 248-250, we don't necessarily take out last checkpoint, instead get checkpoint based on checkpoint_index, so I guess transformers.utils.get_last_checkpoint() may not be useful here.

trainer_state_file = os.path.join(last_checkpoint, "trainer_state.json")
with open(trainer_state_file, "r", encoding="utf-8") as f:
trainer_state = json.load(f)
return trainer_state, last_checkpoint


def _get_training_logs_by_epoch(dir_path: str, epoch: int = None):
"""
Load and optionally filter training_logs.jsonl file.
If an epoch number is specified, the function filters the logs
and returns only the entries corresponding to the specified epoch.

Args:
dir_path (str): The directory path where the `training_logs.jsonl` file is located.
epoch (int, optional): The epoch number to filter logs by. If not specified,
all logs are returned.

Returns:
list: A list containing the training logs. If `epoch` is specified,
only logs from the specified epoch are returned; otherwise, all logs are returned.
"""
data_list = []
with open(f"{dir_path}/training_logs.jsonl", "r", encoding="utf-8") as file:
for line in file:
json_data = json.loads(line)
data_list.append(json_data)
Comment on lines +277 to +280
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could refactor out getting train logs so that this method and _validate_logfile() https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tests/test_sft_trainer.py#L567-L570 both use the same method.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be a future improvement

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Can look into this one as a future improvement.


if epoch:
mod_data_list = []
for value in data_list:
if value["data"]["epoch"] == epoch:
mod_data_list.append(value)
return mod_data_list
return data_list


def test_run_train_requires_output_dir():
"""Check fails when output dir not provided."""
updated_output_dir_train_args = copy.deepcopy(TRAIN_ARGS)
Expand Down
22 changes: 20 additions & 2 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
LlamaTokenizerFast,
TrainerCallback,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_accelerate_available
from trl import SFTConfig, SFTTrainer
import transformers
Expand Down Expand Up @@ -215,7 +216,7 @@ def train(
),
)

# add special tokens only when a custom tokenizer is not passed
# Add special tokens only when a custom tokenizer is not passed
if not model_args.tokenizer_name_or_path:
# TODO: understand if we need to hardcode these here or just use defaults in model
if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
Expand Down Expand Up @@ -366,7 +367,24 @@ def train(
for x in framework.get_callbacks_and_ready_for_train(model, accelerator):
trainer.add_callback(x)

trainer.train()
resume_from_checkpoint = None
# Check if resume flag is not passed (None), or if flag is true and
# output_dir has checkpoints then get last checkpoint from output_dir
if (
training_args.resume_from_checkpoint is None
Copy link
Collaborator

@anhuong anhuong Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this saying that if resume_from_checkpoint is not set, then it would enable resume from checkpoint? That seems off...If a user doesn't pass anything in, why would it be enabled? How does the trainer differentiate between training the first time and resuming from training?

Copy link
Collaborator Author

@Abhishek-TAMU Abhishek-TAMU Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to my understanding on the requirements, I have listed down the cases here. 4th case is if user doesn't pass the flag resume_from_checkpoint then tuning will start from scratch if output_dir doesn't have checkpoints else tuning will resume from last checkpoint.

Feel free to give your input on this case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading through the tests I understand this better but this should be documented in our README so users dont need to go into tests to understand how it works. Or please point to the transformers docs if they exist.

I see that transformers.utils.get_last_checkpoint will return nothing if there are no checkpoints.

So the logic is:

  • By default will check if the output_dir the user specifies has checkpoints. If so, resume from checkpoint
  • If user sets False, will not resume from checkpoint

This is an important note as if a user reuses an output directory, the training behavior is now different than it was previously.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can update the Readme. In this readme, is there any specific section recommended, where I can mention above details for users (regarding info of resume_from_checkpoint flag) ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add the details to the "Tips on Parameters to Set" section

or training_args.resume_from_checkpoint.lower() == "true"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this does accept a boolean value, shouldn't this just check the bool value? A user isn't passing in a string "True" but the boolean "True" right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, resume_from_checkpoint does accepts bool in trainer.train(). Its just that, when user passes resume_from_checkpoint flag as bool, it gets taken as string as it gets treated as flag in training arguments which just accepts string.
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh thanks for the clarity

):
resume_from_checkpoint = get_last_checkpoint(training_args.output_dir)
else:
# `training_args.resume_from_checkpoint` gives string values
# Check if flag is false OR flag has checkpoint value for resuming tuning
resume_from_checkpoint = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I understanding the logic here correctly -- that resume_from_training can be passed in as a boolean value or as a path to a directory to resume training from? Would be good to provide docs on this if the behavior is different from how HF transformers handles the argument as I see HF transformers only accepts path to checkpoint

resume_from_checkpoint (str, optional) — The path to a folder with a valid checkpoint for your model.

Copy link
Collaborator Author

@Abhishek-TAMU Abhishek-TAMU Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

Mentioning trainer.train() arguments.

resume_from_checkpoint (`str` or `bool`, *optional*):
If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
`bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.

If the value of resume_from_checkpoint is passed None or False to trainer.train(), then it would behave similiar to not passing this flag to trainer.train()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

training_args.resume_from_checkpoint
if training_args.resume_from_checkpoint.lower() != "false"
else False
)
Comment on lines +370 to +385
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simplify this logic even further and move it into post_init of training_args.

We can check if resume_from_checkpoint is None or True, and if so, we check if get_last_checkpoint(training_args.output_dir) returns a non empty string. If it returns a non empty string, then we set resume_from_checkpoint to True, else we set resume_from_checkpoint to False. In the current condition added, if get_last_checkpoint does not return a checkpoint, doesn't it get set to empty string, which is not what we want, isn't it?

We can leave every other case untouched, Isn't it? HF will already handle the other cases well, or is it not handling when resume_from_checkpoint is a string false?

Copy link
Collaborator Author

@Abhishek-TAMU Abhishek-TAMU Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if get_last_checkpoint does not return a checkpoint, doesn't it get set to empty string, which is not what we want, isn't it?

get_last_checkpoint either returns a checkpoint string or None (Not empty string).

Copy link
Collaborator Author

@Abhishek-TAMU Abhishek-TAMU Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it not handling when resume_from_checkpoint is a string false ?

When resume_from_checkpoint is a string false, it has to be converted to boolean, as trainer() needs argument resume_from_checkpoint value as boolean.

trainer() accepts argument resume_from_checkpoint value as boolean True/False, None (none_type), or string with path of checkpoint.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we are adding support for string true and false additionally, even though HF Trainer does not support it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wanted your input on exactly that, as mentioned in the code comments. Training arguments, accept the flag resume_from_checkpoint as a string. Hence the value passed in this flag would be a string value, unless we change it manually or change the flag name ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the behaviour if we set the Training argument resume_from_checkpoint to be True or False? Isn't it handling it as a boolean True or False and behave accordingly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Training argument resume_from_checkpoint takes in the path to a folder with a valid checkpoint for the model, as mentioned here, and doesn't work with boolean/String value True/False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then lets go ahead with this change. LGTM. We can refine later if required.


trainer.train(resume_from_checkpoint)

return trainer

Expand Down
Loading