Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A more encompassing fix for offloading + ac #1936

Merged
merged 1 commit into from
Nov 4, 2024

Conversation

janeyx99
Copy link
Contributor

@janeyx99 janeyx99 commented Oct 31, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.
Fixes #1867

Changelog

What are the changes made in this PR?

  • Increase condition for calling record_stream to be whenever the storage count is greater.
  • Add a test case (fails before this PR) to make sure we fixed the issue. How does the test case work? See the notes I wrote myself when contriving:
    temp
  • Note: this uses a private storage use_count API. I'm working on a less sketchy looking API in core, but we don't want to halt this fix due to the progress there. We can always come back and change this to use the newer, nicer API.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Oct 31, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1936

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3c9b33c with merge base f560cbb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 31, 2024
janeyx99 added a commit to pytorch/pytorch that referenced this pull request Nov 1, 2024
Needed by pytorch/torchtune#1936

In favor over #139109, as exposing an existing API is better than adding a new one (and this enables a more robust fix)




[ghstack-poisoned]
janeyx99 added a commit to pytorch/pytorch that referenced this pull request Nov 1, 2024
Needed by pytorch/torchtune#1936

In favor over #139109, as exposing an existing API is better than adding a new one (and this enables a more robust fix)




[ghstack-poisoned]
@janeyx99 janeyx99 force-pushed the storage_use_count_fix branch from a326ab7 to 3c9b33c Compare November 1, 2024 20:54
@janeyx99 janeyx99 marked this pull request as ready for review November 1, 2024 20:55
Copy link

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

janeyx99 added a commit to pytorch/pytorch that referenced this pull request Nov 1, 2024
Needed by pytorch/torchtune#1936

In favor over #139109, as exposing an existing API is better than adding a new one (and this enables a more robust fix)




[ghstack-poisoned]
janeyx99 added a commit to pytorch/pytorch that referenced this pull request Nov 1, 2024
Needed by pytorch/torchtune#1936

In favor over #139109, as exposing an existing API is better than adding a new one (and this enables a more robust fix)




[ghstack-poisoned]
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 77.94118% with 15 lines in your changes missing coverage. Please review.

Project coverage is 68.43%. Comparing base (f560cbb) to head (3c9b33c).

Files with missing lines Patch % Lines
...s/torchtune/training/test_activation_offloading.py 80.30% 13 Missing ⚠️
torchtune/training/_activation_offloading.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1936      +/-   ##
==========================================
+ Coverage   68.39%   68.43%   +0.04%     
==========================================
  Files         311      311              
  Lines       16901    16967      +66     
==========================================
+ Hits        11560    11612      +52     
- Misses       5341     5355      +14     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

# unpacked tensor to exist after the backward node has executed.
storage_refcount = torch._C._storage_Use_Count(
maybe_gpu_tensor.untyped_storage()._cdata
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under what conditions would the refcount not be 1? If maybe_gpu_tensor was gpu already and was used somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it’s usually 2 because calling the python untyped_storage() increases the count

loss = fwd(tensor)
# delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd
ctx.fwd_stash = {}
loss.backward()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are you checking here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the assert is line 158

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Nov 2, 2024
Would be nice to replace the torch._C._storage_Use_Count call in pytorch/torchtune#1936, at least without needing to know about _cdata in OSS code.

Initially keeping it private as Tensor._use_count is also private.

In favor over #139109 in solving the same problem, as exposing an existing API is better than adding a new one (and this enables a more robust fix)

Pull Request resolved: #139426
Approved by: https://github.com/soulitzer
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4/5 dentists prefer "A more encompassing fix for offloading + ac" compared to other leading brands

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the fix! qq: just as sanity check, were you able to run some model with offloading=True and confirmed it worked as before?

@@ -10,6 +10,8 @@
from torch import nn
from torchtune.training import OffloadActivations

NUM_GPU_CYCLES_IN_ONE_SEC = 2000000000 # 2e9 is ~1s worth of GPU cycles
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
NUM_GPU_CYCLES_IN_ONE_SEC = 2000000000 # 2e9 is ~1s worth of GPU cycles
NUM_GPU_CYCLES_IN_ONE_SEC = 2_000_000_000 # 2e9 is ~1s worth of GPU cycles

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I ran the repro script to confirm that memory usage + correctness were as expected. We also have a super small unit test to ensure correctness. I did not run a recipe though.

@janeyx99 janeyx99 merged commit 9eced21 into pytorch:main Nov 4, 2024
17 checks passed
@janeyx99 janeyx99 deleted the storage_use_count_fix branch November 4, 2024 15:58
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
Would be nice to replace the torch._C._storage_Use_Count call in pytorch/torchtune#1936, at least without needing to know about _cdata in OSS code.

Initially keeping it private as Tensor._use_count is also private.

In favor over pytorch#139109 in solving the same problem, as exposing an existing API is better than adding a new one (and this enables a more robust fix)

Pull Request resolved: pytorch#139426
Approved by: https://github.com/soulitzer
@ebsmothers ebsmothers mentioned this pull request Nov 26, 2024
44 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

OffloadActivations(use_streams=True) producing NaN gradients: a tensor deletion data race
7 participants