Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Cannot differentiate any coefficients in new opmath expvals in a jitted qnode #5505

Closed
Qottmann opened this issue Apr 12, 2024 · 0 comments
Labels
bug 🐛 Something isn't working

Comments

@Qottmann
Copy link
Contributor

The problem is we are hitting the is_hermitian check in pennylane/measurements/expval.py which does not work for abstract tracers.

Same problem for qml.sum(qml.s_prod(x, X(0))) or qml.prod(qml.s_prod(x, X(0)))

dev = qml.device("default.qubit")

@jax.jit
@qml.qnode(dev, interface="jax")
def qnode(x):
    return qml.expval(qml.s_prod(x, X(0)))

x = jnp.array([0.5])
jax.grad(qnode)(x)
File ~/Xanadu/pennylane/pennylane/measurements/expval.py:74, in expval(op)
     68 if isinstance(op, qml.Identity) and len(op.wires) == 0:
     69     # temporary solution to merge https://github.com/PennyLaneAI/pennylane/pull/5106
     70     raise NotImplementedError(
     71         "Expectation values of qml.Identity() without wires are currently not allowed."
     72     )
---> 74 if not op.is_hermitian:
     75     warnings.warn(f"{op.name} might not be hermitian.")
     77 return ExpectationMP(obs=op)

File ~/Xanadu/pennylane/pennylane/ops/op_math/sprod.py:201, in SProd.is_hermitian(self)
    197 @property
    198 def is_hermitian(self):
    199     """If the base operator is hermitian and the scalar is real,
    200     then the scalar product operator is hermitian."""
--> 201     return self.base.is_hermitian and not qml.math.iscomplex(self.scalar)

    [... skipping hidden 1 frame]

File ~/anaconda3/envs/pennylane311/lib/python3.11/site-packages/jax/_src/core.py:1510, in concretization_function_error.<locals>.error(self, arg)
   1509 def error(self, arg):
-> 1510   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[1]..
The error occurred while tracing the function qnode at /var/folders/yv/h7q98p6d2td8vzzbdszwc5l40000gq/T/ipykernel_49113/2688043293.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
@Qottmann Qottmann added the bug 🐛 Something isn't working label Apr 12, 2024
Qottmann added a commit that referenced this issue Apr 15, 2024
… tracing (#5506)

Solves #5505 and also
fixes the same issue in catalyst

tldr: The `is_hermitian` check is breaking jit-compilation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants