diff --git a/.github/workflows/install_deps/action.yml b/.github/workflows/install_deps/action.yml index b2084e4ea05..61a0f763d03 100644 --- a/.github/workflows/install_deps/action.yml +++ b/.github/workflows/install_deps/action.yml @@ -31,7 +31,7 @@ inputs: pytorch_version: description: The version of PyTorch to install for any job that requires PyTorch required: false - default: 2.2.0 + default: 2.3.0 install_pennylane_lightning_master: description: Indicate if PennyLane-Lightning should be installed from the master branch required: false diff --git a/pennylane/ops/qubit/hamiltonian.py b/pennylane/ops/qubit/hamiltonian.py index 19bf32982ba..f5117f42418 100644 --- a/pennylane/ops/qubit/hamiltonian.py +++ b/pennylane/ops/qubit/hamiltonian.py @@ -220,7 +220,15 @@ def __init__( self._grouping_indices = None if simplify: - self.simplify() + # simplify upon initialization changes ops such that they wouldnt be + # removed in self.queue() anymore, removing them here manually. + if qml.QueuingManager.recording(): + for o in observables: + qml.QueuingManager.remove(o) + + with qml.QueuingManager.stop_recording(): + self.simplify() + if grouping_type is not None: with qml.QueuingManager.stop_recording(): self._grouping_indices = _compute_grouping_indices( diff --git a/pennylane/templates/subroutines/trotter.py b/pennylane/templates/subroutines/trotter.py index 68946909e0b..c335996eafa 100644 --- a/pennylane/templates/subroutines/trotter.py +++ b/pennylane/templates/subroutines/trotter.py @@ -189,9 +189,13 @@ def __init__( # pylint: disable=too-many-arguments raise ValueError( "There should be at least 2 terms in the Hamiltonian. Otherwise use `qml.exp`" ) + if qml.QueuingManager.recording(): + qml.QueuingManager.remove(hamiltonian) hamiltonian = qml.dot(coeffs, ops) if isinstance(hamiltonian, SProd): + if qml.QueuingManager.recording(): + qml.QueuingManager.remove(hamiltonian) hamiltonian = hamiltonian.simplify() if len(hamiltonian.terms()[0]) < 2: raise ValueError( diff --git a/tests/devices/qubit/test_measure.py b/tests/devices/qubit/test_measure.py index 370db600248..d8ae9a98cc1 100644 --- a/tests/devices/qubit/test_measure.py +++ b/tests/devices/qubit/test_measure.py @@ -192,6 +192,10 @@ def qnode(t1, t2): def test_measure_identity_no_wires(self): """Test that measure can handle the expectation value of identity on no wires.""" + + if not qml.operation.active_new_opmath(): + pytest.skip("Identity with no wires is not supported with legacy opmath.") + state = np.random.random([2, 2, 2]) out = measure(qml.measurements.ExpectationMP(qml.I()), state) assert qml.math.allclose(out, 1.0) diff --git a/tests/devices/qubit/test_sampling.py b/tests/devices/qubit/test_sampling.py index 2d2bd98b80e..79c6efb993d 100644 --- a/tests/devices/qubit/test_sampling.py +++ b/tests/devices/qubit/test_sampling.py @@ -521,6 +521,9 @@ def test_identity_on_no_wires(self): def test_identity_on_no_wires_with_other_observables(self): """Test that measuring an identity on no wires can be used in conjunction with other measurements.""" + if not qml.operation.active_new_opmath(): + pytest.skip("Identity with no wires is not supported with legacy opmath.") + state = np.array([0, 1]) mps = [ diff --git a/tests/ops/op_math/test_linear_combination.py b/tests/ops/op_math/test_linear_combination.py index b1d515e7747..a71714b50dd 100644 --- a/tests/ops/op_math/test_linear_combination.py +++ b/tests/ops/op_math/test_linear_combination.py @@ -1628,6 +1628,16 @@ def circuit(): # simplify worked and added 1. and 2. assert pars == [0.1, 3.0] + @pytest.mark.usefixtures("use_legacy_and_new_opmath") + def test_queuing_behaviour(self): + """Tests that the base observables are correctly dequeued with simplify=True""" + + with qml.queuing.AnnotatedQueue() as q: + obs = qml.Hamiltonian([1, 1, 1], [qml.X(0), qml.X(0), qml.Z(0)], simplify=True) + + assert len(q) == 1 + assert q.queue[0] == obs + class TestLinearCombinationDifferentiation: """Test that the LinearCombination coefficients are differentiable""" diff --git a/tests/pauli/grouping/test_pauli_group_observables.py b/tests/pauli/grouping/test_pauli_group_observables.py index 63411c411fc..dd49115d216 100644 --- a/tests/pauli/grouping/test_pauli_group_observables.py +++ b/tests/pauli/grouping/test_pauli_group_observables.py @@ -469,6 +469,9 @@ def test_observables_on_no_wires_coeffs(self): """Test that observables on no wires are stuck in the first group and coefficients are tracked when provided.""" + if not qml.operation.active_new_opmath(): + pytest.skip("Identity with no wires is not supported with legacy opmath.") + observables = [ qml.X(0), qml.Z(0), diff --git a/tests/templates/test_subroutines/test_trotter.py b/tests/templates/test_subroutines/test_trotter.py index 88f87801a44..71fcc68a22c 100644 --- a/tests/templates/test_subroutines/test_trotter.py +++ b/tests/templates/test_subroutines/test_trotter.py @@ -399,11 +399,19 @@ def test_convention_approx_time_evolv(self, time, n): qml.matrix(op2, wire_order=hamiltonian.wires), ) - def test_queuing(self): + @pytest.mark.parametrize( + "make_H", + [ + lambda: qml.Hamiltonian([1, 1], [qml.PauliX(0), qml.PauliY(1)]), + lambda: qml.sum(qml.PauliX(0), qml.PauliY(1)), + lambda: qml.s_prod(1.2, qml.PauliX(0) + qml.PauliY(1)), + ], + ) + def test_queuing(self, make_H): """Test that the target operator is removed from the queue.""" with qml.queuing.AnnotatedQueue() as q: - H = qml.X(0) + qml.Y(1) + H = make_H() op = qml.TrotterProduct(H, time=2) assert len(q.queue) == 1 diff --git a/tests/test_qnode_legacy.py b/tests/test_qnode_legacy.py index e6070d0927d..900fe6771bf 100644 --- a/tests/test_qnode_legacy.py +++ b/tests/test_qnode_legacy.py @@ -1881,6 +1881,9 @@ def circuit(): def test_multiple_hamiltonian_expansion_finite_shots(self, grouping): """Test that multiple Hamiltonians works correctly (sum_expand should be used)""" + if not qml.operation.active_new_opmath(): + pytest.skip("expval of the legacy Hamiltonian does not support finite shots.") + dev = qml.device("default.qubit.legacy", wires=3, shots=50000) obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)]