Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster intermediate checkpoints with DCP async save in TorchTune #2006

Merged
merged 9 commits into from
Dec 13, 2024
157 changes: 50 additions & 107 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training.activations import apply_selective_activation_checkpointing
from torchtune.training.checkpointing._checkpoint_client import (
CheckpointClient,
TrainingProgress,
)
from torchtune.training.lr_schedulers import get_lr

from tqdm import tqdm
Expand Down Expand Up @@ -138,9 +142,11 @@ def __init__(self, cfg: DictConfig) -> None:

# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
self._checkpoint_client = CheckpointClient(cfg)

# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
if self._optimizer_in_bwd:
Expand Down Expand Up @@ -189,21 +195,6 @@ def __init__(self, cfg: DictConfig) -> None:
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0

def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Extract the checkpoint state from file and validate. If resume_from_checkpoint
is True, this also includes the recipe state.
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

if self._resume_from_checkpoint:
self._update_recipe_state(checkpoint_dict)
return checkpoint_dict

def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets make this an utility to reduce the footprint of checkpointing in the recipe. Wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it. Challenge is that it would be a slightly bigger refactor. The recipe keys are defined in the training module which causes a circular dependency. We would need to refactor that out of training first and then create an util for this method. I will leave it as it is for now since its an existing method in the recipe.

"""
Updates the recipe state from checkpoint.
Expand Down Expand Up @@ -255,7 +246,8 @@ def setup(self, cfg: DictConfig) -> None:
# log config with parameter override
self._metric_logger.log_config(cfg)

checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
# Load the base model
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()

self._compile = cfg.get("compile", False)
self._model = self._setup_model(
Expand All @@ -276,11 +268,36 @@ def setup(self, cfg: DictConfig) -> None:
optimizer_in_bwd=self._optimizer_in_bwd,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
if training.OPT_KEY in checkpoint_dict
else None
),
)

if self._resume_from_checkpoint:
# If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
# using the DistributedCheckpointer.
# Therefore the recipe needs to load the distributed checkpoint to restore the training
# progress.
if self._enable_async_checkpointing:
try:
checkpoint_dict = (
self._checkpoint_client.load_distributed_checkpoint(
self._model,
(
self._optim_ckpt_wrapper
if self._optimizer_in_bwd
else self._optimizer
),
)
)
except Exception as e:
log.warning(
f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint."
)

# Update the recipe state from the checkpoint state dict.
self._update_recipe_state(checkpoint_dict)

# initialize loss
self._loss_fn = config.instantiate(cfg.loss)

Expand Down Expand Up @@ -547,6 +564,7 @@ def _setup_model(
log,
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
)

if self._is_rank_zero:
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
Expand Down Expand Up @@ -661,95 +679,6 @@ def _setup_data(

return sampler, dataloader

def save_checkpoint(
self,
epoch: int,
) -> None:
"""
Checkpoint the state of the recipe. The constructed checkpoint state dict
contains the following information:
- Model weights with key training.MODEL_KEY
- Relevant recipe state if training is not complete

Checkpointer will save the model weights and recipe state in
different checkpoint files. To correctly resume training from an intermediate checkpoint,
the model weights and recipe state must be provided.
"""
# final dict passed onto the checkpointer
checkpoint_dict = {}

intermediate_checkpoint = epoch + 1 < self.total_epochs

utils.log_rank_zero(
log,
"Saving checkpoint. This may take some time. Retrieving full model state dict...",
)
start = time.perf_counter()

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
device=self._device,
)

utils.log_rank_zero(
log,
f"Getting full model state dict took {time.perf_counter() - start:.2f} secs",
)

if intermediate_checkpoint:
start = time.perf_counter()
utils.log_rank_zero(log, "Getting optimizer state dict...")
if not self._optimizer_in_bwd:
opt_state_dict = training.get_full_optimizer_state_dict(
self._optimizer,
self._is_rank_zero,
device=self._device,
)
else:
opt_state_dict = {}
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
opt_state_dict[param] = training.get_full_optimizer_state_dict(
opt, self._is_rank_zero, device=self._device
)
utils.log_rank_zero(
log,
f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs",
)
else:
opt_state_dict = None

# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file

if self._is_rank_zero:
start = time.perf_counter()
checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict})

# if training is in-progress, checkpoint the optimizer state and recipe state
# as well.
if intermediate_checkpoint:
checkpoint_dict.update(
{
training.OPT_KEY: opt_state_dict,
training.SEED_KEY: self.seed,
training.EPOCHS_KEY: self.epochs_run,
training.TOTAL_EPOCHS_KEY: self.total_epochs,
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)

self._checkpointer.save_checkpoint(
checkpoint_dict,
epoch=epoch,
intermediate_checkpoint=intermediate_checkpoint,
)
log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs")

torch.distributed.barrier()

def train(self) -> None:
"""
The core training loop.
Expand Down Expand Up @@ -922,7 +851,21 @@ def train(self) -> None:
self._profiler.step()

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)
self._checkpoint_client.save_checkpoint(
model=self._model,
optimizer=(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper
),
training_progress=TrainingProgress(
seed=self.seed,
epochs_run=self.epochs_run,
total_epochs=self.total_epochs,
max_steps_per_epoch=self.max_steps_per_epoch,
),
epoch=curr_epoch,
)

self._profiler.stop()

Expand Down
2 changes: 1 addition & 1 deletion recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will break every config or system that may be using this function. Although i agree its a better name, i dont think it is worth to change it in this PR. A possibly better alternative is to keep both, handle it in the class, and raise a deprecation warning. We should check what others think.

Copy link
Contributor Author

@saumishr saumishr Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If users have custom recipes with the resume_from_checkpoint, then it may be better to deprecate the resume_from_checkpoint and introduce should_load_recipe_state, otherwise it will be a backward incompatible change. However if thats not the case, then it seems okay to me to update all of our recipes which is a low risk since there is no functionality change. This PR already updates those.

cc @ebsmothers who had this comment in the earlier version of the PR.

)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
8 changes: 4 additions & 4 deletions recipes/ppo_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,22 +377,22 @@ def _setup_checkpointers(

policy_checkpointer = config.instantiate(
policy_cfg,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)

ref_policy_checkpointer = config.instantiate(
ref_policy_cfg,
resume_from_checkpoint=False,
should_load_recipe_state=False,
)

value_checkpointer = config.instantiate(
value_cfg,
resume_from_checkpoint=False,
should_load_recipe_state=False,
)

reward_checkpointer = config.instantiate(
reward_cfg,
resume_from_checkpoint=False,
should_load_recipe_state=False,
)

return (
Expand Down
2 changes: 1 addition & 1 deletion recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
Loading
Loading