diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 8b24103812c..e7cd3acc385 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -9,6 +9,9 @@

Improvements πŸ› 

+* `qml.draw` now supports drawing mid-circuit measurements. + [(#4775)](https://github.com/PennyLaneAI/pennylane/pull/4775) + * Autograd can now use vjps provided by the device from the new device API. If a device provides a vector Jacobian product, this can be selected by providing `device_vjp=True` to `qml.execute`. diff --git a/pennylane/drawer/utils.py b/pennylane/drawer/utils.py index 6a0a36dd1e7..603653eb8a2 100644 --- a/pennylane/drawer/utils.py +++ b/pennylane/drawer/utils.py @@ -77,7 +77,7 @@ def unwrap_controls(op): # Get wires and control values of base operation; need to make a copy of # control values, otherwise it will modify the list in the operation itself. control_wires = getattr(op, "control_wires", []) - control_values = op.hyperparameters.get("control_values", None) + control_values = getattr(op, "hyperparameters", {}).get("control_values", None) if isinstance(control_values, list): control_values = control_values.copy() diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py index 825c0283820..ae5734677a5 100644 --- a/pennylane/measurements/mid_measure.py +++ b/pennylane/measurements/mid_measure.py @@ -223,6 +223,28 @@ def __init__( self.reset = reset self.postselect = postselect + def label(self, decimals=None, base_label=None, cache=None): # pylint: disable=unused-argument + r"""How the mid-circuit measurement is represented in diagrams and drawings. + + Args: + decimals=None (Int): If ``None``, no parameters are included. Else, + how to round the parameters. + base_label=None (Iterable[str]): overwrite the non-parameter component of the label. + Must be same length as ``obs`` attribute. + cache=None (dict): dictionary that carries information between label calls + in the same drawing + + Returns: + str: label to use in drawings + """ + _label = "─↗" + if self.postselect is not None: + _label += "₁" if self.postselect == 1 else "β‚€" + + _label += "β”œ" if not self.reset else "β”‚ β”‚0⟩" + + return _label + @property def return_type(self): return MidMeasure diff --git a/tests/drawer/test_draw.py b/tests/drawer/test_draw.py index 05333550320..06cc0a646ff 100644 --- a/tests/drawer/test_draw.py +++ b/tests/drawer/test_draw.py @@ -300,6 +300,60 @@ def circ(): spy.assert_called_once() +@pytest.mark.parametrize( + "postselect, reset, mid_measure_label", + [ + (None, False, "β”€β†—β”œ"), + (None, True, "─↗│ β”‚0⟩"), + (0, False, "β”€β†—β‚€β”œ"), + (0, True, "─↗₀│ β”‚0⟩"), + (1, False, "β”€β†—β‚β”œ"), + (1, True, "─↗₁│ β”‚0⟩"), + ], +) +def test_draw_mid_circuit_measurement(postselect, reset, mid_measure_label): + """Test that mid-circuit measurements are drawn correctly.""" + + def func(): + qml.Hadamard(0) + qml.measure(0, reset=reset, postselect=postselect) + qml.PauliX(0) + return qml.expval(qml.PauliZ(0)) + + drawing = qml.draw(func)() + expected_drawing = "0: ──H──" + mid_measure_label + "──X── " + + assert drawing == expected_drawing + + +def test_draw_mid_circuit_measurement_multiple_wires(): + """Test that mid-circuit measurements are correctly drawn in circuits + with multiple wires.""" + + def circ(weights): + qml.RX(weights[0], 0) + qml.measure(0, reset=True) + qml.RX(weights[1], 1) + qml.measure(1) + qml.CNOT([0, 3]) + qml.measure(3, postselect=0, reset=True) + qml.RY(weights[2], 2) + qml.CNOT([1, 2]) + qml.measure(2, postselect=1) + qml.MultiRZ(0.5, [0, 2]) + return qml.expval(qml.PauliZ(2)) + + drawing = qml.draw(circ)(np.array([np.pi, 3.124, 0.456])) + expected_drawing = ( + "0: ──RX(3.14)───↗│ β”‚0βŸ©β”€β•­β—β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β•­MultiRZ(0.50)── \n" + "1: ──RX(3.12)β”€β”€β”€β†—β”œβ”€β”€β”€β”€β”€β”€β”‚β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β•­β—β”€β”€β”€β”€β”€β”€β”€β”‚β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ \n" + "3: ─────────────────────╰X───↗₀│ β”‚0βŸ©β”€β”‚β”€β”€β”€β”€β”€β”€β”€β”€β”‚β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ \n" + "2: ──RY(0.46)─────────────────────────╰Xβ”€β”€β”€β†—β‚β”œβ”€β•°MultiRZ(0.50)── " + ) + + assert drawing == expected_drawing + + @pytest.mark.parametrize( "transform", [ diff --git a/tests/measurements/test_mid_measure.py b/tests/measurements/test_mid_measure.py index 6aa70a21457..de192de5181 100644 --- a/tests/measurements/test_mid_measure.py +++ b/tests/measurements/test_mid_measure.py @@ -53,6 +53,24 @@ def test_hash(self): assert m1.hash != m3.hash assert m1.hash == m4.hash + @pytest.mark.parametrize( + "postselect, reset, expected", + [ + (None, False, "β”€β†—β”œ"), + (None, True, "─↗│ β”‚0⟩"), + (0, False, "β”€β†—β‚€β”œ"), + (0, True, "─↗₀│ β”‚0⟩"), + (1, False, "β”€β†—β‚β”œ"), + (1, True, "─↗₁│ β”‚0⟩"), + ], + ) + def test_label(self, postselect, reset, expected): + """Test that the label for a MidMeasureMP is correct""" + mp = MidMeasureMP(0, postselect=postselect, reset=reset) + + label = mp.label() + assert label == expected + mp1 = MidMeasureMP(Wires(0), id="m0") mp2 = MidMeasureMP(Wires(1), id="m1")