Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient scaling to account for world_size normalization #2172

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

mirceamironenco
Copy link
Contributor

@mirceamironenco mirceamironenco commented Dec 18, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

  • FSDP/FSDP2/DDP will normalize the gradients by world_size when performing all_reduce. For a sequence processing task where the desired loss is scaled by the total number of non-padded & non-ignored tokens this requires this normalization be undone. For example if world_size = 2, and we have 2 sets A, B of gradient producing tokens, the total loss we desire is loss(A) + loss(B) / (|A| + |B|) where the normalization factor 1 / (|A| + |B|) is currently being handled by scale_grads:
    training.scale_grads(self._model, 1 / num_tokens)

If A, B are processed on separate data parallel workers the current gradients would be produced by loss(A) / 2 + loss(B) / 2, and with the normalization done as before our loss becomes (loss(A) + loss(B)) / (2 * (|A| + |B|)). This PR accounts for world_size cancelling out the scaling factor.

I haven't seen very large differences wrt loss curves in my preliminary experiments after this change:

Screenshot 2024-12-18 at 16 31 14
Screenshot 2024-12-18 at 16 31 41

Where world_size means the gradient scaling factor is world_size / num_tokens and otherwise 1 / num_tokens. The commands to replicate these plots being:

tune run --nproc_per_node 2 full_finetune_distributed --config llama3_2/3B_full metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=llama3.23b_fix metric_logger.name=world_size dataset.packed=True tokenizer.max_seq_len=512 compile=True

tune run --nproc_per_node 2 full_finetune_distributed.py --config configs/llama3_2/3B_full metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=llama3.23b_fix_noprompt metric_logger.name=world_size dataset.packed=True dataset.train_on_input=False tokenizer.max_seq_len=512 compile=True

Someone with more compute budget can probably get a better idea of the effect for larger models.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Dec 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2172

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 34906b2 with merge base 3518492 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 18, 2024
@ebsmothers
Copy link
Contributor

Thanks @mirceamironenco for finding this bug and for making the fix! Apologies for the delay in getting back to it, I wanted to put together a minimal repro to validate this myself (I trust the code pointers, but I like seeing numerical parity). So I put together the following script(s) to convince myself. Can confirm that on identical toy models with identical data we see (grad on N devices) == (grad on single device) / N. Let me run some more experiments to see to what extent this will affect loss curves on larger world sizes. If there is an impact, we should give people an fyi before landing. Will get back to you soon once I run the experiments!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants