Skip to content

Commit

Permalink
Merge branch 'master' into add-pytest-benchmarks-CI
Browse files Browse the repository at this point in the history
  • Loading branch information
AmintorDusko authored Nov 8, 2023
2 parents 36e8fbd + 4eecda9 commit 9dfa94c
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 1 deletion.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

<h3>Improvements 🛠</h3>

* `qml.draw` now supports drawing mid-circuit measurements.
[(#4775)](https://github.com/PennyLaneAI/pennylane/pull/4775)

* Autograd can now use vjps provided by the device from the new device API. If a device provides
a vector Jacobian product, this can be selected by providing `device_vjp=True` to
`qml.execute`.
Expand Down
2 changes: 1 addition & 1 deletion pennylane/drawer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def unwrap_controls(op):
# Get wires and control values of base operation; need to make a copy of
# control values, otherwise it will modify the list in the operation itself.
control_wires = getattr(op, "control_wires", [])
control_values = op.hyperparameters.get("control_values", None)
control_values = getattr(op, "hyperparameters", {}).get("control_values", None)

if isinstance(control_values, list):
control_values = control_values.copy()
Expand Down
22 changes: 22 additions & 0 deletions pennylane/measurements/mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,28 @@ def __init__(
self.reset = reset
self.postselect = postselect

def label(self, decimals=None, base_label=None, cache=None): # pylint: disable=unused-argument
r"""How the mid-circuit measurement is represented in diagrams and drawings.
Args:
decimals=None (Int): If ``None``, no parameters are included. Else,
how to round the parameters.
base_label=None (Iterable[str]): overwrite the non-parameter component of the label.
Must be same length as ``obs`` attribute.
cache=None (dict): dictionary that carries information between label calls
in the same drawing
Returns:
str: label to use in drawings
"""
_label = "┤↗"
if self.postselect is not None:
_label += "₁" if self.postselect == 1 else "₀"

_label += "├" if not self.reset else "│ │0⟩"

return _label

@property
def return_type(self):
return MidMeasure
Expand Down
54 changes: 54 additions & 0 deletions tests/drawer/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,60 @@ def circ():
spy.assert_called_once()


@pytest.mark.parametrize(
"postselect, reset, mid_measure_label",
[
(None, False, "┤↗├"),
(None, True, "┤↗│ │0⟩"),
(0, False, "┤↗₀├"),
(0, True, "┤↗₀│ │0⟩"),
(1, False, "┤↗₁├"),
(1, True, "┤↗₁│ │0⟩"),
],
)
def test_draw_mid_circuit_measurement(postselect, reset, mid_measure_label):
"""Test that mid-circuit measurements are drawn correctly."""

def func():
qml.Hadamard(0)
qml.measure(0, reset=reset, postselect=postselect)
qml.PauliX(0)
return qml.expval(qml.PauliZ(0))

drawing = qml.draw(func)()
expected_drawing = "0: ──H──" + mid_measure_label + "──X─┤ <Z>"

assert drawing == expected_drawing


def test_draw_mid_circuit_measurement_multiple_wires():
"""Test that mid-circuit measurements are correctly drawn in circuits
with multiple wires."""

def circ(weights):
qml.RX(weights[0], 0)
qml.measure(0, reset=True)
qml.RX(weights[1], 1)
qml.measure(1)
qml.CNOT([0, 3])
qml.measure(3, postselect=0, reset=True)
qml.RY(weights[2], 2)
qml.CNOT([1, 2])
qml.measure(2, postselect=1)
qml.MultiRZ(0.5, [0, 2])
return qml.expval(qml.PauliZ(2))

drawing = qml.draw(circ)(np.array([np.pi, 3.124, 0.456]))
expected_drawing = (
"0: ──RX(3.14)──┤↗│ │0⟩─╭●─────────────────────╭MultiRZ(0.50)─┤ \n"
"1: ──RX(3.12)──┤↗├──────│─────────────╭●───────│──────────────┤ \n"
"3: ─────────────────────╰X──┤↗₀│ │0⟩─│────────│──────────────┤ \n"
"2: ──RY(0.46)─────────────────────────╰X──┤↗₁├─╰MultiRZ(0.50)─┤ <Z>"
)

assert drawing == expected_drawing


@pytest.mark.parametrize(
"transform",
[
Expand Down
18 changes: 18 additions & 0 deletions tests/measurements/test_mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,24 @@ def test_hash(self):
assert m1.hash != m3.hash
assert m1.hash == m4.hash

@pytest.mark.parametrize(
"postselect, reset, expected",
[
(None, False, "┤↗├"),
(None, True, "┤↗│ │0⟩"),
(0, False, "┤↗₀├"),
(0, True, "┤↗₀│ │0⟩"),
(1, False, "┤↗₁├"),
(1, True, "┤↗₁│ │0⟩"),
],
)
def test_label(self, postselect, reset, expected):
"""Test that the label for a MidMeasureMP is correct"""
mp = MidMeasureMP(0, postselect=postselect, reset=reset)

label = mp.label()
assert label == expected


mp1 = MidMeasureMP(Wires(0), id="m0")
mp2 = MidMeasureMP(Wires(1), id="m1")
Expand Down

0 comments on commit 9dfa94c

Please sign in to comment.