From 41aace7ffb377366fae1a20a4ae8b9f33d246ad1 Mon Sep 17 00:00:00 2001 From: Shuli <08cnbj@gmail.com> Date: Tue, 24 Oct 2023 23:52:09 +0000 Subject: [PATCH] add fp32 tests --- mpitests/test_expval.py | 2 -- .../lightning_gpu/lightning_gpu.py | 16 ++++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/mpitests/test_expval.py b/mpitests/test_expval.py index 70b55d7f35..bc6f6fc079 100644 --- a/mpitests/test_expval.py +++ b/mpitests/test_expval.py @@ -105,8 +105,6 @@ def test_hadamard_expectation(self, theta, phi, tol): """Test that Hadamard expectation value is correct""" dev = qml.device(device_name, mpi=True, wires=3) - if device_name == "lightning.gpu" and dev.R_DTYPE == np.float32: - pytest.skip("Skipped FP32 tests for expval in lightning.gpu") O1 = qml.Hadamard(wires=[0]) O2 = qml.Hadamard(wires=[1]) diff --git a/pennylane_lightning/lightning_gpu/lightning_gpu.py b/pennylane_lightning/lightning_gpu/lightning_gpu.py index 8b36a3fc2f..279b1c2228 100644 --- a/pennylane_lightning/lightning_gpu/lightning_gpu.py +++ b/pennylane_lightning/lightning_gpu/lightning_gpu.py @@ -361,23 +361,23 @@ def state(self): @property def create_ops_list(self): """Returns create_ops_list function of the matching precision.""" - if not self._mpi: - return create_ops_listC64 if self.use_csingle else create_ops_listC128 - return create_ops_listMPIC64 if self.use_csingle else create_ops_listMPIC128 + if self._mpi: + return create_ops_listMPIC64 if self.use_csingle else create_ops_listMPIC128 + return create_ops_listC64 if self.use_csingle else create_ops_listC128 @property def measurements(self): """Returns Measurements constructor of the matching precision.""" - if not self._mpi: + if self._mpi: return ( - MeasurementsC64(self._gpu_state) + MeasurementsMPIC64(self._gpu_state) if self.use_csingle - else MeasurementsC128(self._gpu_state) + else MeasurementsMPIC128(self._gpu_state) ) return ( - MeasurementsMPIC64(self._gpu_state) + MeasurementsC64(self._gpu_state) if self.use_csingle - else MeasurementsMPIC128(self._gpu_state) + else MeasurementsC128(self._gpu_state) ) def syncD2H(self, state_vector, use_async=False):