Skip to content

Commit

Permalink
A more encompassing fix for offloading + ac
Browse files Browse the repository at this point in the history
  • Loading branch information
janeyx99 committed Nov 1, 2024
1 parent f560cbb commit 3c9b33c
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 11 deletions.
100 changes: 98 additions & 2 deletions tests/torchtune/training/test_activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torch import nn
from torchtune.training import OffloadActivations

NUM_GPU_CYCLES_IN_ONE_SEC = 2000000000 # 2e9 is ~1s worth of GPU cycles


@gpu_test(gpu_count=1)
@pytest.mark.parametrize("use_streams", [True, False])
Expand Down Expand Up @@ -46,7 +48,8 @@ def test_offloading_is_same_as_without(use_streams) -> None:
def test_offloading_works_with_view_outputs() -> None:
"""
This test is quite contrived but tests against a very obscure situation where
any of the outputs of a backward node are a view of the unpacked tensor.
any of the outputs of a backward node are a view of the unpacked tensor. (See
the first line item under Note: [Track views of the unpacked]).
We want to ensure that if an unpacked tensor may be used later that we do not
free it too early.
Expand Down Expand Up @@ -98,7 +101,7 @@ def forward(ctx, activation):

@staticmethod
def backward(ctx, viewed_activation):
torch.cuda._sleep(2000000000) # 2e9 is ~1s worth of GPU cycles
torch.cuda._sleep(NUM_GPU_CYCLES_IN_ONE_SEC)
return viewed_activation == 1

class InspectEarlierActivation(torch.autograd.Function):
Expand Down Expand Up @@ -129,3 +132,96 @@ def fwd(t):
# delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd
ctx.fwd_stash = {}
loss_c.backward()


@gpu_test(gpu_count=1)
def test_offloading_works_with_view_ac_cached_buffers() -> None:
"""
Similar to test_offloading_works_with_view_outputs, but for when AC stashes
a view of the unpacked tensor. See the second line item under Note: [Track
views of the unpacked].
For details on how the following custom autograd function was contrived,
please see the image attached to the PR description in #1936. The visual
is more helpful than me trying to write a blob of text here.
"""

class A(torch.autograd.Function):
@staticmethod
def forward(ctx, ones):
ctx.save_for_backward(ones * 5) # corruptedly saving 5s
return ones

@staticmethod
def backward(ctx, activation_is_ones):
fives = ctx.saved_tensors[0]
assert torch.all(activation_is_ones)
return activation_is_ones

class B(torch.autograd.Function):
@staticmethod
def forward(ctx, ones):
ctx.save_for_backward(ones.clone())
return ones.clone() # important, a view of 1s will be saved in C

@staticmethod
def backward(ctx, activation_is_ones):
saved_tensor = ctx.saved_tensors[0]
return activation_is_ones.clone()

class C(torch.autograd.Function):
@staticmethod
def forward(ctx, ones):
ctx.save_for_backward(ones.t().t())
return ones.clone()

@staticmethod
def backward(ctx, grad):
saved_tensor = ctx.saved_tensors[0]
return saved_tensor == 1

class D(torch.autograd.Function):
@staticmethod
def forward(ctx, ones):
ctx.save_for_backward(torch.rand_like(ones))
return torch.rand_like(ones)

@staticmethod
def backward(ctx, grad):
saved_tensor = ctx.saved_tensors[0]
torch.cuda._sleep(NUM_GPU_CYCLES_IN_ONE_SEC)
return torch.rand_like(grad)

class E(torch.autograd.Function):
@staticmethod
def forward(ctx, ones):
ctx.save_for_backward(torch.rand_like(ones))
return torch.rand_like(ones)

@staticmethod
def backward(ctx, grad):
# It doesn't matter what E saves, but it needs to save something
# just to trigger AC recompute to fill in this tensor.
saved_tensor = ctx.saved_tensors[0]
return torch.rand_like(grad)

def checkpointed_region(b):
c = C.apply(b)
d = D.apply(c)
return E.apply(d)

def fwd(t):
a = A.apply(t)
b = B.apply(a)
e = torch.utils.checkpoint.checkpoint(
checkpointed_region, b, use_reentrant=False
)
return e.sum()

tensor = torch.ones(256, 1024, device="cuda", requires_grad=True)
ctx = OffloadActivations(use_streams=True)
with ctx:
loss = fwd(tensor)
# delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd
ctx.fwd_stash = {}
loss.backward()
42 changes: 33 additions & 9 deletions torchtune/training/_activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,19 +289,43 @@ def wait_and_del_remaining_references() -> None:
# Stash the tensor to keep memory alive until compute stream is complete
self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor

# Note: [Track views of the unpacked]
# Why do we get the use count of the unpacked tensor here? We want an
# initial count to compare to later, during the post-hook of the
# backward node, when we need to decide whether we're allowed to free
# the tensor yet. In what obscure cases must we delay freeing the
# tensor (and thus call record_stream)?
# 1. Any of the outputs of the backward node is a view of the unpacked
# tensor.
# 2. In the case that this unpacked tensor will be used in a
# checkpointed region, if one of the recomputed saved tensors ends
# up as a view of the unpacked tensor.
# 3. The user abuses the system somehow and manually relies on the
# unpacked tensor to exist after the backward node has executed.
storage_refcount = torch._C._storage_Use_Count(
maybe_gpu_tensor.untyped_storage()._cdata
)

def hook(outputs, inputs):
# create events for the current node inputs/outputs if they were streamed in
if brought_back_from_cpu:
# if any of the outputs is a view of the tensor, meaning the tensor might be used later,
# we cannot presume to delete it after only the current node is done! So we use our frenemy,
# record_stream, to ensure the Tensor stays unmessed with until it's done getting used
# in the compute stream (s0 here). Note that the con here is we introduce non-deterministic
# memory usage, but this case should not happen often.
# See Note: [Track views of the unpacked]
# IF any of the outputs is a view of the tensor, OR if a view of
# the tensor has been saved as a part of checkpoint's recompute
# process, OR the user has abusedly incurred a reference on the
# unpacked tensor, THEN the tensor might be used later and we
# cannot presume to delete it after only the current node is
# done! So we use our frenemy, record_stream, to ensure the
# Tensor stays unmessed with until it's done getting used in the
# compute stream (s0 here). Note that the con here is we introduce
# non-deterministic (thus higher) memory usage, but this case
# should not happen often.
unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
if any(
o.untyped_storage() is unpacked_tensor.untyped_storage()
for o in outputs
if o is not None
if (
torch._C._storage_Use_Count(
unpacked_tensor.untyped_storage()._cdata
)
> storage_refcount
):
unpacked_tensor.record_stream(self.s0)
del self.bwd_tensor_stash[unpack_tensor_id]
Expand Down

0 comments on commit 3c9b33c

Please sign in to comment.