Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise error if broadcasting with MottonenStatePreparation and BasisStatePreparation compute_decomposition #4767

Merged
merged 10 commits into from
Nov 9, 2023
7 changes: 7 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@

<h3>Bug fixes 🐛</h3>

* `MottonenStatePreparation` now raises an error if decomposing a broadcasted state vector.
[(#4767)](https://github.com/PennyLaneAI/pennylane/pull/4767)

* `BasisStatePreparation` now raises an error if decomposing a broadcasted state vector.
[(#4767)](https://github.com/PennyLaneAI/pennylane/pull/4767)

* Gradient transforms now work with overridden shot vectors and default qubit.
[(#4795)](https://github.com/PennyLaneAI/pennylane/pull/4795)

Expand Down Expand Up @@ -99,6 +105,7 @@ Lillian Frederiksen,
Ankit Khandelwal,
Christina Lee,
Anurav Modak,
Mudit Pandey,
Matthew Silverman,
David Wierichs,
Justin Woodring,
8 changes: 8 additions & 0 deletions pennylane/templates/state_preparations/basis.py
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def circuit(basis_state):
num_params = 1
num_wires = AnyWires
grad_method = None
ndim_params = (1,)

def __init__(self, basis_state, wires, id=None):
basis_state = qml.math.stack(basis_state)
Expand Down Expand Up @@ -109,6 +110,13 @@ def compute_decomposition(basis_state, wires): # pylint: disable=arguments-diff
[PauliX(wires=['a']),
PauliX(wires=['b'])]
"""
if len(qml.math.shape(basis_state)) > 1:
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Broadcasting with BasisStatePreparation is not supported. Please use the "
"qml.transforms.broadcast_expand transform to use broadcasting with "
"BasisStatePreparation."
)

if not qml.math.is_abstract(basis_state):
op_list = []
for wire, state in zip(wires, basis_state):
Expand Down
7 changes: 7 additions & 0 deletions pennylane/templates/state_preparations/mottonen.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def circuit(state):

num_wires = AnyWires
grad_method = None
ndim_params = (1,)

def __init__(self, state_vector, wires, id=None):
# check if the `state_vector` param is batched
Expand Down Expand Up @@ -347,6 +348,12 @@ def compute_decomposition(state_vector, wires): # pylint: disable=arguments-dif
CNOT(wires=['a', 'b']),
CNOT(wires=['a', 'b'])]
"""
if len(qml.math.shape(state_vector)) > 1:
raise ValueError(
"Broadcasting with MottonenStatePreparation is not supported. Please use the "
"qml.transforms.broadcast_expand transform to use broadcasting with "
"MottonenStatePreparation."
)

a = qml.math.abs(state_vector)
omega = qml.math.angle(state_vector)
Expand Down
12 changes: 12 additions & 0 deletions tests/templates/test_state_preparations/test_basis_state_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,18 @@ def circuit2():
assert np.allclose(res1, res2, atol=tol, rtol=0)
assert np.allclose(state1, state2, atol=tol, rtol=0)

def test_batched_decomposition_fails(self):
"""Test that attempting to decompose a BasisStatePreparation operation with
broadcasting raises an error."""
state = np.array([[1, 0], [1, 1]])

op = qml.BasisStatePreparation(state, wires=[0, 1])
with pytest.raises(ValueError, match="Broadcasting with BasisStatePreparation"):
_ = op.decomposition()

with pytest.raises(ValueError, match="Broadcasting with BasisStatePreparation"):
_ = qml.BasisStatePreparation.compute_decomposition(state, qml.wires.Wires([0, 1]))


class TestInputs:
"""Test inputs and pre-processing."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,18 @@ def circuit2():
assert np.allclose(res1, res2, atol=tol, rtol=0)
assert np.allclose(state1, state2, atol=tol, rtol=0)

def test_batched_decomposition_fails(self):
"""Test that attempting to decompose a MottonenStatePreparation operation with
broadcasting raises an error."""
state = np.array([[1 / 2, 1 / 2, 1 / 2, 1 / 2], [0.0, 0.0, 0.0, 1.0]])

op = qml.MottonenStatePreparation(state, wires=[0, 1])
with pytest.raises(ValueError, match="Broadcasting with MottonenStatePreparation"):
_ = op.decomposition()

with pytest.raises(ValueError, match="Broadcasting with MottonenStatePreparation"):
_ = qml.MottonenStatePreparation.compute_decomposition(state, qml.wires.Wires([0, 1]))


class TestInputs:
"""Test inputs and pre-processing."""
Expand Down
37 changes: 37 additions & 0 deletions tests/transforms/test_batch_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,43 @@ def circuit2(data, weights):
assert np.allclose(res, indiv_res)


def test_basis_state_preparation(mocker):
"""Test that batching works for BasisStatePreparation"""
dev = qml.device("default.qubit", wires=3)

@partial(qml.batch_input, argnum=0)
@qml.qnode(dev, interface="autograd")
def circuit(data, weights):
qml.templates.BasisStatePreparation(data, wires=[0, 1, 2])
qml.templates.StronglyEntanglingLayers(weights, wires=[0, 1, 2])
return qml.probs(wires=[0, 1, 2])

batch_size = 3

# create a batched input statevector
data = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], requires_grad=False)

# weights is not batched
weights = np.random.random((10, 3, 3), requires_grad=True)

spy = mocker.spy(circuit.device, "execute")
res = circuit(data, weights)
assert res.shape == (batch_size, 2**3)
assert len(spy.call_args[0][0]) == batch_size

# check the results against individually executed circuits (no batching)
@qml.qnode(dev)
def circuit2(data, weights):
qml.templates.BasisStatePreparation(data, wires=[0, 1, 2])
qml.templates.StronglyEntanglingLayers(weights, wires=[0, 1, 2])
return qml.probs(wires=[0, 1, 2])

indiv_res = []
for state in data:
indiv_res.append(circuit2(state, weights))
assert np.allclose(res, indiv_res)


def test_qubit_state_prep(mocker):
"""Test that batching works for StatePrep"""

Expand Down
8 changes: 4 additions & 4 deletions tests/transforms/test_batch_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,14 @@ def test_shot_vector():
@qml.batch_params
@qml.qnode(dev)
def circuit(data, x, weights):
qml.templates.AmplitudeEmbedding(data, wires=[0, 1, 2], normalize=True)
qml.templates.AngleEmbedding(data, wires=[0, 1, 2])
qml.RX(x, wires=0)
qml.RY(0.2, wires=1)
qml.templates.StronglyEntanglingLayers(weights, wires=[0, 1, 2])
return qml.probs(wires=[0, 2])

batch_size = 6
data = np.random.random((batch_size, 8))
data = np.random.random((batch_size, 3))
x = np.linspace(0.1, 0.5, batch_size, requires_grad=True)
weights = np.ones((batch_size, 10, 3, 3), requires_grad=True)

Expand All @@ -306,14 +306,14 @@ def test_multi_returns_shot_vector():
@qml.batch_params
@qml.qnode(dev)
def circuit(data, x, weights):
qml.templates.AmplitudeEmbedding(data, wires=[0, 1, 2], normalize=True)
qml.templates.AngleEmbedding(data, wires=[0, 1, 2])
qml.RX(x, wires=0)
qml.RY(0.2, wires=1)
qml.templates.StronglyEntanglingLayers(weights, wires=[0, 1, 2])
return qml.expval(qml.PauliZ(0)), qml.probs(wires=[0, 2])

batch_size = 6
data = np.random.random((batch_size, 8))
data = np.random.random((batch_size, 3))
x = np.linspace(0.1, 0.5, batch_size, requires_grad=True)
weights = np.ones((batch_size, 10, 3, 3), requires_grad=True)

Expand Down
Loading