Skip to content

Commit

Permalink
Add MCM support initial work.
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentmr committed Mar 19, 2024
1 parent 3b5f937 commit cf2a6dd
Show file tree
Hide file tree
Showing 5 changed files with 485 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) {
}
},
"Copy StateVector data into a Numpy array.")
.def("collapse", &StateVectorT::collapse,
"Collapse the statevector onto the 0 or 1 branch of a given wire.")
.def("normalize", &StateVectorT::normalize,
"Normalizes the statevector to norm 1.")
.def("applyControlledMatrix", &applyControlledMatrix<StateVectorT>,
"Apply controlled operation")
.def("kernel_map", &svKernelMap<StateVectorT>,
Expand Down
4 changes: 1 addition & 3 deletions pennylane_lightning/lightning_qubit/_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@

from pennylane_lightning.core._serialize import QuantumScriptSerializer

from ._state_vector import LightningStateVector


class LightningMeasurements:
"""Lightning Measurements class
Expand All @@ -72,7 +70,7 @@ class LightningMeasurements:

def __init__(
self,
qubit_state: LightningStateVector,
qubit_state,
mcmc: bool = None,
kernel_name: str = None,
num_burnin: int = None,
Expand Down
50 changes: 41 additions & 9 deletions pennylane_lightning/lightning_qubit/_state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@
import numpy as np
import pennylane as qml
from pennylane import BasisState, DeviceError, StatePrep
from pennylane.measurements import MidMeasureMP
from pennylane.ops import Conditional
from pennylane.ops.op_math import Adjoint
from pennylane.tape import QuantumScript
from pennylane.wires import Wires

from ._measurements import LightningMeasurements


class LightningStateVector:
"""Lightning state-vector class.
Expand Down Expand Up @@ -220,10 +224,10 @@ def _apply_lightning_controlled(self, operation):
"""Apply an arbitrary controlled operation to the state tensor.
Args:
operation (~pennylane.operation.Operation): operation to apply
operation (~pennylane.operation.Operation): controlled operation to apply
Returns:
array[complex]: the output state tensor
None
"""
state = self.state_vector

Expand All @@ -246,14 +250,36 @@ def _apply_lightning_controlled(self, operation):
False,
)

def _apply_lightning(self, operations):
def _apply_lightning_midmeasure(self, operation: MidMeasureMP, mid_measurements: dict):
"""Execute a MidMeasureMP operation and return the sample in mid_measurements.
Args:
operation (~pennylane.operation.Operation): mid-circuit measurement
mid_measurements (None, dict): Dictionary of mid-circuit measurements
Returns:
None
"""
wires = self.wires.indices(operation.wires)
wire = list(wires)[0]
circuit = QuantumScript([], [qml.sample(wires=operation.wires)], shots=1)
sample = LightningMeasurements(self).measure_final_state(circuit)
sample = np.squeeze(sample)
if operation.postselect is not None and sample != operation.postselect:
mid_measurements[operation] = -1
return
mid_measurements[operation] = sample
getattr(self.state_vector, "collapse")(wire, bool(sample))
if operation.reset and bool(sample):
self.apply_operations([qml.PauliX(operation.wires)], mid_measurements=mid_measurements)

def _apply_lightning(self, operations, mid_measurements: dict = None):
"""Apply a list of operations to the state tensor.
Args:
operations (list[~pennylane.operation.Operation]): operations to apply
mid_measurements (None, dict): Dictionary of mid-circuit measurements
Returns:
array[complex]: the output state tensor
None
"""
state = self.state_vector

Expand All @@ -271,7 +297,12 @@ def _apply_lightning(self, operations):
method = getattr(state, name, None)
wires = list(operation.wires)

if method is not None: # apply specialized gate
if isinstance(operation, Conditional):
if operation.meas_val.concretize(mid_measurements):
self._apply_lightning([operation.then_op])
elif isinstance(operation, MidMeasureMP):
self._apply_lightning_midmeasure(operation, mid_measurements)
elif method is not None: # apply specialized gate
param = operation.parameters
method(wires, invert_param, param)
elif isinstance(operation, qml.ops.Controlled): # apply n-controlled gate
Expand All @@ -286,7 +317,7 @@ def _apply_lightning(self, operations):
# To support older versions of PL
method(operation.matrix, wires, False)

def apply_operations(self, operations):
def apply_operations(self, operations, mid_measurements: dict = None):
"""Applies operations to the state vector."""
# State preparation is currently done in Python
if operations: # make sure operations[0] exists
Expand All @@ -297,21 +328,22 @@ def apply_operations(self, operations):
self._apply_basis_state(operations[0].parameters[0], operations[0].wires)
operations = operations[1:]

self._apply_lightning(operations)
self._apply_lightning(operations, mid_measurements=mid_measurements)

def get_final_state(self, circuit: QuantumScript):
def get_final_state(self, circuit: QuantumScript, mid_measurements: dict = None):
"""
Get the final state that results from executing the given quantum script.
This is an internal function that will be called by the successor to ``lightning.qubit``.
Args:
circuit (QuantumScript): The single circuit to simulate
mid_measurements (None, dict): Dictionary of mid-circuit measurements
Returns:
LightningStateVector: Lightning final state class.
"""
self.apply_operations(circuit.operations)
self.apply_operations(circuit.operations, mid_measurements=mid_measurements)

return self
16 changes: 15 additions & 1 deletion pennylane_lightning/lightning_qubit/lightning_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
from pennylane.devices.modifiers import simulator_tracking, single_tape_support
from pennylane.devices.preprocess import (
decompose,
mid_circuit_measurements,
no_sampling,
validate_adjoint_trainable_params,
validate_device_wires,
validate_measurements,
validate_observables,
)
from pennylane.measurements import MidMeasureMP
from pennylane.tape import QuantumScript, QuantumTape
from pennylane.transforms.core import TransformProgram
from pennylane.typing import Result, ResultBatch
Expand Down Expand Up @@ -71,6 +73,16 @@ def simulate(circuit: QuantumScript, state: LightningStateVector, mcmc: dict = N
if mcmc is None:
mcmc = {}
state.reset_state()
has_mcm = any(isinstance(op, MidMeasureMP) for op in circuit.operations)
if circuit.shots and has_mcm:
mid_measurements = {}
final_state = state.get_final_state(circuit, mid_measurements=mid_measurements)
if any(v == -1 for v in mid_measurements.values()):
return None, mid_measurements
return (
LightningMeasurements(final_state, **mcmc).measure_final_state(circuit),
mid_measurements,
)
final_state = state.get_final_state(circuit)
return LightningMeasurements(final_state, **mcmc).measure_final_state(circuit)

Expand Down Expand Up @@ -200,6 +212,8 @@ def simulate_and_jacobian(circuit: QuantumTape, state: LightningStateVector, bat
"QFT",
"ECR",
"BlockEncode",
"MidMeasureMP",
"Conditional",
}
)
# The set of supported operations.
Expand Down Expand Up @@ -432,7 +446,7 @@ def preprocess(self, execution_config: ExecutionConfig = DefaultExecutionConfig)
program.add_transform(validate_measurements, name=self.name)
program.add_transform(validate_observables, accepted_observables, name=self.name)
program.add_transform(validate_device_wires, self.wires, name=self.name)
program.add_transform(qml.defer_measurements, device=self)
program.add_transform(mid_circuit_measurements, device=self)
program.add_transform(decompose, stopping_condition=stopping_condition, name=self.name)
program.add_transform(qml.transforms.broadcast_expand)

Expand Down
Loading

0 comments on commit cf2a6dd

Please sign in to comment.