Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: TorchLayer does not work correctly with broadcasting and tuple returns #5816

Merged
merged 14 commits into from
Jun 12, 2024

Conversation

PietropaoloFrisoni
Copy link
Contributor

@PietropaoloFrisoni PietropaoloFrisoni commented Jun 7, 2024

Context: This bug was caught using the following code:

import numpy as np
import pennylane as qml
import torch

n_qubits = 2
dev = qml.device("default.qubit", wires=n_qubits)


@qml.qnode(dev)
def qnode(inputs, weights):
    qml.templates.AngleEmbedding(inputs, wires=range(n_qubits))
    qml.templates.StronglyEntanglingLayers(weights, wires=range(n_qubits))
    return qml.expval(qml.Z(0)), qml.expval(qml.Z(1))


weight_shapes = {"weights": [3, n_qubits, 3]}

qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)
x = torch.tensor(np.random.random((5, 2)))  # Batched inputs with batch dim 5 for 2 qubits


qlayer.forward(x)

The forward function, in pennylane/qnn/torch.py, calls the _evaluate_qnode function under the hood. The issue is caused by the fact that the latter only works with tuples if these contain lists of torch.Tensor. In the code above, the _evaluate_qnode function is called with a tuple of torch.Tensor, which causes the error.

Description of the Change: We intercept such a case to temporarily transform a tuple of torch.Tensor into a tuple of lists so that the code works as originally expected. Obviously, this is not the only possible solution to the problem. Still, it is a non-invasive way to solve the problem without changing the original intended tuples workflow (which is beyond the scope of this PR).

Benefits: The above code no longer throws an error.

Possible Drawbacks: None that I can think of caused by this PR since we are not changing the original workflow. We are just fixing an existing bug by intercepting a specific case that can occur.

Related GitHub Issues: #5762

Related Shortcut Stories: [sc-64283]

@PennyLaneAI PennyLaneAI deleted a comment from codecov bot Jun 11, 2024
@PietropaoloFrisoni
Copy link
Contributor Author

I don't know why @codecov complains, since the new line is covered in the test implemented in this PR (to check this, just remove the line and verify that the test fails).

@PennyLaneAI PennyLaneAI deleted a comment from codecov bot Jun 11, 2024
Copy link

codecov bot commented Jun 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.67%. Comparing base (c6a7cb0) to head (d6049da).
Report is 254 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5816      +/-   ##
==========================================
- Coverage   99.67%   99.67%   -0.01%     
==========================================
  Files         418      420       +2     
  Lines       40080    39788     -292     
==========================================
- Hits        39951    39658     -293     
- Misses        129      130       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@astralcai astralcai self-requested a review June 11, 2024 18:23
Copy link
Contributor

@astralcai astralcai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine as a patch, but I wonder if this is going to work with different combinations of parameter broadcasting/no broadcasting, shot vectors/no shot vectors, multiple measurements/single measurement. A more permanent fix would probably be to extract information such as self.qnode.tape.batch_size, self.qnode.tape.shots.has_partitioned_shots and len(self.qnode.tape.measurements) and use this information to decide what to do with the qnode output. See the current implementation of split_non_commuting for an example.

@PietropaoloFrisoni
Copy link
Contributor Author

PietropaoloFrisoni commented Jun 11, 2024

@astralcai I agree. As I mentioned in the description of the PR, this is just a non-invasive (and temporary) fix to solve the original issue, but it would be appropriate to revisit the entire workflow at some point (here we are simply patching an exception without modifying the original workflow). Unfortunately, I don't have time to implement further/deeper changes right now. Maybe we can insert this in the technical roadmap for next quarter?

@PietropaoloFrisoni PietropaoloFrisoni marked this pull request as ready for review June 12, 2024 11:16
Copy link
Contributor

@dwierichs dwierichs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @PietropaoloFrisoni for the fast bugfix!
I'd second your viewpoint on this: It does feel like a hot fix, but this logical branch anyways is defined in a very specific manner, and I would associate it with the requirement that we need to interface two libraries in TorchLayer.
LGTM!

Copy link
Contributor

@astralcai astralcai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, LGTM then 👍🏻

@PietropaoloFrisoni PietropaoloFrisoni enabled auto-merge (squash) June 12, 2024 17:00
@PietropaoloFrisoni PietropaoloFrisoni enabled auto-merge (squash) June 12, 2024 18:17
@PietropaoloFrisoni PietropaoloFrisoni merged commit d5f6313 into master Jun 12, 2024
40 checks passed
@PietropaoloFrisoni PietropaoloFrisoni deleted the bugfix-TorchLayer_tuples branch June 12, 2024 19:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants