diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 90af0bf7c6b..6a70dd5b1b2 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -87,6 +87,9 @@
wire order.
[(#4781)](https://github.com/PennyLaneAI/pennylane/pull/4781)
+* `transpile` can now handle measurements that are broadcasted onto all wires.
+ [(#4793)](https://github.com/PennyLaneAI/pennylane/pull/4793)
+
Contributors ✍️
This release contains contributions from (in alphabetical order):
diff --git a/pennylane/transforms/transpile.py b/pennylane/transforms/transpile.py
index 2ddd3a49ce1..daccae382d8 100644
--- a/pennylane/transforms/transpile.py
+++ b/pennylane/transforms/transpile.py
@@ -1,10 +1,12 @@
"""
Contains the transpiler transform.
"""
+from functools import partial
from typing import List, Union, Sequence, Callable
import networkx as nx
+import pennylane as qml
from pennylane.transforms import transform
from pennylane import Hamiltonian
from pennylane.operation import Tensor
@@ -14,9 +16,51 @@
from pennylane.tape import QuantumTape
+def state_transposition(results, mps, new_wire_order, original_wire_order):
+ """Transpose the order of any state return.
+
+ Args:
+ results (ResultBatch): the result of executing a batch of length 1
+
+ Keyword Args:
+ mps (List[MeasurementProcess]): A list of measurements processes. At least one is a ``StateMP``
+ new_wire_order (Sequence[Any]): the wire order after transpile has been called
+ original_wire_order (.Wires): the devices wire order
+
+ Returns:
+ Result: The result object with state dimensions transposed.
+
+ """
+ if len(mps) == 1:
+ temp_mp = qml.measurements.StateMP(wires=original_wire_order)
+ return temp_mp.process_state(results[0], wire_order=qml.wires.Wires(new_wire_order))
+ new_results = list(results[0])
+ for i, mp in enumerate(mps):
+ if isinstance(mp, qml.measurements.StateMP):
+ temp_mp = qml.measurements.StateMP(wires=original_wire_order)
+ new_res = temp_mp.process_state(
+ new_results[i], wire_order=qml.wires.Wires(new_wire_order)
+ )
+ new_results[i] = new_res
+ return tuple(new_results)
+
+
+def _process_measurements(expanded_tape, device_wires, is_default_mixed):
+ measurements = expanded_tape.measurements.copy()
+ if device_wires:
+ for i, m in enumerate(measurements):
+ if isinstance(m, qml.measurements.StateMP):
+ if is_default_mixed:
+ measurements[i] = qml.density_matrix(wires=device_wires)
+ elif not m.wires:
+ measurements[i] = type(m)(wires=device_wires)
+
+ return measurements
+
+
@transform
def transpile(
- tape: QuantumTape, coupling_map: Union[List, nx.Graph]
+ tape: QuantumTape, coupling_map: Union[List, nx.Graph], device=None
) -> (Sequence[QuantumTape], Callable):
"""Transpile a circuit according to a desired coupling map
@@ -80,6 +124,12 @@ def circuit():
A swap gate has been applied to wires 2 and 3, and the remaining gates have been adapted accordingly
"""
+ if device:
+ device_wires = device.wires
+ is_default_mixed = getattr(device, "short_name", "") == "default.mixed"
+ else:
+ device_wires = None
+ is_default_mixed = False
# init connectivity graph
coupling_graph = (
nx.Graph(coupling_map) if not isinstance(coupling_map, nx.Graph) else coupling_map
@@ -113,7 +163,9 @@ def stop_at(obj):
# make copy of ops
list_op_copy = expanded_tape.operations.copy()
- measurements = expanded_tape.measurements.copy()
+ wire_order = device_wires or tape.wires
+ measurements = _process_measurements(expanded_tape, device_wires, is_default_mixed)
+
gates = []
while len(list_op_copy) > 0:
@@ -135,7 +187,7 @@ def stop_at(obj):
continue
# since in each iteration, we adjust indices of each op, we reset logical -> phyiscal mapping
- wire_map = {w: w for w in tape.wires}
+ wire_map = {w: w for w in wire_order}
# to make sure two qubit gates which act on non-neighbouring qubits q1, q2 can be applied, we first look
# for the shortest path between the two qubits in the connectivity graph. We then move the q2 into the
@@ -159,13 +211,39 @@ def stop_at(obj):
list_op_copy.pop(0)
list_op_copy = [op.map_wires(wire_map) for op in list_op_copy]
+ wire_order = [wire_map[w] for w in wire_order]
measurements = [m.map_wires(wire_map) for m in measurements]
new_tape = type(tape)(gates, measurements, shots=tape.shots)
- def null_postprocessing(results):
- """A postprocesing function returned by a transform that only converts the batch of results
- into a result for a single ``QuantumTape``.
- """
- return results[0]
+ # note: no need for transposition with density matrix, so type must be `StateMP` but not `DensityMatrixMP`
+ # pylint: disable=unidiomatic-typecheck
+ any_state_mp = any(type(m) is qml.measurements.StateMP for m in measurements)
+ if not any_state_mp or device_wires is None:
+
+ def null_postprocessing(results):
+ """A postprocesing function returned by a transform that only converts the batch of results
+ into a result for a single ``QuantumTape``.
+ """
+ return results[0]
+
+ return (new_tape,), null_postprocessing
+
+ return (new_tape,), partial(
+ state_transposition,
+ mps=measurements,
+ new_wire_order=wire_order,
+ original_wire_order=device_wires,
+ )
+
+
+@transpile.custom_qnode_transform
+def _transpile_qnode(self, qnode, targs, tkwargs):
+ """Custom qnode transform for ``transpile``."""
+ if tkwargs.get("device", None):
+ raise ValueError(
+ "Cannot provide a 'device' value directly to the defer_measurements decorator "
+ "when transforming a QNode."
+ )
- return [new_tape], null_postprocessing
+ tkwargs.setdefault("device", qnode.device)
+ return self.default_qnode_transform(qnode, targs, tkwargs)
diff --git a/tests/transforms/test_transpile.py b/tests/transforms/test_transpile.py
index e9708a9eacb..4bcf6acb700 100644
--- a/tests/transforms/test_transpile.py
+++ b/tests/transforms/test_transpile.py
@@ -310,3 +310,113 @@ def test_transpile_state(self):
assert batch[0][2] == qml.CNOT((0, 1))
assert batch[0][3] == qml.state()
assert batch[0].shots == tape.shots
+
+ def test_transpile_state_with_device(self):
+ """Test that if a device is provided and a state is measured, then the state will be transposed during post processing."""
+
+ dev = qml.device("default.qubit", wires=(0, 1, 2))
+
+ tape = qml.tape.QuantumScript([qml.PauliX(0), qml.CNOT(wires=(0, 2))], [qml.state()])
+ batch, fn = qml.transforms.transpile(tape, coupling_map=[(0, 1), (1, 2)], device=dev)
+
+ original_mat = np.arange(8)
+ new_mat = fn((original_mat,))
+ expected_new_mat = np.swapaxes(np.reshape(original_mat, [2, 2, 2]), 1, 2).flatten()
+ assert qml.math.allclose(new_mat, expected_new_mat)
+
+ assert batch[0][0] == qml.PauliX(0)
+ assert batch[0][1] == qml.SWAP((1, 2))
+ assert batch[0][2] == qml.CNOT((0, 1))
+ assert batch[0][3] == qml.state()
+
+ pre, post = dev.preprocess()[0]((tape,))
+ original_results = post(dev.execute(pre))
+ transformed_results = fn(dev.execute(batch))
+ assert qml.math.allclose(original_results, transformed_results)
+
+ def test_transpile_state_with_device_multiple_measurements(self):
+ """Test that if a device is provided and a state is measured, then the state will be transposed during post processing."""
+
+ dev = qml.device("default.qubit", wires=(0, 1, 2))
+
+ tape = qml.tape.QuantumScript(
+ [qml.PauliX(0), qml.CNOT(wires=(0, 2))], [qml.state(), qml.expval(qml.PauliZ(2))]
+ )
+ batch, fn = qml.transforms.transpile(tape, coupling_map=[(0, 1), (1, 2)], device=dev)
+
+ original_mat = np.arange(8)
+ new_mat, _ = fn(((original_mat, 2.0),))
+ expected_new_mat = np.swapaxes(np.reshape(original_mat, [2, 2, 2]), 1, 2).flatten()
+ assert qml.math.allclose(new_mat, expected_new_mat)
+
+ assert batch[0][0] == qml.PauliX(0)
+ assert batch[0][1] == qml.SWAP((1, 2))
+ assert batch[0][2] == qml.CNOT((0, 1))
+ assert batch[0][3] == qml.state()
+ assert batch[0][4] == qml.expval(qml.PauliZ(1))
+
+ pre, post = dev.preprocess()[0]((tape,))
+ original_results = post(dev.execute(pre))
+ transformed_results = fn(dev.execute(batch))
+ assert qml.math.allclose(original_results[0][0], transformed_results[0])
+ assert qml.math.allclose(original_results[0][1], transformed_results[1])
+
+ def test_transpile_with_state_default_mixed(self):
+ """Test that if the state is default mixed, state measurements are converted in to density measurements with the device wires."""
+
+ dev = qml.device("default.mixed", wires=(0, 1, 2))
+
+ tape = qml.tape.QuantumScript([qml.PauliX(0), qml.CNOT(wires=(0, 2))], [qml.state()])
+ batch, fn = qml.transforms.transpile(tape, coupling_map=[(0, 1), (1, 2)], device=dev)
+
+ assert batch[0][-1] == qml.density_matrix(wires=(0, 2, 1))
+
+ original_results = dev.execute(tape)
+ transformed_results = fn(dev.batch_execute(batch))
+ assert qml.math.allclose(original_results, transformed_results)
+
+ def test_transpile_probs_sample_filled_in_wires(self):
+ """Test that if probs or sample are requested broadcasted over all wires, transpile fills in the device wires."""
+ dev = qml.device("default.qubit", wires=(0, 1, 2))
+
+ tape = qml.tape.QuantumScript(
+ [qml.PauliX(0), qml.CNOT(wires=(0, 2))], [qml.probs(), qml.sample()], shots=100
+ )
+ batch, fn = qml.transforms.transpile(tape, coupling_map=[(0, 1), (1, 2)], device=dev)
+
+ assert batch[0].measurements[0] == qml.probs(wires=(0, 2, 1))
+ assert batch[0].measurements[1] == qml.sample(wires=(0, 2, 1))
+
+ pre, post = dev.preprocess()[0]((tape,))
+ original_results = post(dev.execute(pre))[0]
+ transformed_results = fn(dev.execute(batch))
+ assert qml.math.allclose(original_results[0], transformed_results[0])
+ assert qml.math.allclose(original_results[1], transformed_results[1])
+
+ def test_custom_qnode_transform(self):
+ """Test that applying the transform to a qnode adds the device to the transform kwargs."""
+
+ dev = qml.device("default.qubit", wires=(0, 1, 2))
+
+ def qfunc():
+ return qml.state()
+
+ original_qnode = qml.QNode(qfunc, dev)
+ transformed_qnode = transpile(original_qnode, coupling_map=[(0, 1), (1, 2)])
+
+ assert len(transformed_qnode.transform_program) == 1
+ assert transformed_qnode.transform_program[0].kwargs["device"] is dev
+
+ def test_qnode_transform_raises_if_device_kwarg(self):
+ """Test an error is raised if a device is provided as a keyword argument to a qnode transform."""
+
+ dev = qml.device("default.qubit", wires=[0, 1, 2, 3])
+
+ @qml.qnode(dev)
+ def circuit():
+ return qml.state()
+
+ with pytest.raises(ValueError, match=r"Cannot provide a "):
+ qml.transforms.transpile(
+ circuit, coupling_map=[(0, 1), (1, 3), (3, 2), (2, 0)], device=dev
+ )