diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 90af0bf7c6b..6a70dd5b1b2 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -87,6 +87,9 @@ wire order. [(#4781)](https://github.com/PennyLaneAI/pennylane/pull/4781) +* `transpile` can now handle measurements that are broadcasted onto all wires. + [(#4793)](https://github.com/PennyLaneAI/pennylane/pull/4793) +

Contributors ✍️

This release contains contributions from (in alphabetical order): diff --git a/pennylane/transforms/transpile.py b/pennylane/transforms/transpile.py index 2ddd3a49ce1..daccae382d8 100644 --- a/pennylane/transforms/transpile.py +++ b/pennylane/transforms/transpile.py @@ -1,10 +1,12 @@ """ Contains the transpiler transform. """ +from functools import partial from typing import List, Union, Sequence, Callable import networkx as nx +import pennylane as qml from pennylane.transforms import transform from pennylane import Hamiltonian from pennylane.operation import Tensor @@ -14,9 +16,51 @@ from pennylane.tape import QuantumTape +def state_transposition(results, mps, new_wire_order, original_wire_order): + """Transpose the order of any state return. + + Args: + results (ResultBatch): the result of executing a batch of length 1 + + Keyword Args: + mps (List[MeasurementProcess]): A list of measurements processes. At least one is a ``StateMP`` + new_wire_order (Sequence[Any]): the wire order after transpile has been called + original_wire_order (.Wires): the devices wire order + + Returns: + Result: The result object with state dimensions transposed. + + """ + if len(mps) == 1: + temp_mp = qml.measurements.StateMP(wires=original_wire_order) + return temp_mp.process_state(results[0], wire_order=qml.wires.Wires(new_wire_order)) + new_results = list(results[0]) + for i, mp in enumerate(mps): + if isinstance(mp, qml.measurements.StateMP): + temp_mp = qml.measurements.StateMP(wires=original_wire_order) + new_res = temp_mp.process_state( + new_results[i], wire_order=qml.wires.Wires(new_wire_order) + ) + new_results[i] = new_res + return tuple(new_results) + + +def _process_measurements(expanded_tape, device_wires, is_default_mixed): + measurements = expanded_tape.measurements.copy() + if device_wires: + for i, m in enumerate(measurements): + if isinstance(m, qml.measurements.StateMP): + if is_default_mixed: + measurements[i] = qml.density_matrix(wires=device_wires) + elif not m.wires: + measurements[i] = type(m)(wires=device_wires) + + return measurements + + @transform def transpile( - tape: QuantumTape, coupling_map: Union[List, nx.Graph] + tape: QuantumTape, coupling_map: Union[List, nx.Graph], device=None ) -> (Sequence[QuantumTape], Callable): """Transpile a circuit according to a desired coupling map @@ -80,6 +124,12 @@ def circuit(): A swap gate has been applied to wires 2 and 3, and the remaining gates have been adapted accordingly """ + if device: + device_wires = device.wires + is_default_mixed = getattr(device, "short_name", "") == "default.mixed" + else: + device_wires = None + is_default_mixed = False # init connectivity graph coupling_graph = ( nx.Graph(coupling_map) if not isinstance(coupling_map, nx.Graph) else coupling_map @@ -113,7 +163,9 @@ def stop_at(obj): # make copy of ops list_op_copy = expanded_tape.operations.copy() - measurements = expanded_tape.measurements.copy() + wire_order = device_wires or tape.wires + measurements = _process_measurements(expanded_tape, device_wires, is_default_mixed) + gates = [] while len(list_op_copy) > 0: @@ -135,7 +187,7 @@ def stop_at(obj): continue # since in each iteration, we adjust indices of each op, we reset logical -> phyiscal mapping - wire_map = {w: w for w in tape.wires} + wire_map = {w: w for w in wire_order} # to make sure two qubit gates which act on non-neighbouring qubits q1, q2 can be applied, we first look # for the shortest path between the two qubits in the connectivity graph. We then move the q2 into the @@ -159,13 +211,39 @@ def stop_at(obj): list_op_copy.pop(0) list_op_copy = [op.map_wires(wire_map) for op in list_op_copy] + wire_order = [wire_map[w] for w in wire_order] measurements = [m.map_wires(wire_map) for m in measurements] new_tape = type(tape)(gates, measurements, shots=tape.shots) - def null_postprocessing(results): - """A postprocesing function returned by a transform that only converts the batch of results - into a result for a single ``QuantumTape``. - """ - return results[0] + # note: no need for transposition with density matrix, so type must be `StateMP` but not `DensityMatrixMP` + # pylint: disable=unidiomatic-typecheck + any_state_mp = any(type(m) is qml.measurements.StateMP for m in measurements) + if not any_state_mp or device_wires is None: + + def null_postprocessing(results): + """A postprocesing function returned by a transform that only converts the batch of results + into a result for a single ``QuantumTape``. + """ + return results[0] + + return (new_tape,), null_postprocessing + + return (new_tape,), partial( + state_transposition, + mps=measurements, + new_wire_order=wire_order, + original_wire_order=device_wires, + ) + + +@transpile.custom_qnode_transform +def _transpile_qnode(self, qnode, targs, tkwargs): + """Custom qnode transform for ``transpile``.""" + if tkwargs.get("device", None): + raise ValueError( + "Cannot provide a 'device' value directly to the defer_measurements decorator " + "when transforming a QNode." + ) - return [new_tape], null_postprocessing + tkwargs.setdefault("device", qnode.device) + return self.default_qnode_transform(qnode, targs, tkwargs) diff --git a/tests/transforms/test_transpile.py b/tests/transforms/test_transpile.py index e9708a9eacb..4bcf6acb700 100644 --- a/tests/transforms/test_transpile.py +++ b/tests/transforms/test_transpile.py @@ -310,3 +310,113 @@ def test_transpile_state(self): assert batch[0][2] == qml.CNOT((0, 1)) assert batch[0][3] == qml.state() assert batch[0].shots == tape.shots + + def test_transpile_state_with_device(self): + """Test that if a device is provided and a state is measured, then the state will be transposed during post processing.""" + + dev = qml.device("default.qubit", wires=(0, 1, 2)) + + tape = qml.tape.QuantumScript([qml.PauliX(0), qml.CNOT(wires=(0, 2))], [qml.state()]) + batch, fn = qml.transforms.transpile(tape, coupling_map=[(0, 1), (1, 2)], device=dev) + + original_mat = np.arange(8) + new_mat = fn((original_mat,)) + expected_new_mat = np.swapaxes(np.reshape(original_mat, [2, 2, 2]), 1, 2).flatten() + assert qml.math.allclose(new_mat, expected_new_mat) + + assert batch[0][0] == qml.PauliX(0) + assert batch[0][1] == qml.SWAP((1, 2)) + assert batch[0][2] == qml.CNOT((0, 1)) + assert batch[0][3] == qml.state() + + pre, post = dev.preprocess()[0]((tape,)) + original_results = post(dev.execute(pre)) + transformed_results = fn(dev.execute(batch)) + assert qml.math.allclose(original_results, transformed_results) + + def test_transpile_state_with_device_multiple_measurements(self): + """Test that if a device is provided and a state is measured, then the state will be transposed during post processing.""" + + dev = qml.device("default.qubit", wires=(0, 1, 2)) + + tape = qml.tape.QuantumScript( + [qml.PauliX(0), qml.CNOT(wires=(0, 2))], [qml.state(), qml.expval(qml.PauliZ(2))] + ) + batch, fn = qml.transforms.transpile(tape, coupling_map=[(0, 1), (1, 2)], device=dev) + + original_mat = np.arange(8) + new_mat, _ = fn(((original_mat, 2.0),)) + expected_new_mat = np.swapaxes(np.reshape(original_mat, [2, 2, 2]), 1, 2).flatten() + assert qml.math.allclose(new_mat, expected_new_mat) + + assert batch[0][0] == qml.PauliX(0) + assert batch[0][1] == qml.SWAP((1, 2)) + assert batch[0][2] == qml.CNOT((0, 1)) + assert batch[0][3] == qml.state() + assert batch[0][4] == qml.expval(qml.PauliZ(1)) + + pre, post = dev.preprocess()[0]((tape,)) + original_results = post(dev.execute(pre)) + transformed_results = fn(dev.execute(batch)) + assert qml.math.allclose(original_results[0][0], transformed_results[0]) + assert qml.math.allclose(original_results[0][1], transformed_results[1]) + + def test_transpile_with_state_default_mixed(self): + """Test that if the state is default mixed, state measurements are converted in to density measurements with the device wires.""" + + dev = qml.device("default.mixed", wires=(0, 1, 2)) + + tape = qml.tape.QuantumScript([qml.PauliX(0), qml.CNOT(wires=(0, 2))], [qml.state()]) + batch, fn = qml.transforms.transpile(tape, coupling_map=[(0, 1), (1, 2)], device=dev) + + assert batch[0][-1] == qml.density_matrix(wires=(0, 2, 1)) + + original_results = dev.execute(tape) + transformed_results = fn(dev.batch_execute(batch)) + assert qml.math.allclose(original_results, transformed_results) + + def test_transpile_probs_sample_filled_in_wires(self): + """Test that if probs or sample are requested broadcasted over all wires, transpile fills in the device wires.""" + dev = qml.device("default.qubit", wires=(0, 1, 2)) + + tape = qml.tape.QuantumScript( + [qml.PauliX(0), qml.CNOT(wires=(0, 2))], [qml.probs(), qml.sample()], shots=100 + ) + batch, fn = qml.transforms.transpile(tape, coupling_map=[(0, 1), (1, 2)], device=dev) + + assert batch[0].measurements[0] == qml.probs(wires=(0, 2, 1)) + assert batch[0].measurements[1] == qml.sample(wires=(0, 2, 1)) + + pre, post = dev.preprocess()[0]((tape,)) + original_results = post(dev.execute(pre))[0] + transformed_results = fn(dev.execute(batch)) + assert qml.math.allclose(original_results[0], transformed_results[0]) + assert qml.math.allclose(original_results[1], transformed_results[1]) + + def test_custom_qnode_transform(self): + """Test that applying the transform to a qnode adds the device to the transform kwargs.""" + + dev = qml.device("default.qubit", wires=(0, 1, 2)) + + def qfunc(): + return qml.state() + + original_qnode = qml.QNode(qfunc, dev) + transformed_qnode = transpile(original_qnode, coupling_map=[(0, 1), (1, 2)]) + + assert len(transformed_qnode.transform_program) == 1 + assert transformed_qnode.transform_program[0].kwargs["device"] is dev + + def test_qnode_transform_raises_if_device_kwarg(self): + """Test an error is raised if a device is provided as a keyword argument to a qnode transform.""" + + dev = qml.device("default.qubit", wires=[0, 1, 2, 3]) + + @qml.qnode(dev) + def circuit(): + return qml.state() + + with pytest.raises(ValueError, match=r"Cannot provide a "): + qml.transforms.transpile( + circuit, coupling_map=[(0, 1), (1, 3), (3, 2), (2, 0)], device=dev + )