Skip to content

Commit

Permalink
transpose wire ordering for state results after transpile (#4793)
Browse files Browse the repository at this point in the history
**Context:**

The `transpile` transform changes the wires of everything after the
necessary `SWAP` gates. Unfortunately, we have classes of measurements
in PennyLane that are sensitive to the wire order but do not store the
wire order on the measurement process.

```
dev = qml.device("default.qubit", wires = 4)

coupling_map=[(1, 2), (0, 2), (1, 3)]

@qml.transforms.transpile(coupling_map=coupling_map)
@qml.qnode(dev)
def circuit():
  qml.Hadamard(wires = 0)
  qml.CNOT(wires = [0,1])
  return qml.state()
```

**Description of the Change:**

* The `transpile` transform now takes the device as a keyword argument
* If the device is default mixed and the measurement is `StateMP`, we
convert it to `DensityMatrixMP`
* If the measurement process does not have wires and is not a `StateMP`,
we update the measurement to give it the device wires
3) If any of the measurements are `StateMP`, the post-processing
function transposes the state result

**Benefits:**

The output of the qnode will look the same with and without the
transpile transform

**Possible Drawbacks:**

* Transposition of state result could potentially be classically
expensive for larger systems

**Related GitHub Issues:**
  • Loading branch information
albi3ro authored Nov 9, 2023
1 parent deaf387 commit d2ad6af
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 9 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@
wire order.
[(#4781)](https://github.com/PennyLaneAI/pennylane/pull/4781)

* `transpile` can now handle measurements that are broadcasted onto all wires.
[(#4793)](https://github.com/PennyLaneAI/pennylane/pull/4793)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand Down
96 changes: 87 additions & 9 deletions pennylane/transforms/transpile.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""
Contains the transpiler transform.
"""
from functools import partial
from typing import List, Union, Sequence, Callable

import networkx as nx

import pennylane as qml
from pennylane.transforms import transform
from pennylane import Hamiltonian
from pennylane.operation import Tensor
Expand All @@ -14,9 +16,51 @@
from pennylane.tape import QuantumTape


def state_transposition(results, mps, new_wire_order, original_wire_order):
"""Transpose the order of any state return.
Args:
results (ResultBatch): the result of executing a batch of length 1
Keyword Args:
mps (List[MeasurementProcess]): A list of measurements processes. At least one is a ``StateMP``
new_wire_order (Sequence[Any]): the wire order after transpile has been called
original_wire_order (.Wires): the devices wire order
Returns:
Result: The result object with state dimensions transposed.
"""
if len(mps) == 1:
temp_mp = qml.measurements.StateMP(wires=original_wire_order)
return temp_mp.process_state(results[0], wire_order=qml.wires.Wires(new_wire_order))
new_results = list(results[0])
for i, mp in enumerate(mps):
if isinstance(mp, qml.measurements.StateMP):
temp_mp = qml.measurements.StateMP(wires=original_wire_order)
new_res = temp_mp.process_state(
new_results[i], wire_order=qml.wires.Wires(new_wire_order)
)
new_results[i] = new_res
return tuple(new_results)


def _process_measurements(expanded_tape, device_wires, is_default_mixed):
measurements = expanded_tape.measurements.copy()
if device_wires:
for i, m in enumerate(measurements):
if isinstance(m, qml.measurements.StateMP):
if is_default_mixed:
measurements[i] = qml.density_matrix(wires=device_wires)
elif not m.wires:
measurements[i] = type(m)(wires=device_wires)

return measurements


@transform
def transpile(
tape: QuantumTape, coupling_map: Union[List, nx.Graph]
tape: QuantumTape, coupling_map: Union[List, nx.Graph], device=None
) -> (Sequence[QuantumTape], Callable):
"""Transpile a circuit according to a desired coupling map
Expand Down Expand Up @@ -80,6 +124,12 @@ def circuit():
A swap gate has been applied to wires 2 and 3, and the remaining gates have been adapted accordingly
"""
if device:
device_wires = device.wires
is_default_mixed = getattr(device, "short_name", "") == "default.mixed"
else:
device_wires = None
is_default_mixed = False
# init connectivity graph
coupling_graph = (
nx.Graph(coupling_map) if not isinstance(coupling_map, nx.Graph) else coupling_map
Expand Down Expand Up @@ -113,7 +163,9 @@ def stop_at(obj):

# make copy of ops
list_op_copy = expanded_tape.operations.copy()
measurements = expanded_tape.measurements.copy()
wire_order = device_wires or tape.wires
measurements = _process_measurements(expanded_tape, device_wires, is_default_mixed)

gates = []

while len(list_op_copy) > 0:
Expand All @@ -135,7 +187,7 @@ def stop_at(obj):
continue

# since in each iteration, we adjust indices of each op, we reset logical -> phyiscal mapping
wire_map = {w: w for w in tape.wires}
wire_map = {w: w for w in wire_order}

# to make sure two qubit gates which act on non-neighbouring qubits q1, q2 can be applied, we first look
# for the shortest path between the two qubits in the connectivity graph. We then move the q2 into the
Expand All @@ -159,13 +211,39 @@ def stop_at(obj):
list_op_copy.pop(0)

list_op_copy = [op.map_wires(wire_map) for op in list_op_copy]
wire_order = [wire_map[w] for w in wire_order]
measurements = [m.map_wires(wire_map) for m in measurements]
new_tape = type(tape)(gates, measurements, shots=tape.shots)

def null_postprocessing(results):
"""A postprocesing function returned by a transform that only converts the batch of results
into a result for a single ``QuantumTape``.
"""
return results[0]
# note: no need for transposition with density matrix, so type must be `StateMP` but not `DensityMatrixMP`
# pylint: disable=unidiomatic-typecheck
any_state_mp = any(type(m) is qml.measurements.StateMP for m in measurements)
if not any_state_mp or device_wires is None:

def null_postprocessing(results):
"""A postprocesing function returned by a transform that only converts the batch of results
into a result for a single ``QuantumTape``.
"""
return results[0]

return (new_tape,), null_postprocessing

return (new_tape,), partial(
state_transposition,
mps=measurements,
new_wire_order=wire_order,
original_wire_order=device_wires,
)


@transpile.custom_qnode_transform
def _transpile_qnode(self, qnode, targs, tkwargs):
"""Custom qnode transform for ``transpile``."""
if tkwargs.get("device", None):
raise ValueError(
"Cannot provide a 'device' value directly to the defer_measurements decorator "
"when transforming a QNode."
)

return [new_tape], null_postprocessing
tkwargs.setdefault("device", qnode.device)
return self.default_qnode_transform(qnode, targs, tkwargs)
110 changes: 110 additions & 0 deletions tests/transforms/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,113 @@ def test_transpile_state(self):
assert batch[0][2] == qml.CNOT((0, 1))
assert batch[0][3] == qml.state()
assert batch[0].shots == tape.shots

def test_transpile_state_with_device(self):
"""Test that if a device is provided and a state is measured, then the state will be transposed during post processing."""

dev = qml.device("default.qubit", wires=(0, 1, 2))

tape = qml.tape.QuantumScript([qml.PauliX(0), qml.CNOT(wires=(0, 2))], [qml.state()])
batch, fn = qml.transforms.transpile(tape, coupling_map=[(0, 1), (1, 2)], device=dev)

original_mat = np.arange(8)
new_mat = fn((original_mat,))
expected_new_mat = np.swapaxes(np.reshape(original_mat, [2, 2, 2]), 1, 2).flatten()
assert qml.math.allclose(new_mat, expected_new_mat)

assert batch[0][0] == qml.PauliX(0)
assert batch[0][1] == qml.SWAP((1, 2))
assert batch[0][2] == qml.CNOT((0, 1))
assert batch[0][3] == qml.state()

pre, post = dev.preprocess()[0]((tape,))
original_results = post(dev.execute(pre))
transformed_results = fn(dev.execute(batch))
assert qml.math.allclose(original_results, transformed_results)

def test_transpile_state_with_device_multiple_measurements(self):
"""Test that if a device is provided and a state is measured, then the state will be transposed during post processing."""

dev = qml.device("default.qubit", wires=(0, 1, 2))

tape = qml.tape.QuantumScript(
[qml.PauliX(0), qml.CNOT(wires=(0, 2))], [qml.state(), qml.expval(qml.PauliZ(2))]
)
batch, fn = qml.transforms.transpile(tape, coupling_map=[(0, 1), (1, 2)], device=dev)

original_mat = np.arange(8)
new_mat, _ = fn(((original_mat, 2.0),))
expected_new_mat = np.swapaxes(np.reshape(original_mat, [2, 2, 2]), 1, 2).flatten()
assert qml.math.allclose(new_mat, expected_new_mat)

assert batch[0][0] == qml.PauliX(0)
assert batch[0][1] == qml.SWAP((1, 2))
assert batch[0][2] == qml.CNOT((0, 1))
assert batch[0][3] == qml.state()
assert batch[0][4] == qml.expval(qml.PauliZ(1))

pre, post = dev.preprocess()[0]((tape,))
original_results = post(dev.execute(pre))
transformed_results = fn(dev.execute(batch))
assert qml.math.allclose(original_results[0][0], transformed_results[0])
assert qml.math.allclose(original_results[0][1], transformed_results[1])

def test_transpile_with_state_default_mixed(self):
"""Test that if the state is default mixed, state measurements are converted in to density measurements with the device wires."""

dev = qml.device("default.mixed", wires=(0, 1, 2))

tape = qml.tape.QuantumScript([qml.PauliX(0), qml.CNOT(wires=(0, 2))], [qml.state()])
batch, fn = qml.transforms.transpile(tape, coupling_map=[(0, 1), (1, 2)], device=dev)

assert batch[0][-1] == qml.density_matrix(wires=(0, 2, 1))

original_results = dev.execute(tape)
transformed_results = fn(dev.batch_execute(batch))
assert qml.math.allclose(original_results, transformed_results)

def test_transpile_probs_sample_filled_in_wires(self):
"""Test that if probs or sample are requested broadcasted over all wires, transpile fills in the device wires."""
dev = qml.device("default.qubit", wires=(0, 1, 2))

tape = qml.tape.QuantumScript(
[qml.PauliX(0), qml.CNOT(wires=(0, 2))], [qml.probs(), qml.sample()], shots=100
)
batch, fn = qml.transforms.transpile(tape, coupling_map=[(0, 1), (1, 2)], device=dev)

assert batch[0].measurements[0] == qml.probs(wires=(0, 2, 1))
assert batch[0].measurements[1] == qml.sample(wires=(0, 2, 1))

pre, post = dev.preprocess()[0]((tape,))
original_results = post(dev.execute(pre))[0]
transformed_results = fn(dev.execute(batch))
assert qml.math.allclose(original_results[0], transformed_results[0])
assert qml.math.allclose(original_results[1], transformed_results[1])

def test_custom_qnode_transform(self):
"""Test that applying the transform to a qnode adds the device to the transform kwargs."""

dev = qml.device("default.qubit", wires=(0, 1, 2))

def qfunc():
return qml.state()

original_qnode = qml.QNode(qfunc, dev)
transformed_qnode = transpile(original_qnode, coupling_map=[(0, 1), (1, 2)])

assert len(transformed_qnode.transform_program) == 1
assert transformed_qnode.transform_program[0].kwargs["device"] is dev

def test_qnode_transform_raises_if_device_kwarg(self):
"""Test an error is raised if a device is provided as a keyword argument to a qnode transform."""

dev = qml.device("default.qubit", wires=[0, 1, 2, 3])

@qml.qnode(dev)
def circuit():
return qml.state()

with pytest.raises(ValueError, match=r"Cannot provide a "):
qml.transforms.transpile(
circuit, coupling_map=[(0, 1), (1, 3), (3, 2), (2, 0)], device=dev
)

0 comments on commit d2ad6af

Please sign in to comment.