Skip to content

Commit

Permalink
Support PauliSentence in qml.equal (#6703)
Browse files Browse the repository at this point in the history
**Context:**
Currently `qml.equal` can not handle `PauliSentence` instances, and they
are a bit cumbersome to compare manually.

**Description of the Change:**
Support `PauliSentence` comparison in `qml.equal`.

**Benefits:**
Extended feature.

**Possible Drawbacks:**
N/A

**Related GitHub Issues:**
n/a

---------

Co-authored-by: Christina Lee <[email protected]>
  • Loading branch information
dwierichs and albi3ro authored Dec 18, 2024
1 parent f4cf321 commit 97f2cb5
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 4 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ such as `shots`, `rng` and `prng_key`.

<h3>Improvements 🛠</h3>

* `qml.equal` now supports `PauliWord` and `PauliSentence` instances.
[(#6703)](https://github.com/PennyLaneAI/pennylane/pull/6703)

* Remove redundant commutator computations from `qml.lie_closure`.
[(#6724)](https://github.com/PennyLaneAI/pennylane/pull/6724)

Expand Down
77 changes: 73 additions & 4 deletions pennylane/ops/functions/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pennylane.measurements.vn_entropy import VnEntropyMP
from pennylane.operation import Observable, Operator
from pennylane.ops import Adjoint, CompositeOp, Conditional, Controlled, Exp, Pow, SProd
from pennylane.pauli import PauliSentence, PauliWord
from pennylane.pulse.parametrized_evolution import ParametrizedEvolution
from pennylane.tape import QuantumScript
from pennylane.templates.subroutines import ControlledSequence, PrepSelPrep
Expand All @@ -38,8 +39,8 @@


def equal(
op1: Union[Operator, MeasurementProcess, QuantumScript],
op2: Union[Operator, MeasurementProcess, QuantumScript],
op1: Union[Operator, MeasurementProcess, QuantumScript, PauliWord, PauliSentence],
op2: Union[Operator, MeasurementProcess, QuantumScript, PauliWord, PauliSentence],
check_interface=True,
check_trainability=True,
rtol=1e-5,
Expand All @@ -62,8 +63,8 @@ def equal(
and ``check_trainability``.
Args:
op1 (.Operator or .MeasurementProcess or .QuantumTape): First object to compare
op2 (.Operator or .MeasurementProcess or .QuantumTape): Second object to compare
op1 (.Operator or .MeasurementProcess or .QuantumTape or .PauliWord or .PauliSentence): First object to compare
op2 (.Operator or .MeasurementProcess or .QuantumTape or .PauliWord or .PauliSentence): Second object to compare
check_interface (bool, optional): Whether to compare interfaces. Default: ``True``.
check_trainability (bool, optional): Whether to compare trainability status. Default: ``True``.
rtol (float, optional): Relative tolerance for parameters.
Expand Down Expand Up @@ -356,6 +357,74 @@ def _equal_operators(
return True


# pylint: disable=unused-argument
@_equal_dispatch.register
def _equal_pauliword(
op1: PauliWord,
op2: PauliWord,
**kwargs,
):
if op1 != op2:
if set(op1) != set(op2):
err = "Different wires in Pauli words."
diff12 = set(op1).difference(set(op2))
diff21 = set(op2).difference(set(op1))
if diff12:
err += f" op1 has {diff12} not present in op2."
if diff21:
err += f" op2 has {diff21} not present in op1."
return err
pauli_diff = {}
for wire in op1:
if op1[wire] != op2[wire]:
pauli_diff[wire] = f"{op1[wire]} != {op2[wire]}"
return f"Pauli words agree on wires but differ in Paulis: {pauli_diff}"
return True


@_equal_dispatch.register
def _equal_paulisentence(
op1: PauliSentence,
op2: PauliSentence,
check_interface=True,
check_trainability=True,
rtol=1e-5,
atol=1e-9,
):
if set(op1) != set(op2):
err = "Different Pauli words in PauliSentences."
diff12 = set(op1).difference(set(op2))
diff21 = set(op2).difference(set(op1))
if diff12:
err += f" op1 has {diff12} not present in op2."
if diff21:
err += f" op2 has {diff21} not present in op1."
return err
for pw in op1:
param1 = op1[pw]
param2 = op2[pw]
if check_trainability:
param1_train = qml.math.requires_grad(param1)
param2_train = qml.math.requires_grad(param2)
if param1_train != param2_train:
return (
"Parameters have different trainability.\n "
f"{param1} trainability is {param1_train} and {param2} trainability is {param2_train}"
)

if check_interface:
param1_interface = qml.math.get_interface(param1)
param2_interface = qml.math.get_interface(param2)
if param1_interface != param2_interface:
return (
"Parameters have different interfaces.\n "
f"{param1} interface is {param1_interface} and {param2} interface is {param2_interface}"
)
if not qml.math.allclose(param1, param2, rtol=rtol, atol=atol):
return f"The coefficients of the PauliSentences for {pw} differ: {param1}; {param2}"
return True


@_equal_dispatch.register
# pylint: disable=unused-argument, protected-access
def _equal_prod_and_sum(op1: CompositeOp, op2: CompositeOp, **kwargs):
Expand Down
132 changes: 132 additions & 0 deletions tests/ops/functions/test_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,138 @@ def test_not_equal_operator_measurement(self, op1, op2):
assert not qml.equal(op1, op2)


equal_pauli_words = [
({0: "X", 1: "Y"}, {1: "Y", 0: "X"}, True, None),
({0: "X", 1: "Y"}, {0: "X"}, False, "Different wires in Pauli words."),
({0: "X", 1: "Z"}, {1: "Y", 0: "X"}, False, "agree on wires but differ in Paulis."),
({0: "X", 1: "Y"}, {"X": "Y", 0: "X"}, False, "Different wires in Pauli words."),
]


# pylint: disable=too-few-public-methods
class TestPauliWordsEqual:
"""Tests for qml.equal with PauliSentences."""

@pytest.mark.parametrize("pw1, pw2, res, error_match", equal_pauli_words)
def test_equality(self, pw1, pw2, res, error_match):
"""Test basic equalities/inequalities."""
pw1 = qml.pauli.PauliWord(pw1)
pw2 = qml.pauli.PauliWord(pw2)
assert qml.equal(pw1, pw2) is res
assert qml.equal(pw2, pw1) is res

if res:
assert_equal(pw1, pw2)
assert_equal(pw2, pw1)
else:
with pytest.raises(AssertionError, match=error_match):
assert_equal(pw1, pw2)
with pytest.raises(AssertionError, match=error_match):
assert_equal(pw2, pw1)


equal_pauli_sentences = [
(qml.X(0) @ qml.Y(2), 1.0 * qml.Y(2) @ qml.X(0), True, None),
(
qml.X(0) @ qml.Y(2),
1.0 * qml.X(2) @ qml.Y(0),
False,
"Different Pauli words in PauliSentences",
),
(qml.X(0) - qml.Y(2), -1.0 * (qml.Y(2) - qml.X(0)), True, None),
(qml.X(0) @ qml.Y(2), qml.Y(2) + qml.X(0), False, "Different Pauli words in PauliSentences"),
(qml.SISWAP([0, "a"]) @ qml.Z("b"), qml.Z("b") @ qml.SISWAP((0, "a")), True, None),
(qml.SWAP([0, "a"]) @ qml.S("b"), qml.S("b") @ qml.SWAP(("a", 0)), True, None),
]


class TestPauliSentencesEqual:
"""Tests for qml.equal with PauliSentences."""

@pytest.mark.parametrize("ps1, ps2, res, error_match", equal_pauli_sentences)
def test_equality(self, ps1, ps2, res, error_match):
"""Test basic equalities/inequalities."""
ps1 = qml.simplify(ps1).pauli_rep
ps2 = qml.simplify(ps2).pauli_rep

assert qml.equal(ps1, ps2) is res
assert qml.equal(ps1 * 0.6, ps2 * 0.6) is res
assert qml.equal(ps2, ps1) is res

if res:
assert_equal(ps1, ps2)
assert_equal(ps2, ps1)
else:
with pytest.raises(AssertionError, match=error_match):
assert_equal(ps1, ps2)
with pytest.raises(AssertionError, match=error_match):
assert_equal(ps2, ps1)

@pytest.mark.torch
def test_trainability_and_interface(self):
"""Test that trainability and interface are compared correctly."""
import torch

x1 = qml.numpy.array(0.5, requires_grad=True)
x2 = qml.numpy.array(0.5, requires_grad=False)
x3 = torch.tensor(0.5, requires_grad=True)
x4 = torch.tensor(0.5, requires_grad=False)
pws = [qml.pauli.PauliWord({1: "X", 39: "Y"}), qml.pauli.PauliWord({0: "Z", 1: "Y"})]
ps1 = pws[0] * x1 - 0.7 * pws[1]
ps2 = pws[0] * x2 - 0.7 * pws[1]
ps3 = pws[0] * x3 - 0.7 * pws[1]
ps4 = pws[0] * x4 - 0.7 * pws[1]

assert qml.equal(ps1, ps2) is False
with pytest.raises(AssertionError, match="Parameters have different trainability"):
assert_equal(ps1, ps2)
assert qml.equal(ps1, ps3) is False
assert qml.equal(ps1, ps4) is False
assert qml.equal(ps2, ps3) is False
assert qml.equal(ps2, ps4) is False
assert qml.equal(ps3, ps4) is False

assert qml.equal(ps1, ps2, check_trainability=False) is True
assert_equal(ps1, ps2, check_trainability=False)
assert qml.equal(ps1, ps3, check_trainability=False) is False
with pytest.raises(AssertionError, match="Parameters have different interfaces"):
assert_equal(ps1, ps3, check_trainability=False)
assert qml.equal(ps1, ps4, check_trainability=False) is False
assert qml.equal(ps2, ps3, check_trainability=False) is False
assert qml.equal(ps2, ps4, check_trainability=False) is False
assert qml.equal(ps3, ps4, check_trainability=False) is True

assert qml.equal(ps1, ps2, check_interface=False) is False
with pytest.raises(AssertionError, match="Parameters have different trainability"):
assert_equal(ps1, ps2, check_interface=False)
assert qml.equal(ps1, ps3, check_interface=False) is True
assert_equal(ps1, ps3, check_interface=False)
assert qml.equal(ps1, ps4, check_interface=False) is False
assert qml.equal(ps2, ps3, check_interface=False) is False
assert qml.equal(ps2, ps4, check_interface=False) is True
assert qml.equal(ps3, ps4, check_interface=False) is False

assert qml.equal(ps1, ps2, check_trainability=False, check_interface=False) is True
assert_equal(ps1, ps2, check_trainability=False, check_interface=False)
assert qml.equal(ps1, ps3, check_trainability=False, check_interface=False) is True
assert qml.equal(ps1, ps4, check_trainability=False, check_interface=False) is True
assert qml.equal(ps2, ps3, check_trainability=False, check_interface=False) is True
assert qml.equal(ps2, ps4, check_trainability=False, check_interface=False) is True
assert qml.equal(ps3, ps4, check_trainability=False, check_interface=False) is True

@pytest.mark.parametrize(
"atol, rtol, res", [(1e-9, 0.0, False), (1e-7, 0.0, True), (0.0, 1e-9, True)]
)
def test_tolerance(self, atol, rtol, res):
"""Test that tolerances are taken into account correctly."""
x1 = 100
x2 = 100 + 1e-8
pws = [qml.pauli.PauliWord({1: "X", 39: "Y"}), qml.pauli.PauliWord({0: "Z", 1: "Y"})]
ps1 = pws[0] * x1 - 0.7 * pws[1]
ps2 = pws[0] * x2 - 0.7 * pws[1]
assert qml.equal(ps1, ps2, atol=atol, rtol=rtol) is res


class TestMeasurementsEqual:
@pytest.mark.jax
def test_observables_different_interfaces(self):
Expand Down

0 comments on commit 97f2cb5

Please sign in to comment.