Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] qjit/jit compatibility with operators that do similar things to Superposition is hit-or-miss #6790

Open
1 task done
isaacdevlugt opened this issue Jan 9, 2025 · 0 comments
Labels
bug 🐛 Something isn't working

Comments

@isaacdevlugt
Copy link
Contributor

Expected behavior

I expect that the input types for templates like Superposition are more robust.

import jax
from jax import numpy as jnp

coeffs = jnp.array([1/3, 1/3, 1/3])
bases = jnp.array([[1, 1, 1], [0, 1, 0], [0, 0, 0]])
wires = [0, 1, 2]
work_wire = 3

dev = qml.device('default.qubit')

@jax.jit
@qml.qnode(dev)
def circuit():
    qml.Superposition(jnp.sqrt(coeffs), bases, wires, work_wire)
    return qml.probs(wires)

circuit()

This doesn't work (nasty tracer error) because bases is a JAX array, but works when it's vanilla Numpy. Similar things happen when coeffs is defined as follows:

coeffs = jnp.sqrt(jnp.array([1/3, 1/3, 1/3]))
...
    qml.Superposition(coeffs, bases, wires, work_wire)

Actual behavior

Tracer errors

Additional information

We've had internal discussions about this, and this may be fixable with changes to is_abstract(coeffs). This issue would also happen with other templates like BasisState.

Source code

import jax
from jax import numpy as jnp

coeffs = jnp.array([1/3, 1/3, 1/3])
bases = jnp.array([[1, 1, 1], [0, 1, 0], [0, 0, 0]])
wires = [0, 1, 2]
work_wire = 3

dev = qml.device('default.qubit')

@jax.jit
@qml.qnode(dev)
def circuit():
    qml.Superposition(jnp.sqrt(coeffs), bases, wires, work_wire)
    return qml.probs(wires)

circuit()

Tracebacks

---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[19], [line 17](vscode-notebook-cell:?execution_count=19&line=17)
     [14](vscode-notebook-cell:?execution_count=19&line=14)     qml.Superposition(np.sqrt(coeffs), bases, wires, work_wire)
     [15](vscode-notebook-cell:?execution_count=19&line=15)     return qml.probs(wires)
---> [17](vscode-notebook-cell:?execution_count=19&line=17) circuit()

    [... skipping hidden 11 frame]

File ~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:909, in QNode.__call__(self, *args, **kwargs)
    [907](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:907) if qml.capture.enabled():
    [908](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:908)     return capture_qnode(self, *args, **kwargs)
--> [909](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:909) return self._impl_call(*args, **kwargs)

File ~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:872, in QNode._impl_call(self, *args, **kwargs)
    [869](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:869) def _impl_call(self, *args, **kwargs) -> qml.typing.Result:
    [870](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:870) 
    [871](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:871)     # construct the tape
--> [872](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:872)     tape = self.construct(args, kwargs)
    [874](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:874)     if self.interface == "auto":
    [875](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:875)         interface = qml.math.get_interface(*args, *list(kwargs.values()))

File ~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     [54](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/logging/decorators.py:54)     s_caller = "::L".join(
     [55](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/logging/decorators.py:55)         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     [56](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/logging/decorators.py:56)     )
     [57](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/logging/decorators.py:57)     lgr.debug(
     [58](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/logging/decorators.py:58)         f"Calling {f_string} from {s_caller}",
     [59](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/logging/decorators.py:59)         **_debug_log_kwargs,
     [60](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/logging/decorators.py:60)     )
---> [61](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/logging/decorators.py:61) return func(*args, **kwargs)

File ~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:858, in QNode.construct(self, args, kwargs)
    [856](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:856) with pldb_device_manager(self.device):
    [857](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:857)     with qml.queuing.AnnotatedQueue() as q:
--> [858](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:858)         self._qfunc_output = self.func(*args, **kwargs)
    [860](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:860) tape = QuantumScript.from_queue(q, shots)
    [862](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/workflow/qnode.py:862) params = tape.get_parameters(trainable_only=False)

Cell In[19], [line 14](vscode-notebook-cell:?execution_count=19&line=14)
     [11](vscode-notebook-cell:?execution_count=19&line=11) @jax.jit
     [12](vscode-notebook-cell:?execution_count=19&line=12) @qml.qnode(dev)
     [13](vscode-notebook-cell:?execution_count=19&line=13) def circuit():
---> [14](vscode-notebook-cell:?execution_count=19&line=14)     qml.Superposition(np.sqrt(coeffs), bases, wires, work_wire)
     [15](vscode-notebook-cell:?execution_count=19&line=15)     return qml.probs(wires)

File ~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/capture/capture_meta.py:89, in CaptureMeta.__call__(cls, *args, **kwargs)
     [85](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/capture/capture_meta.py:85) if enabled():
     [86](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/capture/capture_meta.py:86)     # when tracing is enabled, we want to
     [87](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/capture/capture_meta.py:87)     # use bind to construct the class if we want class construction to add it to the jaxpr
     [88](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/capture/capture_meta.py:88)     return cls._primitive_bind_call(*args, **kwargs)
---> [89](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/capture/capture_meta.py:89) return type.__call__(cls, *args, **kwargs)

File ~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:213, in Superposition.__init__(self, coeffs, bases, wires, work_wire, id)
    [209](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:209) def __init__(
    [210](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:210)     self, coeffs, bases, wires, work_wire, id=None
    [211](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:211) ):  # pylint: disable=too-many-positional-arguments, too-many-arguments
--> [213](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:213)     if not all(
    [214](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:214)         all(qml.math.isclose(i, 0.0) or qml.math.isclose(i, 1.0) for i in b) for b in bases
    [215](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:215)     ):
    [216](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:216)         raise ValueError("The elements of the basis states must be either 0 or 1.")
    [218](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:218)     basis_lengths = {len(b) for b in bases}

File ~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:214, in <genexpr>(.0)
    [209](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:209) def __init__(
    [210](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:210)     self, coeffs, bases, wires, work_wire, id=None
    [211](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:211) ):  # pylint: disable=too-many-positional-arguments, too-many-arguments
    [213](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:213)     if not all(
--> [214](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:214)         all(qml.math.isclose(i, 0.0) or qml.math.isclose(i, 1.0) for i in b) for b in bases
    [215](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:215)     ):
    [216](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:216)         raise ValueError("The elements of the basis states must be either 0 or 1.")
    [218](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:218)     basis_lengths = {len(b) for b in bases}

File ~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:214, in <genexpr>(.0)
    [209](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:209) def __init__(
    [210](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:210)     self, coeffs, bases, wires, work_wire, id=None
    [211](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:211) ):  # pylint: disable=too-many-positional-arguments, too-many-arguments
    [213](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:213)     if not all(
--> [214](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:214)         all(qml.math.isclose(i, 0.0) or qml.math.isclose(i, 1.0) for i in b) for b in bases
    [215](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:215)     ):
    [216](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:216)         raise ValueError("The elements of the basis states must be either 0 or 1.")
    [218](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:218)     basis_lengths = {len(b) for b in bases}

    [... skipping hidden 1 frame]

File ~/.virtualenvs/pl-latest/lib/python3.11/site-packages/jax/_src/core.py:1475, in concretization_function_error.<locals>.error(self, arg)
   [1474](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/jax/_src/core.py:1474) def error(self, arg):
-> [1475](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/Documents/~/.virtualenvs/pl-latest/lib/python3.11/site-packages/jax/_src/core.py:1475)   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function circuit at /var/folders/cn/h46l05vn2qd9c7ldxf0g905c0000gq/T/ipykernel_19196/4016252023.py:11 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[3] b:i32[3] c:i32[3] = pjit[
  name=_unstack
  jaxpr={ lambda ; d:i32[3,3]. let
      e:i32[1,3] = slice[
        limit_indices=(1, 3)
        start_indices=(0, 0)
        strides=(1, 1)
      ] d
      f:i32[3] = squeeze[dimensions=(0,)] e
      g:i32[1,3] = slice[
        limit_indices=(2, 3)
        start_indices=(1, 0)
        strides=(1, 1)
      ] d
      h:i32[3] = squeeze[dimensions=(0,)] g
      i:i32[1,3] = slice[
        limit_indices=(3, 3)
        start_indices=(2, 0)
        strides=(1, 1)
      ] d
      j:i32[3] = squeeze[dimensions=(0,)] i
    in (f, h, j) }
] k
    from line /Users/isaac/.virtualenvs/pl-latest/lib/python3.11/site-packages/pennylane/templates/state_preparations/superposition.py:213:18 (Superposition.__init__.<locals>.<genexpr>)

  operation a:bool[] = pjit[
  name=isclose
  jaxpr={ lambda ; b:i32[] c:f32[]. let
      d:f32[] = convert_element_type[new_dtype=float32 weak_type=True] b
      e:f32[] = sub d c
      f:f32[] = abs e
      g:f32[] = abs c
      h:f32[] = mul 9.999999747378752e-06 g
      i:f32[] = add 9.99999993922529e-09 h
      j:bool[] = le f i
      k:bool[] = pjit[
        name=isinf
        jaxpr={ lambda ; l:f32[]. let
            m:f32[] = abs l
            n:bool[] = eq m inf
          in (n,) }
      ] d
      o:bool[] = pjit[
        name=isinf
        jaxpr={ lambda ; l:f32[]. let
            m:f32[] = abs l
            n:bool[] = eq m inf
          in (n,) }
      ] c
      p:bool[] = or k o
      q:bool[] = and k o
      r:bool[] = not p
      s:bool[] = and j r
      t:bool[] = eq d c
      u:bool[] = convert_element_type[new_dtype=bool weak_type=False] t
      v:bool[] = and q u
      w:bool[] = or s v
      x:bool[] = ne d d
      y:bool[] = ne c c
      z:bool[] = convert_element_type[new_dtype=bool weak_type=False] x
      ba:bool[] = convert_element_type[new_dtype=bool weak_type=False] y
      bb:bool[] = or z ba
      bc:bool[] = not bb
      bd:bool[] = and w bc
    in (bd,) }
] be bf
    from line /Users/isaac/.virtualenvs/pl-latest/lib/python3.11/site-packages/autoray/autoray.py:81:11 (do)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

System information

Name: PennyLane
Version: 0.40.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /Users/isaac/.virtualenvs/pl-latest/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing_extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info:           macOS-15.2-arm64-arm-64bit
Python version:          3.11.9
Numpy version:           2.0.2
Scipy version:           1.13.0
Installed devices:
- default.clifford (PennyLane-0.40.0)
- default.gaussian (PennyLane-0.40.0)
- default.mixed (PennyLane-0.40.0)
- default.qubit (PennyLane-0.40.0)
- default.qutrit (PennyLane-0.40.0)
- default.qutrit.mixed (PennyLane-0.40.0)
- default.tensor (PennyLane-0.40.0)
- null.qubit (PennyLane-0.40.0)
- reference.qubit (PennyLane-0.40.0)
- nvidia.custatevec (PennyLane-Catalyst-0.9.0)
- nvidia.cutensornet (PennyLane-Catalyst-0.9.0)
- oqc.cloud (PennyLane-Catalyst-0.9.0)
- softwareq.qpp (PennyLane-Catalyst-0.9.0)
- lightning.qubit (PennyLane_Lightning-0.39.0)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@isaacdevlugt isaacdevlugt added the bug 🐛 Something isn't working label Jan 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant