Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory leak occurs while using lightning.kokkos device for machine learning[BUG] #6297

Closed
1 task done
JustinS6626 opened this issue Sep 24, 2024 · 28 comments
Closed
1 task done
Labels
bug 🐛 Something isn't working

Comments

@JustinS6626
Copy link

Expected behavior

I apologize for reposting this issue from the forum https://discuss.pennylane.ai/t/memory-leak-in-when-using-lighning-kokkos-device/5218, but this issue is a major roadblock for a time-sensitive project. The issue that I am reporting is a possible bug which may be in the QNode class or in one of the Pennylane devices. This issue causes a memory leak when the QNode object is called while Pytorch is set to calculate a gradient. Instead of releasing the memory once the gradient is calculated, Pennylane/Pytorch keeps holding onto it.

Actual behavior

The memory leak is shown through the profiling tool used in the code example below. The problem was originally spotted in a large-scale quantum machine learning project. In this larger project, the training process halted early as a result running out of memory. The output shown is not an actual error, but rather the result of tracking memory usage in the example.

Additional information

No response

Source code

import pennylane as qml
from pennylane import numpy as np
import torch
from matplotlib import pyplot as plt
from memory_profiler import profile

# we can use the dataset hosted on PennyLane
[pm] = qml.data.load('other', name='plus-minus')

X_train = pm.img_train  
X_test = pm.img_test  
Y_train = pm.labels_train 
Y_test = pm.labels_test  


x_vis = [
    (X_train[Y_train == 0])[0],
    (X_train[Y_train == 1])[0],
    (X_train[Y_train == 2])[0],
    (X_train[Y_train == 3])[0],
]
y_vis = [0, 1, 2, 3]



def visualize_data(x, y, pred=None):
    n_img = len(x)
    labels_list = ["\u2212", "\u002b", "\ua714", "\u02e7"]
    fig, axes = plt.subplots(1, 4, figsize=(8, 2))
    for i in range(n_img):
        axes[i].imshow(x[i], cmap="gray")
        if pred is None:
            axes[i].set_title("Label: {}".format(labels_list[y[i]]))
        else:
            axes[i].set_title("Label: {}, Pred: {}".format(labels_list[y[i]], labels_list[pred[i]]))
    plt.tight_layout(w_pad=2)
    # plt.show()


visualize_data(x_vis, y_vis)


input_dim = 256
num_classes = 4
num_layers = 32
num_qubits = 8
num_reup = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


class QML_classifier(torch.nn.Module):
   
    def __init__(self, input_dim, output_dim, num_qubits, num_layers):
        super().__init__()
        torch.manual_seed(1337) 
        self.num_qubits = num_qubits
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.device = qml.device("lightning.kokkos", wires=self.num_qubits)
        self.weights_shape = qml.StronglyEntanglingLayers.shape(
            n_layers=self.num_layers, n_wires=self.num_qubits
        )

        @qml.qnode(self.device)
        def circuit(inputs, weights, bias):
            inputs = torch.reshape(inputs, self.weights_shape)
            qml.StronglyEntanglingLayers(
                weights=weights * inputs + bias, wires=range(self.num_qubits)
            )
            return [qml.expval(qml.PauliZ(i)) for i in range(self.output_dim)]

        param_shapes = {"weights": self.weights_shape, "bias": self.weights_shape}
        init_vals = {
            "weights": 0.1 * torch.rand(self.weights_shape),
            "bias": 0.1 * torch.rand(self.weights_shape),
        }


        self.qcircuit = qml.qnn.TorchLayer(
            qnode=circuit, weight_shapes=param_shapes, init_method=init_vals
        )

    @profile
    def forward(self, x):
        inputs_stack = torch.hstack([x] * num_reup)
        results = self.qcircuit(inputs_stack)
        return results

learning_rate = 0.1
epochs = 5
batch_size = 50

feats_train = torch.from_numpy(X_train[:200]).reshape(200, -1).to(device)
feats_test = torch.from_numpy(X_test[:50]).reshape(50, -1).to(device)
labels_train = torch.from_numpy(Y_train[:200]).to(device)
labels_test = torch.from_numpy(Y_test[:50]).to(device)
num_train = feats_train.shape[0]

# initialize the model, loss function and optimization algorithm (Adam optimizer)
qml_model = QML_classifier(input_dim, num_classes, num_qubits, num_layers)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(qml_model.parameters(), lr=learning_rate)
num_batches = feats_train.shape[0] // batch_size

def accuracy(labels, predictions):
    acc = 0
    for l, p in zip(labels, predictions):
        if torch.argmax(p) == l:
            acc += 1
    acc = acc / len(labels)
    return acc


# generate randomly permutated batches to speed up training
def gen_batches(num_samples, num_batches):
    assert num_samples % num_batches == 0
    perm_ind = torch.reshape(torch.randperm(num_samples), (num_batches, -1))
    return perm_ind


# display accuracy and loss after each epoch (to speed up runtime, only do this for first 100 samples)
def print_acc(epoch, max_ep=5):
    predictions_train = [qml_model(f) for f in feats_train[:50]]
    predictions_test = [qml_model(f) for f in feats_test]
    cost_approx_train = loss(torch.stack(predictions_train), labels_train[:50])
    cost_approx_test = loss(torch.stack(predictions_test), labels_test)
    acc_approx_train = accuracy(labels_train[:50], predictions_train)
    acc_approx_test = accuracy(labels_test, predictions_test)
    print(
        f"Epoch {epoch}/{max_ep} | Approx Cost (train): {cost_approx_train:0.7f} | Cost (val): {cost_approx_test:0.7f} |"
        f" Approx Acc train: {acc_approx_train:0.7f} | Acc val: {acc_approx_test:0.7f}"
    )


print(
    f"Starting training loop for quantum variational classifier ({num_qubits} qubits, {num_layers} layers)..."
)

for ep in range(0, epochs):
    batch_ind = gen_batches(num_train, num_batches)
    print_acc(epoch=ep)

    for it in range(num_batches):
        optimizer.zero_grad()
        feats_train_batch = feats_train[batch_ind[it]]
        labels_train_batch = labels_train[batch_ind[it]]

        outputs = [qml_model(f) for f in feats_train_batch]
        batch_loss = loss(torch.stack(outputs), labels_train_batch)
        # if REG:
        #    loss = loss + lipschitz_regularizer(regularization_rate, model.qcircuit.weights)
        batch_loss.backward()
        optimizer.step()

print_acc(epochs)


# show accuracy
x_vis_torch = torch.from_numpy(np.array(x_vis).reshape(4, -1))
y_vis_torch = torch.from_numpy(np.array(y_vis))
benign_preds = [qml_model(f) for f in x_vis_torch]

benign_class_output = [torch.argmax(p) for p in benign_preds]
visualize_data(x_vis, y_vis, benign_class_output)

Tracebacks

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    689.2 MiB    689.2 MiB           1       @profile
    94                                             def forward(self, x):
    95    693.5 MiB      4.2 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    738.0 MiB     44.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    738.0 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    738.0 MiB    738.0 MiB           1       @profile
    94                                             def forward(self, x):
    95    738.0 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    740.7 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    740.7 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    740.7 MiB    740.7 MiB           1       @profile
    94                                             def forward(self, x):
    95    740.7 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    743.5 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    743.5 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    743.5 MiB    743.5 MiB           1       @profile
    94                                             def forward(self, x):
    95    743.5 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    746.0 MiB      2.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    746.0 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    746.0 MiB    746.0 MiB           1       @profile
    94                                             def forward(self, x):
    95    746.0 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    748.7 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    748.7 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    748.7 MiB    748.7 MiB           1       @profile
    94                                             def forward(self, x):
    95    748.7 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    751.5 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    751.5 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    751.5 MiB    751.5 MiB           1       @profile
    94                                             def forward(self, x):
    95    751.5 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    754.0 MiB      2.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    754.0 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    754.0 MiB    754.0 MiB           1       @profile
    94                                             def forward(self, x):
    95    754.0 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    756.7 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    756.7 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    756.7 MiB    756.7 MiB           1       @profile
    94                                             def forward(self, x):
    95    756.7 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    759.5 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    759.5 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    759.5 MiB    759.5 MiB           1       @profile
    94                                             def forward(self, x):
    95    759.5 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    762.0 MiB      2.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    762.0 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    762.0 MiB    762.0 MiB           1       @profile
    94                                             def forward(self, x):
    95    762.0 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    764.7 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    764.7 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    764.7 MiB    764.7 MiB           1       @profile
    94                                             def forward(self, x):
    95    764.7 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    767.2 MiB      2.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    767.2 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    767.2 MiB    767.2 MiB           1       @profile
    94                                             def forward(self, x):
    95    767.2 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    769.7 MiB      2.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    769.7 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    769.7 MiB    769.7 MiB           1       @profile
    94                                             def forward(self, x):
    95    769.7 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    772.2 MiB      2.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    772.2 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    772.2 MiB    772.2 MiB           1       @profile
    94                                             def forward(self, x):
    95    772.2 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    775.0 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    775.0 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    775.0 MiB    775.0 MiB           1       @profile
    94                                             def forward(self, x):
    95    775.0 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    777.7 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    777.7 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    777.7 MiB    777.7 MiB           1       @profile
    94                                             def forward(self, x):
    95    777.7 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    780.5 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    780.5 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    780.5 MiB    780.5 MiB           1       @profile
    94                                             def forward(self, x):
    95    780.5 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    783.0 MiB      2.5 MiB           1           results = self.qcircuit(inputs_stack)
    97    783.0 MiB      0.0 MiB           1           return results


Filename: /home/owner/quantum_adv_demo.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    93    783.0 MiB    783.0 MiB           1       @profile
    94                                             def forward(self, x):
    95    783.0 MiB      0.0 MiB           1           inputs_stack = torch.hstack([x] * num_reup)
    96    785.7 MiB      2.8 MiB           1           results = self.qcircuit(inputs_stack)
    97    785.7 MiB      0.0 MiB           1           return results

System information

Name: PennyLane
Version: 0.38.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /usr/local/lib/python3.11/dist-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-qiskit, pennylane-qulacs, PennyLane_Lightning, PennyLane_Lightning_GPU, PennyLane_Lightning_Kokkos

Platform info:           Linux-6.8.0-40-generic-x86_64-with-glibc2.35
Python version:          3.11.0
Numpy version:           1.26.3
Scipy version:           1.12.0
Installed devices:
- lightning.kokkos (PennyLane_Lightning_Kokkos-0.38.0)
- qiskit.aer (PennyLane-qiskit-0.37.0)
- qiskit.basicaer (PennyLane-qiskit-0.37.0)
- qiskit.basicsim (PennyLane-qiskit-0.37.0)
- qiskit.ibmq (PennyLane-qiskit-0.37.0)
- qiskit.ibmq.circuit_runner (PennyLane-qiskit-0.37.0)
- qiskit.ibmq.sampler (PennyLane-qiskit-0.37.0)
- qiskit.remote (PennyLane-qiskit-0.37.0)
- default.clifford (PennyLane-0.38.0)
- default.gaussian (PennyLane-0.38.0)
- default.mixed (PennyLane-0.38.0)
- default.qubit (PennyLane-0.38.0)
- default.qubit.autograd (PennyLane-0.38.0)
- default.qubit.jax (PennyLane-0.38.0)
- default.qubit.legacy (PennyLane-0.38.0)
- default.qubit.tf (PennyLane-0.38.0)
- default.qubit.torch (PennyLane-0.38.0)
- default.qutrit (PennyLane-0.38.0)
- default.qutrit.mixed (PennyLane-0.38.0)
- default.tensor (PennyLane-0.38.0)
- null.qubit (PennyLane-0.38.0)
- lightning.qubit (PennyLane_Lightning-0.38.0)
- lightning.gpu (PennyLane_Lightning_GPU-0.35.1)
- qulacs.simulator (pennylane-qulacs-0.36.0)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@JustinS6626 JustinS6626 added the bug 🐛 Something isn't working label Sep 24, 2024
@albi3ro
Copy link
Contributor

albi3ro commented Sep 24, 2024

Thanks for opening this @JustinS6626 . Is this memory leak specific to lightning.kokkos, or does it also occur with default.qubit or lightning.qubit? This would help us isolate whether the issue is in the device or surrounding infrastructure.

@JustinS6626
Copy link
Author

Thank you very much for getting back to me! I tried my large-scale model with lightning.qubit for comparison, and the same thing happens with that device as well. The model runs very slowly with default.qubit, so it will take a while to be able to tell, but I could try if if that helps.

@JustinS6626
Copy link
Author

It looks like the same thing happens when I run the code with default.qubit.

@CatalinaAlbornoz
Copy link
Contributor

CatalinaAlbornoz commented Sep 24, 2024

Thanks for confirming @JustinS6626

We're looking into this to check if it's due to the scale of the system or something else. Thanks for confirming that it's not only a lightning.kokkos issue, it helps a lot.

@JustinS6626
Copy link
Author

Thank you! I tried my code out on PennyLane V0.38.1 and the same thing happens on that version.

@mlxd
Copy link
Member

mlxd commented Sep 25, 2024

Hi @JustinS6626 thanks again for providing the above example.
After spending some time looking into the provided code, and based on my current view of the workload I am not sure there is a memory leak. I see the heap memory bounded running locally on my machine over the course of execution of the workflow, and it always remains under 6GB of total usage. It may be possible that what you observe is caching of data from Torch, combined with some of the intermediates generated over circuit execution, which (as far as I can see) reach a steady state over time.

Since you are using Torch's CUDA device target, it is known that Torch will cache CUDA data (e.g. see https://stackoverflow.com/questions/55322434/how-to-clear-cuda-memory-in-pytorch and https://discuss.pytorch.org/t/about-torch-cuda-empty-cache/34232 for two representative discussions). It may be worth trying to see if freeing these caches give you back some memory for your workload.

As you have a somewhat deep circuit with many parameters (32*8*3, as far as I can currently see) this may also take some time (and use additional memory) when handling gradients. Intermediate values and allocations likely will happen during the execution, and these can be in PennyLane side for the quantum part, or Torch side when handling the backpropagation and dealing with the outputs.

I will attach the output of https://github.com/bloomberg/memray 's live analysis, which shows memory ownership and allocations on the heap for the given workload:
Screenshot from 2024-09-25 10-32-05

You can repeat the above with pip install memray && python -m memray run --live your_script.py which should give you direct snapshots into what is happening across the stack, rather than just at the single function level with memory_profiler.

If you think I may be incorrect about the above analysis, feel free to let us know. If this is the case, some more details about the scale of your problem in full, the hardware it runs on, the Torch version you are using, or a smaller problem size that replicates the given report may help to identify the issue.

@JustinS6626
Copy link
Author

Thank you very much! I tried running my code with memray and I got some results from the analysis. My code is an on-policy reinforcement learning model, so it alternates between a policy sampling phase where gradient calculation is off and an update phase where gradient calculation is on.
Here is screenshot of memray during the policy sampling phase
Screenshot from 2024-09-25 12-09-09
and another from the update phase:
Screenshot from 2024-09-25 13-14-07
. I will try using the torch.cuda.empty_cache() function, but I am wondering if the the results that I got from memray indicate any other solutions.

@mlxd
Copy link
Member

mlxd commented Sep 25, 2024

Hey @JustinS6626

If you click the own bytes tab in the visualization, it will order the list from largest memory ownership to smallest. Keeping an eye on this may help showcase where the changes occur --- though, expect some of these to increase/decrease during execution as memory is allocated and freed from the heap. You can also try to periodically call the Python garbage collector after the CUDA empty cache calls, which may help. Lastly, when zeroing the gradients, you can also additionally set them to none, which will release the memory that was previously used entirely (if this doesn't help, best to avoid it as a deallocation and reallocation will occur on the next gradient pass). Though, this may also modify behaviour depending upon how your workload is set up.

@JustinS6626
Copy link
Author

Thanks! I tried using the torch.cuda.empty_cache() along with the garbage collector, which reduced the memory leak to some extent, but didn't fix the problem completely. The Optimizer.zero_grad() function already sets the gradients to None by default, and it doesn't have any effect. I wonder if the problem is occuring in how PennyLane connects to PyTorch during gradient calculation. Could that be the source of the issue?

@JustinS6626
Copy link
Author

Based on the feedback that I got from memray, it looks like the trouble spots are in the PennyLane-PyTorch interface.

@mlxd
Copy link
Member

mlxd commented Sep 26, 2024

Hi @JustinS6626
Unfortunately, I cannot reproduce the issue locally, as I still see memory use being bounded with the provided example using the profiling of the forward method.

Some suggestions that may be useful to identify the root cause:

  • Can you let me know what version of PyTorch are you using? If an older version, have you tried updating to the latest release?
  • If you are running validation checks as part of your loop, you may be still accumulating gradients. It could be worth adding @torch.no_grad() as a function decorator for such cases, as this will prevent anything being tracked (see https://pytorch.org/docs/stable/generated/torch.no_grad.html).
  • If you remove the quantum circuit execution entirely from your code (as in, just create a dummy function that returns a torch array supporting gradients in place of the circuit execution), does the memory usage grow in time? If so, the issue is likely confined to Torch. If not, then we at least know there's some interaction happening to produce the problem.
  • If you can reduce your larger workload to a smaller example that I can reproduce the issue with, that would be a big help, as I cannot see leaking with the current example on my machines (PL 0.38, Torch 2.4.1). Maybe a rewrite of the above script into function calls that can independently be inspected could help here.
  • You mention the PennyLane-PyTorch interface being a source of the issue. Do you have any more information we could use to inspect the problem here?

@JustinS6626
Copy link
Author

Thanks for getting back to me again! My PyTorch version is V2.4.1, which I think is the latest stable version. The loss calculation only happens for actual updates, without validation chacks. From looking at the memray feedback, it seems that at least part of the root of the problem is in the execute_and_cache_jacobian function of the DeviceDerivatives class in the pennylane.workflow.jacobian_products module. I looked at the code for that function, and from what I can tell, it stores the jacobian as well as result of the calculation done with it. In the case of my code, it seems there is some instance of the DeviceDerivatives class that is not getting the signal to get rid of this stored data once it's no longer needed. Based on that, I am wondering if there is a way to access all instances of the class manually in tell them to free the memory that they are holding onto once the gradients have been calculated and applied.

@albi3ro
Copy link
Contributor

albi3ro commented Sep 27, 2024

Thanks for your patience @JustinS6626 .

When you saw the issues with default.qubit, where you using adjoint or backprop? If you manually specified diff_method="adjoint", can you try using the backprop default instead? If the issue is still present with end-to-end torch backpropagation, we could eliminate any adjoint-specific logic as the source of the problem.

Could you also try setting device_vjp=True in the qnode?

@qml.qnode(device, device_vjp=True)
...

That will follow a slightly different logical pathway, and should actually be substantially more efficient for your type of problem (overall scalar function, quantum component with many observables). If device_vjp=True also shows the issue, we can also eliminate full adjoint-jacobian logic as the source of the issue.

Each instance of the class should only exist inside a single execution, so I'd be rather concerned if the Python garbage collector is somehow broken and not collecting up local variables between function calls. The DeviceDerivatives just serves to hold onto the jacobian between the forward and backward passes, and should be fully cleaned up by the time vjp's are registered with torch and results are returned.

EDIT: Additional idea: If you want to avoid caching the jacobian on the forward pass, you can also set grad_on_execution=False. This way the jacobian won't need to be cached between the forward and backward passes, but we will have to perform additional work on the backward pass.

@JustinS6626
Copy link
Author

Thanks! I tried it with both the default and adjoint methods, and the memory leak still happens in both cases. I tried also with the device_vjp=True setting, but the same thing occurred in that case as well. I will try it out with the grad_on_execution=False setting, and I am also running a test with a completely classical version of my code in order to rule out the possibility that it is a PyTorch issue.

@JustinS6626
Copy link
Author

I tried out the grad_on_execution=False option and the same thing happened again. When I look at the memray outputs, one common observation that occurs regardless of which settings I use is that the pennylane.devices._legacy_device and the pennylane.devices.legacy_facade modules seem to have a notable impact on memory usage. Does that provide any helpful clues? Also, I apologize for asking, but I am wondering if someone would be willing to keep in touch with me on this thread over the weekend.

@albi3ro
Copy link
Contributor

albi3ro commented Sep 27, 2024

Given it happens with both default.qubit and lightning.qubit, it would not be a byproduct of the legacy device interface. I would expect the simulation (occuring inside the device) to be one of the most memory-intensive parts of the task. Are you seeing the amount of memory they consume growing between iterations in an optimization?

Just to double-check, how many wires are you using? Because if you are running into simulation memory issues with 8 wires, I am even more confused.

@mlxd
Copy link
Member

mlxd commented Sep 27, 2024

Just as a follow up, were you able to run the classical part without hitting PennyLane too?
Did the memory grow in this scenario?

@JustinS6626
Copy link
Author

The same issue did not appear to be happening in the classical version, so I think it is PennyLane-specific. The amount of memory consumed increases steadily between optimization iterations, and does not decrease to its original level when the model begins its next update process. In the attached file is the code that is producing the memory leak that I am observing. It's a modification of the code that published with the following paper: https://proceedings.neurips.cc/paper_files/paper/2022/file/69413f87e5a34897cd010ca698097d0a-Paper-Conference.pdf It should allow the duplication of the memory leak problem. In order to run it, you will need to install the Gymnasium and Minigrid packages from Farama, as well as imageio. The code can be executed with Multi-Agent-Transformer/mat/scripts/train_minigrid.sh. The main process for gradient calculation and updated is controlled from the file Multi-Agent-Transformer/mat/algorithms/mat/mat_trainer.py. Within the latest PennyLane version, there is an error on line 539 of the file pennylane.workflow.execution.py - the LightningVJPs class should not be initialized with a gradient_kwargs argument.
Multi-Agent-Transformer.zip

@mlxd
Copy link
Member

mlxd commented Sep 27, 2024

Thanks again for your input @JustinS6626 as we recognize this is an important issue for your workload.

At the current time, the best we can do right now is to track this as an item to investigate on our roadmap, and provide feedback when possible. Per your comment on availability over the weekend, we likely will not be available to respond again until next week. If you have any insights before then, we can try to help out once we are again free to look into it.

@JustinS6626
Copy link
Author

Thank you very much! In the mean time, is there something that I can do to implement manual control over the memory use of the simulation process while the gradient is being calculated? Also, if someone would be willing to try running my code to reproduce the issue and see if I have made any mistakes that would trick the simulator into holding onto memory, I would really appreciate that.

@CatalinaAlbornoz
Copy link
Contributor

Hi @JustinS6626 , unfortunately we're not able to test your code at the moment, and since we couldn't replicate your issue with the code you shared earlier there's not much more we can do.

At this point, we know that it's not any specific device (happens on default.qubit, lightning.qubit, and lightning.kokkos), no specific differentiation method (both backprop and adjoint). Given it still occurs with backprop, that means it has nothing to do with how we bind jacobians to torch, as we don't manually do so with backprop.

I suggest we keep this issue open as well as the thread you made in the Forum. This way if someone else has either the same issue or a solution, they can post it and we can look into it further.

We've noted the issue and will keep an eye on anything that could give us clues about the memory issues you're seeing.

We'll post here if we find anything.

@JustinS6626
Copy link
Author

Thanks for getting back to me again! I am trying a workaround right now, and I will let you know what happens.

@JustinS6626
Copy link
Author

Update: My workaround was successful! It seems that when you implement PennyLane variational quantum circuit layers as PyTorch nn.Module instances and set them up as elements of an nn.Sequential object, the forward pass needs to be called on the nn.Sequential instance itself, not on the individual elements in turn, otherwise it produces stranded data that accumulates in memory during the gradient calculation process.

@CatalinaAlbornoz
Copy link
Contributor

Thanks for letting us know @JustinS6626 !! And great work finding the workaround!!
Are you able to post the line(s) of code that didn't work and the one(s) that did work so that it's clear for anyone running into the same issue?

@JustinS6626
Copy link
Author

For sure! I can give an example here:

n_qubits = 8
n_layers = 4
#For this example, assume that QuantumBlock is an nn.Module subclass that calls an instance of qml.QNode
qblocks = nn.Sequential(*[QuantumBlock(n_qubits) for _ in range(n_layers)])
#Call forward passes like this :)
y = qblocks(x)
#Not like this :(
for block in qblocks:
    x = block(x)


@CatalinaAlbornoz
Copy link
Contributor

CatalinaAlbornoz commented Oct 10, 2024

Thanks @JustinS6626 ! Good example.

@CatalinaAlbornoz
Copy link
Contributor

I'll close the issue now that we know it's a Torch issue.

@JustinS6626
Copy link
Author

Apologies for the delay! I was able to fix the memory leak when I moved the quantum network architecture to the code for a different model, but the issue remains in the original code, and I have learned that it actually caused by something other the setup of nn.Sequential calls. I will update again once I determine the cause.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants