From f016a31f69d1b8a84bc9612af1bc64f0575506e9 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Thu, 21 Mar 2024 14:47:59 -0400 Subject: [PATCH] Transform: measurements to counts (#608) **Context:** Some hardware only return counts, and we currently does not support any other measurements for those devices. **Description of the Change:** We add a transforms that is applied in the pre-processing step of the new device API. This transform replace all measurements by a single count measurement (the transform assumes that measurements are commuting). Then post processing functions are applied in order to get the original measurements like expval, var, and probs. **Benefits:** We can use more measurements on HW supporting counts only. --------- Co-authored-by: Sergei Mironov Co-authored-by: David Ittah --- frontend/catalyst/jax_tracer.py | 5 +- frontend/catalyst/preprocess.py | 111 +++++++++++++++++++++++- frontend/test/pytest/test_preprocess.py | 99 ++++++++++++++++++++- 3 files changed, 211 insertions(+), 4 deletions(-) diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 2daecf35f4..4ef37081db 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -612,9 +612,11 @@ def trace_quantum_measurements( m_wires = o.wires if o.wires else range(device.num_wires) else: m_wires = o.wires if o.wires else range(len(tape.wires)) + obs_tracers, nqubits = trace_observables(o.obs, qrp, m_wires) using_compbasis = obs_tracers.primitive == compbasis_p + if o.return_type.value == "sample": shape = (shots, nqubits) if using_compbasis else (shots,) out_classical_tracers.append(sample_p.bind(obs_tracers, shots=shots, shape=shape)) @@ -702,7 +704,6 @@ def is_midcircuit_measurement(op): def apply_transform(transform_program, tape, flat_results): """Apply transform.""" - # Some transforms use trainability as a basis for transforming. # See batch_params params = tape.get_parameters(trainable_only=False) @@ -827,7 +828,7 @@ def trace_quantum_function( with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: # (1) - Classical tracing - quantum_tape = QuantumTape() + quantum_tape = QuantumTape(shots=device.shots) with EvaluationContext.frame_tracing_context(ctx) as trace: wffa, in_avals, keep_inputs, out_tree_promise = deduce_avals(f, args, kwargs) in_classical_tracers = _input_type_to_tracers(trace.new_arg, in_avals) diff --git a/frontend/catalyst/preprocess.py b/frontend/catalyst/preprocess.py index 95b1e43069..580af052ac 100644 --- a/frontend/catalyst/preprocess.py +++ b/frontend/catalyst/preprocess.py @@ -13,9 +13,14 @@ # limitations under the License. """This module contains the preprocessing functions. """ - +import jax import pennylane as qml from pennylane import transform +from pennylane.measurements import CountsMP, ExpectationMP, ProbabilityMP, VarianceMP +from pennylane.tape.tape import ( + _validate_computational_basis_sampling, + rotations_and_diagonal_measurements, +) import catalyst from catalyst.utils.exceptions import CompileError @@ -60,3 +65,107 @@ def null_postprocessing(results): def catalyst_acceptance(op: qml.operation.Operator, operations) -> bool: """Specify whether or not an Operator is supported.""" return op.name in operations + + +@transform +def measurements_from_counts(tape): + r"""Replace all measurements from a tape with a single count measurement, it adds postprocessing + functions for each original measurement. + + Args: + tape (QNode or QuantumTape or Callable): A quantum circuit. + + Returns: + qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]: The + transformed circuit as described in :func:`qml.transform `. + + .. note:: + + Samples are not supported. + """ + if tape.samples_computational_basis and len(tape.measurements) > 1: + _validate_computational_basis_sampling(tape.measurements) + diagonalizing_gates, diagonal_measurements = rotations_and_diagonal_measurements(tape) + for i, m in enumerate(diagonal_measurements): + if m.obs is not None: + diagonalizing_gates.extend(m.obs.diagonalizing_gates()) + diagonal_measurements[i] = type(m)(eigvals=m.eigvals(), wires=m.wires) + # Add diagonalizing gates + news_operations = tape.operations + news_operations.extend(diagonalizing_gates) + # Transform tape + measured_wires = set() + for m in diagonal_measurements: + measured_wires.update(m.wires.tolist()) + + new_measurements = [qml.counts(wires=list(measured_wires))] + new_tape = type(tape)(news_operations, new_measurements, shots=tape.shots) + + def postprocessing_counts_to_expval(results): + """A processing function to get expecation values from counts.""" + states = results[0][0] + counts_outcomes = results[0][1] + results_processed = [] + for m in tape.measurements: + mapped_counts_outcome = _map_counts( + counts_outcomes, m.wires, qml.wires.Wires(list(measured_wires)) + ) + if isinstance(m, ExpectationMP): + probs = _get_probs(mapped_counts_outcome) + results_processed.append(_get_expval(eigvals=m.eigvals(), prob_vector=probs)) + elif isinstance(m, VarianceMP): + probs = _get_probs(mapped_counts_outcome) + results_processed.append(_get_var(eigvals=m.eigvals(), prob_vector=probs)) + elif isinstance(m, ProbabilityMP): + probs = _get_probs(mapped_counts_outcome) + results_processed.append(probs) + elif isinstance(m, CountsMP): + results_processed.append( + tuple([states[0 : 2 ** len(m.wires)], mapped_counts_outcome]) + ) + if len(tape.measurements) == 1: + results_processed = results_processed[0] + else: + results_processed = tuple(results_processed) + return results_processed + + return [new_tape], postprocessing_counts_to_expval + + +def _get_probs(counts_outcome): + """From the counts outcome, calculate the probability vector.""" + prob_vector = [] + num_shots = jax.numpy.sum(counts_outcome) + for count in counts_outcome: + prob = count / num_shots + prob_vector.append(prob) + return jax.numpy.array(prob_vector) + + +def _get_expval(eigvals, prob_vector): + """From the observable eigenvalues and the probability vector + it calculates the expectation value.""" + expval = jax.numpy.dot(jax.numpy.array(eigvals), prob_vector) + return expval + + +def _get_var(eigvals, prob_vector): + """From the observable eigenvalues and the probability vector + it calculates the variance.""" + var = jax.numpy.dot(prob_vector, (eigvals**2)) - jax.numpy.dot(prob_vector, eigvals) + return var + + +def _map_counts(counts, sub_wires, wire_order): + """Map the count outcome given a wires and wire order.""" + wire_map = dict(zip(wire_order, range(len(wire_order)))) + mapped_wires = [wire_map[w] for w in sub_wires] + + mapped_counts = {} + num_wires = len(wire_order) + for outcome, occurrence in enumerate(counts): + binary_outcome = format(outcome, f"0{num_wires}b") + mapped_outcome = "".join(binary_outcome[i] for i in mapped_wires) + mapped_counts[mapped_outcome] = mapped_counts.get(mapped_outcome, 0) + occurrence + + return jax.numpy.array(list(mapped_counts.values())) diff --git a/frontend/test/pytest/test_preprocess.py b/frontend/test/pytest/test_preprocess.py index f88eb4f8e8..aee97a3eee 100644 --- a/frontend/test/pytest/test_preprocess.py +++ b/frontend/test/pytest/test_preprocess.py @@ -13,6 +13,7 @@ # limitations under the License. """Test for the device preprocessing. """ +# pylint: disable=unused-argument import pathlib import numpy as np @@ -25,7 +26,7 @@ from catalyst import CompileError, ctrl from catalyst.compiler import get_lib_path -from catalyst.preprocess import decompose_ops_to_unitary +from catalyst.preprocess import decompose_ops_to_unitary, measurements_from_counts class DummyDevice(Device): @@ -141,6 +142,102 @@ def f(): with pytest.raises(CompileError, match="could not be decomposed, it might be unsupported"): qml.qjit(f, target="jaxpr") + @pytest.mark.skipif( + not pathlib.Path( + get_lib_path("runtime", "RUNTIME_LIB_DIR") + "/libdummy_device.so" + ).is_file(), + reason="lib_dummydevice.so was not found.", + ) + def test_measurement_from_counts_integration_multiple_measurements(self): + """Test the measurment from counts transform as part of the Catalyst pipeline.""" + dev = DummyDevice(wires=4, shots=1000) + + @qml.qjit + @measurements_from_counts + @qml.qnode(dev) + def circuit(theta: float): + qml.X(0) + qml.X(1) + qml.X(2) + qml.X(3) + return ( + qml.expval(qml.PauliX(wires=0) @ qml.PauliX(wires=1)), + qml.var(qml.PauliX(wires=0) @ qml.PauliX(wires=2)), + qml.counts(qml.PauliX(wires=0) @ qml.PauliX(wires=1) @ qml.PauliX(wires=2)), + ) + + mlir = qml.qjit(circuit, target="mlir").mlir + assert "expval" not in mlir + assert "var" not in mlir + assert "counts" in mlir + + @pytest.mark.skipif( + not pathlib.Path( + get_lib_path("runtime", "RUNTIME_LIB_DIR") + "/libdummy_device.so" + ).is_file(), + reason="lib_dummydevice.so was not found.", + ) + def test_measurement_from_counts_integration_single_measurement(self): + """Test the measurment from counts transform with a single measurements as part of + the Catalyst pipeline.""" + dev = DummyDevice(wires=4, shots=1000) + + @qml.qjit + @measurements_from_counts + @qml.qnode(dev) + def circuit(theta: float): + qml.X(0) + qml.X(1) + qml.X(2) + qml.X(3) + return qml.expval(qml.PauliX(wires=0) @ qml.PauliX(wires=1)) + + mlir = qml.qjit(circuit, target="mlir").mlir + assert "expval" not in mlir + assert "counts" in mlir + + +class TestTransform: + """Test the transforms implemented in Catalyst.""" + + def test_measurements_from_counts(self): + """Test the transfom measurements_from_counts.""" + device = qml.device("lightning.qubit", wires=4, shots=1000) + + @qml.qjit + @measurements_from_counts + @qml.qnode(device=device) + def circuit(a: float): + qml.X(0) + qml.X(1) + qml.X(2) + qml.X(3) + return ( + qml.expval(qml.PauliX(wires=0) @ qml.PauliX(wires=1)), + qml.var(qml.PauliX(wires=0) @ qml.PauliX(wires=2)), + qml.probs(wires=[3]), + qml.counts(qml.PauliX(wires=0) @ qml.PauliX(wires=1) @ qml.PauliX(wires=2)), + ) + + res = circuit(0.2) + results = res[0] + + assert isinstance(results, tuple) + assert len(results) == 4 + + expval = results[0] + var = results[1] + probs = results[2] + counts = results[3] + + assert expval.shape == () + assert var.shape == () + assert probs.shape == (2,) + assert isinstance(counts, tuple) + assert len(counts) == 2 + assert counts[0].shape == (8,) + assert counts[1].shape == (8,) + if __name__ == "__main__": pytest.main(["-x", __file__])