From a7036a9eadf2e7c5e1b558afcc808100014eb28e Mon Sep 17 00:00:00 2001 From: KetpuntoG <65235481+KetpuntoG@users.noreply.github.com> Date: Tue, 17 Oct 2023 13:57:01 -0400 Subject: [PATCH] fixing tests --- pennylane/templates/state_preparations/cosine_window.py | 6 ++++-- .../templates/test_state_preparations/test_cosine_window.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pennylane/templates/state_preparations/cosine_window.py b/pennylane/templates/state_preparations/cosine_window.py index d017fda7911..97274bb5976 100644 --- a/pennylane/templates/state_preparations/cosine_window.py +++ b/pennylane/templates/state_preparations/cosine_window.py @@ -95,7 +95,7 @@ def compute_decomposition(wires): def state_vector(self, wire_order=None): num_op_wires = len(self.wires) - op_vector_shape = (-1,) + (2,) * num_op_wires + op_vector_shape = (-1,) + (2,) * num_op_wires if self.batch_size else (2,) * num_op_wires vector = np.array( [ np.sqrt(2) @@ -118,6 +118,9 @@ def state_vector(self, wire_order=None): [Ellipsis] + [slice(None)] * num_op_wires + [0] * (num_total_wires - num_op_wires) ) ket_shape = [2] * num_total_wires + if self.batch_size: + # Add broadcasted dimension to the shape of the state vector + ket_shape = [self.batch_size] + ket_shape ket = np.zeros(ket_shape, dtype=np.complex128) ket[indices] = op_vector @@ -130,7 +133,6 @@ def state_vector(self, wire_order=None): # If the operation is broadcasted, the desired order must include the batch dimension # as the first dimension. desired_order = [0] + [d + 1 for d in desired_order] - ket = ket.transpose(desired_order) return math.convert_like(ket, op_vector) diff --git a/tests/templates/test_state_preparations/test_cosine_window.py b/tests/templates/test_state_preparations/test_cosine_window.py index 37015529e49..f91321624c4 100644 --- a/tests/templates/test_state_preparations/test_cosine_window.py +++ b/tests/templates/test_state_preparations/test_cosine_window.py @@ -20,6 +20,7 @@ import pennylane as qml from pennylane.wires import WireError + class TestDecomposition: """Tests that the template defines the correct decomposition."""