-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] TorchLayer does not work correctly with broadcasting and tuple
returns
#5762
Closed
1 task done
Labels
bug 🐛
Something isn't working
Comments
PietropaoloFrisoni
added a commit
that referenced
this issue
Jun 12, 2024
…uple` returns (#5816) **Context:** This bug was caught using the following code: ``` import numpy as np import pennylane as qml import torch n_qubits = 2 dev = qml.device("default.qubit", wires=n_qubits) @qml.qnode(dev) def qnode(inputs, weights): qml.templates.AngleEmbedding(inputs, wires=range(n_qubits)) qml.templates.StronglyEntanglingLayers(weights, wires=range(n_qubits)) return qml.expval(qml.Z(0)), qml.expval(qml.Z(1)) weight_shapes = {"weights": [3, n_qubits, 3]} qlayer = qml.qnn.TorchLayer(qnode, weight_shapes) x = torch.tensor(np.random.random((5, 2))) # Batched inputs with batch dim 5 for 2 qubits qlayer.forward(x) ``` The `forward` function, in `pennylane/qnn/torch.py`, calls the `_evaluate_qnode` function under the hood. The issue is caused by the fact that the latter only works with tuples if these contain lists of `torch.Tensor`. In the code above, the `_evaluate_qnode` function is called with a tuple of `torch.Tensor`, which causes the error. **Description of the Change:** We intercept such a case to temporarily transform a tuple of `torch.Tensor` into a tuple of lists so that the code works as originally expected. Obviously, this is not the only possible solution to the problem. Still, it is a non-invasive way to solve the problem without changing the original intended tuples workflow (which is beyond the scope of this PR). **Benefits:** The above code no longer throws an error. **Possible Drawbacks:** None that I can think of caused by this PR since we are not changing the original workflow. We are just fixing an existing bug by intercepting a specific case that can occur. **Related GitHub Issues:** #5762 **Related Shortcut Stories:** [sc-64283]
1 task
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Expected behavior
The code below runs.
Actual behavior
It errors out, because of invalid internal reshaping.
Additional information
If we return a list from the QNode instead, the example works fine.
Source code
Tracebacks
System information
Existing GitHub issues
The text was updated successfully, but these errors were encountered: