diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 3690638820b..3b03629e105 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -156,6 +156,9 @@
Bug fixes 🐛
+* The `dynamic_one_shot` transform now has expanded support for the `jax` and `torch` interfaces.
+ [(#5672)](https://github.com/PennyLaneAI/pennylane/pull/5672)
+
* The decomposition of `StronglyEntanglingLayers` is now compatible with broadcasting.
[(#5716)](https://github.com/PennyLaneAI/pennylane/pull/5716)
@@ -213,5 +216,6 @@ Korbinian Kottmann,
Christina Lee,
Vincent Michaud-Rioux,
Lee James O'Riordan,
+Mudit Pandey,
Kenya Sakka,
David Wierichs.
diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py
index 9b42e942c74..ae0c9117c38 100644
--- a/pennylane/devices/qubit/apply_operation.py
+++ b/pennylane/devices/qubit/apply_operation.py
@@ -330,8 +330,10 @@ def binomial_fn(n, p):
# to reset enables jax.jit and prevents it from using Python callbacks
element = op.reset and sample == 1
matrix = qml.math.array(
- [[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]], like=interface
- ).astype(float)
+ [[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]],
+ like=interface,
+ dtype=float,
+ )
state = apply_operation(
qml.QubitUnitary(matrix, wire), state, is_state_batched=is_state_batched, debugger=debugger
)
diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py
index f1929cd83db..7218449c625 100644
--- a/pennylane/devices/qubit/simulate.py
+++ b/pennylane/devices/qubit/simulate.py
@@ -287,7 +287,7 @@ def simulate(
trainable_params=circuit.trainable_params,
)
keys = jax_random_split(prng_key, num=circuit.shots.total_shots)
- if qml.math.get_deep_interface(circuit.data) == "jax":
+ if qml.math.get_deep_interface(circuit.data) == "jax" and prng_key is not None:
# pylint: disable=import-outside-toplevel
import jax
diff --git a/pennylane/math/single_dispatch.py b/pennylane/math/single_dispatch.py
index 900d90d98dc..7f69ee82b11 100644
--- a/pennylane/math/single_dispatch.py
+++ b/pennylane/math/single_dispatch.py
@@ -242,6 +242,7 @@ def _take_autograd(tensor, indices, axis=None):
ar.autoray._SUBMODULE_ALIASES["tensorflow", "isclose"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "atleast_1d"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "all"] = "tensorflow.experimental.numpy"
+ar.autoray._SUBMODULE_ALIASES["tensorflow", "ravel"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "vstack"] = "tensorflow.experimental.numpy"
tf_fft_functions = [
diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py
index 109f43db30b..998cefa93c5 100644
--- a/pennylane/transforms/dynamic_one_shot.py
+++ b/pennylane/transforms/dynamic_one_shot.py
@@ -228,6 +228,7 @@ def measurement_with_no_shots(measurement):
)
interface = qml.math.get_deep_interface(circuit.data)
+ interface = "numpy" if interface == "builtins" else interface
all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)]
n_mcms = len(all_mcms)
@@ -243,10 +244,13 @@ def measurement_with_no_shots(measurement):
mcm_samples = qml.math.array(
[[res] if single_measurement else res[-n_mcms::] for res in results], like=interface
)
- has_postselect = qml.math.array([op.postselect is not None for op in all_mcms]).reshape((1, -1))
+ # Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1
+ has_postselect = qml.math.array(
+ [[int(op.postselect is not None) for op in all_mcms]], like=interface
+ )
postselect = qml.math.array(
- [0 if op.postselect is None else op.postselect for op in all_mcms]
- ).reshape((1, -1))
+ [[0 if op.postselect is None else op.postselect for op in all_mcms]], like=interface
+ )
is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1)
has_valid = qml.math.any(is_valid)
mid_meas = [op for op in circuit.operations if is_mcm(op)]
@@ -268,7 +272,12 @@ def measurement_with_no_shots(measurement):
meas = measurement_with_no_shots(m)
m_count += 1
else:
- result = qml.math.array([res[m_count] for res in results], like=interface)
+ result = [res[m_count] for res in results]
+ if not isinstance(m, CountsMP):
+ # We don't need to cast to arrays when using qml.counts. qml.math.array is not viable
+ # as it assumes all elements of the input are of builtin python types and not belonging
+ # to any particular interface
+ result = qml.math.stack(result, like=interface)
meas = gather_non_mcm(m, result, is_valid)
m_count += 1
if isinstance(m, SampleMP):
@@ -292,7 +301,9 @@ def gather_non_mcm(circuit_measurement, measurement, is_valid):
if isinstance(circuit_measurement, CountsMP):
tmp = Counter()
for i, d in enumerate(measurement):
- tmp.update(dict((k, v * is_valid[i]) for k, v in d.items()))
+ tmp.update(
+ dict((k if isinstance(k, str) else float(k), v * is_valid[i]) for k, v in d.items())
+ )
tmp = Counter({k: v for k, v in tmp.items() if v > 0})
return dict(sorted(tmp.items()))
if isinstance(circuit_measurement, ExpectationMP):
@@ -341,14 +352,13 @@ def gather_mcm(measurement, samples, is_valid):
counts = qml.math.array(counts, like=interface)
return counts / qml.math.sum(counts)
if isinstance(measurement, CountsMP):
- mcm_samples = [{"".join(str(v) for v in tuple(s)): 1} for s in mcm_samples]
+ mcm_samples = [{"".join(str(int(v)) for v in tuple(s)): 1} for s in mcm_samples]
return gather_non_mcm(measurement, mcm_samples, is_valid)
+ mcm_samples = qml.math.ravel(qml.math.array(mv.concretize(samples), like=interface))
if isinstance(measurement, ProbabilityMP):
- mcm_samples = qml.math.array(mv.concretize(samples), like=interface).ravel()
counts = [qml.math.sum((mcm_samples == v) * is_valid) for v in list(mv.branches.values())]
counts = qml.math.array(counts, like=interface)
return counts / qml.math.sum(counts)
- mcm_samples = qml.math.array([mv.concretize(samples)], like=interface).ravel()
if isinstance(measurement, CountsMP):
- mcm_samples = [{s: 1} for s in mcm_samples]
+ mcm_samples = [{float(s): 1} for s in mcm_samples]
return gather_non_mcm(measurement, mcm_samples, is_valid)
diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py
index 4848ea4f6e3..e2184874cec 100644
--- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py
+++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for default qubit preprocessing."""
-from functools import partial, reduce
+from functools import reduce
from typing import Iterable, Sequence
import numpy as np
@@ -24,7 +24,11 @@
pytestmark = pytest.mark.slow
-get_device = partial(qml.device, name="default.qubit", seed=8237945)
+
+def get_device(**kwargs):
+ kwargs.setdefault("shots", None)
+ kwargs.setdefault("seed", 8237945)
+ return qml.device("default.qubit", **kwargs)
def validate_counts(shots, results1, results2, batch_size=None):
@@ -88,7 +92,7 @@ def validate_samples(shots, results1, results2, batch_size=None):
assert results1.ndim == results2.ndim
if results2.ndim > 1:
assert results1.shape[1] == results2.shape[1]
- np.allclose(np.sum(results1), np.sum(results2), atol=20, rtol=0.2)
+ np.allclose(qml.math.sum(results1), qml.math.sum(results2), atol=20, rtol=0.2)
def validate_expval(shots, results1, results2, batch_size=None):
@@ -611,7 +615,7 @@ def test_sample_with_prng_key(shots, postselect, reset):
# pylint: disable=import-outside-toplevel
from jax.random import PRNGKey
- dev = qml.device("default.qubit", shots=shots, seed=PRNGKey(678))
+ dev = get_device(shots=shots, seed=PRNGKey(678))
param = [np.pi / 4, np.pi / 3]
obs = qml.PauliZ(0) @ qml.PauliZ(1)
@@ -659,7 +663,7 @@ def test_jax_jit(diff_method, postselect, reset):
shots = 10
- dev = qml.device("default.qubit", shots=shots, seed=jax.random.PRNGKey(678))
+ dev = get_device(shots=shots, seed=jax.random.PRNGKey(678))
params = [np.pi / 2.5, np.pi / 3, -np.pi / 3.5]
obs = qml.PauliY(0)
@@ -750,3 +754,44 @@ def func(x):
results2 = func2(param)
for r1, r2 in zip(results1.keys(), results2.keys()):
assert r1 == r2
+
+
+@pytest.mark.torch
+@pytest.mark.parametrize("postselect", [None, 1])
+@pytest.mark.parametrize("diff_method", [None, "best"])
+@pytest.mark.parametrize("measure_f", [qml.probs, qml.sample, qml.expval, qml.var])
+@pytest.mark.parametrize("meas_obj", [qml.PauliZ(1), [0, 1], "composite_mcm", "mcm_list"])
+def test_torch_integration(postselect, diff_method, measure_f, meas_obj):
+ """Test that native MCM circuits are executed correctly with Torch"""
+ if measure_f in (qml.var, qml.expval) and (
+ isinstance(meas_obj, list) or meas_obj == "mcm_list"
+ ):
+ pytest.skip("Can't use wires/mcm lists with var or expval")
+
+ import torch
+
+ shots = 7000
+ dev = get_device(shots=shots, seed=123456789)
+ param = torch.tensor(np.pi / 3, dtype=torch.float64)
+
+ @qml.qnode(dev, diff_method=diff_method)
+ def func(x):
+ qml.RX(x, 0)
+ m0 = qml.measure(0)
+ qml.RX(0.5 * x, 1)
+ m1 = qml.measure(1, postselect=postselect)
+ qml.cond((m0 + m1) == 2, qml.RY)(2.0 * x, 0)
+ m2 = qml.measure(0)
+
+ mid_measure = 0.5 * m2 if meas_obj == "composite_mcm" else [m1, m2]
+ measurement_key = "wires" if isinstance(meas_obj, list) else "op"
+ measurement_value = mid_measure if isinstance(meas_obj, str) else meas_obj
+ return measure_f(**{measurement_key: measurement_value})
+
+ func1 = func
+ func2 = qml.defer_measurements(func)
+
+ results1 = func1(param)
+ results2 = func2(param)
+
+ validate_measurements(measure_f, shots, results1, results2)
diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py
index 3ca186171fe..ce9d67a429b 100644
--- a/tests/transforms/test_dynamic_one_shot.py
+++ b/tests/transforms/test_dynamic_one_shot.py
@@ -155,3 +155,168 @@ def test_len_measurements_mcms(measure, aux_measure, n_meas):
assert len(aux_tape.measurements) == n_meas + n_mcms
assert isinstance(aux_tape.measurements[0], aux_measure)
assert all(isinstance(m, SampleMP) for m in aux_tape.measurements[1:])
+
+
+def assert_results(res, shots, n_mcms):
+ """Helper to check that expected raw results of executing the transformed tape are correct"""
+ assert len(res) == shots
+ # One for the non-MeasurementValue MP, and the rest of the mid-circuit measurements
+ assert all(len(r) == n_mcms + 1 for r in res)
+ # Not validating distribution of results as device sampling unit tests already validate
+ # that samples are generated correctly.
+
+
+@pytest.mark.jax
+@pytest.mark.parametrize("measure_f", (qml.expval, qml.probs, qml.sample, qml.var))
+@pytest.mark.parametrize("shots", [20, [20, 21]])
+@pytest.mark.parametrize("n_mcms", [1, 3])
+def test_tape_results_jax(shots, n_mcms, measure_f):
+ """Test that the simulation results of a tape are correct with jax parameters"""
+ import jax
+
+ dev = qml.device("default.qubit", wires=4, shots=shots, seed=jax.random.PRNGKey(123))
+ param = jax.numpy.array(np.pi / 2)
+
+ mv = qml.measure(0)
+ mp = mv.measurements[0]
+
+ tape = qml.tape.QuantumScript(
+ [qml.RX(param, 0), mp] + [MidMeasureMP(0, id=str(i)) for i in range(n_mcms - 1)],
+ [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)],
+ shots=shots,
+ )
+
+ tapes, _ = qml.dynamic_one_shot(tape)
+ results = dev.execute(tapes)[0]
+
+ # The transformed tape never has a shot vector
+ if isinstance(shots, list):
+ shots = sum(shots)
+
+ assert_results(results, shots, n_mcms)
+
+
+@pytest.mark.jax
+@pytest.mark.parametrize(
+ "measure_f, expected1, expected2",
+ [
+ (qml.expval, 1.0, 1.0),
+ (qml.probs, [1, 0], [0, 1]),
+ (qml.sample, 1, 1),
+ (qml.var, 0.0, 0.0),
+ ],
+)
+@pytest.mark.parametrize("shots", [20, [20, 21]])
+@pytest.mark.parametrize("n_mcms", [1, 3])
+def test_jax_results_processing(shots, n_mcms, measure_f, expected1, expected2):
+ """Test that the results of tapes are processed correctly for tapes with jax parameters"""
+ import jax.numpy as jnp
+
+ mv = qml.measure(0)
+ mp = mv.measurements[0]
+
+ tape = qml.tape.QuantumScript(
+ [qml.RX(1.5, 0), mp] + [MidMeasureMP(0)] * (n_mcms - 1),
+ [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)],
+ shots=shots,
+ )
+ _, fn = qml.dynamic_one_shot(tape)
+ all_shots = sum(shots) if isinstance(shots, list) else shots
+
+ first_res = jnp.array([1.0, 0.0]) if measure_f == qml.probs else jnp.array(1.0)
+ rest = jnp.array(1, dtype=int)
+ single_shot_res = (first_res,) + (rest,) * n_mcms
+ # Raw results for each shot are (sample_for_first_measurement,) + (sample for 1st MCM, sample for 2nd MCM, ...)
+ raw_results = (single_shot_res,) * all_shots
+ raw_results = (raw_results,)
+ res = fn(raw_results)
+
+ if measure_f is qml.sample:
+ # All samples 1
+ expected1 = (
+ [[expected1] * s for s in shots] if isinstance(shots, list) else [expected1] * shots
+ )
+ expected2 = (
+ [[expected2] * s for s in shots] if isinstance(shots, list) else [expected2] * shots
+ )
+ else:
+ expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1
+ expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2
+
+ if isinstance(shots, list):
+ assert len(res) == len(shots)
+ for r, e1, e2 in zip(res, expected1, expected2):
+ # Expected result is 2-list since we have two measurements in the tape
+ assert qml.math.allclose(r, [e1, e2])
+ else:
+ # Expected result is 2-list since we have two measurements in the tape
+ assert qml.math.allclose(res, [expected1, expected2])
+
+
+@pytest.mark.jax
+@pytest.mark.parametrize(
+ "measure_f, expected1, expected2",
+ [
+ (qml.expval, 1.0, 1.0),
+ (qml.probs, [1, 0], [0, 1]),
+ (qml.sample, 1, 1),
+ (qml.var, 0.0, 0.0),
+ ],
+)
+@pytest.mark.parametrize("shots", [20, [20, 22]])
+def test_jax_results_postselection_processing(shots, measure_f, expected1, expected2):
+ """Test that the results of tapes are processed correctly for tapes with jax parameters
+ when postselecting"""
+ import jax.numpy as jnp
+
+ param = jnp.array(np.pi / 2)
+ fill_value = np.iinfo(np.int32).min
+ mv = qml.measure(0, postselect=1)
+ mp = mv.measurements[0]
+
+ tape = qml.tape.QuantumScript(
+ [qml.RX(param, 0), mp, MidMeasureMP(0)],
+ [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)],
+ shots=shots,
+ )
+ _, fn = qml.dynamic_one_shot(tape)
+ all_shots = sum(shots) if isinstance(shots, list) else shots
+
+ # Alternating tuple. Only the values at odd indices are valid
+ first_res_two_shot = (
+ (jnp.array([1.0, 0.0]), jnp.array([0.0, 1.0]))
+ if measure_f == qml.probs
+ else (jnp.array(1.0), jnp.array(0.0))
+ )
+ first_res = first_res_two_shot * (all_shots // 2)
+ # Tuple of alternating 1s and 0s. Zero is invalid as postselecting on 1
+ postselect_res = (jnp.array(1, dtype=int), jnp.array(0, dtype=int)) * (all_shots // 2)
+ rest = (jnp.array(1, dtype=int),) * all_shots
+ # Raw results for each shot are (sample_for_first_measurement, sample for 1st MCM, sample for 2nd MCM)
+ raw_results = tuple(zip(first_res, postselect_res, rest))
+ raw_results = (raw_results,)
+ res = fn(raw_results)
+
+ if measure_f is qml.sample:
+ expected1 = (
+ [[expected1, fill_value] * (s // 2) for s in shots]
+ if isinstance(shots, list)
+ else [expected1, fill_value] * (shots // 2)
+ )
+ expected2 = (
+ [[expected2, fill_value] * (s // 2) for s in shots]
+ if isinstance(shots, list)
+ else [expected2, fill_value] * (shots // 2)
+ )
+ else:
+ expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1
+ expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2
+
+ if isinstance(shots, list):
+ assert len(res) == len(shots)
+ for r, e1, e2 in zip(res, expected1, expected2):
+ # Expected result is 2-list since we have two measurements in the tape
+ assert qml.math.allclose(r, [e1, e2])
+ else:
+ # Expected result is 2-list since we have two measurements in the tape
+ assert qml.math.allclose(res, [expected1, expected2])