Skip to content

Commit

Permalink
address comments, some optim in bwd cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers committed Oct 31, 2024
1 parent 83cba27 commit 408e521
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 33 deletions.
32 changes: 17 additions & 15 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,20 @@ def __init__(self, cfg: DictConfig) -> None:
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
self._clip_grad_norm = cfg.get("clip_grad_norm", None)

if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd:
raise RuntimeError(
"Gradient accumulation is not supported with optimizer in bwd."
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
)
# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
if self._optimizer_in_bwd:
if self._clip_grad_norm is not None:
raise RuntimeError(
"Gradient clipping is not supported with optimizer in bwd."
"Please set clip_grad_norm=None, or optimizer_in_bwd=False."
)
if self._gradient_accumulation_steps > 1:
raise RuntimeError(
"Gradient accumulation is not supported with optimizer in bwd."
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
)

# activation checkpointing/offloading
self._enable_activation_checkpointing = cfg.get(
Expand Down Expand Up @@ -187,7 +195,6 @@ def __init__(self, cfg: DictConfig) -> None:
self.total_epochs = cfg.epochs
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0
self._clip_grad_norm = cfg.get("clip_grad_norm", None)

def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -796,16 +803,11 @@ def train(self) -> None:
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
if self._optimizer_in_bwd:
raise NotImplementedError(
"Gradient clipping is not supported after optimizer-in-the-backward."
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
if not self._optimizer_in_bwd:
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

Expand Down
35 changes: 20 additions & 15 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ def __init__(self, cfg: DictConfig) -> None:
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._optimizer_in_bwd = cfg.optimizer_in_bwd
self._clip_grad_norm = cfg.get("clip_grad_norm", None)

# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
if self._optimizer_in_bwd:
if self._clip_grad_norm is not None:
raise RuntimeError(
"Gradient clipping is not supported with optimizer in bwd."
"Please set clip_grad_norm=None, or optimizer_in_bwd=False."
)
if self._gradient_accumulation_steps > 1:
raise RuntimeError(
"Gradient accumulation is not supported with optimizer in bwd."
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
)

# activation checkpointing/offloading
self._enable_activation_checkpointing = cfg.get(
Expand All @@ -164,22 +178,13 @@ def __init__(self, cfg: DictConfig) -> None:
"Enabling activation offloading should reduce memory further."
)

# TODO: find a better place / way to perform validation of args that don't yet
# compose with each other.
if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd:
raise RuntimeError(
"Gradient accumulation is not supported with optimizer in bwd."
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
)

# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
self.seed = training.set_seed(seed=cfg.seed)
self.epochs_run = 0
self.total_epochs = cfg.epochs
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0
self._clip_grad_norm = cfg.get("clip_grad_norm", None)

def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -692,13 +697,13 @@ def train(self) -> None:

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
if not self._optimizer_in_bwd:
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

Expand Down
11 changes: 8 additions & 3 deletions tests/recipes/test_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def _get_test_config_overrides(self):
"lr_scheduler.num_warmup_steps=0",
"lr_scheduler.num_cycles=0",
"log_every_n_steps=1",
"clip_grad_norm=100",
] + dummy_alpaca_dataset_config()

def _fetch_expected_loss_values(self, model_type):
Expand Down Expand Up @@ -94,7 +93,6 @@ def test_loss(
--config {config} \
batch_size={micro_batch_size} \
gradient_accumulation_steps={gradient_accumulation_steps} \
optimizer_in_bwd={optimizer_in_bwd} \
output_dir={tmpdir} \
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir='{ckpt_dir}' \
Expand All @@ -109,7 +107,14 @@ def test_loss(

model_config = MODEL_TEST_CONFIGS[model_type]
cmd = cmd + self._get_test_config_overrides() + model_config

# "optimizer_in_bwd=True" would free gradient info before clip_grad, causing
# wrong grad_norm, so we only test one of them each time. But loss values
# should be the same.
if not optimizer_in_bwd:
cmd.append("clip_grad_norm=100")
cmd.append("optimizer_in_bwd=False")
else:
cmd.append("optimizer_in_bwd=True")
monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")
Expand Down

0 comments on commit 408e521

Please sign in to comment.