Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] KerasLayer has unintended side effects on its QNode #5723

Closed
1 task done
isaacdevlugt opened this issue May 21, 2024 · 1 comment · Fixed by #5800
Closed
1 task done

[BUG] KerasLayer has unintended side effects on its QNode #5723

isaacdevlugt opened this issue May 21, 2024 · 1 comment · Fixed by #5800
Labels
bug 🐛 Something isn't working
Milestone

Comments

@isaacdevlugt
Copy link
Contributor

isaacdevlugt commented May 21, 2024

https://app.shortcut.com/xanaduai/story/63723/bug-keraslayer-has-unintended-side-effects-on-its-qnode

Expected behavior

Defining a Keras layer from a QNode has no side effects on the QNode itself.

Actual behavior

QNodes in a KerasLayer are mutated.

Additional information

No response

Source code

import pennylane as qml
import torch

dev = qml.device('default.qubit')

@qml.qnode(dev)
def circuit(inputs, weights):
    qml.AmplitudeEmbedding(inputs, wires=[0, 1], normalize=True)
    qml.RY(weights[0], wires=0)
    qml.RY(weights[1], wires=1)
    return qml.vn_entropy(wires=[1])

weight_shapes = {"weights": (2,)}

qlayer_torch = qml.qnn.TorchLayer(circuit, weight_shapes=weight_shapes)
qlayer_keras = qml.qnn.KerasLayer(circuit, weight_shapes=weight_shapes, output_dim=1)

inputs = torch.rand(4, requires_grad=False)

clayer = torch.nn.Softmax()
qlayer_torch(clayer(inputs))
model = torch.nn.Sequential(clayer, qlayer_torch)

model(inputs)

Tracebacks

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 21
     18 inputs = torch.rand(4, requires_grad=False)
     20 clayer = torch.nn.Softmax()
---> 21 qlayer_torch(clayer(inputs))
     22 model = torch.nn.Sequential(clayer, qlayer_torch)
     24 model(inputs)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/qnn/torch.py:402, in TorchLayer.forward(self, inputs)
    399     inputs = torch.reshape(inputs, (-1, inputs.shape[-1]))
    401 # calculate the forward pass as usual
--> 402 results = self._evaluate_qnode(inputs)
    404 if isinstance(results, tuple):
    405     if has_batch_dim:

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/qnn/torch.py:428, in TorchLayer._evaluate_qnode(self, x)
    416 """Evaluates the QNode for a single input datapoint.
    417 
    418 Args:
   (...)
    422     tensor: output datapoint
    423 """
    424 kwargs = {
    425     **{self.input_arg: x},
    426     **{arg: weight.to(x) for arg, weight in self.qnode_weights.items()},
    427 }
--> 428 res = self.qnode(**kwargs)
    430 if isinstance(res, torch.Tensor):
    431     return res.type(x.dtype)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/workflow/qnode.py:1098, in QNode.__call__(self, *args, **kwargs)
   1095 self._update_gradient_fn(shots=override_shots, tape=self._tape)
   1097 try:
-> 1098     res = self._execution_component(args, kwargs, override_shots=override_shots)
   1099 finally:
   1100     if old_interface == "auto":

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/workflow/qnode.py:1052, in QNode._execution_component(self, args, kwargs, override_shots)
   1049 full_transform_program.prune_dynamic_transform()
   1051 # pylint: disable=unexpected-keyword-arg
-> 1052 res = qml.execute(
   1053     (self._tape,),
   1054     device=self.device,
   1055     gradient_fn=self.gradient_fn,
   1056     interface=self.interface,
   1057     transform_program=full_transform_program,
   1058     config=config,
   1059     gradient_kwargs=self.gradient_kwargs,
   1060     override_shots=override_shots,
   1061     **self.execute_kwargs,
   1062 )
   1063 res = res[0]
   1065 # convert result to the interface in case the qfunc has no parameters

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/workflow/execution.py:616, in execute(tapes, device, gradient_fn, interface, transform_program, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform, device_vjp)
    614 # Exiting early if we do not need to deal with an interface boundary
    615 if no_interface_boundary_required:
--> 616     results = inner_execute(tapes)
    617     return post_processing(results)
    619 _grad_on_execution = False

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/workflow/execution.py:297, in _make_inner_execute.<locals>.inner_execute(tapes, **_)
    294 transformed_tapes, transform_post_processing = transform_program(tapes)
    296 if transformed_tapes:
--> 297     results = device_execution(transformed_tapes)
    298 else:
    299     results = ()

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/modifiers/simulator_tracking.py:30, in _track_execute.<locals>.execute(self, circuits, execution_config)
     28 @wraps(untracked_execute)
     29 def execute(self, circuits, execution_config=DefaultExecutionConfig):
---> 30     results = untracked_execute(self, circuits, execution_config)
     31     if isinstance(circuits, QuantumScript):
     32         batch = (circuits,)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/modifiers/single_tape_support.py:32, in _make_execute.<locals>.execute(self, circuits, execution_config)
     30     is_single_circuit = True
     31     circuits = (circuits,)
---> 32 results = batch_execute(self, circuits, execution_config)
     33 return results[0] if is_single_circuit else results

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/default_qubit.py:593, in DefaultQubit.execute(self, circuits, execution_config)
    590 prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
    592 if max_workers is None:
--> 593     return tuple(
    594         _simulate_wrapper(
    595             c,
    596             {
    597                 "rng": self._rng,
    598                 "debugger": self._debugger,
    599                 "interface": interface,
    600                 "state_cache": self._state_cache,
    601                 "prng_key": _key,
    602             },
    603         )
    604         for c, _key in zip(circuits, prng_keys)
    605     )
    607 vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
    608 seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/default_qubit.py:594, in <genexpr>(.0)
    590 prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
    592 if max_workers is None:
    593     return tuple(
--> 594         _simulate_wrapper(
    595             c,
    596             {
    597                 "rng": self._rng,
    598                 "debugger": self._debugger,
    599                 "interface": interface,
    600                 "state_cache": self._state_cache,
    601                 "prng_key": _key,
    602             },
    603         )
    604         for c, _key in zip(circuits, prng_keys)
    605     )
    607 vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
    608 seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/default_qubit.py:841, in _simulate_wrapper(circuit, kwargs)
    840 def _simulate_wrapper(circuit, kwargs):
--> 841     return simulate(circuit, **kwargs)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/qubit/simulate.py:287, in simulate(circuit, debugger, state_cache, **execution_kwargs)
    282     return simulate_one_shot_native_mcm(
    283         circuit, debugger=debugger, rng=rng, prng_key=prng_key, interface=interface
    284     )
    286 ops_key, meas_key = jax_random_split(prng_key)
--> 287 state, is_state_batched = get_final_state(
    288     circuit, debugger=debugger, rng=rng, prng_key=ops_key, interface=interface
    289 )
    290 if state_cache is not None:
    291     state_cache[circuit.hash] = state

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/qubit/simulate.py:150, in get_final_state(circuit, debugger, **execution_kwargs)
    148 if isinstance(op, MidMeasureMP):
    149     prng_key, key = jax_random_split(prng_key)
--> 150 state = apply_operation(
    151     op,
    152     state,
    153     is_state_batched=is_state_batched,
    154     debugger=debugger,
    155     mid_measurements=mid_measurements,
    156     rng=rng,
    157     prng_key=key,
    158 )
    159 # Handle postselection on mid-circuit measurements
    160 if isinstance(op, qml.Projector):

File /opt/homebrew/Cellar/[email protected]/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/qubit/apply_operation.py:206, in apply_operation(op, state, is_state_batched, debugger, **_)
    152 @singledispatch
    153 def apply_operation(
    154     op: qml.operation.Operator,
   (...)
    158     **_,
    159 ):
    160     """Apply and operator to a given state.
    161 
    162     Args:
   (...)
    204 
    205     """
--> 206     return _apply_operation_default(op, state, is_state_batched, debugger)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/qubit/apply_operation.py:216, in _apply_operation_default(op, state, is_state_batched, debugger)
    210 """The default behaviour of apply_operation, accessed through the standard dispatch
    211 of apply_operation, as well as conditionally in other dispatches."""
    212 if (
    213     len(op.wires) < EINSUM_OP_WIRECOUNT_PERF_THRESHOLD
    214     and math.ndim(state) < EINSUM_STATE_WIRECOUNT_PERF_THRESHOLD
    215 ) or (op.batch_size and is_state_batched):
--> 216     return apply_operation_einsum(op, state, is_state_batched=is_state_batched)
    217 return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/devices/qubit/apply_operation.py:102, in apply_operation_einsum(op, state, is_state_batched)
     99         op._batch_size = batch_size  # pylint:disable=protected-access
    100 reshaped_mat = math.reshape(mat, new_mat_shape)
--> 102 return math.einsum(einsum_indices, reshaped_mat, state)

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/math/multi_dispatch.py:547, in einsum(indices, like, optimize, *operands)
    501 """Evaluates the Einstein summation convention on the operands.
    502 
    503 Args:
   (...)
    544 array([ 30,  80, 130, 180, 230])
    545 """
    546 if like is None:
--> 547     like = get_interface(*operands)
    548 operands = np.coerce(operands, like=like)
    549 if optimize is None or like == "torch":
    550     # torch einsum doesn't support the optimize keyword argument

File ~/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages/pennylane/math/utils.py:221, in get_interface(*values)
    217 interfaces = {_get_interface_of_single_tensor(v) for v in values}
    219 if len(interfaces - {"numpy", "scipy", "autograd"}) > 1:
    220     # contains multiple non-autograd interfaces
--> 221     raise ValueError("Tensors contain mixed types; cannot determine dispatch library")
    223 non_numpy_scipy_interfaces = set(interfaces) - {"numpy", "scipy"}
    225 if len(non_numpy_scipy_interfaces) > 1:
    226     # contains autograd and another interface

ValueError: Tensors contain mixed types; cannot determine dispatch library

System information

Name: PennyLane
Version: 0.36.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /Users/isaac/.virtualenvs/pl-qiskit-1.0/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-qiskit, PennyLane_Lightning

Platform info:           macOS-14.5-arm64-arm-64bit
Python version:          3.11.8
Numpy version:           1.26.4
Scipy version:           1.13.0
Installed devices:
- default.clifford (PennyLane-0.36.0)
- default.gaussian (PennyLane-0.36.0)
- default.mixed (PennyLane-0.36.0)
- default.qubit (PennyLane-0.36.0)
- default.qubit.autograd (PennyLane-0.36.0)
- default.qubit.jax (PennyLane-0.36.0)
- default.qubit.legacy (PennyLane-0.36.0)
- default.qubit.tf (PennyLane-0.36.0)
- default.qubit.torch (PennyLane-0.36.0)
- default.qutrit (PennyLane-0.36.0)
- default.qutrit.mixed (PennyLane-0.36.0)
- null.qubit (PennyLane-0.36.0)
- qiskit.aer (PennyLane-qiskit-0.36.0)
- qiskit.basicaer (PennyLane-qiskit-0.36.0)
- qiskit.basicsim (PennyLane-qiskit-0.36.0)
- qiskit.ibmq (PennyLane-qiskit-0.36.0)
- qiskit.ibmq.circuit_runner (PennyLane-qiskit-0.36.0)
- qiskit.ibmq.sampler (PennyLane-qiskit-0.36.0)
- qiskit.remote (PennyLane-qiskit-0.36.0)
- lightning.qubit (PennyLane_Lightning-0.36.0)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@isaacdevlugt isaacdevlugt added the bug 🐛 Something isn't working label May 21, 2024
@isaacdevlugt isaacdevlugt added this to the v0.37 milestone May 21, 2024
@albi3ro
Copy link
Contributor

albi3ro commented May 21, 2024

This is due to:

self.qnode.interface = "tf"

where it sets the qnode's interface to tf. We could potentially remove this line and replace it with a validation check that ensures the interface is in auto/ tf.

@mudit2812 mudit2812 linked a pull request Jun 4, 2024 that will close this issue
mudit2812 added a commit that referenced this issue Jun 6, 2024
**Context:**
Since `KerasLayer` and `TorchLayer` both update the input QNode's
interface, this was causing trouble if the same QNode was used for
constructing both a `KerasLayer` and `TorchLayer`.

**Description of the Change:**
* Instead of mutating the QNode's interface, check that the interface of
the QNode is expected (`"tf"` or equivalent for `KerasLayer`, and
`"torch"` or equivalent for `TorchLayer`)
* Tests

**Benefits:**
* QNodes are no longer sneakily mutated.

**Possible Drawbacks:**
* New error, potential breaking change, but it might be minor enough for
us to not worry about it.

**Related GitHub Issues:**
#5723
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants