Skip to content

Commit

Permalink
Refactor the usage of custom devices
Browse files Browse the repository at this point in the history
  • Loading branch information
rauletorresc committed Dec 24, 2024
1 parent e986ba9 commit f9e8986
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 147 deletions.
13 changes: 13 additions & 0 deletions tests/custom_devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""A collection of Custom Devices"""

import pennylane as qml


class BaseCustomDeviceReturnsZero(qml.devices.Device):
def execute(self, circuits, execution_config=None):
return 0


class BaseCustomDeviceReturnsTuple(qml.devices.Device):
def execute(self, circuits, execution_config=None):
return (0,)
21 changes: 10 additions & 11 deletions tests/devices/modifiers/test_all_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig)
assert dev.tracker.history["shots"] == [50]


class BaseDummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return 0.0


@pytest.mark.parametrize("modifier", (simulator_tracking, single_tape_support))
class TestModifierDefaultBeahviour:
"""Test generic behavior for device modifiers."""
Expand All @@ -66,29 +71,23 @@ def test_adds_to_applied_modifiers_private_property(self, modifier):
"""Test that the modifier is added to the `_applied_modifiers` property."""

@modifier
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return 0.0
class DummyDev(BaseDummyDev):
pass

assert DummyDev._applied_modifiers == [modifier]

@modifier
class DummyDev2(qml.devices.Device):

class DummyDev2(BaseDummyDev):
_applied_modifiers = [None] # some existing value

def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return 0.0

assert DummyDev2._applied_modifiers == [None, modifier]

def test_leaves_undefined_methods_untouched(self, modifier):
"""Test that undefined methods are left the same as the Device class methods."""

@modifier
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return 0.0
class DummyDev(BaseDummyDev):
pass

assert DummyDev.compute_derivatives == Device.compute_derivatives
assert DummyDev.execute_and_compute_derivatives == Device.execute_and_compute_derivatives
Expand Down
35 changes: 11 additions & 24 deletions tests/devices/modifiers/test_simulator_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,16 @@ def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig)
assert dev.tracker.history["shots"] == [100]


class BaseDummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return 0.0


def test_tracking_compute_derivatives():
"""Test the compute_derivatives tracking behavior."""

@simulator_tracking
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return 0.0

class DummyDev(BaseDummyDev):
def compute_derivatives(
self, circuits, execution_config=qml.devices.DefaultExecutionConfig
):
Expand All @@ -80,10 +82,7 @@ def test_tracking_execute_and_compute_derivatives():
"""Test tracking the execute_and_compute_derivatives method."""

@simulator_tracking
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return 0.0

class DummyDev(BaseDummyDev):
def execute_and_compute_derivatives(
self, circuits, execution_config=qml.devices.DefaultExecutionConfig
):
Expand All @@ -108,10 +107,7 @@ def test_tracking_compute_jvp():
"""Test the compute_jvp tracking behavior."""

@simulator_tracking
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return 0.0

class DummyDev(BaseDummyDev):
def compute_jvp(
self, circuits, tangents, execution_config=qml.devices.DefaultExecutionConfig
):
Expand All @@ -132,10 +128,7 @@ def test_tracking_execute_and_compute_jvp():
"""Test tracking the execute_and_compute_jvp method."""

@simulator_tracking
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return 0.0

class DummyDev(BaseDummyDev):
def execute_and_compute_jvp(
self, circuits, tangents, execution_config=qml.devices.DefaultExecutionConfig
):
Expand All @@ -161,10 +154,7 @@ def test_tracking_compute_vjp():
"""Test the compute_vjp tracking behavior."""

@simulator_tracking
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return 0.0

class DummyDev(BaseDummyDev):
def compute_vjp(
self, circuits, cotangents, execution_config=qml.devices.DefaultExecutionConfig
):
Expand All @@ -185,10 +175,7 @@ def test_tracking_execute_and_compute_vjp():
"""Test tracking the execute_and_compute_derivatives method."""

@simulator_tracking
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return 0.0

class DummyDev(BaseDummyDev):
def execute_and_compute_vjp(
self, circuits, cotangents, execution_config=qml.devices.DefaultExecutionConfig
):
Expand Down
40 changes: 13 additions & 27 deletions tests/devices/modifiers/test_single_tape_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@
from pennylane.devices.modifiers import single_tape_support


class BaseDummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return tuple(0.0 for _ in circuits)


def test_wraps_execute():
"""Test that execute now accepts a single circuit."""

@single_tape_support
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return tuple(0.0 for _ in circuits)
class DummyDev(BaseDummyDev):
pass

t = qml.tape.QuantumScript()
dev = DummyDev()
Expand All @@ -37,10 +41,7 @@ def test_wraps_compute_derivatives():
"""Test that compute_derivatives now accepts a single circuit."""

@single_tape_support
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return tuple(0.0 for _ in circuits)

class DummyDev(BaseDummyDev):
def compute_derivatives(
self, circuits, execution_config=qml.devices.DefaultExecutionConfig
):
Expand All @@ -55,10 +56,7 @@ def test_wraps_execute_and_compute_derivatives():
"""Test that execute_and_compute_derivatives now accepts a single circuit."""

@single_tape_support
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return tuple(0.0 for _ in circuits)

class DummyDev(BaseDummyDev):
def execute_and_compute_derivatives(
self, circuits, execution_config=qml.devices.DefaultExecutionConfig
):
Expand All @@ -74,10 +72,7 @@ def test_wraps_compute_jvp():
"""Test that compute_jvp now accepts a single circuit."""

@single_tape_support
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return tuple(0.0 for _ in circuits)

class DummyDev(BaseDummyDev):
def compute_jvp(
self, circuits, tangents, execution_config=qml.devices.DefaultExecutionConfig
):
Expand All @@ -94,10 +89,7 @@ def test_wraps_execute_and_compute_jvp():
"""Test that execute_and_compute_jvp now accepts a single circuit."""

@single_tape_support
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return tuple(0.0 for _ in circuits)

class DummyDev(BaseDummyDev):
def execute_and_compute_jvp(
self, circuits, tangents, execution_config=qml.devices.DefaultExecutionConfig
):
Expand All @@ -114,10 +106,7 @@ def test_wraps_compute_vjp():
"""Test that compute_vjp now accepts a single circuit."""

@single_tape_support
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return tuple(0.0 for _ in circuits)

class DummyDev(BaseDummyDev):
def compute_vjp(
self, circuits, cotangents, execution_config=qml.devices.DefaultExecutionConfig
):
Expand All @@ -134,10 +123,7 @@ def test_wraps_execute_and_compute_vjp():
"""Test that execute_and_compute_vjp now accepts a single circuit."""

@single_tape_support
class DummyDev(qml.devices.Device):
def execute(self, circuits, execution_config=qml.devices.DefaultExecutionConfig):
return tuple(0.0 for _ in circuits)

class DummyDev(BaseDummyDev):
def execute_and_compute_vjp(
self, circuits, cotangents, execution_config=qml.devices.DefaultExecutionConfig
):
Expand Down
Loading

0 comments on commit f9e8986

Please sign in to comment.