Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
dwierichs committed Oct 13, 2023
1 parent 8fcf2e0 commit 51ac2ec
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
24 changes: 18 additions & 6 deletions pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,30 +264,41 @@ def apply_cnot(op: qml.CNOT, state, is_state_batched: bool = False, debugger=Non


@apply_operation.register
def apply_multicontrolledx(op: qml.MultiControlledX, state, is_state_batched: bool = False, debugger=None):
def apply_multicontrolledx(
op: qml.MultiControlledX, state, is_state_batched: bool = False, debugger=None
):
r"""Apply MultiControlledX to state by composing transpositions, rolling of control axes
and the CNOT logic above."""
if len(op.wires) < 8:
return apply_operation_einsum(op, state, is_state_batched)

ctrl_wires = [w + is_state_batched for w in op.hyperparameters["control_wires"]]
# apply x on all "wrong" controls
roll_axes = [w for val, w in zip(op.hyperparameters["control_values"], ctrl_wires) if val=="0"]
roll_axes = [
w for val, w in zip(op.hyperparameters["control_values"], ctrl_wires) if val == "0"
]
state = math.roll(state, 1, roll_axes)

orig_shape = math.shape(state)
# Move the axes into the order [(batch), other, target, controls]
transpose_axes = np.array([w for w in range(len(orig_shape)) if w not in op.wires] + [op.total_wires[-1]] + op.total_wires[:-1].tolist()) + is_state_batched
#transpose_axes = + is_state_batched for w in transpose_axes]
transpose_axes = (
np.array(
[w for w in range(len(orig_shape)) if w not in op.wires]
+ [op.total_wires[-1]]
+ op.total_wires[:-1].tolist()
)
+ is_state_batched
)
# transpose_axes = + is_state_batched for w in transpose_axes]
state = math.transpose(state, transpose_axes)

# Reshape the state into 3-dimensional array with axes [batch+other, target, controls]
state = state.reshape((-1, 2, 2**(len(op.wires)-1)))
state = state.reshape((-1, 2, 2 ** (len(op.wires) - 1)))

# The part of the state to which we want to apply PauliX is now in the last entry along
# the third axis. Extract it, apply the PauliX along the target axis, and append a dummy axis
state_x = math.roll(state[:, :, -1], 1, 1)[:, :, np.newaxis]

# Stack the transformed part of the state with the unmodified rest of the state
state = math.concatenate([state[:, :, :-1], state_x], axis=2)

Expand All @@ -297,6 +308,7 @@ def apply_multicontrolledx(op: qml.MultiControlledX, state, is_state_batched: bo
# revert x on all "wrong" controls
return math.roll(state, 1, roll_axes)


@apply_operation.register
def apply_grover(op: qml.GroverOperator, state, is_state_batched: bool = False, debugger=None):
r"""Apply GroverOperator to state. This method uses that this operator
Expand Down
1 change: 1 addition & 0 deletions tests/devices/qubit/test_apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ def test_double_excitation(self, method):

assert qml.math.allclose(state_v1, state_v2)


@pytest.mark.parametrize("num_wires, einsum_called", [(3, True), (8, False)])
def test_multicontrolledx_dispatching(num_wires, einsum_called, mocker):
"""Test that apply_multicontrolledx dispatches to einsum for small numbers of wires."""
Expand Down

0 comments on commit 51ac2ec

Please sign in to comment.