Skip to content

Commit

Permalink
Rewrite tests for split_non_commuting
Browse files Browse the repository at this point in the history
  • Loading branch information
astralcai committed Dec 9, 2024
1 parent c5fd5bc commit 29f6be8
Show file tree
Hide file tree
Showing 5 changed files with 554 additions and 797 deletions.
5 changes: 3 additions & 2 deletions pennylane/ops/op_math/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,9 @@ def compute_grouping(self, grouping_type="qwc", method="lf"):
_, ops = self.terms()

with qml.QueuingManager.stop_recording():
op_groups = qml.pauli.group_observables(ops, grouping_type=grouping_type, method=method)
self._grouping_indices = tuple(tuple(ops.index(o) for o in group) for group in op_groups)
self._grouping_indices = qml.pauli.compute_partition_indices(
ops, grouping_type=grouping_type, method=method
)

@property
def coeffs(self):
Expand Down
3 changes: 3 additions & 0 deletions pennylane/ops/qubit/non_parametric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class Hadamard(Observable, Operation):

_queue_category = "_ops"

def __init__(self, wires: Optional[WiresLike] = None, id: Optional[str] = None):
super().__init__(wires=wires, id=id)

def label(
self,
decimals: Optional[int] = None,
Expand Down
71 changes: 24 additions & 47 deletions pennylane/transforms/split_non_commuting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import pennylane as qml
from pennylane.measurements import ExpectationMP, MeasurementProcess, Shots, StateMP
from pennylane.ops import LinearCombination, Prod, SProd, Sum
from pennylane.ops import Prod, SProd, Sum
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.transforms import transform
from pennylane.typing import PostprocessingFn, Result, ResultBatch, TensorLike, Union
Expand Down Expand Up @@ -279,7 +279,7 @@ def circuit(x):

if grouping_strategy is None:
measurements = list(single_term_obs_mps.keys())
tapes = [tape.__class__(tape.operations, [m], shots=tape.shots) for m in measurements]
tapes = [tape.copy(measurements=[m]) for m in measurements]
return tapes, partial(
_processing_fn_no_grouping,
single_term_obs_mps=single_term_obs_mps,
Expand All @@ -288,26 +288,15 @@ def circuit(x):
batch_size=tape.batch_size,
)

if (
grouping_strategy == "wires"
or grouping_strategy == "default"
and any(
isinstance(m, ExpectationMP) and isinstance(m.obs, LinearCombination)
for m in tape.measurements
)
or any(
m.obs is not None and not qml.pauli.is_pauli_word(m.obs) for m in single_term_obs_mps
)
if grouping_strategy == "wires" or any(
m.obs is not None and not qml.pauli.is_pauli_word(m.obs) for m in single_term_obs_mps
):
# This is a loose check to see whether wires grouping or qwc grouping should be used,
# which does not necessarily make perfect sense but is consistent with the old decision
# logic in `Device.batch_transform`. The premise is that qwc grouping is classically
# expensive but produces fewer tapes, whereas wires grouping is classically faster to
# compute, but inefficient quantum-wise. If this transform is to be added to a device's
# `preprocess`, it will be performed for every circuit execution, which can get very
# expensive if there is a large number of observables. The reasoning here is, large
# Hamiltonians typically come in the form of a `LinearCombination`, so
# if we see one of those, use wires grouping to be safe. Otherwise, use qwc grouping.
# TODO: here we fall back to wire-based grouping if any of the observables in the tape
# is not a pauli word. As a result, adding a single measurement to a circuit could
# significantly increase the number of circuit executions. We should be able to
# separate the logic for pauli-word observables and non-pauli-word observables,
# putting non-pauli-word observables in separate wire-based groups, but using qwc
# based grouping for the rest of the observables. [sc-79686]
return _split_using_wires_grouping(tape, single_term_obs_mps, offsets)

return _split_using_qwc_grouping(tape, single_term_obs_mps, offsets)
Expand Down Expand Up @@ -370,7 +359,7 @@ def _split_ham_with_grouping(tape: qml.tape.QuantumScript):
mp_groups.append(mp_group)
group_sizes.append(group_size)

tapes = [tape.__class__(tape.operations, mps, shots=tape.shots) for mps in mp_groups]
tapes = [tape.copy(measurements=mps) for mps in mp_groups]
return tapes, partial(
_processing_fn_with_grouping,
single_term_obs_mps=single_term_obs_mps,
Expand Down Expand Up @@ -405,7 +394,7 @@ def _split_using_qwc_grouping(
obs_list = [_mp_to_obs(m, tape) for m in measurements]
index_groups = []
if len(obs_list) > 0:
_, index_groups = qml.pauli.group_observables(obs_list, range(len(obs_list)))
index_groups = qml.pauli.compute_partition_indices(obs_list)

# A dictionary for measurements of each unique single-term observable, mapped to the
# indices of the original measurements it belongs to, its coefficients, the index of
Expand Down Expand Up @@ -436,7 +425,7 @@ def _split_using_qwc_grouping(
)
group_sizes.append(1)

tapes = [tape.__class__(tape.operations, mps, shots=tape.shots) for mps in mp_groups]
tapes = [tape.copy(measurements=mps) for mps in mp_groups]
return tapes, partial(
_processing_fn_with_grouping,
single_term_obs_mps=single_term_obs_mps_grouped,
Expand Down Expand Up @@ -507,7 +496,7 @@ def _split_using_wires_grouping(
single_term_obs_mps_grouped[smp] = (mp_indices, coeffs, num_groups, 0)
num_groups += 1

tapes = [tape.__class__(tape.operations, mps, shots=tape.shots) for mps in mp_groups]
tapes = [tape.copy(measurements=mps) for mps in mp_groups]
return tapes, partial(
_processing_fn_with_grouping,
single_term_obs_mps=single_term_obs_mps_grouped,
Expand Down Expand Up @@ -558,8 +547,10 @@ def _split_all_multi_term_obs_mps(tape: qml.tape.QuantumScript):
# Otherwise, add this new measurement to the list of single-term measurements.
else:
single_term_obs_mps[sm] = ([mp_idx], [c])
elif isinstance(obs, qml.Identity):
offset += 1
else:
if isinstance(obs, SProd):
if isinstance(obs, (SProd, Prod)):
obs = obs.simplify()
if isinstance(obs, Sum):
raise RuntimeError(
Expand Down Expand Up @@ -606,27 +597,7 @@ def _processing_fn_no_grouping(
res_batch_for_each_mp[mp_idx].append(res[smp_idx])
coeffs_for_each_mp[mp_idx].append(coeff)

result_shape = _infer_result_shape(shots, batch_size)

# Sum up the results for each original measurement
res_for_each_mp = [
_sum_terms(_sub_res, coeffs, offset, result_shape)
for _sub_res, coeffs, offset in zip(res_batch_for_each_mp, coeffs_for_each_mp, offsets)
]

# res_for_each_mp should have shape (n_mps, [,n_shots] [,batch_size])
if len(res_for_each_mp) == 1:
return res_for_each_mp[0]

if shots.has_partitioned_shots:
# If the shot vector dimension exists, it should be moved to the first axis
# Basically, the shape becomes (n_shots, n_mps, [,batch_size])
res_for_each_mp = [
tuple(res_for_each_mp[j][i] for j in range(len(res_for_each_mp)))
for i in range(shots.num_copies)
]

return tuple(res_for_each_mp)
return _res_for_each_mp(res_batch_for_each_mp, coeffs_for_each_mp, offsets, shots, batch_size)


def _processing_fn_with_grouping(

Check notice on line 603 in pennylane/transforms/split_non_commuting.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/split_non_commuting.py#L603

Too many positional arguments (6/5) (too-many-positional-arguments)
Expand Down Expand Up @@ -678,6 +649,12 @@ def _processing_fn_with_grouping(
res_batch_for_each_mp[mp_idx].append(sub_res)
coeffs_for_each_mp[mp_idx].append(coeff)

return _res_for_each_mp(res_batch_for_each_mp, coeffs_for_each_mp, offsets, shots, batch_size)


def _res_for_each_mp(res_batch_for_each_mp, coeffs_for_each_mp, offsets, shots, batch_size):
"""Helper function that combines a result batch into results for each mp"""

result_shape = _infer_result_shape(shots, batch_size)

# Sum up the results for each original measurement
Expand Down
1 change: 1 addition & 0 deletions tests/measurements/test_probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def test_observable_is_measurement_value_list(
): # pylint: disable=too-many-arguments
"""Test that probs for mid-circuit measurement values
are correct for a measurement value list."""

dev = qml.device("default.qubit", seed=seed)

@qml.qnode(dev)
Expand Down
Loading

0 comments on commit 29f6be8

Please sign in to comment.