Skip to content

Commit

Permalink
Add tests for exponential extrapolation (#953)
Browse files Browse the repository at this point in the history
**Context:** The last PR
([first](#806),
[second](PennyLaneAI/pennylane#5972)) to add
exponential extrapolation capabilities to catalyst.

**Description of the Change:** This PR ensures that
`pennylane.transforms.exponential_extrapolate` works inside catalyst's
`mitigate_with_zne` function.

**Benefits:** ZNE has increased functionality.

**Possible Drawbacks:** As part of the testing, I've removed 0.1 and 0.2
from the list of values being used to create different circuits. This is
because exponential extrapolation is not necessarily as stable as
polynomial fitting. For example, when extrapolating near constant values
exponential extrapolation can struggle due to the fact that a linear fit
is first performed to understand the "direction" (or `sign`) of the
exponential. If the slope is positive, the data is flipped to fit
something that looks like $A\mathrm{e}^{-Bx} + C$ where
$B\in\mathbb{R}_{> 0}$.

An example of this happening can be seen here.
```py
>>> exponential_extrapolate([1,2,3], [0.3894, 0.3894183, 0.38941])
-1.0000000000000059e-06
>>> richardson_extrapolate([1,2,3], [0.3894, 0.3894183, 0.38941])
0.3893551000000072
```

If we want to ensure to continue testing values 0.1 and 0.2 for
polynomial extrapolation, we can create separate tests. Let me know what
would be best.

**Related GitHub Issues:**

fixes #754
  • Loading branch information
natestemen authored Jul 30, 2024
1 parent 4f2b9f1 commit 0055f9e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 21 deletions.
27 changes: 21 additions & 6 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,20 @@

```

* Exponential extrapolation is now a supported method of extrapolation when using `mitigate_with_zne`.
[(#953)](https://github.com/PennyLaneAI/catalyst/pull/953)

This new functionality fits the data from noise-scaled circuits with an exponential function,
and returns the zero-noise value. This functionality is available through the pennylane module
as follows
```py
from pennylane.transforms import exponential_extrapolate

catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=exponential_extrapolate
)
```

<h3>Improvements</h3>

* Catalyst is now compatible with Enzyme `v0.0.130`
Expand Down Expand Up @@ -182,6 +196,12 @@
* Support for TOML files in Schema 1 has been disabled.
[(#960)](https://github.com/PennyLaneAI/catalyst/pull/960)

* The `mitigate_with_zne` function no longer accepts a `degree` parameter for polynomial fitting
and instead accepts a callable to perform extrapolation. Any qjit-compatible extrapolation
function is valid. Keyword arguments can be passed to this function using the
`extrapolate_kwargs` keyword argument in `mitigate_with_zne`.
[(#806)](https://github.com/PennyLaneAI/catalyst/pull/806)

<h3>Bug fixes</h3>

* Static arguments can now be passed through a QNode when specified
Expand Down Expand Up @@ -288,6 +308,7 @@ Mehrdad Malekmohammadi,
Romain Moyard,
Erick Ochoa,
Mudit Pandey,
nate stemen,
Raul Torres,
Tzung-Han Juang,
Paul Haochen Wang,
Expand Down Expand Up @@ -803,12 +824,6 @@ Paul Haochen Wang,

<h3>Breaking changes</h3>

* The `mitigate_with_zne` function no longer accepts a `degree` parameter for polynomial fitting
and instead accepts a callable to perform extrapolation. Any qjit-compatible extrapolation
function is valid. Keyword arguments can be passed to this function using the
`extrapolate_kwargs` keyword argument in `mitigate_with_zne`.
[(#806)](https://github.com/PennyLaneAI/catalyst/pull/806)

* Binary distributions for Linux are now based on `manylinux_2_28` instead of `manylinux_2014`.
As a result, Catalyst will only be compatible on systems with `glibc` versions `2.28` and above
(e.g., Ubuntu 20.04 and above).
Expand Down
74 changes: 59 additions & 15 deletions frontend/test/pytest/test_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,26 @@
import numpy as np
import pennylane as qml
import pytest
from pennylane.transforms import exponential_extrapolate

import catalyst
from catalyst.api_extensions.error_mitigation import polynomial_extrapolation

quadratic_extrapolation = polynomial_extrapolation(2)


def skip_if_exponential_extrapolation_unstable(circuit_param, extrapolation_func):
"""skip test if exponential extrapolation will be unstable"""
if circuit_param < 0.3 and extrapolation_func == exponential_extrapolate:
pytest.skip("Exponential extrapolation unstable in this region.")


@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
def test_single_measurement(params):
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_single_measurement(params, extrapolation):
"""Test that without noise the same results are returned for single measurements."""
skip_if_exponential_extrapolation_unstable(params, extrapolation)

dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(device=dev)
Expand All @@ -42,15 +52,18 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

assert np.allclose(mitigated_qnode(params), circuit(params))


@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
def test_multiple_measurements(params):
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_multiple_measurements(params, extrapolation):
"""Test that without noise the same results are returned for multiple measurements"""
skip_if_exponential_extrapolation_unstable(params, extrapolation)

dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(device=dev)
Expand All @@ -65,7 +78,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

assert np.allclose(mitigated_qnode(params), circuit(params))
Expand Down Expand Up @@ -121,7 +134,8 @@ def mitigated_function(args):
mitigated_function(0.1)


def test_dtype_error():
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_dtype_error(extrapolation):
"""Test that an error is raised when multiple results do not have the same dtype."""
dev = qml.device("lightning.qubit", wires=2)

Expand All @@ -137,7 +151,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

with pytest.raises(
Expand All @@ -146,7 +160,8 @@ def mitigated_qnode(args):
mitigated_qnode(0.1)


def test_dtype_not_float_error():
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_dtype_not_float_error(extrapolation):
"""Test that an error is raised when results are not float."""
dev = qml.device("lightning.qubit", wires=2)

Expand All @@ -162,7 +177,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

with pytest.raises(
Expand All @@ -171,7 +186,8 @@ def mitigated_qnode(args):
mitigated_qnode(0.1)


def test_shape_error():
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_shape_error(extrapolation):
"""Test that an error is raised when results have shape."""
dev = qml.device("lightning.qubit", wires=2)

Expand All @@ -187,7 +203,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

with pytest.raises(
Expand Down Expand Up @@ -229,8 +245,11 @@ def mitigated_qnode():


@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
def test_zne_usage_patterns(params):
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_zne_usage_patterns(params, extrapolation):
"""Test usage patterns of catalyst.zne."""
skip_if_exponential_extrapolation_unstable(params, extrapolation)

dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(device=dev)
Expand All @@ -245,13 +264,13 @@ def fn(x):
@catalyst.qjit
def mitigated_qnode_fn_as_argument(args):
return catalyst.mitigate_with_zne(
fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

@catalyst.qjit
def mitigated_qnode_partial(args):
return catalyst.mitigate_with_zne(
scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(fn)(args)

assert np.allclose(mitigated_qnode_fn_as_argument(params), fn(params))
Expand All @@ -271,13 +290,13 @@ def circuit():
qml.Hadamard(wires=1)
return qml.expval(qml.PauliY(wires=0))

def jax_extrap(scale_factors, results):
def jax_extrapolation(scale_factors, results):
return jax.numpy.polyfit(scale_factors, results, 2)[-1]

@catalyst.qjit
def mitigated_qnode():
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=jax_extrap
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=jax_extrapolation
)()

assert np.allclose(mitigated_qnode(), circuit())
Expand Down Expand Up @@ -308,5 +327,30 @@ def mitigated_qnode():
assert np.allclose(mitigated_qnode(), circuit())


def test_exponential_extrapolation_with_kwargs():
"""test mitigate_with_zne with keyword arguments for exponential extrapolation function"""
dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(device=dev)
def circuit():
qml.Hadamard(wires=0)
qml.RZ(0.1, wires=0)
qml.RZ(0.2, wires=0)
qml.CNOT(wires=[1, 0])
qml.Hadamard(wires=1)
return qml.expval(qml.PauliY(wires=0))

@catalyst.qjit
def mitigated_qnode():
return catalyst.mitigate_with_zne(
circuit,
scale_factors=jax.numpy.array([1, 2, 3]),
extrapolate=qml.transforms.exponential_extrapolate,
extrapolate_kwargs={"asymptote": 3},
)()

assert np.allclose(mitigated_qnode(), circuit())


if __name__ == "__main__":
pytest.main(["-x", __file__])

0 comments on commit 0055f9e

Please sign in to comment.