Skip to content

Commit

Permalink
More update
Browse files Browse the repository at this point in the history
  • Loading branch information
rmoyard committed Oct 17, 2023
1 parent b45f808 commit e6eb319
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 179 deletions.
2 changes: 1 addition & 1 deletion pennylane/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pennylane.boolean_fn import BooleanFn
from pennylane.queuing import QueuingManager, apply

import pennylane.fourier
import pennylane.kernels
import pennylane.math
import pennylane.operation
Expand Down Expand Up @@ -123,6 +122,7 @@
from pennylane.shadows import ClassicalShadow
import pennylane.pulse

import pennylane.fourier
import pennylane.gradients # pylint:disable=wrong-import-order
import pennylane.qinfo # pylint:disable=wrong-import-order
from pennylane.interfaces import execute # pylint:disable=wrong-import-order
Expand Down
20 changes: 12 additions & 8 deletions pennylane/fourier/circuit_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@
"""Contains a transform that computes the simple frequency spectrum
of a quantum circuit, that is the frequencies without considering
preprocessing in the QNode."""
from functools import wraps
from typing import Sequence, Callable
from functools import partial
from .utils import get_spectrum, join_spectra
from pennylane.transforms.core import transform
from pennylane.tape import QuantumTape


def circuit_spectrum(qnode, encoding_gates=None, decimals=8):
@partial(transform, is_informative=True)
def circuit_spectrum(
tape: QuantumTape, encoding_gates=None, decimals=8
) -> (Sequence[QuantumTape], Callable):
r"""Compute the frequency spectrum of the Fourier representation of
simple quantum circuits ignoring classical preprocessing.
Expand All @@ -42,7 +48,7 @@ def circuit_spectrum(qnode, encoding_gates=None, decimals=8):
If no input-encoding gates are found, an empty dictionary is returned.
Args:
qnode (pennylane.QNode): a quantum node representing a circuit in which
tape (QuantumTape): a quantum node representing a circuit in which
input-encoding gates are marked by their ``id`` attribute
encoding_gates (list[str]): list of input-encoding gate ``id`` strings
for which to compute the frequency spectra
Expand Down Expand Up @@ -178,10 +184,8 @@ def circuit(x):
"""

@wraps(qnode)
def wrapper(*args, **kwargs):
qnode.construct(args, kwargs)
tape = qnode.qtape
def processing_fn(tapes):
tape = tapes[0]
freqs = {}
for op in tape.operations:
id = op.id
Expand Down Expand Up @@ -221,4 +225,4 @@ def wrapper(*args, **kwargs):

return freqs

return wrapper
return [tape], processing_fn
37 changes: 34 additions & 3 deletions pennylane/gradients/parameter_shift_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@
"""
import itertools as it
import warnings
from functools import partial
from typing import Sequence, Callable
from .hessian_transform import _process_jacs

import pennylane as qml
from pennylane import numpy as np
from pennylane.measurements import ProbabilityMP, StateMP, VarianceMP
from pennylane.transforms import transform

from .general_shift_rules import (
_combine_shift_rules,
generate_multishifted_tapes,
generate_shifted_tapes,
)
from .gradient_transform import gradient_analysis_and_validation
from .hessian_transform import hessian_transform
from .parameter_shift import _get_operation_recipe


Expand Down Expand Up @@ -363,8 +366,36 @@ def processing_fn(results):
return hessian_tapes, processing_fn


@hessian_transform
def param_shift_hessian(tape, argnum=None, diagonal_shifts=None, off_diagonal_shifts=None, f0=None):
# pylint: disable=too-many-return-statements,too-many-branches
def _contract_qjac_with_cjac(qhess, cjac, tape):
"""Contract a quantum Jacobian with a classical preprocessing Jacobian."""
if len(tape.measurements) > 1:
qhess = qhess[0]
has_single_arg = False
if not isinstance(cjac, tuple):
has_single_arg = True
cjac = (cjac,)

# The classical Jacobian for each argument has shape:
# (# gate_args, *qnode_arg_shape)
# The Jacobian needs to be contracted twice with the quantum Hessian of shape:
# (*qnode_output_shape, # gate_args, # gate_args)
# The result should then have the shape:
# (*qnode_output_shape, *qnode_arg_shape, *qnode_arg_shape)
hessians = []

for jac in cjac:
if jac is not None:
hess = _process_jacs(jac, qhess)
hessians.append(hess)

return hessians[0] if has_single_arg else tuple(hessians)


@partial(transform, classical_cotransform=_contract_qjac_with_cjac, final_transform=True)
def param_shift_hessian(
tape: qml.tape.QuantumTape, argnum=None, diagonal_shifts=None, off_diagonal_shifts=None, f0=None
) -> (Sequence[qml.tape.QuantumTape], Callable):
r"""Transform a QNode to compute the parameter-shift Hessian with respect to its trainable
parameters. This is the Hessian transform to replace the old one in the new return types system
Expand Down
58 changes: 25 additions & 33 deletions pennylane/ops/functions/map_wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
"""
This module contains the qml.map_wires function.
"""
from functools import wraps
from typing import Callable, Union
from functools import partial
from typing import Callable, Union, Sequence

import pennylane as qml
from pennylane.measurements import MeasurementProcess
from pennylane.operation import Operator
from pennylane.qnode import QNode
from pennylane.queuing import QueuingManager
from pennylane.tape import QuantumScript, make_qscript, QuantumTape
from pennylane.tape import QuantumScript, QuantumTape
from pennylane.transforms.core import transform


def map_wires(
Expand Down Expand Up @@ -97,35 +98,26 @@ def map_wires(
qml.apply(new_op)
return new_op
return input.map_wires(wire_map=wire_map)

if isinstance(input, QuantumScript):
ops = [qml.map_wires(op, wire_map) for op in input.operations]
measurements = [qml.map_wires(m, wire_map) for m in input.measurements]

out = input.__class__(ops=ops, measurements=measurements, shots=input.shots)
out.trainable_params = input.trainable_params
return out

if callable(input):
func = input.func if isinstance(input, QNode) else input

@wraps(func)
def qfunc(*args, **kwargs):
qscript = make_qscript(func)(*args, **kwargs)
_ = [qml.map_wires(op, wire_map=wire_map, queue=True) for op in qscript.operations]
m = tuple(qml.map_wires(m, wire_map=wire_map, queue=True) for m in qscript.measurements)
return m[0] if len(m) == 1 else m

if isinstance(input, QNode):
return QNode(
func=qfunc,
device=input.device,
interface=input.interface,
diff_method=input.diff_method,
expansion_strategy=input.expansion_strategy,
**input.execute_kwargs,
**input.gradient_kwargs,
)
return qfunc
elif isinstance(input, (QuantumScript, QNode)) or callable(input):
return _map_wires_transform(input, wire_map=wire_map)

raise ValueError(f"Cannot map wires of object {input} of type {type(input)}.")


@partial(transform)
def _map_wires_transform(
tape: qml.tape.QuantumTape, wire_map=None
) -> (Sequence[qml.tape.QuantumTape], Callable):
ops = [map_wires(op, wire_map) for op in tape.operations]
measurements = [map_wires(m, wire_map) for m in tape.measurements]

out = tape.__class__(ops=ops, measurements=measurements, shots=tape.shots)
out.trainable_params = tape.trainable_params
print("inc", out.circuit)
print(wire_map)

def processing_fn(res):
"""Defines how matrix works if applied to a tape containing multiple operations."""
return res[0]

return [out], processing_fn
99 changes: 48 additions & 51 deletions pennylane/shadows/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Classical shadow transforms"""

import warnings
from functools import reduce, wraps
from functools import reduce, wraps, partial
from itertools import product
from typing import Sequence, Callable

Expand Down Expand Up @@ -52,7 +52,8 @@ def processing_fn(res):
return [qscript], processing_fn


def shadow_expval(H, k=1):
@partial(transform, final_transform=True)
def shadow_expval(tape: QuantumTape, H, k=1) -> (Sequence[QuantumTape], Callable):
"""Transform a QNode returning a classical shadow into one that returns
the approximate expectation values in a differentiable manner.
Expand Down Expand Up @@ -88,14 +89,15 @@ def circuit(x):
>>> qml.grad(circuit)(x)
-0.9323999999999998
"""
tapes, _ = _replace_obs(tape, qml.shadow_expval, H, k=k)

def decorator(qnode):
return _replace_obs(qnode, qml.shadow_expval, H, k=k)
def post_processing_fn(res):
return res

return decorator
return tapes, post_processing_fn


def _shadow_state_diffable(wires):
def _shadow_state_diffable(tape, wires):
"""Differentiable version of the shadow state transform"""
wires_list = wires if isinstance(wires[0], list) else [wires]

Expand All @@ -117,63 +119,55 @@ def _shadow_state_diffable(wires):
observables.append(reduce(lambda a, b: a @ b, [ob(wire) for ob, wire in zip(obs, w)]))
all_observables.extend(observables)

def decorator(qnode):
new_qnode = _replace_obs(qnode, qml.shadow_expval, all_observables)

@wraps(qnode)
def wrapper(*args, **kwargs):
# pylint: disable=not-callable
results = new_qnode(*args, **kwargs)

# cast to complex
results = qml.math.cast(results, np.complex64)

states = []
start = 0
for w in wires_list:
# reconstruct the state given the observables and the expectations of
# those observables

obs_matrices = qml.math.stack(
[
qml.math.cast_like(qml.math.convert_like(qml.matrix(obs), results), results)
for obs in all_observables[start : start + 4 ** len(w)]
]
)

s = qml.math.einsum(
"a,abc->bc", results[start : start + 4 ** len(w)], obs_matrices
) / 2 ** len(w)
states.append(s)
tapes, _ = _replace_obs(tape, qml.shadow_expval, all_observables)

def post_processing_fn(results):
"""Post process the classical shadows."""
results = results[0]
# cast to complex
results = qml.math.cast(results, np.complex64)

states = []
start = 0
for w in wires_list:
# reconstruct the state given the observables and the expectations of
# those observables

obs_matrices = qml.math.stack(
[
qml.math.cast_like(qml.math.convert_like(qml.matrix(obs), results), results)
for obs in all_observables[start : start + 4 ** len(w)]
]
)

start += 4 ** len(w)
s = qml.math.einsum(
"a,abc->bc", results[start : start + 4 ** len(w)], obs_matrices
) / 2 ** len(w)
states.append(s)

return states if isinstance(wires[0], list) else states[0]
start += 4 ** len(w)

return wrapper
return states if isinstance(wires[0], list) else states[0]

return decorator
return tapes, post_processing_fn


def _shadow_state_undiffable(wires):
def _shadow_state_undiffable(tape, wires):
"""Non-differentiable version of the shadow state transform"""
wires_list = wires if isinstance(wires[0], list) else [wires]

def decorator(qnode):
@wraps(qnode)
def wrapper(*args, **kwargs):
bits, recipes = qnode(*args, **kwargs)
shadow = qml.shadows.ClassicalShadow(bits, recipes)

states = [qml.math.mean(shadow.global_snapshots(wires=w), 0) for w in wires_list]
return states if isinstance(wires[0], list) else states[0]
def post_processing(results):
bits, recipes = results[0]
shadow = qml.shadows.ClassicalShadow(bits, recipes)

return wrapper
states = [qml.math.mean(shadow.global_snapshots(wires=w), 0) for w in wires_list]
return states if isinstance(wires[0], list) else states[0]

return decorator
return [tape], post_processing


def shadow_state(wires, diffable=False):
@partial(transform, final_transform=True)
def shadow_state(tape: QuantumTape, wires, diffable=False) -> (Sequence[QuantumTape], Callable):
"""Transform a QNode returning a classical shadow into one that returns
the reconstructed state in a differentiable manner.
Expand Down Expand Up @@ -221,4 +215,7 @@ def circuit(x):
[ 0.004275, 0.2358 , 0.244875, -0.002175],
[-0.2358 , -0.004275, -0.002175, -0.235125]])
"""
return _shadow_state_diffable(wires) if diffable else _shadow_state_undiffable(wires)
tapes, fn = (
_shadow_state_diffable(tape, wires) if diffable else _shadow_state_undiffable(tape, wires)
)
return tapes, fn
Loading

0 comments on commit e6eb319

Please sign in to comment.