diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3412e495bb..31a110e932 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -14,6 +14,9 @@

Bug fixes 🐛

-

Contributors ✍️

+* Resolves a bug where calling `qml.counts()` within complicated and/or nested return expressions did not return the correct `PyTreeDef`. + [(#1219)](https://github.com/PennyLaneAI/catalyst/pull/1219) -This release contains contributions from (in alphabetical order): \ No newline at end of file +

Contributors ✍️

+This release contains contributions from (in alphabetical order): +Arjun Bhamra diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 1bfb118626..b6f42fff91 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -844,6 +844,9 @@ def trace_quantum_measurements( """ shots = get_device_shots(device) out_classical_tracers = [] + # NOTE: Number of qml.counts() we hit, used to update our iteration variable to account + # for additional leaf PyTreeDef nodes. + num_counts = 0 for i, o in enumerate(outputs): if isinstance(o, MeasurementProcess): @@ -914,16 +917,8 @@ def trace_quantum_measurements( results = (jnp.asarray(results[0], jnp.int64), results[1]) out_classical_tracers.extend(results) counts_tree = tree_structure(("keys", "counts")) - meas_return_trees_children = out_tree.children() - if len(meas_return_trees_children): - meas_return_trees_children[i] = counts_tree - out_tree = out_tree.make_from_node_data_and_children( - PyTreeRegistry(), - out_tree.node_data(), - meas_return_trees_children, - ) - else: - out_tree = counts_tree + num_counts += 1 + out_tree = replace_child_tree(out_tree, i + num_counts, counts_tree) elif isinstance(o, StateMP) and not isinstance(o, DensityMatrixMP): assert using_compbasis shape = (2**nqubits,) @@ -941,6 +936,37 @@ def trace_quantum_measurements( return out_classical_tracers, out_tree +def replace_child_tree(tree: PyTreeDef, index: int, subtree: PyTreeDef) -> PyTreeDef: + """ + Replace the index-th leaf node in a left-to-right depth-first tree traversal of a PyTreeDef + with a given subtree. + + Args: + tree (PyTreeDef): The original PyTree. + index (int): The index of the leaf node to replace. + subtree (PyTreeDef): The new subtree to replace the original leaf node with. + + Returns: + PyTreeDef: The modified PyTree with the replaced leaf node. + """ + + def replace_node(node, idx): + if not node.children(): + # Leaf node => update leaf node counter + idx[0] += 1 + if idx[0] == index: + return subtree + return node + + return node.make_from_node_data_and_children( + PyTreeRegistry(), + node.node_data(), + [replace_node(child, idx) for child in node.children()], + ) + + return replace_node(tree, [0]) + + @debug_logger def is_transform_valid_for_batch_transforms(tape, flat_results): """Not all transforms are valid for batch transforms. diff --git a/frontend/test/pytest/test_pytree_args.py b/frontend/test/pytest/test_pytree_args.py index b8f1ab1e61..ff55b45a81 100644 --- a/frontend/test/pytest/test_pytree_args.py +++ b/frontend/test/pytest/test_pytree_args.py @@ -593,5 +593,169 @@ def classical(x): assert result.a == 4 +class TestPyTreesQmlCounts: + """Test QJIT workflows when using qml.counts in a return expression.""" + + def test_pytree_qml_counts_simple(self): + """Test if a single qml.counts() can be used and output correctly.""" + dev = qml.device("lightning.qubit", wires=1, shots=20) + + @qjit + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return {"1": qml.counts()} + + observed = circuit(0.5) + expected = {"1": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))} + + _, expected_shape = tree_flatten(expected) + _, observed_shape = tree_flatten(observed) + assert expected_shape == observed_shape + + def test_pytree_qml_counts_nested(self): + """Test if nested qml.counts() can be used and output correctly.""" + dev = qml.device("lightning.qubit", wires=1, shots=20) + + @qjit + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return {"1": qml.counts()}, {"2": qml.expval(qml.Z(0))} + + observed = circuit(0.5) + expected = ( + {"1": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + {"2": jnp.array(-1, dtype=jnp.float64)}, + ) + + _, expected_shape = tree_flatten(expected) + _, observed_shape = tree_flatten(observed) + assert expected_shape == observed_shape + + @qjit + @qml.qnode(dev) + def circuit2(x): + qml.RX(x, wires=0) + return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], {"3": qml.expval(qml.Z(0))} + + observed = circuit2(0.5) + expected = ( + [ + {"1": jnp.array(-1, dtype=jnp.float64)}, + {"2": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ], + {"3": jnp.array(-1, dtype=jnp.float64)}, + ) + _, expected_shape = tree_flatten(expected) + _, observed_shape = tree_flatten(observed) + assert expected_shape == observed_shape + + def test_pytree_qml_counts_2_nested(self): + """Test if multiple nested qml.counts() can be used and output correctly.""" + dev = qml.device("lightning.qubit", wires=1, shots=20) + + @qjit + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [ + {"3": qml.expval(qml.Z(0))}, + {"4": qml.counts()}, + ] + + observed = circuit(0.5) + expected = ( + [ + {"1": jnp.array(-1, dtype=jnp.float64)}, + {"2": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ], + [ + {"3": jnp.array(-1, dtype=jnp.float64)}, + {"4": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ], + ) + _, expected_shape = tree_flatten(expected) + _, observed_shape = tree_flatten(observed) + assert expected_shape == observed_shape + + @qjit + @qml.qnode(dev) + def circuit2(x): + qml.RX(x, wires=0) + return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], [ + {"3": qml.counts()}, + {"4": qml.expval(qml.Z(0))}, + ] + + observed = circuit2(0.5) + expected = ( + [ + {"1": jnp.array(-1, dtype=jnp.float64)}, + {"2": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ], + [ + {"3": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + {"4": jnp.array(-1, dtype=jnp.float64)}, + ], + ) + _, expected_shape = tree_flatten(expected) + _, observed_shape = tree_flatten(observed) + assert expected_shape == observed_shape + + def test_pytree_qml_counts_longer(self): + """Test if 3 differently nested qml.counts() can be used and output correctly.""" + dev = qml.device("lightning.qubit", wires=1, shots=20) + + @qjit + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return [ + [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], + [{"3": qml.expval(qml.Z(0))}, {"4": qml.counts()}], + {"5": qml.expval(qml.Z(0))}, + {"6": qml.counts()}, + ] + + observed = circuit(0.5) + expected = [ + [ + {"1": jnp.array(-1, dtype=jnp.float64)}, + {"2": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ], + [ + {"3": jnp.array(-1, dtype=jnp.float64)}, + {"4": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ], + {"5": jnp.array(-1, dtype=jnp.float64)}, + {"6": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ] + _, expected_shape = tree_flatten(expected) + _, observed_shape = tree_flatten(observed) + assert expected_shape == observed_shape + + def test_pytree_qml_counts_mcm(self): + """Test qml.counts() with mid-circuit measurement.""" + dev = qml.device("lightning.qubit", wires=1, shots=20) + + @qml.qjit + @qml.qnode(dev, mcm_method="one-shot", postselect_mode=None) + def circuit(x): + qml.RX(x, wires=0) + measure(0, postselect=1) + return {"hi": qml.counts()}, {"bye": qml.expval(qml.Z(0))}, {"hi": qml.counts()} + + observed = circuit(0.5) + expected = ( + {"hi": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + {"bye": jnp.array(-1, dtype=jnp.float64)}, + {"hi": (jnp.array((0, 1), dtype=jnp.int64), jnp.array((0, 3), dtype=jnp.int64))}, + ) + _, expected_shape = tree_flatten(expected) + _, observed_shape = tree_flatten(observed) + assert expected_shape == observed_shape + + if __name__ == "__main__": pytest.main(["-x", __file__])