Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ControlledQubitUnitary consistency with program capture #6719

Merged
merged 24 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ba59255
feat: Update ControlledQubitUnitary class
andrijapau Dec 13, 2024
5ab830d
fix: Update _primitive_bind_call to fix jaxpr error
andrijapau Dec 13, 2024
cdb8619
fix: Update controlled_ops.py and add bug tests
andrijapau Dec 13, 2024
fba5517
fix: Add jax pytest mark to enable_disable_plxpr
andrijapau Dec 16, 2024
cfc63d7
Merge branch 'master' into fix-ctrl-qb-unit
andrijapau Dec 16, 2024
0a2d919
fix: Add jax mark to test_controlled_ops.py
andrijapau Dec 16, 2024
06f1185
fix: Improve coverage by adding more tests
andrijapau Dec 16, 2024
785444f
Merge branch 'master' into fix-ctrl-qb-unit
andrijapau Dec 16, 2024
42f9817
Update pennylane/ops/op_math/controlled_ops.py
andrijapau Dec 16, 2024
dfaa2b2
Update pennylane/ops/op_math/controlled_ops.py
andrijapau Dec 16, 2024
ecb0d78
Update pennylane/ops/op_math/controlled_ops.py
andrijapau Dec 16, 2024
b7744f8
Update pennylane/ops/op_math/controlled_ops.py
andrijapau Dec 16, 2024
51624f1
fix: Update test_controlled_ops.py with improvements
andrijapau Dec 16, 2024
8268001
Merge branch 'fix-ctrl-qb-unit' of github.com:PennyLaneAI/pennylane i…
andrijapau Dec 16, 2024
782b046
fix: Update controlled_ops.py for incorrect hasattr signature
andrijapau Dec 16, 2024
3a0b611
Merge branch 'master' into fix-ctrl-qb-unit
andrijapau Dec 16, 2024
c993f4d
doc: Update changelog-dev.md
andrijapau Dec 16, 2024
917e126
doc: Add doc string to test_wires_is_none
andrijapau Dec 16, 2024
bea9fd4
refactor: minor clean-up
andrijapau Dec 16, 2024
286cd84
Merge branch 'master' into fix-ctrl-qb-unit
andrijapau Dec 17, 2024
75fd29c
Merge branch 'master' into fix-ctrl-qb-unit
andrijapau Dec 17, 2024
5547056
Merge branch 'master' into fix-ctrl-qb-unit
andrijapau Dec 18, 2024
38a4356
doc: Update changelog-dev.md
andrijapau Dec 18, 2024
9c12a1a
refactor: Code clean-up
andrijapau Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions pennylane/ops/op_math/controlled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pennylane as qml
from pennylane.operation import AnyWires, Wires
from pennylane.ops.qubit.parametric_ops_single_qubit import stack_last
from pennylane.wires import WiresLike

from .controlled import ControlledOp
from .controlled_decompositions import decompose_mcx
Expand Down Expand Up @@ -119,22 +120,61 @@
data[0], control_wires=metadata[0], control_values=metadata[1], work_wires=metadata[2]
)

# pylint: disable=arguments-differ, too-many-arguments, unused-argument
andrijapau marked this conversation as resolved.
Show resolved Hide resolved
@classmethod
def _primitive_bind_call(

Check notice on line 125 in pennylane/ops/op_math/controlled_ops.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/ops/op_math/controlled_ops.py#L125

Too many positional arguments (7/5) (too-many-positional-arguments)
cls,
base,
control_wires: WiresLike,
wires: WiresLike = (),
control_values=None,
unitary_check=False,
work_wires: WiresLike = (),
):
wires = Wires(()) if wires is None else Wires(wires)
work_wires = Wires(()) if work_wires is None else Wires(work_wires)
andrijapau marked this conversation as resolved.
Show resolved Hide resolved

if getattr(base, "wires", False) and len(wires) != 0:
warnings.warn(
"base operator already has wires; values specified through wires kwarg will be ignored."
)
wires = Wires(())
andrijapau marked this conversation as resolved.
Show resolved Hide resolved

all_wires = control_wires + wires
return cls._primitive.bind(
base, control_wires=all_wires, control_values=control_values, work_wires=work_wires
)

# pylint: disable=too-many-arguments,too-many-positional-arguments
def __init__(
self,
base,
control_wires,
wires=None,
control_wires: WiresLike,
wires: WiresLike = (),
control_values=None,
unitary_check=False,
work_wires=None,
work_wires: WiresLike = (),
):
if getattr(base, "wires", False) and wires is not None:
wires = () if wires is None else Wires(wires)
work_wires = () if work_wires is None else Wires(work_wires)
control_wires = Wires(control_wires)

if getattr(base, "wires", False) and len(wires) != 0:
andrijapau marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(
"base operator already has wires; values specified through wires kwarg will be ignored."
)
wires = Wires(())

if isinstance(base, Iterable):
if len(wires) == 0:
if len(Wires(control_wires)) > 1:
num_base_wires = int(qml.math.log2(qml.math.shape(base)[-1]))
wires = Wires(control_wires)[-num_base_wires:]
control_wires = Wires(control_wires)[:-num_base_wires]
else:
raise TypeError(
"Must specify a set of wires. None is not a valid `wires` label."
)
andrijapau marked this conversation as resolved.
Show resolved Hide resolved
# We use type.__call__ instead of calling the class directly so that we don't bind the
# operator primitive when new program capture is enabled
base = type.__call__(qml.QubitUnitary, base, wires=wires, unitary_check=unitary_check)
Expand Down
54 changes: 53 additions & 1 deletion tests/ops/op_math/test_controlled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,61 @@
X_broadcasted = np.array([X] * 3)


# pylint: disable=too-many-public-methods
class TestControlledQubitUnitary:
"""Tests specific to the ControlledQubitUnitary operation"""

def test_wires_is_none(self):
andrijapau marked this conversation as resolved.
Show resolved Hide resolved
base_op = [[0, 1], [1, 0]]

with pytest.raises(TypeError, match="Must specify a set of wires"):
andrijapau marked this conversation as resolved.
Show resolved Hide resolved
qml.ControlledQubitUnitary(base_op, control_wires=1, wires=None)

@pytest.mark.jax
@pytest.mark.usefixtures("enable_disable_plxpr")
def test_wires_specified_twice_with_capture(self):
base = qml.QubitUnitary(X, wires=0)
with pytest.warns(
UserWarning,
match="base operator already has wires; values specified through wires kwarg will be ignored.",
):
qml.ControlledQubitUnitary(base, control_wires=[1, 2], wires=3)

@pytest.mark.jax
@pytest.mark.usefixtures("enable_disable_plxpr")
@pytest.mark.parametrize(
"control_wires, wires",
[(0, 1), ([0, 1], [2])],
)
def test_consistency_with_capture(self, control_wires, wires):
base_op = [[0, 1], [1, 0]]

op = qml.ControlledQubitUnitary(base_op, control_wires=control_wires, wires=wires)
assert op.base.wires == Wires(wires)
assert op.control_wires == Wires(control_wires)

@pytest.mark.jax
@pytest.mark.usefixtures("enable_disable_plxpr")
@pytest.mark.parametrize(
"control_wires_1, wires_1, control_wires_2, wires_2",
[
([0, 1], [2], [0, 1, 2], ()), # Pair of configurations to compare
],
)
def test_pairwise_consistency_with_capture(
self, control_wires_1, wires_1, control_wires_2, wires_2
):
base_op = [[0, 1], [1, 0]]

op_1 = qml.ControlledQubitUnitary(base_op, control_wires=control_wires_1, wires=wires_1)
op_2 = qml.ControlledQubitUnitary(base_op, control_wires=control_wires_2, wires=wires_2)

assert op_1.base.wires == Wires(2)
assert op_2.base.wires == Wires(2)

assert op_1.control_wires == Wires([0, 1])
assert op_2.control_wires == Wires([0, 1])
andrijapau marked this conversation as resolved.
Show resolved Hide resolved

def test_initialization_from_matrix_and_operator(self):
base_op = QubitUnitary(X, wires=1)

Expand Down Expand Up @@ -88,7 +140,7 @@ def test_wrong_shape(self):
with pytest.raises(ValueError, match=r"Input unitary must be of shape \(2, 2\)"):
qml.ControlledQubitUnitary(np.eye(4), control_wires=[0, 1], wires=2).matrix()

@pytest.mark.parametrize("target_wire", range(3))
@pytest.mark.parametrize("target_wire", list(range(3)))
andrijapau marked this conversation as resolved.
Show resolved Hide resolved
def test_toffoli(self, target_wire):
"""Test if ControlledQubitUnitary acts like a Toffoli gate when the input unitary is a
single-qubit X. This test allows the target wire to be any of the three wires."""
Expand Down
Loading