-
Notifications
You must be signed in to change notification settings - Fork 494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A more encompassing fix for offloading + ac #1936
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,8 @@ | |
from torch import nn | ||
from torchtune.training import OffloadActivations | ||
|
||
NUM_GPU_CYCLES_IN_ONE_SEC = 2000000000 # 2e9 is ~1s worth of GPU cycles | ||
|
||
|
||
@gpu_test(gpu_count=1) | ||
@pytest.mark.parametrize("use_streams", [True, False]) | ||
|
@@ -46,7 +48,8 @@ def test_offloading_is_same_as_without(use_streams) -> None: | |
def test_offloading_works_with_view_outputs() -> None: | ||
""" | ||
This test is quite contrived but tests against a very obscure situation where | ||
any of the outputs of a backward node are a view of the unpacked tensor. | ||
any of the outputs of a backward node are a view of the unpacked tensor. (See | ||
the first line item under Note: [Track views of the unpacked]). | ||
|
||
We want to ensure that if an unpacked tensor may be used later that we do not | ||
free it too early. | ||
|
@@ -98,7 +101,7 @@ def forward(ctx, activation): | |
|
||
@staticmethod | ||
def backward(ctx, viewed_activation): | ||
torch.cuda._sleep(2000000000) # 2e9 is ~1s worth of GPU cycles | ||
torch.cuda._sleep(NUM_GPU_CYCLES_IN_ONE_SEC) | ||
return viewed_activation == 1 | ||
|
||
class InspectEarlierActivation(torch.autograd.Function): | ||
|
@@ -129,3 +132,96 @@ def fwd(t): | |
# delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd | ||
ctx.fwd_stash = {} | ||
loss_c.backward() | ||
|
||
|
||
@gpu_test(gpu_count=1) | ||
def test_offloading_works_with_view_ac_cached_buffers() -> None: | ||
""" | ||
Similar to test_offloading_works_with_view_outputs, but for when AC stashes | ||
a view of the unpacked tensor. See the second line item under Note: [Track | ||
views of the unpacked]. | ||
|
||
For details on how the following custom autograd function was contrived, | ||
please see the image attached to the PR description in #1936. The visual | ||
is more helpful than me trying to write a blob of text here. | ||
""" | ||
|
||
class A(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, ones): | ||
ctx.save_for_backward(ones * 5) # corruptedly saving 5s | ||
return ones | ||
|
||
@staticmethod | ||
def backward(ctx, activation_is_ones): | ||
fives = ctx.saved_tensors[0] | ||
assert torch.all(activation_is_ones) | ||
return activation_is_ones | ||
|
||
class B(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, ones): | ||
ctx.save_for_backward(ones.clone()) | ||
return ones.clone() # important, a view of 1s will be saved in C | ||
|
||
@staticmethod | ||
def backward(ctx, activation_is_ones): | ||
saved_tensor = ctx.saved_tensors[0] | ||
return activation_is_ones.clone() | ||
|
||
class C(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, ones): | ||
ctx.save_for_backward(ones.t().t()) | ||
return ones.clone() | ||
|
||
@staticmethod | ||
def backward(ctx, grad): | ||
saved_tensor = ctx.saved_tensors[0] | ||
return saved_tensor == 1 | ||
|
||
class D(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, ones): | ||
ctx.save_for_backward(torch.rand_like(ones)) | ||
return torch.rand_like(ones) | ||
|
||
@staticmethod | ||
def backward(ctx, grad): | ||
saved_tensor = ctx.saved_tensors[0] | ||
torch.cuda._sleep(NUM_GPU_CYCLES_IN_ONE_SEC) | ||
return torch.rand_like(grad) | ||
|
||
class E(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, ones): | ||
ctx.save_for_backward(torch.rand_like(ones)) | ||
return torch.rand_like(ones) | ||
|
||
@staticmethod | ||
def backward(ctx, grad): | ||
# It doesn't matter what E saves, but it needs to save something | ||
# just to trigger AC recompute to fill in this tensor. | ||
saved_tensor = ctx.saved_tensors[0] | ||
return torch.rand_like(grad) | ||
|
||
def checkpointed_region(b): | ||
c = C.apply(b) | ||
d = D.apply(c) | ||
return E.apply(d) | ||
|
||
def fwd(t): | ||
a = A.apply(t) | ||
b = B.apply(a) | ||
e = torch.utils.checkpoint.checkpoint( | ||
checkpointed_region, b, use_reentrant=False | ||
) | ||
return e.sum() | ||
|
||
tensor = torch.ones(256, 1024, device="cuda", requires_grad=True) | ||
ctx = OffloadActivations(use_streams=True) | ||
with ctx: | ||
loss = fwd(tensor) | ||
# delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd | ||
ctx.fwd_stash = {} | ||
loss.backward() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what are you checking here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the assert is line 158 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -289,19 +289,43 @@ def wait_and_del_remaining_references() -> None: | |
# Stash the tensor to keep memory alive until compute stream is complete | ||
self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor | ||
|
||
# Note: [Track views of the unpacked] | ||
# Why do we get the use count of the unpacked tensor here? We want an | ||
# initial count to compare to later, during the post-hook of the | ||
# backward node, when we need to decide whether we're allowed to free | ||
# the tensor yet. In what obscure cases must we delay freeing the | ||
# tensor (and thus call record_stream)? | ||
# 1. Any of the outputs of the backward node is a view of the unpacked | ||
# tensor. | ||
# 2. In the case that this unpacked tensor will be used in a | ||
# checkpointed region, if one of the recomputed saved tensors ends | ||
# up as a view of the unpacked tensor. | ||
# 3. The user abuses the system somehow and manually relies on the | ||
# unpacked tensor to exist after the backward node has executed. | ||
storage_refcount = torch._C._storage_Use_Count( | ||
maybe_gpu_tensor.untyped_storage()._cdata | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Under what conditions would the refcount not be 1? If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so it’s usually 2 because calling the python untyped_storage() increases the count |
||
|
||
def hook(outputs, inputs): | ||
# create events for the current node inputs/outputs if they were streamed in | ||
if brought_back_from_cpu: | ||
# if any of the outputs is a view of the tensor, meaning the tensor might be used later, | ||
# we cannot presume to delete it after only the current node is done! So we use our frenemy, | ||
# record_stream, to ensure the Tensor stays unmessed with until it's done getting used | ||
# in the compute stream (s0 here). Note that the con here is we introduce non-deterministic | ||
# memory usage, but this case should not happen often. | ||
# See Note: [Track views of the unpacked] | ||
# IF any of the outputs is a view of the tensor, OR if a view of | ||
# the tensor has been saved as a part of checkpoint's recompute | ||
# process, OR the user has abusedly incurred a reference on the | ||
# unpacked tensor, THEN the tensor might be used later and we | ||
# cannot presume to delete it after only the current node is | ||
# done! So we use our frenemy, record_stream, to ensure the | ||
# Tensor stays unmessed with until it's done getting used in the | ||
# compute stream (s0 here). Note that the con here is we introduce | ||
# non-deterministic (thus higher) memory usage, but this case | ||
# should not happen often. | ||
unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] | ||
if any( | ||
o.untyped_storage() is unpacked_tensor.untyped_storage() | ||
for o in outputs | ||
if o is not None | ||
if ( | ||
torch._C._storage_Use_Count( | ||
unpacked_tensor.untyped_storage()._cdata | ||
) | ||
> storage_refcount | ||
): | ||
unpacked_tensor.record_stream(self.s0) | ||
del self.bwd_tensor_stash[unpack_tensor_id] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I ran the repro script to confirm that memory usage + correctness were as expected. We also have a super small unit test to ensure correctness. I did not run a recipe though.