Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite tests for split_non_commuting #6687

Merged
merged 11 commits into from
Dec 13, 2024
5 changes: 3 additions & 2 deletions pennylane/ops/op_math/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,9 @@ def compute_grouping(self, grouping_type="qwc", method="lf"):
_, ops = self.terms()

with qml.QueuingManager.stop_recording():
op_groups = qml.pauli.group_observables(ops, grouping_type=grouping_type, method=method)
self._grouping_indices = tuple(tuple(ops.index(o) for o in group) for group in op_groups)
self._grouping_indices = qml.pauli.compute_partition_indices(
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
ops, grouping_type=grouping_type, method=method
)

@property
def coeffs(self):
Expand Down
3 changes: 3 additions & 0 deletions pennylane/ops/qubit/non_parametric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class Hadamard(Observable, Operation):

_queue_category = "_ops"

def __init__(self, wires: WiresLike, id: Optional[str] = None):
super().__init__(wires=wires, id=id)

def label(
self,
decimals: Optional[int] = None,
Expand Down
71 changes: 24 additions & 47 deletions pennylane/transforms/split_non_commuting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import pennylane as qml
from pennylane.measurements import ExpectationMP, MeasurementProcess, Shots, StateMP
from pennylane.ops import LinearCombination, Prod, SProd, Sum
from pennylane.ops import Prod, SProd, Sum
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.transforms import transform
from pennylane.typing import PostprocessingFn, Result, ResultBatch, TensorLike, Union
Expand Down Expand Up @@ -279,7 +279,7 @@

if grouping_strategy is None:
measurements = list(single_term_obs_mps.keys())
tapes = [tape.__class__(tape.operations, [m], shots=tape.shots) for m in measurements]
tapes = [tape.copy(measurements=[m]) for m in measurements]
return tapes, partial(
_processing_fn_no_grouping,
single_term_obs_mps=single_term_obs_mps,
Expand All @@ -288,26 +288,15 @@
batch_size=tape.batch_size,
)

if (
grouping_strategy == "wires"
or grouping_strategy == "default"
and any(
isinstance(m, ExpectationMP) and isinstance(m.obs, LinearCombination)
for m in tape.measurements
)
or any(
m.obs is not None and not qml.pauli.is_pauli_word(m.obs) for m in single_term_obs_mps
)
if grouping_strategy == "wires" or any(
m.obs is not None and not qml.pauli.is_pauli_word(m.obs) for m in single_term_obs_mps
):
# This is a loose check to see whether wires grouping or qwc grouping should be used,
astralcai marked this conversation as resolved.
Show resolved Hide resolved
# which does not necessarily make perfect sense but is consistent with the old decision
# logic in `Device.batch_transform`. The premise is that qwc grouping is classically
# expensive but produces fewer tapes, whereas wires grouping is classically faster to
# compute, but inefficient quantum-wise. If this transform is to be added to a device's
# `preprocess`, it will be performed for every circuit execution, which can get very
# expensive if there is a large number of observables. The reasoning here is, large
# Hamiltonians typically come in the form of a `LinearCombination`, so
# if we see one of those, use wires grouping to be safe. Otherwise, use qwc grouping.
# TODO: here we fall back to wire-based grouping if any of the observables in the tape
# is not a pauli word. As a result, adding a single measurement to a circuit could
# significantly increase the number of circuit executions. We should be able to
# separate the logic for pauli-word observables and non-pauli-word observables,
# putting non-pauli-word observables in separate wire-based groups, but using qwc
# based grouping for the rest of the observables. [sc-79686]
return _split_using_wires_grouping(tape, single_term_obs_mps, offsets)

return _split_using_qwc_grouping(tape, single_term_obs_mps, offsets)
Expand Down Expand Up @@ -370,7 +359,7 @@
mp_groups.append(mp_group)
group_sizes.append(group_size)

tapes = [tape.__class__(tape.operations, mps, shots=tape.shots) for mps in mp_groups]
tapes = [tape.copy(measurements=mps) for mps in mp_groups]
return tapes, partial(
_processing_fn_with_grouping,
single_term_obs_mps=single_term_obs_mps,
Expand Down Expand Up @@ -405,7 +394,7 @@
obs_list = [_mp_to_obs(m, tape) for m in measurements]
index_groups = []
if len(obs_list) > 0:
_, index_groups = qml.pauli.group_observables(obs_list, range(len(obs_list)))
index_groups = qml.pauli.compute_partition_indices(obs_list)

# A dictionary for measurements of each unique single-term observable, mapped to the
# indices of the original measurements it belongs to, its coefficients, the index of
Expand Down Expand Up @@ -436,7 +425,7 @@
)
group_sizes.append(1)

tapes = [tape.__class__(tape.operations, mps, shots=tape.shots) for mps in mp_groups]
tapes = [tape.copy(measurements=mps) for mps in mp_groups]
return tapes, partial(
_processing_fn_with_grouping,
single_term_obs_mps=single_term_obs_mps_grouped,
Expand Down Expand Up @@ -507,7 +496,7 @@
single_term_obs_mps_grouped[smp] = (mp_indices, coeffs, num_groups, 0)
num_groups += 1

tapes = [tape.__class__(tape.operations, mps, shots=tape.shots) for mps in mp_groups]
tapes = [tape.copy(measurements=mps) for mps in mp_groups]
return tapes, partial(
_processing_fn_with_grouping,
single_term_obs_mps=single_term_obs_mps_grouped,
Expand Down Expand Up @@ -558,8 +547,10 @@
# Otherwise, add this new measurement to the list of single-term measurements.
else:
single_term_obs_mps[sm] = ([mp_idx], [c])
elif isinstance(obs, qml.Identity):
offset += 1
else:
if isinstance(obs, SProd):
if isinstance(obs, (SProd, Prod)):
obs = obs.simplify()
if isinstance(obs, Sum):
raise RuntimeError(
Expand Down Expand Up @@ -606,30 +597,10 @@
res_batch_for_each_mp[mp_idx].append(res[smp_idx])
coeffs_for_each_mp[mp_idx].append(coeff)

result_shape = _infer_result_shape(shots, batch_size)

# Sum up the results for each original measurement
res_for_each_mp = [
_sum_terms(_sub_res, coeffs, offset, result_shape)
for _sub_res, coeffs, offset in zip(res_batch_for_each_mp, coeffs_for_each_mp, offsets)
]

# res_for_each_mp should have shape (n_mps, [,n_shots] [,batch_size])
if len(res_for_each_mp) == 1:
return res_for_each_mp[0]

if shots.has_partitioned_shots:
# If the shot vector dimension exists, it should be moved to the first axis
# Basically, the shape becomes (n_shots, n_mps, [,batch_size])
res_for_each_mp = [
tuple(res_for_each_mp[j][i] for j in range(len(res_for_each_mp)))
for i in range(shots.num_copies)
]

return tuple(res_for_each_mp)
return _res_for_each_mp(res_batch_for_each_mp, coeffs_for_each_mp, offsets, shots, batch_size)


def _processing_fn_with_grouping(

Check notice on line 603 in pennylane/transforms/split_non_commuting.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/split_non_commuting.py#L603

Too many positional arguments (6/5) (too-many-positional-arguments)
res: ResultBatch,
single_term_obs_mps: dict[
MeasurementProcess, tuple[list[int], list[Union[float, TensorLike]], int, int]
Expand Down Expand Up @@ -678,6 +649,12 @@
res_batch_for_each_mp[mp_idx].append(sub_res)
coeffs_for_each_mp[mp_idx].append(coeff)

return _res_for_each_mp(res_batch_for_each_mp, coeffs_for_each_mp, offsets, shots, batch_size)


def _res_for_each_mp(res_batch_for_each_mp, coeffs_for_each_mp, offsets, shots, batch_size):
"""Helper function that combines a result batch into results for each mp"""

result_shape = _infer_result_shape(shots, batch_size)

# Sum up the results for each original measurement
Expand Down
2 changes: 1 addition & 1 deletion tests/measurements/test_classical_shadow.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def test_multi_measurement_allowed(self, seed):
def circuit():
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
return qml.shadow_expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(0))
return qml.shadow_expval(qml.PauliZ(0), seed=seed), qml.expval(qml.PauliZ(0))

res = circuit()
assert isinstance(res, tuple)
Expand Down
1 change: 1 addition & 0 deletions tests/measurements/test_probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def test_observable_is_measurement_value_list(
): # pylint: disable=too-many-arguments
"""Test that probs for mid-circuit measurement values
are correct for a measurement value list."""

dev = qml.device("default.qubit", seed=seed)

@qml.qnode(dev)
Expand Down
Loading
Loading