Skip to content

Commit

Permalink
Transform: measurements to counts (#608)
Browse files Browse the repository at this point in the history
**Context:**
Some hardware only return counts, and we currently does not support any
other measurements for those devices.

**Description of the Change:**
We add a transforms that is applied in the pre-processing step of the
new device API. This transform replace all measurements by a single
count measurement (the transform assumes that measurements are
commuting). Then post processing functions are applied in order to get
the original measurements like expval, var, and probs.

**Benefits:**
We can use more measurements on HW supporting counts only.

---------

Co-authored-by: Sergei Mironov <[email protected]>
Co-authored-by: David Ittah <[email protected]>
  • Loading branch information
3 people authored Mar 21, 2024
1 parent 21e0f5f commit f016a31
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 4 deletions.
5 changes: 3 additions & 2 deletions frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,9 +612,11 @@ def trace_quantum_measurements(
m_wires = o.wires if o.wires else range(device.num_wires)
else:
m_wires = o.wires if o.wires else range(len(tape.wires))

obs_tracers, nqubits = trace_observables(o.obs, qrp, m_wires)

using_compbasis = obs_tracers.primitive == compbasis_p

if o.return_type.value == "sample":
shape = (shots, nqubits) if using_compbasis else (shots,)
out_classical_tracers.append(sample_p.bind(obs_tracers, shots=shots, shape=shape))
Expand Down Expand Up @@ -702,7 +704,6 @@ def is_midcircuit_measurement(op):

def apply_transform(transform_program, tape, flat_results):
"""Apply transform."""

# Some transforms use trainability as a basis for transforming.
# See batch_params
params = tape.get_parameters(trainable_only=False)
Expand Down Expand Up @@ -827,7 +828,7 @@ def trace_quantum_function(

with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx:
# (1) - Classical tracing
quantum_tape = QuantumTape()
quantum_tape = QuantumTape(shots=device.shots)
with EvaluationContext.frame_tracing_context(ctx) as trace:
wffa, in_avals, keep_inputs, out_tree_promise = deduce_avals(f, args, kwargs)
in_classical_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
Expand Down
111 changes: 110 additions & 1 deletion frontend/catalyst/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@
# limitations under the License.
"""This module contains the preprocessing functions.
"""

import jax
import pennylane as qml
from pennylane import transform
from pennylane.measurements import CountsMP, ExpectationMP, ProbabilityMP, VarianceMP
from pennylane.tape.tape import (
_validate_computational_basis_sampling,
rotations_and_diagonal_measurements,
)

import catalyst
from catalyst.utils.exceptions import CompileError
Expand Down Expand Up @@ -60,3 +65,107 @@ def null_postprocessing(results):
def catalyst_acceptance(op: qml.operation.Operator, operations) -> bool:
"""Specify whether or not an Operator is supported."""
return op.name in operations


@transform
def measurements_from_counts(tape):
r"""Replace all measurements from a tape with a single count measurement, it adds postprocessing
functions for each original measurement.
Args:
tape (QNode or QuantumTape or Callable): A quantum circuit.
Returns:
qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]: The
transformed circuit as described in :func:`qml.transform <pennylane.transform>`.
.. note::
Samples are not supported.
"""
if tape.samples_computational_basis and len(tape.measurements) > 1:
_validate_computational_basis_sampling(tape.measurements)
diagonalizing_gates, diagonal_measurements = rotations_and_diagonal_measurements(tape)
for i, m in enumerate(diagonal_measurements):
if m.obs is not None:
diagonalizing_gates.extend(m.obs.diagonalizing_gates())
diagonal_measurements[i] = type(m)(eigvals=m.eigvals(), wires=m.wires)
# Add diagonalizing gates
news_operations = tape.operations
news_operations.extend(diagonalizing_gates)
# Transform tape
measured_wires = set()
for m in diagonal_measurements:
measured_wires.update(m.wires.tolist())

new_measurements = [qml.counts(wires=list(measured_wires))]
new_tape = type(tape)(news_operations, new_measurements, shots=tape.shots)

def postprocessing_counts_to_expval(results):
"""A processing function to get expecation values from counts."""
states = results[0][0]
counts_outcomes = results[0][1]
results_processed = []
for m in tape.measurements:
mapped_counts_outcome = _map_counts(
counts_outcomes, m.wires, qml.wires.Wires(list(measured_wires))
)
if isinstance(m, ExpectationMP):
probs = _get_probs(mapped_counts_outcome)
results_processed.append(_get_expval(eigvals=m.eigvals(), prob_vector=probs))
elif isinstance(m, VarianceMP):
probs = _get_probs(mapped_counts_outcome)
results_processed.append(_get_var(eigvals=m.eigvals(), prob_vector=probs))
elif isinstance(m, ProbabilityMP):
probs = _get_probs(mapped_counts_outcome)
results_processed.append(probs)
elif isinstance(m, CountsMP):
results_processed.append(
tuple([states[0 : 2 ** len(m.wires)], mapped_counts_outcome])
)
if len(tape.measurements) == 1:
results_processed = results_processed[0]
else:
results_processed = tuple(results_processed)
return results_processed

return [new_tape], postprocessing_counts_to_expval


def _get_probs(counts_outcome):
"""From the counts outcome, calculate the probability vector."""
prob_vector = []
num_shots = jax.numpy.sum(counts_outcome)
for count in counts_outcome:
prob = count / num_shots
prob_vector.append(prob)
return jax.numpy.array(prob_vector)


def _get_expval(eigvals, prob_vector):
"""From the observable eigenvalues and the probability vector
it calculates the expectation value."""
expval = jax.numpy.dot(jax.numpy.array(eigvals), prob_vector)
return expval


def _get_var(eigvals, prob_vector):
"""From the observable eigenvalues and the probability vector
it calculates the variance."""
var = jax.numpy.dot(prob_vector, (eigvals**2)) - jax.numpy.dot(prob_vector, eigvals)
return var


def _map_counts(counts, sub_wires, wire_order):
"""Map the count outcome given a wires and wire order."""
wire_map = dict(zip(wire_order, range(len(wire_order))))
mapped_wires = [wire_map[w] for w in sub_wires]

mapped_counts = {}
num_wires = len(wire_order)
for outcome, occurrence in enumerate(counts):
binary_outcome = format(outcome, f"0{num_wires}b")
mapped_outcome = "".join(binary_outcome[i] for i in mapped_wires)
mapped_counts[mapped_outcome] = mapped_counts.get(mapped_outcome, 0) + occurrence

return jax.numpy.array(list(mapped_counts.values()))
99 changes: 98 additions & 1 deletion frontend/test/pytest/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Test for the device preprocessing.
"""
# pylint: disable=unused-argument
import pathlib

import numpy as np
Expand All @@ -25,7 +26,7 @@

from catalyst import CompileError, ctrl
from catalyst.compiler import get_lib_path
from catalyst.preprocess import decompose_ops_to_unitary
from catalyst.preprocess import decompose_ops_to_unitary, measurements_from_counts


class DummyDevice(Device):
Expand Down Expand Up @@ -141,6 +142,102 @@ def f():
with pytest.raises(CompileError, match="could not be decomposed, it might be unsupported"):
qml.qjit(f, target="jaxpr")

@pytest.mark.skipif(
not pathlib.Path(
get_lib_path("runtime", "RUNTIME_LIB_DIR") + "/libdummy_device.so"
).is_file(),
reason="lib_dummydevice.so was not found.",
)
def test_measurement_from_counts_integration_multiple_measurements(self):
"""Test the measurment from counts transform as part of the Catalyst pipeline."""
dev = DummyDevice(wires=4, shots=1000)

@qml.qjit
@measurements_from_counts
@qml.qnode(dev)
def circuit(theta: float):
qml.X(0)
qml.X(1)
qml.X(2)
qml.X(3)
return (
qml.expval(qml.PauliX(wires=0) @ qml.PauliX(wires=1)),
qml.var(qml.PauliX(wires=0) @ qml.PauliX(wires=2)),
qml.counts(qml.PauliX(wires=0) @ qml.PauliX(wires=1) @ qml.PauliX(wires=2)),
)

mlir = qml.qjit(circuit, target="mlir").mlir
assert "expval" not in mlir
assert "var" not in mlir
assert "counts" in mlir

@pytest.mark.skipif(
not pathlib.Path(
get_lib_path("runtime", "RUNTIME_LIB_DIR") + "/libdummy_device.so"
).is_file(),
reason="lib_dummydevice.so was not found.",
)
def test_measurement_from_counts_integration_single_measurement(self):
"""Test the measurment from counts transform with a single measurements as part of
the Catalyst pipeline."""
dev = DummyDevice(wires=4, shots=1000)

@qml.qjit
@measurements_from_counts
@qml.qnode(dev)
def circuit(theta: float):
qml.X(0)
qml.X(1)
qml.X(2)
qml.X(3)
return qml.expval(qml.PauliX(wires=0) @ qml.PauliX(wires=1))

mlir = qml.qjit(circuit, target="mlir").mlir
assert "expval" not in mlir
assert "counts" in mlir


class TestTransform:
"""Test the transforms implemented in Catalyst."""

def test_measurements_from_counts(self):
"""Test the transfom measurements_from_counts."""
device = qml.device("lightning.qubit", wires=4, shots=1000)

@qml.qjit
@measurements_from_counts
@qml.qnode(device=device)
def circuit(a: float):
qml.X(0)
qml.X(1)
qml.X(2)
qml.X(3)
return (
qml.expval(qml.PauliX(wires=0) @ qml.PauliX(wires=1)),
qml.var(qml.PauliX(wires=0) @ qml.PauliX(wires=2)),
qml.probs(wires=[3]),
qml.counts(qml.PauliX(wires=0) @ qml.PauliX(wires=1) @ qml.PauliX(wires=2)),
)

res = circuit(0.2)
results = res[0]

assert isinstance(results, tuple)
assert len(results) == 4

expval = results[0]
var = results[1]
probs = results[2]
counts = results[3]

assert expval.shape == ()
assert var.shape == ()
assert probs.shape == (2,)
assert isinstance(counts, tuple)
assert len(counts) == 2
assert counts[0].shape == (8,)
assert counts[1].shape == (8,)


if __name__ == "__main__":
pytest.main(["-x", __file__])

0 comments on commit f016a31

Please sign in to comment.