Skip to content

Commit

Permalink
Support Optimizer-in-the-backward (#1833)
Browse files Browse the repository at this point in the history
  • Loading branch information
mori360 authored Oct 23, 2024
1 parent b02825a commit dc0591c
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 43 deletions.
1 change: 1 addition & 0 deletions docs/source/api_ref_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ Utilities to control lr during the training process.
:nosignatures:

get_cosine_schedule_with_warmup
get_lr

.. _metric_logging_label:

Expand Down
121 changes: 92 additions & 29 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training.activations import apply_selective_activation_checkpointing
from torchtune.training.lr_schedulers import get_lr

from tqdm import tqdm

Expand Down Expand Up @@ -129,6 +130,7 @@ def __init__(self, cfg: DictConfig) -> None:
# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)

# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
Expand Down Expand Up @@ -222,6 +224,7 @@ def setup(self, cfg: DictConfig) -> None:

self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
optimizer_in_bwd=self._optimizer_in_bwd,
opt_state_dict=checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
else None,
Expand Down Expand Up @@ -456,19 +459,57 @@ def _is_layer_fqn(s: str) -> bool:
return model

def _setup_optimizer(
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
) -> Optimizer:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
optimizer,
opt_state_dict,
self._device,
self,
cfg_optimizer: DictConfig,
optimizer_in_bwd: bool = False,
opt_state_dict: Optional[Dict[str, Any]] = None,
) -> Optional[Optimizer]:
if optimizer_in_bwd:
# Maintain a dict of optims for every parameter.
optim_dict = {
param: config.instantiate(cfg_optimizer, [param])
for param in self._model.parameters()
}

# Register optimizer step hooks on the model to run optimizer in backward.
training.register_optim_in_bwd_hooks(
model=self._model, optim_dict=optim_dict
)
# Create a wrapper for checkpoint save/load of optimizer states when running in backward.
self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper(
model=self._model, optim_dict=optim_dict
)
# Load optimizer states for each param. If optimizer states are being restored in an optimizer in
# backward run, these need to have been saved with the same setting. Cannot restore from runs that
# did not use optimizer in backward.
if opt_state_dict is not None:
for param in opt_state_dict.keys():
try:
training.load_from_full_optimizer_state_dict(
self._optim_ckpt_wrapper.state_dict()[param],
opt_state_dict[param],
self._device,
)
except BaseException as e:
raise RuntimeError(
"Failed loading in-backward optimizer checkpoints."
"Please make sure run being restored from was using in-backward optimizer."
) from e
if self._is_rank_zero:
log.info("In-backward optimizers are set up.")
return None
else:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
optimizer,
opt_state_dict,
self._device,
)

if self._is_rank_zero:
log.info("Optimizer is initialized.")
return optimizer
if self._is_rank_zero:
log.info("Optimizer is initialized.")
return optimizer

def _setup_data(
self,
Expand Down Expand Up @@ -509,13 +550,15 @@ def _setup_data(
sampler=sampler,
# dropping last avoids shape issues with compile + flex attention
drop_last=True,
collate_fn=partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else padded_collate_packed,
collate_fn=(
partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else padded_collate_packed
),
)

if self._is_rank_zero:
Expand Down Expand Up @@ -550,18 +593,24 @@ def save_checkpoint(
)

if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._optimizer,
self._is_rank_zero,
device=self._device,
)
if not self._optimizer_in_bwd:
opt_state_dict = training.get_full_optimizer_state_dict(
self._optimizer,
self._is_rank_zero,
device=self._device,
)
else:
opt_state_dict = {}
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
opt_state_dict[param] = training.get_full_optimizer_state_dict(
opt, self._is_rank_zero, device=self._device
)
else:
opt_state_dict = None

# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
if self._is_rank_zero:

checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict})

# if training is in-progress, checkpoint the optimizer state and recipe state
Expand Down Expand Up @@ -593,7 +642,11 @@ def train(self) -> None:
_, rank = training.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
if not self._optimizer_in_bwd:
self._optimizer.zero_grad()
else:
for opt in self._optim_ckpt_wrapper.optim_map.values():
opt.zero_grad()

# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
Expand All @@ -603,7 +656,6 @@ def train(self) -> None:
self._profiler.start()
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):

# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
self._sampler.set_epoch(curr_epoch)
Expand Down Expand Up @@ -657,12 +709,17 @@ def train(self) -> None:
# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
if self._clip_grad_norm is not None:
if self._optimizer_in_bwd:
raise NotImplementedError(
"Gradient clipping is not supported after optimizer-in-the-backward."
)
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
if not self._optimizer_in_bwd:
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Update the number of steps when the weights are updated
self.global_step += 1
Expand All @@ -681,7 +738,13 @@ def train(self) -> None:
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"lr": get_lr(
(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper
),
),
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
if self._log_peak_memory_stats:
Expand Down
25 changes: 14 additions & 11 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training.lr_schedulers import get_lr

from tqdm import tqdm

Expand Down Expand Up @@ -517,13 +518,15 @@ def _setup_data(
sampler=sampler,
# dropping last avoids shape issues with compile + flex attention
drop_last=True,
collate_fn=partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else padded_collate_packed,
collate_fn=(
partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else padded_collate_packed
),
)

log.info("Dataset and Sampler are initialized.")
Expand Down Expand Up @@ -658,10 +661,10 @@ def train(self) -> None:
"loss": loss_to_log,
# NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently
# true since we don't expose the ability to configure this yet.
"lr": (
self._optim_ckpt_wrapper.get_optim_key("lr")
if self._optimizer_in_bwd
else self._optimizer.param_groups[0]["lr"]
"lr": get_lr(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper,
),
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
Expand Down
18 changes: 16 additions & 2 deletions tests/recipes/test_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def _get_test_config_overrides(self):
"optimizer=torch.optim.AdamW",
"optimizer.lr=2e-5",
"log_every_n_steps=1",
"clip_grad_norm=100",
] + dummy_alpaca_dataset_config()

def _fetch_expected_loss_values(self, model_type):
Expand All @@ -60,9 +59,17 @@ def _fetch_expected_loss_values(self, model_type):
("llama3/8B_full", "llama3", "tune", "NO_SHARD"),
],
)
@pytest.mark.parametrize("optim_in_bwd", [True, False])
@gpu_test(gpu_count=2)
def test_loss(
self, config, model_type, ckpt_type, fsdp_sharding_strategy, tmpdir, monkeypatch
self,
config,
model_type,
ckpt_type,
fsdp_sharding_strategy,
optim_in_bwd,
tmpdir,
monkeypatch,
):
ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
ckpt = model_type + "_" + ckpt_type
Expand Down Expand Up @@ -91,6 +98,13 @@ def test_loss(
cmd.append(f"fsdp_sharding_strategy={fsdp_sharding_strategy}")
model_config = MODEL_TEST_CONFIGS[model_type]
cmd = cmd + self._get_test_config_overrides() + model_config
# "optimizer_in_bwd=True" would free gradient info before clip_grad, causing
# wrong grad_norm, so we only test one of them each time. But loss values
# should be the same.
if not optim_in_bwd:
cmd.append("clip_grad_norm=100")
else:
cmd.append("optimizer_in_bwd=True")

monkeypatch.setattr(sys, "argv", cmd)
runpy.run_path(TUNE_PATH, run_name="__main__")
Expand Down
3 changes: 2 additions & 1 deletion torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
TOTAL_EPOCHS_KEY,
update_state_dict_for_classifier,
)
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup, get_lr
from torchtune.training.memory import (
cleanup_before_training,
create_optim_in_bwd_wrapper,
Expand Down Expand Up @@ -93,6 +93,7 @@
"TOTAL_EPOCHS_KEY",
"get_quantizer_mode",
"get_cosine_schedule_with_warmup",
"get_lr",
"cleanup_before_training",
"create_optim_in_bwd_wrapper",
"get_memory_stats",
Expand Down
39 changes: 39 additions & 0 deletions torchtune/training/lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
# LICENSE file in the root directory of this source tree.

import math
from typing import Union

import torch
from torch.optim.lr_scheduler import LambdaLR
from torchtune.training.memory import OptimizerInBackwardWrapper


def get_cosine_schedule_with_warmup(
Expand Down Expand Up @@ -54,3 +56,40 @@ def lr_lambda(current_step: int) -> float:
return max(0.0, cosine_lr_multiple)

return LambdaLR(optimizer, lr_lambda, last_epoch)


def get_lr(
optimizer: Union[torch.optim.Optimizer, OptimizerInBackwardWrapper]
) -> float:
"""
Full_finetune_distributed and full_finetune_single_device assume all optimizers have
the same LR, here to validate whether all the LR are the same and return if True.
Args:
optimizer (Union[torch.optim.Optimizer, OptimizerInBackwardWrapper]): A general
optimizer input that could whether be a general optimizer or an optimizer
warpper based on optimizer_in_backward.
Returns:
lr (float): The learning rate of the input optimizers.
Raises:
RuntimeError: If the learning rates of the input optimizer are not the same.
"""
if isinstance(optimizer, OptimizerInBackwardWrapper):
param_groups = []
for param in optimizer.state_dict().values():
param_groups.append(param["param_groups"][0])
else:
param_groups = optimizer.param_groups
if len(param_groups) < 1:
raise RuntimeError(
f"Invalid optimizer param groups with len of: {len(param_groups)}"
)

# LR Schedulers are the same across all param groups for full_finetune right now
lr = param_groups[0]["lr"]
for group in param_groups:
if group["lr"] != lr:
raise RuntimeError("LR Schedulers are different across all param groups ")
return lr

0 comments on commit dc0591c

Please sign in to comment.