diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 3412e495bb..31a110e932 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -14,6 +14,9 @@
Bug fixes 🐛
-Contributors ✍️
+* Resolves a bug where calling `qml.counts()` within complicated and/or nested return expressions did not return the correct `PyTreeDef`.
+ [(#1219)](https://github.com/PennyLaneAI/catalyst/pull/1219)
-This release contains contributions from (in alphabetical order):
\ No newline at end of file
+Contributors ✍️
+This release contains contributions from (in alphabetical order):
+Arjun Bhamra
diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py
index 1bfb118626..b6f42fff91 100644
--- a/frontend/catalyst/jax_tracer.py
+++ b/frontend/catalyst/jax_tracer.py
@@ -844,6 +844,9 @@ def trace_quantum_measurements(
"""
shots = get_device_shots(device)
out_classical_tracers = []
+ # NOTE: Number of qml.counts() we hit, used to update our iteration variable to account
+ # for additional leaf PyTreeDef nodes.
+ num_counts = 0
for i, o in enumerate(outputs):
if isinstance(o, MeasurementProcess):
@@ -914,16 +917,8 @@ def trace_quantum_measurements(
results = (jnp.asarray(results[0], jnp.int64), results[1])
out_classical_tracers.extend(results)
counts_tree = tree_structure(("keys", "counts"))
- meas_return_trees_children = out_tree.children()
- if len(meas_return_trees_children):
- meas_return_trees_children[i] = counts_tree
- out_tree = out_tree.make_from_node_data_and_children(
- PyTreeRegistry(),
- out_tree.node_data(),
- meas_return_trees_children,
- )
- else:
- out_tree = counts_tree
+ num_counts += 1
+ out_tree = replace_child_tree(out_tree, i + num_counts, counts_tree)
elif isinstance(o, StateMP) and not isinstance(o, DensityMatrixMP):
assert using_compbasis
shape = (2**nqubits,)
@@ -941,6 +936,37 @@ def trace_quantum_measurements(
return out_classical_tracers, out_tree
+def replace_child_tree(tree: PyTreeDef, index: int, subtree: PyTreeDef) -> PyTreeDef:
+ """
+ Replace the index-th leaf node in a left-to-right depth-first tree traversal of a PyTreeDef
+ with a given subtree.
+
+ Args:
+ tree (PyTreeDef): The original PyTree.
+ index (int): The index of the leaf node to replace.
+ subtree (PyTreeDef): The new subtree to replace the original leaf node with.
+
+ Returns:
+ PyTreeDef: The modified PyTree with the replaced leaf node.
+ """
+
+ def replace_node(node, idx):
+ if not node.children():
+ # Leaf node => update leaf node counter
+ idx[0] += 1
+ if idx[0] == index:
+ return subtree
+ return node
+
+ return node.make_from_node_data_and_children(
+ PyTreeRegistry(),
+ node.node_data(),
+ [replace_node(child, idx) for child in node.children()],
+ )
+
+ return replace_node(tree, [0])
+
+
@debug_logger
def is_transform_valid_for_batch_transforms(tape, flat_results):
"""Not all transforms are valid for batch transforms.
diff --git a/frontend/test/pytest/test_pytree_args.py b/frontend/test/pytest/test_pytree_args.py
index b8f1ab1e61..ff55b45a81 100644
--- a/frontend/test/pytest/test_pytree_args.py
+++ b/frontend/test/pytest/test_pytree_args.py
@@ -593,5 +593,169 @@ def classical(x):
assert result.a == 4
+class TestPyTreesQmlCounts:
+ """Test QJIT workflows when using qml.counts in a return expression."""
+
+ def test_pytree_qml_counts_simple(self):
+ """Test if a single qml.counts() can be used and output correctly."""
+ dev = qml.device("lightning.qubit", wires=1, shots=20)
+
+ @qjit
+ @qml.qnode(dev)
+ def circuit(x):
+ qml.RX(x, wires=0)
+ return {"1": qml.counts()}
+
+ observed = circuit(0.5)
+ expected = {"1": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}
+
+ _, expected_shape = tree_flatten(expected)
+ _, observed_shape = tree_flatten(observed)
+ assert expected_shape == observed_shape
+
+ def test_pytree_qml_counts_nested(self):
+ """Test if nested qml.counts() can be used and output correctly."""
+ dev = qml.device("lightning.qubit", wires=1, shots=20)
+
+ @qjit
+ @qml.qnode(dev)
+ def circuit(x):
+ qml.RX(x, wires=0)
+ return {"1": qml.counts()}, {"2": qml.expval(qml.Z(0))}
+
+ observed = circuit(0.5)
+ expected = (
+ {"1": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))},
+ {"2": jnp.array(-1, dtype=jnp.float64)},
+ )
+
+ _, expected_shape = tree_flatten(expected)
+ _, observed_shape = tree_flatten(observed)
+ assert expected_shape == observed_shape
+
+ @qjit
+ @qml.qnode(dev)
+ def circuit2(x):
+ qml.RX(x, wires=0)
+ return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], {"3": qml.expval(qml.Z(0))}
+
+ observed = circuit2(0.5)
+ expected = (
+ [
+ {"1": jnp.array(-1, dtype=jnp.float64)},
+ {"2": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))},
+ ],
+ {"3": jnp.array(-1, dtype=jnp.float64)},
+ )
+ _, expected_shape = tree_flatten(expected)
+ _, observed_shape = tree_flatten(observed)
+ assert expected_shape == observed_shape
+
+ def test_pytree_qml_counts_2_nested(self):
+ """Test if multiple nested qml.counts() can be used and output correctly."""
+ dev = qml.device("lightning.qubit", wires=1, shots=20)
+
+ @qjit
+ @qml.qnode(dev)
+ def circuit(x):
+ qml.RX(x, wires=0)
+ return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [
+ {"3": qml.expval(qml.Z(0))},
+ {"4": qml.counts()},
+ ]
+
+ observed = circuit(0.5)
+ expected = (
+ [
+ {"1": jnp.array(-1, dtype=jnp.float64)},
+ {"2": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))},
+ ],
+ [
+ {"3": jnp.array(-1, dtype=jnp.float64)},
+ {"4": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))},
+ ],
+ )
+ _, expected_shape = tree_flatten(expected)
+ _, observed_shape = tree_flatten(observed)
+ assert expected_shape == observed_shape
+
+ @qjit
+ @qml.qnode(dev)
+ def circuit2(x):
+ qml.RX(x, wires=0)
+ return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [
+ {"3": qml.counts()},
+ {"4": qml.expval(qml.Z(0))},
+ ]
+
+ observed = circuit2(0.5)
+ expected = (
+ [
+ {"1": jnp.array(-1, dtype=jnp.float64)},
+ {"2": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))},
+ ],
+ [
+ {"3": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))},
+ {"4": jnp.array(-1, dtype=jnp.float64)},
+ ],
+ )
+ _, expected_shape = tree_flatten(expected)
+ _, observed_shape = tree_flatten(observed)
+ assert expected_shape == observed_shape
+
+ def test_pytree_qml_counts_longer(self):
+ """Test if 3 differently nested qml.counts() can be used and output correctly."""
+ dev = qml.device("lightning.qubit", wires=1, shots=20)
+
+ @qjit
+ @qml.qnode(dev)
+ def circuit(x):
+ qml.RX(x, wires=0)
+ return [
+ [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}],
+ [{"3": qml.expval(qml.Z(0))}, {"4": qml.counts()}],
+ {"5": qml.expval(qml.Z(0))},
+ {"6": qml.counts()},
+ ]
+
+ observed = circuit(0.5)
+ expected = [
+ [
+ {"1": jnp.array(-1, dtype=jnp.float64)},
+ {"2": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))},
+ ],
+ [
+ {"3": jnp.array(-1, dtype=jnp.float64)},
+ {"4": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))},
+ ],
+ {"5": jnp.array(-1, dtype=jnp.float64)},
+ {"6": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))},
+ ]
+ _, expected_shape = tree_flatten(expected)
+ _, observed_shape = tree_flatten(observed)
+ assert expected_shape == observed_shape
+
+ def test_pytree_qml_counts_mcm(self):
+ """Test qml.counts() with mid-circuit measurement."""
+ dev = qml.device("lightning.qubit", wires=1, shots=20)
+
+ @qml.qjit
+ @qml.qnode(dev, mcm_method="one-shot", postselect_mode=None)
+ def circuit(x):
+ qml.RX(x, wires=0)
+ measure(0, postselect=1)
+ return {"hi": qml.counts()}, {"bye": qml.expval(qml.Z(0))}, {"hi": qml.counts()}
+
+ observed = circuit(0.5)
+ expected = (
+ {"hi": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))},
+ {"bye": jnp.array(-1, dtype=jnp.float64)},
+ {"hi": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))},
+ )
+ _, expected_shape = tree_flatten(expected)
+ _, observed_shape = tree_flatten(observed)
+ assert expected_shape == observed_shape
+
+
if __name__ == "__main__":
pytest.main(["-x", __file__])