diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 6a70dd5b1b2..7c5fd3ecebe 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -69,6 +69,12 @@
Bug fixes 🐛
+* `MottonenStatePreparation` now raises an error if decomposing a broadcasted state vector.
+ [(#4767)](https://github.com/PennyLaneAI/pennylane/pull/4767)
+
+* `BasisStatePreparation` now raises an error if decomposing a broadcasted state vector.
+ [(#4767)](https://github.com/PennyLaneAI/pennylane/pull/4767)
+
* Gradient transforms now work with overridden shot vectors and default qubit.
[(#4795)](https://github.com/PennyLaneAI/pennylane/pull/4795)
@@ -99,6 +105,7 @@ Lillian Frederiksen,
Ankit Khandelwal,
Christina Lee,
Anurav Modak,
+Mudit Pandey,
Matthew Silverman,
David Wierichs,
Justin Woodring,
diff --git a/pennylane/templates/state_preparations/basis.py b/pennylane/templates/state_preparations/basis.py
index 550b5c10a6e..0b309e383cc 100644
--- a/pennylane/templates/state_preparations/basis.py
+++ b/pennylane/templates/state_preparations/basis.py
@@ -54,6 +54,7 @@ def circuit(basis_state):
num_params = 1
num_wires = AnyWires
grad_method = None
+ ndim_params = (1,)
def __init__(self, basis_state, wires, id=None):
basis_state = qml.math.stack(basis_state)
@@ -109,6 +110,13 @@ def compute_decomposition(basis_state, wires): # pylint: disable=arguments-diff
[PauliX(wires=['a']),
PauliX(wires=['b'])]
"""
+ if len(qml.math.shape(basis_state)) > 1:
+ raise ValueError(
+ "Broadcasting with BasisStatePreparation is not supported. Please use the "
+ "qml.transforms.broadcast_expand transform to use broadcasting with "
+ "BasisStatePreparation."
+ )
+
if not qml.math.is_abstract(basis_state):
op_list = []
for wire, state in zip(wires, basis_state):
diff --git a/pennylane/templates/state_preparations/mottonen.py b/pennylane/templates/state_preparations/mottonen.py
index 1818bd7bac5..e8440ff83bc 100644
--- a/pennylane/templates/state_preparations/mottonen.py
+++ b/pennylane/templates/state_preparations/mottonen.py
@@ -286,6 +286,7 @@ def circuit(state):
num_wires = AnyWires
grad_method = None
+ ndim_params = (1,)
def __init__(self, state_vector, wires, id=None):
# check if the `state_vector` param is batched
@@ -347,6 +348,12 @@ def compute_decomposition(state_vector, wires): # pylint: disable=arguments-dif
CNOT(wires=['a', 'b']),
CNOT(wires=['a', 'b'])]
"""
+ if len(qml.math.shape(state_vector)) > 1:
+ raise ValueError(
+ "Broadcasting with MottonenStatePreparation is not supported. Please use the "
+ "qml.transforms.broadcast_expand transform to use broadcasting with "
+ "MottonenStatePreparation."
+ )
a = qml.math.abs(state_vector)
omega = qml.math.angle(state_vector)
diff --git a/tests/templates/test_state_preparations/test_basis_state_prep.py b/tests/templates/test_state_preparations/test_basis_state_prep.py
index b63b01a33f7..24b58e906d8 100644
--- a/tests/templates/test_state_preparations/test_basis_state_prep.py
+++ b/tests/templates/test_state_preparations/test_basis_state_prep.py
@@ -153,6 +153,18 @@ def circuit2():
assert np.allclose(res1, res2, atol=tol, rtol=0)
assert np.allclose(state1, state2, atol=tol, rtol=0)
+ def test_batched_decomposition_fails(self):
+ """Test that attempting to decompose a BasisStatePreparation operation with
+ broadcasting raises an error."""
+ state = np.array([[1, 0], [1, 1]])
+
+ op = qml.BasisStatePreparation(state, wires=[0, 1])
+ with pytest.raises(ValueError, match="Broadcasting with BasisStatePreparation"):
+ _ = op.decomposition()
+
+ with pytest.raises(ValueError, match="Broadcasting with BasisStatePreparation"):
+ _ = qml.BasisStatePreparation.compute_decomposition(state, qml.wires.Wires([0, 1]))
+
class TestInputs:
"""Test inputs and pre-processing."""
diff --git a/tests/templates/test_state_preparations/test_mottonen_state_prep.py b/tests/templates/test_state_preparations/test_mottonen_state_prep.py
index f7ac7bb3844..e27977c106c 100644
--- a/tests/templates/test_state_preparations/test_mottonen_state_prep.py
+++ b/tests/templates/test_state_preparations/test_mottonen_state_prep.py
@@ -289,6 +289,18 @@ def circuit2():
assert np.allclose(res1, res2, atol=tol, rtol=0)
assert np.allclose(state1, state2, atol=tol, rtol=0)
+ def test_batched_decomposition_fails(self):
+ """Test that attempting to decompose a MottonenStatePreparation operation with
+ broadcasting raises an error."""
+ state = np.array([[1 / 2, 1 / 2, 1 / 2, 1 / 2], [0.0, 0.0, 0.0, 1.0]])
+
+ op = qml.MottonenStatePreparation(state, wires=[0, 1])
+ with pytest.raises(ValueError, match="Broadcasting with MottonenStatePreparation"):
+ _ = op.decomposition()
+
+ with pytest.raises(ValueError, match="Broadcasting with MottonenStatePreparation"):
+ _ = qml.MottonenStatePreparation.compute_decomposition(state, qml.wires.Wires([0, 1]))
+
class TestInputs:
"""Test inputs and pre-processing."""
diff --git a/tests/transforms/test_batch_input.py b/tests/transforms/test_batch_input.py
index d30a12cc06c..c642cf15ac9 100644
--- a/tests/transforms/test_batch_input.py
+++ b/tests/transforms/test_batch_input.py
@@ -201,6 +201,43 @@ def circuit2(data, weights):
assert np.allclose(res, indiv_res)
+def test_basis_state_preparation(mocker):
+ """Test that batching works for BasisStatePreparation"""
+ dev = qml.device("default.qubit", wires=3)
+
+ @partial(qml.batch_input, argnum=0)
+ @qml.qnode(dev, interface="autograd")
+ def circuit(data, weights):
+ qml.templates.BasisStatePreparation(data, wires=[0, 1, 2])
+ qml.templates.StronglyEntanglingLayers(weights, wires=[0, 1, 2])
+ return qml.probs(wires=[0, 1, 2])
+
+ batch_size = 3
+
+ # create a batched input statevector
+ data = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], requires_grad=False)
+
+ # weights is not batched
+ weights = np.random.random((10, 3, 3), requires_grad=True)
+
+ spy = mocker.spy(circuit.device, "execute")
+ res = circuit(data, weights)
+ assert res.shape == (batch_size, 2**3)
+ assert len(spy.call_args[0][0]) == batch_size
+
+ # check the results against individually executed circuits (no batching)
+ @qml.qnode(dev)
+ def circuit2(data, weights):
+ qml.templates.BasisStatePreparation(data, wires=[0, 1, 2])
+ qml.templates.StronglyEntanglingLayers(weights, wires=[0, 1, 2])
+ return qml.probs(wires=[0, 1, 2])
+
+ indiv_res = []
+ for state in data:
+ indiv_res.append(circuit2(state, weights))
+ assert np.allclose(res, indiv_res)
+
+
def test_qubit_state_prep(mocker):
"""Test that batching works for StatePrep"""
diff --git a/tests/transforms/test_batch_params.py b/tests/transforms/test_batch_params.py
index 36accf7bb25..14e21384678 100644
--- a/tests/transforms/test_batch_params.py
+++ b/tests/transforms/test_batch_params.py
@@ -279,14 +279,14 @@ def test_shot_vector():
@qml.batch_params
@qml.qnode(dev)
def circuit(data, x, weights):
- qml.templates.AmplitudeEmbedding(data, wires=[0, 1, 2], normalize=True)
+ qml.templates.AngleEmbedding(data, wires=[0, 1, 2])
qml.RX(x, wires=0)
qml.RY(0.2, wires=1)
qml.templates.StronglyEntanglingLayers(weights, wires=[0, 1, 2])
return qml.probs(wires=[0, 2])
batch_size = 6
- data = np.random.random((batch_size, 8))
+ data = np.random.random((batch_size, 3))
x = np.linspace(0.1, 0.5, batch_size, requires_grad=True)
weights = np.ones((batch_size, 10, 3, 3), requires_grad=True)
@@ -306,14 +306,14 @@ def test_multi_returns_shot_vector():
@qml.batch_params
@qml.qnode(dev)
def circuit(data, x, weights):
- qml.templates.AmplitudeEmbedding(data, wires=[0, 1, 2], normalize=True)
+ qml.templates.AngleEmbedding(data, wires=[0, 1, 2])
qml.RX(x, wires=0)
qml.RY(0.2, wires=1)
qml.templates.StronglyEntanglingLayers(weights, wires=[0, 1, 2])
return qml.expval(qml.PauliZ(0)), qml.probs(wires=[0, 2])
batch_size = 6
- data = np.random.random((batch_size, 8))
+ data = np.random.random((batch_size, 3))
x = np.linspace(0.1, 0.5, batch_size, requires_grad=True)
weights = np.ones((batch_size, 10, 3, 3), requires_grad=True)