diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 2029528ae4..088b488492 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -2,6 +2,9 @@ ### New features since last release +* Add shot measurement support to `lightning.tensor`. + [(#852)](https://github.com/PennyLaneAI/pennylane-lightning/pull/852) + * Build and upload Lightning-Tensor wheels (x86_64, AARCH64) to PyPI. [(#862)](https://github.com/PennyLaneAI/pennylane-lightning/pull/862) diff --git a/pennylane_lightning/core/_version.py b/pennylane_lightning/core/_version.py index 79c3d48852..212f402ce0 100644 --- a/pennylane_lightning/core/_version.py +++ b/pennylane_lightning/core/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.39.0-dev17" +__version__ = "0.39.0-dev18" diff --git a/pennylane_lightning/core/src/bindings/Bindings.hpp b/pennylane_lightning/core/src/bindings/Bindings.hpp index 30a0eedde8..fc6e3e60a9 100644 --- a/pennylane_lightning/core/src/bindings/Bindings.hpp +++ b/pennylane_lightning/core/src/bindings/Bindings.hpp @@ -746,7 +746,27 @@ void registerLightningTensorBackendAgnosticMeasurements(PyClass &pyclass) { [](MeasurementsT &M, const std::shared_ptr &ob) { return M.var(*ob); }, - "Variance of an observable object."); + "Variance of an observable object.") + .def("generate_samples", [](MeasurementsT &M, + const std::vector &wires, + const std::size_t num_shots) { + constexpr auto sz = sizeof(std::size_t); + const std::size_t num_wires = wires.size(); + const std::size_t ndim = 2; + const std::vector shape{num_shots, num_wires}; + auto &&result = M.generate_samples(wires, num_shots); + + const std::vector strides{sz * num_wires, sz}; + // return 2-D NumPy array + return py::array(py::buffer_info( + result.data(), /* data as contiguous array */ + sz, /* size of one scalar */ + py::format_descriptor::format(), /* data type */ + ndim, /* number of dimensions */ + shape, /* shape of the matrix */ + strides /* strides for each axis */ + )); + }); } /** diff --git a/pennylane_lightning/core/src/simulators/lightning_tensor/tncuda/measurements/MeasurementsTNCuda.hpp b/pennylane_lightning/core/src/simulators/lightning_tensor/tncuda/measurements/MeasurementsTNCuda.hpp index c64203941d..34c22995f5 100644 --- a/pennylane_lightning/core/src/simulators/lightning_tensor/tncuda/measurements/MeasurementsTNCuda.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_tensor/tncuda/measurements/MeasurementsTNCuda.hpp @@ -20,6 +20,7 @@ #pragma once +#include #include #include #include @@ -164,6 +165,97 @@ template class MeasurementsTNCuda { return h_res; } + /** + * @brief Utility method for samples. + * + * @param wires Wires can be a subset or the full system. + * @param num_samples Number of samples + * @param numHyperSamples Number of hyper samples to use in the calculation + * and is default as 1. + * + * @return std::vector A 1-d array storing the samples. + * Each sample has a length equal to the number of wires. Each sample can + * be accessed using the stride `sample_id * num_wires`, where `sample_id` + * is a number between `0` and `num_samples - 1`. + */ + auto generate_samples(const std::vector &wires, + const std::size_t num_samples, + const int32_t numHyperSamples = 1) + -> std::vector { + std::vector samples(num_samples * wires.size()); + + const std::vector modesToSample = + cuUtil::NormalizeCastIndices( + wires, tensor_network_.getNumQubits()); + + cutensornetStateSampler_t sampler; + + PL_CUTENSORNET_IS_SUCCESS(cutensornetCreateSampler( + /* const cutensornetHandle_t */ tensor_network_.getTNCudaHandle(), + /* cutensornetState_t */ tensor_network_.getQuantumState(), + /* int32_t numModesToSample */ modesToSample.size(), + /* const int32_t *modesToSample */ modesToSample.data(), + /* cutensornetStateSampler_t * */ &sampler)); + + // Configure the quantum circuit sampler + const cutensornetSamplerAttributes_t samplerAttributes = + CUTENSORNET_SAMPLER_CONFIG_NUM_HYPER_SAMPLES; + + PL_CUTENSORNET_IS_SUCCESS(cutensornetSamplerConfigure( + /* const cutensornetHandle_t */ tensor_network_.getTNCudaHandle(), + /* cutensornetStateSampler_t */ sampler, + /* cutensornetSamplerAttributes_t */ samplerAttributes, + /* const void *attributeValue */ &numHyperSamples, + /* size_t attributeSize */ sizeof(numHyperSamples))); + + cutensornetWorkspaceDescriptor_t workDesc; + PL_CUTENSORNET_IS_SUCCESS(cutensornetCreateWorkspaceDescriptor( + /* const cutensornetHandle_t */ tensor_network_.getTNCudaHandle(), + /* cutensornetWorkspaceDescriptor_t * */ &workDesc)); + + const std::size_t scratchSize = cuUtil::getFreeMemorySize() / 2; + + // Prepare the quantum circuit sampler for sampling + PL_CUTENSORNET_IS_SUCCESS(cutensornetSamplerPrepare( + /* const cutensornetHandle_t */ tensor_network_.getTNCudaHandle(), + /* cutensornetStateSampler_t */ sampler, + /* size_t maxWorkspaceSizeDevice */ scratchSize, + /* cutensornetWorkspaceDescriptor_t */ workDesc, + /* cudaStream_t unused as of v24.08 */ 0x0)); + + std::size_t worksize = + getWorkSpaceMemorySize(tensor_network_.getTNCudaHandle(), workDesc); + + PL_ABORT_IF(worksize > scratchSize, + "Insufficient workspace size on Device.\n"); + + const std::size_t d_scratch_length = worksize / sizeof(size_t) + 1; + DataBuffer d_scratch(d_scratch_length, + tensor_network_.getDevTag(), true); + + setWorkSpaceMemory(tensor_network_.getTNCudaHandle(), workDesc, + reinterpret_cast(d_scratch.getData()), + worksize); + + PL_CUTENSORNET_IS_SUCCESS(cutensornetSamplerSample( + /* const cutensornetHandle_t */ tensor_network_.getTNCudaHandle(), + /* cutensornetStateSampler_t */ sampler, + /* int64_t numShots */ num_samples, + /* cutensornetWorkspaceDescriptor_t */ workDesc, + /* int64_t * */ samples.data(), + /* cudaStream_t unused as of v24.08 */ 0x0)); + + PL_CUTENSORNET_IS_SUCCESS( + cutensornetDestroyWorkspaceDescriptor(workDesc)); + PL_CUTENSORNET_IS_SUCCESS(cutensornetDestroySampler(sampler)); + + std::vector samples_size_t(samples.size()); + + std::transform(samples.begin(), samples.end(), samples_size_t.begin(), + [](int64_t x) { return static_cast(x); }); + return samples_size_t; + } + /** * @brief Calculate var value for a general ObservableTNCuda Observable. * diff --git a/pennylane_lightning/core/src/simulators/lightning_tensor/tncuda/measurements/tests/Test_MPSTNCuda_Measure.cpp b/pennylane_lightning/core/src/simulators/lightning_tensor/tncuda/measurements/tests/Test_MPSTNCuda_Measure.cpp index 5e89426d49..74923cf87c 100644 --- a/pennylane_lightning/core/src/simulators/lightning_tensor/tncuda/measurements/tests/Test_MPSTNCuda_Measure.cpp +++ b/pennylane_lightning/core/src/simulators/lightning_tensor/tncuda/measurements/tests/Test_MPSTNCuda_Measure.cpp @@ -26,6 +26,7 @@ #include "MPSTNCuda.hpp" #include "MeasurementsTNCuda.hpp" #include "TNCudaGateCache.hpp" +#include "TestHelpers.hpp" #include "cuda_helpers.hpp" /// @cond DEV @@ -33,6 +34,7 @@ namespace { using namespace Pennylane::LightningTensor::TNCuda::Measures; using namespace Pennylane::LightningTensor::TNCuda::Observables; using namespace Pennylane::LightningTensor::TNCuda; +using namespace Pennylane::Util; } // namespace /// @endcond @@ -92,3 +94,47 @@ TEMPLATE_TEST_CASE("Probabilities", "[Measures]", float, double) { REQUIRE_THROWS_AS(measure.probs({2, 1}), LightningException); } } + +TEMPLATE_TEST_CASE("Samples", "[Measures]", float, double) { + using TensorNetT = MPSTNCuda; + + SECTION("Looping over different wire configurations:") { + // Probabilities calculated with Pennylane default.qubit: + std::vector expected_probabilities = { + 0.67078706, 0.03062806, 0.0870997, 0.00397696, + 0.17564072, 0.00801973, 0.02280642, 0.00104134}; + + // Defining the State Vector that will be measured. + std::size_t bondDim = GENERATE(4, 5); + std::size_t num_qubits = 3; + std::size_t maxBondDim = bondDim; + + TensorNetT mps_state{num_qubits, maxBondDim}; + + mps_state.applyOperations( + {{"RX"}, {"RX"}, {"RY"}, {"RY"}, {"RX"}, {"RY"}}, + {{0}, {0}, {1}, {1}, {2}, {2}}, + {{false}, {false}, {false}, {false}, {false}, {false}}, + {{0.5}, {0.5}, {0.2}, {0.2}, {0.5}, {0.5}}); + mps_state.append_mps_final_state(); + + auto measure = MeasurementsTNCuda(mps_state); + + std::size_t num_samples = 100000; + const std::vector wires = {0, 1, 2}; + auto samples = measure.generate_samples(wires, num_samples); + auto counts = samples_to_decimal(samples, num_qubits, num_samples); + + // compute estimated probabilities from histogram + std::vector probabilities(counts.size()); + for (std::size_t i = 0; i < counts.size(); i++) { + probabilities[i] = counts[i] / static_cast(num_samples); + } + + // compare estimated probabilities to real probabilities + SECTION("No wires provided:") { + REQUIRE_THAT(probabilities, + Catch::Approx(expected_probabilities).margin(.1)); + } + } +} diff --git a/pennylane_lightning/core/src/utils/TestHelpers.hpp b/pennylane_lightning/core/src/utils/TestHelpers.hpp index 556ab24035..c4ef596f17 100644 --- a/pennylane_lightning/core/src/utils/TestHelpers.hpp +++ b/pennylane_lightning/core/src/utils/TestHelpers.hpp @@ -598,6 +598,35 @@ auto randomUnitary(RandomEngine &re, std::size_t num_qubits) return res; } +inline auto samples_to_decimal(const std::vector &samples, + const std::size_t num_qubits, + const std::size_t num_samples) + -> std::vector { + constexpr uint32_t twos[] = { + 1U << 0U, 1U << 1U, 1U << 2U, 1U << 3U, 1U << 4U, 1U << 5U, + 1U << 6U, 1U << 7U, 1U << 8U, 1U << 9U, 1U << 10U, 1U << 11U, + 1U << 12U, 1U << 13U, 1U << 14U, 1U << 15U, 1U << 16U, 1U << 17U, + 1U << 18U, 1U << 19U, 1U << 20U, 1U << 21U, 1U << 22U, 1U << 23U, + 1U << 24U, 1U << 25U, 1U << 26U, 1U << 27U, 1U << 28U, 1U << 29U, + 1U << 30U, 1U << 31U}; + + std::size_t N = std::pow(2, num_qubits); + std::vector counts(N, 0); + std::vector samples_decimal(num_samples, 0); + + // convert samples to decimal and then bin them in counts + for (std::size_t i = 0; i < num_samples; i++) { + for (std::size_t j = 0; j < num_qubits; j++) { + if (samples[i * num_qubits + j] != 0) { + samples_decimal[i] += twos[num_qubits - 1 - j]; + } + } + counts[samples_decimal[i]] += 1; + } + + return counts; +} + #define PL_REQUIRE_THROWS_MATCHES(expr, type, message_match) \ REQUIRE_THROWS_AS(expr, type); \ REQUIRE_THROWS_WITH(expr, Catch::Matchers::Contains(message_match)); diff --git a/pennylane_lightning/core/src/utils/cuda_utils/tests/CMakeLists.txt b/pennylane_lightning/core/src/utils/cuda_utils/tests/CMakeLists.txt index b72365a315..c8353963c3 100644 --- a/pennylane_lightning/core/src/utils/cuda_utils/tests/CMakeLists.txt +++ b/pennylane_lightning/core/src/utils/cuda_utils/tests/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.20) -project(lightning_gpu_utils_tests) +project(cuda_utils_tests) # Default build type for test code is Debug if(NOT CMAKE_BUILD_TYPE) diff --git a/pennylane_lightning/lightning_tensor/_measurements.py b/pennylane_lightning/lightning_tensor/_measurements.py index 93ab53e506..c2d5d93d4d 100644 --- a/pennylane_lightning/lightning_tensor/_measurements.py +++ b/pennylane_lightning/lightning_tensor/_measurements.py @@ -21,17 +21,25 @@ except ImportError: pass -from typing import Callable +from functools import reduce +from typing import Callable, List, Union import numpy as np import pennylane as qml +from pennylane.devices.qubit.sampling import _group_measurements from pennylane.measurements import ( + ClassicalShadowMP, + CountsMP, ExpectationMP, MeasurementProcess, ProbabilityMP, + SampleMeasurement, + ShadowExpvalMP, + Shots, StateMeasurement, VarianceMP, ) +from pennylane.ops import Hamiltonian, SparseHamiltonian, Sum from pennylane.tape import QuantumScript from pennylane.typing import Result, TensorLike from pennylane.wires import Wires @@ -169,9 +177,13 @@ def get_measurement_function( """ if isinstance(measurementprocess, StateMeasurement): if isinstance(measurementprocess, ExpectationMP): + if isinstance(measurementprocess.obs, qml.Identity): + return self.state_diagonalizing_gates return self.expval if isinstance(measurementprocess, VarianceMP): + if isinstance(measurementprocess.obs, qml.Identity): + return self.state_diagonalizing_gates return self.var if isinstance(measurementprocess, ProbabilityMP): @@ -207,9 +219,167 @@ def measure_tensor_network(self, circuit: QuantumScript) -> Result: """ if circuit.shots: - raise NotImplementedError("Shots are not supported for tensor network simulations.") + # finite-shot case + results = self.measure_with_samples( + circuit.measurements, + shots=circuit.shots, + ) + + if len(circuit.measurements) == 1: + if circuit.shots.has_partitioned_shots: + return tuple(res[0] for res in results) + + return results[0] + + return results # analytic case if len(circuit.measurements) == 1: return self.measurement(circuit.measurements[0]) return tuple(self.measurement(mp) for mp in circuit.measurements) + + # pylint:disable = too-many-arguments + def measure_with_samples( + self, + measurements: List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]], + shots: Shots, + ) -> List[TensorLike]: + """ + Returns the samples of the measurement process performed on the given state. + This function assumes that the user-defined wire labels in the measurement process + have already been mapped to integer wires used in the device. + + Args: + measurements (List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]]): + The sample measurements to perform + shots (Shots): The number of samples to take + + Returns: + List[TensorLike[Any]]: Sample measurement results + """ + mps = measurements + groups, indices = _group_measurements(mps) + + all_res = [] + for group in groups: + if isinstance(group[0], (ExpectationMP, VarianceMP)) and isinstance( + group[0].obs, SparseHamiltonian + ): + raise TypeError( + "ExpectationMP/VarianceMP(SparseHamiltonian) cannot be computed with samples." + ) + if isinstance(group[0], VarianceMP) and isinstance(group[0].obs, (Hamiltonian, Sum)): + raise TypeError("VarianceMP(Hamiltonian/Sum) cannot be computed with samples.") + if isinstance(group[0], (ClassicalShadowMP, ShadowExpvalMP)): + raise TypeError( + "ExpectationMP(ClassicalShadowMP, ShadowExpvalMP) cannot be computed with samples." + ) + if isinstance(group[0], ExpectationMP) and isinstance(group[0].obs, Sum): + all_res.extend(self._measure_sum_with_samples(group, shots)) + else: + all_res.extend(self._measure_with_samples_diagonalizing_gates(group, shots)) + + # reorder results + flat_indices = [] + for row in indices: + flat_indices += row + sorted_res = tuple( + res for _, res in sorted(list(enumerate(all_res)), key=lambda r: flat_indices[r[0]]) + ) + + # put the shot vector axis before the measurement axis + if shots.has_partitioned_shots: + sorted_res = tuple(zip(*sorted_res)) + + return sorted_res + + def _apply_diagonalizing_gates(self, mps: List[SampleMeasurement], adjoint: bool = False): + if len(mps) == 1: + diagonalizing_gates = mps[0].diagonalizing_gates() + elif all(mp.obs for mp in mps): + diagonalizing_gates = qml.pauli.diagonalize_qwc_pauli_words([mp.obs for mp in mps])[0] + else: + diagonalizing_gates = [] + + if adjoint: + diagonalizing_gates = [ + qml.adjoint(g, lazy=False) for g in reversed(diagonalizing_gates) + ] + + self._tensornet.apply_operations(diagonalizing_gates) + self._tensornet.appendMPSFinalState() + + def _measure_with_samples_diagonalizing_gates( + self, + mps: List[SampleMeasurement], + shots: Shots, + ) -> TensorLike: + """ + Returns the samples of the measurement process performed on the given state, + by rotating the state into the measurement basis using the diagonalizing gates + given by the measurement process. + + Args: + mps (~.measurements.SampleMeasurement): The sample measurements to perform + shots (~.measurements.Shots): The number of samples to take + + Returns: + TensorLike[Any]: Sample measurement results + """ + # apply diagonalizing gates + self._apply_diagonalizing_gates(mps) + + wires = reduce(sum, (mp.wires for mp in mps)) + + def _process_single_shot(samples): + processed = [] + for mp in mps: + res = mp.process_samples(samples, wires) + if not isinstance(mp, CountsMP): + res = qml.math.squeeze(res) + + processed.append(res) + + return tuple(processed) + + try: + samples = self._measurement_lightning.generate_samples( + list(wires), shots.total_shots + ).astype(int, copy=False) + except ValueError as e: + if str(e) != "probabilities contain NaN": + raise e + samples = qml.math.full((shots.total_shots, len(wires)), 0) + + self._apply_diagonalizing_gates(mps, adjoint=True) + + # if there is a shot vector, use the shots.bins generator to + # split samples w.r.t. the shots + processed_samples = [] + for lower, upper in shots.bins(): + result = _process_single_shot(samples[..., lower:upper, :]) + processed_samples.append(result) + + return ( + tuple(zip(*processed_samples)) if shots.has_partitioned_shots else processed_samples[0] + ) + + def _measure_sum_with_samples( + self, + mp: List[SampleMeasurement], + shots: Shots, + ): + # the list contains only one element based on how we group measurements + mp = mp[0] + + # if the measurement process involves a Sum, measure each + # of the terms separately and sum + def _sum_for_single_shot(s): + results = self.measure_with_samples( + [ExpectationMP(t) for t in mp.obs], + s, + ) + return sum(results) + + unsqueezed_results = tuple(_sum_for_single_shot(type(shots)(s)) for s in shots) + return [unsqueezed_results] if shots.has_partitioned_shots else [unsqueezed_results[0]] diff --git a/pennylane_lightning/lightning_tensor/_tensornet.py b/pennylane_lightning/lightning_tensor/_tensornet.py index 505d0f850d..31a64c9ace 100644 --- a/pennylane_lightning/lightning_tensor/_tensornet.py +++ b/pennylane_lightning/lightning_tensor/_tensornet.py @@ -316,6 +316,7 @@ def set_tensor_network(self, circuit: QuantumScript): """ self.apply_operations(circuit.operations) self.appendMPSFinalState() + return self def appendMPSFinalState(self): """ diff --git a/pennylane_lightning/lightning_tensor/lightning_tensor.py b/pennylane_lightning/lightning_tensor/lightning_tensor.py index 4920d1f1da..01ec65c5cb 100644 --- a/pennylane_lightning/lightning_tensor/lightning_tensor.py +++ b/pennylane_lightning/lightning_tensor/lightning_tensor.py @@ -198,6 +198,9 @@ class LightningTensor(Device): Args: wires (int): The number of wires to initialize the device with. Defaults to ``None`` if not specified. + shots (int): Measurements are performed drawing ``shots`` times from a discrete random variable distribution associated with a state vector and an observable. Defaults to ``None`` if not specified. Setting + to ``None`` results in computing statistics like expectation values and + variances analytically. method (str): Supported method. Currently, only ``mps`` is supported. c_dtype: Datatypes for the tensor representation. Must be one of ``numpy.complex64`` or ``numpy.complex128``. Default is ``numpy.complex128``. @@ -253,6 +256,7 @@ def __init__( self, *, wires=None, + shots=None, method: str = "mps", c_dtype=np.complex128, **kwargs, @@ -269,7 +273,7 @@ def __init__( if wires is None: raise ValueError("The number of wires must be specified.") - super().__init__(wires=wires, shots=None) + super().__init__(wires=wires, shots=shots) if isinstance(wires, int): self._wire_map = None # should just use wires as is @@ -372,7 +376,6 @@ def preprocess( This device currently: - * Does not support finite shots. * Does not support derivatives. * Does not support vector-Jacobian products. """ @@ -387,6 +390,7 @@ def preprocess( program.add_transform( decompose, stopping_condition=stopping_condition, + stopping_condition_shots=stopping_condition, skip_initial_state_prep=True, name=self.name, ) diff --git a/tests/conftest.py b/tests/conftest.py index b5ddf416ce..a648418465 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -152,10 +152,14 @@ def get_device(): from pennylane_lightning.lightning_gpu_ops import LightningException elif device_name == "lightning.tensor": from pennylane_lightning.lightning_tensor import LightningTensor as LightningDevice + from pennylane_lightning.lightning_tensor._measurements import ( + LightningTensorMeasurements as LightningMeasurements, + ) + from pennylane_lightning.lightning_tensor._tensornet import ( + LightningTensorNet as LightningStateVector, + ) LightningAdjointJacobian = None - LightningMeasurements = None - LightningStateVector = None if hasattr(pennylane_lightning, "lightning_tensor_ops"): import pennylane_lightning.lightning_tensor_ops as lightning_ops @@ -178,8 +182,6 @@ def get_device(): ) def qubit_device(request): def _device(wires, shots=None): - if device_name == "lightning.tensor": - return qml.device(device_name, wires=wires, c_dtype=request.param) return qml.device(device_name, wires=wires, shots=shots, c_dtype=request.param) return _device @@ -192,6 +194,8 @@ def _device(wires, shots=None): ) def lightning_sv(request): def _statevector(num_wires): + if device_name == "lightning.tensor": + return LightningStateVector(num_wires=num_wires, c_dtype=request.param) return LightningStateVector(num_wires=num_wires, dtype=request.param) return _statevector diff --git a/tests/lightning_qubit/test_measurements_class.py b/tests/lightning_qubit/test_measurements_class.py index 4bb6eaf9a1..2567366393 100644 --- a/tests/lightning_qubit/test_measurements_class.py +++ b/tests/lightning_qubit/test_measurements_class.py @@ -38,8 +38,6 @@ allow_module_level=True, ) -if device_name == "lightning.tensor": - pytest.skip("Skipping tests for the LightningTensor class.", allow_module_level=True) if not LightningDevice._CPP_BINARY_AVAILABLE: pytest.skip("No binary module found. Skipping.", allow_module_level=True) @@ -61,13 +59,63 @@ def process_state(self, state, wire_order): return 1 +# Observables not supported in lightning.tensor +def obs_not_supported_in_ltensor(obs): + if device_name == "lightning.tensor": + if isinstance(obs, qml.Projector) or isinstance(obs, qml.SparseHamiltonian): + return True + if isinstance(obs, qml.Hamiltonian): + return any([obs_not_supported_in_ltensor(o) for o in obs]) + if isinstance(obs, qml.Hermitian) and len(obs.wires) > 1: + return True + if isinstance(obs, list) and all([isinstance(o, int) for o in obs]): # out of order probs + return obs != sorted(obs) + return False + else: + return False + + +# Ops not supported in lightning.tensor +def ops_not_supported_in_ltensor(ops): + if device_name == "lightning.tensor": + unsupported_ops = [qml.MultiRZ, qml.GlobalPhase] + if any([ops == op for op in unsupported_ops]): + return True + return False + else: + return False + + +def controlled_gate_not_supported_in_ltensor(ops): + if device_name == "lightning.tensor": + if ops.num_wires > 1: + return True + else: + return False + + +def get_final_state(statevector, tape): + if device_name == "lightning.tensor": + return statevector.set_tensor_network(tape) + return statevector.get_final_state(tape) + + +def measure_final_state(m, tape): + if device_name == "lightning.tensor": + return m.measure_tensor_network(tape) + return m.measure_final_state(tape) + + def test_initialization(lightning_sv): """Tests for the initialization of the LightningMeasurements class.""" statevector = lightning_sv(num_wires=5) m = LightningMeasurements(statevector) - assert m.qubit_state is statevector - assert m.dtype == statevector.dtype + if device_name == "lightning.tensor": + assert m.dtype == statevector.dtype + else: + assert m.qubit_state is statevector + assert m.dtype == statevector.dtype class TestGetMeasurementFunction: @@ -96,6 +144,9 @@ def test_only_support_state_measurements(self, lightning_sv): ) def test_state_diagonalizing_gates_measurements(self, lightning_sv, mp): """Test that any non-expval measurement calls the state_diagonalizing_gates method""" + if obs_not_supported_in_ltensor(mp.obs): + pytest.skip("Observable not supported in lightning.tensor.") + statevector = lightning_sv(num_wires=5) m = LightningMeasurements(statevector) @@ -117,6 +168,9 @@ def test_state_diagonalizing_gates_measurements(self, lightning_sv, mp): ) def test_expval_selected(self, lightning_sv, obs): """Test that expval is chosen for a variety of different expectation values.""" + if obs_not_supported_in_ltensor(obs): + pytest.skip("Observable not supported in lightning.tensor.") + statevector = lightning_sv(num_wires=5) m = LightningMeasurements(statevector) mp = qml.expval(obs) @@ -174,6 +228,10 @@ def test_identity_expval(self, lightning_sv, method_name): result = getattr(m, method_name)(qml.expval(qml.I(4))) assert np.allclose(result, 1.0) + @pytest.mark.skipif( + device_name == "lightning.tensor", + reason="lightning.tensor does not support a single-wire circuit.", + ) def test_basis_state_projector_expval(self, lightning_sv, method_name): """Test expectation value for a basis state projector.""" phi = 0.8 @@ -183,6 +241,10 @@ def test_basis_state_projector_expval(self, lightning_sv, method_name): result = getattr(m, method_name)(qml.expval(qml.Projector([0], wires=0))) assert qml.math.allclose(result, np.cos(phi / 2) ** 2) + @pytest.mark.skipif( + device_name == "lightning.tensor", + reason="lightning.tensor does not support a single-wire circuit.", + ) def test_state_vector_projector_expval(self, lightning_sv, method_name): """Test expectation value for a state vector projector.""" phi = -0.6 @@ -212,9 +274,9 @@ def test_identity(self, theta, phi, tol, lightning_sv): tape = qml.tape.QuantumScript(ops, measurements) statevector = lightning_sv(wires) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) expected = np.cos(theta) assert np.allclose(result, expected, tol) @@ -228,9 +290,9 @@ def test_identity_expectation(self, theta, phi, tol, lightning_sv): [qml.expval(qml.Identity(wires=[0])), qml.expval(qml.Identity(wires=[1]))], ) statevector = lightning_sv(wires) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) expected = 1.0 assert np.allclose(result, expected, tol) @@ -243,9 +305,9 @@ def test_multi_wire_identity_expectation(self, theta, phi, tol, lightning_sv): [qml.expval(qml.Identity(wires=[0, 1]))], ) statevector = lightning_sv(wires) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) expected = 1.0 assert np.allclose(result, expected, tol) @@ -291,9 +353,9 @@ def test_single_wire_observables_expectation( [qml.expval(Obs[0]), qml.expval(Obs[1])], ) statevector = lightning_sv(wires) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) expected = expected_fn(theta, phi) assert np.allclose(result, expected, tol) @@ -330,6 +392,10 @@ class TestExpvalHamiltonian: ) def test_expval_hamiltonian(self, obs, coeffs, expected, tol, lightning_sv, method_name): """Test expval with Hamiltonian""" + + if any(isinstance(o, qml.Hermitian) for o in obs) and device_name == "lightning.tensor": + pytest.skip("Hermitian with 1+ wires target not supported in lightning.tensor.") + ham = qml.Hamiltonian(coeffs, obs) statevector = lightning_sv(self.wires) @@ -341,6 +407,9 @@ def test_expval_hamiltonian(self, obs, coeffs, expected, tol, lightning_sv, meth assert np.allclose(result, expected, atol=tol, rtol=0) +@pytest.mark.skipif( + device_name == "lightning.tensor", reason="lightning.tensor does not support sparseH." +) class TestSparseExpval: """Tests for the expval function""" @@ -371,9 +440,9 @@ def test_sparse_Pauli_words(self, ham_terms, expected, tol, lightning_sv): tape = qml.tape.QuantumScript(ops, measurements) statevector = lightning_sv(self.wires) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) assert np.allclose(result, expected, tol) @@ -388,9 +457,7 @@ def calculate_reference(tape, lightning_sv): for m in tape.measurements: # NotImplementedError in DefaultQubit # We therefore validate against `qml.Hermitian` - if isinstance(m, VarianceMP) and isinstance( - m.obs, (qml.Hamiltonian, qml.SparseHamiltonian) - ): + if isinstance(m, VarianceMP) and isinstance(m.obs, (qml.SparseHamiltonian)): use_default = False new_meas.append(m.__class__(qml.Hermitian(qml.matrix(m.obs), wires=m.obs.wires))) continue @@ -404,9 +471,9 @@ def calculate_reference(tape, lightning_sv): tape = qml.tape.QuantumScript(tape.operations, new_meas) statevector = lightning_sv(tape.num_wires) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) - return m.measure_final_state(tape) + return measure_final_state(m, tape) @flaky(max_runs=5) @pytest.mark.parametrize("shots", [None, 500_000, [500_000, 500_000]]) @@ -432,6 +499,9 @@ def calculate_reference(tape, lightning_sv): ), ) def test_single_return_value(self, shots, measurement, observable, lightning_sv, tol): + if obs_not_supported_in_ltensor(observable): + pytest.skip("Observable not supported in lightning.tensor.") + if measurement is qml.probs and isinstance( observable, ( @@ -461,7 +531,8 @@ def test_single_return_value(self, shots, measurement, observable, lightning_sv, np.random.seed(0) weights = np.random.rand(n_layers, n_qubits, 3) ops = [qml.Hadamard(i) for i in range(n_qubits)] - ops += [qml.StronglyEntanglingLayers(weights, wires=range(n_qubits))] + if device_name != "lightning.tensor": + ops += [qml.StronglyEntanglingLayers(weights, wires=range(n_qubits))] measurements = ( [measurement(wires=observable)] if isinstance(observable, list) @@ -470,7 +541,7 @@ def test_single_return_value(self, shots, measurement, observable, lightning_sv, tape = qml.tape.QuantumScript(ops, measurements, shots=shots) statevector = lightning_sv(n_qubits) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) skip_list = ( @@ -484,10 +555,10 @@ def test_single_return_value(self, shots, measurement, observable, lightning_sv, do_skip = do_skip and shots is not None if do_skip: with pytest.raises(TypeError): - _ = m.measure_final_state(tape) + _ = measure_final_state(m, tape) return else: - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) expected = self.calculate_reference(tape, lightning_sv) @@ -538,6 +609,9 @@ def test_single_return_value(self, shots, measurement, observable, lightning_sv, ), ) def test_double_return_value(self, shots, measurement, obs0_, obs1_, lightning_sv, tol): + if obs_not_supported_in_ltensor(obs0_) or obs_not_supported_in_ltensor(obs1_): + pytest.skip("Observable not supported in lightning.tensor.") + skip_list = ( qml.ops.Sum, qml.ops.SProd, @@ -564,12 +638,13 @@ def test_double_return_value(self, shots, measurement, obs0_, obs1_, lightning_s np.random.seed(0) weights = np.random.rand(n_layers, n_qubits, 3) ops = [qml.Hadamard(i) for i in range(n_qubits)] - ops += [qml.StronglyEntanglingLayers(weights, wires=range(n_qubits))] + if device_name != "lightning.tensor": + ops += [qml.StronglyEntanglingLayers(weights, wires=range(n_qubits))] measurements = [measurement(op=obs0_), measurement(op=obs1_)] tape = qml.tape.QuantumScript(ops, measurements, shots=shots) statevector = lightning_sv(n_qubits) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) skip_list = ( @@ -589,10 +664,10 @@ def test_double_return_value(self, shots, measurement, obs0_, obs1_, lightning_s do_skip = do_skip and shots is not None if do_skip: with pytest.raises(TypeError): - _ = m.measure_final_state(tape) + _ = measure_final_state(m, tape) return else: - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) expected = self.calculate_reference(tape, lightning_sv) if len(expected) == 1: @@ -610,6 +685,10 @@ def test_double_return_value(self, shots, measurement, obs0_, obs1_, lightning_s # allclose -> absolute(r - e) <= (atol + rtol * absolute(e)) assert np.allclose(r, e, atol=dtol, rtol=dtol) + @pytest.mark.skipif( + device_name == "lightning.tensor", + reason="lightning.tensor does not support out of order probs.", + ) @pytest.mark.parametrize( "cases", [ @@ -645,6 +724,7 @@ def calculate_reference(tape): results = dev.execute(tapes) return transf_fn(results) + @flaky(max_runs=5) @pytest.mark.parametrize( "operation", [ @@ -678,10 +758,15 @@ def calculate_reference(tape): @pytest.mark.parametrize("n_qubits", list(range(2, 5))) def test_controlled_qubit_gates(self, operation, n_qubits, control_value, tol, lightning_sv): """Test that multi-controlled gates are correctly applied to a state""" - threshold = 250 + threshold = 250 if device_name != "lightning.tensor" else 5 num_wires = max(operation.num_wires, 1) np.random.seed(0) + if ops_not_supported_in_ltensor(operation): + pytest.skip("Controlled operation not supported in lightning.tensor.") + if controlled_gate_not_supported_in_ltensor(operation): + pytest.skip("Controlled operation not supported in lightning.tensor.") + for n_wires in range(num_wires + 1, num_wires + 4): wire_lists = list(itertools.permutations(range(0, n_qubits), n_wires)) n_perms = len(wire_lists) * n_wires @@ -702,9 +787,11 @@ def test_controlled_qubit_gates(self, operation, n_qubits, control_value, tol, l qml.ctrl( operation(target_wires), control_wires, - control_values=[ - control_value or bool(i % 2) for i, _ in enumerate(control_wires) - ], + control_values=( + [control_value or bool(i % 2) for i, _ in enumerate(control_wires)] + if device_name != "lightning.tensor" + else [control_value for _ in control_wires] + ), ), ] else: @@ -712,9 +799,11 @@ def test_controlled_qubit_gates(self, operation, n_qubits, control_value, tol, l qml.ctrl( operation(*tuple([0.1234] * operation.num_params), target_wires), control_wires, - control_values=[ - control_value or bool(i % 2) for i, _ in enumerate(control_wires) - ], + control_values=( + [control_value or bool(i % 2) for i, _ in enumerate(control_wires)] + if device_name != "lightning.tensor" + else [control_value for _ in control_wires] + ), ), ] @@ -722,13 +811,19 @@ def test_controlled_qubit_gates(self, operation, n_qubits, control_value, tol, l tape = qml.tape.QuantumScript(ops, measurements) statevector = lightning_sv(n_qubits) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) expected = self.calculate_reference(tape) + if device_name == "lightning.tensor": + assert np.allclose(result, expected, 1e-4) + else: + assert np.allclose(result, expected, tol * 10) - assert np.allclose(result, expected, tol * 10) - + @pytest.mark.skipif( + device_name != "lightning.qubit", + reason="N-controlled operations only implemented in lightning.qubit.", + ) def test_controlled_qubit_unitary_from_op(self, tol, lightning_sv): n_qubits = 10 par = 0.1234 @@ -743,13 +838,14 @@ def test_controlled_qubit_unitary_from_op(self, tol, lightning_sv): ) statevector = lightning_sv(n_qubits) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) expected = self.calculate_reference(tape) assert np.allclose(result, expected, tol) + @flaky(max_runs=5) @pytest.mark.parametrize("control_wires", range(4)) @pytest.mark.parametrize("target_wires", range(4)) def test_cnot_controlled_qubit_unitary(self, control_wires, target_wires, tol, lightning_sv): @@ -776,12 +872,15 @@ def test_cnot_controlled_qubit_unitary(self, control_wires, target_wires, tol, l ) statevector = lightning_sv(n_qubits) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) expected = self.calculate_reference(tape_cnot) - assert np.allclose(result, expected, tol) + if device_name == "lightning.tensor": + assert np.allclose(result, expected, 1e-4) + else: + assert np.allclose(result, expected, tol) @pytest.mark.parametrize("control_value", [False, True]) @pytest.mark.parametrize("n_qubits", list(range(2, 8))) @@ -789,6 +888,8 @@ def test_controlled_globalphase(self, n_qubits, control_value, tol, lightning_sv """Test that multi-controlled gates are correctly applied to a state""" threshold = 250 operation = qml.GlobalPhase + if ops_not_supported_in_ltensor(operation): + pytest.skip("Operation not supported in lightning.tensor.") num_wires = max(operation.num_wires, 1) for n_wires in range(num_wires + 1, num_wires + 4): wire_lists = list(itertools.permutations(range(0, n_qubits), n_wires)) @@ -815,9 +916,9 @@ def test_controlled_globalphase(self, n_qubits, control_value, tol, lightning_sv [qml.state()], ) statevector = lightning_sv(n_qubits) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) expected = self.calculate_reference(tape) assert np.allclose(result, expected, tol) @@ -836,9 +937,9 @@ def test_sprod(self, phi, lightning_sv, tol): [qml.expval(qml.s_prod(0.5, qml.PauliZ(0)))], ) statevector = lightning_sv(self.wires) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) expected = 0.5 * np.cos(phi) assert np.allclose(result, expected, tol) @@ -850,9 +951,9 @@ def test_prod(self, phi, lightning_sv, tol): [qml.expval(qml.prod(qml.PauliZ(0), qml.PauliX(1)))], ) statevector = lightning_sv(self.wires) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) expected = -np.cos(phi) assert np.allclose(result, expected, tol) @@ -865,9 +966,9 @@ def test_sum(self, phi, theta, lightning_sv, tol): [qml.expval(qml.sum(qml.PauliZ(0), qml.PauliX(1)))], ) statevector = lightning_sv(self.wires) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) expected = np.cos(phi) + np.sin(theta) assert np.allclose(result, expected, tol) @@ -891,9 +992,9 @@ def test_state_vector_2_qubit_subset(tol, op, par, wires, expected, lightning_sv ) statevector = lightning_sv(2) - statevector = statevector.get_final_state(tape) + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) - result = m.measure_final_state(tape) + result = measure_final_state(m, tape) assert np.allclose(result, expected, tol) diff --git a/tests/lightning_tensor/test_gates_and_expval.py b/tests/lightning_tensor/test_gates_and_expval.py index 9d2213d528..e8a73fcb5a 100644 --- a/tests/lightning_tensor/test_gates_and_expval.py +++ b/tests/lightning_tensor/test_gates_and_expval.py @@ -266,32 +266,6 @@ def test_var_hermitian_not_supported(self): ): m.var(q.queue[0]) - def test_measurement_shot_not_supported(self): - """Test shots measurement error for measure_tensor_network.""" - obs = [ - qml.expval(qml.PauliX(0) @ qml.Identity(1)), - ] - - tensornet = LightningTensorNet(4, 10) - tape = qml.tape.QuantumScript(measurements=obs, shots=1000) - m = LightningTensorMeasurements(tensornet) - - with pytest.raises( - NotImplementedError, match="Shots are not supported for tensor network simulations." - ): - m.measure_tensor_network(tape) - - def test_measurement_not_supported(self): - """Test error for measure_tensor_network.""" - obs = [qml.sample(wires=0)] - - tensornet = LightningTensorNet(4, 10) - tape = qml.tape.QuantumScript(measurements=obs) - m = LightningTensorMeasurements(tensornet) - - with pytest.raises(NotImplementedError, match="Unsupported measurement type."): - m.measure_tensor_network(tape) - class QChem: """Integration tests for qchem module by parameter-shift and finite-diff differentiation methods.""" diff --git a/tests/lightning_tensor/test_lightning_tensor.py b/tests/lightning_tensor/test_lightning_tensor.py index 5c82babdfb..44b1ee9229 100644 --- a/tests/lightning_tensor/test_lightning_tensor.py +++ b/tests/lightning_tensor/test_lightning_tensor.py @@ -46,6 +46,8 @@ def test_device_available_as_plugin(): """Test that the device can be instantiated using ``qml.device``.""" dev = qml.device("lightning.tensor", wires=2) assert isinstance(dev, LightningTensor) + assert dev.backend == "cutensornet" + assert dev.method in ["mps"] @pytest.mark.parametrize("backend", ["fake_backend"]) @@ -55,6 +57,12 @@ def test_invalid_backend(backend): LightningTensor(wires=1, backend=backend) +def test_invalid_arg(): + """Test that an error is raised if an invalid argument is provided.""" + with pytest.raises(TypeError): + LightningTensor(wires=2, kwargs="invalid_arg") + + @pytest.mark.parametrize("method", ["fake_method"]) def test_invalid_method(method): """Test an invalid method.""" diff --git a/tests/lightning_tensor/test_measurements_class.py b/tests/lightning_tensor/test_measurements_class.py index 6d9574aa83..cdae207f8a 100644 --- a/tests/lightning_tensor/test_measurements_class.py +++ b/tests/lightning_tensor/test_measurements_class.py @@ -14,10 +14,16 @@ """ Unit tests for measurements class. """ +from typing import Sequence + import numpy as np import pennylane as qml import pytest from conftest import LightningDevice, device_name # tested device +from flaky import flaky +from pennylane.devices import DefaultQubit +from pennylane.measurements import VarianceMP +from scipy.sparse import csr_matrix, random_array if device_name != "lightning.tensor": pytest.skip( @@ -40,7 +46,11 @@ ) def lightning_tn(request): """Fixture for creating a LightningTensorNet object.""" - return LightningTensorNet(num_wires=5, max_bond_dim=128, c_dtype=request.param) + + def _lightning_tn(n_wires): + return LightningTensorNet(num_wires=n_wires, max_bond_dim=128, c_dtype=request.param) + + return _lightning_tn class TestMeasurementFunction: @@ -48,7 +58,7 @@ class TestMeasurementFunction: def test_initialization(self, lightning_tn): """Tests for the initialization of the LightningTensorMeasurements class.""" - tensornetwork = lightning_tn + tensornetwork = lightning_tn(2) m = LightningTensorMeasurements(tensornetwork) assert m.dtype == tensornetwork.dtype @@ -56,24 +66,63 @@ def test_initialization(self, lightning_tn): def test_not_implemented_state_measurements(self, lightning_tn): """Test than a NotImplementedError is raised if the measurement is not a state measurement.""" - tensornetwork = lightning_tn + tensornetwork = lightning_tn(2) m = LightningTensorMeasurements(tensornetwork) mp = qml.counts(wires=(0, 1)) with pytest.raises(NotImplementedError): m.get_measurement_function(mp) - def test_not_measure_tensor_network(self, lightning_tn): - """Test than a NotImplementedError is raised if the measurement is not a state measurement.""" + def test_not_supported_sparseH_shot_measurements(self): + """Test than a TypeError is raised if the measurement is not supported.""" + + tensornetwork = LightningTensorNet(num_wires=3, max_bond_dim=128) - tensornetwork = lightning_tn m = LightningTensorMeasurements(tensornetwork) - tape = qml.tape.QuantumScript( - [qml.RX(0.1, wires=0), qml.Hadamard(1), qml.PauliZ(1)], - [qml.expval(qml.prod(qml.PauliZ(0), qml.PauliX(1)))], - shots=1000, + ops = [qml.PauliX(0), qml.PauliZ(1)] + + obs = qml.SparseHamiltonian( + qml.Hamiltonian([-1.0, 1.5], [qml.Z(1), qml.X(1)]).sparse_matrix(wire_order=[0, 1, 2]), + wires=[0, 1, 2], ) - with pytest.raises(NotImplementedError): - m.measure_tensor_network(tape) + for mp in [qml.var(obs), qml.expval(obs)]: + tape = qml.tape.QuantumScript(ops, [mp], shots=100) + + with pytest.raises(TypeError): + m.measure_tensor_network(tape) + + def test_not_supported_ham_sum_shot_measurements(self): + """Test than a TypeError is raised if the measurement is not supported.""" + + tensornetwork = LightningTensorNet(num_wires=3, max_bond_dim=128) + + m = LightningTensorMeasurements(tensornetwork) + + ops = [qml.PauliX(0), qml.PauliZ(1)] + + obs_ham = qml.Hamiltonian([-1.0, 1.5], [qml.Z(1), qml.X(1)]) + + obs_sum = qml.sum(qml.PauliX(0), qml.PauliX(1)) + + for mp in [qml.var(obs_ham), qml.var(obs_sum)]: + tape = qml.tape.QuantumScript(ops, [mp], shots=100) + + with pytest.raises(TypeError): + m.measure_tensor_network(tape) + + def test_not_supported_shadowmp_shot_measurements(self): + """Test than a TypeError is raised if the measurement is not supported.""" + + tensornetwork = LightningTensorNet(num_wires=3, max_bond_dim=128) + + m = LightningTensorMeasurements(tensornetwork) + + ops = [qml.PauliX(0), qml.PauliZ(1)] + + for mp in [qml.classical_shadow(wires=[0, 1]), qml.shadow_expval(qml.PauliX(0))]: + tape = qml.tape.QuantumScript(ops, [mp], shots=100) + + with pytest.raises(TypeError): + m.measure_tensor_network(tape) diff --git a/tests/test_apply.py b/tests/test_apply.py index 8c06a25af4..9e769b8533 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -564,7 +564,7 @@ def test_expval_single_wire_no_parameters( @pytest.mark.skipif( device_name == "lightning.tensor", - reason="lightning.tensor does not support shot measurement", + reason="lightning.tensor does not support single wire devices", ) def test_expval_estimate(self): """Test that the expectation value is not analytically calculated""" @@ -623,7 +623,7 @@ def test_var_single_wire_no_parameters( @pytest.mark.skipif( device_name == "lightning.tensor", - reason="lightning.tensor does not support shot measurement and single-wire devices", + reason="lightning.tensor does not support single-wire devices", ) def test_var_estimate(self): """Test that the variance is not analytically calculated""" @@ -641,10 +641,6 @@ def circuit(): assert var != 1.0 -@pytest.mark.skipif( - device_name == "lightning.tensor", - reason="lightning.tensor device does not support qml.samples()", -) class TestSample: """Tests that samples are properly calculated.""" @@ -785,7 +781,7 @@ def circuit(x): @pytest.mark.skipif( device_name == "lightning.tensor", - reason="lightning.tensor does not support shot measurements", + reason="lightning.tensor does not support single wire devices", ) def test_nonzero_shots(self, tol_stochastic): """Test that the default qubit plugin provides correct result for high shot number""" @@ -810,7 +806,7 @@ def circuit(x): # This test is ran against the state |0> with one Z expval @pytest.mark.skipif( device_name == "lightning.tensor", - reason="lightning.tensor does not support shot measurements and single-wire devices", + reason="lightning.tensor does not support single-wire devices", ) @pytest.mark.parametrize( "name,expected_output", @@ -956,10 +952,6 @@ def circuit(): assert np.allclose(circuit(), expected_output, atol=tol, rtol=0) # This test is run with two expvals - @pytest.mark.skipif( - device_name == "lightning.tensor", - reason="lightning.tensor does not support QubitStateVector", - ) @pytest.mark.parametrize( "name,par,wires,expected_output", [ @@ -1022,7 +1014,8 @@ def circuit(): # This test is ran on the state |0> with one Z expvals @pytest.mark.skipif( - device_name == "lightning.tensor", reason="lightning.tensor requires num_wires > 1" + device_name == "lightning.tensor", + reason="lightning.tensor does not support single wire devices", ) @pytest.mark.parametrize( "name,par,expected_output", @@ -1173,10 +1166,6 @@ def circuit(): assert np.isclose(circuit(), expected_output, atol=tol, rtol=0) - @pytest.mark.skipif( - device_name == "lightning.tensor", - reason="lightning.tensor does not support shot measurements", - ) def test_multi_samples_return_correlated_results(self, qubit_device): """Tests if the samples returned by the sample function have the correct dimensions @@ -1194,10 +1183,6 @@ def circuit(): assert np.array_equal(outcomes[0], outcomes[1]) - @pytest.mark.skipif( - device_name == "lightning.tensor", - reason="lightning.tensor does not support shot measurements.", - ) @pytest.mark.parametrize("num_wires", [3, 4, 5, 6, 7, 8]) def test_multi_samples_return_correlated_results_more_wires_than_size_of_observable( self, num_wires @@ -1219,10 +1204,6 @@ def circuit(): assert np.array_equal(outcomes[0], outcomes[1]) - @pytest.mark.skipif( - device_name == "lightning.tensor", - reason="lightning.tensor does not support shot measurements", - ) def test_snapshot_is_ignored_without_shot(self): """Tests if the Snapshot operator is ignored correctly""" dev = qml.device(device_name, wires=4) @@ -1239,10 +1220,6 @@ def circuit(): assert np.allclose(outcomes, [0.0]) - @pytest.mark.skipif( - device_name == "lightning.tensor", - reason="lightning.tensor does not support shot measurements", - ) def test_snapshot_is_ignored_with_shots(self): """Tests if the Snapshot operator is ignored correctly""" dev = qml.device(device_name, wires=4, shots=1000) @@ -1343,7 +1320,7 @@ class TestApplyLightningMethod: @pytest.mark.skipif( device_name == "lightning.tensor", - reason="lightning.tensor does not support _apply_state_vector", + reason="lightning.tensor does not support direct access to the state", ) @pytest.mark.skipif(ld._new_API, reason="Old API required") def test_apply_identity_skipped(self, mocker, tol): diff --git a/tests/test_measurements.py b/tests/test_measurements.py index 0fdb3fafa3..d892fd50b9 100644 --- a/tests/test_measurements.py +++ b/tests/test_measurements.py @@ -625,10 +625,6 @@ def circuit2(): assert np.allclose(circuit1(), circuit2(), atol=tol) -@pytest.mark.skipif( - device_name == "lightning.tensor", - reason="lightning.tensor does not support qml.sample()", -) class TestSample: """Tests that samples are properly calculated.""" @@ -744,10 +740,6 @@ def circuit2(): assert np.allclose(circuit1(), circuit2(), atol=tol) -@pytest.mark.skipif( - device_name == "lightning.tensor", - reason="lightning.tensor does not support shots", -) @flaky(max_runs=5) @pytest.mark.parametrize("shots", [None, 10000, [10000, 11111]]) @pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var]) @@ -768,9 +760,9 @@ def test_shots_single_measure_obs(shots, measure_f, obs, mcmc, kernel_name): """Tests that Lightning handles shots in a circuit where a single measurement of a common observable is performed at the end.""" n_qubits = 3 - if (shots is None or device_name in ("lightning.gpu", "lightning.kokkos")) and ( - mcmc or kernel_name != "Local" - ): + if ( + shots is None or device_name in ("lightning.gpu", "lightning.kokkos", "lightning.tensor") + ) and (mcmc or kernel_name != "Local"): pytest.skip(f"Device {device_name} does not have an mcmc option.") if measure_f in (qml.expval, qml.var) and isinstance(obs, Sequence): @@ -779,7 +771,7 @@ def test_shots_single_measure_obs(shots, measure_f, obs, mcmc, kernel_name): if measure_f in (qml.counts, qml.sample) and shots is None: pytest.skip("qml.counts, qml.sample do not work with shots = None.") - if device_name in ("lightning.gpu", "lightning.kokkos"): + if device_name in ("lightning.gpu", "lightning.kokkos", "lightning.tensor"): dev = qml.device(device_name, wires=n_qubits, shots=shots) else: dev = qml.device( @@ -807,7 +799,7 @@ def func(x, y): # TODO: Add LT after extending the support for shots_vector @pytest.mark.skipif( device_name == "lightning.tensor", - reason="lightning.tensor does not support shot vectors.", + reason="lightning.tensor does not support single-wire devices.", ) @pytest.mark.parametrize("shots", ((1, 10), (1, 10, 100), (1, 10, 10, 100, 100, 100))) def test_shots_bins(shots, qubit_device):