Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
QSVT.data
property to improve backend inference (#5226)
**Context:** The QSVT example at the bottom of [this demo](https://pennylane.ai/qml/demos/tutorial_apply_qsvt) does not run with `@jax.jit`: ```python import jax import pennylane as qml from pennylane import numpy as np # ------------------------------------------------------------------------------ kappa = 4 s = 0.10145775 # ------------------------------------------------------------------------------ # Required to avoid: # jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Executable expected # parameter 0 of size 204 but got buffer with incompatible size 408 jax.config.update("jax_enable_x64", True) # Required to avoid: # TypeError: cannot unpack non-iterable ShapedArray object jax.config.update("jax_dynamic_shapes", True) # ------------------------------------------------------------------------------ def sum_even_odd_circ(x, phi, ancilla_wire, wires): phi1, phi2 = phi[: len(phi) // 2], phi[len(phi) // 2:] qml.Hadamard(wires=ancilla_wire) # equal superposition # apply even and odd polynomial approx qml.ctrl(qml.qsvt, control=(ancilla_wire,), control_values=(0,))(x, phi1, wires=wires) qml.ctrl(qml.qsvt, control=(ancilla_wire,), control_values=(1,))(x, phi2, wires=wires) qml.Hadamard(wires=ancilla_wire) # un-prepare superposition # ------------------------------------------------------------------------------ np.random.seed(42) # set seed for reproducibility phi = np.random.rand(51) # ------------------------------------------------------------------------------ samples_x = np.linspace(1 / kappa, 1, 100) def target_func(x): return s * (1 / x) # Note: This @jax.jit is NOT part of the original demo. @jax.jit def loss_func(phi): sum_square_error = 0 for x in samples_x: qsvt_matrix = qml.matrix(sum_even_odd_circ)(x, phi, ancilla_wire="ancilla", wires=[0]) qsvt_val = qsvt_matrix[0, 0] sum_square_error += (np.real(qsvt_val) - target_func(x)) ** 2 return sum_square_error / len(samples_x) # ------------------------------------------------------------------------------ result = loss_func(phi) print(result) ``` due to ``` jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape complex128[4,4]. ``` **Description of the Change:** * ~Removed the explicit interface passed to `qml.math.matmul()` in `qml.matrix()`. This enables PennyLane to infer the correct backend (JAX) when multiplying NumPy arrays and JAX tracers.~ * Added a `data` property to the QSVT class so that the `qml.matrix()` transform infers the correct backend when the `QuantumScript` contains a QSVT operation. * Removed trailing whitespace from the changelog (done automatically by VS Code). **Benefits:** * The QSVT workflow in the [QSVT in Practice](https://pennylane.ai/qml/demos/tutorial_apply_qsvt) demo can be JIT-compiled and optimized by JAX (as long as `JAX_DYNAMIC_SHAPES` is explicitly set). **Possible Drawbacks:** None. **Related GitHub Issues:** None.
- Loading branch information