Skip to content

Commit

Permalink
Update qml.math.norm for dispatching autograd; stop rounding when…
Browse files Browse the repository at this point in the history
… postselecting (#4766)

**Context:**
Differentiation of `qml.math.norm` does not work for L2 norm. This was
causing incorrect gradients with autograd. Moreover, due to the rounding
function, the gradient of `QNode`s being postselected is incorrect, so
that needs to be removed.

**Description of the Change:**
* Add private function to compute the norm for `autograd` interface with
`qml.math.norm` when `ord=None` and `axis=None`. Otherwise, we dispatch
to `scipy.linalg.norm` as we did before.
* Stop rounding the norm of the state when renormalizing the state
vector after postselection. Instead, we check if the norm is close to 0
and set it to exactly 0 if it is. This condition is only checked is the
state vector is not abstract.
* Added warning to `qml.measure` docs about jitting with postselection
on zero probability states.

**Benefits:**
* `qml.math.norm` is differentiable for all interfaces for `ord=None`
and `axis=None`.
* Postselection doesn't lead to incorrect gradients.

**Possible Drawbacks:**
Postselection with jitting can lead to incorrect results and errors if
postselecting on a state with zero probability. However, this is an edge
case that is not causing problems frequently.

**Related GitHub Issues:**
#4867

---------

Co-authored-by: Christina Lee <[email protected]>
  • Loading branch information
mudit2812 and albi3ro authored Nov 29, 2023
1 parent c7cda37 commit 669c86c
Show file tree
Hide file tree
Showing 10 changed files with 297 additions and 16 deletions.
7 changes: 5 additions & 2 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def _postselection_postprocess(state, is_state_batched, shots):
# equal to zero so that the state can become invalid. This way, execution can continue, and
# bad postselection gives results that are invalid rather than results that look valid but
# are incorrect.
norm = qml.math.floor(qml.math.real(qml.math.norm(state)) * 1e15) * 1e-15
norm = qml.math.norm(state)

if not qml.math.is_abstract(state) and qml.math.allclose(norm, 0.0):
norm = 0.0

if shots:
# Clip the number of shots using a binomial distribution using the probability of
Expand All @@ -89,7 +92,7 @@ def _postselection_postprocess(state, is_state_batched, shots):
# valid samples
shots = _FlexShots(postselected_shots)

state = state / qml.math.cast_like(norm, state)
state = state / norm
return state, shots


Expand Down
13 changes: 13 additions & 0 deletions pennylane/math/multi_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,12 +859,25 @@ def norm(tensor, like=None, **kwargs):
axis_val = kwargs.pop("axis")
kwargs["dim"] = axis_val

elif (
like == "autograd" and kwargs.get("ord", None) is None and kwargs.get("axis", None) is None
):
norm = _flat_autograd_norm

else:
from scipy.linalg import norm

return norm(tensor, **kwargs)


def _flat_autograd_norm(tensor, **kwargs): # pylint: disable=unused-argument
"""Helper function for computing the norm of an autograd tensor when the order or axes are not
specified. This is used for differentiability."""
x = np.ravel(tensor)
sq_norm = np.dot(x, np.conj(x))
return np.real(np.sqrt(sq_norm))


@multi_dispatch(argnum=[1])
def gammainc(m, t, like=None):
r"""Return the lower incomplete Gamma function.
Expand Down
20 changes: 13 additions & 7 deletions pennylane/measurements/mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,21 @@ def func(x):
.. warning::
All measurements are supported when using postselection. However, postselection on a zero probability
state can cause some measurements to break.
state can cause some measurements to break:
With finite shots, one must be careful when measuring ``qml.probs`` or ``qml.counts``, as these
measurements will raise errors if there are no valid samples after postselection. This will occur
with postselection states that have zero or close to zero probability.
* With finite shots, one must be careful when measuring ``qml.probs`` or ``qml.counts``, as these
measurements will raise errors if there are no valid samples after postselection. This will occur
with postselection states that have zero or close to zero probability.
* With analytic execution, ``qml.mutual_info`` will raise errors when using any interfaces except
``jax``, and ``qml.vn_entropy`` will raise an error with the ``tensorflow`` interface when the
postselection state has zero probability.
* When using JIT, ``QNode``'s may have unexpected behaviour when postselection on a zero
probability state is performed. Due to floating point precision, the zero probability may not be
detected, thus letting execution continue as normal without ``NaN`` or ``Inf`` values or empty
samples, leading to unexpected or incorrect results.
With analytic execution, ``qml.mutual_info`` will raise errors when using any interfaces except
``jax``, and ``qml.vn_entropy`` will raise an error with the ``tensorflow`` interface when the
postselection state has zero probability.
"""

wire = Wires(wires)
Expand Down
47 changes: 40 additions & 7 deletions tests/devices/default_qubit/test_default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,8 +1830,8 @@ def test_postselection_invalid_analytic(
):
pytest.skip("Unsupported measurements and interfaces.")

if use_jit and interface != "jax":
pytest.skip("Can't jit with non-jax interfaces.")
if use_jit:
pytest.skip("Jitting tested in different test.")

# Wires are specified so that the shape for measurements can be determined correctly
dev = qml.device("default.qubit", wires=2)
Expand All @@ -1843,22 +1843,55 @@ def circ():
qml.measure(0, postselect=0)
return qml.apply(mp)

if use_jit:
import jax

circ = jax.jit(circ)

res = circ()
if interface == "autograd":
assert qml.math.get_interface(res) == autograd_interface
else:
assert qml.math.get_interface(res) == interface

assert qml.math.shape(res) == mp.shape(dev, qml.measurements.Shots(None))
if is_nan:
assert qml.math.all(qml.math.isnan(res))
else:
assert qml.math.allclose(res, 0.0)

@pytest.mark.parametrize(
"mp, expected",
[
(qml.expval(qml.PauliZ(0)), 1.0),
(qml.var(qml.PauliZ(0)), 0.0),
(qml.probs(wires=[0, 1]), [1.0, 0.0, 0.0, 0.0]),
(qml.density_matrix(wires=0), [[1.0, 0.0], [0.0, 0.0]]),
(qml.purity(0), 1.0),
(qml.vn_entropy(0), 0.0),
(qml.mutual_info(0, 1), 0.0),
],
)
def test_postselection_invalid_analytic_jit(self, mp, expected, interface, use_jit):
"""Test that the results of a qnode give the postselected results even when the
probability of the postselected state is zero when jitting."""
if interface != "jax" or not use_jit:
pytest.skip("Test is only for jitting.")

import jax

# Wires are specified so that the shape for measurements can be determined correctly
dev = qml.device("default.qubit", wires=2)

@jax.jit
@qml.qnode(dev, interface=interface)
def circ():
qml.RX(np.pi, 0)
qml.CNOT([0, 1])
qml.measure(0, postselect=0)
return qml.apply(mp)

res = circ()

assert qml.math.get_interface(res) == "jax"
assert qml.math.shape(res) == mp.shape(dev, qml.measurements.Shots(None))
assert qml.math.allclose(res, expected)

@pytest.mark.parametrize(
"mp, expected_shape",
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,49 @@ def circuit(x, y):

assert np.allclose(jac, expected, atol=tol, rtol=0)

def test_postselection_differentiation(
self, interface, dev, diff_method, grad_on_execution, device_vjp
):
"""Test that when postselecting with default.qubit, differentiation works correctly."""

if diff_method in ["adjoint", "spsa", "hadamard"]:
pytest.skip("Diff method does not support postselection.")

@qml.qnode(
dev,
diff_method=diff_method,
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
)
def circuit(phi, theta):
qml.RX(phi, wires=0)
qml.CNOT([0, 1])
qml.measure(wires=0, postselect=1)
qml.RX(theta, wires=1)
return qml.expval(qml.PauliZ(1))

@qml.qnode(
dev,
diff_method=diff_method,
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
)
def expected_circuit(theta):
qml.PauliX(1)
qml.RX(theta, wires=1)
return qml.expval(qml.PauliZ(1))

phi = np.array(1.23, requires_grad=True)
theta = np.array(4.56, requires_grad=True)

assert np.allclose(circuit(phi, theta), expected_circuit(theta))

gradient = qml.grad(circuit)(phi, theta)
exp_theta_grad = qml.grad(expected_circuit)(theta)
assert np.allclose(gradient, [0.0, exp_theta_grad])


@pytest.mark.parametrize(
"dev,diff_method,grad_on_execution, device_vjp", qubit_device_and_diff_method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,39 @@ def cost(weights):

assert len(res) == 2

def test_postselection_differentiation(self, dev, diff_method, grad_on_execution, interface):
"""Test that when postselecting with default.qubit, differentiation works correctly."""

if diff_method in ["adjoint", "spsa", "hadamard"]:
pytest.skip("Diff method does not support postselection.")

@qml.qnode(
dev, diff_method=diff_method, interface=interface, grad_on_execution=grad_on_execution
)
def circuit(phi, theta):
qml.RX(phi, wires=0)
qml.CNOT([0, 1])
qml.measure(wires=0, postselect=1)
qml.RX(theta, wires=1)
return qml.expval(qml.PauliZ(1))

@qml.qnode(
dev, diff_method=diff_method, interface=interface, grad_on_execution=grad_on_execution
)
def expected_circuit(theta):
qml.PauliX(1)
qml.RX(theta, wires=1)
return qml.expval(qml.PauliZ(1))

phi = jax.numpy.array(1.23)
theta = jax.numpy.array(4.56)

assert np.allclose(jax.jit(circuit)(phi, theta), jax.jit(expected_circuit)(theta))

gradient = jax.jit(jax.grad(circuit, argnums=[0, 1]))(phi, theta)
exp_theta_grad = jax.jit(jax.grad(expected_circuit))(theta)
assert np.allclose(gradient, [0.0, exp_theta_grad])


@pytest.mark.parametrize(
"interface,dev,diff_method,grad_on_execution", interface_and_qubit_device_and_diff_method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,39 @@ def cost(weights):

assert len(res) == 2

def test_postselection_differentiation(self, dev, diff_method, grad_on_execution):
"""Test that when postselecting with default.qubit, differentiation works correctly."""

if diff_method in ["adjoint", "spsa", "hadamard"]:
pytest.skip("Diff method does not support postselection.")

@qml.qnode(
dev, diff_method=diff_method, interface="jax", grad_on_execution=grad_on_execution
)
def circuit(phi, theta):
qml.RX(phi, wires=0)
qml.CNOT([0, 1])
qml.measure(wires=0, postselect=1)
qml.RX(theta, wires=1)
return qml.expval(qml.PauliZ(1))

@qml.qnode(
dev, diff_method=diff_method, interface="jax", grad_on_execution=grad_on_execution
)
def expected_circuit(theta):
qml.PauliX(1)
qml.RX(theta, wires=1)
return qml.expval(qml.PauliZ(1))

phi = jax.numpy.array(1.23)
theta = jax.numpy.array(4.56)

assert np.allclose(circuit(phi, theta), expected_circuit(theta))

gradient = jax.grad(circuit, argnums=[0, 1])(phi, theta)
exp_theta_grad = jax.grad(expected_circuit)(theta)
assert np.allclose(gradient, [0.0, exp_theta_grad])


@pytest.mark.parametrize(
"interface,dev,diff_method,grad_on_execution", interface_and_device_and_diff_method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,45 @@ def circuit(weights):
]
assert np.allclose(grad, expected, atol=tol, rtol=0)

def test_postselection_differentiation(self, dev, diff_method, grad_on_execution, interface):
"""Test that when postselecting with default.qubit, differentiation works correctly."""

if diff_method in ["adjoint", "spsa", "hadamard"]:
pytest.skip("Diff method does not support postselection.")

@qml.qnode(
dev, diff_method=diff_method, interface=interface, grad_on_execution=grad_on_execution
)
def circuit(phi, theta):
qml.RX(phi, wires=0)
qml.CNOT([0, 1])
qml.measure(wires=0, postselect=1)
qml.RX(theta, wires=1)
return qml.expval(qml.PauliZ(1))

@qml.qnode(
dev, diff_method=diff_method, interface=interface, grad_on_execution=grad_on_execution
)
def expected_circuit(theta):
qml.PauliX(1)
qml.RX(theta, wires=1)
return qml.expval(qml.PauliZ(1))

phi = tf.Variable(1.23)
theta = tf.Variable(4.56)

assert np.allclose(circuit(phi, theta), expected_circuit(theta))

with tf.GradientTape() as res_tape:
res = circuit(phi, theta)
gradient = res_tape.gradient(res, [phi, theta])

with tf.GradientTape() as expected_tape:
expected = expected_circuit(theta)
exp_theta_grad = expected_tape.gradient(expected, theta)

assert np.allclose(gradient, [0.0, exp_theta_grad])


@pytest.mark.parametrize(
"interface,dev,diff_method,grad_on_execution", interface_and_qubit_device_and_diff_method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,49 @@ def circuit(x, y):
)
assert np.allclose(weights.grad.detach(), expected, atol=tol, rtol=0)

def test_postselection_differentiation(
self, interface, dev, diff_method, grad_on_execution, device_vjp
):
"""Test that when postselecting with default.qubit, differentiation works correctly."""

if diff_method in ["adjoint", "spsa", "hadamard"]:
pytest.skip("Diff method does not support postselection.")

@qml.qnode(
dev,
diff_method=diff_method,
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
)
def circuit(phi, theta):
qml.RX(phi, wires=0)
qml.CNOT([0, 1])
qml.measure(wires=0, postselect=1)
qml.RX(theta, wires=1)
return qml.expval(qml.PauliZ(1))

@qml.qnode(
dev,
diff_method=diff_method,
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
)
def expected_circuit(theta):
qml.PauliX(1)
qml.RX(theta, wires=1)
return qml.expval(qml.PauliZ(1))

phi = torch.tensor(1.23, requires_grad=True)
theta = torch.tensor(4.56, requires_grad=True)

assert qml.math.allclose(circuit(phi, theta), expected_circuit(theta))

gradient = torch.autograd.grad(circuit(phi, theta), [phi, theta])
exp_theta_grad = torch.autograd.grad(expected_circuit(theta), theta)[0]
assert qml.math.allclose(gradient, [0.0, exp_theta_grad])


@pytest.mark.parametrize(
"dev,diff_method,grad_on_execution, device_vjp", qubit_device_and_diff_method
Expand Down
Loading

0 comments on commit 669c86c

Please sign in to comment.