Skip to content

Commit

Permalink
Merge branch 'master' into add_mps_capture
Browse files Browse the repository at this point in the history
  • Loading branch information
multiphaseCFD authored Dec 18, 2024
2 parents 034c866 + 4ddcdf2 commit a410064
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
* Replace the `dummy_tensor_update` method with the `cutensornetStateCaptureMPS`API to ensure that further gates apply is allowed after the `cutensornetStateCompute` call.
[(#1028)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1028/)

* Add unit test for measurement with shots for Lightning Tensor with `tn` method.
[(#1027)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1027)

* Update the python layer UI of Lightning Tensor.
[(#1022)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1022/)

Expand Down
14 changes: 12 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,22 @@ def _device(wires, shots=None):
# General LightningStateVector fixture, for any number of wires.
@pytest.fixture(
scope="function",
params=[np.complex64, np.complex128],
params=(
[np.complex64, np.complex128]
if device_name != "lightning.tensor"
else [
[c_dtype, method]
for c_dtype in [np.complex64, np.complex128]
for method in ["mps", "tn"]
]
),
)
def lightning_sv(request):
def _statevector(num_wires):
if device_name == "lightning.tensor":
return LightningStateVector(num_wires=num_wires, c_dtype=request.param)
return LightningStateVector(
num_wires=num_wires, c_dtype=request.param[0], method=request.param[1]
)
return LightningStateVector(num_wires=num_wires, dtype=request.param)

return _statevector
Expand Down
19 changes: 18 additions & 1 deletion tests/lightning_qubit/test_measurements_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,12 @@ def test_single_return_value(self, shots, measurement, observable, lightning_sv,

# a few tests may fail in single precision, and hence we increase the tolerance
if shots is None:
assert np.allclose(result, expected, max(tol, 1.0e-4))
assert np.allclose(
result,
expected,
max(tol, 1.0e-4),
1e-6 if statevector.dtype == np.complex64 else 1e-8,
)
else:
# TODO Set better atol and rtol
dtol = max(tol, 1.0e-2)
Expand Down Expand Up @@ -788,6 +793,9 @@ def test_controlled_qubit_gates(self, operation, n_qubits, control_value, tol, l
tape = qml.tape.QuantumScript(ops, measurements)

statevector = lightning_sv(n_qubits)
if device_name == "lightning.tensor" and statevector.method == "tn":
pytest.skip("StatePrep not supported in lightning.tensor with the tn method.")

statevector = get_final_state(statevector, tape)
m = LightningMeasurements(statevector)
result = measure_final_state(m, tape)
Expand Down Expand Up @@ -845,6 +853,9 @@ def test_cnot_controlled_qubit_unitary(self, control_wires, target_wires, tol, l
)

statevector = lightning_sv(n_qubits)
if device_name == "lightning.tensor" and statevector.method == "tn":
pytest.skip("StatePrep not supported in lightning.tensor with the tn method.")

statevector = get_final_state(statevector, tape)
m = LightningMeasurements(statevector)
result = measure_final_state(m, tape)
Expand Down Expand Up @@ -889,6 +900,9 @@ def test_controlled_globalphase(self, n_qubits, control_value, tol, lightning_sv
[qml.state()],
)
statevector = lightning_sv(n_qubits)
if device_name == "lightning.tensor" and statevector.method == "tn":
pytest.skip("StatePrep not supported in lightning.tensor with the tn method.")

statevector = get_final_state(statevector, tape)
m = LightningMeasurements(statevector)
result = measure_final_state(m, tape)
Expand Down Expand Up @@ -967,6 +981,9 @@ def test_state_vector_2_qubit_subset(tol, op, par, wires, expected, lightning_sv
)

statevector = lightning_sv(2)
if device_name == "lightning.tensor" and statevector.method == "tn":
pytest.skip("StatePrep not supported in lightning.tensor with the tn method.")

statevector = get_final_state(statevector, tape)

m = LightningMeasurements(statevector)
Expand Down

0 comments on commit a410064

Please sign in to comment.