diff --git a/pennylane_lightning/lightning_qubit/_measurements.py b/pennylane_lightning/lightning_qubit/_measurements.py index 264a4b2bda..52d3c154b4 100644 --- a/pennylane_lightning/lightning_qubit/_measurements.py +++ b/pennylane_lightning/lightning_qubit/_measurements.py @@ -248,7 +248,7 @@ def measurement(self, measurementprocess: MeasurementProcess) -> TensorLike: """ return self.get_measurement_function(measurementprocess)(measurementprocess) - def measure_final_state(self, circuit: QuantumScript) -> Result: + def measure_final_state(self, circuit: QuantumScript, mid_measurements=None) -> Result: """ Perform the measurements required by the circuit on the provided state. @@ -256,6 +256,7 @@ def measure_final_state(self, circuit: QuantumScript) -> Result: Args: circuit (QuantumScript): The single circuit to simulate + mid_measurements (None, dict): Dictionary of mid-circuit measurements Returns: Tuple[TensorLike]: The measurement results @@ -272,6 +273,7 @@ def measure_final_state(self, circuit: QuantumScript) -> Result: results = self.measure_with_samples( circuit.measurements, shots=circuit.shots, + mid_measurements=mid_measurements, ) if len(circuit.measurements) == 1: @@ -285,8 +287,9 @@ def measure_final_state(self, circuit: QuantumScript) -> Result: # pylint:disable = too-many-arguments def measure_with_samples( self, - mps: List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]], + measurements: List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]], shots: Shots, + mid_measurements=None, ) -> List[TensorLike]: """ Returns the samples of the measurement process performed on the given state. @@ -294,18 +297,27 @@ def measure_with_samples( have already been mapped to integer wires used in the device. Args: - mps (List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]]): + measurements (List[Union[SampleMeasurement, ClassicalShadowMP, ShadowExpvalMP]]): The sample measurements to perform shots (Shots): The number of samples to take + mid_measurements (None, dict): Dictionary of mid-circuit measurements Returns: List[TensorLike[Any]]: Sample measurement results """ + # last N measurements are sampling MCMs in ``dynamic_one_shot`` execution mode + mps = measurements[0 : -len(mid_measurements)] if mid_measurements else measurements + skip_measure = ( + any(v == -1 for v in mid_measurements.values()) if mid_measurements else False + ) groups, indices = _group_measurements(mps) all_res = [] for group in groups: + if skip_measure: + all_res.extend([None] * len(group)) + continue if isinstance(group[0], (ExpectationMP, VarianceMP)) and isinstance( group[0].obs, SparseHamiltonian ): @@ -333,6 +345,10 @@ def measure_with_samples( res for _, res in sorted(list(enumerate(all_res)), key=lambda r: flat_indices[r[0]]) ) + # append MCM samples + if mid_measurements: + sorted_res += tuple(mid_measurements.values()) + # put the shot vector axis before the measurement axis if shots.has_partitioned_shots: sorted_res = tuple(zip(*sorted_res)) diff --git a/pennylane_lightning/lightning_qubit/lightning_qubit.py b/pennylane_lightning/lightning_qubit/lightning_qubit.py index bfefacb4bd..8350804c18 100644 --- a/pennylane_lightning/lightning_qubit/lightning_qubit.py +++ b/pennylane_lightning/lightning_qubit/lightning_qubit.py @@ -80,11 +80,8 @@ def simulate(circuit: QuantumScript, state: LightningStateVector, mcmc: dict = N if circuit.shots and has_mcm: mid_measurements = {} final_state = state.get_final_state(circuit, mid_measurements=mid_measurements) - if any(v == -1 for v in mid_measurements.values()): - return None, mid_measurements - return ( - LightningMeasurements(final_state, **mcmc).measure_final_state(circuit), - mid_measurements, + return LightningMeasurements(final_state, **mcmc).measure_final_state( + circuit, mid_measurements=mid_measurements ) final_state = state.get_final_state(circuit) return LightningMeasurements(final_state, **mcmc).measure_final_state(circuit)