Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Issue calculating qml.qinfo.transforms.quantum_fisher #5197

Closed
1 task done
NickGut0711 opened this issue Feb 13, 2024 · 2 comments
Closed
1 task done

[BUG] Issue calculating qml.qinfo.transforms.quantum_fisher #5197

NickGut0711 opened this issue Feb 13, 2024 · 2 comments
Labels
bug 🐛 Something isn't working

Comments

@NickGut0711
Copy link

Expected behavior

Hello! I’m currently trying to use qml.qinfo.transforms.quantum_fisher to calculate the quantum Fisher information matrix (QFIM) of a certain circuit called U1_qfi. However, I’m having an issue with the calculation in Pennylane. So it seems like this error pertains to some variable r which is empty and thus np.stack() doesn't work. To make this reproducible on your machine, the perceptron_qfi.H used in the qml.evolve method in U1_qfi is a parameterized Hamiltonian for 5 wires given as

(constant(params_0, t)*((PauliZ(wires=[0]) @ PauliZ(wires=[1])) + (PauliZ(wires=[1]) @ PauliZ(wires=[2])) + (PauliZ(wires=[2]) @ PauliZ(wires=[3])) + (PauliZ(wires=[3]) @ PauliZ(wires=[4]))))
+ (constant(params_1, t)*(Identity(wires=[0]) + Identity(wires=[1]) + Identity(wires=[2]) + Identity(wires=[3]) + Identity(wires=[4])))

Also, params_qfi is just a JAX array of parameters which has a number of elements that depends on L_qfi. For the case of L_qfi=1, params_qfi is a JAX array with four elements. When L_qfi=2, params_qfi has eight elements.

Actual behavior

The QFIM isn't calculated due to the np.stack() error.

Additional information

A more minimal example of the problem is

params = jax.numpy.array([0.5, 1., 0.2])
H = 1.*qml.PauliX(0) @ qml.PauliX(1) - 0.5 * qml.PauliZ(1)

dev = qml.device("default.qubit")
@qml.qnode(dev)
def circ(params):
    qml.RY(params[0], wires=1)
    qml.CNOT(wires=(1,0))
    qml.RY(params[1], wires=1)
    qml.RZ(params[2], wires=1)
    return qml.expval(H)
qml.gradients.adjoint_metric_tensor(circ)(params)

with the output as

Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

While this ran, it didn’t consider any of the parameters trainable because jax tracks trainability information in a different way. You will just need to update when it considers a parameter trainable or not.

Source code

n_wires = 5
random_seed = 25
perceptron_qfi = QuantumPerceptron(n_wires, L=2)

@qml.qnode(dev_qfi2, interface='jax')
def U1_qfi(params, L = 1):

    start_index = 0
    n_params_in_layer = (2 + 2)

    for i in range(L):
        params_layer = params[start_index:start_index + n_params_in_layer]
        qml.evolve(perceptron_qfi.H)(params_layer[:2], t)
        for j in range(n_wires): # Single-qubit X rotations
            qml.RX(params_layer[2], wires=j)
        for k in range(n_wires): # Single-qubit Y rotations
            qml.RY(params_layer[3], wires=k)
        start_index += n_params_in_layer
    
    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

L_qfi = perceptron_qfi.L
params_qfi = perceptron_qfi.get_random_parameter_vector(random_seed)

qfim = qml.qinfo.quantum_fisher(U1_qfi)(params_qfi, L = L_qfi)

Tracebacks

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/.../.ipynb Cell 25 line 1
      7 params_qfi = perceptron_qfi.get_random_parameter_vector(random_seed)
      9 # qfim = calculate_QFI(U1_qfi, params_qfi, L_qfi)
---> 10 qfim = qml.qinfo.quantum_fisher(U1_qfi)(params_qfi, L = L_qfi)
     11 qfim.shape

File ~/.../site-packages/pennylane/qnode.py:1027, in QNode.__call__(self, *args, **kwargs)
   1022         full_transform_program._set_all_argnums(
   1023             self, args, kwargs, argnums
   1024         )  # pylint: disable=protected-access
   1026 # pylint: disable=unexpected-keyword-arg
-> 1027 res = qml.execute(
   1028     (self._tape,),
   1029     device=self.device,
   1030     gradient_fn=self.gradient_fn,
   1031     interface=self.interface,
   1032     transform_program=full_transform_program,
   1033     config=config,
   1034     gradient_kwargs=self.gradient_kwargs,
   1035     override_shots=override_shots,
   1036     **self.execute_kwargs,
   1037 )
   1039 res = res[0]
   1041 # convert result to the interface in case the qfunc has no parameters

File ~/.../site-packages/pennylane/interfaces/execution.py:612, in execute(tapes, device, gradient_fn, interface, transform_program, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform)
    609         return program_post_processing(program_pre_processing(results))
    611 if transform_program.is_informative:
--> 612     return post_processing(tapes)
    614 # Exiting early if we do not need to deal with an interface boundary
    615 if no_interface_boundary_required:

File ~/.../site-packages/pennylane/transforms/core/transform_program.py:86, in _apply_postprocessing_stack(results, postprocessing_stack)
     63 """Applies the postprocessing and cotransform postprocessing functions in a Last-In-First-Out LIFO manner.
     64 
     65 Args:
   (...)
     83 
     84 """
     85 for postprocessing in reversed(postprocessing_stack):
---> 86     results = postprocessing(results)
     87 return results

File ~/.../site-packages/pennylane/transforms/core/transform_program.py:56, in _batch_postprocessing(results, individual_fns, slices)
     30 def _batch_postprocessing(
     31     results: ResultBatch, individual_fns: List[PostProcessingFn], slices: List[slice]
     32 ) -> ResultBatch:
     33     """Broadcast individual post processing functions onto their respective tapes.
     34 
     35     Args:
   (...)
     54 
     55     """
---> 56     return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices))

File ~/.../site-packages/pennylane/transforms/core/transform_program.py:56, in <genexpr>(.0)
     30 def _batch_postprocessing(
     31     results: ResultBatch, individual_fns: List[PostProcessingFn], slices: List[slice]
     32 ) -> ResultBatch:
     33     """Broadcast individual post processing functions onto their respective tapes.
     34 
     35     Args:
   (...)
     54 
     55     """
---> 56     return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices))

File ~/.../site-packages/pennylane/qinfo/transforms.py:752, in quantum_fisher.<locals>.processing_fn_multiply(r)
    750 def processing_fn_multiply(r):  # pylint: disable=function-redefined
    751     print("Value of r:", r)
--> 752     r = qml.math.stack(r)
    753     return 4 * r

File ~/.../site-packages/pennylane/math/multi_dispatch.py:151, in multi_dispatch.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    148 interface = interface or get_interface(*dispatch_args)
    149 kwargs["like"] = interface
--> 151 return fn(*args, **kwargs)

File ~/.../site-packages/pennylane/math/multi_dispatch.py:488, in stack(values, axis, like)
    459 """Stack a sequence of tensors along the specified axis.
    460 
    461 .. warning::
   (...)
    485        [5.00e+00, 8.00e+00, 1.01e+02]], dtype=float32)>
    486 """
    487 values = np.coerce(values, like=like)
--> 488 return np.stack(values, axis=axis, like=like)

File ~/.../site-packages/autoray/autoray.py:80, in do(fn, like, *args, **kwargs)
     31 """Do function named ``fn`` on ``(*args, **kwargs)``, peforming single
     32 dispatch to retrieve ``fn`` based on whichever library defines the class of
     33 the ``args[0]``, or the ``like`` keyword argument if specified.
   (...)
     77     <tf.Tensor: id=91, shape=(3, 3), dtype=float32>
     78 """
     79 backend = choose_backend(fn, *args, like=like, **kwargs)
---> 80 return get_lib_fn(backend, fn)(*args, **kwargs)

File ~/.../site-packages/numpy/core/shape_base.py:445, in stack(arrays, axis, out, dtype, casting)
    443 arrays = [asanyarray(arr) for arr in arrays]
    444 if not arrays:
--> 445     raise ValueError('need at least one array to stack')
    447 shapes = {arr.shape for arr in arrays}
    448 if len(shapes) != 1:

ValueError: need at least one array to stack

System information

Name: PennyLane
Version: 0.33.1
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /.../python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Lightning

Platform info:           macOS-13.4.1-x86_64-i386-64bit
Python version:          3.11.5
Numpy version:           1.26.2
Scipy version:           1.11.4
Installed devices:
- default.gaussian (PennyLane-0.33.1)
- default.mixed (PennyLane-0.33.1)
- default.qubit (PennyLane-0.33.1)
- default.qubit.autograd (PennyLane-0.33.1)
- default.qubit.jax (PennyLane-0.33.1)
- default.qubit.legacy (PennyLane-0.33.1)
- default.qubit.tf (PennyLane-0.33.1)
- default.qubit.torch (PennyLane-0.33.1)
- default.qutrit (PennyLane-0.33.1)
- null.qubit (PennyLane-0.33.1)
- lightning.qubit (PennyLane-Lightning-0.33.1)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@NickGut0711 NickGut0711 added the bug 🐛 Something isn't working label Feb 13, 2024
@trbromley
Copy link
Contributor

Thanks @NickGut0711 for posting, we'll let you know once we have a fix for this.

astralcai added a commit that referenced this issue Feb 28, 2024
**Context:**
The `adjoint_metric_tensor` transform does not work with jax variables
because jax variables are not considered trainable parameters until they
become tracers.

**Description of the Change:**
1. Add a custom expand transform to `adjoint_metric_tensor` that expands
trainable parameters based on argnums.
2. Add an optional `use_argnum_in_expand` argument to transform programs
to determine whether or not to perform `jax_argnums_to_tape_trainable`
on the parameters.

**Benefits:**
BugFix

**Possible Drawbacks:**
Adding yet another keyword argument to `transform` may make code look
messy.

**Related GitHub Issues:**
#5197

**Related Shortcut Stories:**
[sc-56734]
@astralcai
Copy link
Contributor

Fixed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants