Skip to content

Commit

Permalink
Add QSVT.data property to improve backend inference (#5226)
Browse files Browse the repository at this point in the history
**Context:**

The QSVT example at the bottom of [this demo](https://pennylane.ai/qml/demos/tutorial_apply_qsvt) does not run
with `@jax.jit`:

```python
import jax

import pennylane as qml
from pennylane import numpy as np

# ------------------------------------------------------------------------------

kappa = 4
s = 0.10145775

# ------------------------------------------------------------------------------

# Required to avoid:
#   jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Executable expected
#   parameter 0 of size 204 but got buffer with incompatible size 408
jax.config.update("jax_enable_x64", True)

# Required to avoid:
#   TypeError: cannot unpack non-iterable ShapedArray object
jax.config.update("jax_dynamic_shapes", True)

# ------------------------------------------------------------------------------

def sum_even_odd_circ(x, phi, ancilla_wire, wires):
    phi1, phi2 = phi[: len(phi) // 2], phi[len(phi) // 2:]

    qml.Hadamard(wires=ancilla_wire)  # equal superposition

    # apply even and odd polynomial approx
    qml.ctrl(qml.qsvt, control=(ancilla_wire,), control_values=(0,))(x, phi1, wires=wires)
    qml.ctrl(qml.qsvt, control=(ancilla_wire,), control_values=(1,))(x, phi2, wires=wires)

    qml.Hadamard(wires=ancilla_wire)  # un-prepare superposition

# ------------------------------------------------------------------------------

np.random.seed(42)  # set seed for reproducibility
phi = np.random.rand(51)

# ------------------------------------------------------------------------------

samples_x = np.linspace(1 / kappa, 1, 100)

def target_func(x):
    return s * (1 / x)

# Note: This @jax.jit is NOT part of the original demo. 
@jax.jit
def loss_func(phi):
    sum_square_error = 0
    for x in samples_x:
        qsvt_matrix = qml.matrix(sum_even_odd_circ)(x, phi, ancilla_wire="ancilla", wires=[0])
        qsvt_val = qsvt_matrix[0, 0]
        sum_square_error += (np.real(qsvt_val) - target_func(x)) ** 2

    return sum_square_error / len(samples_x)

# ------------------------------------------------------------------------------

result = loss_func(phi)
print(result)
```

due to

```
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape complex128[4,4].
```

**Description of the Change:**

* ~Removed the explicit interface passed to `qml.math.matmul()` in `qml.matrix()`. This enables PennyLane to infer the correct backend (JAX) when multiplying NumPy arrays and JAX tracers.~
* Added a `data` property to the QSVT class so that the `qml.matrix()` transform infers the correct backend when the `QuantumScript` contains a QSVT operation.
* Removed trailing whitespace from the changelog (done automatically by VS Code).

**Benefits:**

* The QSVT workflow in the [QSVT in Practice](https://pennylane.ai/qml/demos/tutorial_apply_qsvt) demo can
be JIT-compiled and optimized by JAX (as long as `JAX_DYNAMIC_SHAPES` is explicitly set).

**Possible Drawbacks:**

None.

**Related GitHub Issues:**

None.
  • Loading branch information
Mandrenkov authored Feb 21, 2024
1 parent b97f7a5 commit a8be574
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 16 deletions.
31 changes: 17 additions & 14 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<h4>Easy to import circuits 💾</h4>

* An error message provides clearer instructions for installing PennyLane-Qiskit if the `qml.from_qiskit`
* An error message provides clearer instructions for installing PennyLane-Qiskit if the `qml.from_qiskit`
function fails because the Qiskit converter is missing.
[(#5218)](https://github.com/PennyLaneAI/pennylane/pull/5218)

Expand Down Expand Up @@ -75,13 +75,13 @@
[(#5003)](https://github.com/PennyLaneAI/pennylane/pull/5003)
[(#5017)](https://github.com/PennyLaneAI/pennylane/pull/5017)
[(#5027)](https://github.com/PennyLaneAI/pennylane/pull/5027)

```pycon
>>> op = X(0) + Y(0)
>>> type(op.pauli_rep)
pennylane.pauli.pauli_arithmetic.PauliSentence
```

The `PauliWord` and `PauliSentence` objects in the
[pauli](https://docs.pennylane.ai/en/stable/code/qml_pauli.html#classes) module provide an
efficient representation and can be combined using basic arithmetic like addition, products, and
Expand All @@ -99,7 +99,7 @@
>>> 0.5 * (X(0) + Y(1))
0.5 * (X(0) + Y(1))
```

Sums with many terms are broken up into multiple lines, but can still be copied back as valid
code:

Expand All @@ -111,13 +111,13 @@
+ 0.8 * (X(2) @ X(3))
)
```

* The `Sum` and `Prod` classes have been updated to reach feature parity with `Hamiltonian`
and `Tensor`, respectively. This includes support for grouping via the `pauli` module:
[(#5070)](https://github.com/PennyLaneAI/pennylane/pull/5070)
[(#5132)](https://github.com/PennyLaneAI/pennylane/pull/5132)
[(#5133)](https://github.com/PennyLaneAI/pennylane/pull/5133)

```pycon
>>> obs = [X(0) @ Y(1), Z(0), Y(0) @ Z(1), Y(1)]
>>> qml.pauli.group_observables(obs)
Expand Down Expand Up @@ -167,7 +167,7 @@
[0, 0, 0, 1, 0],
[1, 0, 0, 1, 1]])
```

The `default.clifford` device also supports the `PauliError`, `DepolarizingChannel`, `BitFlip` and
`PhaseFlip`
[noise channels](https://docs.pennylane.ai/en/latest/introduction/operations.html#noisy-channels)
Expand Down Expand Up @@ -244,12 +244,12 @@

* You can now multiply `PauliWord` and `PauliSentence` instances by scalars, e.g.
`0.5 * PauliWord({0:"X"})` or `0.5 * PauliSentence({PauliWord({0:"X"}): 1.})`.

* You can now intuitively add together
`PauliWord` and `PauliSentence` as well as scalars, which are treated implicitly as identities.
For example, `ps1 + pw1 + 1.` for some Pauli word `pw1 = PauliWord({0: "X", 1: "Y"})` and Pauli
sentence `ps1 = PauliSentence({pw1: 3.})`.

* You can now subtract `PauliWord` and `PauliSentence` instances, as well as scalars, from each
other. For example `ps1 - pw1 - 1`.

Expand Down Expand Up @@ -320,6 +320,9 @@

<h4>Other improvements</h4>

* The `QSVT` operation now determines its `data` from the block encoding and projector operator data.
[(#5226)](https://github.com/PennyLaneAI/pennylane/pull/5226)

* Faster `qml.probs` measurements due to an optimization in `_samples_to_counts`.
[(#5145)](https://github.com/PennyLaneAI/pennylane/pull/5145)

Expand Down Expand Up @@ -369,11 +372,11 @@
when converting a `QuantumCircuit` using `qml.from_qiskit`.
[(#5168)](https://github.com/PennyLaneAI/pennylane/pull/5168)

* Added new error tracking and propagation functionality.
* Added new error tracking and propagation functionality.
[(#5115)](https://github.com/PennyLaneAI/pennylane/pull/5115)
[(#5121)](https://github.com/PennyLaneAI/pennylane/pull/5121)

* Replacing `map_batch_transform` in the source code with the method `_batch_transform`
* Replacing `map_batch_transform` in the source code with the method `_batch_transform`
implemented in `TransformDispatcher`.
[(#5212)](https://github.com/PennyLaneAI/pennylane/pull/5212)

Expand All @@ -389,15 +392,15 @@
To allow for packages to register multiple compilers with PennyLane,
the `entry_points` convention under the designated group name
`pennylane.compilers` has been modified.

Previously, compilers would register `qjit` (JIT decorator),
`ops` (compiler-specific operations), and `context` (for tracing and
program capture).

Now, compilers must register `compiler_name.qjit`, `compiler_name.ops`,
and `compiler_name.context`, where `compiler_name` is replaced
by the name of the provided compiler.

For more information, please see the
[documentation on adding compilers](https://docs.pennylane.ai/en/stable/code/qml_compiler.html#adding-a-compiler).

Expand Down
28 changes: 26 additions & 2 deletions pennylane/templates/subroutines/qsvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,34 @@ def __init__(self, UA, projectors, id=None):
total_wires = ua_wires.union(proj_wires)
super().__init__(wires=total_wires, id=id)

@property
def data(self):
r"""Flattened list of operator data in this QSVT operation.
This ensures that the backend of a ``QuantumScript`` which contains a
``QSVT`` operation can be inferred with respect to the types of the
``QSVT`` block encoding and projector-controlled phase shift data.
"""
return tuple(datum for op in self._operators for datum in op.data)

@data.setter
def data(self, new_data):
# We need to check if ``new_data`` is empty because ``Operator.__init__()`` will attempt to
# assign the QSVT data to an empty tuple (since no positional arguments are provided).
if new_data:
for op in self._operators:
if op.num_params > 0:
op.data = new_data[: op.num_params]
new_data = new_data[op.num_params :]

@property
def _operators(self) -> list[qml.operation.Operator]:
"""Flattened list of operators that compose this QSVT operation."""
return [self._hyperparameters["UA"], *self._hyperparameters["projectors"]]

@staticmethod
def compute_decomposition(
UA, projectors, **hyperparameters
*_data, UA, projectors, **_kwargs
): # pylint: disable=arguments-differ
r"""Representation of the operator as a product of other operators.
Expand Down Expand Up @@ -317,7 +342,6 @@ def compute_decomposition(
UA (Operator): the block encoding circuit, specified as a :class:`~.Operator`
projectors (list[Operator]): a list of projector-controlled phase
shift circuits that implement the desired polynomial
wires (Iterable): wires that the template acts on
Returns:
list[.Operator]: decomposition of the operator
Expand Down
35 changes: 35 additions & 0 deletions tests/templates/test_subroutines/test_qsvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,34 @@ def test_QSVT_jax(self, input_matrix, angles, wires):
assert np.allclose(qml.matrix(op), default_matrix)
assert qml.math.get_interface(qml.matrix(op)) == "jax"

@pytest.mark.jax
@pytest.mark.parametrize(
("input_matrix", "angles", "wires"),
[([[0.1, 0.2], [0.3, 0.4]], [0.1, 0.2], [0, 1])],
)
def test_QSVT_jax_with_identity(self, input_matrix, angles, wires):
"""Test that applying the identity operation before the qsvt function in
a JIT context does not affect the matrix for jax.
The main purpose of this test is to ensure that the types of the block
encoding and projector-controlled phase shift data in a QSVT instance
are taken into account when inferring the backend of a QuantumScript.
"""
import jax

def identity_and_qsvt(angles):
qml.Identity(wires=wires[0])
qml.qsvt(input_matrix, angles, wires)

@jax.jit
def get_matrix_with_identity(angles):
return qml.matrix(identity_and_qsvt, wire_order=wires)(angles)

matrix = qml.matrix(qml.qsvt(input_matrix, angles, wires))
matrix_with_identity = get_matrix_with_identity(angles)

assert np.allclose(matrix, matrix_with_identity)

@pytest.mark.tf
@pytest.mark.parametrize(
("input_matrix", "angles", "wires"),
Expand Down Expand Up @@ -330,6 +358,13 @@ def test_label(self):
assert op.label() == "QSVT"
assert op.label(base_label="custom_label") == "custom_label"

def test_data(self):
"""Test that the data property gets and sets the correct values"""
op = qml.QSVT(qml.RX(1, wires=0), [qml.RY(2, wires=0), qml.RZ(3, wires=0)])
assert op.data == (1, 2, 3)
op.data = [4, 5, 6]
assert op.data == (4, 5, 6)


class Testqsvt:
"""Test the qml.qsvt function."""
Expand Down

0 comments on commit a8be574

Please sign in to comment.