diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index c9d57b62fed..7ff7d86ba42 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -264,7 +264,9 @@ def apply_cnot(op: qml.CNOT, state, is_state_batched: bool = False, debugger=Non @apply_operation.register -def apply_multicontrolledx(op: qml.MultiControlledX, state, is_state_batched: bool = False, debugger=None): +def apply_multicontrolledx( + op: qml.MultiControlledX, state, is_state_batched: bool = False, debugger=None +): r"""Apply MultiControlledX to state by composing transpositions, rolling of control axes and the CNOT logic above.""" if len(op.wires) < 8: @@ -272,22 +274,31 @@ def apply_multicontrolledx(op: qml.MultiControlledX, state, is_state_batched: bo ctrl_wires = [w + is_state_batched for w in op.hyperparameters["control_wires"]] # apply x on all "wrong" controls - roll_axes = [w for val, w in zip(op.hyperparameters["control_values"], ctrl_wires) if val=="0"] + roll_axes = [ + w for val, w in zip(op.hyperparameters["control_values"], ctrl_wires) if val == "0" + ] state = math.roll(state, 1, roll_axes) orig_shape = math.shape(state) # Move the axes into the order [(batch), other, target, controls] - transpose_axes = np.array([w for w in range(len(orig_shape)) if w not in op.wires] + [op.total_wires[-1]] + op.total_wires[:-1].tolist()) + is_state_batched - #transpose_axes = + is_state_batched for w in transpose_axes] + transpose_axes = ( + np.array( + [w for w in range(len(orig_shape)) if w not in op.wires] + + [op.total_wires[-1]] + + op.total_wires[:-1].tolist() + ) + + is_state_batched + ) + # transpose_axes = + is_state_batched for w in transpose_axes] state = math.transpose(state, transpose_axes) # Reshape the state into 3-dimensional array with axes [batch+other, target, controls] - state = state.reshape((-1, 2, 2**(len(op.wires)-1))) + state = state.reshape((-1, 2, 2 ** (len(op.wires) - 1))) # The part of the state to which we want to apply PauliX is now in the last entry along # the third axis. Extract it, apply the PauliX along the target axis, and append a dummy axis state_x = math.roll(state[:, :, -1], 1, 1)[:, :, np.newaxis] - + # Stack the transformed part of the state with the unmodified rest of the state state = math.concatenate([state[:, :, :-1], state_x], axis=2) @@ -297,6 +308,7 @@ def apply_multicontrolledx(op: qml.MultiControlledX, state, is_state_batched: bo # revert x on all "wrong" controls return math.roll(state, 1, roll_axes) + @apply_operation.register def apply_grover(op: qml.GroverOperator, state, is_state_batched: bool = False, debugger=None): r"""Apply GroverOperator to state. This method uses that this operator diff --git a/tests/devices/qubit/test_apply_operation.py b/tests/devices/qubit/test_apply_operation.py index e92b47e52d3..81b52ba1f79 100644 --- a/tests/devices/qubit/test_apply_operation.py +++ b/tests/devices/qubit/test_apply_operation.py @@ -831,6 +831,7 @@ def test_double_excitation(self, method): assert qml.math.allclose(state_v1, state_v2) + @pytest.mark.parametrize("num_wires, einsum_called", [(3, True), (8, False)]) def test_multicontrolledx_dispatching(num_wires, einsum_called, mocker): """Test that apply_multicontrolledx dispatches to einsum for small numbers of wires."""