diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3690638820b..3b03629e105 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -156,6 +156,9 @@

Bug fixes 🐛

+* The `dynamic_one_shot` transform now has expanded support for the `jax` and `torch` interfaces. + [(#5672)](https://github.com/PennyLaneAI/pennylane/pull/5672) + * The decomposition of `StronglyEntanglingLayers` is now compatible with broadcasting. [(#5716)](https://github.com/PennyLaneAI/pennylane/pull/5716) @@ -213,5 +216,6 @@ Korbinian Kottmann, Christina Lee, Vincent Michaud-Rioux, Lee James O'Riordan, +Mudit Pandey, Kenya Sakka, David Wierichs. diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index 9b42e942c74..ae0c9117c38 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -330,8 +330,10 @@ def binomial_fn(n, p): # to reset enables jax.jit and prevents it from using Python callbacks element = op.reset and sample == 1 matrix = qml.math.array( - [[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]], like=interface - ).astype(float) + [[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]], + like=interface, + dtype=float, + ) state = apply_operation( qml.QubitUnitary(matrix, wire), state, is_state_batched=is_state_batched, debugger=debugger ) diff --git a/pennylane/devices/qubit/simulate.py b/pennylane/devices/qubit/simulate.py index f1929cd83db..7218449c625 100644 --- a/pennylane/devices/qubit/simulate.py +++ b/pennylane/devices/qubit/simulate.py @@ -287,7 +287,7 @@ def simulate( trainable_params=circuit.trainable_params, ) keys = jax_random_split(prng_key, num=circuit.shots.total_shots) - if qml.math.get_deep_interface(circuit.data) == "jax": + if qml.math.get_deep_interface(circuit.data) == "jax" and prng_key is not None: # pylint: disable=import-outside-toplevel import jax diff --git a/pennylane/math/single_dispatch.py b/pennylane/math/single_dispatch.py index 900d90d98dc..7f69ee82b11 100644 --- a/pennylane/math/single_dispatch.py +++ b/pennylane/math/single_dispatch.py @@ -242,6 +242,7 @@ def _take_autograd(tensor, indices, axis=None): ar.autoray._SUBMODULE_ALIASES["tensorflow", "isclose"] = "tensorflow.experimental.numpy" ar.autoray._SUBMODULE_ALIASES["tensorflow", "atleast_1d"] = "tensorflow.experimental.numpy" ar.autoray._SUBMODULE_ALIASES["tensorflow", "all"] = "tensorflow.experimental.numpy" +ar.autoray._SUBMODULE_ALIASES["tensorflow", "ravel"] = "tensorflow.experimental.numpy" ar.autoray._SUBMODULE_ALIASES["tensorflow", "vstack"] = "tensorflow.experimental.numpy" tf_fft_functions = [ diff --git a/pennylane/transforms/dynamic_one_shot.py b/pennylane/transforms/dynamic_one_shot.py index 109f43db30b..998cefa93c5 100644 --- a/pennylane/transforms/dynamic_one_shot.py +++ b/pennylane/transforms/dynamic_one_shot.py @@ -228,6 +228,7 @@ def measurement_with_no_shots(measurement): ) interface = qml.math.get_deep_interface(circuit.data) + interface = "numpy" if interface == "builtins" else interface all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)] n_mcms = len(all_mcms) @@ -243,10 +244,13 @@ def measurement_with_no_shots(measurement): mcm_samples = qml.math.array( [[res] if single_measurement else res[-n_mcms::] for res in results], like=interface ) - has_postselect = qml.math.array([op.postselect is not None for op in all_mcms]).reshape((1, -1)) + # Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1 + has_postselect = qml.math.array( + [[int(op.postselect is not None) for op in all_mcms]], like=interface + ) postselect = qml.math.array( - [0 if op.postselect is None else op.postselect for op in all_mcms] - ).reshape((1, -1)) + [[0 if op.postselect is None else op.postselect for op in all_mcms]], like=interface + ) is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1) has_valid = qml.math.any(is_valid) mid_meas = [op for op in circuit.operations if is_mcm(op)] @@ -268,7 +272,12 @@ def measurement_with_no_shots(measurement): meas = measurement_with_no_shots(m) m_count += 1 else: - result = qml.math.array([res[m_count] for res in results], like=interface) + result = [res[m_count] for res in results] + if not isinstance(m, CountsMP): + # We don't need to cast to arrays when using qml.counts. qml.math.array is not viable + # as it assumes all elements of the input are of builtin python types and not belonging + # to any particular interface + result = qml.math.stack(result, like=interface) meas = gather_non_mcm(m, result, is_valid) m_count += 1 if isinstance(m, SampleMP): @@ -292,7 +301,9 @@ def gather_non_mcm(circuit_measurement, measurement, is_valid): if isinstance(circuit_measurement, CountsMP): tmp = Counter() for i, d in enumerate(measurement): - tmp.update(dict((k, v * is_valid[i]) for k, v in d.items())) + tmp.update( + dict((k if isinstance(k, str) else float(k), v * is_valid[i]) for k, v in d.items()) + ) tmp = Counter({k: v for k, v in tmp.items() if v > 0}) return dict(sorted(tmp.items())) if isinstance(circuit_measurement, ExpectationMP): @@ -341,14 +352,13 @@ def gather_mcm(measurement, samples, is_valid): counts = qml.math.array(counts, like=interface) return counts / qml.math.sum(counts) if isinstance(measurement, CountsMP): - mcm_samples = [{"".join(str(v) for v in tuple(s)): 1} for s in mcm_samples] + mcm_samples = [{"".join(str(int(v)) for v in tuple(s)): 1} for s in mcm_samples] return gather_non_mcm(measurement, mcm_samples, is_valid) + mcm_samples = qml.math.ravel(qml.math.array(mv.concretize(samples), like=interface)) if isinstance(measurement, ProbabilityMP): - mcm_samples = qml.math.array(mv.concretize(samples), like=interface).ravel() counts = [qml.math.sum((mcm_samples == v) * is_valid) for v in list(mv.branches.values())] counts = qml.math.array(counts, like=interface) return counts / qml.math.sum(counts) - mcm_samples = qml.math.array([mv.concretize(samples)], like=interface).ravel() if isinstance(measurement, CountsMP): - mcm_samples = [{s: 1} for s in mcm_samples] + mcm_samples = [{float(s): 1} for s in mcm_samples] return gather_non_mcm(measurement, mcm_samples, is_valid) diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index 4848ea4f6e3..e2184874cec 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for default qubit preprocessing.""" -from functools import partial, reduce +from functools import reduce from typing import Iterable, Sequence import numpy as np @@ -24,7 +24,11 @@ pytestmark = pytest.mark.slow -get_device = partial(qml.device, name="default.qubit", seed=8237945) + +def get_device(**kwargs): + kwargs.setdefault("shots", None) + kwargs.setdefault("seed", 8237945) + return qml.device("default.qubit", **kwargs) def validate_counts(shots, results1, results2, batch_size=None): @@ -88,7 +92,7 @@ def validate_samples(shots, results1, results2, batch_size=None): assert results1.ndim == results2.ndim if results2.ndim > 1: assert results1.shape[1] == results2.shape[1] - np.allclose(np.sum(results1), np.sum(results2), atol=20, rtol=0.2) + np.allclose(qml.math.sum(results1), qml.math.sum(results2), atol=20, rtol=0.2) def validate_expval(shots, results1, results2, batch_size=None): @@ -611,7 +615,7 @@ def test_sample_with_prng_key(shots, postselect, reset): # pylint: disable=import-outside-toplevel from jax.random import PRNGKey - dev = qml.device("default.qubit", shots=shots, seed=PRNGKey(678)) + dev = get_device(shots=shots, seed=PRNGKey(678)) param = [np.pi / 4, np.pi / 3] obs = qml.PauliZ(0) @ qml.PauliZ(1) @@ -659,7 +663,7 @@ def test_jax_jit(diff_method, postselect, reset): shots = 10 - dev = qml.device("default.qubit", shots=shots, seed=jax.random.PRNGKey(678)) + dev = get_device(shots=shots, seed=jax.random.PRNGKey(678)) params = [np.pi / 2.5, np.pi / 3, -np.pi / 3.5] obs = qml.PauliY(0) @@ -750,3 +754,44 @@ def func(x): results2 = func2(param) for r1, r2 in zip(results1.keys(), results2.keys()): assert r1 == r2 + + +@pytest.mark.torch +@pytest.mark.parametrize("postselect", [None, 1]) +@pytest.mark.parametrize("diff_method", [None, "best"]) +@pytest.mark.parametrize("measure_f", [qml.probs, qml.sample, qml.expval, qml.var]) +@pytest.mark.parametrize("meas_obj", [qml.PauliZ(1), [0, 1], "composite_mcm", "mcm_list"]) +def test_torch_integration(postselect, diff_method, measure_f, meas_obj): + """Test that native MCM circuits are executed correctly with Torch""" + if measure_f in (qml.var, qml.expval) and ( + isinstance(meas_obj, list) or meas_obj == "mcm_list" + ): + pytest.skip("Can't use wires/mcm lists with var or expval") + + import torch + + shots = 7000 + dev = get_device(shots=shots, seed=123456789) + param = torch.tensor(np.pi / 3, dtype=torch.float64) + + @qml.qnode(dev, diff_method=diff_method) + def func(x): + qml.RX(x, 0) + m0 = qml.measure(0) + qml.RX(0.5 * x, 1) + m1 = qml.measure(1, postselect=postselect) + qml.cond((m0 + m1) == 2, qml.RY)(2.0 * x, 0) + m2 = qml.measure(0) + + mid_measure = 0.5 * m2 if meas_obj == "composite_mcm" else [m1, m2] + measurement_key = "wires" if isinstance(meas_obj, list) else "op" + measurement_value = mid_measure if isinstance(meas_obj, str) else meas_obj + return measure_f(**{measurement_key: measurement_value}) + + func1 = func + func2 = qml.defer_measurements(func) + + results1 = func1(param) + results2 = func2(param) + + validate_measurements(measure_f, shots, results1, results2) diff --git a/tests/transforms/test_dynamic_one_shot.py b/tests/transforms/test_dynamic_one_shot.py index 3ca186171fe..ce9d67a429b 100644 --- a/tests/transforms/test_dynamic_one_shot.py +++ b/tests/transforms/test_dynamic_one_shot.py @@ -155,3 +155,168 @@ def test_len_measurements_mcms(measure, aux_measure, n_meas): assert len(aux_tape.measurements) == n_meas + n_mcms assert isinstance(aux_tape.measurements[0], aux_measure) assert all(isinstance(m, SampleMP) for m in aux_tape.measurements[1:]) + + +def assert_results(res, shots, n_mcms): + """Helper to check that expected raw results of executing the transformed tape are correct""" + assert len(res) == shots + # One for the non-MeasurementValue MP, and the rest of the mid-circuit measurements + assert all(len(r) == n_mcms + 1 for r in res) + # Not validating distribution of results as device sampling unit tests already validate + # that samples are generated correctly. + + +@pytest.mark.jax +@pytest.mark.parametrize("measure_f", (qml.expval, qml.probs, qml.sample, qml.var)) +@pytest.mark.parametrize("shots", [20, [20, 21]]) +@pytest.mark.parametrize("n_mcms", [1, 3]) +def test_tape_results_jax(shots, n_mcms, measure_f): + """Test that the simulation results of a tape are correct with jax parameters""" + import jax + + dev = qml.device("default.qubit", wires=4, shots=shots, seed=jax.random.PRNGKey(123)) + param = jax.numpy.array(np.pi / 2) + + mv = qml.measure(0) + mp = mv.measurements[0] + + tape = qml.tape.QuantumScript( + [qml.RX(param, 0), mp] + [MidMeasureMP(0, id=str(i)) for i in range(n_mcms - 1)], + [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)], + shots=shots, + ) + + tapes, _ = qml.dynamic_one_shot(tape) + results = dev.execute(tapes)[0] + + # The transformed tape never has a shot vector + if isinstance(shots, list): + shots = sum(shots) + + assert_results(results, shots, n_mcms) + + +@pytest.mark.jax +@pytest.mark.parametrize( + "measure_f, expected1, expected2", + [ + (qml.expval, 1.0, 1.0), + (qml.probs, [1, 0], [0, 1]), + (qml.sample, 1, 1), + (qml.var, 0.0, 0.0), + ], +) +@pytest.mark.parametrize("shots", [20, [20, 21]]) +@pytest.mark.parametrize("n_mcms", [1, 3]) +def test_jax_results_processing(shots, n_mcms, measure_f, expected1, expected2): + """Test that the results of tapes are processed correctly for tapes with jax parameters""" + import jax.numpy as jnp + + mv = qml.measure(0) + mp = mv.measurements[0] + + tape = qml.tape.QuantumScript( + [qml.RX(1.5, 0), mp] + [MidMeasureMP(0)] * (n_mcms - 1), + [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)], + shots=shots, + ) + _, fn = qml.dynamic_one_shot(tape) + all_shots = sum(shots) if isinstance(shots, list) else shots + + first_res = jnp.array([1.0, 0.0]) if measure_f == qml.probs else jnp.array(1.0) + rest = jnp.array(1, dtype=int) + single_shot_res = (first_res,) + (rest,) * n_mcms + # Raw results for each shot are (sample_for_first_measurement,) + (sample for 1st MCM, sample for 2nd MCM, ...) + raw_results = (single_shot_res,) * all_shots + raw_results = (raw_results,) + res = fn(raw_results) + + if measure_f is qml.sample: + # All samples 1 + expected1 = ( + [[expected1] * s for s in shots] if isinstance(shots, list) else [expected1] * shots + ) + expected2 = ( + [[expected2] * s for s in shots] if isinstance(shots, list) else [expected2] * shots + ) + else: + expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1 + expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2 + + if isinstance(shots, list): + assert len(res) == len(shots) + for r, e1, e2 in zip(res, expected1, expected2): + # Expected result is 2-list since we have two measurements in the tape + assert qml.math.allclose(r, [e1, e2]) + else: + # Expected result is 2-list since we have two measurements in the tape + assert qml.math.allclose(res, [expected1, expected2]) + + +@pytest.mark.jax +@pytest.mark.parametrize( + "measure_f, expected1, expected2", + [ + (qml.expval, 1.0, 1.0), + (qml.probs, [1, 0], [0, 1]), + (qml.sample, 1, 1), + (qml.var, 0.0, 0.0), + ], +) +@pytest.mark.parametrize("shots", [20, [20, 22]]) +def test_jax_results_postselection_processing(shots, measure_f, expected1, expected2): + """Test that the results of tapes are processed correctly for tapes with jax parameters + when postselecting""" + import jax.numpy as jnp + + param = jnp.array(np.pi / 2) + fill_value = np.iinfo(np.int32).min + mv = qml.measure(0, postselect=1) + mp = mv.measurements[0] + + tape = qml.tape.QuantumScript( + [qml.RX(param, 0), mp, MidMeasureMP(0)], + [measure_f(op=qml.PauliZ(0)), measure_f(op=mv)], + shots=shots, + ) + _, fn = qml.dynamic_one_shot(tape) + all_shots = sum(shots) if isinstance(shots, list) else shots + + # Alternating tuple. Only the values at odd indices are valid + first_res_two_shot = ( + (jnp.array([1.0, 0.0]), jnp.array([0.0, 1.0])) + if measure_f == qml.probs + else (jnp.array(1.0), jnp.array(0.0)) + ) + first_res = first_res_two_shot * (all_shots // 2) + # Tuple of alternating 1s and 0s. Zero is invalid as postselecting on 1 + postselect_res = (jnp.array(1, dtype=int), jnp.array(0, dtype=int)) * (all_shots // 2) + rest = (jnp.array(1, dtype=int),) * all_shots + # Raw results for each shot are (sample_for_first_measurement, sample for 1st MCM, sample for 2nd MCM) + raw_results = tuple(zip(first_res, postselect_res, rest)) + raw_results = (raw_results,) + res = fn(raw_results) + + if measure_f is qml.sample: + expected1 = ( + [[expected1, fill_value] * (s // 2) for s in shots] + if isinstance(shots, list) + else [expected1, fill_value] * (shots // 2) + ) + expected2 = ( + [[expected2, fill_value] * (s // 2) for s in shots] + if isinstance(shots, list) + else [expected2, fill_value] * (shots // 2) + ) + else: + expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1 + expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2 + + if isinstance(shots, list): + assert len(res) == len(shots) + for r, e1, e2 in zip(res, expected1, expected2): + # Expected result is 2-list since we have two measurements in the tape + assert qml.math.allclose(r, [e1, e2]) + else: + # Expected result is 2-list since we have two measurements in the tape + assert qml.math.allclose(res, [expected1, expected2])