Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A more encompassing fix for offloading + ac #1936

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 98 additions & 2 deletions tests/torchtune/training/test_activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torch import nn
from torchtune.training import OffloadActivations

NUM_GPU_CYCLES_IN_ONE_SEC = 2000000000 # 2e9 is ~1s worth of GPU cycles
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
NUM_GPU_CYCLES_IN_ONE_SEC = 2000000000 # 2e9 is ~1s worth of GPU cycles
NUM_GPU_CYCLES_IN_ONE_SEC = 2_000_000_000 # 2e9 is ~1s worth of GPU cycles

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I ran the repro script to confirm that memory usage + correctness were as expected. We also have a super small unit test to ensure correctness. I did not run a recipe though.



@gpu_test(gpu_count=1)
@pytest.mark.parametrize("use_streams", [True, False])
Expand Down Expand Up @@ -46,7 +48,8 @@ def test_offloading_is_same_as_without(use_streams) -> None:
def test_offloading_works_with_view_outputs() -> None:
"""
This test is quite contrived but tests against a very obscure situation where
any of the outputs of a backward node are a view of the unpacked tensor.
any of the outputs of a backward node are a view of the unpacked tensor. (See
the first line item under Note: [Track views of the unpacked]).

We want to ensure that if an unpacked tensor may be used later that we do not
free it too early.
Expand Down Expand Up @@ -98,7 +101,7 @@ def forward(ctx, activation):

@staticmethod
def backward(ctx, viewed_activation):
torch.cuda._sleep(2000000000) # 2e9 is ~1s worth of GPU cycles
torch.cuda._sleep(NUM_GPU_CYCLES_IN_ONE_SEC)
return viewed_activation == 1

class InspectEarlierActivation(torch.autograd.Function):
Expand Down Expand Up @@ -129,3 +132,96 @@ def fwd(t):
# delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd
ctx.fwd_stash = {}
loss_c.backward()


@gpu_test(gpu_count=1)
def test_offloading_works_with_view_ac_cached_buffers() -> None:
"""
Similar to test_offloading_works_with_view_outputs, but for when AC stashes
a view of the unpacked tensor. See the second line item under Note: [Track
views of the unpacked].

For details on how the following custom autograd function was contrived,
please see the image attached to the PR description in #1936. The visual
is more helpful than me trying to write a blob of text here.
"""

class A(torch.autograd.Function):
@staticmethod
def forward(ctx, ones):
ctx.save_for_backward(ones * 5) # corruptedly saving 5s
return ones

@staticmethod
def backward(ctx, activation_is_ones):
fives = ctx.saved_tensors[0]
assert torch.all(activation_is_ones)
return activation_is_ones

class B(torch.autograd.Function):
@staticmethod
def forward(ctx, ones):
ctx.save_for_backward(ones.clone())
return ones.clone() # important, a view of 1s will be saved in C

@staticmethod
def backward(ctx, activation_is_ones):
saved_tensor = ctx.saved_tensors[0]
return activation_is_ones.clone()

class C(torch.autograd.Function):
@staticmethod
def forward(ctx, ones):
ctx.save_for_backward(ones.t().t())
return ones.clone()

@staticmethod
def backward(ctx, grad):
saved_tensor = ctx.saved_tensors[0]
return saved_tensor == 1

class D(torch.autograd.Function):
@staticmethod
def forward(ctx, ones):
ctx.save_for_backward(torch.rand_like(ones))
return torch.rand_like(ones)

@staticmethod
def backward(ctx, grad):
saved_tensor = ctx.saved_tensors[0]
torch.cuda._sleep(NUM_GPU_CYCLES_IN_ONE_SEC)
return torch.rand_like(grad)

class E(torch.autograd.Function):
@staticmethod
def forward(ctx, ones):
ctx.save_for_backward(torch.rand_like(ones))
return torch.rand_like(ones)

@staticmethod
def backward(ctx, grad):
# It doesn't matter what E saves, but it needs to save something
# just to trigger AC recompute to fill in this tensor.
saved_tensor = ctx.saved_tensors[0]
return torch.rand_like(grad)

def checkpointed_region(b):
c = C.apply(b)
d = D.apply(c)
return E.apply(d)

def fwd(t):
a = A.apply(t)
b = B.apply(a)
e = torch.utils.checkpoint.checkpoint(
checkpointed_region, b, use_reentrant=False
)
return e.sum()

tensor = torch.ones(256, 1024, device="cuda", requires_grad=True)
ctx = OffloadActivations(use_streams=True)
with ctx:
loss = fwd(tensor)
# delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd
ctx.fwd_stash = {}
loss.backward()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are you checking here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the assert is line 158

42 changes: 33 additions & 9 deletions torchtune/training/_activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,19 +289,43 @@ def wait_and_del_remaining_references() -> None:
# Stash the tensor to keep memory alive until compute stream is complete
self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor

# Note: [Track views of the unpacked]
# Why do we get the use count of the unpacked tensor here? We want an
# initial count to compare to later, during the post-hook of the
# backward node, when we need to decide whether we're allowed to free
# the tensor yet. In what obscure cases must we delay freeing the
# tensor (and thus call record_stream)?
# 1. Any of the outputs of the backward node is a view of the unpacked
# tensor.
# 2. In the case that this unpacked tensor will be used in a
# checkpointed region, if one of the recomputed saved tensors ends
# up as a view of the unpacked tensor.
# 3. The user abuses the system somehow and manually relies on the
# unpacked tensor to exist after the backward node has executed.
storage_refcount = torch._C._storage_Use_Count(
maybe_gpu_tensor.untyped_storage()._cdata
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under what conditions would the refcount not be 1? If maybe_gpu_tensor was gpu already and was used somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it’s usually 2 because calling the python untyped_storage() increases the count


def hook(outputs, inputs):
# create events for the current node inputs/outputs if they were streamed in
if brought_back_from_cpu:
# if any of the outputs is a view of the tensor, meaning the tensor might be used later,
# we cannot presume to delete it after only the current node is done! So we use our frenemy,
# record_stream, to ensure the Tensor stays unmessed with until it's done getting used
# in the compute stream (s0 here). Note that the con here is we introduce non-deterministic
# memory usage, but this case should not happen often.
# See Note: [Track views of the unpacked]
# IF any of the outputs is a view of the tensor, OR if a view of
# the tensor has been saved as a part of checkpoint's recompute
# process, OR the user has abusedly incurred a reference on the
# unpacked tensor, THEN the tensor might be used later and we
# cannot presume to delete it after only the current node is
# done! So we use our frenemy, record_stream, to ensure the
# Tensor stays unmessed with until it's done getting used in the
# compute stream (s0 here). Note that the con here is we introduce
# non-deterministic (thus higher) memory usage, but this case
# should not happen often.
unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
if any(
o.untyped_storage() is unpacked_tensor.untyped_storage()
for o in outputs
if o is not None
if (
torch._C._storage_Use_Count(
unpacked_tensor.untyped_storage()._cdata
)
> storage_refcount
):
unpacked_tensor.record_stream(self.s0)
del self.bwd_tensor_stash[unpack_tensor_id]
Expand Down
Loading