Skip to content

Commit

Permalink
Addressing comments from CR
Browse files Browse the repository at this point in the history
  • Loading branch information
PietropaoloFrisoni committed Apr 25, 2024
1 parent 149d126 commit 9ef5625
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 25 deletions.
25 changes: 7 additions & 18 deletions pennylane_lightning/lightning_tensor/lightning_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,11 @@ class LightningTensor(Device):

# So far we just consider the options for MPS simulator
_device_options = (
"apply_reverse_lightcone",
"backend",
"c_dtype",
"cutoff",
"method",
"max_bond_dim",
"measure_algorithm",
"return_tn",
"rehearse",
)

_new_API = True
Expand All @@ -112,7 +108,7 @@ def __init__(
raise ValueError(f"Unsupported method: {method}")

if shots is not None:
raise ValueError("LightningTensor does not support the `shots` parameter.")
raise ValueError("LightningTensor does not support finite shots.")

super().__init__(wires=wires, shots=shots)

Expand All @@ -124,12 +120,6 @@ def __init__(
# options for MPS
self._max_bond_dim = kwargs.get("max_bond_dim", None)
self._cutoff = kwargs.get("cutoff", 1e-16)
self._measure_algorithm = kwargs.get("measure_algorithm", None)

# common options (MPS and TN)
self._return_tn = kwargs.get("return_tn", False)
self._rehearse = kwargs.get("rehearse", False)
self._apply_reverse_lightcone = kwargs.get("apply_reverse_lightcone", None)

self._interface = None
interface_opts = self._setup_execution_config().device_options
Expand Down Expand Up @@ -253,8 +243,7 @@ def supports_derivatives(
Bool: Whether or not a derivative can be calculated provided the given information.
"""
# TODO: implement during next quarter
return False # pragma: no cover
return False

def compute_derivatives(
self,
Expand All @@ -272,7 +261,7 @@ def compute_derivatives(
"""
raise NotImplementedError(
"The computation of derivatives has yet to be implemented for the lightning.tensor device."
) # pragma: no cover
)

def execute_and_compute_derivatives(
self,
Expand All @@ -290,7 +279,7 @@ def execute_and_compute_derivatives(
"""
raise NotImplementedError(
"The computation of derivatives has yet to be implemented for the lightning.tensor device."
) # pragma: no cover
)

# pylint: disable=unused-argument
def supports_vjp(
Expand All @@ -308,7 +297,7 @@ def supports_vjp(
Bool: Whether or not a derivative can be calculated provided the given information.
"""
# TODO: implement during next quarter
return False # pragma: no cover
return False

def compute_vjp(
self,
Expand All @@ -330,7 +319,7 @@ def compute_vjp(
"""
raise NotImplementedError(
"The computation of vector jacobian product has yet to be implemented for the lightning.tensor device."
) # pragma: no cover
)

def execute_and_compute_vjp(
self,
Expand All @@ -351,4 +340,4 @@ def execute_and_compute_vjp(
"""
raise NotImplementedError(
"The computation of vector jacobian product has yet to be implemented for the lightning.tensor device."
) # pragma: no cover
)
2 changes: 0 additions & 2 deletions pennylane_lightning/lightning_tensor/quimb/_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def __init__(self, num_wires, interf_opts, dtype=np.complex128):

self._wires = Wires(range(num_wires))
self._dtype = dtype
self._return_tn = interf_opts["return_tn"]

self._init_state_ops = {
"binary": "0" * max(1, len(self._wires)),
Expand All @@ -75,7 +74,6 @@ def __init__(self, num_wires, interf_opts, dtype=np.complex128):
"dtype": self._dtype.__name__,
"simplify_sequence": "ADCRS",
"simplify_atol": 0.0,
"rehearse": interf_opts["rehearse"],
}

self._circuitMPS = qtn.CircuitMPS(psi0=self._initial_mps())
Expand Down
46 changes: 43 additions & 3 deletions tests/lightning_tensor/test_lightning_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
if not LightningDevice._new_API:
pytest.skip("Exclusive tests for new API. Skipping.", allow_module_level=True)

if LightningDevice._CPP_BINARY_AVAILABLE:
pytest.skip("Device doesn't have C++ support yet.", allow_module_level=True)
# if LightningDevice._CPP_BINARY_AVAILABLE:
# pytest.skip("Device doesn't have C++ support yet.", allow_module_level=True)


@pytest.mark.parametrize("num_wires", [None, 4])
Expand Down Expand Up @@ -71,5 +71,45 @@ def test_invalid_keyword_arg():

def test_invalid_shots():
"""Test that an error is raised if finite number of shots are requestd."""
with pytest.raises(ValueError, match="LightningTensor does not support the `shots` parameter."):
with pytest.raises(ValueError, match="LightningTensor does not support finite shots."):
LightningTensor(shots=5)


def test_support_derivatives():
"""Test that the device does not support derivatives yet."""
dev = LightningTensor()
assert not dev.supports_derivatives()


def test_compute_derivatives():
"""Test that an error is raised if the `compute_derivatives` method is called."""
dev = LightningTensor()
with pytest.raises(NotImplementedError):
dev.compute_derivatives(circuits=None)


def test_execute_and_compute_derivatives():
"""Test that an error is raised if `execute_and_compute_derivative` method is called."""
dev = LightningTensor()
with pytest.raises(NotImplementedError):
dev.execute_and_compute_derivatives(circuits=None)


def test_supports_vjp():
"""Test that the device does not support VJP yet."""
dev = LightningTensor()
assert not dev.supports_vjp()


def test_compute_vjp():
"""Test that an error is raised if `compute_vjp` method is called."""
dev = LightningTensor()
with pytest.raises(NotImplementedError):
dev.compute_vjp(circuits=None, cotangents=None)


def test_execute_and_compute_vjp():
"""Test that an error is raised if `execute_and_compute_vjp` method is called."""
dev = LightningTensor()
with pytest.raises(NotImplementedError):
dev.execute_and_compute_vjp(circuits=None, cotangents=None)
4 changes: 2 additions & 2 deletions tests/lightning_tensor/test_quimb_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
if not LightningDevice._new_API:
pytest.skip("Exclusive tests for new API. Skipping.", allow_module_level=True)

if LightningDevice._CPP_BINARY_AVAILABLE:
pytest.skip("Device doesn't have C++ support yet.", allow_module_level=True)
# if LightningDevice._CPP_BINARY_AVAILABLE:
# pytest.skip("Device doesn't have C++ support yet.", allow_module_level=True)


@pytest.mark.parametrize("backend", ["quimb"])
Expand Down

0 comments on commit 9ef5625

Please sign in to comment.