Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore backward after each batch for grad accum #1917

Merged
merged 10 commits into from
Oct 31, 2024
10 changes: 6 additions & 4 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,15 +697,17 @@ def train(self) -> None:
# Compute loss
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_fn(logits, labels) * current_num_tokens
current_loss = self._loss_fn(logits, labels) * current_num_tokens
Copy link
Contributor

@pbontrager pbontrager Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there was ever a issue with numerical stability, another option for scaling the loss would be:

if grad_accumulation_step == 0:
	base_num_tokens = current_num_tokens
	torch.distributed.broadcast(base_num_tokens, src=0)

current_loss = loss_fn(logits, labels) * current_num_tokens / base_num_tokens

This might over complicate things but I wanted to leave this here if in the future it turns out a reduced gradient/loss is necessary for smaller dtypes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we can take a look at this in the future


# free logits otherwise it peaks backward memory
del logits

running_loss += current_loss
felipemello1 marked this conversation as resolved.
Show resolved Hide resolved
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
if self._optimizer_in_bwd:
raise NotImplementedError(
Expand All @@ -722,7 +724,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably normalize by local_num_tokens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: I am probably gonna keep it like this since it should be representative of the loss we are actually using to step (even though it means our loss curves will look slightly different than they do today)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it makes sense. Will it break all regression tests though?

pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
9 changes: 5 additions & 4 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,12 +641,13 @@ def train(self) -> None:

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_step(batch) * current_num_tokens
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -661,7 +662,7 @@ def train(self) -> None:
self._lr_scheduler.step()
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
15 changes: 7 additions & 8 deletions recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,15 +704,14 @@ def train(self) -> None:
class_loss, kd_loss = self._loss_step(batch)
running_class_loss += class_loss * current_num_tokens
running_kd_loss += kd_loss * current_num_tokens
current_loss = (
1 - self._kd_ratio
) * class_loss + self._kd_ratio * kd_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
class_loss = running_class_loss / num_tokens
kd_loss = running_kd_loss / num_tokens
loss = (
1 - self._kd_ratio
) * class_loss + self._kd_ratio * kd_loss
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -724,8 +723,8 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

class_loss_to_log = class_loss.item()
kd_loss_to_log = kd_loss.item()
class_loss_to_log = running_class_loss.item() / num_tokens
kd_loss_to_log = running_kd_loss.item() / num_tokens
loss_to_log = (
1 - self._kd_ratio
) * class_loss_to_log + self._kd_ratio * kd_loss_to_log
Expand Down
10 changes: 6 additions & 4 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,15 +797,17 @@ def train(self) -> None:
# Compute loss
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_fn(logits, labels) * current_num_tokens
current_loss = self._loss_fn(logits, labels) * current_num_tokens

# free logits otherwise it peaks backward memory
del logits

running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -818,7 +820,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
9 changes: 5 additions & 4 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,12 +694,13 @@ def train(self) -> None:

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_step(batch) * current_num_tokens
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -711,7 +712,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
11 changes: 7 additions & 4 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,22 +692,25 @@ def train(self) -> None:
logits = logits.reshape(-1, logits.size(-1))

# Compute loss
running_loss += self._loss_fn(logits, labels) * current_num_tokens
current_loss = self._loss_fn(logits, labels) * current_num_tokens

# free logits otherwise it peaks backward memory
del logits

running_loss += current_loss
current_loss.backward

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)

self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
2 changes: 2 additions & 0 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
shard_model,
validate_no_params_on_meta_device,
)
from torchtune.training._grad_scaler import scale_grads
from torchtune.training._profiler import (
DEFAULT_PROFILE_DIR,
DEFAULT_PROFILER_ACTIVITIES,
Expand Down Expand Up @@ -132,4 +133,5 @@
"NoOpManager",
"OffloadActivations",
"FormattedCheckpointFiles",
"scale_grads",
]
14 changes: 14 additions & 0 deletions torchtune/training/_grad_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this really need its own file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do you wanna put it then? Otherwise I am gonna copy-paste this in every recipe which is worse imo

# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn


def scale_grads(m: nn.Module, scaler: torch.Tensor) -> None:
for p in m.parameters():
if p.grad is not None:
p.grad /= scaler
Loading