Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] qml.Hermitian doesn’t work with jax-jit #6640

Closed
1 task done
CatalinaAlbornoz opened this issue Nov 26, 2024 · 0 comments
Closed
1 task done

[BUG] qml.Hermitian doesn’t work with jax-jit #6640

CatalinaAlbornoz opened this issue Nov 26, 2024 · 0 comments
Labels
bug 🐛 Something isn't working

Comments

@CatalinaAlbornoz
Copy link
Contributor

Expected behavior

Getting the expval of a Hermitian should jittable

Actual behavior

It throws an error

Additional information

Originally posted in this forum thread.

Indenting Hermitian._validate_input(A, expected_mx_shape) in observables.py partly solves the issue.

A is captured as a const, but it fails with the allclose comparison since qml.math.allclose(A, qml.math.T(qml.math.conj(A))) returns a Tracer (and we cannot use it as a standard bool value).
Even if A is captured and traced, it seems that the code still fails because the _validate_input staticmethod is called in other methods without checking if the value is abstract. But probably we just need to update the Hermitian class.

Source code

import jax
import jax.numpy as jnp
import pennylane as qml

# jax.config.update("jax_enable_x64", True)
dev = qml.device("default.qubit", wires=2)
y=jnp.array([[1,0],[0,-1]])

@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(param):
    qml.RX(param, wires=0)
    qml.CNOT(wires=[0, 1])
    return qml.expval(qml.Hermitian(y, wires=[1]))

print(f"Result: {repr(circuit(jnp.array([0.123,0.25])))}")

Tracebacks

TracerBoolConversionError                 Traceback (most recent call last)
Cell In[2], line 16
     13     qml.CNOT(wires=[0, 1])
     14     return qml.expval(qml.Hermitian(y, wires=[1]))
---> 16 print(f"Result: {repr(circuit(jnp.array([0.123,0.25])))}")

    [... skipping hidden 11 frame]

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/workflow/qnode.py:987, in QNode.__call__(self, *args, **kwargs)
    985 if qml.capture.enabled():
    986     return qml.capture.qnode_call(self, *args, **kwargs)
--> 987 return self._impl_call(*args, **kwargs)

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/workflow/qnode.py:963, in QNode._impl_call(self, *args, **kwargs)
    960 def _impl_call(self, *args, **kwargs) -> qml.typing.Result:
    961 
    962     # construct the tape
--> 963     self.construct(args, kwargs)
    965     old_interface = self.interface
    966     if old_interface == "auto":

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/workflow/qnode.py:857, in QNode.construct(self, args, kwargs)
    855 with pldb_device_manager(self.device):
    856     with qml.queuing.AnnotatedQueue() as q:
--> 857         self._qfunc_output = self.func(*args, **kwargs)
    859 self._tape = QuantumScript.from_queue(q, shots)
    861 params = self.tape.get_parameters(trainable_only=False)

Cell In[2], line 14, in circuit(param)
     12 qml.RX(param, wires=0)
     13 qml.CNOT(wires=[0, 1])
---> 14 return qml.expval(qml.Hermitian(y, wires=[1]))

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/capture/capture_meta.py:89, in CaptureMeta.__call__(cls, *args, **kwargs)
     85 if enabled():
     86     # when tracing is enabled, we want to
     87     # use bind to construct the class if we want class construction to add it to the jaxpr
     88     return cls._primitive_bind_call(*args, **kwargs)
---> 89 return type.__call__(cls, *args, **kwargs)

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/ops/qubit/observables.py:87, in Hermitian.__init__(self, A, wires, id)
     83     else:
     84         # Assumably wires is an int; further validation checks are performed by calling super().__init__
     85         expected_mx_shape = self._num_basis_states
---> 87     Hermitian._validate_input(A, expected_mx_shape)
     89 super().__init__(A, wires=wires, id=id)

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/pennylane/ops/qubit/observables.py:103, in Hermitian._validate_input(A, expected_mx_shape)
     97 if expected_mx_shape is not None and A.shape[0] != expected_mx_shape:
     98     raise ValueError(
     99         f"Expected input matrix to have shape {expected_mx_shape}x{expected_mx_shape}, but "
    100         f"a matrix with shape {A.shape[0]}x{A.shape[0]} was passed."
    101     )
--> 103 if not qml.math.allclose(A, qml.math.T(qml.math.conj(A))):
    104     raise ValueError("Observable must be Hermitian.")

    [... skipping hidden 1 frame]

File ~/anaconda3/envs/research2/lib/python3.13/site-packages/jax/_src/core.py:1554, in concretization_function_error.<locals>.error(self, arg)
   1553 def error(self, arg):
-> 1554   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function circuit at /tmp/ipykernel_20529/364763606.py:9 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[2,2] = transpose[permutation=(1, 0)] b
    from line /tmp/ipykernel_20529/364763606.py:14:22 (circuit)

  operation a:bool[] = pjit[
  name=allclose
  jaxpr={ lambda ; b:i32[2,2] c:i32[2,2] d:f32[] e:f32[]. let
      f:bool[2,2] = pjit[
        name=isclose
        jaxpr={ lambda ; g:i32[2,2] h:i32[2,2] i:f32[] j:f32[]. let
            k:f32[2,2] = convert_element_type[new_dtype=float32 weak_type=False] g
            l:f32[2,2] = convert_element_type[new_dtype=float32 weak_type=False] h
            m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] i
            n:f32[] = convert_element_type[new_dtype=float32 weak_type=False] j
            o:f32[2,2] = sub k l
            p:f32[2,2] = abs o
            q:f32[2,2] = abs l
            r:f32[2,2] = mul m q
            s:f32[2,2] = add n r
            t:bool[2,2] = le p s
            u:bool[2,2] = pjit[
              name=isinf
              jaxpr={ lambda ; v:f32[2,2]. let
                  w:f32[2,2] = abs v
                  x:bool[2,2] = eq w inf
                in (x,) }
            ] k
            y:bool[2,2] = pjit[
              name=isinf
              jaxpr={ lambda ; v:f32[2,2]. let
                  w:f32[2,2] = abs v
                  x:bool[2,2] = eq w inf
                in (x,) }
            ] l
            z:bool[2,2] = or u y
            ba:bool[2,2] = and u y
            bb:bool[2,2] = not z
            bc:bool[2,2] = and t bb
            bd:bool[2,2] = eq k l
            be:bool[2,2] = and ba bd
            bf:bool[2,2] = or bc be
            bg:bool[2,2] = ne k k
            bh:bool[2,2] = ne l l
            bi:bool[2,2] = or bg bh
            bj:bool[2,2] = not bi
            bk:bool[2,2] = and bf bj
          in (bk,) }
      ] b c e d
      bl:bool[] = reduce_and[axes=(0, 1)] f
    in (bl,) }
] bm bn bo bp
    from line /tmp/ipykernel_20529/364763606.py:14:22 (circuit)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

System information

Version of Jax:
jax==0.4.35
jaxlib==0.4.35

Output of qml.about():
Name: PennyLane
Version: 0.39.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: GitHub - PennyLaneAI/pennylane: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Author:
Author-email:
License: Apache License 2.0
Location: /home/ubuntu2022/anaconda3/envs/research2/lib/python3.13/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
Required-by: PennyLane-qiskit, PennyLane_Lightning

Platform info: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version: 3.13.0
Numpy version: 2.0.2
Scipy version: 1.14.1
Installed devices:

qiskit.aer (PennyLane-qiskit-0.39.0)
qiskit.basicaer (PennyLane-qiskit-0.39.0)
qiskit.basicsim (PennyLane-qiskit-0.39.0)
qiskit.remote (PennyLane-qiskit-0.39.0)
default.clifford (PennyLane-0.39.0)
default.gaussian (PennyLane-0.39.0)
default.mixed (PennyLane-0.39.0)
default.qubit (PennyLane-0.39.0)
default.qutrit (PennyLane-0.39.0)
default.qutrit.mixed (PennyLane-0.39.0)
default.tensor (PennyLane-0.39.0)
null.qubit (PennyLane-0.39.0)
reference.qubit (PennyLane-0.39.0)
lightning.qubit (PennyLane_Lightning-0.39.0)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@CatalinaAlbornoz CatalinaAlbornoz added the bug 🐛 Something isn't working label Nov 26, 2024
PietropaoloFrisoni added a commit that referenced this issue Dec 9, 2024
**Context:** After some discussion, we decided to completely remove the
'Hermitianity' input validation check in the `qml.Hermitian` class since
it might be expensive for very large matrices. Furthermore, it is used
every time we call the `compute_matrix` and `compute_decomposition`
methods.

Our current validation strategy is to raise errors if the corresponding
error is rather difficult to debug and/ or the sanity check is
sufficiently cheap and jittable.

In this case, the validation check is neither cheap nor jittable. And
the user might actually want to provide a non-hermitian matrix-based
observable.

**Description of the Change:** As above.

**Benefits:** Faster execution and jittability.

**Possible Drawbacks:** The main drawback is that if the user provides
by accident a non-Hermitian matrix to the `qml.Hermitian` class
believing it is actually Hermitian, he/she might experience strange
errors because the code doesn't check explicitly the input matrix
anymore.

**Related GitHub Issues:** #6640 

**Related Shortcut Stories:** [sc-79122]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants