Skip to content

Commit

Permalink
2024-12-14 nightly release (c2c6f4a)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Dec 14, 2024
1 parent c56f08e commit a8e5a25
Show file tree
Hide file tree
Showing 30 changed files with 1,247 additions and 227 deletions.
1 change: 0 additions & 1 deletion docs/source/api_ref_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ Utilities for enabling and working with distributed training.

init_distributed
is_distributed
get_world_size_and_rank
gather_cpu_state_dict

.. _ac_label:
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_ref_utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ Miscellaneous
get_device
get_logger
torch_version_ge
get_world_size_and_rank
6 changes: 3 additions & 3 deletions recipes/dev/early_exit_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __init__(self, cfg: DictConfig) -> None:

# _is_rank_zero is used primarily for logging. In the future, the logger
# should directly take care of this
_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()
self._is_rank_zero = rank == 0

# Training cfg
Expand Down Expand Up @@ -646,7 +646,7 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -826,7 +826,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
if not self._optimizer_in_bwd:
Expand Down
163 changes: 53 additions & 110 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training.activations import apply_selective_activation_checkpointing
from torchtune.training.checkpointing._checkpoint_client import (
CheckpointClient,
TrainingProgress,
)
from torchtune.training.lr_schedulers import get_lr

from tqdm import tqdm
Expand Down Expand Up @@ -133,14 +137,16 @@ def __init__(self, cfg: DictConfig) -> None:
)
self._log_peak_memory_stats = False

_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()
self._is_rank_zero = rank == 0

# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
self._checkpoint_client = CheckpointClient(cfg)

# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
if self._optimizer_in_bwd:
Expand Down Expand Up @@ -189,21 +195,6 @@ def __init__(self, cfg: DictConfig) -> None:
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0

def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Extract the checkpoint state from file and validate. If resume_from_checkpoint
is True, this also includes the recipe state.
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

if self._resume_from_checkpoint:
self._update_recipe_state(checkpoint_dict)
return checkpoint_dict

def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
"""
Updates the recipe state from checkpoint.
Expand Down Expand Up @@ -255,7 +246,8 @@ def setup(self, cfg: DictConfig) -> None:
# log config with parameter override
self._metric_logger.log_config(cfg)

checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
# Load the base model
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()

self._compile = cfg.get("compile", False)
self._model = self._setup_model(
Expand All @@ -276,11 +268,36 @@ def setup(self, cfg: DictConfig) -> None:
optimizer_in_bwd=self._optimizer_in_bwd,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
if training.OPT_KEY in checkpoint_dict
else None
),
)

if self._resume_from_checkpoint:
# If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
# using the DistributedCheckpointer.
# Therefore the recipe needs to load the distributed checkpoint to restore the training
# progress.
if self._enable_async_checkpointing:
try:
checkpoint_dict = (
self._checkpoint_client.load_distributed_checkpoint(
self._model,
(
self._optim_ckpt_wrapper
if self._optimizer_in_bwd
else self._optimizer
),
)
)
except Exception as e:
log.warning(
f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint."
)

# Update the recipe state from the checkpoint state dict.
self._update_recipe_state(checkpoint_dict)

# initialize loss
self._loss_fn = config.instantiate(cfg.loss)

Expand Down Expand Up @@ -547,6 +564,7 @@ def _setup_model(
log,
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
)

if self._is_rank_zero:
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
Expand Down Expand Up @@ -619,7 +637,7 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -661,103 +679,14 @@ def _setup_data(

return sampler, dataloader

def save_checkpoint(
self,
epoch: int,
) -> None:
"""
Checkpoint the state of the recipe. The constructed checkpoint state dict
contains the following information:
- Model weights with key training.MODEL_KEY
- Relevant recipe state if training is not complete
Checkpointer will save the model weights and recipe state in
different checkpoint files. To correctly resume training from an intermediate checkpoint,
the model weights and recipe state must be provided.
"""
# final dict passed onto the checkpointer
checkpoint_dict = {}

intermediate_checkpoint = epoch + 1 < self.total_epochs

utils.log_rank_zero(
log,
"Saving checkpoint. This may take some time. Retrieving full model state dict...",
)
start = time.perf_counter()

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
device=self._device,
)

utils.log_rank_zero(
log,
f"Getting full model state dict took {time.perf_counter() - start:.2f} secs",
)

if intermediate_checkpoint:
start = time.perf_counter()
utils.log_rank_zero(log, "Getting optimizer state dict...")
if not self._optimizer_in_bwd:
opt_state_dict = training.get_full_optimizer_state_dict(
self._optimizer,
self._is_rank_zero,
device=self._device,
)
else:
opt_state_dict = {}
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
opt_state_dict[param] = training.get_full_optimizer_state_dict(
opt, self._is_rank_zero, device=self._device
)
utils.log_rank_zero(
log,
f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs",
)
else:
opt_state_dict = None

# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file

if self._is_rank_zero:
start = time.perf_counter()
checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict})

# if training is in-progress, checkpoint the optimizer state and recipe state
# as well.
if intermediate_checkpoint:
checkpoint_dict.update(
{
training.OPT_KEY: opt_state_dict,
training.SEED_KEY: self.seed,
training.EPOCHS_KEY: self.epochs_run,
training.TOTAL_EPOCHS_KEY: self.total_epochs,
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)

self._checkpointer.save_checkpoint(
checkpoint_dict,
epoch=epoch,
intermediate_checkpoint=intermediate_checkpoint,
)
log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs")

torch.distributed.barrier()

def train(self) -> None:
"""
The core training loop.
"""
# clean up before training begins
training.cleanup_before_training()

world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
if not self._optimizer_in_bwd:
Expand Down Expand Up @@ -922,7 +851,21 @@ def train(self) -> None:
self._profiler.step()

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)
self._checkpoint_client.save_checkpoint(
model=self._model,
optimizer=(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper
),
training_progress=TrainingProgress(
seed=self.seed,
epochs_run=self.epochs_run,
total_epochs=self.total_epochs,
max_steps_per_epoch=self.max_steps_per_epoch,
),
epoch=curr_epoch,
)

self._profiler.stop()

Expand Down
2 changes: 1 addition & 1 deletion recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
8 changes: 4 additions & 4 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self, cfg: DictConfig) -> None:
"fp16 precision is not supported in this recipe. Please use fp32 or bf16."
)

_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()

self._is_rank_zero = rank == 0

Expand Down Expand Up @@ -149,7 +149,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down Expand Up @@ -646,7 +646,7 @@ def _setup_data(
Map-style Datasets which fit into memory and an option for random shuffling.
Samplers, iterable datasets, and streaming datasets are not supported.
"""
world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -815,7 +815,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down
2 changes: 1 addition & 1 deletion recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
8 changes: 4 additions & 4 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()

self._is_rank_zero = rank == 0

Expand Down Expand Up @@ -188,7 +188,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down Expand Up @@ -492,7 +492,7 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = training.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
Expand Down Expand Up @@ -642,7 +642,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

Expand Down
Loading

0 comments on commit a8e5a25

Please sign in to comment.