Skip to content

Commit

Permalink
Fix memory leak in TraceEnumELBO (#3131)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Aug 29, 2022
1 parent d7f6474 commit 7102cf5
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 3 deletions.
5 changes: 4 additions & 1 deletion pyro/ops/einsum/torch_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ def einsum(equation, *operands):

inputs, output = equation.split("->")
if inputs == output:
return operands[0][...] # create a new object
# Originally we return `operands[0][...]` but that caused
# memory leak in PyTorch >= 1.11 (issue #3068). Hence we
# return `operands[0].clone()` here.
return operands[0].clone() # create a new object
inputs = inputs.split(",")

shifts = []
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_pickle(Dist):
try:
dist = Dist(*args)
except Exception:
pytest.skip(msg="cannot construct distribution")
pytest.skip(reason="cannot construct distribution")

buffer = io.BytesIO()
# Note that pickling torch.Size() requires protocol >= 2
Expand Down
2 changes: 1 addition & 1 deletion tests/infer/test_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def model():
num_samples=num_samples,
parallel=parallel,
)
actual = predictive.get_samples()
actual = predictive()
assert set(actual) == set(expected)
assert actual["x"].shape == expected["x"].shape
assert actual["y"].shape == expected["y"].shape
Expand Down
27 changes: 27 additions & 0 deletions tests/ops/einsum/test_adjoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import gc
import itertools

import pytest
Expand Down Expand Up @@ -114,3 +115,29 @@ def test_marginal(equation):
)
actual = operand._pyro_backward_result
assert_equal(expected, actual)


@pytest.mark.filterwarnings("ignore:.*reduce_op is deprecated")
def test_require_backward_memory_leak():
tensors = [o for o in gc.get_objects() if torch.is_tensor(o)]
num_global_tensors = len(tensors)
del tensors

# Using clone resolves memory leak.
for i in range(10):
x = torch.tensor(0.0)
require_backward(x)
x._pyro_backward.process(x.clone())

tensors = [o for o in gc.get_objects() if torch.is_tensor(o)]
assert len(tensors) <= 5 + num_global_tensors
del tensors

# Using [...] creates memory leak.
for i in range(10):
x = torch.tensor(0.0)
require_backward(x)
x._pyro_backward.process(x[...])

tensors = [o for o in gc.get_objects() if torch.is_tensor(o)]
assert len(tensors) >= 15 + num_global_tensors

0 comments on commit 7102cf5

Please sign in to comment.