diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index b7aab6ea9e..700adc8e5b 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -2,6 +2,9 @@ ### New features since last release +* Add Python class for the `lightning.tensor` device which uses the new device API and the interface for `quimb` based on the MPS method. + [(#671)](https://github.com/PennyLaneAI/pennylane-lightning/pull/671) + * Add compile-time support for AVX2/512 streaming operations in `lightning.qubit`. [(#664)](https://github.com/PennyLaneAI/pennylane-lightning/pull/664) @@ -129,11 +132,14 @@ * Update the version of `codecov-action` to v4 and fix the CodeCov issue with the PL-Lightning check-compatibility actions. [(#682)](https://github.com/PennyLaneAI/pennylane-lightning/pull/682) +* Increase tolerance for a flaky test. + [(#703)](https://github.com/PennyLaneAI/pennylane-lightning/pull/703) + ### Contributors This release contains contributions from (in alphabetical order): -Ali Asadi, Amintor Dusko, Christina Lee, Vincent Michaud-Rioux, Lee James O'Riordan, Mudit Pandey, Shuli Shu +Ali Asadi, Amintor Dusko, Pietropaolo Frisoni, Christina Lee, Vincent Michaud-Rioux, Lee James O'Riordan, Mudit Pandey, Shuli Shu --- diff --git a/.github/workflows/tests_linux_python.yml b/.github/workflows/tests_linux_python.yml index 9e7e6ee574..dbb6ee1037 100644 --- a/.github/workflows/tests_linux_python.yml +++ b/.github/workflows/tests_linux_python.yml @@ -43,7 +43,7 @@ jobs: strategy: matrix: pl_backend: ["lightning_qubit"] - timeout-minutes: 60 + timeout-minutes: 75 name: Python tests runs-on: ${{ needs.determine_runner.outputs.runner_group }} @@ -150,7 +150,7 @@ jobs: strategy: matrix: pl_backend: ["lightning_qubit"] - timeout-minutes: 60 + timeout-minutes: 75 name: Python tests with OpenBLAS runs-on: ${{ needs.determine_runner.outputs.runner_group }} @@ -265,7 +265,7 @@ jobs: exclude: - pl_backend: ["all"] exec_model: OPENMP - timeout-minutes: 60 + timeout-minutes: 75 name: Python tests with Kokkos runs-on: ${{ needs.determine_runner.outputs.runner_group }} diff --git a/pennylane_lightning/core/_version.py b/pennylane_lightning/core/_version.py index b324311e95..0198b9a3ca 100644 --- a/pennylane_lightning/core/_version.py +++ b/pennylane_lightning/core/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.36.0-dev44" +__version__ = "0.36.0-dev45" diff --git a/pennylane_lightning/lightning_tensor/__init__.py b/pennylane_lightning/lightning_tensor/__init__.py new file mode 100644 index 0000000000..48cc140c46 --- /dev/null +++ b/pennylane_lightning/lightning_tensor/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PennyLane lightning_tensor package.""" + +from pennylane_lightning.core import __version__ + +from .lightning_tensor import LightningTensor diff --git a/pennylane_lightning/lightning_tensor/backends/quimb/_mps.py b/pennylane_lightning/lightning_tensor/backends/quimb/_mps.py new file mode 100644 index 0000000000..c5f53b15ec --- /dev/null +++ b/pennylane_lightning/lightning_tensor/backends/quimb/_mps.py @@ -0,0 +1,100 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Class implementation for the Quimb MPS interface for simulating quantum circuits while keeping the state always in MPS form. +""" + +import numpy as np +import pennylane as qml +import quimb.tensor as qtn +from pennylane.wires import Wires + +_operations = frozenset({}) # pragma: no cover +# The set of supported operations. + +_observables = frozenset({}) # pragma: no cover +# The set of supported observables. + + +def stopping_condition(op: qml.operation.Operator) -> bool: + """A function that determines if an operation is supported by ``lightning.tensor`` for this interface.""" + return op.name in _operations # pragma: no cover + + +def accepted_observables(obs: qml.operation.Operator) -> bool: + """A function that determines if an observable is supported by ``lightning.tensor`` for this interface.""" + return obs.name in _observables # pragma: no cover + + +class QuimbMPS: + """Quimb MPS class. + + Used internally by the `LightningTensor` device. + Interfaces with `quimb` for MPS manipulation, and provides methods to execute quantum circuits. + + Args: + num_wires (int): the number of wires in the circuit. + interf_opts (dict): dictionary containing the interface options. + dtype (np.dtype): the complex type used for the MPS. + """ + + def __init__(self, num_wires, interf_opts, dtype=np.complex128): + + if dtype not in [np.complex64, np.complex128]: # pragma: no cover + raise TypeError(f"Unsupported complex type: {dtype}") + + self._wires = Wires(range(num_wires)) + self._dtype = dtype + + self._init_state_ops = { + "binary": "0" * max(1, len(self._wires)), + "dtype": self._dtype.__name__, + "tags": [str(l) for l in self._wires.labels], + } + + self._gate_opts = { + "contract": "swap+split", + "parametrize": None, + "cutoff": interf_opts["cutoff"], + "max_bond": interf_opts["max_bond_dim"], + } + + self._expval_opts = { + "dtype": self._dtype.__name__, + "simplify_sequence": "ADCRS", + "simplify_atol": 0.0, + } + + self._circuitMPS = qtn.CircuitMPS(psi0=self._initial_mps()) + + @property + def state(self): + """Current MPS handled by the interface.""" + return self._circuitMPS.psi + + def state_to_array(self) -> np.ndarray: + """Contract the MPS into a dense array.""" + return self._circuitMPS.to_dense() + + def _initial_mps(self) -> qtn.MatrixProductState: + r""" + Returns an initial state to :math:`\ket{0}`. + + Internally, it uses `quimb`'s `MPS_computational_state` method. + + Returns: + MatrixProductState: The initial MPS of a circuit. + """ + + return qtn.MPS_computational_state(**self._init_state_ops) diff --git a/pennylane_lightning/lightning_tensor/lightning_tensor.py b/pennylane_lightning/lightning_tensor/lightning_tensor.py new file mode 100644 index 0000000000..fc8974cda2 --- /dev/null +++ b/pennylane_lightning/lightning_tensor/lightning_tensor.py @@ -0,0 +1,348 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains the LightningTensor class that inherits from the new device interface. +It is a device to perform tensor network simulation of a quantum circuit. +""" +from dataclasses import replace +from numbers import Number +from typing import Callable, Optional, Sequence, Tuple, Union + +import numpy as np +import pennylane as qml +from pennylane.devices import DefaultExecutionConfig, Device, ExecutionConfig +from pennylane.devices.modifiers import simulator_tracking, single_tape_support +from pennylane.tape import QuantumTape +from pennylane.transforms.core import TransformProgram +from pennylane.typing import Result, ResultBatch + +from .backends.quimb._mps import QuimbMPS + +Result_or_ResultBatch = Union[Result, ResultBatch] +QuantumTapeBatch = Sequence[QuantumTape] +QuantumTape_or_Batch = Union[QuantumTape, QuantumTapeBatch] +PostprocessingFn = Callable[[ResultBatch], Result_or_ResultBatch] + + +_backends = frozenset({"quimb"}) +# The set of supported backends. + +_methods = frozenset({"mps"}) +# The set of supported methods. + + +def accepted_backends(backend: str) -> bool: + """A function that determines whether or not a backend is supported by ``lightning.tensor``.""" + return backend in _backends + + +def accepted_methods(method: str) -> bool: + """A function that determines whether or not a method is supported by ``lightning.tensor``.""" + return method in _methods + + +@simulator_tracking +@single_tape_support +class LightningTensor(Device): + """PennyLane Lightning Tensor device. + + A device to perform tensor network operations on a quantum circuit. + + Args: + wires (int): The number of wires to initialize the device with. + Defaults to ``None`` if not specified. + backend (str): Supported backend. Currently, only ``quimb`` is supported. + method (str): Supported method. Currently, only ``mps`` is supported. + shots (int): How many times the circuit should be evaluated (or sampled) to estimate + the expectation values. Currently, it can only be ``None``, so that computation of + statistics like expectation values and variances is performed analytically. + c_dtype: Datatypes for the tensor representation. Must be one of + ``np.complex64`` or ``np.complex128``. + **kwargs: keyword arguments. The following options are currently supported: + + ``max_bond_dim`` (int): Maximum bond dimension for the MPS simulator. + It corresponds to the number of Schmidt coefficients retained at the end of the SVD algorithm when applying gates. Default is ``None``. + ``cutoff`` (float): Truncation threshold for the Schmidt coefficients in a MPS simulator. Default is ``1e-16``. + """ + + # pylint: disable=too-many-instance-attributes + + # So far we just consider the options for MPS simulator + _device_options = ( + "backend", + "c_dtype", + "cutoff", + "method", + "max_bond_dim", + ) + + _new_API = True + + # pylint: disable=too-many-arguments + def __init__( + self, + *, + wires=None, + backend="quimb", + method="mps", + shots=None, + c_dtype=np.complex128, + **kwargs, + ): + + if not accepted_backends(backend): + raise ValueError(f"Unsupported backend: {backend}") + + if not accepted_methods(method): + raise ValueError(f"Unsupported method: {method}") + + if shots is not None: + raise ValueError("LightningTensor does not support finite shots.") + + super().__init__(wires=wires, shots=shots) + + self._num_wires = len(self.wires) if self.wires else 0 + self._backend = backend + self._method = method + self._c_dtype = c_dtype + + # options for MPS + self._max_bond_dim = kwargs.get("max_bond_dim", None) + self._cutoff = kwargs.get("cutoff", np.finfo(self._c_dtype).eps) + + self._interface = None + interface_opts = self._setup_execution_config().device_options + + if self.backend == "quimb" and self.method == "mps": + self._interface = QuimbMPS( + self._num_wires, + interface_opts, + self._c_dtype, + ) + + else: + raise ValueError( + f"Unsupported backend: {self.backend} or method: {self.method}" + ) # pragma: no cover + + for arg in kwargs: + if arg not in self._device_options: + raise TypeError( + f"Unexpected argument: {arg} during initialization of the LightningTensor device." + ) + + @property + def name(self): + """The name of the device.""" + return "lightning.tensor" + + @property + def num_wires(self): + """Number of wires addressed on this device.""" + return self._num_wires + + @property + def backend(self): + """Supported backend.""" + return self._backend + + @property + def method(self): + """Supported method.""" + return self._method + + @property + def c_dtype(self): + """Tensor complex data type.""" + return self._c_dtype + + dtype = c_dtype + + def _setup_execution_config( + self, config: Optional[ExecutionConfig] = DefaultExecutionConfig + ) -> ExecutionConfig: + """ + Update the execution config with choices for how the device should be used and the device options. + """ + # TODO: add options for gradients next quarter + + updated_values = {} + + new_device_options = dict(config.device_options) + for option in self._device_options: + if option not in new_device_options: + new_device_options[option] = getattr(self, f"_{option}", None) + + return replace(config, **updated_values, device_options=new_device_options) + + def preprocess( + self, + execution_config: ExecutionConfig = DefaultExecutionConfig, + ): + """This function defines the device transform program to be applied and an updated device configuration. + + Args: + execution_config (Union[ExecutionConfig, Sequence[ExecutionConfig]]): A data structure describing the + parameters needed to fully describe the execution. + + Returns: + TransformProgram, ExecutionConfig: A transform program that when called returns :class:`~.QuantumTape`'s that the + device can natively execute as well as a postprocessing function to be called after execution, and a configuration + with unset specifications filled in. + + This device: + + * Supports any qubit operations that provide a matrix. + * Currently does not support finite shots. + """ + + config = self._setup_execution_config(execution_config) + + program = TransformProgram() + + # more in the next PR + + return program, config + + def execute( + self, + circuits: QuantumTape_or_Batch, + execution_config: ExecutionConfig = DefaultExecutionConfig, + ) -> Result_or_ResultBatch: + """Execute a circuit or a batch of circuits and turn it into results. + + Args: + circuits (Union[QuantumTape, Sequence[QuantumTape]]): the quantum circuits to be executed. + execution_config (ExecutionConfig): a datastructure with additional information required for execution. + + Returns: + TensorLike, tuple[TensorLike], tuple[tuple[TensorLike]]: A numeric result of the computation. + """ + # comment is removed in the next PR + # return self._interface.execute(circuits, execution_config) + + # pylint: disable=unused-argument + def supports_derivatives( + self, + execution_config: Optional[ExecutionConfig] = None, + circuit: Optional[qml.tape.QuantumTape] = None, + ) -> bool: + """Check whether or not derivatives are available for a given configuration and circuit. + + Args: + execution_config (ExecutionConfig): The configuration of the desired derivative calculation. + circuit (QuantumTape): An optional circuit to check derivatives support for. + + Returns: + Bool: Whether or not a derivative can be calculated provided the given information. + + """ + return False + + def compute_derivatives( + self, + circuits: QuantumTape_or_Batch, + execution_config: ExecutionConfig = DefaultExecutionConfig, + ): + """Calculate the jacobian of either a single or a batch of circuits on the device. + + Args: + circuits (Union[QuantumTape, Sequence[QuantumTape]]): the circuits to calculate derivatives for. + execution_config (ExecutionConfig): a datastructure with all additional information required for execution. + + Returns: + Tuple: The jacobian for each trainable parameter. + """ + raise NotImplementedError( + "The computation of derivatives has yet to be implemented for the lightning.tensor device." + ) + + def execute_and_compute_derivatives( + self, + circuits: QuantumTape_or_Batch, + execution_config: ExecutionConfig = DefaultExecutionConfig, + ): + """Compute the results and jacobians of circuits at the same time. + + Args: + circuits (Union[QuantumTape, Sequence[QuantumTape]]): the circuits or batch of circuits. + execution_config (ExecutionConfig): a datastructure with all additional information required for execution. + + Returns: + tuple: A numeric result of the computation and the gradient. + """ + raise NotImplementedError( + "The computation of derivatives has yet to be implemented for the lightning.tensor device." + ) + + # pylint: disable=unused-argument + def supports_vjp( + self, + execution_config: Optional[ExecutionConfig] = None, + circuit: Optional[QuantumTape] = None, + ) -> bool: + """Whether or not this device defines a custom vector jacobian product. + + Args: + execution_config (ExecutionConfig): The configuration of the desired derivative calculation. + circuit (QuantumTape): An optional circuit to check derivatives support for. + + Returns: + Bool: Whether or not a derivative can be calculated provided the given information. + """ + # TODO: implement during next quarter + return False + + def compute_vjp( + self, + circuits: QuantumTape_or_Batch, + cotangents: Tuple[Number], + execution_config: ExecutionConfig = DefaultExecutionConfig, + ): + r"""The vector jacobian product used in reverse-mode differentiation. + + Args: + circuits (Union[QuantumTape, Sequence[QuantumTape]]): the circuit or batch of circuits. + cotangents (Tuple[Number, Tuple[Number]]): Gradient-output vector. Must have shape matching the output shape of the + corresponding circuit. If the circuit has a single output, ``cotangents`` may be a single number, not an iterable + of numbers. + execution_config (ExecutionConfig): a datastructure with all additional information required for execution. + + Returns: + tensor-like: A numeric result of computing the vector jacobian product. + """ + raise NotImplementedError( + "The computation of vector jacobian product has yet to be implemented for the lightning.tensor device." + ) + + def execute_and_compute_vjp( + self, + circuits: QuantumTape_or_Batch, + cotangents: Tuple[Number], + execution_config: ExecutionConfig = DefaultExecutionConfig, + ): + """Calculate both the results and the vector jacobian product used in reverse-mode differentiation. + + Args: + circuits (Union[QuantumTape, Sequence[QuantumTape]]): the circuit or batch of circuits to be executed. + cotangents (Tuple[Number, Tuple[Number]]): Gradient-output vector. Must have shape matching the output shape of the + corresponding circuit. + execution_config (ExecutionConfig): a datastructure with all additional information required for execution. + + Returns: + Tuple, Tuple: the result of executing the scripts and the numeric result of computing the vector jacobian product + """ + raise NotImplementedError( + "The computation of vector jacobian product has yet to be implemented for the lightning.tensor device." + ) diff --git a/requirements-dev.txt b/requirements-dev.txt index c87c9154a4..9d0e09b02c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,4 +17,5 @@ click==8.0.4 cmake custatevec-cu12 pylint -scipy +scipy~=1.12.0 +quimb diff --git a/tests/lightning_qubit/test_measurements_class.py b/tests/lightning_qubit/test_measurements_class.py index bd070629f8..c282e69358 100644 --- a/tests/lightning_qubit/test_measurements_class.py +++ b/tests/lightning_qubit/test_measurements_class.py @@ -474,7 +474,7 @@ def test_single_return_value(self, measurement, observable, lightning_sv, tol): result = m.measure_final_state(tape) # a few tests may fail in single precision, and hence we increase the tolerance - assert np.allclose(result, expected, max(tol, 1.0e-5)) + assert np.allclose(result, expected, max(tol, 1.0e-4)) @flaky(max_runs=5) @pytest.mark.parametrize("shots", [None, 1000000]) diff --git a/tests/lightning_tensor/__init__.py b/tests/lightning_tensor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/lightning_tensor/test_lightning_tensor.py b/tests/lightning_tensor/test_lightning_tensor.py new file mode 100644 index 0000000000..3d7d1033b8 --- /dev/null +++ b/tests/lightning_tensor/test_lightning_tensor.py @@ -0,0 +1,115 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests for the generic lightning tensor class. +""" + + +import numpy as np +import pennylane as qml +import pytest +from conftest import LightningDevice # tested device +from pennylane.wires import Wires + +from pennylane_lightning.lightning_tensor import LightningTensor + +if not LightningDevice._new_API: + pytest.skip("Exclusive tests for new API. Skipping.", allow_module_level=True) + +if LightningDevice._CPP_BINARY_AVAILABLE: + pytest.skip("Device doesn't have C++ support yet.", allow_module_level=True) + + +@pytest.mark.parametrize("num_wires", [None, 4]) +@pytest.mark.parametrize("c_dtype", [np.complex64, np.complex128]) +def test_device_name_and_init(num_wires, c_dtype): + """Test the class initialization and returned properties.""" + wires = Wires(range(num_wires)) if num_wires else None + dev = LightningTensor(wires=wires, c_dtype=c_dtype) + assert dev.name == "lightning.tensor" + assert dev.c_dtype == c_dtype + assert dev.wires == wires + if num_wires is None: + assert dev.num_wires == 0 + else: + assert dev.num_wires == num_wires + + +@pytest.mark.parametrize("backend", ["fake_backend"]) +def test_invalid_backend(backend): + """Test an invalid backend.""" + with pytest.raises(ValueError, match=f"Unsupported backend: {backend}"): + LightningTensor(backend=backend) + + +@pytest.mark.parametrize("method", ["fake_method"]) +def test_invalid_method(method): + """Test an invalid method.""" + with pytest.raises(ValueError, match=f"Unsupported method: {method}"): + LightningTensor(method=method) + + +def test_invalid_keyword_arg(): + """Test an invalid keyword argument.""" + with pytest.raises( + TypeError, + match=f"Unexpected argument: fake_arg during initialization of the LightningTensor device.", + ): + LightningTensor(fake_arg=None) + + +def test_invalid_shots(): + """Test that an error is raised if finite number of shots are requestd.""" + with pytest.raises(ValueError, match="LightningTensor does not support finite shots."): + LightningTensor(shots=5) + + +def test_support_derivatives(): + """Test that the device does not support derivatives yet.""" + dev = LightningTensor() + assert not dev.supports_derivatives() + + +def test_compute_derivatives(): + """Test that an error is raised if the `compute_derivatives` method is called.""" + dev = LightningTensor() + with pytest.raises(NotImplementedError): + dev.compute_derivatives(circuits=None) + + +def test_execute_and_compute_derivatives(): + """Test that an error is raised if `execute_and_compute_derivative` method is called.""" + dev = LightningTensor() + with pytest.raises(NotImplementedError): + dev.execute_and_compute_derivatives(circuits=None) + + +def test_supports_vjp(): + """Test that the device does not support VJP yet.""" + dev = LightningTensor() + assert not dev.supports_vjp() + + +def test_compute_vjp(): + """Test that an error is raised if `compute_vjp` method is called.""" + dev = LightningTensor() + with pytest.raises(NotImplementedError): + dev.compute_vjp(circuits=None, cotangents=None) + + +def test_execute_and_compute_vjp(): + """Test that an error is raised if `execute_and_compute_vjp` method is called.""" + dev = LightningTensor() + with pytest.raises(NotImplementedError): + dev.execute_and_compute_vjp(circuits=None, cotangents=None) diff --git a/tests/lightning_tensor/test_quimb_mps.py b/tests/lightning_tensor/test_quimb_mps.py new file mode 100644 index 0000000000..8f3e376e56 --- /dev/null +++ b/tests/lightning_tensor/test_quimb_mps.py @@ -0,0 +1,51 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests for the ``quimb`` interface. +""" + + +import numpy as np +import pytest +import quimb.tensor as qtn +from conftest import LightningDevice # tested device +from pennylane.wires import Wires + +from pennylane_lightning.lightning_tensor import LightningTensor + +if not LightningDevice._new_API: + pytest.skip("Exclusive tests for new API. Skipping.", allow_module_level=True) + +if LightningDevice._CPP_BINARY_AVAILABLE: + pytest.skip("Device doesn't have C++ support yet.", allow_module_level=True) + + +@pytest.mark.parametrize("backend", ["quimb"]) +@pytest.mark.parametrize("method", ["mps"]) +class TestQuimbMPS: + """Tests for the MPS method.""" + + @pytest.mark.parametrize("num_wires", [None, 4]) + @pytest.mark.parametrize("c_dtype", [np.complex64, np.complex128]) + def test_device_init(self, num_wires, c_dtype, backend, method): + """Test the class initialization and returned properties.""" + + wires = Wires(range(num_wires)) if num_wires else None + dev = LightningTensor(wires=wires, backend=backend, method=method, c_dtype=c_dtype) + assert isinstance(dev._interface.state, qtn.MatrixProductState) + assert isinstance(dev._interface.state_to_array(), np.ndarray) + + _, config = dev.preprocess() + assert config.device_options["backend"] == backend + assert config.device_options["method"] == method