diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 4a9add0c35c..4c595ff07f9 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -4,6 +4,10 @@

New features since last release

+* `qml.measure` now includes a boolean keyword argument `reset` to reset a wire to the + $|0\rangle$ computational basis state after measurement. + [(#4402)](https://github.com/PennyLaneAI/pennylane/pull/4402/) + * `DefaultQubit2` accepts a `max_workers` argument which controls multiprocessing. A `ProcessPoolExecutor` executes tapes asynchronously using a pool of at most `max_workers` processes. If `max_workers` is `None` @@ -63,6 +67,9 @@ array([False, False])

Improvements 🛠

+* Wires can now be reused after making a mid-circuit measurement on them. + [(#4402)](https://github.com/PennyLaneAI/pennylane/pull/4402/) + * Transform Programs, `qml.transforms.core.TransformProgram`, can now be called on a batch of circuits and return a new batch of circuits and a single post processing function. [(#4364)](https://github.com/PennyLaneAI/pennylane/pull/4364) diff --git a/pennylane/_qubit_device.py b/pennylane/_qubit_device.py index e89b2cb3b04..a6618ac4be3 100644 --- a/pennylane/_qubit_device.py +++ b/pennylane/_qubit_device.py @@ -136,6 +136,7 @@ class QubitDevice(Device): _real = staticmethod(np.real) _size = staticmethod(np.size) _ndim = staticmethod(np.ndim) + _norm = staticmethod(np.linalg.norm) @staticmethod def _scatter(indices, array, new_dimensions): diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 0a7d739a879..7c0e4ae22e5 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -270,6 +270,7 @@ def apply(self, operations, rotations=None, **kwargs): # apply the circuit operations for i, operation in enumerate(operations): + print(self._state.shape) if i > 0 and isinstance(operation, (QubitStateVector, BasisState)): raise DeviceError( f"Operation {operation.name} cannot be used after other Operations have already been applied " @@ -333,12 +334,17 @@ def _apply_operation(self, state, operation): matrix = self._asarray(self._get_unitary_matrix(operation), dtype=self.C_DTYPE) if operation in diagonal_in_z_basis: - return self._apply_diagonal_unitary(state, matrix, wires) - if len(wires) <= 2: + new_state = self._apply_diagonal_unitary(state, matrix, wires) + elif len(wires) <= 2: # Einsum is faster for small gates - return self._apply_unitary_einsum(state, matrix, wires) + new_state = self._apply_unitary_einsum(state, matrix, wires) + else: + new_state = self._apply_unitary(state, matrix, wires) + + if operation.__class__.__name__ in {"Projector", "_BasisStateProjector"}: + new_state = new_state / self._norm(new_state) - return self._apply_unitary(state, matrix, wires) + return new_state def _apply_x(self, state, axes, **kwargs): """Applies a PauliX gate by rolling 1 unit along the axis specified in ``axes``. diff --git a/pennylane/devices/default_qubit_jax.py b/pennylane/devices/default_qubit_jax.py index 411d77fea02..c29590367a1 100644 --- a/pennylane/devices/default_qubit_jax.py +++ b/pennylane/devices/default_qubit_jax.py @@ -162,6 +162,7 @@ def circuit(): _const_mul = staticmethod(jnp.multiply) _size = staticmethod(jnp.size) _ndim = staticmethod(jnp.ndim) + _norm = staticmethod(jnp.linalg.norm) operations = DefaultQubit.operations.union({"ParametrizedEvolution"}) diff --git a/pennylane/devices/default_qubit_tf.py b/pennylane/devices/default_qubit_tf.py index 702bf34af0b..bf12988e0af 100644 --- a/pennylane/devices/default_qubit_tf.py +++ b/pennylane/devices/default_qubit_tf.py @@ -136,6 +136,7 @@ class DefaultQubitTF(DefaultQubit): _stack = staticmethod(tf.stack) _size = staticmethod(tf.size) _ndim = staticmethod(_ndim_tf) + _norm = staticmethod(tf.norm) @staticmethod def _const_mul(constant, array): diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index 71155b2892f..7024aa9ba41 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -190,8 +190,12 @@ def _(op: type_op, state): len(op.wires) < EINSUM_OP_WIRECOUNT_PERF_THRESHOLD and math.ndim(state) < EINSUM_STATE_WIRECOUNT_PERF_THRESHOLD ) or (op.batch_size and is_state_batched): - return apply_operation_einsum(op, state, is_state_batched=is_state_batched) - return apply_operation_tensordot(op, state, is_state_batched=is_state_batched) + new_state = apply_operation_einsum(op, state, is_state_batched=is_state_batched) + else: + new_state = apply_operation_tensordot(op, state, is_state_batched=is_state_batched) + + if op.__class__.__name__ in {"Projector", "_BasisStateProjector"}: + new_state = new_state / math.norm(new_state) @apply_operation.register diff --git a/pennylane/drawer/drawable_layers.py b/pennylane/drawer/drawable_layers.py index 0042b9e3122..9a29b90e320 100644 --- a/pennylane/drawer/drawable_layers.py +++ b/pennylane/drawer/drawable_layers.py @@ -97,11 +97,6 @@ def drawable_layers(ops, wire_map=None): # loop over operations for op in ops: is_mid_measure = is_conditional = False - if set(measured_wires.values()).intersection({wire_map[w] for w in op.wires}): - raise ValueError( - f"Cannot apply operations on {op.wires} as some wires have been measured already." - ) - if isinstance(op, MidMeasureMP): if len(op.wires) > 1: raise ValueError("Cannot draw mid-circuit measurements with more than one wire.") diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py index 2de61e178a8..a4b51bf2a13 100644 --- a/pennylane/measurements/mid_measure.py +++ b/pennylane/measurements/mid_measure.py @@ -24,8 +24,10 @@ from .measurements import MeasurementProcess, MidMeasure -def measure(wires): # TODO: Change name to mid_measure - """Perform a mid-circuit measurement in the computational basis on the +def measure( + wires: Wires, reset: Optional[bool] = False, postselect: Optional[int] = None +): # TODO: Change name to mid_measure + r"""Perform a mid-circuit measurement in the computational basis on the supplied qubit. Measurement outcomes can be obtained and used to conditionally apply @@ -38,7 +40,7 @@ def measure(wires): # TODO: Change name to mid_measure .. code-block:: python3 - dev = qml.device("default.qubit", wires=2) + dev = qml.device("default.qubit", wires=3) @qml.qnode(dev) def func(x, y): @@ -55,16 +57,40 @@ def func(x, y): >>> func(*pars) tensor([0.90165331, 0.09834669], requires_grad=True) + Wires can be reused after measurement. Moreover, measured wires can be reset + to the :math:`|0 \rangle` by setting ``reset=True``. + + .. code-block:: python3 + + dev = qml.device("default.qubit", wires=3) + + @qml.qnode(dev) + def func(): + qml.PauliX(1) + m_0 = qml.measure(1, reset=True) + return qml.probs(wires=[1]) + + Executing this QNode: + + >>> func() + tensor([1., 0.], requires_grad=True) + Mid circuit measurements can be manipulated using the following dunder methods ``+``, ``-``, ``*``, ``/``, ``~`` (not), ``&`` (and), ``|`` (or), ``==``, ``<=``, ``>=``, ``<``, ``>`` with other mid-circuit measurements or scalars. - Note: - python ``not``, ``and``, ``or``, do not work since these do not have dunder - methods. Instead use ``~``, ``&``, ``|``. + .. Note :: + + Python ``not``, ``and``, ``or``, do not work since these do not have dunder methods. + Instead use ``~``, ``&``, ``|``. Args: wires (Wires): The wire of the qubit the measurement process applies to. + reset (Optional[bool]): Whether to reset the wire to the :math:`|0 \rangle` + state after measurement. + postselect (Optional[int]): The measured computational basis state on which to + optionally postselect the circuit. Must be ``0`` or ``1`` if postselection + is requested. Returns: MidMeasureMP: measurement process instance @@ -72,6 +98,7 @@ def func(x, y): Raises: QuantumFunctionError: if multiple wires were specified """ + wire = Wires(wires) if len(wire) > 1: raise qml.QuantumFunctionError( @@ -80,7 +107,7 @@ def func(x, y): # Create a UUID and a map between MP and MV to support serialization measurement_id = str(uuid.uuid4())[:8] - mp = MidMeasureMP(wires=wire, id=measurement_id) + mp = MidMeasureMP(wires=wire, reset=reset, postselect=postselect, id=measurement_id) return MeasurementValue([mp], processing_fn=lambda v: v) @@ -90,17 +117,31 @@ def func(x, y): class MidMeasureMP(MeasurementProcess): """Mid-circuit measurement. + This class additionally stores information about unknown measurement outcomes in the qubit model. + Measurements on a single qubit in the computational basis are assumed. + Please refer to :func:`measure` for detailed documentation. Args: wires (.Wires): The wires the measurement process applies to. This can only be specified if an observable was not provided. - id (str): custom label given to a measurement instance, can be useful for some applications - where the instance has to be identified + reset (bool): Whether to reset the wire after measurement. + postselect (Optional[int]): The measured computational basis state on which to + optionally postselect the circuit. Must be ``0`` or ``1`` if postselection + is requested. + id (str): Custom label given to a measurement instance. """ - def __init__(self, wires: Optional[Wires] = None, id: Optional[str] = None): - super().__init__(wires=wires, id=id) + def __init__( + self, + wires: Optional[Wires] = None, + reset: Optional[bool] = False, + postselect: Optional[int] = None, + id: Optional[str] = None, + ): + super().__init__(wires=Wires(wires), id=id) + self.reset = reset + self.postselect = postselect @property def return_type(self): diff --git a/pennylane/ops/qubit/observables.py b/pennylane/ops/qubit/observables.py index 9246de99dc7..3b48e83b058 100644 --- a/pennylane/ops/qubit/observables.py +++ b/pennylane/ops/qubit/observables.py @@ -23,7 +23,7 @@ from scipy.sparse import csr_matrix import pennylane as qml -from pennylane.operation import AnyWires, Observable +from pennylane.operation import AnyWires, Observable, Operation from pennylane.wires import Wires from .matrix_ops import QubitUnitary @@ -428,7 +428,7 @@ def __copy__(self): return copied_op -class _BasisStateProjector(Observable): +class _BasisStateProjector(Observable, Operation): # The call signature should be the same as Projector.__new__ for the positional # arguments, but with free key word arguments. def __init__(self, state, wires, id=None): diff --git a/pennylane/transforms/condition.py b/pennylane/transforms/condition.py index 1c3530b0eff..d5dcd4ea6e7 100644 --- a/pennylane/transforms/condition.py +++ b/pennylane/transforms/condition.py @@ -67,7 +67,7 @@ def cond(condition, true_fn, false_fn=None): :func:`defer_measurements` transform. Args: - condition (.MeasurementValue[bool]): a conditional expression involving a mid-circuit + condition (.MeasurementValue): a conditional expression involving a mid-circuit measurement value (see :func:`.pennylane.measure`) true_fn (callable): The quantum function of PennyLane operation to apply if ``condition`` is ``True`` @@ -114,6 +114,8 @@ def qnode(x, y): Expressions with boolean logic flow using operators like ``and``, ``or`` and ``not`` are not supported as the ``condition`` argument. + While such statements may not result in errors, they may result in + incorrect behaviour. .. details:: :title: Usage Details diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py index 29c13cdfe27..2cc04f656d8 100644 --- a/pennylane/transforms/defer_measurements.py +++ b/pennylane/transforms/defer_measurements.py @@ -16,12 +16,13 @@ from pennylane.measurements import MidMeasureMP from pennylane.ops.op_math import ctrl from pennylane.queuing import apply +from pennylane.tape import QuantumTape from pennylane.transforms import qfunc_transform from pennylane.wires import Wires @qfunc_transform -def defer_measurements(tape): +def defer_measurements(tape: QuantumTape): """Quantum function transform that substitutes operations conditioned on measurement outcomes to controlled operations. @@ -40,6 +41,13 @@ def defer_measurements(tape): that can be controlled as such depends on the set of operations supported by the chosen device. + .. note:: + + Devices that inherit `QubitDevice` **must** be initialized with an additional + wire for each mid-circuit measurement for `defer_measurements` to transform + the quantum tape correctly. Such devices should also be initialized without + custom wire labels for correct behaviour. + .. note:: This transform does not change the list of terminal measurements returned by @@ -54,7 +62,7 @@ def defer_measurements(tape): post-measurement states are considered. Args: - qfunc (function): a quantum function + tape (.QuantumTape): a quantum tape **Example** @@ -84,8 +92,30 @@ def qfunc(par): >>> qml.grad(qnode)(par) tensor(-0.49622252, requires_grad=True) + + Reusing and reseting measured wires will work as expected with the + ``defer_measurements`` transform: + + .. code-block:: python3 + + dev = qml.device("default.qubit", wires=3) + + @qml.qnode(dev) + def func(x, y): + qml.RY(x, wires=0) + qml.CNOT(wires=[0, 1]) + m_0 = qml.measure(1, reset=True) + + qml.cond(m_0, qml.RY)(y, wires=0) + qml.RX(np.pi/4, wires=1) + return qml.probs(wires=[0, 1]) + + Executing this QNode: + + >>> pars = np.array([0.643, 0.246], requires_grad=True) + >>> func(*pars) + tensor([0.76960924, 0.13204407, 0.08394415, 0.01440254], requires_grad=True) """ - measured_wires = {} cv_types = (qml.operation.CVOperation, qml.operation.CVObservable) ops_cv = any(isinstance(op, cv_types) for op in tape.operations) @@ -93,25 +123,34 @@ def qfunc(par): if ops_cv or obs_cv: raise ValueError("Continuous variable operations and observables are not supported.") - for op in tape: - op_wires_measured = set(wire for wire in op.wires if wire in measured_wires.values()) - if len(op_wires_measured) > 0: - raise ValueError( - f"Cannot apply operations on {op.wires} as the following wires have been measured already: {op_wires_measured}." - ) + # Current wire in which pre-measurement state will be saved if dev_wires not specified + cur_wire = max(tape.wires) + 1 + new_wires = {} + for op in tape: if isinstance(op, MidMeasureMP): - measured_wires[op.id] = op.wires[0] + new_wires[op.id] = cur_wire + qml.CNOT([op.wires[0], cur_wire]) + + if op.postselect is not None: + qml.Projector([op.postselect], wires=op.wires[0]) + + if op.reset: + qml.CNOT([cur_wire, op.wires[0]]) + + cur_wire += 1 elif op.__class__.__name__ == "Conditional": - _add_control_gate(op, measured_wires) + _add_control_gate(op, new_wires) else: apply(op) + return tape._qfunc_output # pylint: disable=protected-access + -def _add_control_gate(op, measured_wires): +def _add_control_gate(op, control_wires): """Helper function to add control gates""" - control = [measured_wires[m.id] for m in op.meas_val.measurements] + control = [control_wires[m.id] for m in op.meas_val.measurements] for branch, value in op.meas_val._items(): # pylint: disable=protected-access if value: ctrl( diff --git a/tests/devices/qubit/test_preprocess.py b/tests/devices/qubit/test_preprocess.py index 9c1ad571a73..281c32be9b2 100644 --- a/tests/devices/qubit/test_preprocess.py +++ b/tests/devices/qubit/test_preprocess.py @@ -230,7 +230,8 @@ def test_expand_fn_defer_measurement(self): expanded_tape = expand_fn(tape) expected = [ qml.Hadamard(0), - qml.ops.Controlled(qml.RX(0.123, wires=1), 0), + qml.CNOT([0, 2]), + qml.ops.Controlled(qml.RX(0.123, wires=1), 2), ] for op, exp in zip(expanded_tape, expected + measurements): diff --git a/tests/drawer/test_drawable_layers.py b/tests/drawer/test_drawable_layers.py index a842a3c915a..5bac4268d6a 100644 --- a/tests/drawer/test_drawable_layers.py +++ b/tests/drawer/test_drawable_layers.py @@ -173,23 +173,7 @@ def test_empty_layers_are_pruned(self): layers = drawable_layers(ops, wire_map={i: i for i in range(3)}) assert layers == [[ops[1]], [ops[2], ops[0]], [ops[3]]] - def test_cannot_reuse_wire_after_conditional(self): - """Tests that a wire cannot be re-used after using a mid-circuit measurement.""" - with AnnotatedQueue() as q: - m0 = qml.measure(0) - qml.cond(m0, qml.PauliX)(1) - qml.Hadamard(0) - - with pytest.raises(ValueError, match="some wires have been measured already"): - drawable_layers(q.queue) - def test_cannot_draw_multi_wire_MidMeasureMP(self): """Tests that MidMeasureMP is only supported with one wire.""" with pytest.raises(ValueError, match="mid-circuit measurements with more than one wire."): drawable_layers([MidMeasureMP([0, 1])]) - - def test_cannot_use_measured_wire(self): - """Tests error is raised when trying to use a measured wire.""" - ops = [MidMeasureMP([0]), qml.PauliX(0)] - with pytest.raises(ValueError, match="some wires have been measured already"): - drawable_layers(ops) diff --git a/tests/measurements/test_mid_measure.py b/tests/measurements/test_mid_measure.py index 06352b0a1de..6aa70a21457 100644 --- a/tests/measurements/test_mid_measure.py +++ b/tests/measurements/test_mid_measure.py @@ -26,7 +26,7 @@ def test_samples_computational_basis(): """Test that samples_computational_basis is always false for mid circuit measurements.""" - m = qml.measurements.MidMeasureMP(Wires(0)) + m = MidMeasureMP(Wires(0)) assert not m.samples_computational_basis @@ -44,19 +44,19 @@ def test_many_wires_error(self): def test_hash(self): """Test that the hash for `MidMeasureMP` is defined correctly.""" - m1 = MidMeasureMP(Wires(0), "m1") - m2 = MidMeasureMP(Wires(0), "m2") - m3 = MidMeasureMP(Wires(1), "m1") - m4 = MidMeasureMP(Wires(0), "m1") + m1 = MidMeasureMP(Wires(0), id="m1") + m2 = MidMeasureMP(Wires(0), id="m2") + m3 = MidMeasureMP(Wires(1), id="m1") + m4 = MidMeasureMP(Wires(0), id="m1") assert m1.hash != m2.hash assert m1.hash != m3.hash assert m1.hash == m4.hash -mp1 = MidMeasureMP(Wires(0), "m0") -mp2 = MidMeasureMP(Wires(1), "m1") -mp3 = MidMeasureMP(Wires(2), "m2") +mp1 = MidMeasureMP(Wires(0), id="m0") +mp2 = MidMeasureMP(Wires(1), id="m1") +mp3 = MidMeasureMP(Wires(2), id="m2") class TestMeasurementValueManipulation: diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 95067f2b2d0..30705d375ff 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -1077,7 +1077,7 @@ def test_defer_meas_if_mcm_unsupported(self, first_par, sec_par, return_type): """Tests that the transform using the deferred measurement principle is applied if the device doesn't support mid-circuit measurements natively.""" - dev = qml.device("default.qubit", wires=2) + dev = qml.device("default.qubit", wires=3) @qml.qnode(dev) def cry_qnode(x, y): @@ -1105,7 +1105,7 @@ def conditional_ry_qnode(x, y): def test_sampling_with_mcm(self, basis_state): """Tests that a QNode with qml.sample and mid-circuit measurements returns the expected results.""" - dev = qml.device("default.qubit", wires=2, shots=1000) + dev = qml.device("default.qubit", wires=3, shots=1000) first_par = np.pi @@ -1135,7 +1135,7 @@ def test_conditional_ops_tensorflow(self, interface): """Test conditional operations with TensorFlow.""" import tensorflow as tf - dev = qml.device("default.qubit", wires=2) + dev = qml.device("default.qubit", wires=3) @qml.qnode(dev, interface=interface, diff_method="parameter-shift") def cry_qnode(x): @@ -1178,7 +1178,7 @@ def test_conditional_ops_torch(self, interface): """Test conditional operations with Torch.""" import torch - dev = qml.device("default.qubit", wires=2) + dev = qml.device("default.qubit", wires=3) @qml.qnode(dev, interface=interface, diff_method="parameter-shift") def cry_qnode(x): @@ -1217,7 +1217,7 @@ def test_conditional_ops_jax(self, jax_interface): import jax jnp = jax.numpy - dev = qml.device("default.qubit", wires=2) + dev = qml.device("default.qubit", wires=3) @qml.qnode(dev, interface=jax_interface, diff_method="parameter-shift") def cry_qnode(x): @@ -1246,20 +1246,6 @@ def conditional_ry_qnode(x): assert np.allclose(r1, r2) assert np.allclose(jax.grad(cry_qnode)(x1), jax.grad(conditional_ry_qnode)(x2)) - def test_already_measured_error_operation(self): - """Test that attempting to apply an operation on a wires that has been - measured raises an error.""" - dev = qml.device("default.qubit", wires=3) - - @qml.qnode(dev) - def circuit(): - qml.measure(1) - qml.PauliX(1) - return qml.expval(qml.PauliZ(0)) - - with pytest.raises(ValueError, match="wires have been measured already: {1}"): - circuit() - def test_qnode_does_not_support_nested_queuing(self): """Test that operators in QNodes are not queued to surrounding contexts.""" dev = qml.device("default.qubit", wires=1) diff --git a/tests/transforms/test_defer_measurements.py b/tests/transforms/test_defer_measurements.py index 6e247e1c8f2..80cf97c2978 100644 --- a/tests/transforms/test_defer_measurements.py +++ b/tests/transforms/test_defer_measurements.py @@ -47,19 +47,29 @@ def qnode2(): assert isinstance(res1, type(res2)) assert res1.shape == res2.shape - assert len(qnode1.qtape.operations) == len(qnode2.qtape.operations) + assert len(qnode2.qtape.operations) == 1 + assert isinstance(qnode2.qtape.operations[0], qml.CNOT) assert len(qnode1.qtape.measurements) == len(qnode2.qtape.measurements) - # Check the operations - for op1, op2 in zip(qnode1.qtape.operations, qnode2.qtape.operations): - assert isinstance(op1, type(op2)) - assert op1.data == op2.data - # Check the measurements for op1, op2 in zip(qnode1.qtape.measurements, qnode2.qtape.measurements): assert isinstance(op1, type(op2)) assert op1.data == op2.data + def test_reuse_wire_after_measurement(self): + """Test that wires can be reused after measurement.""" + dev = qml.device("default.qubit", wires=2) + + @qml.qnode(dev) + @qml.defer_measurements + def qnode(): + qml.Hadamard(0) + qml.measure(0) + qml.PauliZ(0) + return qml.expval(qml.PauliX(0)) + + _ = qnode() + def test_measure_between_ops(self): """Test that a quantum function that contains one operation before and after a mid-circuit measurement yields the correct results and is @@ -87,11 +97,14 @@ def func2(): assert isinstance(res1, type(res2)) assert res1.shape == res2.shape - assert len(qnode1.qtape.operations) == len(qnode2.qtape.operations) + assert len(qnode2.qtape.operations) == len(qnode1.qtape.operations) + 1 assert len(qnode1.qtape.measurements) == len(qnode2.qtape.measurements) # Check the operations - for op1, op2 in zip(qnode1.qtape.operations, qnode2.qtape.operations): + deferred_ops = qnode2.qtape.operations + assert qml.equal(deferred_ops.pop(1), qml.CNOT([1, 2])) + + for op1, op2 in zip(qnode1.qtape.operations, deferred_ops): assert isinstance(op1, type(op2)) assert op1.data == op2.data @@ -107,6 +120,9 @@ def test_measure_with_tensor_obs(self, mid_measure_wire, tp_wires): """Test that the defer_measurements transform works well even with tensor observables in the tape.""" # pylint: disable=protected-access + if isinstance(mid_measure_wire, str): + pytest.skip("defer_measurements does not support custom wire labels.") + with qml.queuing.AnnotatedQueue() as q: qml.measure(mid_measure_wire) qml.expval(qml.operation.Tensor(*[qml.PauliZ(w) for w in tp_wires])) @@ -115,7 +131,10 @@ def test_measure_with_tensor_obs(self, mid_measure_wire, tp_wires): tape = qml.defer_measurements(tape) # Check the operations and measurements in the tape - assert tape._ops == [] + assert len(tape._ops) == 1 + assert qml.equal( + tape._ops[0], qml.CNOT([mid_measure_wire, max([mid_measure_wire] + tp_wires) + 1]) + ) assert len(tape.measurements) == 1 measurement = tape.measurements[0] @@ -128,37 +147,6 @@ def test_measure_with_tensor_obs(self, mid_measure_wire, tp_wires): assert isinstance(ob, qml.PauliZ) assert ob.wires == qml.wires.Wires(tp_wires[idx]) - def test_already_measured_error_operation(self): - """Test that attempting to apply an operation on a wires that has been - measured raises an error.""" - dev = qml.device("default.qubit", wires=3) - - def qfunc(): - qml.measure(1) - qml.PauliX(1) - return qml.expval(qml.PauliZ(0)) - - tape_deferred_func = qml.defer_measurements(qfunc) - qnode = qml.QNode(tape_deferred_func, dev) - - with pytest.raises(ValueError, match="wires have been measured already: {1}"): - qnode() - - def test_already_measured_error_terminal_measurement(self): - """Test that attempting to measure a wire at the end of the circuit - that has been measured in the middle of the circuit raises an error.""" - dev = qml.device("default.qubit", wires=3) - - def qfunc(): - qml.measure(1) - return qml.expval(qml.PauliZ(1)) - - tape_deferred_func = qml.defer_measurements(qfunc) - qnode = qml.QNode(tape_deferred_func, dev) - - with pytest.raises(ValueError, match="Cannot apply operations"): - qnode() - def test_cv_op_error(self): """Test that CV operations are not supported.""" dev = qml.device("default.gaussian", wires=3) @@ -216,15 +204,15 @@ def test_correct_ops_in_tape(self, terminal_measurement): tape = qml.tape.QuantumScript.from_queue(q) tape = qml.defer_measurements(tape) - assert len(tape.operations) == 2 + assert len(tape.operations) == 4 assert len(tape.measurements) == 1 # Check the two underlying Controlled instances - first_ctrl_op = tape.operations[0] + first_ctrl_op = tape.operations[1] assert isinstance(first_ctrl_op, qml.ops.op_math.Controlled) assert qml.equal(first_ctrl_op.base, qml.RY(first_par, 1)) - sec_ctrl_op = tape.operations[1] + sec_ctrl_op = tape.operations[3] assert isinstance(sec_ctrl_op, qml.ops.op_math.Controlled) assert qml.equal(sec_ctrl_op.base, qml.RZ(sec_par, 1)) @@ -245,15 +233,15 @@ def test_correct_ops_in_tape_inversion(self): tape = qml.defer_measurements(tape) # Conditioned on 0 as the control value, PauliX is applied before and after - assert len(tape.operations) == 1 + assert len(tape.operations) == 2 assert len(tape.measurements) == 1 # Check the two underlying Controlled instance - ctrl_op = tape.operations[0] + ctrl_op = tape.operations[1] assert isinstance(ctrl_op, qml.ops.op_math.Controlled) assert qml.equal(ctrl_op.base, qml.RY(first_par, 1)) - assert ctrl_op.wires == qml.wires.Wires([0, 1]) + assert ctrl_op.wires == qml.wires.Wires([2, 1]) def test_correct_ops_in_tape_assert_zero_state(self): """Test that the underlying tape contains the correct operations if a @@ -271,11 +259,11 @@ def test_correct_ops_in_tape_assert_zero_state(self): tape = qml.defer_measurements(tape) # Conditioned on 0 as the control value, PauliX is applied before and after - assert len(tape.operations) == 1 + assert len(tape.operations) == 2 assert len(tape.measurements) == 1 # Check the underlying Controlled instance - ctrl_op = tape.operations[0] + ctrl_op = tape.operations[1] assert isinstance(ctrl_op, qml.ops.op_math.Controlled) assert qml.equal(ctrl_op.base, qml.RY(first_par, 1)) @@ -312,7 +300,9 @@ def test_quantum_teleportation(self, rads): tape = qml.tape.QuantumScript.from_queue(q) tape = qml.defer_measurements(tape) - assert len(tape.operations) == 5 + 2 # 5 regular ops + 2 conditional ops + assert ( + len(tape.operations) == 5 + 2 + 2 + ) # 5 regular ops + 2 measurement ops + 2 conditional ops assert len(tape.measurements) == 1 # Check the each operation @@ -337,15 +327,25 @@ def test_quantum_teleportation(self, rads): assert isinstance(op5, qml.Hadamard) assert op5.wires == qml.wires.Wires([0]) - # Check the two underlying Controlled instances - ctrl_op1 = tape.operations[5] + # Check the two underlying CNOTs for storing measurement state + meas_op1 = tape.operations[5] + assert isinstance(meas_op1, qml.CNOT) + assert meas_op1.wires == qml.wires.Wires([0, 3]) + + meas_op2 = tape.operations[6] + assert isinstance(meas_op2, qml.CNOT) + assert meas_op2.wires == qml.wires.Wires([1, 4]) + + # Check the two underlying Controlled instances + ctrl_op1 = tape.operations[7] assert isinstance(ctrl_op1, qml.ops.op_math.Controlled) assert qml.equal(ctrl_op1.base, qml.RX(math.pi, 2)) + assert ctrl_op1.wires == qml.wires.Wires([4, 2]) - ctrl_op2 = tape.operations[6] + ctrl_op2 = tape.operations[8] assert isinstance(ctrl_op2, qml.ops.op_math.Controlled) assert qml.equal(ctrl_op2.base, qml.RZ(math.pi, 2)) - assert ctrl_op2.wires == qml.wires.Wires([0, 2]) + assert ctrl_op2.wires == qml.wires.Wires([3, 2]) # Check the measurement assert tape.measurements[0] == terminal_measurement @@ -395,11 +395,16 @@ def test_hermitian_queued(self): tape = qml.tape.QuantumScript.from_queue(q) tape = qml.defer_measurements(tape) - assert len(tape.operations) == 1 + assert len(tape.operations) == 2 assert len(tape.measurements) == 1 + # Check the underlying CNOT for storing measurement state + meas_op1 = tape.operations[0] + assert isinstance(meas_op1, qml.CNOT) + assert meas_op1.wires == qml.wires.Wires([0, 5]) + # Check the underlying Controlled instances - first_ctrl_op = tape.operations[0] + first_ctrl_op = tape.operations[1] assert isinstance(first_ctrl_op, qml.ops.op_math.Controlled) assert qml.equal(first_ctrl_op.base, qml.RY(rads, 4)) @@ -426,16 +431,21 @@ def test_hamiltonian_queued(self): tape = qml.tape.QuantumScript.from_queue(q) tape = qml.defer_measurements(tape) - assert len(tape.operations) == 1 + assert len(tape.operations) == 2 assert len(tape.measurements) == 1 + # Check the underlying CNOT for storing measurement state + meas_op1 = tape.operations[0] + assert isinstance(meas_op1, qml.CNOT) + assert meas_op1.wires == qml.wires.Wires([0, 5]) + # Check the underlying Controlled instance - first_ctrl_op = tape.operations[0] + first_ctrl_op = tape.operations[1] assert isinstance(first_ctrl_op, qml.ops.op_math.Controlled) assert qml.equal(first_ctrl_op.base, qml.RY(rads, 4)) assert len(tape.measurements) == 1 assert isinstance(tape.measurements[0], qml.measurements.MeasurementProcess) - assert tape.measurements[0].obs == H + assert qml.equal(tape.measurements[0].obs, H) @pytest.mark.parametrize("device", ["default.qubit", "default.mixed", "lightning.qubit"]) @pytest.mark.parametrize("ops", [(qml.RX, qml.CRX), (qml.RY, qml.CRY), (qml.RZ, qml.CRZ)]) @@ -470,7 +480,7 @@ def quantum_control_circuit(rads): @pytest.mark.parametrize("device", ["default.qubit", "default.mixed", "lightning.qubit"]) def test_conditional_rotations_with_else(self, device): """Test that an else operation can also defined using qml.cond.""" - dev = qml.device(device, wires=2) + dev = qml.device(device, wires=3) r = 2.345 op1, controlled_op1 = qml.RY, qml.CRY @@ -504,7 +514,7 @@ def test_keyword_syntax(self): keyword syntax works.""" op = qml.RY - dev = qml.device("default.qubit", wires=2) + dev = qml.device("default.qubit", wires=3) @qml.qnode(dev) def qnode1(par): @@ -528,7 +538,7 @@ def qnode2(par): def test_condition_using_measurement_outcome(self, control_val, expected): """Apply a conditional bitflip by selecting the measurement outcome.""" - dev = qml.device("default.qubit", wires=2) + dev = qml.device("default.qubit", wires=3) @qml.qnode(dev) def qnode(): @@ -541,7 +551,7 @@ def qnode(): @pytest.mark.parametrize("device", ["default.qubit", "default.mixed", "lightning.qubit"]) def test_cond_qfunc(self, device): """Test that a qfunc can also used with qml.cond.""" - dev = qml.device(device, wires=2) + dev = qml.device(device, wires=3) r = 2.324 @@ -576,7 +586,7 @@ def quantum_control_circuit(r): def test_cond_qfunc_with_else(self, device): """Test that a qfunc can also used with qml.cond even when an else qfunc is provided.""" - dev = qml.device(device, wires=2) + dev = qml.device(device, wires=3) x = 0.3 y = 3.123 @@ -610,7 +620,25 @@ def cond_qnode(x, y): return qml.probs(wires=[0]) assert np.allclose(normal_circuit(x, y), cond_qnode(x, y)) - assert np.allclose(qml.matrix(normal_circuit)(x, y), qml.matrix(cond_qnode)(x, y)) + + def test_cond_on_measured_wire(self): + """Test that applying a conditional operation on the same wire + that is measured works as expected.""" + dev = qml.device("default.qubit", wires=2) + + @qml.qnode(dev) + @qml.defer_measurements + def qnode(): + qml.Hadamard(0) + m = qml.measure(0) + qml.cond(m, qml.PauliX)(0) + return qml.density_matrix(0) + + # Above circuit will cause wire 0 to go back to the |0> computational + # basis state. We can inspect the reduced density matrix to confirm this + # without inspecting the extra wires + expected_dmat = np.array([[1, 0], [0, 0]]) + assert np.allclose(qnode(), expected_dmat) class TestExpressionConditionals: @@ -621,7 +649,7 @@ class TestExpressionConditionals: def test_conditional_rotations(self, r, op): """Test that the quantum conditional operations match the output of controlled rotations. And additionally that summing measurements works as expected.""" - dev = qml.device("default.qubit", wires=3) + dev = qml.device("default.qubit", wires=5) @qml.qnode(dev) def normal_circuit(rads): @@ -648,7 +676,7 @@ def quantum_control_circuit(rads): @pytest.mark.parametrize("r", np.linspace(0.1, 2 * np.pi - 0.1, 4)) def test_triple_measurement_condition_expression(self, r): """Test that combining the results of three mid-circuit measurements works as expected.""" - dev = qml.device("default.qubit", wires=4) + dev = qml.device("default.qubit", wires=7) @qml.qnode(dev) @qml.defer_measurements @@ -690,7 +718,7 @@ def test_multiple_conditions(self): """Test that when multiple "branches" of the mid-circuit measurements all satisfy the criteria then this translates to multiple control gates. """ - dev = qml.device("default.qubit", wires=4) + dev = qml.device("default.qubit", wires=7) @qml.qnode(dev) @qml.defer_measurements @@ -733,7 +761,7 @@ def quantum_control_circuit(rads): def test_composed_conditions(self): """Test that a complex nested expression gets resolved correctly to the corresponding correct control gates.""" - dev = qml.device("default.qubit", wires=4) + dev = qml.device("default.qubit", wires=7) @qml.qnode(dev) @qml.defer_measurements @@ -788,7 +816,7 @@ def test_basis_state_prep(self): basis_state = [0, 1, 1, 0] - dev = qml.device("default.qubit", wires=5) + dev = qml.device("default.qubit", wires=6) @qml.qnode(dev) def qnode1(): @@ -804,15 +832,15 @@ def qnode2(): qml.cond(m_0, template)(basis_state, wires=range(1, 5)) return qml.expval(qml.PauliZ(1) @ qml.PauliZ(2) @ qml.PauliZ(3) @ qml.PauliZ(4)) - dev = qml.device("default.qubit", wires=2) - assert np.allclose(qnode1(), qnode2()) - assert len(qnode1.qtape.operations) == len(qnode2.qtape.operations) + assert len(qnode2.qtape.operations) == len(qnode1.qtape.operations) + 1 assert len(qnode1.qtape.measurements) == len(qnode2.qtape.measurements) # Check the operations - for op1, op2 in zip(qnode1.qtape.operations, qnode2.qtape.operations): + deferred_ops = qnode2.qtape.operations + assert qml.equal(deferred_ops.pop(1), qml.CNOT([0, 5])) + for op1, op2 in zip(qnode1.qtape.operations, deferred_ops): assert isinstance(op1, type(op2)) assert np.allclose(op1.data, op2.data) @@ -827,7 +855,7 @@ def test_angle_embedding(self): template = qml.AngleEmbedding feature_vector = [1, 2, 3] - dev = qml.device("default.qubit", wires=5) + dev = qml.device("default.qubit", wires=6) @qml.qnode(dev) def qnode1(): @@ -843,17 +871,18 @@ def qnode2(): qml.cond(m_0, template)(features=feature_vector, wires=range(1, 5), rotation="Z") return qml.expval(qml.PauliZ(1) @ qml.PauliZ(2) @ qml.PauliZ(3) @ qml.PauliZ(4)) - dev = qml.device("default.qubit", wires=2) res1 = qnode1() res2 = qnode2() assert np.allclose(res1, res2) - assert len(qnode1.qtape.operations) == len(qnode2.qtape.operations) + assert len(qnode2.qtape.operations) == len(qnode1.qtape.operations) + 1 assert len(qnode1.qtape.measurements) == len(qnode2.qtape.measurements) # Check the operations - for op1, op2 in zip(qnode1.qtape.operations, qnode2.qtape.operations): + deferred_ops = qnode2.qtape.operations + assert qml.equal(deferred_ops.pop(1), qml.CNOT([0, 5])) + for op1, op2 in zip(qnode1.qtape.operations, deferred_ops): assert isinstance(op1, type(op2)) assert np.allclose(op1.data, op2.data) @@ -865,7 +894,7 @@ def qnode2(): @pytest.mark.parametrize("template", [qml.StronglyEntanglingLayers, qml.BasicEntanglerLayers]) def test_layers(self, template): """Test layers conditioned on mid-circuit measurement outcomes.""" - dev = qml.device("default.qubit", wires=3) + dev = qml.device("default.qubit", wires=4) num_wires = 2 @@ -888,11 +917,13 @@ def qnode2(parameters): assert np.allclose(qnode1(weights), qnode2(weights)) - assert len(qnode1.qtape.operations) == len(qnode2.qtape.operations) + assert len(qnode2.qtape.operations) == len(qnode1.qtape.operations) + 1 assert len(qnode1.qtape.measurements) == len(qnode2.qtape.measurements) # Check the operations - for op1, op2 in zip(qnode1.qtape.operations, qnode2.qtape.operations): + deferred_ops = qnode2.qtape.operations + assert qml.equal(deferred_ops.pop(1), qml.CNOT([0, 3])) + for op1, op2 in zip(qnode1.qtape.operations, deferred_ops): assert isinstance(op1, type(op2)) assert np.allclose(op1.data, op2.data) @@ -902,6 +933,110 @@ def qnode2(parameters): assert np.allclose(op1.data, op2.data) +class TestQubitReset: + """Tests for the qubit reset functionality of `qml.measure`.""" + + def test_correct_cnot_for_reset(self): + """Test that a CNOT is applied from the wire that stores the measurement + to the measured wire after the measurement.""" + dev = qml.device("default.qubit", wires=3) + + @qml.qnode(dev) + def qnode1(x): + qml.Hadamard(0) + qml.CRX(x, [0, 1]) + return qml.expval(qml.PauliZ(1)) + + @qml.qnode(dev) + @qml.defer_measurements + def qnode2(x): + qml.Hadamard(0) + m0 = qml.measure(0, reset=True) + qml.cond(m0, qml.RX)(x, 1) + return qml.expval(qml.PauliZ(1)) + + assert np.allclose(qnode1(0.123), qnode2(0.123)) + + expected_circuit = [ + qml.Hadamard(0), + qml.CNOT([0, 2]), + qml.CNOT([2, 0]), + qml.ops.Controlled(qml.RX(0.123, 1), 2), + qml.expval(qml.PauliZ(1)), + ] + + assert len(qnode2.qtape.circuit) == len(expected_circuit) + assert all( + qml.equal(actual, expected) + for actual, expected in zip(qnode2.qtape.circuit, expected_circuit) + ) + + def test_wire_is_reset(self): + """Test that a wire is reset to the |0> state without any local phases + after measurement if reset is requested.""" + dev = qml.device("default.qubit", wires=3) + + @qml.qnode(dev) + @qml.defer_measurements + def qnode(x): + qml.Hadamard(0) + qml.PhaseShift(np.pi / 4, 0) + m = qml.measure(0, reset=True) + qml.cond(m, qml.RX)(x, 1) + return qml.density_matrix(wires=[0]) + + # Expected reduced density matrix on wire 0 + expected_mat = np.array([[1, 0], [0, 0]]) + assert np.allclose(qnode(0.123), expected_mat) + + def test_multiple_measurements_mixed_reset(self): + """Test that a QNode with multiple mid-circuit measurements with + different resets is transformed correctly.""" + dev = qml.device("default.qubit", wires=6) + + @qml.qnode(dev) + def qnode(p, x, y): + qml.Hadamard(0) + qml.PhaseShift(p, 0) + # Set measurement_ids so that the order of wires in combined + # measurement values is consistent + + mp0 = qml.measurements.MidMeasureMP(0, reset=True, id=0) + m0 = qml.measurements.MeasurementValue([mp0], lambda v: v) + qml.cond(~m0, qml.RX)(x, 1) + mp1 = qml.measurements.MidMeasureMP(1, reset=True, id=1) + m1 = qml.measurements.MeasurementValue([mp1], lambda v: v) + qml.cond(m0 & m1, qml.Hadamard)(0) + mp2 = qml.measurements.MidMeasureMP(0, id=2) + m2 = qml.measurements.MeasurementValue([mp2], lambda v: v) + qml.cond(m1 | m2, qml.RY)(y, 2) + return qml.expval(qml.PauliZ(2)) + + _ = qnode(0.123, 0.456, 0.789) + + expected_circuit = [ + qml.Hadamard(0), + qml.PhaseShift(0.123, 0), + qml.CNOT([0, 3]), + qml.CNOT([3, 0]), + qml.ops.Controlled(qml.RX(0.456, 1), 3, [False]), + qml.CNOT([1, 4]), + qml.CNOT([4, 1]), + qml.ops.Controlled(qml.Hadamard(0), [3, 4]), + qml.CNOT([0, 5]), + qml.ops.Controlled(qml.RY(0.789, 2), [4, 5], [False, True]), + qml.ops.Controlled(qml.RY(0.789, 2), [4, 5], [True, False]), + qml.ops.Controlled(qml.RY(0.789, 2), [4, 5], [True, True]), + qml.expval(qml.PauliZ(2)), + ] + + assert len(qnode.qtape.circuit) == len(expected_circuit) + assert all( + qml.equal(actual, expected) + for actual, expected in zip(qnode.qtape.circuit, expected_circuit) + ) + + class TestDrawing: """Tests drawing circuits with mid-circuit measurements and conditional operations that have been transformed""" @@ -910,6 +1045,8 @@ def test_drawing(self): """Test that drawing a func with mid-circuit measurements works and that controlled operations are drawn for conditional operations.""" + # TODO: Update after drawing for mid-circuit measurements is updated. + def qfunc(): m_0 = qml.measure(0) qml.cond(m_0, qml.RY)(0.312, wires=1) @@ -924,8 +1061,10 @@ def qfunc(): transformed_qnode = qml.QNode(transformed_qfunc, dev) expected = ( - "0: ─╭●──────────────────┤ \n" - "1: ─╰RY(0.31)─╭RY(0.31)─┤ \n" - "2: ───────────╰●────────┤ " + "0: ─╭●────────────────────────┤ \n" + "1: ─│──╭RY(0.31)────╭RY(0.31)─┤ \n" + "2: ─│──│─────────╭●─│─────────┤ \n" + "3: ─╰X─╰●────────│──│─────────┤ \n" + "4: ──────────────╰X─╰●────────┤ " ) assert qml.draw(transformed_qnode)() == expected