From e6eb319fd3b3833c1db1d28ed43e9b6cdded5067 Mon Sep 17 00:00:00 2001 From: rmoyard Date: Tue, 17 Oct 2023 00:06:21 -0400 Subject: [PATCH] More update --- pennylane/__init__.py | 2 +- pennylane/fourier/circuit_spectrum.py | 20 ++-- .../gradients/parameter_shift_hessian.py | 37 ++++++- pennylane/ops/functions/map_wires.py | 58 +++++------ pennylane/shadows/transforms.py | 99 +++++++++---------- pennylane/transforms/commutation_dag.py | 42 +++----- .../test_parameter_shift_hessian.py | 33 +++---- tests/ops/functions/test_map_wires.py | 6 +- tests/shadow/test_shadow_transforms.py | 28 +++--- tests/transforms/test_commutation_dag.py | 17 +--- 10 files changed, 163 insertions(+), 179 deletions(-) diff --git a/pennylane/__init__.py b/pennylane/__init__.py index 4b33c0cb011..252542704eb 100644 --- a/pennylane/__init__.py +++ b/pennylane/__init__.py @@ -27,7 +27,6 @@ from pennylane.boolean_fn import BooleanFn from pennylane.queuing import QueuingManager, apply -import pennylane.fourier import pennylane.kernels import pennylane.math import pennylane.operation @@ -123,6 +122,7 @@ from pennylane.shadows import ClassicalShadow import pennylane.pulse +import pennylane.fourier import pennylane.gradients # pylint:disable=wrong-import-order import pennylane.qinfo # pylint:disable=wrong-import-order from pennylane.interfaces import execute # pylint:disable=wrong-import-order diff --git a/pennylane/fourier/circuit_spectrum.py b/pennylane/fourier/circuit_spectrum.py index 79f8c5e1a70..ba440ce9f4f 100644 --- a/pennylane/fourier/circuit_spectrum.py +++ b/pennylane/fourier/circuit_spectrum.py @@ -14,11 +14,17 @@ """Contains a transform that computes the simple frequency spectrum of a quantum circuit, that is the frequencies without considering preprocessing in the QNode.""" -from functools import wraps +from typing import Sequence, Callable +from functools import partial from .utils import get_spectrum, join_spectra +from pennylane.transforms.core import transform +from pennylane.tape import QuantumTape -def circuit_spectrum(qnode, encoding_gates=None, decimals=8): +@partial(transform, is_informative=True) +def circuit_spectrum( + tape: QuantumTape, encoding_gates=None, decimals=8 +) -> (Sequence[QuantumTape], Callable): r"""Compute the frequency spectrum of the Fourier representation of simple quantum circuits ignoring classical preprocessing. @@ -42,7 +48,7 @@ def circuit_spectrum(qnode, encoding_gates=None, decimals=8): If no input-encoding gates are found, an empty dictionary is returned. Args: - qnode (pennylane.QNode): a quantum node representing a circuit in which + tape (QuantumTape): a quantum node representing a circuit in which input-encoding gates are marked by their ``id`` attribute encoding_gates (list[str]): list of input-encoding gate ``id`` strings for which to compute the frequency spectra @@ -178,10 +184,8 @@ def circuit(x): """ - @wraps(qnode) - def wrapper(*args, **kwargs): - qnode.construct(args, kwargs) - tape = qnode.qtape + def processing_fn(tapes): + tape = tapes[0] freqs = {} for op in tape.operations: id = op.id @@ -221,4 +225,4 @@ def wrapper(*args, **kwargs): return freqs - return wrapper + return [tape], processing_fn diff --git a/pennylane/gradients/parameter_shift_hessian.py b/pennylane/gradients/parameter_shift_hessian.py index 4fa48b9ff74..d805f8130a5 100644 --- a/pennylane/gradients/parameter_shift_hessian.py +++ b/pennylane/gradients/parameter_shift_hessian.py @@ -17,10 +17,14 @@ """ import itertools as it import warnings +from functools import partial +from typing import Sequence, Callable +from .hessian_transform import _process_jacs import pennylane as qml from pennylane import numpy as np from pennylane.measurements import ProbabilityMP, StateMP, VarianceMP +from pennylane.transforms import transform from .general_shift_rules import ( _combine_shift_rules, @@ -28,7 +32,6 @@ generate_shifted_tapes, ) from .gradient_transform import gradient_analysis_and_validation -from .hessian_transform import hessian_transform from .parameter_shift import _get_operation_recipe @@ -363,8 +366,36 @@ def processing_fn(results): return hessian_tapes, processing_fn -@hessian_transform -def param_shift_hessian(tape, argnum=None, diagonal_shifts=None, off_diagonal_shifts=None, f0=None): +# pylint: disable=too-many-return-statements,too-many-branches +def _contract_qjac_with_cjac(qhess, cjac, tape): + """Contract a quantum Jacobian with a classical preprocessing Jacobian.""" + if len(tape.measurements) > 1: + qhess = qhess[0] + has_single_arg = False + if not isinstance(cjac, tuple): + has_single_arg = True + cjac = (cjac,) + + # The classical Jacobian for each argument has shape: + # (# gate_args, *qnode_arg_shape) + # The Jacobian needs to be contracted twice with the quantum Hessian of shape: + # (*qnode_output_shape, # gate_args, # gate_args) + # The result should then have the shape: + # (*qnode_output_shape, *qnode_arg_shape, *qnode_arg_shape) + hessians = [] + + for jac in cjac: + if jac is not None: + hess = _process_jacs(jac, qhess) + hessians.append(hess) + + return hessians[0] if has_single_arg else tuple(hessians) + + +@partial(transform, classical_cotransform=_contract_qjac_with_cjac, final_transform=True) +def param_shift_hessian( + tape: qml.tape.QuantumTape, argnum=None, diagonal_shifts=None, off_diagonal_shifts=None, f0=None +) -> (Sequence[qml.tape.QuantumTape], Callable): r"""Transform a QNode to compute the parameter-shift Hessian with respect to its trainable parameters. This is the Hessian transform to replace the old one in the new return types system diff --git a/pennylane/ops/functions/map_wires.py b/pennylane/ops/functions/map_wires.py index dd431e599a9..365e566f530 100644 --- a/pennylane/ops/functions/map_wires.py +++ b/pennylane/ops/functions/map_wires.py @@ -14,15 +14,16 @@ """ This module contains the qml.map_wires function. """ -from functools import wraps -from typing import Callable, Union +from functools import partial +from typing import Callable, Union, Sequence import pennylane as qml from pennylane.measurements import MeasurementProcess from pennylane.operation import Operator from pennylane.qnode import QNode from pennylane.queuing import QueuingManager -from pennylane.tape import QuantumScript, make_qscript, QuantumTape +from pennylane.tape import QuantumScript, QuantumTape +from pennylane.transforms.core import transform def map_wires( @@ -97,35 +98,26 @@ def map_wires( qml.apply(new_op) return new_op return input.map_wires(wire_map=wire_map) - - if isinstance(input, QuantumScript): - ops = [qml.map_wires(op, wire_map) for op in input.operations] - measurements = [qml.map_wires(m, wire_map) for m in input.measurements] - - out = input.__class__(ops=ops, measurements=measurements, shots=input.shots) - out.trainable_params = input.trainable_params - return out - - if callable(input): - func = input.func if isinstance(input, QNode) else input - - @wraps(func) - def qfunc(*args, **kwargs): - qscript = make_qscript(func)(*args, **kwargs) - _ = [qml.map_wires(op, wire_map=wire_map, queue=True) for op in qscript.operations] - m = tuple(qml.map_wires(m, wire_map=wire_map, queue=True) for m in qscript.measurements) - return m[0] if len(m) == 1 else m - - if isinstance(input, QNode): - return QNode( - func=qfunc, - device=input.device, - interface=input.interface, - diff_method=input.diff_method, - expansion_strategy=input.expansion_strategy, - **input.execute_kwargs, - **input.gradient_kwargs, - ) - return qfunc + elif isinstance(input, (QuantumScript, QNode)) or callable(input): + return _map_wires_transform(input, wire_map=wire_map) raise ValueError(f"Cannot map wires of object {input} of type {type(input)}.") + + +@partial(transform) +def _map_wires_transform( + tape: qml.tape.QuantumTape, wire_map=None +) -> (Sequence[qml.tape.QuantumTape], Callable): + ops = [map_wires(op, wire_map) for op in tape.operations] + measurements = [map_wires(m, wire_map) for m in tape.measurements] + + out = tape.__class__(ops=ops, measurements=measurements, shots=tape.shots) + out.trainable_params = tape.trainable_params + print("inc", out.circuit) + print(wire_map) + + def processing_fn(res): + """Defines how matrix works if applied to a tape containing multiple operations.""" + return res[0] + + return [out], processing_fn diff --git a/pennylane/shadows/transforms.py b/pennylane/shadows/transforms.py index 43d6d535eb9..7e4c00d8845 100644 --- a/pennylane/shadows/transforms.py +++ b/pennylane/shadows/transforms.py @@ -14,7 +14,7 @@ """Classical shadow transforms""" import warnings -from functools import reduce, wraps +from functools import reduce, wraps, partial from itertools import product from typing import Sequence, Callable @@ -52,7 +52,8 @@ def processing_fn(res): return [qscript], processing_fn -def shadow_expval(H, k=1): +@partial(transform, final_transform=True) +def shadow_expval(tape: QuantumTape, H, k=1) -> (Sequence[QuantumTape], Callable): """Transform a QNode returning a classical shadow into one that returns the approximate expectation values in a differentiable manner. @@ -88,14 +89,15 @@ def circuit(x): >>> qml.grad(circuit)(x) -0.9323999999999998 """ + tapes, _ = _replace_obs(tape, qml.shadow_expval, H, k=k) - def decorator(qnode): - return _replace_obs(qnode, qml.shadow_expval, H, k=k) + def post_processing_fn(res): + return res - return decorator + return tapes, post_processing_fn -def _shadow_state_diffable(wires): +def _shadow_state_diffable(tape, wires): """Differentiable version of the shadow state transform""" wires_list = wires if isinstance(wires[0], list) else [wires] @@ -117,63 +119,55 @@ def _shadow_state_diffable(wires): observables.append(reduce(lambda a, b: a @ b, [ob(wire) for ob, wire in zip(obs, w)])) all_observables.extend(observables) - def decorator(qnode): - new_qnode = _replace_obs(qnode, qml.shadow_expval, all_observables) - - @wraps(qnode) - def wrapper(*args, **kwargs): - # pylint: disable=not-callable - results = new_qnode(*args, **kwargs) - - # cast to complex - results = qml.math.cast(results, np.complex64) - - states = [] - start = 0 - for w in wires_list: - # reconstruct the state given the observables and the expectations of - # those observables - - obs_matrices = qml.math.stack( - [ - qml.math.cast_like(qml.math.convert_like(qml.matrix(obs), results), results) - for obs in all_observables[start : start + 4 ** len(w)] - ] - ) - - s = qml.math.einsum( - "a,abc->bc", results[start : start + 4 ** len(w)], obs_matrices - ) / 2 ** len(w) - states.append(s) + tapes, _ = _replace_obs(tape, qml.shadow_expval, all_observables) + + def post_processing_fn(results): + """Post process the classical shadows.""" + results = results[0] + # cast to complex + results = qml.math.cast(results, np.complex64) + + states = [] + start = 0 + for w in wires_list: + # reconstruct the state given the observables and the expectations of + # those observables + + obs_matrices = qml.math.stack( + [ + qml.math.cast_like(qml.math.convert_like(qml.matrix(obs), results), results) + for obs in all_observables[start : start + 4 ** len(w)] + ] + ) - start += 4 ** len(w) + s = qml.math.einsum( + "a,abc->bc", results[start : start + 4 ** len(w)], obs_matrices + ) / 2 ** len(w) + states.append(s) - return states if isinstance(wires[0], list) else states[0] + start += 4 ** len(w) - return wrapper + return states if isinstance(wires[0], list) else states[0] - return decorator + return tapes, post_processing_fn -def _shadow_state_undiffable(wires): +def _shadow_state_undiffable(tape, wires): """Non-differentiable version of the shadow state transform""" wires_list = wires if isinstance(wires[0], list) else [wires] - def decorator(qnode): - @wraps(qnode) - def wrapper(*args, **kwargs): - bits, recipes = qnode(*args, **kwargs) - shadow = qml.shadows.ClassicalShadow(bits, recipes) - - states = [qml.math.mean(shadow.global_snapshots(wires=w), 0) for w in wires_list] - return states if isinstance(wires[0], list) else states[0] + def post_processing(results): + bits, recipes = results[0] + shadow = qml.shadows.ClassicalShadow(bits, recipes) - return wrapper + states = [qml.math.mean(shadow.global_snapshots(wires=w), 0) for w in wires_list] + return states if isinstance(wires[0], list) else states[0] - return decorator + return [tape], post_processing -def shadow_state(wires, diffable=False): +@partial(transform, final_transform=True) +def shadow_state(tape: QuantumTape, wires, diffable=False) -> (Sequence[QuantumTape], Callable): """Transform a QNode returning a classical shadow into one that returns the reconstructed state in a differentiable manner. @@ -221,4 +215,7 @@ def circuit(x): [ 0.004275, 0.2358 , 0.244875, -0.002175], [-0.2358 , -0.004275, -0.002175, -0.235125]]) """ - return _shadow_state_diffable(wires) if diffable else _shadow_state_undiffable(wires) + tapes, fn = ( + _shadow_state_diffable(tape, wires) if diffable else _shadow_state_undiffable(tape, wires) + ) + return tapes, fn diff --git a/pennylane/transforms/commutation_dag.py b/pennylane/transforms/commutation_dag.py index 014601ebec7..383df0497b8 100644 --- a/pennylane/transforms/commutation_dag.py +++ b/pennylane/transforms/commutation_dag.py @@ -16,17 +16,20 @@ """ import heapq from collections import OrderedDict -from functools import wraps +from functools import partial +from typing import Sequence, Callable import networkx as nx from networkx.drawing.nx_pydot import to_pydot import pennylane as qml -from pennylane.tape import QuantumScript, make_qscript, QuantumTape +from pennylane.tape import QuantumTape from pennylane.wires import Wires +from pennylane.transforms.core import transform -def commutation_dag(circuit): +@partial(transform, is_informative=True) +def commutation_dag(tape: QuantumTape) -> (Sequence[QuantumTape], Callable): r"""Construct the pairwise-commutation DAG (directed acyclic graph) representation of a quantum circuit. In the DAG, each node represents a quantum operation, and edges represent @@ -36,8 +39,7 @@ def commutation_dag(circuit): operations can be moved next to each other by pairwise commutation. Args: - circuit (pennylane.QNode, .QuantumTape, or Callable): A quantum node, tape, - or function that applies quantum operations. + tape ( .QuantumTape): The quantum circuit. Returns: function: Function which accepts the same arguments as the :class:`qml.QNode`, :class:`qml.tape.QuantumTape` @@ -91,35 +93,13 @@ def circuit(x, y, z): """ - # pylint: disable=protected-access - - @wraps(circuit) - def wrapper(*args, **kwargs): - if isinstance(circuit, qml.QNode): - # user passed a QNode, get the tape - circuit.construct(args, kwargs) - tape = circuit.qtape - - elif isinstance(circuit, QuantumScript): - # user passed a tape - tape = circuit - - elif callable(circuit): - # user passed something that is callable but not a tape or qnode. - tape = make_qscript(circuit)(*args, **kwargs) - # raise exception if it is not a quantum function - if len(tape.operations) == 0: - raise ValueError("Function contains no quantum operation") - - else: - raise ValueError("Input is not a tape, QNode, or quantum function") - + def processing_fn(res): + """Processing function that returns the circuit as a commutation DAG.""" # Initialize DAG - dag = CommutationDAG(tape) - + dag = CommutationDAG(res[0]) return dag - return wrapper + return [tape], processing_fn def _merge_no_duplicates(*iterables): diff --git a/tests/gradients/parameter_shift/test_parameter_shift_hessian.py b/tests/gradients/parameter_shift/test_parameter_shift_hessian.py index d54dd921fb5..8dce4d5de56 100644 --- a/tests/gradients/parameter_shift/test_parameter_shift_hessian.py +++ b/tests/gradients/parameter_shift/test_parameter_shift_hessian.py @@ -987,9 +987,9 @@ def circuit(x): return qml.probs(wires=[0, 1]) x = np.array([0.1, 0.2, 0.3], requires_grad=True) - shape = (6, 6, 4) # (num_gate_args, num_gate_args, num_output_vals) + shape = (3, 3, 4) # (num_args, num_args, num_output_vals) - hessian = qml.gradients.param_shift_hessian(circuit, hybrid=False)(x) + hessian = qml.gradients.param_shift_hessian(circuit)(x) assert qml.math.shape(hessian) == shape @@ -1436,10 +1436,8 @@ def circuit(weights): return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)) weights = [0.1, 0.2] - with pytest.warns(UserWarning, match="Hessian of a QNode with no trainable parameters"): - res = qml.gradients.param_shift_hessian(circuit)(weights) - - assert res == () + with pytest.raises(qml.QuantumFunctionError, match="No trainable parameters."): + qml.gradients.param_shift_hessian(circuit)(weights) @pytest.mark.torch def test_no_trainable_params_qnode_torch(self): @@ -1455,10 +1453,8 @@ def circuit(weights): return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)) weights = [0.1, 0.2] - with pytest.warns(UserWarning, match="Hessian of a QNode with no trainable parameters"): - res = qml.gradients.param_shift_hessian(circuit)(weights) - - assert res == () + with pytest.raises(qml.QuantumFunctionError, match="No trainable parameters."): + qml.gradients.param_shift_hessian(circuit)(weights) @pytest.mark.tf def test_no_trainable_params_qnode_tf(self): @@ -1474,10 +1470,8 @@ def circuit(weights): return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)) weights = [0.1, 0.2] - with pytest.warns(UserWarning, match="Hessian of a QNode with no trainable parameters"): - res = qml.gradients.param_shift_hessian(circuit)(weights) - - assert res == () + with pytest.raises(qml.QuantumFunctionError, match="No trainable parameters."): + qml.gradients.param_shift_hessian(circuit)(weights) @pytest.mark.jax def test_no_trainable_params_qnode_jax(self): @@ -1493,10 +1487,8 @@ def circuit(weights): return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)) weights = [0.1, 0.2] - with pytest.warns(UserWarning, match="Hessian of a QNode with no trainable parameters"): - res = qml.gradients.param_shift_hessian(circuit)(weights) - - assert res == () + with pytest.raises(qml.QuantumFunctionError, match="No trainable parameters."): + qml.gradients.param_shift_hessian(circuit)(weights) def test_all_zero_diff_methods(self): """Test that the transform works correctly when the diff method for every parameter is @@ -1509,11 +1501,12 @@ def circuit(params): return qml.probs([2, 3]) params = np.array([0.5, 0.5, 0.5], requires_grad=True) + circuit(params) result = qml.gradients.param_shift_hessian(circuit)(params) assert np.allclose(result, np.zeros((3, 3, 4)), atol=0, rtol=0) - tapes, _ = qml.gradients.param_shift_hessian(circuit.tape) + tapes, _ = qml.gradients.param_shift_hessian(circuit.qtape) assert tapes == [] @pytest.mark.xfail(reason="Update tracker for new return types") @@ -1573,7 +1566,7 @@ def cost6(x): circuits = [qml.QNode(cost, dev) for cost in (cost1, cost2, cost3, cost4, cost5, cost6)] transform = [qml.math.shape(qml.gradients.param_shift_hessian(c)(x)) for c in circuits] - expected = [(3, 3), (3, 3), (2, 3, 3), (3, 3, 4), (3, 3, 4), (2, 3, 3, 4)] + expected = [(3, 3), (1, 3, 3), (2, 3, 3), (3, 3, 4), (1, 3, 3, 4), (2, 3, 3, 4)] assert all(t == e for t, e in zip(transform, expected)) diff --git a/tests/ops/functions/test_map_wires.py b/tests/ops/functions/test_map_wires.py index e32e689e945..08e361de414 100644 --- a/tests/ops/functions/test_map_wires.py +++ b/tests/ops/functions/test_map_wires.py @@ -108,7 +108,8 @@ def test_map_wires_tape(self, shots): tape.trainable_params = [0, 2] # TODO: Use qml.equal when supported - s_tape = qml.map_wires(tape, wire_map=wire_map) + s_tapes, _ = qml.map_wires(tape, wire_map=wire_map) + s_tape = s_tapes[0] assert len(s_tape) == 1 assert s_tape.trainable_params == [0, 2] s_op = s_tape[0] @@ -129,7 +130,8 @@ def test_execute_mapped_tape(self, shots): tape = QuantumScript.from_queue(q_tape, shots=shots) # TODO: Use qml.equal when supported - m_tape = qml.map_wires(tape, wire_map=wire_map) + m_tapes, _ = qml.map_wires(tape, wire_map=wire_map) + m_tape = m_tapes[0] m_op = m_tape.operations[0] m_obs = m_tape.observables[0] assert qml.equal(m_op, mapped_op) diff --git a/tests/shadow/test_shadow_transforms.py b/tests/shadow/test_shadow_transforms.py index ef91d5956cf..52fa3c950dc 100644 --- a/tests/shadow/test_shadow_transforms.py +++ b/tests/shadow/test_shadow_transforms.py @@ -131,7 +131,7 @@ def test_hadamard_state(self, wires, diffable): """Test that the state reconstruction is correct for a uniform superposition of qubits""" circuit = hadamard_circuit(wires) - circuit = qml.shadows.shadow_state(wires=range(wires), diffable=diffable)(circuit) + circuit = qml.shadows.shadow_state(circuit, wires=range(wires), diffable=diffable) actual = circuit() expected = np.ones((2**wires, 2**wires)) / (2**wires) @@ -174,17 +174,12 @@ def test_large_state_warning(self, monkeypatch): """Test that a warning is raised when the system to get the state of is large""" circuit = hadamard_circuit(8, shots=1) + circuit.construct([], {}) - with monkeypatch.context() as m: - # monkeypatch the range function so we don't run the state reconstruction - m.setattr(builtins, "range", lambda *args: [0]) - - msg = "Differentiable state reconstruction for more than 8 qubits is not recommended" - with pytest.warns(UserWarning, match=msg): - # full hard-coded list for wires instead of range(8) since we monkeypatched it - circuit = qml.shadows.shadow_state(wires=[0, 1, 2, 3, 4, 5, 6, 7], diffable=True)( - circuit - ) + msg = "Differentiable state reconstruction for more than 8 qubits is not recommended" + with pytest.warns(UserWarning, match=msg): + # full hard-coded list for wires instead of range(8) since we monkeypatched it + qml.shadows.shadow_state(circuit.qtape, wires=[0, 1, 2, 3, 4, 5, 6, 7], diffable=True) def test_multi_measurement_error(self): """Test that an error is raised when classical shadows is returned @@ -345,7 +340,8 @@ def test_hadamard_forward(self): expected = [1, 1, 1, 0, 0, 0, 0] circuit = hadamard_circuit(3, shots=100000) - circuit = qml.shadows.shadow_expval(obs)(circuit) + circuit = qml.shadows.shadow_expval(circuit, obs) + actual = circuit() assert qml.math.allclose(actual, expected, atol=1e-1) @@ -364,16 +360,20 @@ def test_basic_entangler_backward(self): ] shadow_circuit = basic_entangler_circuit(3, shots=20000, interface="autograd") - shadow_circuit = qml.shadows.shadow_expval(obs)(shadow_circuit) + shadow_circuit = qml.shadows.shadow_expval(shadow_circuit, obs) exact_circuit = basic_entangler_circuit_exact_expval(3, "autograd") x = np.random.uniform(0.8, 2, size=qml.BasicEntanglerLayers.shape(n_layers=1, n_wires=3)) + def shadow_cost(x): + res = shadow_circuit(x) + return qml.math.stack(res) + def exact_cost(x, obs): res = exact_circuit(x, obs) return qml.math.stack(res) - actual = qml.jacobian(shadow_circuit)(x) + actual = qml.jacobian(shadow_cost)(x) expected = qml.jacobian(exact_cost)(x, obs) assert qml.math.allclose(actual, expected, atol=1e-1) diff --git a/tests/transforms/test_commutation_dag.py b/tests/transforms/test_commutation_dag.py index 81391d6d63f..3716825b238 100644 --- a/tests/transforms/test_commutation_dag.py +++ b/tests/transforms/test_commutation_dag.py @@ -33,21 +33,6 @@ def circuit(): assert len(dag) != 0 - def test_dag_invalid_argument(self): - """Assert error raised when input is neither a tape, QNode, nor quantum function""" - - with pytest.raises(ValueError, match="Input is not a tape, QNode, or quantum function"): - qml.transforms.commutation_dag(qml.PauliZ(0))() - - def test_dag_wrong_function(self): - """Assert error raised when input function is not a quantum function""" - - def test_function(x): - return x - - with pytest.raises(ValueError, match="Function contains no quantum operation"): - qml.transforms.commutation_dag(test_function)(1) - def test_dag_transform_simple_dag_function(self): """Test a simple DAG on 1 wire with a quantum function.""" @@ -80,7 +65,7 @@ def test_dag_transform_simple_dag_tape(self): qml.PauliX(wires=0) tape = qml.tape.QuantumScript.from_queue(q) - dag = qml.transforms.commutation_dag(tape)() + dag = qml.transforms.commutation_dag(tape) a = qml.PauliZ(wires=0) b = qml.PauliX(wires=0)