From d336ef099b276d52d84169389219d77fe49c1638 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 17 Oct 2023 18:10:32 -0400 Subject: [PATCH] Updated tests; fixed docs --- doc/introduction/measurements.rst | 4 +- .../default_qubit/test_default_qubit.py | 40 +++++++++++-------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index e9fccdcb7b0..b15b2bd6ceb 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -380,8 +380,8 @@ be equivalent to projecting the state vector onto the :math:`|1\rangle` state on qml.cond(m0, qml.PauliX)(wires=1) return qml.sample(wires=1) -By postselecting on ``1``, we only consider the ``1`` measurement outcome. So, the probability of measuring 1 -on wire 1 should be 100%. Executing this QNode with 10 shots: +By postselecting on ``1``, we only consider the ``1`` measurement outcome on wire 0. So, the probability of +measuring ``1`` on wire 1 after postselection should also be 1. Executing this QNode with 10 shots: >>> func(np.pi / 2, shots=10) array([1, 1, 1, 1, 1, 1, 1]) diff --git a/tests/devices/default_qubit/test_default_qubit.py b/tests/devices/default_qubit/test_default_qubit.py index b85e8b76bf3..e52595e51ec 100644 --- a/tests/devices/default_qubit/test_default_qubit.py +++ b/tests/devices/default_qubit/test_default_qubit.py @@ -1704,14 +1704,14 @@ def circ_expected(): ], ) @pytest.mark.parametrize("param", np.linspace(np.pi / 4, 3 * np.pi / 4, 3)) - @pytest.mark.parametrize("shots", [20000]) + @pytest.mark.parametrize("shots", [50000, (50000, 50000)]) def test_postselection_valid_finite_shots( self, param, mp, shots, interface, use_jit, tol_stochastic ): """Test that the results of a circuit with postselection is expected with finite shots.""" - if use_jit and interface != "jax": - pytest.skip("Cannot JIT in non-JAX interfaces.") + if use_jit and (interface != "jax" or isinstance(shots, tuple)): + pytest.skip("Cannot JIT in non-JAX interfaces, or with shot vectors.") dev = qml.device("default.qubit", shots=shots) param = qml.math.asarray(param, like=interface) @@ -1742,30 +1742,35 @@ def circ_expected(): assert qml.math.get_interface(res) == qml.math.get_interface(expected) else: - # No testing with shot vectors currently, but keeping this here so that - # we can just use it once shot vectors are supported assert isinstance(res, tuple) for r, e in zip(res, expected): assert qml.math.allclose(r, e, atol=tol_stochastic, rtol=0) assert qml.math.get_interface(r) == qml.math.get_interface(e) @pytest.mark.parametrize( - "mp, autograd_interface", + "mp, autograd_interface, is_nan", [ - (qml.expval(qml.PauliZ(0)), "autograd"), - (qml.var(qml.PauliZ(0)), "autograd"), - (qml.probs(wires=[0, 1]), "autograd"), - (qml.state(), "autograd"), - (qml.density_matrix(wires=0), "autograd"), - (qml.purity(0), "numpy"), - # qml.vn_entropy(0), - # qml.mutual_info(0, 1), + (qml.expval(qml.PauliZ(0)), "autograd", True), + (qml.var(qml.PauliZ(0)), "autograd", True), + (qml.probs(wires=[0, 1]), "autograd", True), + (qml.state(), "autograd", True), + (qml.density_matrix(wires=0), "autograd", True), + (qml.purity(0), "numpy", True), + (qml.vn_entropy(0), "numpy", False), + (qml.mutual_info(0, 1), "numpy", False), ], ) - def test_postselection_invalid_analytic(self, mp, autograd_interface, interface, use_jit): + def test_postselection_invalid_analytic( + self, mp, autograd_interface, is_nan, interface, use_jit + ): """Test that the results of a qnode are nan values of the correct shape if the state that we are postselecting has a zero probability of occurring.""" + if (isinstance(mp, qml.measurements.MutualInfoMP) and interface != "jax") or ( + isinstance(mp, qml.measurements.VnEntropyMP) and interface == "tensorflow" + ): + pytest.skip("Unsupported measurements and interfaces.") + if use_jit and interface != "jax": pytest.skip("Can't jit with non-jax interfaces.") @@ -1790,7 +1795,10 @@ def circ(): else: assert qml.math.get_interface(res) == interface assert qml.math.shape(res) == mp.shape(dev, qml.measurements.Shots(None)) - assert qml.math.all(qml.math.isnan(res)) + if is_nan: + assert qml.math.all(qml.math.isnan(res)) + else: + assert qml.math.allclose(res, 0.0) @pytest.mark.parametrize( "mp, expected_shape",