Skip to content

Commit

Permalink
add function in conftest that generates a test matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
astralcai committed Dec 13, 2024
1 parent 08c1d3a commit 624c8f6
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 37 deletions.
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ packaging
autoray>=0.6.11
matplotlib~=3.5
opt_einsum~=3.3
requests~=2.31.0
requests~=2.32.0
typing_extensions>=4.6.0
tomli~=2.0.0 # Drop once minimum Python version is 3.11
tach~=0.13.1
diastatic-malt
86 changes: 86 additions & 0 deletions tests/workflow/interfaces/qnode/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Fixtures and functions for QNode interface tests."""

import warnings
import itertools
from typing import Callable

import pytest
import pennylane as qml
from param_shift_dev import ParamShiftDerivativesDevice


@pytest.fixture(autouse=True)
def suppress_tape_property_deprecation_warning():
warnings.filterwarnings(
"ignore", "The tape/qtape property is deprecated", category=qml.PennyLaneDeprecationWarning
)


def get_device(device_name, wires, seed=None):
if device_name == "param_shift.qubit":
return ParamShiftDerivativesDevice(seed=seed)
if device_name == "lightning.qubit":
return qml.device("lightning.qubit", wires=wires)
return qml.device(device_name, seed=seed)


_device_names = ("default.qubit", "param_shift.qubit", "lightning.qubit", "reference.qubit")
_diff_methods = ("backprop", "adjoint", "finite-diff", "parameter-shift", "hadamard", "spsa")
_grad_on_execution = (True, False)
_device_vjp = (True, False)


def generate_test_matrix(xfail_condition: Callable, skip_condition: Callable):
"""Generates the test matrix for different combinations."""

def _test_matrix_iter():
"""Yields tuples of (device_name, diff_method, grad_on_execution, device_vjp)."""

all_combinations = itertools.product(_device_names, _diff_methods, _grad_on_execution)
for device_name, diff_method, grad_on_execution, device_vjp in all_combinations:

if diff_method == "adjoint" and device_name not in ("default.qubit", "lightning.qubit"):
continue # adjoint diff is only supported on DQ and LQ

if device_name == "param_shift.qubit" and diff_method != "parameter-shift":
continue # param_shift.qubit is not intended to be used with anything else

xfail_reason = xfail_condition(device_name, diff_method, grad_on_execution, device_vjp)
skip_reason = skip_condition(device_name, diff_method, grad_on_execution, device_vjp)

if xfail_reason:
yield pytest.param(
device_name,
diff_method,
grad_on_execution,
device_vjp,
marks=pytest.mark.xfail(reason=xfail_reason),
)

elif skip_reason:
yield pytest.param(
device_name,
diff_method,
grad_on_execution,
device_vjp,
marks=pytest.mark.skip(reason=skip_reason),
)

else:
yield device_name, diff_method, grad_on_execution, device_vjp

return [_params for _params in _test_matrix_iter()]
7 changes: 0 additions & 7 deletions tests/workflow/interfaces/qnode/test_autograd_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,6 @@
from pennylane.devices import DefaultQubit


@pytest.fixture(autouse=True)
def suppress_tape_property_deprecation_warning():
warnings.filterwarnings(
"ignore", "The tape/qtape property is deprecated", category=qml.PennyLaneDeprecationWarning
)


# dev, diff_method, grad_on_execution, device_vjp
qubit_device_and_diff_method = [
[qml.device("default.qubit"), "finite-diff", False, False],
Expand Down
7 changes: 0 additions & 7 deletions tests/workflow/interfaces/qnode/test_jax_jit_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,6 @@
from pennylane.devices import DefaultQubit


@pytest.fixture(autouse=True)
def suppress_tape_property_deprecation_warning():
warnings.filterwarnings(
"ignore", "The tape/qtape property is deprecated", category=qml.PennyLaneDeprecationWarning
)


def get_device(device_name, wires, seed):
if device_name == "param_shift.qubit":
return ParamShiftDerivativesDevice(seed=seed)
Expand Down
7 changes: 0 additions & 7 deletions tests/workflow/interfaces/qnode/test_jax_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@
from pennylane.devices import DefaultQubit


@pytest.fixture(autouse=True)
def suppress_tape_property_deprecation_warning():
warnings.filterwarnings(
"ignore", "The tape/qtape property is deprecated", category=qml.PennyLaneDeprecationWarning
)


def get_device(device_name, wires, seed):
if device_name == "lightning.qubit":
return qml.device("lightning.qubit", wires=wires)
Expand Down
7 changes: 0 additions & 7 deletions tests/workflow/interfaces/qnode/test_tensorflow_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,6 @@
from pennylane.devices import DefaultQubit


@pytest.fixture(autouse=True)
def suppress_tape_property_deprecation_warning():
warnings.filterwarnings(
"ignore", "The tape/qtape property is deprecated", category=qml.PennyLaneDeprecationWarning
)


pytestmark = pytest.mark.tf
tf = pytest.importorskip("tensorflow")

Expand Down
7 changes: 0 additions & 7 deletions tests/workflow/interfaces/qnode/test_torch_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@
from pennylane.devices import DefaultQubit


@pytest.fixture(autouse=True)
def suppress_tape_property_deprecation_warning():
warnings.filterwarnings(
"ignore", "The tape/qtape property is deprecated", category=qml.PennyLaneDeprecationWarning
)


pytestmark = pytest.mark.torch

torch = pytest.importorskip("torch", minversion="1.3")
Expand Down

0 comments on commit 624c8f6

Please sign in to comment.