diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 4a9add0c35c..4c595ff07f9 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -4,6 +4,10 @@
New features since last release
+* `qml.measure` now includes a boolean keyword argument `reset` to reset a wire to the
+ $|0\rangle$ computational basis state after measurement.
+ [(#4402)](https://github.com/PennyLaneAI/pennylane/pull/4402/)
+
* `DefaultQubit2` accepts a `max_workers` argument which controls multiprocessing.
A `ProcessPoolExecutor` executes tapes asynchronously
using a pool of at most `max_workers` processes. If `max_workers` is `None`
@@ -63,6 +67,9 @@ array([False, False])
Improvements 🛠
+* Wires can now be reused after making a mid-circuit measurement on them.
+ [(#4402)](https://github.com/PennyLaneAI/pennylane/pull/4402/)
+
* Transform Programs, `qml.transforms.core.TransformProgram`, can now be called on a batch of circuits
and return a new batch of circuits and a single post processing function.
[(#4364)](https://github.com/PennyLaneAI/pennylane/pull/4364)
diff --git a/pennylane/_qubit_device.py b/pennylane/_qubit_device.py
index e89b2cb3b04..a6618ac4be3 100644
--- a/pennylane/_qubit_device.py
+++ b/pennylane/_qubit_device.py
@@ -136,6 +136,7 @@ class QubitDevice(Device):
_real = staticmethod(np.real)
_size = staticmethod(np.size)
_ndim = staticmethod(np.ndim)
+ _norm = staticmethod(np.linalg.norm)
@staticmethod
def _scatter(indices, array, new_dimensions):
diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py
index 0a7d739a879..7c0e4ae22e5 100644
--- a/pennylane/devices/default_qubit.py
+++ b/pennylane/devices/default_qubit.py
@@ -270,6 +270,7 @@ def apply(self, operations, rotations=None, **kwargs):
# apply the circuit operations
for i, operation in enumerate(operations):
+ print(self._state.shape)
if i > 0 and isinstance(operation, (QubitStateVector, BasisState)):
raise DeviceError(
f"Operation {operation.name} cannot be used after other Operations have already been applied "
@@ -333,12 +334,17 @@ def _apply_operation(self, state, operation):
matrix = self._asarray(self._get_unitary_matrix(operation), dtype=self.C_DTYPE)
if operation in diagonal_in_z_basis:
- return self._apply_diagonal_unitary(state, matrix, wires)
- if len(wires) <= 2:
+ new_state = self._apply_diagonal_unitary(state, matrix, wires)
+ elif len(wires) <= 2:
# Einsum is faster for small gates
- return self._apply_unitary_einsum(state, matrix, wires)
+ new_state = self._apply_unitary_einsum(state, matrix, wires)
+ else:
+ new_state = self._apply_unitary(state, matrix, wires)
+
+ if operation.__class__.__name__ in {"Projector", "_BasisStateProjector"}:
+ new_state = new_state / self._norm(new_state)
- return self._apply_unitary(state, matrix, wires)
+ return new_state
def _apply_x(self, state, axes, **kwargs):
"""Applies a PauliX gate by rolling 1 unit along the axis specified in ``axes``.
diff --git a/pennylane/devices/default_qubit_jax.py b/pennylane/devices/default_qubit_jax.py
index 411d77fea02..c29590367a1 100644
--- a/pennylane/devices/default_qubit_jax.py
+++ b/pennylane/devices/default_qubit_jax.py
@@ -162,6 +162,7 @@ def circuit():
_const_mul = staticmethod(jnp.multiply)
_size = staticmethod(jnp.size)
_ndim = staticmethod(jnp.ndim)
+ _norm = staticmethod(jnp.linalg.norm)
operations = DefaultQubit.operations.union({"ParametrizedEvolution"})
diff --git a/pennylane/devices/default_qubit_tf.py b/pennylane/devices/default_qubit_tf.py
index 702bf34af0b..bf12988e0af 100644
--- a/pennylane/devices/default_qubit_tf.py
+++ b/pennylane/devices/default_qubit_tf.py
@@ -136,6 +136,7 @@ class DefaultQubitTF(DefaultQubit):
_stack = staticmethod(tf.stack)
_size = staticmethod(tf.size)
_ndim = staticmethod(_ndim_tf)
+ _norm = staticmethod(tf.norm)
@staticmethod
def _const_mul(constant, array):
diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py
index 71155b2892f..7024aa9ba41 100644
--- a/pennylane/devices/qubit/apply_operation.py
+++ b/pennylane/devices/qubit/apply_operation.py
@@ -190,8 +190,12 @@ def _(op: type_op, state):
len(op.wires) < EINSUM_OP_WIRECOUNT_PERF_THRESHOLD
and math.ndim(state) < EINSUM_STATE_WIRECOUNT_PERF_THRESHOLD
) or (op.batch_size and is_state_batched):
- return apply_operation_einsum(op, state, is_state_batched=is_state_batched)
- return apply_operation_tensordot(op, state, is_state_batched=is_state_batched)
+ new_state = apply_operation_einsum(op, state, is_state_batched=is_state_batched)
+ else:
+ new_state = apply_operation_tensordot(op, state, is_state_batched=is_state_batched)
+
+ if op.__class__.__name__ in {"Projector", "_BasisStateProjector"}:
+ new_state = new_state / math.norm(new_state)
@apply_operation.register
diff --git a/pennylane/drawer/drawable_layers.py b/pennylane/drawer/drawable_layers.py
index 0042b9e3122..9a29b90e320 100644
--- a/pennylane/drawer/drawable_layers.py
+++ b/pennylane/drawer/drawable_layers.py
@@ -97,11 +97,6 @@ def drawable_layers(ops, wire_map=None):
# loop over operations
for op in ops:
is_mid_measure = is_conditional = False
- if set(measured_wires.values()).intersection({wire_map[w] for w in op.wires}):
- raise ValueError(
- f"Cannot apply operations on {op.wires} as some wires have been measured already."
- )
-
if isinstance(op, MidMeasureMP):
if len(op.wires) > 1:
raise ValueError("Cannot draw mid-circuit measurements with more than one wire.")
diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py
index 2de61e178a8..a4b51bf2a13 100644
--- a/pennylane/measurements/mid_measure.py
+++ b/pennylane/measurements/mid_measure.py
@@ -24,8 +24,10 @@
from .measurements import MeasurementProcess, MidMeasure
-def measure(wires): # TODO: Change name to mid_measure
- """Perform a mid-circuit measurement in the computational basis on the
+def measure(
+ wires: Wires, reset: Optional[bool] = False, postselect: Optional[int] = None
+): # TODO: Change name to mid_measure
+ r"""Perform a mid-circuit measurement in the computational basis on the
supplied qubit.
Measurement outcomes can be obtained and used to conditionally apply
@@ -38,7 +40,7 @@ def measure(wires): # TODO: Change name to mid_measure
.. code-block:: python3
- dev = qml.device("default.qubit", wires=2)
+ dev = qml.device("default.qubit", wires=3)
@qml.qnode(dev)
def func(x, y):
@@ -55,16 +57,40 @@ def func(x, y):
>>> func(*pars)
tensor([0.90165331, 0.09834669], requires_grad=True)
+ Wires can be reused after measurement. Moreover, measured wires can be reset
+ to the :math:`|0 \rangle` by setting ``reset=True``.
+
+ .. code-block:: python3
+
+ dev = qml.device("default.qubit", wires=3)
+
+ @qml.qnode(dev)
+ def func():
+ qml.PauliX(1)
+ m_0 = qml.measure(1, reset=True)
+ return qml.probs(wires=[1])
+
+ Executing this QNode:
+
+ >>> func()
+ tensor([1., 0.], requires_grad=True)
+
Mid circuit measurements can be manipulated using the following dunder methods
``+``, ``-``, ``*``, ``/``, ``~`` (not), ``&`` (and), ``|`` (or), ``==``, ``<=``,
``>=``, ``<``, ``>`` with other mid-circuit measurements or scalars.
- Note:
- python ``not``, ``and``, ``or``, do not work since these do not have dunder
- methods. Instead use ``~``, ``&``, ``|``.
+ .. Note ::
+
+ Python ``not``, ``and``, ``or``, do not work since these do not have dunder methods.
+ Instead use ``~``, ``&``, ``|``.
Args:
wires (Wires): The wire of the qubit the measurement process applies to.
+ reset (Optional[bool]): Whether to reset the wire to the :math:`|0 \rangle`
+ state after measurement.
+ postselect (Optional[int]): The measured computational basis state on which to
+ optionally postselect the circuit. Must be ``0`` or ``1`` if postselection
+ is requested.
Returns:
MidMeasureMP: measurement process instance
@@ -72,6 +98,7 @@ def func(x, y):
Raises:
QuantumFunctionError: if multiple wires were specified
"""
+
wire = Wires(wires)
if len(wire) > 1:
raise qml.QuantumFunctionError(
@@ -80,7 +107,7 @@ def func(x, y):
# Create a UUID and a map between MP and MV to support serialization
measurement_id = str(uuid.uuid4())[:8]
- mp = MidMeasureMP(wires=wire, id=measurement_id)
+ mp = MidMeasureMP(wires=wire, reset=reset, postselect=postselect, id=measurement_id)
return MeasurementValue([mp], processing_fn=lambda v: v)
@@ -90,17 +117,31 @@ def func(x, y):
class MidMeasureMP(MeasurementProcess):
"""Mid-circuit measurement.
+ This class additionally stores information about unknown measurement outcomes in the qubit model.
+ Measurements on a single qubit in the computational basis are assumed.
+
Please refer to :func:`measure` for detailed documentation.
Args:
wires (.Wires): The wires the measurement process applies to.
This can only be specified if an observable was not provided.
- id (str): custom label given to a measurement instance, can be useful for some applications
- where the instance has to be identified
+ reset (bool): Whether to reset the wire after measurement.
+ postselect (Optional[int]): The measured computational basis state on which to
+ optionally postselect the circuit. Must be ``0`` or ``1`` if postselection
+ is requested.
+ id (str): Custom label given to a measurement instance.
"""
- def __init__(self, wires: Optional[Wires] = None, id: Optional[str] = None):
- super().__init__(wires=wires, id=id)
+ def __init__(
+ self,
+ wires: Optional[Wires] = None,
+ reset: Optional[bool] = False,
+ postselect: Optional[int] = None,
+ id: Optional[str] = None,
+ ):
+ super().__init__(wires=Wires(wires), id=id)
+ self.reset = reset
+ self.postselect = postselect
@property
def return_type(self):
diff --git a/pennylane/ops/qubit/observables.py b/pennylane/ops/qubit/observables.py
index 9246de99dc7..3b48e83b058 100644
--- a/pennylane/ops/qubit/observables.py
+++ b/pennylane/ops/qubit/observables.py
@@ -23,7 +23,7 @@
from scipy.sparse import csr_matrix
import pennylane as qml
-from pennylane.operation import AnyWires, Observable
+from pennylane.operation import AnyWires, Observable, Operation
from pennylane.wires import Wires
from .matrix_ops import QubitUnitary
@@ -428,7 +428,7 @@ def __copy__(self):
return copied_op
-class _BasisStateProjector(Observable):
+class _BasisStateProjector(Observable, Operation):
# The call signature should be the same as Projector.__new__ for the positional
# arguments, but with free key word arguments.
def __init__(self, state, wires, id=None):
diff --git a/pennylane/transforms/condition.py b/pennylane/transforms/condition.py
index 1c3530b0eff..d5dcd4ea6e7 100644
--- a/pennylane/transforms/condition.py
+++ b/pennylane/transforms/condition.py
@@ -67,7 +67,7 @@ def cond(condition, true_fn, false_fn=None):
:func:`defer_measurements` transform.
Args:
- condition (.MeasurementValue[bool]): a conditional expression involving a mid-circuit
+ condition (.MeasurementValue): a conditional expression involving a mid-circuit
measurement value (see :func:`.pennylane.measure`)
true_fn (callable): The quantum function of PennyLane operation to
apply if ``condition`` is ``True``
@@ -114,6 +114,8 @@ def qnode(x, y):
Expressions with boolean logic flow using operators like ``and``,
``or`` and ``not`` are not supported as the ``condition`` argument.
+ While such statements may not result in errors, they may result in
+ incorrect behaviour.
.. details::
:title: Usage Details
diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py
index 29c13cdfe27..2cc04f656d8 100644
--- a/pennylane/transforms/defer_measurements.py
+++ b/pennylane/transforms/defer_measurements.py
@@ -16,12 +16,13 @@
from pennylane.measurements import MidMeasureMP
from pennylane.ops.op_math import ctrl
from pennylane.queuing import apply
+from pennylane.tape import QuantumTape
from pennylane.transforms import qfunc_transform
from pennylane.wires import Wires
@qfunc_transform
-def defer_measurements(tape):
+def defer_measurements(tape: QuantumTape):
"""Quantum function transform that substitutes operations conditioned on
measurement outcomes to controlled operations.
@@ -40,6 +41,13 @@ def defer_measurements(tape):
that can be controlled as such depends on the set of operations
supported by the chosen device.
+ .. note::
+
+ Devices that inherit `QubitDevice` **must** be initialized with an additional
+ wire for each mid-circuit measurement for `defer_measurements` to transform
+ the quantum tape correctly. Such devices should also be initialized without
+ custom wire labels for correct behaviour.
+
.. note::
This transform does not change the list of terminal measurements returned by
@@ -54,7 +62,7 @@ def defer_measurements(tape):
post-measurement states are considered.
Args:
- qfunc (function): a quantum function
+ tape (.QuantumTape): a quantum tape
**Example**
@@ -84,8 +92,30 @@ def qfunc(par):
>>> qml.grad(qnode)(par)
tensor(-0.49622252, requires_grad=True)
+
+ Reusing and reseting measured wires will work as expected with the
+ ``defer_measurements`` transform:
+
+ .. code-block:: python3
+
+ dev = qml.device("default.qubit", wires=3)
+
+ @qml.qnode(dev)
+ def func(x, y):
+ qml.RY(x, wires=0)
+ qml.CNOT(wires=[0, 1])
+ m_0 = qml.measure(1, reset=True)
+
+ qml.cond(m_0, qml.RY)(y, wires=0)
+ qml.RX(np.pi/4, wires=1)
+ return qml.probs(wires=[0, 1])
+
+ Executing this QNode:
+
+ >>> pars = np.array([0.643, 0.246], requires_grad=True)
+ >>> func(*pars)
+ tensor([0.76960924, 0.13204407, 0.08394415, 0.01440254], requires_grad=True)
"""
- measured_wires = {}
cv_types = (qml.operation.CVOperation, qml.operation.CVObservable)
ops_cv = any(isinstance(op, cv_types) for op in tape.operations)
@@ -93,25 +123,34 @@ def qfunc(par):
if ops_cv or obs_cv:
raise ValueError("Continuous variable operations and observables are not supported.")
- for op in tape:
- op_wires_measured = set(wire for wire in op.wires if wire in measured_wires.values())
- if len(op_wires_measured) > 0:
- raise ValueError(
- f"Cannot apply operations on {op.wires} as the following wires have been measured already: {op_wires_measured}."
- )
+ # Current wire in which pre-measurement state will be saved if dev_wires not specified
+ cur_wire = max(tape.wires) + 1
+ new_wires = {}
+ for op in tape:
if isinstance(op, MidMeasureMP):
- measured_wires[op.id] = op.wires[0]
+ new_wires[op.id] = cur_wire
+ qml.CNOT([op.wires[0], cur_wire])
+
+ if op.postselect is not None:
+ qml.Projector([op.postselect], wires=op.wires[0])
+
+ if op.reset:
+ qml.CNOT([cur_wire, op.wires[0]])
+
+ cur_wire += 1
elif op.__class__.__name__ == "Conditional":
- _add_control_gate(op, measured_wires)
+ _add_control_gate(op, new_wires)
else:
apply(op)
+ return tape._qfunc_output # pylint: disable=protected-access
+
-def _add_control_gate(op, measured_wires):
+def _add_control_gate(op, control_wires):
"""Helper function to add control gates"""
- control = [measured_wires[m.id] for m in op.meas_val.measurements]
+ control = [control_wires[m.id] for m in op.meas_val.measurements]
for branch, value in op.meas_val._items(): # pylint: disable=protected-access
if value:
ctrl(
diff --git a/tests/devices/qubit/test_preprocess.py b/tests/devices/qubit/test_preprocess.py
index 9c1ad571a73..281c32be9b2 100644
--- a/tests/devices/qubit/test_preprocess.py
+++ b/tests/devices/qubit/test_preprocess.py
@@ -230,7 +230,8 @@ def test_expand_fn_defer_measurement(self):
expanded_tape = expand_fn(tape)
expected = [
qml.Hadamard(0),
- qml.ops.Controlled(qml.RX(0.123, wires=1), 0),
+ qml.CNOT([0, 2]),
+ qml.ops.Controlled(qml.RX(0.123, wires=1), 2),
]
for op, exp in zip(expanded_tape, expected + measurements):
diff --git a/tests/drawer/test_drawable_layers.py b/tests/drawer/test_drawable_layers.py
index a842a3c915a..5bac4268d6a 100644
--- a/tests/drawer/test_drawable_layers.py
+++ b/tests/drawer/test_drawable_layers.py
@@ -173,23 +173,7 @@ def test_empty_layers_are_pruned(self):
layers = drawable_layers(ops, wire_map={i: i for i in range(3)})
assert layers == [[ops[1]], [ops[2], ops[0]], [ops[3]]]
- def test_cannot_reuse_wire_after_conditional(self):
- """Tests that a wire cannot be re-used after using a mid-circuit measurement."""
- with AnnotatedQueue() as q:
- m0 = qml.measure(0)
- qml.cond(m0, qml.PauliX)(1)
- qml.Hadamard(0)
-
- with pytest.raises(ValueError, match="some wires have been measured already"):
- drawable_layers(q.queue)
-
def test_cannot_draw_multi_wire_MidMeasureMP(self):
"""Tests that MidMeasureMP is only supported with one wire."""
with pytest.raises(ValueError, match="mid-circuit measurements with more than one wire."):
drawable_layers([MidMeasureMP([0, 1])])
-
- def test_cannot_use_measured_wire(self):
- """Tests error is raised when trying to use a measured wire."""
- ops = [MidMeasureMP([0]), qml.PauliX(0)]
- with pytest.raises(ValueError, match="some wires have been measured already"):
- drawable_layers(ops)
diff --git a/tests/measurements/test_mid_measure.py b/tests/measurements/test_mid_measure.py
index 06352b0a1de..6aa70a21457 100644
--- a/tests/measurements/test_mid_measure.py
+++ b/tests/measurements/test_mid_measure.py
@@ -26,7 +26,7 @@
def test_samples_computational_basis():
"""Test that samples_computational_basis is always false for mid circuit measurements."""
- m = qml.measurements.MidMeasureMP(Wires(0))
+ m = MidMeasureMP(Wires(0))
assert not m.samples_computational_basis
@@ -44,19 +44,19 @@ def test_many_wires_error(self):
def test_hash(self):
"""Test that the hash for `MidMeasureMP` is defined correctly."""
- m1 = MidMeasureMP(Wires(0), "m1")
- m2 = MidMeasureMP(Wires(0), "m2")
- m3 = MidMeasureMP(Wires(1), "m1")
- m4 = MidMeasureMP(Wires(0), "m1")
+ m1 = MidMeasureMP(Wires(0), id="m1")
+ m2 = MidMeasureMP(Wires(0), id="m2")
+ m3 = MidMeasureMP(Wires(1), id="m1")
+ m4 = MidMeasureMP(Wires(0), id="m1")
assert m1.hash != m2.hash
assert m1.hash != m3.hash
assert m1.hash == m4.hash
-mp1 = MidMeasureMP(Wires(0), "m0")
-mp2 = MidMeasureMP(Wires(1), "m1")
-mp3 = MidMeasureMP(Wires(2), "m2")
+mp1 = MidMeasureMP(Wires(0), id="m0")
+mp2 = MidMeasureMP(Wires(1), id="m1")
+mp3 = MidMeasureMP(Wires(2), id="m2")
class TestMeasurementValueManipulation:
diff --git a/tests/test_qnode.py b/tests/test_qnode.py
index 95067f2b2d0..30705d375ff 100644
--- a/tests/test_qnode.py
+++ b/tests/test_qnode.py
@@ -1077,7 +1077,7 @@ def test_defer_meas_if_mcm_unsupported(self, first_par, sec_par, return_type):
"""Tests that the transform using the deferred measurement principle is
applied if the device doesn't support mid-circuit measurements
natively."""
- dev = qml.device("default.qubit", wires=2)
+ dev = qml.device("default.qubit", wires=3)
@qml.qnode(dev)
def cry_qnode(x, y):
@@ -1105,7 +1105,7 @@ def conditional_ry_qnode(x, y):
def test_sampling_with_mcm(self, basis_state):
"""Tests that a QNode with qml.sample and mid-circuit measurements
returns the expected results."""
- dev = qml.device("default.qubit", wires=2, shots=1000)
+ dev = qml.device("default.qubit", wires=3, shots=1000)
first_par = np.pi
@@ -1135,7 +1135,7 @@ def test_conditional_ops_tensorflow(self, interface):
"""Test conditional operations with TensorFlow."""
import tensorflow as tf
- dev = qml.device("default.qubit", wires=2)
+ dev = qml.device("default.qubit", wires=3)
@qml.qnode(dev, interface=interface, diff_method="parameter-shift")
def cry_qnode(x):
@@ -1178,7 +1178,7 @@ def test_conditional_ops_torch(self, interface):
"""Test conditional operations with Torch."""
import torch
- dev = qml.device("default.qubit", wires=2)
+ dev = qml.device("default.qubit", wires=3)
@qml.qnode(dev, interface=interface, diff_method="parameter-shift")
def cry_qnode(x):
@@ -1217,7 +1217,7 @@ def test_conditional_ops_jax(self, jax_interface):
import jax
jnp = jax.numpy
- dev = qml.device("default.qubit", wires=2)
+ dev = qml.device("default.qubit", wires=3)
@qml.qnode(dev, interface=jax_interface, diff_method="parameter-shift")
def cry_qnode(x):
@@ -1246,20 +1246,6 @@ def conditional_ry_qnode(x):
assert np.allclose(r1, r2)
assert np.allclose(jax.grad(cry_qnode)(x1), jax.grad(conditional_ry_qnode)(x2))
- def test_already_measured_error_operation(self):
- """Test that attempting to apply an operation on a wires that has been
- measured raises an error."""
- dev = qml.device("default.qubit", wires=3)
-
- @qml.qnode(dev)
- def circuit():
- qml.measure(1)
- qml.PauliX(1)
- return qml.expval(qml.PauliZ(0))
-
- with pytest.raises(ValueError, match="wires have been measured already: {1}"):
- circuit()
-
def test_qnode_does_not_support_nested_queuing(self):
"""Test that operators in QNodes are not queued to surrounding contexts."""
dev = qml.device("default.qubit", wires=1)
diff --git a/tests/transforms/test_defer_measurements.py b/tests/transforms/test_defer_measurements.py
index 6e247e1c8f2..80cf97c2978 100644
--- a/tests/transforms/test_defer_measurements.py
+++ b/tests/transforms/test_defer_measurements.py
@@ -47,19 +47,29 @@ def qnode2():
assert isinstance(res1, type(res2))
assert res1.shape == res2.shape
- assert len(qnode1.qtape.operations) == len(qnode2.qtape.operations)
+ assert len(qnode2.qtape.operations) == 1
+ assert isinstance(qnode2.qtape.operations[0], qml.CNOT)
assert len(qnode1.qtape.measurements) == len(qnode2.qtape.measurements)
- # Check the operations
- for op1, op2 in zip(qnode1.qtape.operations, qnode2.qtape.operations):
- assert isinstance(op1, type(op2))
- assert op1.data == op2.data
-
# Check the measurements
for op1, op2 in zip(qnode1.qtape.measurements, qnode2.qtape.measurements):
assert isinstance(op1, type(op2))
assert op1.data == op2.data
+ def test_reuse_wire_after_measurement(self):
+ """Test that wires can be reused after measurement."""
+ dev = qml.device("default.qubit", wires=2)
+
+ @qml.qnode(dev)
+ @qml.defer_measurements
+ def qnode():
+ qml.Hadamard(0)
+ qml.measure(0)
+ qml.PauliZ(0)
+ return qml.expval(qml.PauliX(0))
+
+ _ = qnode()
+
def test_measure_between_ops(self):
"""Test that a quantum function that contains one operation before and
after a mid-circuit measurement yields the correct results and is
@@ -87,11 +97,14 @@ def func2():
assert isinstance(res1, type(res2))
assert res1.shape == res2.shape
- assert len(qnode1.qtape.operations) == len(qnode2.qtape.operations)
+ assert len(qnode2.qtape.operations) == len(qnode1.qtape.operations) + 1
assert len(qnode1.qtape.measurements) == len(qnode2.qtape.measurements)
# Check the operations
- for op1, op2 in zip(qnode1.qtape.operations, qnode2.qtape.operations):
+ deferred_ops = qnode2.qtape.operations
+ assert qml.equal(deferred_ops.pop(1), qml.CNOT([1, 2]))
+
+ for op1, op2 in zip(qnode1.qtape.operations, deferred_ops):
assert isinstance(op1, type(op2))
assert op1.data == op2.data
@@ -107,6 +120,9 @@ def test_measure_with_tensor_obs(self, mid_measure_wire, tp_wires):
"""Test that the defer_measurements transform works well even with
tensor observables in the tape."""
# pylint: disable=protected-access
+ if isinstance(mid_measure_wire, str):
+ pytest.skip("defer_measurements does not support custom wire labels.")
+
with qml.queuing.AnnotatedQueue() as q:
qml.measure(mid_measure_wire)
qml.expval(qml.operation.Tensor(*[qml.PauliZ(w) for w in tp_wires]))
@@ -115,7 +131,10 @@ def test_measure_with_tensor_obs(self, mid_measure_wire, tp_wires):
tape = qml.defer_measurements(tape)
# Check the operations and measurements in the tape
- assert tape._ops == []
+ assert len(tape._ops) == 1
+ assert qml.equal(
+ tape._ops[0], qml.CNOT([mid_measure_wire, max([mid_measure_wire] + tp_wires) + 1])
+ )
assert len(tape.measurements) == 1
measurement = tape.measurements[0]
@@ -128,37 +147,6 @@ def test_measure_with_tensor_obs(self, mid_measure_wire, tp_wires):
assert isinstance(ob, qml.PauliZ)
assert ob.wires == qml.wires.Wires(tp_wires[idx])
- def test_already_measured_error_operation(self):
- """Test that attempting to apply an operation on a wires that has been
- measured raises an error."""
- dev = qml.device("default.qubit", wires=3)
-
- def qfunc():
- qml.measure(1)
- qml.PauliX(1)
- return qml.expval(qml.PauliZ(0))
-
- tape_deferred_func = qml.defer_measurements(qfunc)
- qnode = qml.QNode(tape_deferred_func, dev)
-
- with pytest.raises(ValueError, match="wires have been measured already: {1}"):
- qnode()
-
- def test_already_measured_error_terminal_measurement(self):
- """Test that attempting to measure a wire at the end of the circuit
- that has been measured in the middle of the circuit raises an error."""
- dev = qml.device("default.qubit", wires=3)
-
- def qfunc():
- qml.measure(1)
- return qml.expval(qml.PauliZ(1))
-
- tape_deferred_func = qml.defer_measurements(qfunc)
- qnode = qml.QNode(tape_deferred_func, dev)
-
- with pytest.raises(ValueError, match="Cannot apply operations"):
- qnode()
-
def test_cv_op_error(self):
"""Test that CV operations are not supported."""
dev = qml.device("default.gaussian", wires=3)
@@ -216,15 +204,15 @@ def test_correct_ops_in_tape(self, terminal_measurement):
tape = qml.tape.QuantumScript.from_queue(q)
tape = qml.defer_measurements(tape)
- assert len(tape.operations) == 2
+ assert len(tape.operations) == 4
assert len(tape.measurements) == 1
# Check the two underlying Controlled instances
- first_ctrl_op = tape.operations[0]
+ first_ctrl_op = tape.operations[1]
assert isinstance(first_ctrl_op, qml.ops.op_math.Controlled)
assert qml.equal(first_ctrl_op.base, qml.RY(first_par, 1))
- sec_ctrl_op = tape.operations[1]
+ sec_ctrl_op = tape.operations[3]
assert isinstance(sec_ctrl_op, qml.ops.op_math.Controlled)
assert qml.equal(sec_ctrl_op.base, qml.RZ(sec_par, 1))
@@ -245,15 +233,15 @@ def test_correct_ops_in_tape_inversion(self):
tape = qml.defer_measurements(tape)
# Conditioned on 0 as the control value, PauliX is applied before and after
- assert len(tape.operations) == 1
+ assert len(tape.operations) == 2
assert len(tape.measurements) == 1
# Check the two underlying Controlled instance
- ctrl_op = tape.operations[0]
+ ctrl_op = tape.operations[1]
assert isinstance(ctrl_op, qml.ops.op_math.Controlled)
assert qml.equal(ctrl_op.base, qml.RY(first_par, 1))
- assert ctrl_op.wires == qml.wires.Wires([0, 1])
+ assert ctrl_op.wires == qml.wires.Wires([2, 1])
def test_correct_ops_in_tape_assert_zero_state(self):
"""Test that the underlying tape contains the correct operations if a
@@ -271,11 +259,11 @@ def test_correct_ops_in_tape_assert_zero_state(self):
tape = qml.defer_measurements(tape)
# Conditioned on 0 as the control value, PauliX is applied before and after
- assert len(tape.operations) == 1
+ assert len(tape.operations) == 2
assert len(tape.measurements) == 1
# Check the underlying Controlled instance
- ctrl_op = tape.operations[0]
+ ctrl_op = tape.operations[1]
assert isinstance(ctrl_op, qml.ops.op_math.Controlled)
assert qml.equal(ctrl_op.base, qml.RY(first_par, 1))
@@ -312,7 +300,9 @@ def test_quantum_teleportation(self, rads):
tape = qml.tape.QuantumScript.from_queue(q)
tape = qml.defer_measurements(tape)
- assert len(tape.operations) == 5 + 2 # 5 regular ops + 2 conditional ops
+ assert (
+ len(tape.operations) == 5 + 2 + 2
+ ) # 5 regular ops + 2 measurement ops + 2 conditional ops
assert len(tape.measurements) == 1
# Check the each operation
@@ -337,15 +327,25 @@ def test_quantum_teleportation(self, rads):
assert isinstance(op5, qml.Hadamard)
assert op5.wires == qml.wires.Wires([0])
- # Check the two underlying Controlled instances
- ctrl_op1 = tape.operations[5]
+ # Check the two underlying CNOTs for storing measurement state
+ meas_op1 = tape.operations[5]
+ assert isinstance(meas_op1, qml.CNOT)
+ assert meas_op1.wires == qml.wires.Wires([0, 3])
+
+ meas_op2 = tape.operations[6]
+ assert isinstance(meas_op2, qml.CNOT)
+ assert meas_op2.wires == qml.wires.Wires([1, 4])
+
+ # Check the two underlying Controlled instances
+ ctrl_op1 = tape.operations[7]
assert isinstance(ctrl_op1, qml.ops.op_math.Controlled)
assert qml.equal(ctrl_op1.base, qml.RX(math.pi, 2))
+ assert ctrl_op1.wires == qml.wires.Wires([4, 2])
- ctrl_op2 = tape.operations[6]
+ ctrl_op2 = tape.operations[8]
assert isinstance(ctrl_op2, qml.ops.op_math.Controlled)
assert qml.equal(ctrl_op2.base, qml.RZ(math.pi, 2))
- assert ctrl_op2.wires == qml.wires.Wires([0, 2])
+ assert ctrl_op2.wires == qml.wires.Wires([3, 2])
# Check the measurement
assert tape.measurements[0] == terminal_measurement
@@ -395,11 +395,16 @@ def test_hermitian_queued(self):
tape = qml.tape.QuantumScript.from_queue(q)
tape = qml.defer_measurements(tape)
- assert len(tape.operations) == 1
+ assert len(tape.operations) == 2
assert len(tape.measurements) == 1
+ # Check the underlying CNOT for storing measurement state
+ meas_op1 = tape.operations[0]
+ assert isinstance(meas_op1, qml.CNOT)
+ assert meas_op1.wires == qml.wires.Wires([0, 5])
+
# Check the underlying Controlled instances
- first_ctrl_op = tape.operations[0]
+ first_ctrl_op = tape.operations[1]
assert isinstance(first_ctrl_op, qml.ops.op_math.Controlled)
assert qml.equal(first_ctrl_op.base, qml.RY(rads, 4))
@@ -426,16 +431,21 @@ def test_hamiltonian_queued(self):
tape = qml.tape.QuantumScript.from_queue(q)
tape = qml.defer_measurements(tape)
- assert len(tape.operations) == 1
+ assert len(tape.operations) == 2
assert len(tape.measurements) == 1
+ # Check the underlying CNOT for storing measurement state
+ meas_op1 = tape.operations[0]
+ assert isinstance(meas_op1, qml.CNOT)
+ assert meas_op1.wires == qml.wires.Wires([0, 5])
+
# Check the underlying Controlled instance
- first_ctrl_op = tape.operations[0]
+ first_ctrl_op = tape.operations[1]
assert isinstance(first_ctrl_op, qml.ops.op_math.Controlled)
assert qml.equal(first_ctrl_op.base, qml.RY(rads, 4))
assert len(tape.measurements) == 1
assert isinstance(tape.measurements[0], qml.measurements.MeasurementProcess)
- assert tape.measurements[0].obs == H
+ assert qml.equal(tape.measurements[0].obs, H)
@pytest.mark.parametrize("device", ["default.qubit", "default.mixed", "lightning.qubit"])
@pytest.mark.parametrize("ops", [(qml.RX, qml.CRX), (qml.RY, qml.CRY), (qml.RZ, qml.CRZ)])
@@ -470,7 +480,7 @@ def quantum_control_circuit(rads):
@pytest.mark.parametrize("device", ["default.qubit", "default.mixed", "lightning.qubit"])
def test_conditional_rotations_with_else(self, device):
"""Test that an else operation can also defined using qml.cond."""
- dev = qml.device(device, wires=2)
+ dev = qml.device(device, wires=3)
r = 2.345
op1, controlled_op1 = qml.RY, qml.CRY
@@ -504,7 +514,7 @@ def test_keyword_syntax(self):
keyword syntax works."""
op = qml.RY
- dev = qml.device("default.qubit", wires=2)
+ dev = qml.device("default.qubit", wires=3)
@qml.qnode(dev)
def qnode1(par):
@@ -528,7 +538,7 @@ def qnode2(par):
def test_condition_using_measurement_outcome(self, control_val, expected):
"""Apply a conditional bitflip by selecting the measurement
outcome."""
- dev = qml.device("default.qubit", wires=2)
+ dev = qml.device("default.qubit", wires=3)
@qml.qnode(dev)
def qnode():
@@ -541,7 +551,7 @@ def qnode():
@pytest.mark.parametrize("device", ["default.qubit", "default.mixed", "lightning.qubit"])
def test_cond_qfunc(self, device):
"""Test that a qfunc can also used with qml.cond."""
- dev = qml.device(device, wires=2)
+ dev = qml.device(device, wires=3)
r = 2.324
@@ -576,7 +586,7 @@ def quantum_control_circuit(r):
def test_cond_qfunc_with_else(self, device):
"""Test that a qfunc can also used with qml.cond even when an else
qfunc is provided."""
- dev = qml.device(device, wires=2)
+ dev = qml.device(device, wires=3)
x = 0.3
y = 3.123
@@ -610,7 +620,25 @@ def cond_qnode(x, y):
return qml.probs(wires=[0])
assert np.allclose(normal_circuit(x, y), cond_qnode(x, y))
- assert np.allclose(qml.matrix(normal_circuit)(x, y), qml.matrix(cond_qnode)(x, y))
+
+ def test_cond_on_measured_wire(self):
+ """Test that applying a conditional operation on the same wire
+ that is measured works as expected."""
+ dev = qml.device("default.qubit", wires=2)
+
+ @qml.qnode(dev)
+ @qml.defer_measurements
+ def qnode():
+ qml.Hadamard(0)
+ m = qml.measure(0)
+ qml.cond(m, qml.PauliX)(0)
+ return qml.density_matrix(0)
+
+ # Above circuit will cause wire 0 to go back to the |0> computational
+ # basis state. We can inspect the reduced density matrix to confirm this
+ # without inspecting the extra wires
+ expected_dmat = np.array([[1, 0], [0, 0]])
+ assert np.allclose(qnode(), expected_dmat)
class TestExpressionConditionals:
@@ -621,7 +649,7 @@ class TestExpressionConditionals:
def test_conditional_rotations(self, r, op):
"""Test that the quantum conditional operations match the output of
controlled rotations. And additionally that summing measurements works as expected."""
- dev = qml.device("default.qubit", wires=3)
+ dev = qml.device("default.qubit", wires=5)
@qml.qnode(dev)
def normal_circuit(rads):
@@ -648,7 +676,7 @@ def quantum_control_circuit(rads):
@pytest.mark.parametrize("r", np.linspace(0.1, 2 * np.pi - 0.1, 4))
def test_triple_measurement_condition_expression(self, r):
"""Test that combining the results of three mid-circuit measurements works as expected."""
- dev = qml.device("default.qubit", wires=4)
+ dev = qml.device("default.qubit", wires=7)
@qml.qnode(dev)
@qml.defer_measurements
@@ -690,7 +718,7 @@ def test_multiple_conditions(self):
"""Test that when multiple "branches" of the mid-circuit measurements all satisfy the criteria then
this translates to multiple control gates.
"""
- dev = qml.device("default.qubit", wires=4)
+ dev = qml.device("default.qubit", wires=7)
@qml.qnode(dev)
@qml.defer_measurements
@@ -733,7 +761,7 @@ def quantum_control_circuit(rads):
def test_composed_conditions(self):
"""Test that a complex nested expression gets resolved correctly to the corresponding correct control gates."""
- dev = qml.device("default.qubit", wires=4)
+ dev = qml.device("default.qubit", wires=7)
@qml.qnode(dev)
@qml.defer_measurements
@@ -788,7 +816,7 @@ def test_basis_state_prep(self):
basis_state = [0, 1, 1, 0]
- dev = qml.device("default.qubit", wires=5)
+ dev = qml.device("default.qubit", wires=6)
@qml.qnode(dev)
def qnode1():
@@ -804,15 +832,15 @@ def qnode2():
qml.cond(m_0, template)(basis_state, wires=range(1, 5))
return qml.expval(qml.PauliZ(1) @ qml.PauliZ(2) @ qml.PauliZ(3) @ qml.PauliZ(4))
- dev = qml.device("default.qubit", wires=2)
-
assert np.allclose(qnode1(), qnode2())
- assert len(qnode1.qtape.operations) == len(qnode2.qtape.operations)
+ assert len(qnode2.qtape.operations) == len(qnode1.qtape.operations) + 1
assert len(qnode1.qtape.measurements) == len(qnode2.qtape.measurements)
# Check the operations
- for op1, op2 in zip(qnode1.qtape.operations, qnode2.qtape.operations):
+ deferred_ops = qnode2.qtape.operations
+ assert qml.equal(deferred_ops.pop(1), qml.CNOT([0, 5]))
+ for op1, op2 in zip(qnode1.qtape.operations, deferred_ops):
assert isinstance(op1, type(op2))
assert np.allclose(op1.data, op2.data)
@@ -827,7 +855,7 @@ def test_angle_embedding(self):
template = qml.AngleEmbedding
feature_vector = [1, 2, 3]
- dev = qml.device("default.qubit", wires=5)
+ dev = qml.device("default.qubit", wires=6)
@qml.qnode(dev)
def qnode1():
@@ -843,17 +871,18 @@ def qnode2():
qml.cond(m_0, template)(features=feature_vector, wires=range(1, 5), rotation="Z")
return qml.expval(qml.PauliZ(1) @ qml.PauliZ(2) @ qml.PauliZ(3) @ qml.PauliZ(4))
- dev = qml.device("default.qubit", wires=2)
res1 = qnode1()
res2 = qnode2()
assert np.allclose(res1, res2)
- assert len(qnode1.qtape.operations) == len(qnode2.qtape.operations)
+ assert len(qnode2.qtape.operations) == len(qnode1.qtape.operations) + 1
assert len(qnode1.qtape.measurements) == len(qnode2.qtape.measurements)
# Check the operations
- for op1, op2 in zip(qnode1.qtape.operations, qnode2.qtape.operations):
+ deferred_ops = qnode2.qtape.operations
+ assert qml.equal(deferred_ops.pop(1), qml.CNOT([0, 5]))
+ for op1, op2 in zip(qnode1.qtape.operations, deferred_ops):
assert isinstance(op1, type(op2))
assert np.allclose(op1.data, op2.data)
@@ -865,7 +894,7 @@ def qnode2():
@pytest.mark.parametrize("template", [qml.StronglyEntanglingLayers, qml.BasicEntanglerLayers])
def test_layers(self, template):
"""Test layers conditioned on mid-circuit measurement outcomes."""
- dev = qml.device("default.qubit", wires=3)
+ dev = qml.device("default.qubit", wires=4)
num_wires = 2
@@ -888,11 +917,13 @@ def qnode2(parameters):
assert np.allclose(qnode1(weights), qnode2(weights))
- assert len(qnode1.qtape.operations) == len(qnode2.qtape.operations)
+ assert len(qnode2.qtape.operations) == len(qnode1.qtape.operations) + 1
assert len(qnode1.qtape.measurements) == len(qnode2.qtape.measurements)
# Check the operations
- for op1, op2 in zip(qnode1.qtape.operations, qnode2.qtape.operations):
+ deferred_ops = qnode2.qtape.operations
+ assert qml.equal(deferred_ops.pop(1), qml.CNOT([0, 3]))
+ for op1, op2 in zip(qnode1.qtape.operations, deferred_ops):
assert isinstance(op1, type(op2))
assert np.allclose(op1.data, op2.data)
@@ -902,6 +933,110 @@ def qnode2(parameters):
assert np.allclose(op1.data, op2.data)
+class TestQubitReset:
+ """Tests for the qubit reset functionality of `qml.measure`."""
+
+ def test_correct_cnot_for_reset(self):
+ """Test that a CNOT is applied from the wire that stores the measurement
+ to the measured wire after the measurement."""
+ dev = qml.device("default.qubit", wires=3)
+
+ @qml.qnode(dev)
+ def qnode1(x):
+ qml.Hadamard(0)
+ qml.CRX(x, [0, 1])
+ return qml.expval(qml.PauliZ(1))
+
+ @qml.qnode(dev)
+ @qml.defer_measurements
+ def qnode2(x):
+ qml.Hadamard(0)
+ m0 = qml.measure(0, reset=True)
+ qml.cond(m0, qml.RX)(x, 1)
+ return qml.expval(qml.PauliZ(1))
+
+ assert np.allclose(qnode1(0.123), qnode2(0.123))
+
+ expected_circuit = [
+ qml.Hadamard(0),
+ qml.CNOT([0, 2]),
+ qml.CNOT([2, 0]),
+ qml.ops.Controlled(qml.RX(0.123, 1), 2),
+ qml.expval(qml.PauliZ(1)),
+ ]
+
+ assert len(qnode2.qtape.circuit) == len(expected_circuit)
+ assert all(
+ qml.equal(actual, expected)
+ for actual, expected in zip(qnode2.qtape.circuit, expected_circuit)
+ )
+
+ def test_wire_is_reset(self):
+ """Test that a wire is reset to the |0> state without any local phases
+ after measurement if reset is requested."""
+ dev = qml.device("default.qubit", wires=3)
+
+ @qml.qnode(dev)
+ @qml.defer_measurements
+ def qnode(x):
+ qml.Hadamard(0)
+ qml.PhaseShift(np.pi / 4, 0)
+ m = qml.measure(0, reset=True)
+ qml.cond(m, qml.RX)(x, 1)
+ return qml.density_matrix(wires=[0])
+
+ # Expected reduced density matrix on wire 0
+ expected_mat = np.array([[1, 0], [0, 0]])
+ assert np.allclose(qnode(0.123), expected_mat)
+
+ def test_multiple_measurements_mixed_reset(self):
+ """Test that a QNode with multiple mid-circuit measurements with
+ different resets is transformed correctly."""
+ dev = qml.device("default.qubit", wires=6)
+
+ @qml.qnode(dev)
+ def qnode(p, x, y):
+ qml.Hadamard(0)
+ qml.PhaseShift(p, 0)
+ # Set measurement_ids so that the order of wires in combined
+ # measurement values is consistent
+
+ mp0 = qml.measurements.MidMeasureMP(0, reset=True, id=0)
+ m0 = qml.measurements.MeasurementValue([mp0], lambda v: v)
+ qml.cond(~m0, qml.RX)(x, 1)
+ mp1 = qml.measurements.MidMeasureMP(1, reset=True, id=1)
+ m1 = qml.measurements.MeasurementValue([mp1], lambda v: v)
+ qml.cond(m0 & m1, qml.Hadamard)(0)
+ mp2 = qml.measurements.MidMeasureMP(0, id=2)
+ m2 = qml.measurements.MeasurementValue([mp2], lambda v: v)
+ qml.cond(m1 | m2, qml.RY)(y, 2)
+ return qml.expval(qml.PauliZ(2))
+
+ _ = qnode(0.123, 0.456, 0.789)
+
+ expected_circuit = [
+ qml.Hadamard(0),
+ qml.PhaseShift(0.123, 0),
+ qml.CNOT([0, 3]),
+ qml.CNOT([3, 0]),
+ qml.ops.Controlled(qml.RX(0.456, 1), 3, [False]),
+ qml.CNOT([1, 4]),
+ qml.CNOT([4, 1]),
+ qml.ops.Controlled(qml.Hadamard(0), [3, 4]),
+ qml.CNOT([0, 5]),
+ qml.ops.Controlled(qml.RY(0.789, 2), [4, 5], [False, True]),
+ qml.ops.Controlled(qml.RY(0.789, 2), [4, 5], [True, False]),
+ qml.ops.Controlled(qml.RY(0.789, 2), [4, 5], [True, True]),
+ qml.expval(qml.PauliZ(2)),
+ ]
+
+ assert len(qnode.qtape.circuit) == len(expected_circuit)
+ assert all(
+ qml.equal(actual, expected)
+ for actual, expected in zip(qnode.qtape.circuit, expected_circuit)
+ )
+
+
class TestDrawing:
"""Tests drawing circuits with mid-circuit measurements and conditional
operations that have been transformed"""
@@ -910,6 +1045,8 @@ def test_drawing(self):
"""Test that drawing a func with mid-circuit measurements works and
that controlled operations are drawn for conditional operations."""
+ # TODO: Update after drawing for mid-circuit measurements is updated.
+
def qfunc():
m_0 = qml.measure(0)
qml.cond(m_0, qml.RY)(0.312, wires=1)
@@ -924,8 +1061,10 @@ def qfunc():
transformed_qnode = qml.QNode(transformed_qfunc, dev)
expected = (
- "0: ─╭●──────────────────┤ \n"
- "1: ─╰RY(0.31)─╭RY(0.31)─┤ \n"
- "2: ───────────╰●────────┤ "
+ "0: ─╭●────────────────────────┤ \n"
+ "1: ─│──╭RY(0.31)────╭RY(0.31)─┤ \n"
+ "2: ─│──│─────────╭●─│─────────┤ \n"
+ "3: ─╰X─╰●────────│──│─────────┤ \n"
+ "4: ──────────────╰X─╰●────────┤ "
)
assert qml.draw(transformed_qnode)() == expected