From 669c86c28e20bd37df62cc7c8c498288baf27662 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 29 Nov 2023 15:29:04 -0500 Subject: [PATCH] Update `qml.math.norm` for dispatching `autograd`; stop rounding when postselecting (#4766) **Context:** Differentiation of `qml.math.norm` does not work for L2 norm. This was causing incorrect gradients with autograd. Moreover, due to the rounding function, the gradient of `QNode`s being postselected is incorrect, so that needs to be removed. **Description of the Change:** * Add private function to compute the norm for `autograd` interface with `qml.math.norm` when `ord=None` and `axis=None`. Otherwise, we dispatch to `scipy.linalg.norm` as we did before. * Stop rounding the norm of the state when renormalizing the state vector after postselection. Instead, we check if the norm is close to 0 and set it to exactly 0 if it is. This condition is only checked is the state vector is not abstract. * Added warning to `qml.measure` docs about jitting with postselection on zero probability states. **Benefits:** * `qml.math.norm` is differentiable for all interfaces for `ord=None` and `axis=None`. * Postselection doesn't lead to incorrect gradients. **Possible Drawbacks:** Postselection with jitting can lead to incorrect results and errors if postselecting on a state with zero probability. However, this is an edge case that is not causing problems frequently. **Related GitHub Issues:** #4867 --------- Co-authored-by: Christina Lee --- pennylane/devices/qubit/simulate.py | 7 ++- pennylane/math/multi_dispatch.py | 13 +++++ pennylane/measurements/mid_measure.py | 20 +++++--- .../default_qubit/test_default_qubit.py | 47 ++++++++++++++++--- .../test_autograd_qnode_default_qubit_2.py | 43 +++++++++++++++++ .../test_jax_jit_qnode_default_qubit_2.py | 33 +++++++++++++ .../test_jax_qnode_default_qubit_2.py | 33 +++++++++++++ .../test_tensorflow_qnode_default_qubit_2.py | 39 +++++++++++++++ .../test_torch_qnode_default_qubit_2.py | 43 +++++++++++++++++ tests/math/test_multi_dispatch.py | 35 ++++++++++++++ 10 files changed, 297 insertions(+), 16 deletions(-) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index 670b2a77a96..b16fe178ba1 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -74,7 +74,10 @@ def _postselection_postprocess(state, is_state_batched, shots): # equal to zero so that the state can become invalid. This way, execution can continue, and # bad postselection gives results that are invalid rather than results that look valid but # are incorrect. - norm = qml.math.floor(qml.math.real(qml.math.norm(state)) * 1e15) * 1e-15 + norm = qml.math.norm(state) + + if not qml.math.is_abstract(state) and qml.math.allclose(norm, 0.0): + norm = 0.0 if shots: # Clip the number of shots using a binomial distribution using the probability of @@ -89,7 +92,7 @@ def _postselection_postprocess(state, is_state_batched, shots): # valid samples shots = _FlexShots(postselected_shots) - state = state / qml.math.cast_like(norm, state) + state = state / norm return state, shots diff --git a/pennylane/math/multi_dispatch.py b/pennylane/math/multi_dispatch.py index 46354429843..bda3b308ef9 100644 --- a/pennylane/math/multi_dispatch.py +++ b/pennylane/math/multi_dispatch.py @@ -859,12 +859,25 @@ def norm(tensor, like=None, **kwargs): axis_val = kwargs.pop("axis") kwargs["dim"] = axis_val + elif ( + like == "autograd" and kwargs.get("ord", None) is None and kwargs.get("axis", None) is None + ): + norm = _flat_autograd_norm + else: from scipy.linalg import norm return norm(tensor, **kwargs) +def _flat_autograd_norm(tensor, **kwargs): # pylint: disable=unused-argument + """Helper function for computing the norm of an autograd tensor when the order or axes are not + specified. This is used for differentiability.""" + x = np.ravel(tensor) + sq_norm = np.dot(x, np.conj(x)) + return np.real(np.sqrt(sq_norm)) + + @multi_dispatch(argnum=[1]) def gammainc(m, t, like=None): r"""Return the lower incomplete Gamma function. diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py index 9becfac983d..f272bb583ce 100644 --- a/pennylane/measurements/mid_measure.py +++ b/pennylane/measurements/mid_measure.py @@ -164,15 +164,21 @@ def func(x): .. warning:: All measurements are supported when using postselection. However, postselection on a zero probability - state can cause some measurements to break. + state can cause some measurements to break: - With finite shots, one must be careful when measuring ``qml.probs`` or ``qml.counts``, as these - measurements will raise errors if there are no valid samples after postselection. This will occur - with postselection states that have zero or close to zero probability. + * With finite shots, one must be careful when measuring ``qml.probs`` or ``qml.counts``, as these + measurements will raise errors if there are no valid samples after postselection. This will occur + with postselection states that have zero or close to zero probability. + + * With analytic execution, ``qml.mutual_info`` will raise errors when using any interfaces except + ``jax``, and ``qml.vn_entropy`` will raise an error with the ``tensorflow`` interface when the + postselection state has zero probability. + + * When using JIT, ``QNode``'s may have unexpected behaviour when postselection on a zero + probability state is performed. Due to floating point precision, the zero probability may not be + detected, thus letting execution continue as normal without ``NaN`` or ``Inf`` values or empty + samples, leading to unexpected or incorrect results. - With analytic execution, ``qml.mutual_info`` will raise errors when using any interfaces except - ``jax``, and ``qml.vn_entropy`` will raise an error with the ``tensorflow`` interface when the - postselection state has zero probability. """ wire = Wires(wires) diff --git a/tests/devices/default_qubit/test_default_qubit.py b/tests/devices/default_qubit/test_default_qubit.py index 3d56d299ef0..2722b5b9e9f 100644 --- a/tests/devices/default_qubit/test_default_qubit.py +++ b/tests/devices/default_qubit/test_default_qubit.py @@ -1830,8 +1830,8 @@ def test_postselection_invalid_analytic( ): pytest.skip("Unsupported measurements and interfaces.") - if use_jit and interface != "jax": - pytest.skip("Can't jit with non-jax interfaces.") + if use_jit: + pytest.skip("Jitting tested in different test.") # Wires are specified so that the shape for measurements can be determined correctly dev = qml.device("default.qubit", wires=2) @@ -1843,22 +1843,55 @@ def circ(): qml.measure(0, postselect=0) return qml.apply(mp) - if use_jit: - import jax - - circ = jax.jit(circ) - res = circ() if interface == "autograd": assert qml.math.get_interface(res) == autograd_interface else: assert qml.math.get_interface(res) == interface + assert qml.math.shape(res) == mp.shape(dev, qml.measurements.Shots(None)) if is_nan: assert qml.math.all(qml.math.isnan(res)) else: assert qml.math.allclose(res, 0.0) + @pytest.mark.parametrize( + "mp, expected", + [ + (qml.expval(qml.PauliZ(0)), 1.0), + (qml.var(qml.PauliZ(0)), 0.0), + (qml.probs(wires=[0, 1]), [1.0, 0.0, 0.0, 0.0]), + (qml.density_matrix(wires=0), [[1.0, 0.0], [0.0, 0.0]]), + (qml.purity(0), 1.0), + (qml.vn_entropy(0), 0.0), + (qml.mutual_info(0, 1), 0.0), + ], + ) + def test_postselection_invalid_analytic_jit(self, mp, expected, interface, use_jit): + """Test that the results of a qnode give the postselected results even when the + probability of the postselected state is zero when jitting.""" + if interface != "jax" or not use_jit: + pytest.skip("Test is only for jitting.") + + import jax + + # Wires are specified so that the shape for measurements can be determined correctly + dev = qml.device("default.qubit", wires=2) + + @jax.jit + @qml.qnode(dev, interface=interface) + def circ(): + qml.RX(np.pi, 0) + qml.CNOT([0, 1]) + qml.measure(0, postselect=0) + return qml.apply(mp) + + res = circ() + + assert qml.math.get_interface(res) == "jax" + assert qml.math.shape(res) == mp.shape(dev, qml.measurements.Shots(None)) + assert qml.math.allclose(res, expected) + @pytest.mark.parametrize( "mp, expected_shape", [ diff --git a/tests/interfaces/default_qubit_2_integration/test_autograd_qnode_default_qubit_2.py b/tests/interfaces/default_qubit_2_integration/test_autograd_qnode_default_qubit_2.py index b199cf02d41..94dc7627c74 100644 --- a/tests/interfaces/default_qubit_2_integration/test_autograd_qnode_default_qubit_2.py +++ b/tests/interfaces/default_qubit_2_integration/test_autograd_qnode_default_qubit_2.py @@ -1404,6 +1404,49 @@ def circuit(x, y): assert np.allclose(jac, expected, atol=tol, rtol=0) + def test_postselection_differentiation( + self, interface, dev, diff_method, grad_on_execution, device_vjp + ): + """Test that when postselecting with default.qubit, differentiation works correctly.""" + + if diff_method in ["adjoint", "spsa", "hadamard"]: + pytest.skip("Diff method does not support postselection.") + + @qml.qnode( + dev, + diff_method=diff_method, + interface=interface, + grad_on_execution=grad_on_execution, + device_vjp=device_vjp, + ) + def circuit(phi, theta): + qml.RX(phi, wires=0) + qml.CNOT([0, 1]) + qml.measure(wires=0, postselect=1) + qml.RX(theta, wires=1) + return qml.expval(qml.PauliZ(1)) + + @qml.qnode( + dev, + diff_method=diff_method, + interface=interface, + grad_on_execution=grad_on_execution, + device_vjp=device_vjp, + ) + def expected_circuit(theta): + qml.PauliX(1) + qml.RX(theta, wires=1) + return qml.expval(qml.PauliZ(1)) + + phi = np.array(1.23, requires_grad=True) + theta = np.array(4.56, requires_grad=True) + + assert np.allclose(circuit(phi, theta), expected_circuit(theta)) + + gradient = qml.grad(circuit)(phi, theta) + exp_theta_grad = qml.grad(expected_circuit)(theta) + assert np.allclose(gradient, [0.0, exp_theta_grad]) + @pytest.mark.parametrize( "dev,diff_method,grad_on_execution, device_vjp", qubit_device_and_diff_method diff --git a/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py b/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py index aff589f5c96..1b97a52a71f 100644 --- a/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py +++ b/tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py @@ -921,6 +921,39 @@ def cost(weights): assert len(res) == 2 + def test_postselection_differentiation(self, dev, diff_method, grad_on_execution, interface): + """Test that when postselecting with default.qubit, differentiation works correctly.""" + + if diff_method in ["adjoint", "spsa", "hadamard"]: + pytest.skip("Diff method does not support postselection.") + + @qml.qnode( + dev, diff_method=diff_method, interface=interface, grad_on_execution=grad_on_execution + ) + def circuit(phi, theta): + qml.RX(phi, wires=0) + qml.CNOT([0, 1]) + qml.measure(wires=0, postselect=1) + qml.RX(theta, wires=1) + return qml.expval(qml.PauliZ(1)) + + @qml.qnode( + dev, diff_method=diff_method, interface=interface, grad_on_execution=grad_on_execution + ) + def expected_circuit(theta): + qml.PauliX(1) + qml.RX(theta, wires=1) + return qml.expval(qml.PauliZ(1)) + + phi = jax.numpy.array(1.23) + theta = jax.numpy.array(4.56) + + assert np.allclose(jax.jit(circuit)(phi, theta), jax.jit(expected_circuit)(theta)) + + gradient = jax.jit(jax.grad(circuit, argnums=[0, 1]))(phi, theta) + exp_theta_grad = jax.jit(jax.grad(expected_circuit))(theta) + assert np.allclose(gradient, [0.0, exp_theta_grad]) + @pytest.mark.parametrize( "interface,dev,diff_method,grad_on_execution", interface_and_qubit_device_and_diff_method diff --git a/tests/interfaces/default_qubit_2_integration/test_jax_qnode_default_qubit_2.py b/tests/interfaces/default_qubit_2_integration/test_jax_qnode_default_qubit_2.py index fb8de4444a6..b6547334fd5 100644 --- a/tests/interfaces/default_qubit_2_integration/test_jax_qnode_default_qubit_2.py +++ b/tests/interfaces/default_qubit_2_integration/test_jax_qnode_default_qubit_2.py @@ -809,6 +809,39 @@ def cost(weights): assert len(res) == 2 + def test_postselection_differentiation(self, dev, diff_method, grad_on_execution): + """Test that when postselecting with default.qubit, differentiation works correctly.""" + + if diff_method in ["adjoint", "spsa", "hadamard"]: + pytest.skip("Diff method does not support postselection.") + + @qml.qnode( + dev, diff_method=diff_method, interface="jax", grad_on_execution=grad_on_execution + ) + def circuit(phi, theta): + qml.RX(phi, wires=0) + qml.CNOT([0, 1]) + qml.measure(wires=0, postselect=1) + qml.RX(theta, wires=1) + return qml.expval(qml.PauliZ(1)) + + @qml.qnode( + dev, diff_method=diff_method, interface="jax", grad_on_execution=grad_on_execution + ) + def expected_circuit(theta): + qml.PauliX(1) + qml.RX(theta, wires=1) + return qml.expval(qml.PauliZ(1)) + + phi = jax.numpy.array(1.23) + theta = jax.numpy.array(4.56) + + assert np.allclose(circuit(phi, theta), expected_circuit(theta)) + + gradient = jax.grad(circuit, argnums=[0, 1])(phi, theta) + exp_theta_grad = jax.grad(expected_circuit)(theta) + assert np.allclose(gradient, [0.0, exp_theta_grad]) + @pytest.mark.parametrize( "interface,dev,diff_method,grad_on_execution", interface_and_device_and_diff_method diff --git a/tests/interfaces/default_qubit_2_integration/test_tensorflow_qnode_default_qubit_2.py b/tests/interfaces/default_qubit_2_integration/test_tensorflow_qnode_default_qubit_2.py index 6a8ea43db89..18b045f5914 100644 --- a/tests/interfaces/default_qubit_2_integration/test_tensorflow_qnode_default_qubit_2.py +++ b/tests/interfaces/default_qubit_2_integration/test_tensorflow_qnode_default_qubit_2.py @@ -916,6 +916,45 @@ def circuit(weights): ] assert np.allclose(grad, expected, atol=tol, rtol=0) + def test_postselection_differentiation(self, dev, diff_method, grad_on_execution, interface): + """Test that when postselecting with default.qubit, differentiation works correctly.""" + + if diff_method in ["adjoint", "spsa", "hadamard"]: + pytest.skip("Diff method does not support postselection.") + + @qml.qnode( + dev, diff_method=diff_method, interface=interface, grad_on_execution=grad_on_execution + ) + def circuit(phi, theta): + qml.RX(phi, wires=0) + qml.CNOT([0, 1]) + qml.measure(wires=0, postselect=1) + qml.RX(theta, wires=1) + return qml.expval(qml.PauliZ(1)) + + @qml.qnode( + dev, diff_method=diff_method, interface=interface, grad_on_execution=grad_on_execution + ) + def expected_circuit(theta): + qml.PauliX(1) + qml.RX(theta, wires=1) + return qml.expval(qml.PauliZ(1)) + + phi = tf.Variable(1.23) + theta = tf.Variable(4.56) + + assert np.allclose(circuit(phi, theta), expected_circuit(theta)) + + with tf.GradientTape() as res_tape: + res = circuit(phi, theta) + gradient = res_tape.gradient(res, [phi, theta]) + + with tf.GradientTape() as expected_tape: + expected = expected_circuit(theta) + exp_theta_grad = expected_tape.gradient(expected, theta) + + assert np.allclose(gradient, [0.0, exp_theta_grad]) + @pytest.mark.parametrize( "interface,dev,diff_method,grad_on_execution", interface_and_qubit_device_and_diff_method diff --git a/tests/interfaces/default_qubit_2_integration/test_torch_qnode_default_qubit_2.py b/tests/interfaces/default_qubit_2_integration/test_torch_qnode_default_qubit_2.py index c2020c2d070..39f171e7180 100644 --- a/tests/interfaces/default_qubit_2_integration/test_torch_qnode_default_qubit_2.py +++ b/tests/interfaces/default_qubit_2_integration/test_torch_qnode_default_qubit_2.py @@ -1141,6 +1141,49 @@ def circuit(x, y): ) assert np.allclose(weights.grad.detach(), expected, atol=tol, rtol=0) + def test_postselection_differentiation( + self, interface, dev, diff_method, grad_on_execution, device_vjp + ): + """Test that when postselecting with default.qubit, differentiation works correctly.""" + + if diff_method in ["adjoint", "spsa", "hadamard"]: + pytest.skip("Diff method does not support postselection.") + + @qml.qnode( + dev, + diff_method=diff_method, + interface=interface, + grad_on_execution=grad_on_execution, + device_vjp=device_vjp, + ) + def circuit(phi, theta): + qml.RX(phi, wires=0) + qml.CNOT([0, 1]) + qml.measure(wires=0, postselect=1) + qml.RX(theta, wires=1) + return qml.expval(qml.PauliZ(1)) + + @qml.qnode( + dev, + diff_method=diff_method, + interface=interface, + grad_on_execution=grad_on_execution, + device_vjp=device_vjp, + ) + def expected_circuit(theta): + qml.PauliX(1) + qml.RX(theta, wires=1) + return qml.expval(qml.PauliZ(1)) + + phi = torch.tensor(1.23, requires_grad=True) + theta = torch.tensor(4.56, requires_grad=True) + + assert qml.math.allclose(circuit(phi, theta), expected_circuit(theta)) + + gradient = torch.autograd.grad(circuit(phi, theta), [phi, theta]) + exp_theta_grad = torch.autograd.grad(expected_circuit(theta), theta)[0] + assert qml.math.allclose(gradient, [0.0, exp_theta_grad]) + @pytest.mark.parametrize( "dev,diff_method,grad_on_execution, device_vjp", qubit_device_and_diff_method diff --git a/tests/math/test_multi_dispatch.py b/tests/math/test_multi_dispatch.py index 1ff28227733..519d86af2c8 100644 --- a/tests/math/test_multi_dispatch.py +++ b/tests/math/test_multi_dispatch.py @@ -293,3 +293,38 @@ def test_inf_norm(self, arr, expected_intrf, expected_norm, kwargs): computed_norm = fn.norm(arr, ord=np.inf, **kwargs) assert np.allclose(computed_norm, expected_norm) assert fn.get_interface(computed_norm) == expected_intrf + + @pytest.mark.parametrize( + "arr", + [ + np.array([1.0, 2.0, 3.0, 4.0, 5.0]), + np.array( + [ + [[0.123, 0.456, 0.789], [-0.123, -0.456, -0.789]], + [[1.23, 4.56, 7.89], [-1.23, -4.56, -7.89]], + ] + ), + np.array( + [ + [ + [0.123 - 0.789j, 0.456 + 0.456j, 0.789 - 0.123j], + [-0.123 + 0.789j, -0.456 - 0.456j, -0.789 + 0.123j], + ], + [ + [1.23 + 4.56j, 4.56 - 7.89j, 7.89 + 1.23j], + [-1.23 - 7.89j, -4.56 + 1.23j, -7.89 - 4.56j], + ], + ] + ), + ], + ) + def test_autograd_norm_gradient(self, arr): + """Test that qml.math.norm has the correct gradient with autograd + when the order and axis are not specified.""" + norm = fn.norm(arr) + expected_norm = onp.linalg.norm(arr) + assert np.isclose(norm, expected_norm) + + grad = qml_grad(fn.norm)(arr) + expected_grad = (norm**-1) * arr.conj() + assert fn.allclose(grad, expected_grad)