-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A more encompassing fix for offloading + ac #1936
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1936
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3c9b33c with merge base f560cbb (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Needed by pytorch/torchtune#1936 In favor over #139109, as exposing an existing API is better than adding a new one (and this enables a more robust fix) [ghstack-poisoned]
Needed by pytorch/torchtune#1936 In favor over #139109, as exposing an existing API is better than adding a new one (and this enables a more robust fix) [ghstack-poisoned]
a326ab7
to
3c9b33c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
Needed by pytorch/torchtune#1936 In favor over #139109, as exposing an existing API is better than adding a new one (and this enables a more robust fix) [ghstack-poisoned]
Needed by pytorch/torchtune#1936 In favor over #139109, as exposing an existing API is better than adding a new one (and this enables a more robust fix) [ghstack-poisoned]
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1936 +/- ##
==========================================
+ Coverage 68.39% 68.43% +0.04%
==========================================
Files 311 311
Lines 16901 16967 +66
==========================================
+ Hits 11560 11612 +52
- Misses 5341 5355 +14 ☔ View full report in Codecov by Sentry. |
# unpacked tensor to exist after the backward node has executed. | ||
storage_refcount = torch._C._storage_Use_Count( | ||
maybe_gpu_tensor.untyped_storage()._cdata | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Under what conditions would the refcount not be 1? If maybe_gpu_tensor
was gpu already and was used somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so it’s usually 2 because calling the python untyped_storage() increases the count
loss = fwd(tensor) | ||
# delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd | ||
ctx.fwd_stash = {} | ||
loss.backward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are you checking here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the assert is line 158
Would be nice to replace the torch._C._storage_Use_Count call in pytorch/torchtune#1936, at least without needing to know about _cdata in OSS code. Initially keeping it private as Tensor._use_count is also private. In favor over #139109 in solving the same problem, as exposing an existing API is better than adding a new one (and this enables a more robust fix) Pull Request resolved: #139426 Approved by: https://github.com/soulitzer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4/5 dentists prefer "A more encompassing fix for offloading + ac" compared to other leading brands
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the fix! qq: just as sanity check, were you able to run some model with offloading=True and confirmed it worked as before?
@@ -10,6 +10,8 @@ | |||
from torch import nn | |||
from torchtune.training import OffloadActivations | |||
|
|||
NUM_GPU_CYCLES_IN_ONE_SEC = 2000000000 # 2e9 is ~1s worth of GPU cycles |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NUM_GPU_CYCLES_IN_ONE_SEC = 2000000000 # 2e9 is ~1s worth of GPU cycles | |
NUM_GPU_CYCLES_IN_ONE_SEC = 2_000_000_000 # 2e9 is ~1s worth of GPU cycles |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I ran the repro script to confirm that memory usage + correctness were as expected. We also have a super small unit test to ensure correctness. I did not run a recipe though.
Would be nice to replace the torch._C._storage_Use_Count call in pytorch/torchtune#1936, at least without needing to know about _cdata in OSS code. Initially keeping it private as Tensor._use_count is also private. In favor over pytorch#139109 in solving the same problem, as exposing an existing API is better than adding a new one (and this enables a more robust fix) Pull Request resolved: pytorch#139426 Approved by: https://github.com/soulitzer
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Fixes #1867
Changelog
What are the changes made in this PR?
record_stream
to be whenever the storage count is greater.Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example