Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for Torch and Jax with dynamic_one_shot #5672

Merged
merged 115 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
115 commits
Select commit Hold shift + click to select a range
2d6d336
Added rng and prng_key to get_final_state, apply_operation
mudit2812 Mar 7, 2024
7560cdf
Use rng in apply_operation args; linting
mudit2812 Mar 7, 2024
11290a7
WIP
vincentmr Apr 4, 2024
1313af8
Fix measure_with_samples' handling of mid_measurements.
vincentmr Apr 9, 2024
412374e
Remove comments.
vincentmr Apr 9, 2024
bd32a35
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 9, 2024
19d7f73
update changelog
vincentmr Apr 9, 2024
8aea51e
Fix legacy node native mcm test.
vincentmr Apr 9, 2024
3fb1758
Fix old device API
vincentmr Apr 9, 2024
57386e6
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 9, 2024
b9d8996
Fill out mid_measurements in mock device.
vincentmr Apr 10, 2024
5f7a33f
Refactor using masks.
vincentmr Apr 10, 2024
146baef
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 10, 2024
c975493
Update rng use; fix docs
mudit2812 Apr 10, 2024
6899b2c
[skip ci] Skip CI
mudit2812 Apr 10, 2024
87f34b2
Always compute all results, even if mv = -1.
vincentmr Apr 11, 2024
721c284
[skip ci] testing changes to native MCM tests
mudit2812 Apr 11, 2024
4aabfe5
Make post-processing jax-ready. WIP
vincentmr Apr 11, 2024
2614d53
Update pennylane/devices/qubit/sampling.py
vincentmr Apr 12, 2024
c707d46
Implement Christina's suggestions.
vincentmr Apr 12, 2024
bd800bb
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 12, 2024
f89139a
Add rng and jit mid_measure. WARN: regression
vincentmr Apr 12, 2024
f9c7b10
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 12, 2024
f7d7fd2
WIP
vincentmr Apr 12, 2024
a836a51
Implement apply_mid_meas with norm of a branch @mudit.
vincentmr Apr 12, 2024
5b1b732
Merge remote-tracking branch 'origin/simulate-rng' into feature/dynam…
vincentmr Apr 12, 2024
34b513d
Add jax-jit support in apply_cond/mid_meas.
vincentmr Apr 12, 2024
2562b94
Introduce prng
vincentmr Apr 12, 2024
3db8767
WIP
vincentmr Apr 12, 2024
ceaa339
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 15, 2024
ccafa89
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 15, 2024
e1804a8
Remove jax branch in apply_mid_meas.
vincentmr Apr 15, 2024
f7c65bc
Fix single wire probs.
vincentmr Apr 15, 2024
5aec928
Bug fix MV lists.
vincentmr Apr 15, 2024
11d6d43
Add tests for math.all/any
vincentmr Apr 17, 2024
2753b26
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 17, 2024
62ef926
qml.math.all doesn't need implementation.
vincentmr Apr 17, 2024
6c15204
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 17, 2024
97cf86e
Fix error message @albi3ro
vincentmr Apr 18, 2024
ce9a4a1
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 18, 2024
027c65a
Test measure_final_state raises.
vincentmr Apr 18, 2024
85334e7
Add tests in tests/transforms/test_dynamic_one_shot.py
vincentmr Apr 19, 2024
de47630
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 19, 2024
0c295ac
Fix lint.
vincentmr Apr 19, 2024
eeed3d4
Add error test.
vincentmr Apr 19, 2024
387d898
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 19, 2024
88e9eca
Merge branch 'master' into simulate-rng
mudit2812 Apr 19, 2024
985775e
Merge branch 'master' into simulate-rng
mudit2812 Apr 19, 2024
e1d8e4b
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 22, 2024
7f7be2d
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 22, 2024
d7a91fc
Add test for batched dynamic_one_shot
vincentmr Apr 22, 2024
7125646
Merge branch 'master' into feature/dynamic_samples
vincentmr Apr 22, 2024
c46a46a
Merge remote-tracking branch 'origin/feature/dynamic_samples' into fe…
vincentmr Apr 22, 2024
ae5abe7
Refactor here and there.
vincentmr Apr 22, 2024
b6b5e68
Sort imports
vincentmr Apr 22, 2024
8162704
Revert isort changes.
vincentmr Apr 22, 2024
8f3e558
Merge branch 'master' into simulate-rng
mudit2812 Apr 22, 2024
e091e86
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 22, 2024
b0a1ec9
Correctly propagating PRNGKey and RNG to apply_operation
mudit2812 Apr 22, 2024
6ad8352
Fix _sample_state_jax cond
vincentmr Apr 23, 2024
c641af4
Fix test_parse_native_mid_circuit_measurements_unsupported_meas
vincentmr Apr 23, 2024
42a649d
DQ.execute distributes kwargs with multithreading
mudit2812 Apr 23, 2024
148d051
\Merge remote-tracking branch 'origin/simulate-rng' into feature/dyna…
vincentmr Apr 23, 2024
bec05b7
Add device as kwargs to dyn_one_shot
vincentmr Apr 24, 2024
b301554
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 29, 2024
0565c61
Filter postselected values in post-processing.
vincentmr Apr 29, 2024
bb49932
Add jax.jit tests.
vincentmr Apr 29, 2024
a0b6703
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 29, 2024
fabfe8f
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr Apr 30, 2024
3c05878
Fix flaky test's seed in test_broadcast_expand.
vincentmr Apr 30, 2024
9ea624e
jax.numpy.array(params)
vincentmr Apr 30, 2024
11f369f
Merge branch 'master' into feature/dynamic_samples_jit
vincentmr Apr 30, 2024
5398ce7
Update pennylane/devices/qubit/apply_operation.py
vincentmr Apr 30, 2024
f140b5c
Update pennylane/devices/qubit/apply_operation.py
vincentmr Apr 30, 2024
a858e0e
Update pennylane/devices/qubit/sampling.py
vincentmr Apr 30, 2024
303faaf
Implement Mudit's suggestions.
vincentmr Apr 30, 2024
d948334
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr May 1, 2024
91945be
Use more robust QubitUnitary for the time being.
vincentmr May 1, 2024
457bcae
Merge remote-tracking branch 'origin/master' into feature/dynamic_sam…
vincentmr May 1, 2024
ac42ef0
Make sure we're not using Python callbacks with jaxpr.
vincentmr May 1, 2024
7cc9472
Fix reset matrix.
vincentmr May 1, 2024
b13961e
Remove unused where.
vincentmr May 1, 2024
268e4e6
Merge branch 'master' into feature/dynamic_samples_jit
vincentmr May 1, 2024
2b1db5f
Add dev notes and remove useless importorskip.
vincentmr May 2, 2024
6c3bd0b
Merge branch 'master' into feature/dynamic_samples_jit
vincentmr May 2, 2024
0c3d5d3
Update pennylane/devices/qubit/apply_operation.py
vincentmr May 2, 2024
77bdf4c
Update pennylane/devices/qubit/apply_operation.py
vincentmr May 2, 2024
13fc37b
Merge branch 'master' into feature/dynamic_samples_jit
vincentmr May 6, 2024
9d6e823
Merge branch 'master' into feature/dynamic_samples_jit
vincentmr May 7, 2024
38bb92d
Merge branch 'master' into feature/dynamic_samples_jit
vincentmr May 8, 2024
ef06fdc
Daily rc sync to master (#5650)
github-actions[bot] May 6, 2024
3dc3283
Update legacy opmath tests to temporarily use the rc branch (#5653)
astralcai May 6, 2024
e088513
Daily rc sync to master (#5665)
github-actions[bot] May 7, 2024
3f6d658
Fix transforms that error with non-commuting measurements (#5424)
albi3ro May 7, 2024
39c9ac2
Removing `qml.load` (#5654)
PietropaoloFrisoni May 7, 2024
d125334
Removing `qml.from_qasm_file` from source code and tests (#5659)
PietropaoloFrisoni May 7, 2024
71d2879
Update legacy_op_math.yml to use master (#5670)
astralcai May 8, 2024
73a5b5a
[BUG] Fix a fermi sentence bug in bravyi_kitaev (#5671)
soranjh May 8, 2024
a0909ff
Testing changes needed to get interfaces to work
mudit2812 May 8, 2024
0769622
Merge branch 'feature/dynamic_samples_jit' into dos-interfaces
mudit2812 May 8, 2024
a038190
Apply suggestions from code review
mudit2812 May 8, 2024
d273d92
Fix rebase artifact
mudit2812 May 8, 2024
fbc69ca
Merge branch 'master' into dos-interfaces
mudit2812 May 14, 2024
cae872e
Fixed torch and jax interfaces; added tests
mudit2812 May 15, 2024
8f6a8cd
Small default interface fix; changelog entry
mudit2812 May 15, 2024
148159e
Merge branch 'master' into dos-interfaces
mudit2812 May 15, 2024
a54f703
Added diff_method to tests
mudit2812 May 15, 2024
2ed91dc
Reverted change to qml math
mudit2812 May 15, 2024
75c5031
Fixed tests
mudit2812 May 16, 2024
e19bfc6
Merge branch 'master' into dos-interfaces
mudit2812 May 16, 2024
af760f8
Addressed review
mudit2812 May 16, 2024
9cb6d46
Fixed torch tests
mudit2812 May 17, 2024
ce49af4
Added jax dynamic_one_shot unit tests
mudit2812 May 24, 2024
b45cbcb
Tidy up tests
mudit2812 May 24, 2024
7d1915d
Merge branch 'master' into dos-interfaces
mudit2812 May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@

<h3>Bug fixes 🐛</h3>

* The `dynamic_one_shot` transform now has expanded support for the `jax` and `torch` interfaces.
[(#5672)](https://github.com/PennyLaneAI/pennylane/pull/5672)
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved

* The decomposition of `StronglyEntanglingLayers` is now compatible with broadcasting.
[(#5716)](https://github.com/PennyLaneAI/pennylane/pull/5716)

Expand Down Expand Up @@ -213,5 +216,6 @@ Korbinian Kottmann,
Christina Lee,
Vincent Michaud-Rioux,
Lee James O'Riordan,
Mudit Pandey,
Kenya Sakka,
David Wierichs.
6 changes: 4 additions & 2 deletions pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,10 @@ def binomial_fn(n, p):
# to reset enables jax.jit and prevents it from using Python callbacks
element = op.reset and sample == 1
matrix = qml.math.array(
[[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]], like=interface
).astype(float)
[[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]],
like=interface,
dtype=float,
)
state = apply_operation(
qml.QubitUnitary(matrix, wire), state, is_state_batched=is_state_batched, debugger=debugger
)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def simulate(
trainable_params=circuit.trainable_params,
)
keys = jax_random_split(prng_key, num=circuit.shots.total_shots)
if qml.math.get_deep_interface(circuit.data) == "jax":
if qml.math.get_deep_interface(circuit.data) == "jax" and prng_key is not None:
# pylint: disable=import-outside-toplevel
import jax

Expand Down
1 change: 1 addition & 0 deletions pennylane/math/single_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def _take_autograd(tensor, indices, axis=None):
ar.autoray._SUBMODULE_ALIASES["tensorflow", "isclose"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "atleast_1d"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "all"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "ravel"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "vstack"] = "tensorflow.experimental.numpy"

tf_fft_functions = [
Expand Down
28 changes: 19 additions & 9 deletions pennylane/transforms/dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def measurement_with_no_shots(measurement):
)

interface = qml.math.get_deep_interface(circuit.data)
interface = "numpy" if interface == "builtins" else interface

all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)]
n_mcms = len(all_mcms)
Expand All @@ -243,10 +244,13 @@ def measurement_with_no_shots(measurement):
mcm_samples = qml.math.array(
[[res] if single_measurement else res[-n_mcms::] for res in results], like=interface
)
has_postselect = qml.math.array([op.postselect is not None for op in all_mcms]).reshape((1, -1))
# Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1
has_postselect = qml.math.array(
[[int(op.postselect is not None) for op in all_mcms]], like=interface
)
postselect = qml.math.array(
[0 if op.postselect is None else op.postselect for op in all_mcms]
).reshape((1, -1))
[[0 if op.postselect is None else op.postselect for op in all_mcms]], like=interface
)
is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1)
has_valid = qml.math.any(is_valid)
mid_meas = [op for op in circuit.operations if is_mcm(op)]
Expand All @@ -268,7 +272,12 @@ def measurement_with_no_shots(measurement):
meas = measurement_with_no_shots(m)
m_count += 1
else:
result = qml.math.array([res[m_count] for res in results], like=interface)
result = [res[m_count] for res in results]
if not isinstance(m, CountsMP):
# We don't need to cast to arrays when using qml.counts. qml.math.array is not viable
# as it assumes all elements of the input are of builtin python types and not belonging
# to any particular interface
result = qml.math.stack(result, like=interface)
meas = gather_non_mcm(m, result, is_valid)
m_count += 1
if isinstance(m, SampleMP):
Expand All @@ -292,7 +301,9 @@ def gather_non_mcm(circuit_measurement, measurement, is_valid):
if isinstance(circuit_measurement, CountsMP):
tmp = Counter()
for i, d in enumerate(measurement):
tmp.update(dict((k, v * is_valid[i]) for k, v in d.items()))
tmp.update(
dict((k if isinstance(k, str) else float(k), v * is_valid[i]) for k, v in d.items())
)
tmp = Counter({k: v for k, v in tmp.items() if v > 0})
return dict(sorted(tmp.items()))
if isinstance(circuit_measurement, ExpectationMP):
Expand Down Expand Up @@ -341,14 +352,13 @@ def gather_mcm(measurement, samples, is_valid):
counts = qml.math.array(counts, like=interface)
return counts / qml.math.sum(counts)
if isinstance(measurement, CountsMP):
mcm_samples = [{"".join(str(v) for v in tuple(s)): 1} for s in mcm_samples]
mcm_samples = [{"".join(str(int(v)) for v in tuple(s)): 1} for s in mcm_samples]
return gather_non_mcm(measurement, mcm_samples, is_valid)
mcm_samples = qml.math.ravel(qml.math.array(mv.concretize(samples), like=interface))
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(measurement, ProbabilityMP):
mcm_samples = qml.math.array(mv.concretize(samples), like=interface).ravel()
counts = [qml.math.sum((mcm_samples == v) * is_valid) for v in list(mv.branches.values())]
counts = qml.math.array(counts, like=interface)
return counts / qml.math.sum(counts)
mcm_samples = qml.math.array([mv.concretize(samples)], like=interface).ravel()
if isinstance(measurement, CountsMP):
mcm_samples = [{s: 1} for s in mcm_samples]
mcm_samples = [{float(s): 1} for s in mcm_samples]
return gather_non_mcm(measurement, mcm_samples, is_valid)
55 changes: 50 additions & 5 deletions tests/devices/default_qubit/test_default_qubit_native_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for default qubit preprocessing."""
from functools import partial, reduce
from functools import reduce
from typing import Iterable, Sequence

import numpy as np
Expand All @@ -24,7 +24,11 @@

pytestmark = pytest.mark.slow

get_device = partial(qml.device, name="default.qubit", seed=8237945)

def get_device(**kwargs):
kwargs.setdefault("shots", None)
kwargs.setdefault("seed", 8237945)
return qml.device("default.qubit", **kwargs)


def validate_counts(shots, results1, results2, batch_size=None):
Expand Down Expand Up @@ -88,7 +92,7 @@ def validate_samples(shots, results1, results2, batch_size=None):
assert results1.ndim == results2.ndim
if results2.ndim > 1:
assert results1.shape[1] == results2.shape[1]
np.allclose(np.sum(results1), np.sum(results2), atol=20, rtol=0.2)
np.allclose(qml.math.sum(results1), qml.math.sum(results2), atol=20, rtol=0.2)


def validate_expval(shots, results1, results2, batch_size=None):
Expand Down Expand Up @@ -611,7 +615,7 @@ def test_sample_with_prng_key(shots, postselect, reset):
# pylint: disable=import-outside-toplevel
from jax.random import PRNGKey

dev = qml.device("default.qubit", shots=shots, seed=PRNGKey(678))
dev = get_device(shots=shots, seed=PRNGKey(678))
param = [np.pi / 4, np.pi / 3]
obs = qml.PauliZ(0) @ qml.PauliZ(1)

Expand Down Expand Up @@ -659,7 +663,7 @@ def test_jax_jit(diff_method, postselect, reset):

shots = 10

dev = qml.device("default.qubit", shots=shots, seed=jax.random.PRNGKey(678))
dev = get_device(shots=shots, seed=jax.random.PRNGKey(678))
params = [np.pi / 2.5, np.pi / 3, -np.pi / 3.5]
obs = qml.PauliY(0)

Expand Down Expand Up @@ -750,3 +754,44 @@ def func(x):
results2 = func2(param)
for r1, r2 in zip(results1.keys(), results2.keys()):
assert r1 == r2


@pytest.mark.torch
@pytest.mark.parametrize("postselect", [None, 1])
@pytest.mark.parametrize("diff_method", [None, "best"])
@pytest.mark.parametrize("measure_f", [qml.probs, qml.sample, qml.expval, qml.var])
@pytest.mark.parametrize("meas_obj", [qml.PauliZ(1), [0, 1], "composite_mcm", "mcm_list"])
def test_torch_integration(postselect, diff_method, measure_f, meas_obj):
"""Test that native MCM circuits are executed correctly with Torch"""
if measure_f in (qml.var, qml.expval) and (
isinstance(meas_obj, list) or meas_obj == "mcm_list"
):
pytest.skip("Can't use wires/mcm lists with var or expval")

import torch

shots = 7000
dev = get_device(shots=shots, seed=123456789)
param = torch.tensor(np.pi / 3, dtype=torch.float64)
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved

@qml.qnode(dev, diff_method=diff_method)
def func(x):
qml.RX(x, 0)
m0 = qml.measure(0)
qml.RX(0.5 * x, 1)
m1 = qml.measure(1, postselect=postselect)
qml.cond((m0 + m1) == 2, qml.RY)(2.0 * x, 0)
m2 = qml.measure(0)

mid_measure = 0.5 * m2 if meas_obj == "composite_mcm" else [m1, m2]
measurement_key = "wires" if isinstance(meas_obj, list) else "op"
measurement_value = mid_measure if isinstance(meas_obj, str) else meas_obj
return measure_f(**{measurement_key: measurement_value})
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved

func1 = func
func2 = qml.defer_measurements(func)

results1 = func1(param)
results2 = func2(param)

validate_measurements(measure_f, shots, results1, results2)
165 changes: 165 additions & 0 deletions tests/transforms/test_dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,168 @@ def test_len_measurements_mcms(measure, aux_measure, n_meas):
assert len(aux_tape.measurements) == n_meas + n_mcms
assert isinstance(aux_tape.measurements[0], aux_measure)
assert all(isinstance(m, SampleMP) for m in aux_tape.measurements[1:])


def assert_results(res, shots, n_mcms):
"""Helper to check that expected raw results of executing the transformed tape are correct"""
assert len(res) == shots
# One for the non-MeasurementValue MP, and the rest of the mid-circuit measurements
assert all(len(r) == n_mcms + 1 for r in res)
# Not validating distribution of results as device sampling unit tests already validate
# that samples are generated correctly.


@pytest.mark.jax
@pytest.mark.parametrize("measure_f", (qml.expval, qml.probs, qml.sample, qml.var))
@pytest.mark.parametrize("shots", [20, [20, 21]])
@pytest.mark.parametrize("n_mcms", [1, 3])
def test_tape_results_jax(shots, n_mcms, measure_f):
"""Test that the simulation results of a tape are correct with jax parameters"""
import jax

dev = qml.device("default.qubit", wires=4, shots=shots, seed=jax.random.PRNGKey(123))
param = jax.numpy.array(np.pi / 2)

mv = qml.measure(0)
mp = mv.measurements[0]

tape = qml.tape.QuantumScript(
[qml.RX(param, 0), mp] + [MidMeasureMP(0, id=str(i)) for i in range(n_mcms - 1)],
[measure_f(op=qml.PauliZ(0)), measure_f(op=mv)],
shots=shots,
)

tapes, _ = qml.dynamic_one_shot(tape)
results = dev.execute(tapes)[0]

# The transformed tape never has a shot vector
if isinstance(shots, list):
shots = sum(shots)

assert_results(results, shots, n_mcms)


@pytest.mark.jax
@pytest.mark.parametrize(
"measure_f, expected1, expected2",
[
(qml.expval, 1.0, 1.0),
(qml.probs, [1, 0], [0, 1]),
(qml.sample, 1, 1),
(qml.var, 0.0, 0.0),
],
)
@pytest.mark.parametrize("shots", [20, [20, 21]])
@pytest.mark.parametrize("n_mcms", [1, 3])
def test_jax_results_processing(shots, n_mcms, measure_f, expected1, expected2):
"""Test that the results of tapes are processed correctly for tapes with jax parameters"""
import jax.numpy as jnp

mv = qml.measure(0)
mp = mv.measurements[0]

tape = qml.tape.QuantumScript(
[qml.RX(1.5, 0), mp] + [MidMeasureMP(0)] * (n_mcms - 1),
[measure_f(op=qml.PauliZ(0)), measure_f(op=mv)],
shots=shots,
)
_, fn = qml.dynamic_one_shot(tape)
all_shots = sum(shots) if isinstance(shots, list) else shots

first_res = jnp.array([1.0, 0.0]) if measure_f == qml.probs else jnp.array(1.0)
rest = jnp.array(1, dtype=int)
single_shot_res = (first_res,) + (rest,) * n_mcms
# Raw results for each shot are (sample_for_first_measurement,) + (sample for 1st MCM, sample for 2nd MCM, ...)
raw_results = (single_shot_res,) * all_shots
raw_results = (raw_results,)
res = fn(raw_results)

if measure_f is qml.sample:
# All samples 1
expected1 = (
[[expected1] * s for s in shots] if isinstance(shots, list) else [expected1] * shots
)
expected2 = (
[[expected2] * s for s in shots] if isinstance(shots, list) else [expected2] * shots
)
else:
expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1
expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2

if isinstance(shots, list):
assert len(res) == len(shots)
for r, e1, e2 in zip(res, expected1, expected2):
# Expected result is 2-list since we have two measurements in the tape
assert qml.math.allclose(r, [e1, e2])
else:
# Expected result is 2-list since we have two measurements in the tape
assert qml.math.allclose(res, [expected1, expected2])


@pytest.mark.jax
@pytest.mark.parametrize(
"measure_f, expected1, expected2",
[
(qml.expval, 1.0, 1.0),
(qml.probs, [1, 0], [0, 1]),
(qml.sample, 1, 1),
(qml.var, 0.0, 0.0),
],
)
@pytest.mark.parametrize("shots", [20, [20, 22]])
def test_jax_results_postselection_processing(shots, measure_f, expected1, expected2):
"""Test that the results of tapes are processed correctly for tapes with jax parameters
when postselecting"""
import jax.numpy as jnp

param = jnp.array(np.pi / 2)
fill_value = np.iinfo(np.int32).min
mv = qml.measure(0, postselect=1)
mp = mv.measurements[0]

tape = qml.tape.QuantumScript(
[qml.RX(param, 0), mp, MidMeasureMP(0)],
[measure_f(op=qml.PauliZ(0)), measure_f(op=mv)],
shots=shots,
)
_, fn = qml.dynamic_one_shot(tape)
all_shots = sum(shots) if isinstance(shots, list) else shots

# Alternating tuple. Only the values at odd indices are valid
first_res_two_shot = (
(jnp.array([1.0, 0.0]), jnp.array([0.0, 1.0]))
if measure_f == qml.probs
else (jnp.array(1.0), jnp.array(0.0))
)
first_res = first_res_two_shot * (all_shots // 2)
# Tuple of alternating 1s and 0s. Zero is invalid as postselecting on 1
postselect_res = (jnp.array(1, dtype=int), jnp.array(0, dtype=int)) * (all_shots // 2)
rest = (jnp.array(1, dtype=int),) * all_shots
# Raw results for each shot are (sample_for_first_measurement, sample for 1st MCM, sample for 2nd MCM)
raw_results = tuple(zip(first_res, postselect_res, rest))
raw_results = (raw_results,)
res = fn(raw_results)

if measure_f is qml.sample:
expected1 = (
[[expected1, fill_value] * (s // 2) for s in shots]
if isinstance(shots, list)
else [expected1, fill_value] * (shots // 2)
)
expected2 = (
[[expected2, fill_value] * (s // 2) for s in shots]
if isinstance(shots, list)
else [expected2, fill_value] * (shots // 2)
)
else:
expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1
expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2

if isinstance(shots, list):
assert len(res) == len(shots)
for r, e1, e2 in zip(res, expected1, expected2):
# Expected result is 2-list since we have two measurements in the tape
assert qml.math.allclose(r, [e1, e2])
else:
# Expected result is 2-list since we have two measurements in the tape
assert qml.math.allclose(res, [expected1, expected2])
Loading