Skip to content

Commit

Permalink
Implement __eq__ comparison for TransformContainer and TransformProgr…
Browse files Browse the repository at this point in the history
…am (#4858)

Towards #4824.

I've completed the first draft of the __eq__ codes for
TransformContainer and TransformProgram. I haven't written the testing
code yet. Once you confirm these implementations, I'll proceed with the
testing phase. Let me know your thoughts.

---------

Signed-off-by: Anurav Modak <[email protected]>
Co-authored-by: Christina Lee <[email protected]>
Co-authored-by: Christina Lee <[email protected]>
Co-authored-by: Romain Moyard <[email protected]>
  • Loading branch information
4 people authored Nov 23, 2023
1 parent dcff928 commit 2bdcf45
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 1 deletion.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@
* Added `ops.functions.assert_valid` for checking if an `Operator` class is defined correctly.
[(#4764)](https://github.com/PennyLaneAI/pennylane/pull/4764)

* Added `__eq__` method for TransformProgram and TransformContainers allowing the comparison of respective objects using `==` and `!=` operators.
[(#4858)](https://github.com/PennyLaneAI/pennylane/pull/4858)

* `GlobalPhase` now decomposes to nothing, in case devices do not support global phases.
[(#4855)](https://github.com/PennyLaneAI/pennylane/pull/4855)

Expand Down
12 changes: 12 additions & 0 deletions pennylane/transforms/core/transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,18 @@ def __iter__(self):
)
)

def __eq__(self, other: object) -> bool:
if not isinstance(other, TransformContainer):
return False
return (
self.args == other.args
and self.transform == other.transform
and self.kwargs == other.kwargs
and self.classical_cotransform == other.classical_cotransform
and self.is_informative == other.is_informative
and self.final_transform == other.final_transform
)

@property
def transform(self):
"""The stored quantum transform."""
Expand Down
6 changes: 6 additions & 0 deletions pennylane/transforms/core/transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ def __repr__(self):
contents = ", ".join(f"{transform_c.transform.__name__}" for transform_c in self)
return f"TransformProgram({contents})"

def __eq__(self, other):
if not isinstance(other, TransformProgram):
return False

return self._transform_program == other._transform_program

def push_back(self, transform_container: TransformContainer):
"""Add a transform (container) to the end of the program.
Expand Down
36 changes: 35 additions & 1 deletion tests/transforms/test_experimental/test_transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytest
import pennylane as qml
from pennylane.transforms.core import transform, TransformError
from pennylane.transforms.core import transform, TransformError, TransformContainer

dev = qml.device("default.qubit", wires=2)

Expand Down Expand Up @@ -282,6 +282,40 @@ def qnode_circuit(a):
qnode_circuit.transform_program.pop_front(), qml.transforms.core.TransformContainer
)

def test_equality(self):
"""Tests that we can compare TransformContainer objects with the '==' and '!=' operators."""

t1 = TransformContainer(
qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1}
)
t2 = TransformContainer(
qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1}
)
t3 = TransformContainer(
qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (1, 2)]}
)
t4 = TransformContainer(
qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 2}
)

t5 = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-6,))
t6 = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-7,))

# test for equality of identical transformers
assert t1 == t2

# test for inequality of different transformers
assert t1 != t3
assert t2 != t3
assert t1 != 2
assert t1 != t4
assert t5 != t6
assert t5 != t1

# Test equality with the same args
t5_copy = TransformContainer(qml.transforms.merge_rotations.transform, args=(1e-6,))
assert t5 == t5_copy

def test_queuing_qfunc_transform(self):
"""Test that queuing works with the transformed quantum function."""

Expand Down
29 changes: 29 additions & 0 deletions tests/transforms/test_experimental/test_transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,35 @@ def test_valid_transforms(self):
):
transform_program.push_back(transform2)

def test_equality(self):
"""Tests that we can compare TransformProgram objects with the '==' and '!=' operators."""
t1 = TransformContainer(
qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1}
)
t2 = TransformContainer(
qml.transforms.compile.transform, kwargs={"num_passes": 2, "expand_depth": 1}
)
t3 = TransformContainer(
qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (1, 2)]}
)

p1 = TransformProgram([t1, t3])
p2 = TransformProgram([t2, t3])
p3 = TransformProgram([t3, t2])

# test for equality of identical objects
assert p1 == p2
# test for inequality of different objects
assert p1 != p3
assert p1 != t1

# Test inequality with different transforms
t4 = TransformContainer(
qml.transforms.transpile.transform, kwargs={"coupling_map": [(0, 1), (2, 3)]}
)
p4 = TransformProgram([t1, t4])
assert p1 != p4


class TestTransformProgramCall:
"""Tests for calling a TransformProgram on a batch of quantum tapes."""
Expand Down

0 comments on commit 2bdcf45

Please sign in to comment.