Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient scaling to account for world_size normalization #2172

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,9 @@ def train(self) -> None:
if self._optimizer_in_bwd:
torch.distributed.all_reduce(num_tokens)
torch.distributed.all_reduce(running_loss)
current_loss = current_loss / num_tokens

# We multiply by world_size to undo FSDP2 gradient normalization.
current_loss = current_loss * (world_size / num_tokens)

current_loss.backward()

Expand All @@ -777,7 +779,8 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand Down
4 changes: 2 additions & 2 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,6 @@ def save_checkpoint(self, epoch: int) -> None:
def _loss_step(
self, batch: Dict[str, torch.Tensor]
) -> (torch.Tensor, torch.Tensor):

# Both are shape [b, s]
tokens, labels = batch["tokens"], batch["labels"]

Expand Down Expand Up @@ -876,7 +875,8 @@ def train(self) -> None:
torch.distributed.all_reduce(running_class_loss)
torch.distributed.all_reduce(running_kd_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
class_loss_to_log = running_class_loss.item() / num_tokens
kd_loss_to_log = running_kd_loss.item() / num_tokens
self._optimizer.step()
Expand Down
3 changes: 2 additions & 1 deletion recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,8 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand Down
3 changes: 2 additions & 1 deletion recipes/lora_finetune_distributed_multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,8 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand Down
7 changes: 5 additions & 2 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,9 @@ def train(self) -> None:
if self._optimizer_in_bwd:
torch.distributed.all_reduce(num_tokens)
torch.distributed.all_reduce(running_loss)
current_loss = current_loss / num_tokens

# We multiply by world_size to undo FSDP2 gradient normalization.
current_loss = current_loss * (world_size / num_tokens)

current_loss.backward()

Expand All @@ -841,7 +843,8 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand Down
3 changes: 2 additions & 1 deletion recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,8 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand Down
Loading