diff --git a/doc/introduction/measurements.rst b/doc/introduction/measurements.rst index d8966495903..9353c8a0c7e 100644 --- a/doc/introduction/measurements.rst +++ b/doc/introduction/measurements.rst @@ -260,6 +260,9 @@ outcome of such mid-circuit measurements: qml.cond(m_0, qml.RY)(y, wires=0) return qml.probs(wires=[0]) +Deferred measurements +********************* + A quantum function with mid-circuit measurements (defined using :func:`~.pennylane.measure`) and conditional operations (defined using :func:`~.pennylane.cond`) can be executed by applying the `deferred measurement @@ -269,8 +272,12 @@ measurement on qubit 1 yielded ``1`` as an outcome, otherwise doing nothing for the ``0`` measurement outcome. PennyLane implements the deferred measurement principle to transform -conditional operations with the :func:`~.defer_measurements` quantum -function transform. +conditional operations with the :func:`~.pennylane.defer_measurements` quantum +function transform. The deferred measurement principle provides a natural method +to simulate the application of mid-circuit measurements and conditional operations +in a differentiable and device-independent way. Performing true mid-circuit +measurements and conditional operations is dependent on the quantum hardware and +PennyLane device capabilities. .. code-block:: python @@ -290,7 +297,35 @@ The decorator syntax applies equally well: def qnode(x, y): (...) -Note that we can also specify an outcome when defining a conditional operation: +Resetting wires +*************** + +Wires can be reused as normal after making mid-circuit measurements. Moreover, a measured wire can also be +reset to the :math:`|0 \rangle` state by setting the ``reset`` keyword argument of :func:`~.pennylane.measure` +to ``True``. + +.. code-block:: python3 + + dev = qml.device("default.qubit", wires=3) + + @qml.qnode(dev) + def func(): + qml.PauliX(1) + m_0 = qml.measure(1, reset=True) + qml.PauliX(1) + return qml.probs(wires=[1]) + +Executing this QNode: + +>>> func() +tensor([0., 1.], requires_grad=True) + +Conditional operators +********************* + +Users can create conditional operators controlled on mid-circuit measurements using +:func:`~.pennylane.cond`. We can also specify an outcome when defining a conditional +operation: .. code-block:: python @@ -309,30 +344,50 @@ Note that we can also specify an outcome when defining a conditional operation: >>> qnode_conditional_op_on_zero(*pars) tensor([0.88660045, 0.11339955], requires_grad=True) -Wires can be reused as normal after making mid-circuit measurements. Moreover, a measured wire can also be -reset to the :math:`|0 \rangle` state by setting the ``reset`` keyword argument of :func:`~.pennylane.measure` -to ``True``. +For more examples on applying quantum functions conditionally, refer to the +:func:`~.pennylane.cond` documentation. + +Postselecting mid-circuit measurements +************************************** + +PennyLane also supports postselecting on mid-circuit measurement outcomes by specifying the ``postselect`` +keyword argument of :func:`~.pennylane.measure`. Postselection discards outcomes that do not meet the +criteria provided by the ``postselect`` argument. For example, specifying ``postselect=1`` on wire 0 would +be equivalent to projecting the state vector onto the :math:`|1\rangle` state on wire 0: .. code-block:: python3 - dev = qml.device("default.qubit", wires=3) + dev = qml.device("default.qubit") @qml.qnode(dev) - def func(): - qml.PauliX(1) - m_0 = qml.measure(1, reset=True) - qml.PauliX(1) - return qml.probs(wires=[1]) + def func(x): + qml.RX(x, wires=0) + m0 = qml.measure(0, postselect=1) + qml.cond(m0, qml.PauliX)(wires=1) + return qml.sample(wires=1) -Executing this QNode: +By postselecting on ``1``, we only consider the ``1`` measurement outcome on wire 0. So, the probability of +measuring ``1`` on wire 1 after postselection should also be 1. Executing this QNode with 10 shots: ->>> func() -tensor([0., 1.], requires_grad=True) +>>> func(np.pi / 2, shots=10) +array([1, 1, 1, 1, 1, 1, 1]) + +Note that only 7 samples are returned. This is because samples that do not meet the postselection criteria are +discarded. + +.. note:: -Statistics can also be collected on mid-circuit measurements along with terminal measurement statistics. + Currently, postselection support is only available on :class:`~.pennylane.devices.DefaultQubit`. Using + postselection on other devices will raise an error. + +Mid-circuit measurement statistics +********************************** + +Statistics can be collected on mid-circuit measurements along with terminal measurement statistics. Currently, ``qml.probs``, ``qml.sample``, ``qml.expval``, ``qml.var``, and ``qml.counts`` are supported, and can be requested along with other measurements. The devices that currently support collecting such -statistics are ``"default.qubit"``, ``"default.mixed"``, and ``"default.qubit.legacy"``. +statistics are :class:`~.pennylane.devices.DefaultQubit`, :class:`~.pennylane.devices.DefaultMixed`, and +:class:`~.pennylane.devices.DefaultQubitLegacy`. .. code-block:: python3 @@ -351,19 +406,11 @@ Executing this QNode: (tensor([0.9267767, 0.0732233], requires_grad=True), tensor([0.5, 0.5], requires_grad=True)) -Currently, statistics can only be collected for single mid-circuit measurement values. Moreover, any -measurement values manipulated using boolean or arithmetic operators cannot be used. These can lead to -unexpected/incorrect behaviour. - -The deferred measurement principle provides a natural method to simulate the -application of mid-circuit measurements and conditional operations in a -differentiable and device-independent way. Performing true mid-circuit -measurements and conditional operations is dependent on the -quantum hardware and PennyLane device capabilities. - -For more examples on applying quantum functions conditionally, refer to the -:func:`~.pennylane.cond` transform. +.. warning:: + Currently, statistics can only be collected for single mid-circuit measurement values. Moreover, any + measurement values manipulated using boolean or arithmetic operators cannot be used. These can lead to + unexpected/incorrect behaviour. Changing the number of shots ---------------------------- diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index cd0d2c9a3c3..812e35faafd 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -117,6 +117,30 @@ (array(0.6), array([1, 1, 1, 0, 1])) ``` +* Users can now request postselection after making mid-circuit measurements. They can do so + by specifying the `postselect` keyword argument for `qml.measure` as either `0` or `1`, + corresponding to the basis states. + [(#4604)](https://github.com/PennyLaneAI/pennylane/pull/4604) + + ```python + dev = qml.device("default.qubit", wires=3) + + @qml.qnode(dev) + def circuit(phi): + qml.RX(phi, wires=0) + m = qml.measure(0, postselect=1) + qml.cond(m, qml.PauliX)(wires=1) + return qml.probs(wires=1) + ``` + ```pycon + >>> circuit(np.pi) + tensor([0., 1.], requires_grad=True) + ``` + + Here, we measure a probability of one on wire 1 as we postselect on the $|1\rangle$ state on wire + 0, thus resulting in the circuit being projected onto the state corresponding to the measurement + outcome $|1\rangle$ on wire 0. + * Operator transforms `qml.matrix`, `qml.eigvals`, `qml.generator`, and `qml.transforms.to_zx` are updated to the new transform program system. [(#4573)](https://github.com/PennyLaneAI/pennylane/pull/4573) diff --git a/pennylane/_device.py b/pennylane/_device.py index 7658b06b49f..724ae14f8c1 100644 --- a/pennylane/_device.py +++ b/pennylane/_device.py @@ -977,6 +977,9 @@ def check_validity(self, queue, observables): "simulate the application of mid-circuit measurements on this device." ) + if isinstance(o, qml.Projector): + raise ValueError(f"Postselection is not supported on the {self.name} device.") + if not self.stopping_condition(o): raise DeviceError( f"Gate {operation_name} not supported on device {self.short_name}" diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 485598ceb6d..52bcd90470b 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -387,7 +387,7 @@ def preprocess( config = self._setup_execution_config(execution_config) transform_program = TransformProgram() - transform_program.add_transform(qml.defer_measurements) + transform_program.add_transform(qml.defer_measurements, device=self) transform_program.add_transform(validate_device_wires, self.wires, name=self.name) transform_program.add_transform( decompose, stopping_condition=stopping_condition, name=self.name diff --git a/pennylane/devices/qubit/sampling.py b/pennylane/devices/qubit/sampling.py index cbcb51fc2d6..856d0cf592b 100644 --- a/pennylane/devices/qubit/sampling.py +++ b/pennylane/devices/qubit/sampling.py @@ -179,6 +179,7 @@ def measure_with_samples( Returns: List[TensorLike[Any]]: Sample measurement results """ + groups, indices = _group_measurements(mps) all_res = [] @@ -264,27 +265,37 @@ def _process_single_shot(samples): # currently we call sample_state for each shot entry, but it may be # better to call sample_state just once with total_shots, then use # the shot_range keyword argument - samples = sample_state( - state, - shots=s, - is_state_batched=is_state_batched, - wires=wires, - rng=rng, - prng_key=prng_key, - ) + try: + samples = sample_state( + state, + shots=s, + is_state_batched=is_state_batched, + wires=wires, + rng=rng, + prng_key=prng_key, + ) + except ValueError as e: + if str(e) != "probabilities contain NaN": + raise e + samples = qml.math.full((s, len(wires)), 0) processed_samples.append(_process_single_shot(samples)) return tuple(zip(*processed_samples)) - samples = sample_state( - state, - shots=shots.total_shots, - is_state_batched=is_state_batched, - wires=wires, - rng=rng, - prng_key=prng_key, - ) + try: + samples = sample_state( + state, + shots=shots.total_shots, + is_state_batched=is_state_batched, + wires=wires, + rng=rng, + prng_key=prng_key, + ) + except ValueError as e: + if str(e) != "probabilities contain NaN": + raise e + samples = qml.math.full((shots.total_shots, len(wires)), 0) return _process_single_shot(samples) @@ -352,7 +363,7 @@ def _sum_for_single_shot(s): ) return sum(c * res for c, res in zip(mp.obs.terms()[0], results)) - unsqueezed_results = tuple(_sum_for_single_shot(Shots(s)) for s in shots) + unsqueezed_results = tuple(_sum_for_single_shot(type(shots)(s)) for s in shots) return [unsqueezed_results] if shots.has_partitioned_shots else [unsqueezed_results[0]] @@ -380,7 +391,7 @@ def _sum_for_single_shot(s): ) return sum(results) - unsqueezed_results = tuple(_sum_for_single_shot(Shots(s)) for s in shots) + unsqueezed_results = tuple(_sum_for_single_shot(type(shots)(s)) for s in shots) return [unsqueezed_results] if shots.has_partitioned_shots else [unsqueezed_results[0]] diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index c3fc63afa38..2716712b863 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -14,6 +14,7 @@ """Simulate a quantum script.""" # pylint: disable=protected-access from numpy.random import default_rng +import numpy as np import pennylane as qml from pennylane.typing import Result @@ -46,6 +47,20 @@ } +class _FlexShots(qml.measurements.Shots): + """Shots class that allows zero shots.""" + + # pylint: disable=super-init-not-called + def __init__(self, shots=None): + if isinstance(shots, int): + self.total_shots = shots + self.shot_vector = (qml.measurements.ShotCopies(shots, 1),) + else: + self.__all_tuple_init__([s if isinstance(s, tuple) else (s, 1) for s in shots]) + + self._frozen = True + + def expand_state_over_wires(state, state_wires, all_wires, is_state_batched): """ Expand and re-order a state given some initial and target wire orders, setting @@ -83,6 +98,39 @@ def expand_state_over_wires(state, state_wires, all_wires, is_state_batched): return qml.math.transpose(state, desired_axes) +def _postselection_postprocess(state, is_state_batched, shots): + """Update state after projector is applied.""" + if is_state_batched: + raise ValueError( + "Cannot postselect on circuits with broadcasting. Use the " + "qml.transforms.broadcast_expand transform to split a broadcasted " + "tape into multiple non-broadcasted tapes before executing if " + "postselection is used." + ) + + # The floor function is being used here so that a norm very close to zero becomes exactly + # equal to zero so that the state can become invalid. This way, execution can continue, and + # bad postselection gives results that are invalid rather than results that look valid but + # are incorrect. + norm = qml.math.floor(qml.math.real(qml.math.norm(state)) * 1e15) * 1e-15 + + if shots: + # Clip the number of shots using a binomial distribution using the probability of + # measuring the postselected state. + postselected_shots = ( + [np.random.binomial(s, float(norm)) for s in shots] + if not qml.math.is_abstract(norm) + else shots + ) + + # _FlexShots is used here since the binomial distribution could result in zero + # valid samples + shots = _FlexShots(postselected_shots) + + state = state / qml.math.cast_like(norm, state) + return state, shots + + def get_final_state(circuit, debugger=None, interface=None): """ Get the final state that results from executing the given quantum script. @@ -112,6 +160,12 @@ def get_final_state(circuit, debugger=None, interface=None): for op in circuit.operations[bool(prep) :]: state = apply_operation(op, state, is_state_batched=is_state_batched, debugger=debugger) + # Handle postselection on mid-circuit measurements + if isinstance(op, qml.Projector): + state, circuit._shots = _postselection_postprocess( + state, is_state_batched, circuit.shots + ) + # new state is batched if i) the old state is batched, or ii) the new op adds a batch dim is_state_batched = is_state_batched or op.batch_size is not None @@ -147,6 +201,7 @@ def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=Non Returns: Tuple[TensorLike]: The measurement results """ + circuit = circuit.map_to_standard_wires() if not circuit.shots: diff --git a/pennylane/drawer/draw.py b/pennylane/drawer/draw.py index b077212897a..08379a72f98 100644 --- a/pennylane/drawer/draw.py +++ b/pennylane/drawer/draw.py @@ -267,7 +267,15 @@ def wrapper(*args, **kwargs): tapes = qnode.construct(args, kwargs) if isinstance(qnode.device, qml.devices.Device): program = qnode.transform_program - tapes = program([qnode.tape]) + if any( + isinstance(op, qml.measurements.MidMeasureMP) + for op in qnode.tape.operations + ): + tapes, _ = qml.defer_measurements(qnode.tape, device=qnode.device) + else: + tapes = [qnode.tape] + + tapes = program(tapes) finally: qnode.expansion_strategy = original_expansion_strategy diff --git a/pennylane/math/single_dispatch.py b/pennylane/math/single_dispatch.py index 96a24711a18..b2dd2ea3660 100644 --- a/pennylane/math/single_dispatch.py +++ b/pennylane/math/single_dispatch.py @@ -274,6 +274,22 @@ def _take_autograd(tensor, indices, axis=None): ar.register_function("tensorflow", "flatten", lambda x: _i("tf").reshape(x, [-1])) ar.register_function("tensorflow", "shape", lambda x: tuple(x.shape)) +ar.register_function( + "tensorflow", + "full", + lambda shape, fill_value, **kwargs: _i("tf").fill( + shape if isinstance(shape, (tuple, list)) else (shape), fill_value, **kwargs + ), +) +ar.register_function( + "tensorflow", + "isnan", + lambda tensor, **kwargs: _i("tf").math.is_nan(_i("tf").math.real(tensor), **kwargs) + | _i("tf").math.is_nan(_i("tf").math.imag(tensor), **kwargs), +) +ar.register_function( + "tensorflow", "any", lambda tensor, **kwargs: _i("tf").reduce_any(tensor, **kwargs) +) ar.register_function( "tensorflow", "sqrt", diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py index dc1cb2e4fe1..b36bdb1c285 100644 --- a/pennylane/measurements/mid_measure.py +++ b/pennylane/measurements/mid_measure.py @@ -24,7 +24,7 @@ from .measurements import MeasurementProcess, MidMeasure -def measure(wires: Wires, reset: Optional[bool] = False): +def measure(wires: Wires, reset: Optional[bool] = False, postselect: Optional[int] = None): r"""Perform a mid-circuit measurement in the computational basis on the supplied qubit. @@ -86,12 +86,93 @@ def func(): wires (Wires): The wire of the qubit the measurement process applies to. reset (Optional[bool]): Whether to reset the wire to the :math:`|0 \rangle` state after measurement. + postselect (Optional[int]): Which basis state to postselect after a mid-circuit + measurement. None by default. If postselection is requested, only the post-measurement + state that is used for postselection will be considered in the remaining circuit. Returns: MidMeasureMP: measurement process instance Raises: QuantumFunctionError: if multiple wires were specified + + .. details:: + :title: Postselection + + Postselection discards outcomes that do not meet the criteria provided by the ``postselect`` + argument. For example, specifying ``postselect=1`` on wire 0 would be equivalent to projecting + the state vector onto the :math:`|1\rangle` state on wire 0: + + .. code-block:: python3 + + dev = qml.device("default.qubit") + + @qml.qnode(dev) + def func(x): + qml.RX(x, wires=0) + m0 = qml.measure(0, postselect=1) + qml.cond(m0, qml.PauliX)(wires=1) + return qml.sample(wires=1) + + By postselecting on ``1``, we only consider the ``1`` measurement outcome on wire 0. So, the probability of + measuring ``1`` on wire 1 after postselection should also be 1. Executing this QNode with 10 shots: + + >>> func(np.pi / 2, shots=10) + array([1, 1, 1, 1, 1, 1, 1]) + + Note that only 7 samples are returned. This is because samples that do not meet the postselection criteria are + thrown away. + + If postselection is requested on a state with zero probability of being measured, the result may contain ``NaN`` + or ``Inf`` values: + + .. code-block:: python3 + + dev = qml.device("default.qubit") + + @qml.qnode(dev) + def func(x): + qml.RX(x, wires=0) + m0 = qml.measure(0, postselect=1) + qml.cond(m0, qml.PauliX)(wires=1) + return qml.probs(wires=1) + + >>> func(0.0) + tensor([nan, nan], requires_grad=True) + + In the case of ``qml.sample``, an empty array will be returned: + + .. code-block:: python3 + + dev = qml.device("default.qubit") + + @qml.qnode(dev) + def func(x): + qml.RX(x, wires=0) + m0 = qml.measure(0, postselect=1) + qml.cond(m0, qml.PauliX)(wires=1) + return qml.sample() + + >>> func(0.0, shots=[10, 10]) + (array([], dtype=float64), array([], dtype=float64)) + + .. note:: + + Currently, postselection support is only available on ``"default.qubit"``. Using postselection + on other devices will raise an error. + + .. warning:: + + All measurements are supported when using postselection. However, postselection on a zero probability + state can cause some measurements to break. + + With finite shots, one must be careful when measuring ``qml.probs`` or ``qml.counts``, as these + measurements will raise errors if there are no valid samples after postselection. This will occur + with postselection states that have zero or close to zero probability. + + With analytic execution, ``qml.mutual_info`` will raise errors when using any interfaces except + ``jax``, and ``qml.vn_entropy`` will raise an error with the ``tensorflow`` interface when the + postselection state has zero probability. """ wire = Wires(wires) @@ -102,7 +183,7 @@ def func(): # Create a UUID and a map between MP and MV to support serialization measurement_id = str(uuid.uuid4())[:8] - mp = MidMeasureMP(wires=wire, reset=reset, id=measurement_id) + mp = MidMeasureMP(wires=wire, reset=reset, postselect=postselect, id=measurement_id) return MeasurementValue([mp], processing_fn=lambda v: v) @@ -121,6 +202,9 @@ class MidMeasureMP(MeasurementProcess): wires (.Wires): The wires the measurement process applies to. This can only be specified if an observable was not provided. reset (bool): Whether to reset the wire after measurement. + postselect (Optional[int]): Which basis state to postselect after a mid-circuit + measurement. None by default. If postselection is requested, only the post-measurement + state that is used for postselection will be considered in the remaining circuit. id (str): Custom label given to a measurement instance. """ @@ -129,10 +213,15 @@ def _flatten(self): return (None, None), metadata def __init__( - self, wires: Optional[Wires] = None, reset: Optional[bool] = False, id: Optional[str] = None + self, + wires: Optional[Wires] = None, + reset: Optional[bool] = False, + postselect: Optional[int] = None, + id: Optional[str] = None, ): super().__init__(wires=Wires(wires), id=id) self.reset = reset + self.postselect = postselect @property def return_type(self): diff --git a/pennylane/ops/qubit/observables.py b/pennylane/ops/qubit/observables.py index 823e2250755..2dbc515d6e1 100644 --- a/pennylane/ops/qubit/observables.py +++ b/pennylane/ops/qubit/observables.py @@ -23,7 +23,7 @@ from scipy.sparse import csr_matrix import pennylane as qml -from pennylane.operation import AnyWires, Observable +from pennylane.operation import AnyWires, Observable, Operation from pennylane.wires import Wires from .matrix_ops import QubitUnitary @@ -412,7 +412,7 @@ def pow(self, z): return [copy(self)] if (isinstance(z, int) and z > 0) else super().pow(z) -class BasisStateProjector(Projector): +class BasisStateProjector(Projector, Operation): r"""Observable corresponding to the state projector :math:`P=\ket{\phi}\bra{\phi}`, where :math:`\phi` denotes a basis state.""" diff --git a/pennylane/qnode.py b/pennylane/qnode.py index e77d2ae6b26..90555e3d90b 100644 --- a/pennylane/qnode.py +++ b/pennylane/qnode.py @@ -898,13 +898,17 @@ def construct(self, args, kwargs): # pylint: disable=too-many-branches ) # Apply the deferred measurement principle if the device doesn't - # support mid-circuit measurements natively - expand_mid_measure = any(isinstance(op, MidMeasureMP) for op in self.tape.operations) and ( - isinstance(self.device, qml.devices.Device) - or not self.device.capabilities().get("supports_mid_measure", False) + # support mid-circuit measurements natively. + # Only apply transform with old device API as postselection with + # broadcasting will split tapes. + expand_mid_measure = ( + any(isinstance(op, MidMeasureMP) for op in self.tape.operations) + and not isinstance(self.device, qml.devices.Device) + and not self.device.capabilities().get("supports_mid_measure", False) ) if expand_mid_measure: - tapes, _ = qml.defer_measurements(self._tape) + # Assume that tapes are not split if old device is used since postselection is not supported. + tapes, _ = qml.defer_measurements(self._tape, device=self.device) self._tape = tapes[0] if self.expansion_strategy == "device": diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py index 668ce74c03a..1acf9ec29ec 100644 --- a/pennylane/transforms/defer_measurements.py +++ b/pennylane/transforms/defer_measurements.py @@ -23,7 +23,7 @@ from pennylane.wires import Wires from pennylane.queuing import QueuingManager -# pylint: disable=too-many-branches +# pylint: disable=too-many-branches, too-many-statements def null_postprocessing(results): @@ -34,7 +34,7 @@ def null_postprocessing(results): @transform -def defer_measurements(tape: QuantumTape) -> (Sequence[QuantumTape], Callable): +def defer_measurements(tape: QuantumTape, **kwargs) -> (Sequence[QuantumTape], Callable): """Quantum function transform that substitutes operations conditioned on measurement outcomes to controlled operations. @@ -150,16 +150,21 @@ def func(x, y): if ops_cv or obs_cv: raise ValueError("Continuous variable operations and observables are not supported.") + device = kwargs.get("device", None) + new_operations = [] # Find wires that are reused after measurement measured_wires = [] reused_measurement_wires = set() repeated_measurement_wire = False + is_postselecting = False for op in tape.operations: if isinstance(op, MidMeasureMP): - if op.reset is True: + if op.postselect is not None: + is_postselecting = True + if op.reset: reused_measurement_wires.add(op.wires[0]) if op.wires[0] in measured_wires: @@ -171,6 +176,9 @@ def func(x, y): set(measured_wires).intersection(op.wires.toset()) ) + if is_postselecting and device is not None and not isinstance(device, qml.devices.DefaultQubit): + raise ValueError(f"Postselection is not supported on the {device} device.") + # Apply controlled operations to store measurement outcomes and replace # classically controlled operations control_wires = {} @@ -182,6 +190,10 @@ def func(x, y): if isinstance(op, MidMeasureMP): _ = measured_wires.pop(0) + if op.postselect is not None: + with QueuingManager.stop_recording(): + new_operations.append(qml.Projector([op.postselect], wires=op.wires[0])) + # Store measurement outcome in new wire if wire gets reused if op.wires[0] in reused_measurement_wires or op.wires[0] in measured_wires: control_wires[op.id] = cur_wire @@ -191,7 +203,13 @@ def func(x, y): if op.reset: with QueuingManager.stop_recording(): - new_operations.append(qml.CNOT([cur_wire, op.wires[0]])) + # No need to manually reset if postselecting on |0> + if op.postselect is None: + new_operations.append(qml.CNOT([cur_wire, op.wires[0]])) + elif op.postselect == 1: + # We know that the measured wire will be in the |1> state if postselected + # |1>. So we can just apply a PauliX instead of a CNOT to reset + new_operations.append(qml.PauliX(op.wires[0])) cur_wire += 1 else: @@ -207,15 +225,33 @@ def func(x, y): for mp in tape.measurements: if mp.mv is not None: + # Update measurement value wires wire_map = {m.wires[0]: control_wires[m.id] for m in mp.mv.measurements} mp = qml.map_wires(mp, wire_map=wire_map) new_measurements.append(mp) new_tape = type(tape)(new_operations, new_measurements, shots=tape.shots) + if is_postselecting and new_tape.batch_size is not None: + # Split tapes if broadcasting with postselection + return qml.transforms.broadcast_expand(new_tape) + return [new_tape], null_postprocessing +@defer_measurements.custom_qnode_transform +def _defer_measurements_qnode(self, qnode, targs, tkwargs): + """Custom qnode transform for ``defer_measurements``.""" + if tkwargs.get("device", None): + raise ValueError( + "Cannot provide a 'device' value directly to the defer_measurements decorator " + "when transforming a QNode." + ) + + tkwargs.setdefault("device", qnode.device) + return self.default_qnode_transform(qnode, targs, tkwargs) + + def _add_control_gate(op, control_wires): """Helper function to add control gates""" control = [control_wires[m.id] for m in op.meas_val.measurements] diff --git a/tests/devices/default_qubit/test_default_qubit.py b/tests/devices/default_qubit/test_default_qubit.py index 1867834bf6d..c56966d4942 100644 --- a/tests/devices/default_qubit/test_default_qubit.py +++ b/tests/devices/default_qubit/test_default_qubit.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for default qubit.""" -# pylint: disable=import-outside-toplevel, no-member +# pylint: disable=import-outside-toplevel, no-member, too-many-arguments +from unittest import mock import pytest import numpy as np @@ -1637,6 +1638,268 @@ def test_projector_dynamic_type(max_workers, n_wires): assert np.isclose(res, 1 / 2**n_wires) +@pytest.mark.all_interfaces +@pytest.mark.parametrize("interface", ["numpy", "autograd", "torch", "jax", "tensorflow"]) +@pytest.mark.parametrize("use_jit", [True, False]) +class TestPostselection: + """Various integration tests for postselection of mid-circuit measurements.""" + + @pytest.mark.parametrize( + "mp", + [ + qml.expval(qml.PauliZ(0)), + qml.var(qml.PauliZ(0)), + qml.probs(wires=[0, 1]), + qml.state(), + qml.density_matrix(wires=0), + qml.purity(0), + qml.vn_entropy(0), + qml.mutual_info(0, 1), + ], + ) + @pytest.mark.parametrize("param", np.linspace(np.pi / 4, 3 * np.pi / 4, 3)) + def test_postselection_valid_analytic(self, param, mp, interface, use_jit): + """Test that the results of a circuit with postselection is expected + with analytic execution.""" + if use_jit and interface != "jax": + pytest.skip("Cannot JIT in non-JAX interfaces.") + + dev = qml.device("default.qubit") + param = qml.math.asarray(param, like=interface) + + @qml.qnode(dev, interface=interface) + def circ_postselect(theta): + qml.RX(theta, 0) + qml.CNOT([0, 1]) + qml.measure(0, postselect=1) + return qml.apply(mp) + + @qml.qnode(dev, interface=interface) + def circ_expected(): + qml.RX(np.pi, 0) + qml.CNOT([0, 1]) + return qml.apply(mp) + + if use_jit: + import jax + + circ_postselect = jax.jit(circ_postselect) + + res = circ_postselect(param) + expected = circ_expected() + + assert qml.math.allclose(res, expected) + assert qml.math.get_interface(res) == qml.math.get_interface(expected) + + @pytest.mark.parametrize( + "mp", + [ + qml.expval(qml.PauliZ(0)), + qml.var(qml.PauliZ(0)), + qml.probs(wires=[0, 1]), + qml.shadow_expval(qml.Hamiltonian([1.0, -1.0], [qml.PauliZ(0), qml.PauliX(0)])), + # qml.sample, qml.classical_shadow, qml.counts are not included because their + # shape/values are dependent on the number of shots, which will be changed + # randomly per the binomial distribution and the probability of the postselected + # state + ], + ) + @pytest.mark.parametrize("param", np.linspace(np.pi / 4, 3 * np.pi / 4, 3)) + @pytest.mark.parametrize("shots", [50000, (50000, 50000)]) + def test_postselection_valid_finite_shots( + self, param, mp, shots, interface, use_jit, tol_stochastic + ): + """Test that the results of a circuit with postselection is expected with + finite shots.""" + if use_jit and (interface != "jax" or isinstance(shots, tuple)): + pytest.skip("Cannot JIT in non-JAX interfaces, or with shot vectors.") + + dev = qml.device("default.qubit") + param = qml.math.asarray(param, like=interface) + + @qml.qnode(dev, interface=interface) + def circ_postselect(theta): + qml.RX(theta, 0) + qml.CNOT([0, 1]) + qml.measure(0, postselect=1) + return qml.apply(mp) + + @qml.qnode(dev, interface=interface) + def circ_expected(): + qml.RX(np.pi, 0) + qml.CNOT([0, 1]) + return qml.apply(mp) + + if use_jit: + import jax + + circ_postselect = jax.jit(circ_postselect, static_argnames=["shots"]) + + res = circ_postselect(param, shots=shots) + expected = circ_expected(shots=shots) + + if not isinstance(shots, tuple): + assert qml.math.allclose(res, expected, atol=tol_stochastic, rtol=0) + assert qml.math.get_interface(res) == qml.math.get_interface(expected) + + else: + assert isinstance(res, tuple) + for r, e in zip(res, expected): + assert qml.math.allclose(r, e, atol=tol_stochastic, rtol=0) + assert qml.math.get_interface(r) == qml.math.get_interface(e) + + @pytest.mark.parametrize( + "mp, expected_shape", + [(qml.sample(wires=[0]), (5,)), (qml.classical_shadow(wires=[0]), (2, 5, 1))], + ) + @pytest.mark.parametrize("param", np.linspace(np.pi / 4, 3 * np.pi / 4, 3)) + @pytest.mark.parametrize("shots", [10, (10, 10)]) + def test_postselection_valid_finite_shots_varied_shape( + self, mp, param, expected_shape, shots, interface, use_jit + ): + """Test that qml.sample and qml.classical_shadow work correctly. + Separate test because their shape is non-deterministic.""" + + if use_jit: + pytest.skip("Cannot JIT while mocking function.") + + dev = qml.device("default.qubit", seed=42) + param = qml.math.asarray(param, like=interface) + + with mock.patch("numpy.random.binomial", lambda *args, **kwargs: 5): + + @qml.qnode(dev, interface=interface) + def circ_postselect(theta): + qml.RX(theta, 0) + qml.CNOT([0, 1]) + qml.measure(0, postselect=1) + return qml.apply(mp) + + if use_jit: + import jax + + circ_postselect = jax.jit(circ_postselect, static_argnames=["shots"]) + + res = circ_postselect(param, shots=shots) + + if not isinstance(shots, tuple): + assert qml.math.get_interface(res) == interface if interface != "autograd" else "numpy" + assert qml.math.shape(res) == expected_shape + + else: + assert isinstance(res, tuple) + for r in res: + assert ( + qml.math.get_interface(r) == interface if interface != "autograd" else "numpy" + ) + assert qml.math.shape(r) == expected_shape + + @pytest.mark.parametrize( + "mp, autograd_interface, is_nan", + [ + (qml.expval(qml.PauliZ(0)), "autograd", True), + (qml.var(qml.PauliZ(0)), "autograd", True), + (qml.probs(wires=[0, 1]), "autograd", True), + (qml.state(), "autograd", True), + (qml.density_matrix(wires=0), "autograd", True), + (qml.purity(0), "numpy", True), + (qml.vn_entropy(0), "numpy", False), + (qml.mutual_info(0, 1), "numpy", False), + ], + ) + def test_postselection_invalid_analytic( + self, mp, autograd_interface, is_nan, interface, use_jit + ): + """Test that the results of a qnode are nan values of the correct shape if the state + that we are postselecting has a zero probability of occurring.""" + + if (isinstance(mp, qml.measurements.MutualInfoMP) and interface != "jax") or ( + isinstance(mp, qml.measurements.VnEntropyMP) and interface == "tensorflow" + ): + pytest.skip("Unsupported measurements and interfaces.") + + if use_jit and interface != "jax": + pytest.skip("Can't jit with non-jax interfaces.") + + # Wires are specified so that the shape for measurements can be determined correctly + dev = qml.device("default.qubit", wires=2) + + @qml.qnode(dev, interface=interface) + def circ(): + qml.RX(np.pi, 0) + qml.CNOT([0, 1]) + qml.measure(0, postselect=0) + return qml.apply(mp) + + if use_jit: + import jax + + circ = jax.jit(circ) + + res = circ() + if interface == "autograd": + assert qml.math.get_interface(res) == autograd_interface + else: + assert qml.math.get_interface(res) == interface + assert qml.math.shape(res) == mp.shape(dev, qml.measurements.Shots(None)) + if is_nan: + assert qml.math.all(qml.math.isnan(res)) + else: + assert qml.math.allclose(res, 0.0) + + @pytest.mark.parametrize( + "mp, expected_shape", + [ + (qml.expval(qml.PauliZ(0)), ()), + (qml.var(qml.PauliZ(0)), ()), + (qml.sample(qml.PauliZ(0)), (0,)), + (qml.classical_shadow(wires=0), (2, 0, 1)), + (qml.shadow_expval(qml.Hamiltonian([1, 1], [qml.PauliZ(0), qml.PauliX(0)])), ()), + # qml.probs and qml.counts are not tested because they fail in this case + ], + ) + @pytest.mark.parametrize("shots", [10, (10, 10)]) + def test_postselection_invalid_finite_shots( + self, mp, expected_shape, shots, interface, use_jit + ): + """Test that the results of a qnode are nan values of the correct shape if the state + that we are postselecting has a zero probability of occurring with finite shots.""" + + if use_jit and interface != "jax": + pytest.skip("Can't jit with non-jax interfaces.") + + dev = qml.device("default.qubit") + + @qml.qnode(dev, interface=interface) + def circ(): + qml.RX(np.pi, 0) + qml.CNOT([0, 1]) + qml.measure(0, postselect=0) + return qml.apply(mp) + + if use_jit: + import jax + + circ = jax.jit(circ, static_argnames=["shots"]) + + res = circ(shots=shots) + + if not isinstance(shots, tuple): + assert qml.math.shape(res) == expected_shape + assert qml.math.get_interface(res) == interface if interface != "autograd" else "numpy" + if not 0 in expected_shape: # No nan values if array is empty + assert qml.math.all(qml.math.isnan(res)) + else: + assert isinstance(res, tuple) + for r in res: + assert qml.math.shape(r) == expected_shape + assert ( + qml.math.get_interface(r) == interface if interface != "autograd" else "numpy" + ) + if not 0 in expected_shape: # No nan values if array is empty + assert qml.math.all(qml.math.isnan(r)) + + class TestIntegration: """Various integration tests""" diff --git a/tests/devices/qubit/test_measure.py b/tests/devices/qubit/test_measure.py index 9794ab16469..3f9573d1831 100644 --- a/tests/devices/qubit/test_measure.py +++ b/tests/devices/qubit/test_measure.py @@ -236,6 +236,125 @@ def test_sparse_hamiltonian(self): assert np.allclose(res, expected) +class TestNaNMeasurements: + """Tests for state vectors containing nan values.""" + + @pytest.mark.all_interfaces + @pytest.mark.parametrize( + "mp", + [ + qml.expval(qml.PauliZ(0)), + qml.expval( + qml.Hamiltonian( + [1.0, 2.0, 3.0, 4.0], + [qml.PauliZ(0) @ qml.PauliX(1), qml.PauliX(1), qml.PauliZ(1), qml.PauliY(1)], + ) + ), + qml.expval( + qml.dot( + [1.0, 2.0, 3.0, 4.0], + [qml.PauliZ(0) @ qml.PauliX(1), qml.PauliX(1), qml.PauliZ(1), qml.PauliY(1)], + ) + ), + qml.var(qml.PauliZ(0)), + qml.var( + qml.dot( + [1.0, 2.0, 3.0, 4.0], + [qml.PauliZ(0) @ qml.PauliX(1), qml.PauliX(1), qml.PauliZ(1), qml.PauliY(1)], + ) + ), + ], + ) + @pytest.mark.parametrize("interface", ["numpy", "autograd", "torch", "tensorflow"]) + def test_nan_float_result(self, mp, interface): + """Test that the result of circuits with 0 probability postselections is NaN with the + expected shape.""" + state = qml.math.full((2, 2), np.NaN, like=interface) + res = measure(mp, state, is_state_batched=False) + + assert qml.math.ndim(res) == 0 + assert qml.math.isnan(res) + assert qml.math.get_interface(res) == interface + + @pytest.mark.jax + @pytest.mark.parametrize( + "mp", + [ + qml.expval(qml.PauliZ(0)), + qml.expval( + qml.Hamiltonian( + [1.0, 2.0, 3.0, 4.0], + [qml.PauliZ(0) @ qml.PauliX(1), qml.PauliX(1), qml.PauliZ(1), qml.PauliY(1)], + ) + ), + qml.expval( + qml.dot( + [1.0, 2.0, 3.0, 4.0], + [qml.PauliZ(0) @ qml.PauliX(1), qml.PauliX(1), qml.PauliZ(1), qml.PauliY(1)], + ) + ), + qml.var(qml.PauliZ(0)), + qml.var( + qml.dot( + [1.0, 2.0, 3.0, 4.0], + [qml.PauliZ(0) @ qml.PauliX(1), qml.PauliX(1), qml.PauliZ(1), qml.PauliY(1)], + ) + ), + ], + ) + @pytest.mark.parametrize("use_jit", [True, False]) + def test_nan_float_result_jax(self, mp, use_jit): + """Test that the result of circuits with 0 probability postselections is NaN with the + expected shape.""" + state = qml.math.full((2, 2), np.NaN, like="jax") + if use_jit: + import jax + + res = jax.jit(measure, static_argnums=[0, 2])(mp, state, is_state_batched=False) + else: + res = measure(mp, state, is_state_batched=False) + + assert qml.math.ndim(res) == 0 + + assert qml.math.isnan(res) + assert qml.math.get_interface(res) == "jax" + + @pytest.mark.all_interfaces + @pytest.mark.parametrize( + "mp", [qml.probs(wires=0), qml.probs(op=qml.PauliZ(0)), qml.probs(wires=[0, 1])] + ) + @pytest.mark.parametrize("interface", ["numpy", "autograd", "torch", "tensorflow"]) + def test_nan_probs(self, mp, interface): + """Test that the result of circuits with 0 probability postselections is NaN with the + expected shape.""" + state = qml.math.full((2, 2), np.NaN, like=interface) + res = measure(mp, state, is_state_batched=False) + + assert qml.math.shape(res) == (2 ** len(mp.wires),) + assert qml.math.all(qml.math.isnan(res)) + assert qml.math.get_interface(res) == interface + + @pytest.mark.jax + @pytest.mark.parametrize( + "mp", [qml.probs(wires=0), qml.probs(op=qml.PauliZ(0)), qml.probs(wires=[0, 1])] + ) + @pytest.mark.parametrize("use_jit", [True, False]) + def test_nan_probs_jax(self, mp, use_jit): + """Test that the result of circuits with 0 probability postselections is NaN with the + expected shape.""" + state = qml.math.full((2, 2), np.NaN, like="jax") + if use_jit: + import jax + + res = jax.jit(measure, static_argnums=[0, 2])(mp, state, is_state_batched=False) + else: + res = measure(mp, state, is_state_batched=False) + + assert qml.math.shape(res) == (2 ** len(mp.wires),) + assert qml.math.all(qml.math.isnan(res)) + assert qml.math.get_interface(res) == "jax" + + class TestSumOfTermsDifferentiability: @staticmethod def f(scale, coeffs, n_wires=10, offset=0.1, convert_to_hamiltonian=False): diff --git a/tests/devices/qubit/test_sampling.py b/tests/devices/qubit/test_sampling.py index 8ce558d2b0f..281d590a450 100644 --- a/tests/devices/qubit/test_sampling.py +++ b/tests/devices/qubit/test_sampling.py @@ -20,8 +20,10 @@ import pennylane as qml from pennylane import numpy as np from pennylane.devices.qubit import simulate +from pennylane.devices.qubit.simulate import _FlexShots from pennylane.devices.qubit import sample_state, measure_with_samples from pennylane.devices.qubit.sampling import _sample_state_jax +from pennylane.measurements import Shots two_qubit_state = np.array([[0, 1j], [-1, 0]], dtype=np.complex128) / np.sqrt(2) APPROX_ATOL = 0.01 @@ -40,6 +42,22 @@ def _init_state(n): return _init_state +def _valid_flex_int(s): + """Returns True if s is a non-negative integer.""" + return isinstance(s, int) and s >= 0 + + +def _valid_flex_tuple(s): + """Returns True if s is a tuple of the form (shots, copies).""" + return ( + isinstance(s, tuple) + and len(s) == 2 + and _valid_flex_int(s[0]) + and isinstance(s[1], int) + and s[1] > 0 + ) + + def samples_to_probs(samples, num_wires): """Converts samples to probs""" samples_decimal = [np.ravel_multi_index(sample, [2] * num_wires) for sample in samples] @@ -490,6 +508,145 @@ def test_measure_with_samples_one_shot_one_wire(self): assert result == -1.0 +class TestInvalidStateSamples: + """Tests for state vectors containing nan values or shot vectors with zero shots.""" + + @pytest.mark.parametrize("shots", [10, [10, 10]]) + def test_only_catch_nan_errors(self, shots): + """Test that errors are only caught if they are raised due to nan values in the state.""" + state = np.zeros((2, 2)).astype(np.complex128) + mp = qml.expval(qml.PauliZ(0)) + _shots = Shots(shots) + + with pytest.raises(ValueError, match="probabilities do not sum to 1"): + _ = measure_with_samples([mp], state, _shots) + + @pytest.mark.all_interfaces + @pytest.mark.parametrize( + "mp", + [ + qml.expval(qml.PauliZ(0)), + qml.expval( + qml.Hamiltonian( + [1.0, 2.0, 3.0, 4.0], + [qml.PauliZ(0) @ qml.PauliX(1), qml.PauliX(1), qml.PauliZ(1), qml.PauliY(1)], + ) + ), + qml.expval( + qml.dot( + [1.0, 2.0, 3.0, 4.0], + [qml.PauliZ(0) @ qml.PauliX(1), qml.PauliX(1), qml.PauliZ(1), qml.PauliY(1)], + ) + ), + qml.var(qml.PauliZ(0)), + ], + ) + @pytest.mark.parametrize("interface", ["numpy", "autograd", "torch", "tensorflow", "jax"]) + @pytest.mark.parametrize("shots", [0, [0, 0]]) + def test_nan_float_result(self, mp, interface, shots): + """Test that the result of circuits with 0 probability postselections is NaN with the + expected shape.""" + state = qml.math.full((2, 2), np.NaN, like=interface) + res = measure_with_samples((mp,), state, _FlexShots(shots), is_state_batched=False) + + if not isinstance(shots, list): + assert isinstance(res, tuple) + res = res[0] + assert qml.math.ndim(res) == 0 + assert qml.math.isnan(res) + + else: + assert isinstance(res, tuple) + assert len(res) == 2 + for r in res: + assert isinstance(r, tuple) + r = r[0] + assert qml.math.ndim(r) == 0 + assert qml.math.isnan(r) + + @pytest.mark.all_interfaces + @pytest.mark.parametrize( + "mp", [qml.sample(wires=0), qml.sample(op=qml.PauliZ(0)), qml.sample(wires=[0, 1])] + ) + @pytest.mark.parametrize("interface", ["numpy", "autograd", "torch", "tensorflow", "jax"]) + @pytest.mark.parametrize("shots", [0, [0, 0]]) + def test_nan_samples(self, mp, interface, shots): + """Test that the result of circuits with 0 probability postselections is NaN with the + expected shape.""" + state = qml.math.full((2, 2), np.NaN, like=interface) + res = measure_with_samples((mp,), state, _FlexShots(shots), is_state_batched=False) + + if not isinstance(shots, list): + assert isinstance(res, tuple) + res = res[0] + assert qml.math.shape(res) == (shots,) if len(mp.wires) == 1 else (shots, len(mp.wires)) + + else: + assert isinstance(res, tuple) + assert len(res) == 2 + for i, r in enumerate(res): + assert isinstance(r, tuple) + r = r[0] + assert ( + qml.math.shape(r) == (shots[i],) + if len(mp.wires) == 1 + else (shots[i], len(mp.wires)) + ) + + @pytest.mark.all_interfaces + @pytest.mark.parametrize("interface", ["numpy", "autograd", "torch", "tensorflow", "jax"]) + @pytest.mark.parametrize("shots", [0, [0, 0]]) + def test_nan_classical_shadows(self, interface, shots): + """Test that classical_shadows returns an empty array when the state has + NaN values""" + state = qml.math.full((2, 2), np.NaN, like=interface) + res = measure_with_samples( + (qml.classical_shadow([0]),), state, _FlexShots(shots), is_state_batched=False + ) + + if not isinstance(shots, list): + assert isinstance(res, tuple) + res = res[0] + assert qml.math.shape(res) == (2, 0, 1) + assert qml.math.size(res) == 0 + + else: + assert isinstance(res, tuple) + assert len(res) == 2 + for r in res: + assert isinstance(r, tuple) + r = r[0] + assert qml.math.shape(r) == (2, 0, 1) + assert qml.math.size(r) == 0 + + @pytest.mark.all_interfaces + @pytest.mark.parametrize("H", [qml.PauliZ(0), [qml.PauliZ(0), qml.PauliX(1)]]) + @pytest.mark.parametrize("interface", ["numpy", "autograd", "torch", "tensorflow", "jax"]) + @pytest.mark.parametrize("shots", [0, [0, 0]]) + def test_nan_shadow_expval(self, H, interface, shots): + """Test that shadow_expval returns an empty array when the state has + NaN values""" + state = qml.math.full((2, 2), np.NaN, like=interface) + res = measure_with_samples( + (qml.shadow_expval(H),), state, _FlexShots(shots), is_state_batched=False + ) + + if not isinstance(shots, list): + assert isinstance(res, tuple) + res = res[0] + assert qml.math.shape(res) == qml.math.shape(H) + assert qml.math.all(qml.math.isnan(res)) + + else: + assert isinstance(res, tuple) + assert len(res) == 2 + for r in res: + assert isinstance(r, tuple) + r = r[0] + assert qml.math.shape(r) == qml.math.shape(H) + assert qml.math.all(qml.math.isnan(r)) + + class TestBroadcasting: """Test that measurements work when the state has a batch dim""" diff --git a/tests/devices/qubit/test_simulate.py b/tests/devices/qubit/test_simulate.py index eb9851b8e87..bcddbc212b9 100644 --- a/tests/devices/qubit/test_simulate.py +++ b/tests/devices/qubit/test_simulate.py @@ -351,6 +351,43 @@ def test_broadcasting_with_extra_measurement_wires(self, mocker): assert spy.call_args_list[0].args == (qs, {2: 0, 1: 1, 0: 2}) +class TestPostselection: + """Tests for applying projectors as operations.""" + + def test_projector_norm(self): + """Test that the norm of the state is maintained after applying a projector""" + tape = qml.tape.QuantumScript( + [qml.PauliX(0), qml.RX(0.123, 1), qml.Projector([0], wires=1)], [qml.state()] + ) + res = simulate(tape) + assert np.isclose(np.linalg.norm(res), 1.0) + + @pytest.mark.parametrize("shots", [None, 10, [10, 10]]) + def test_broadcasting_with_projector(self, shots): + """Test that postselecting a broadcasted state raises an error""" + tape = qml.tape.QuantumScript( + [ + qml.RX([0.1, 0.2], 0), + qml.Projector([0], wires=0), + ], + [qml.state()], + shots=shots, + ) + + with pytest.raises(ValueError, match="Cannot postselect on circuits with broadcasting"): + _ = simulate(tape) + + @pytest.mark.all_interfaces + @pytest.mark.parametrize("interface", ["numpy", "torch", "jax", "tensorflow", "autograd"]) + def test_nan_state(self, interface): + """Test that a state with nan values is returned if the probability of a postselection state + is 0.""" + tape = qml.tape.QuantumScript([qml.PauliX(0), qml.Projector([0], 0)]) + + res, _ = get_final_state(tape, interface=interface) + assert qml.math.all(qml.math.isnan(res)) + + class TestDebugger: """Tests that the debugger works for a simple circuit""" diff --git a/tests/drawer/test_draw.py b/tests/drawer/test_draw.py index a574ceac7c1..05333550320 100644 --- a/tests/drawer/test_draw.py +++ b/tests/drawer/test_draw.py @@ -99,7 +99,7 @@ def test_decimals_higher_value(self): def test_decimals_multiparameters(self): """Test decimals also displays parameters when the operation has multiple parameters.""" - @qml.qnode(qml.device("default.qubit", wires=(0))) + @qml.qnode(qml.device("default.qubit", wires=[0])) def circ(x): qml.Rot(*x, wires=0) return qml.expval(qml.PauliZ(0)) @@ -281,6 +281,25 @@ def circ(): assert draw(circ)() == expected +@pytest.mark.parametrize("device_name", ["default.qubit"]) +def test_mid_circuit_measurement_device_api(device_name, mocker): + """Test that a circuit containing mid-circuit measurements is transformed by the drawer + to use deferred measurements if the device uses the new device API.""" + dev = qml.device(device_name) + + @qml.qnode(dev) + def circ(): + qml.PauliX(0) + qml.measure(0) + return qml.probs(wires=0) + + draw_qnode = qml.draw(circ) + spy = mocker.spy(qml.defer_measurements, "_transform") + + _ = draw_qnode() + spy.assert_called_once() + + @pytest.mark.parametrize( "transform", [ diff --git a/tests/test_device.py b/tests/test_device.py index b04db9813ec..2ce1126f2f2 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -396,6 +396,16 @@ def test_check_validity_on_invalid_observable(self, mock_device_supporting_pauli with pytest.raises(DeviceError, match="Observable Hadamard not supported on device"): dev.check_validity(queue, observables) + def test_check_validity_on_projector_as_operation(self, mock_device_with_operations): + """Test that an error is raised if the operation queue contains qml.Projector""" + dev = mock_device_with_operations(wires=1) + + queue = [qml.PauliX(0), qml.Projector([0], wires=0), qml.PauliZ(0)] + observables = [] + + with pytest.raises(ValueError, match="Postselection is not supported"): + dev.check_validity(queue, observables) + def test_args(self, mock_device): """Test that the device requires correct arguments""" with pytest.raises( diff --git a/tests/test_qnode.py b/tests/test_qnode.py index a528d54f43f..fc4ca54bc8a 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -848,7 +848,11 @@ def circuit(x, y): assert np.allclose(res, expected, atol=tol, rtol=0) @pytest.mark.parametrize( - "dev", [qml.device("default.qubit", wires=3), qml.device("default.qubit", wires=3)] + "dev, call_count", + [ + (qml.device("default.qubit", wires=3), 2), + (qml.device("default.qubit.legacy", wires=3), 1), + ], ) @pytest.mark.parametrize("first_par", np.linspace(0.15, np.pi - 0.3, 3)) @pytest.mark.parametrize("sec_par", np.linspace(0.15, np.pi - 0.3, 3)) @@ -864,7 +868,7 @@ def circuit(x, y): ], ) def test_defer_meas_if_mcm_unsupported( - self, dev, first_par, sec_par, return_type, mv_return, mv_res, mocker + self, dev, call_count, first_par, sec_par, return_type, mv_return, mv_res, mocker ): # pylint: disable=too-many-arguments """Tests that the transform using the deferred measurement principle is applied if the device doesn't support mid-circuit measurements @@ -894,7 +898,7 @@ def conditional_ry_qnode(x, y): assert np.allclose(r1, r2[0]) assert np.allclose(r2[1], mv_res(first_par)) - assert spy.call_count == 3 # once for each preprocessing, once for conditional qnode + assert spy.call_count == call_count # once for each preprocessing def test_drawing_has_deferred_measurements(self): """Test that `qml.draw` with qnodes uses defer_measurements @@ -941,7 +945,7 @@ def conditional_ry_qnode(x): r1 = cry_qnode(first_par) r2 = conditional_ry_qnode(first_par) assert np.allclose(r1, r2) - assert spy.call_count == 3 # once per device preprocessing, once for conditional qnode + assert spy.call_count == 2 # once per device preprocessing @pytest.mark.tf @pytest.mark.parametrize("interface", ["tf", "auto"]) diff --git a/tests/transforms/test_defer_measurements.py b/tests/transforms/test_defer_measurements.py index 9c24826147b..f779c5a94cf 100644 --- a/tests/transforms/test_defer_measurements.py +++ b/tests/transforms/test_defer_measurements.py @@ -14,18 +14,90 @@ """ Tests for the transform implementing the deferred measurement principle. """ -# pylint: disable=too-few-public-methods +# pylint: disable=too-few-public-methods, too-many-arguments import math import pytest import pennylane as qml +from pennylane.measurements import MidMeasureMP, MeasurementValue import pennylane.numpy as np from pennylane.devices import DefaultQubit +def test_broadcasted_postselection(mocker): + """Test that broadcast_expand is used iff broadcasting with postselection.""" + spy = mocker.spy(qml.transforms, "broadcast_expand") + + # Broadcasting with postselection + tape1 = qml.tape.QuantumScript( + [qml.RX([0.1, 0.2], 0), MidMeasureMP(0, postselect=1), qml.CNOT([0, 1])], + [qml.probs(wires=[0])], + ) + _, _ = qml.defer_measurements(tape1) + + assert spy.call_count == 1 + + # Broadcasting without postselection + tape2 = qml.tape.QuantumScript( + [qml.RX([0.1, 0.2], 0), MidMeasureMP(0), qml.CNOT([0, 1])], + [qml.probs(wires=[0])], + ) + _, _ = qml.defer_measurements(tape2) + + assert spy.call_count == 1 + + # Postselection without broadcasting + tape3 = qml.tape.QuantumScript( + [qml.RX(0.1, 0), MidMeasureMP(0, postselect=1), qml.CNOT([0, 1])], + [qml.probs(wires=[0])], + ) + _, _ = qml.defer_measurements(tape3) + + assert spy.call_count == 1 + + # No postselection, no broadcasting + tape4 = qml.tape.QuantumScript( + [qml.RX(0.1, 0), MidMeasureMP(0), qml.CNOT([0, 1])], + [qml.probs(wires=[0])], + ) + _, _ = qml.defer_measurements(tape4) + + assert spy.call_count == 1 + + +def test_postselection_error_with_wrong_device(): + """Test that an error is raised when postselection is used with a device + other than `default.qubit`.""" + dev = qml.device("default.mixed", wires=2) + + @qml.defer_measurements + @qml.qnode(dev) + def circ(): + qml.measure(0, postselect=1) + return qml.probs(wires=[0]) + + with pytest.raises(ValueError, match="Postselection is not supported"): + _ = circ() + + class TestQNode: """Test that the transform integrates well with QNodes.""" + def test_custom_qnode_transform_error(self): + """Test that an error is raised if a user tries to give a device argument to the + transform when transformingn a qnode.""" + + dev = qml.device("default.qubit") + + @qml.qnode(dev) + def circ(): + qml.PauliX(0) + qml.measure(0) + return qml.probs() + + with pytest.raises(ValueError, match="Cannot provide a 'device'"): + _ = qml.defer_measurements(circ, device=dev) + def test_only_mcm(self): """Test that a quantum function that only contains one mid-circuit measurement yields the correct results and is transformed correctly.""" @@ -36,8 +108,8 @@ def test_only_mcm(self): def qnode1(): return qml.expval(qml.PauliZ(0)) - @qml.defer_measurements @qml.qnode(dev) + @qml.defer_measurements def qnode2(): qml.measure(1) return qml.expval(qml.PauliZ(0)) @@ -99,7 +171,7 @@ def qnode2(phi): # Outputs should match assert np.isclose(qnode1(np.pi / 4), qnode2(np.pi / 4)) - assert spy.call_count == 3 # once per device preprocessing, one for qnode + assert spy.call_count == 2 # once per device preprocessing deferred_tapes, _ = qml.defer_measurements(qnode1.qtape) deferred_tape = deferred_tapes[0] @@ -139,7 +211,7 @@ def qnode2(phi, theta): res2 = qnode2(np.pi / 4, 3 * np.pi / 4) - assert spy.call_count == 4 + assert spy.call_count == 2 deferred_tapes1, _ = qml.defer_measurements(qnode1.qtape) deferred_tape1 = deferred_tapes1[0] @@ -153,6 +225,102 @@ def qnode2(phi, theta): assert len(deferred_tape2.wires) == 3 assert len(deferred_tape2.operations) == 4 + @pytest.mark.parametrize("shots", [None, 1000]) + @pytest.mark.parametrize("phi", np.linspace(np.pi / 2, 7 * np.pi / 2, 6)) + def test_postselection_qnode(self, phi, shots): + """Test that a Projector is queued when postselection is requested + on a mid-circuit measurement""" + dev = DefaultQubit() + + @qml.qnode(dev) + @qml.defer_measurements + def circ1(phi): + qml.RX(phi, wires=0) + # Postselecting on |1> on wire 0 means that the probability of measuring + # |1> on wire 0 is 1 + m = qml.measure(0, postselect=1) + qml.cond(m, qml.PauliX)(wires=1) + # Probability of measuring |1> on wire 1 should be 1 + return qml.probs(wires=1) + + assert np.allclose(circ1(phi, shots=shots), [0, 1]) + + expected_circuit = [ + qml.RX(phi, 0), + qml.Projector([1], wires=0), + qml.CNOT([0, 1]), + qml.probs(wires=1), + ] + + for op, expected_op in zip(circ1.qtape, expected_circuit): + assert qml.equal(op, expected_op) + + @pytest.mark.parametrize("shots", [None, 1000]) + @pytest.mark.parametrize("phi", np.linspace(np.pi / 4, 4 * np.pi, 4)) + @pytest.mark.parametrize("theta", np.linspace(np.pi / 3, 3 * np.pi, 4)) + def test_multiple_postselection_qnode(self, phi, theta, shots, tol, tol_stochastic): + """Test that a qnode with mid-circuit measurements with postselection + is transformed correctly by defer_measurements""" + dev = DefaultQubit() + + # Initializing mid circuit measurements here so that id can be controlled (affects + # wire ordering for qml.cond) + mp0 = MidMeasureMP(wires=0, postselect=0, id=0) + mv0 = MeasurementValue([mp0], lambda v: v) + mp1 = MidMeasureMP(wires=1, postselect=0, id=1) + mv1 = MeasurementValue([mp1], lambda v: v) + mp2 = MidMeasureMP(wires=2, reset=True, postselect=1, id=2) + mv2 = MeasurementValue([mp2], lambda v: v) + + @qml.qnode(dev) + @qml.defer_measurements + def circ1(phi, theta): + qml.RX(phi, 0) + qml.apply(mp0) + qml.CNOT([0, 1]) + qml.apply(mp1) + qml.cond(~(mv0 & mv1), qml.RY)(theta, wires=2) + qml.apply(mp2) + qml.cond(mv2, qml.PauliX)(1) + return qml.probs(wires=[0, 1, 2]) + + @qml.qnode(dev) + def circ2(): + # To add wire 0 to tape + qml.Identity(0) + qml.PauliX(1) + qml.Identity(2) + return qml.probs(wires=[0, 1, 2]) + + atol = tol if shots is None else tol_stochastic + assert np.allclose(circ1(phi, theta, shots=shots), circ2(), atol=atol, rtol=0) + + expected_circuit = [ + qml.RX(phi, wires=0), + qml.Projector([0], wires=0), + qml.CNOT([0, 3]), + qml.CNOT([0, 1]), + qml.Projector([0], wires=1), + qml.CNOT([1, 4]), + qml.ops.Controlled( + qml.RY(theta, wires=[2]), control_wires=[3, 4], control_values=[False, False] + ), + qml.ops.Controlled( + qml.RY(theta, wires=[2]), control_wires=[3, 4], control_values=[False, True] + ), + qml.ops.Controlled( + qml.RY(theta, wires=[2]), control_wires=[3, 4], control_values=[True, False] + ), + qml.Projector([1], wires=2), + qml.CNOT([2, 5]), + qml.PauliX(2), + qml.CNOT([5, 1]), + qml.probs(wires=[0, 1, 2]), + ] + + for op, expected_op in zip(circ1.qtape, expected_circuit): + assert qml.equal(op, expected_op) + @pytest.mark.parametrize("shots", [None, 1000, [1000, 1000]]) def test_measurement_statistics_single_wire(self, shots): """Test that users can collect measurement statistics on @@ -175,13 +343,14 @@ def circ2(x): param = 1.5 assert np.allclose(circ1(param, shots=shots), circ2(param, shots=shots)) - @pytest.mark.parametrize("shots", [None, 1000, [1000, 1000]]) + @pytest.mark.parametrize("shots", [None, 2000, [2000, 2000]]) def test_measured_value_wires_mapped(self, shots, tol, tol_stochastic): """Test that collecting statistics on a measurement value works correctly when the measured wire is reused.""" dev = DefaultQubit() @qml.qnode(dev) + @qml.defer_measurements def circ1(x): qml.RX(x, 0) m0 = qml.measure(0) @@ -1105,6 +1274,7 @@ def test_new_wire_for_multiple_measurements(self): dev = qml.device("default.qubit", wires=4) @qml.qnode(dev) + @qml.defer_measurements def circ(x, y): qml.RX(x, 0) qml.measure(0) @@ -1204,7 +1374,7 @@ def qnode(p, x, y): spy = mocker.spy(qml.defer_measurements, "_transform") _ = qnode(0.123, 0.456, 0.789) - assert spy.call_count == 2 + assert spy.call_count == 1 expected_circuit = [ qml.Hadamard(0),