diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 8b24103812c..e7cd3acc385 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -9,6 +9,9 @@
Improvements π
+* `qml.draw` now supports drawing mid-circuit measurements.
+ [(#4775)](https://github.com/PennyLaneAI/pennylane/pull/4775)
+
* Autograd can now use vjps provided by the device from the new device API. If a device provides
a vector Jacobian product, this can be selected by providing `device_vjp=True` to
`qml.execute`.
diff --git a/pennylane/drawer/utils.py b/pennylane/drawer/utils.py
index 6a0a36dd1e7..603653eb8a2 100644
--- a/pennylane/drawer/utils.py
+++ b/pennylane/drawer/utils.py
@@ -77,7 +77,7 @@ def unwrap_controls(op):
# Get wires and control values of base operation; need to make a copy of
# control values, otherwise it will modify the list in the operation itself.
control_wires = getattr(op, "control_wires", [])
- control_values = op.hyperparameters.get("control_values", None)
+ control_values = getattr(op, "hyperparameters", {}).get("control_values", None)
if isinstance(control_values, list):
control_values = control_values.copy()
diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py
index 825c0283820..ae5734677a5 100644
--- a/pennylane/measurements/mid_measure.py
+++ b/pennylane/measurements/mid_measure.py
@@ -223,6 +223,28 @@ def __init__(
self.reset = reset
self.postselect = postselect
+ def label(self, decimals=None, base_label=None, cache=None): # pylint: disable=unused-argument
+ r"""How the mid-circuit measurement is represented in diagrams and drawings.
+
+ Args:
+ decimals=None (Int): If ``None``, no parameters are included. Else,
+ how to round the parameters.
+ base_label=None (Iterable[str]): overwrite the non-parameter component of the label.
+ Must be same length as ``obs`` attribute.
+ cache=None (dict): dictionary that carries information between label calls
+ in the same drawing
+
+ Returns:
+ str: label to use in drawings
+ """
+ _label = "β€β"
+ if self.postselect is not None:
+ _label += "β" if self.postselect == 1 else "β"
+
+ _label += "β" if not self.reset else "β β0β©"
+
+ return _label
+
@property
def return_type(self):
return MidMeasure
diff --git a/tests/drawer/test_draw.py b/tests/drawer/test_draw.py
index 05333550320..06cc0a646ff 100644
--- a/tests/drawer/test_draw.py
+++ b/tests/drawer/test_draw.py
@@ -300,6 +300,60 @@ def circ():
spy.assert_called_once()
+@pytest.mark.parametrize(
+ "postselect, reset, mid_measure_label",
+ [
+ (None, False, "β€ββ"),
+ (None, True, "β€ββ β0β©"),
+ (0, False, "β€βββ"),
+ (0, True, "β€βββ β0β©"),
+ (1, False, "β€βββ"),
+ (1, True, "β€βββ β0β©"),
+ ],
+)
+def test_draw_mid_circuit_measurement(postselect, reset, mid_measure_label):
+ """Test that mid-circuit measurements are drawn correctly."""
+
+ def func():
+ qml.Hadamard(0)
+ qml.measure(0, reset=reset, postselect=postselect)
+ qml.PauliX(0)
+ return qml.expval(qml.PauliZ(0))
+
+ drawing = qml.draw(func)()
+ expected_drawing = "0: ββHββ" + mid_measure_label + "ββXββ€ "
+
+ assert drawing == expected_drawing
+
+
+def test_draw_mid_circuit_measurement_multiple_wires():
+ """Test that mid-circuit measurements are correctly drawn in circuits
+ with multiple wires."""
+
+ def circ(weights):
+ qml.RX(weights[0], 0)
+ qml.measure(0, reset=True)
+ qml.RX(weights[1], 1)
+ qml.measure(1)
+ qml.CNOT([0, 3])
+ qml.measure(3, postselect=0, reset=True)
+ qml.RY(weights[2], 2)
+ qml.CNOT([1, 2])
+ qml.measure(2, postselect=1)
+ qml.MultiRZ(0.5, [0, 2])
+ return qml.expval(qml.PauliZ(2))
+
+ drawing = qml.draw(circ)(np.array([np.pi, 3.124, 0.456]))
+ expected_drawing = (
+ "0: ββRX(3.14)βββ€ββ β0β©βββββββββββββββββββββββββMultiRZ(0.50)ββ€ \n"
+ "1: ββRX(3.12)βββ€βββββββββββββββββββββββββββββββββββββββββββββββ€ \n"
+ "3: ββββββββββββββββββββββ°Xβββ€βββ β0β©ββββββββββββββββββββββββββ€ \n"
+ "2: ββRY(0.46)ββββββββββββββββββββββββββ°Xβββ€βββββ°MultiRZ(0.50)ββ€ "
+ )
+
+ assert drawing == expected_drawing
+
+
@pytest.mark.parametrize(
"transform",
[
diff --git a/tests/measurements/test_mid_measure.py b/tests/measurements/test_mid_measure.py
index 6aa70a21457..de192de5181 100644
--- a/tests/measurements/test_mid_measure.py
+++ b/tests/measurements/test_mid_measure.py
@@ -53,6 +53,24 @@ def test_hash(self):
assert m1.hash != m3.hash
assert m1.hash == m4.hash
+ @pytest.mark.parametrize(
+ "postselect, reset, expected",
+ [
+ (None, False, "β€ββ"),
+ (None, True, "β€ββ β0β©"),
+ (0, False, "β€βββ"),
+ (0, True, "β€βββ β0β©"),
+ (1, False, "β€βββ"),
+ (1, True, "β€βββ β0β©"),
+ ],
+ )
+ def test_label(self, postselect, reset, expected):
+ """Test that the label for a MidMeasureMP is correct"""
+ mp = MidMeasureMP(0, postselect=postselect, reset=reset)
+
+ label = mp.label()
+ assert label == expected
+
mp1 = MidMeasureMP(Wires(0), id="m0")
mp2 = MidMeasureMP(Wires(1), id="m1")