diff --git a/pennylane/devices/qubit/dq_interpreter.py b/pennylane/devices/qubit/dq_interpreter.py index c2b3c4b4c76..a6d42b851dc 100644 --- a/pennylane/devices/qubit/dq_interpreter.py +++ b/pennylane/devices/qubit/dq_interpreter.py @@ -122,6 +122,7 @@ def interpret_operation(self, op): self.state = apply_operation(op, self.state, is_state_batched=self.is_state_batched) if op.batch_size: self.is_state_batched = True + return op def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"): if "mcm" in eqn.primitive.name: diff --git a/pennylane/ops/functions/iterative_qpe.py b/pennylane/ops/functions/iterative_qpe.py index 98a05b105c7..c21b9489e0f 100644 --- a/pennylane/ops/functions/iterative_qpe.py +++ b/pennylane/ops/functions/iterative_qpe.py @@ -28,12 +28,12 @@ def iterative_qpe(base, aux_wire, iters): estimation and returns a list of mid-circuit measurements with qubit reset. Args: - base (Operator): the phase estimation unitary, specified as an :class:`~.Operator` - aux_wire (Union[Wires, int, str]): the wire to be used for the estimation - iters (int): the number of measurements to be performed + base (Operator): the phase estimation unitary, specified as an :class:`~.Operator` + aux_wire (Union[Wires, int, str]): the wire to be used for the estimation + iters (int): the number of measurements to be performed Returns: - list[MeasurementValue]: the abstract results of the mid circuit measurements + list[MeasurementValue]: the abstract results of the mid circuit measurements .. seealso:: :class:`~.QuantumPhaseEstimation`, :func:`~.measure` @@ -46,13 +46,13 @@ def iterative_qpe(base, aux_wire, iters): @qml.qnode(dev) def circuit(): - # Initial state - qml.X(0) + # Initial state + qml.X(0) - # Iterative QPE - measurements = qml.iterative_qpe(qml.RZ(2.0, wires=[0]), aux_wire=1, iters=3) + # Iterative QPE + measurements = qml.iterative_qpe(qml.RZ(2.0, wires=[0]), aux_wire=1, iters=3) - return qml.sample(measurements) + return qml.sample(measurements) .. code-block:: pycon @@ -96,7 +96,7 @@ def cond_func(k): qml.cond(meas, cond_func)(j) - g() + g() # pylint: disable=no-value-for-parameter qml.Hadamard(wires=aux_wire) m = qml.measure(wires=aux_wire, reset=True) @@ -107,4 +107,4 @@ def cond_func(k): return measurements, target - return f(measurements, base)[0] + return f(measurements, base)[0] # pylint: disable=no-value-for-parameter