diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 450625b31..a664c6634 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -32,6 +32,9 @@ * Replace the `dummy_tensor_update` method with the `cutensornetStateCaptureMPS`API to ensure that further gates apply is allowed after the `cutensornetStateCompute` call. [(#1028)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1028/) +* Add unit test for measurement with shots for Lightning Tensor with `tn` method. + [(#1027)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1027) + * Update the python layer UI of Lightning Tensor. [(#1022)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1022/) diff --git a/tests/conftest.py b/tests/conftest.py index 178982bbe..16bf95b13 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -195,12 +195,22 @@ def _device(wires, shots=None): # General LightningStateVector fixture, for any number of wires. @pytest.fixture( scope="function", - params=[np.complex64, np.complex128], + params=( + [np.complex64, np.complex128] + if device_name != "lightning.tensor" + else [ + [c_dtype, method] + for c_dtype in [np.complex64, np.complex128] + for method in ["mps", "tn"] + ] + ), ) def lightning_sv(request): def _statevector(num_wires): if device_name == "lightning.tensor": - return LightningStateVector(num_wires=num_wires, c_dtype=request.param) + return LightningStateVector( + num_wires=num_wires, c_dtype=request.param[0], method=request.param[1] + ) return LightningStateVector(num_wires=num_wires, dtype=request.param) return _statevector diff --git a/tests/lightning_qubit/test_measurements_class.py b/tests/lightning_qubit/test_measurements_class.py index 8e73a5b13..ced0c0e1f 100644 --- a/tests/lightning_qubit/test_measurements_class.py +++ b/tests/lightning_qubit/test_measurements_class.py @@ -545,7 +545,12 @@ def test_single_return_value(self, shots, measurement, observable, lightning_sv, # a few tests may fail in single precision, and hence we increase the tolerance if shots is None: - assert np.allclose(result, expected, max(tol, 1.0e-4)) + assert np.allclose( + result, + expected, + max(tol, 1.0e-4), + 1e-6 if statevector.dtype == np.complex64 else 1e-8, + ) else: # TODO Set better atol and rtol dtol = max(tol, 1.0e-2) @@ -788,6 +793,9 @@ def test_controlled_qubit_gates(self, operation, n_qubits, control_value, tol, l tape = qml.tape.QuantumScript(ops, measurements) statevector = lightning_sv(n_qubits) + if device_name == "lightning.tensor" and statevector.method == "tn": + pytest.skip("StatePrep not supported in lightning.tensor with the tn method.") + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) result = measure_final_state(m, tape) @@ -845,6 +853,9 @@ def test_cnot_controlled_qubit_unitary(self, control_wires, target_wires, tol, l ) statevector = lightning_sv(n_qubits) + if device_name == "lightning.tensor" and statevector.method == "tn": + pytest.skip("StatePrep not supported in lightning.tensor with the tn method.") + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) result = measure_final_state(m, tape) @@ -889,6 +900,9 @@ def test_controlled_globalphase(self, n_qubits, control_value, tol, lightning_sv [qml.state()], ) statevector = lightning_sv(n_qubits) + if device_name == "lightning.tensor" and statevector.method == "tn": + pytest.skip("StatePrep not supported in lightning.tensor with the tn method.") + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector) result = measure_final_state(m, tape) @@ -967,6 +981,9 @@ def test_state_vector_2_qubit_subset(tol, op, par, wires, expected, lightning_sv ) statevector = lightning_sv(2) + if device_name == "lightning.tensor" and statevector.method == "tn": + pytest.skip("StatePrep not supported in lightning.tensor with the tn method.") + statevector = get_final_state(statevector, tape) m = LightningMeasurements(statevector)