Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding unit test for measurement with shots for LT with tn method #1027

Merged
merged 5 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pennylane_lightning/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.40.0-dev37"
__version__ = "0.40.0-dev38"
14 changes: 12 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,22 @@ def _device(wires, shots=None):
# General LightningStateVector fixture, for any number of wires.
@pytest.fixture(
scope="function",
params=[np.complex64, np.complex128],
params=(
[np.complex64, np.complex128]
if device_name != "lightning.tensor"
else [
[c_dtype, method]
for c_dtype in [np.complex64, np.complex128]
for method in ["mps", "tn"]
]
),
)
def lightning_sv(request):
def _statevector(num_wires):
if device_name == "lightning.tensor":
return LightningStateVector(num_wires=num_wires, c_dtype=request.param)
return LightningStateVector(
num_wires=num_wires, c_dtype=request.param[0], method=request.param[1]
)
return LightningStateVector(num_wires=num_wires, dtype=request.param)

return _statevector
Expand Down
27 changes: 26 additions & 1 deletion tests/lightning_qubit/test_measurements_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,12 @@ def test_single_return_value(self, shots, measurement, observable, lightning_sv,

# a few tests may fail in single precision, and hence we increase the tolerance
if shots is None:
assert np.allclose(result, expected, max(tol, 1.0e-4))
assert np.allclose(
result,
expected,
max(tol, 1.0e-4),
1e-6 if statevector.dtype == np.complex64 else 1e-8,
)
else:
# TODO Set better atol and rtol
dtol = max(tol, 1.0e-2)
Expand Down Expand Up @@ -788,6 +793,12 @@ def test_controlled_qubit_gates(self, operation, n_qubits, control_value, tol, l
tape = qml.tape.QuantumScript(ops, measurements)

statevector = lightning_sv(n_qubits)
if device_name == "lightning.tensor":
LuisAlfredoNu marked this conversation as resolved.
Show resolved Hide resolved
if statevector.method == "tn":
LuisAlfredoNu marked this conversation as resolved.
Show resolved Hide resolved
pytest.skip(
"StatePrep not supported in lightning.tensor with the tn method."
)

statevector = get_final_state(statevector, tape)
m = LightningMeasurements(statevector)
result = measure_final_state(m, tape)
Expand Down Expand Up @@ -845,6 +856,10 @@ def test_cnot_controlled_qubit_unitary(self, control_wires, target_wires, tol, l
)

statevector = lightning_sv(n_qubits)
if device_name == "lightning.tensor":
LuisAlfredoNu marked this conversation as resolved.
Show resolved Hide resolved
if statevector.method == "tn":
pytest.skip("StatePrep not supported in lightning.tensor with the tn method.")

statevector = get_final_state(statevector, tape)
m = LightningMeasurements(statevector)
result = measure_final_state(m, tape)
Expand Down Expand Up @@ -889,6 +904,12 @@ def test_controlled_globalphase(self, n_qubits, control_value, tol, lightning_sv
[qml.state()],
)
statevector = lightning_sv(n_qubits)
if device_name == "lightning.tensor":
if statevector.method == "tn":
pytest.skip(
"StatePrep not supported in lightning.tensor with the tn method."
)

statevector = get_final_state(statevector, tape)
m = LightningMeasurements(statevector)
result = measure_final_state(m, tape)
Expand Down Expand Up @@ -967,6 +988,10 @@ def test_state_vector_2_qubit_subset(tol, op, par, wires, expected, lightning_sv
)

statevector = lightning_sv(2)
if device_name == "lightning.tensor":
if statevector.method == "tn":
pytest.skip("StatePrep not supported in lightning.tensor with the tn method.")

statevector = get_final_state(statevector, tape)

m = LightningMeasurements(statevector)
Expand Down
Loading