diff --git a/pennylane_lightning/lightning_tensor/_tensornet.py b/pennylane_lightning/lightning_tensor/_tensornet.py index 4972a15344..78260363fc 100644 --- a/pennylane_lightning/lightning_tensor/_tensornet.py +++ b/pennylane_lightning/lightning_tensor/_tensornet.py @@ -225,27 +225,27 @@ def _preprocess_state_vector(self, state, device_wires): local_dev_wires = np.array(device_wires.tolist())[::-1] - # generate basis states on subset of qubits via broadcasting - base = np.tile([0,1], 2**(len(local_dev_wires)-1)).astype(dtype=np.int64) - indexes = np.zeros(2**(len(local_dev_wires)), dtype=np.int64) + # generate basis states on subset of qubits via broadcasting + base = np.tile([0, 1], 2 ** (len(local_dev_wires) - 1)).astype(dtype=np.int64) + indexes = np.zeros(2 ** (len(local_dev_wires)), dtype=np.int64) max_dev_wire = self._num_wires - 1 - + # get basis states to alter on full set of qubits for i, wire in enumerate(local_dev_wires): - + # get indices for which the state is changed to input state vector elements - indexes += base * 2**(max_dev_wire-wire) - - if i == len(local_dev_wires)-1: + indexes += base * 2 ** (max_dev_wire - wire) + + if i == len(local_dev_wires) - 1: continue - - two_n = 2**(i+1) - base = base.reshape(-1, two_n*2) - swaper_A = two_n//2 + + two_n = 2 ** (i + 1) + base = base.reshape(-1, two_n * 2) + swaper_A = two_n // 2 swaper_B = swaper_A + two_n - base[:,swaper_A:swaper_B] = base[:,swaper_A:swaper_B][:,::-1] + base[:, swaper_A:swaper_B] = base[:, swaper_A:swaper_B][:, ::-1] base = base.reshape(-1) # get full state vector to be factorized into MPS