Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] JAX JIT errors-out when using a tracer-based PRNG key #6054

Open
1 task done
trbromley opened this issue Jul 30, 2024 · 6 comments · May be fixed by #6788
Open
1 task done

[BUG] JAX JIT errors-out when using a tracer-based PRNG key #6054

trbromley opened this issue Jul 30, 2024 · 6 comments · May be fixed by #6788
Labels
bug 🐛 Something isn't working

Comments

@trbromley
Copy link
Contributor

Expected behavior

It is possible to JAX JIT a function where qml.device has seed set to be a tracer for the PRNG key.

Actual behavior

An error is raised in many cases. I'm not exactly sure when, but it appears to be if there are other tracers active - like with the param variable below.

There is also an error when using qml.qjit and lightning.qubit, which doesn't expect a JAX tracer for the seed argument.

Additional information

Although there is a draft PR to fix this problem, it has gone stale and I decided it would be more useful to track the problem in this issue.

Source code

import jax
import pennylane as qml

def circuit(key, param):
    dev = qml.device("default.qubit", wires=2, shots=10, seed=key)

    @qml.qnode(dev, interface="jax")
    def my_circuit():
        qml.RX(param, wires=0)  # no error if this line is commented-out
        qml.CNOT(wires=[0, 1])
        return qml.sample(qml.PauliZ(0))
    return my_circuit()

key = jax.random.PRNGKey(1967)
jax.jit(circuit)(key, 0.5)

Tracebacks

XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was circuit at /tmp/ipykernel_31794/2671401565.py:4 traced for jit.
------------------------------
The leaked intermediate value was created on line /tmp/ipykernel_31794/2671401565.py:15 (<module>). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/tmp/ipykernel_31794/2671401565.py:15 (<module>)

System information

Working on the dev branch of PennyLane with JAX==0.4.23.

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@josh146
Copy link
Member

josh146 commented Jul 30, 2024

@trbromley do you have more details on this?

There is also an error when using qml.qjit and lightning.qubit, which doesn't expect a JAX tracer for the seed argument.

This PR might be relevant, which adds a seed kwarg to qjit: PennyLaneAI/catalyst#936

@trbromley
Copy link
Contributor Author

@trbromley do you have more details on this?

There is also an error when using qml.qjit and lightning.qubit, which doesn't expect a JAX tracer for the seed argument.

When I run this:

import jax
import pennylane as qml

def circuit(key, param):
    dev = qml.device("lightning.qubit", wires=2, shots=10, seed=key)

    @qml.qnode(dev, interface="jax")
    def my_circuit():
        qml.RX(param, wires=0)
        qml.CNOT(wires=[0, 1])
        return qml.sample(qml.PauliZ(0))
    return my_circuit()

key = jax.random.PRNGKey(1967)

I get this error:

>>> qml.qjit(circuit)(key, 0.5)
TypeError: SeedSequence expects int or sequence of ints for entropy not Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=1/0)>

This PR might be relevant, which adds a seed kwarg to qjit: PennyLaneAI/catalyst#936

Oh interesting, so the seed should live with qjit rather than the device? 🤔

@josh146
Copy link
Member

josh146 commented Jul 30, 2024

Oh interesting, so the seed should live with qjit rather than the device? 🤔

The PR was motivated by a technical problem: we have tests in the catalyst code base which are flaky, but no qjit-compatible way of specifying a seed. So passing it through the qjit decorator is a quick way of getting this support in there, but we can revisit from a user POV if we need to

@josh146
Copy link
Member

josh146 commented Jul 30, 2024

I get this error:

Ah yes I guess this is to be expected, since JAX seeds are not supported by MLIR/LLVM

@CatalinaAlbornoz
Copy link
Contributor

I'm reviving this issue.

The latest version of the JAX interface docs (Randomness section) shows an example using seed=key where key=jax.random.PRNGKey(0). This works fine with PennyLane v0.40.0.dev37 and JAX v0.4.33.

See the example here:

import jax
import pennylane as qml

@jax.jit
def sample_circuit(phi, theta, key):

    # Device construction should happen inside a `jax.jit` decorated
    # method when using a PRNGKey.
    dev = qml.device('default.qubit', wires=2, seed=key, shots=100)

    @qml.qnode(dev, interface='jax', diff_method=None)
    def circuit(phi, theta):
        qml.RX(phi[0], wires=0)
        qml.RZ(phi[1], wires=1)
        qml.CNOT(wires=[0, 1])
        qml.RX(theta, wires=0)
        return qml.sample() # Here, we take samples instead.

    return circuit(phi, theta)

# Get the samples from the jitted method.
samples = sample_circuit([0.2, 1.0], 5.2, jax.random.PRNGKey(0))

However when you remove remove diff_method=None it throws an error ERROR:jax._src.callback:jax.pure_callback failed

Full traceback below:

ERROR:jax._src.callback:jax.pure_callback failed
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/callback.py", line 86, in pure_callback_impl
    return tree_util.tree_map(np.asarray, callback(*args))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/callback.py", line 64, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/pennylane/workflow/interfaces/jax_jit.py", line 168, in pure_callback_wrapper
    return _to_jax(execute_fn(new_tapes))
  File "/usr/local/lib/python3.10/dist-packages/pennylane/workflow/run.py", line 247, in inner_execute
    results = device.execute(transformed_tapes, execution_config=execution_config)
  File "/usr/local/lib/python3.10/dist-packages/pennylane/devices/modifiers/simulator_tracking.py", line 30, in execute
    results = untracked_execute(self, circuits, execution_config)
  File "/usr/local/lib/python3.10/dist-packages/pennylane/devices/modifiers/single_tape_support.py", line 32, in execute
    results = batch_execute(self, circuits, execution_config)
  File "/usr/local/lib/python3.10/dist-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pennylane/devices/default_qubit.py", line 627, in execute
    prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
  File "/usr/local/lib/python3.10/dist-packages/pennylane/devices/default_qubit.py", line 627, in <listcomp>
    prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
  File "/usr/local/lib/python3.10/dist-packages/pennylane/devices/default_qubit.py", line 439, in get_prng_keys
    self._prng_key, *keys = jax_random_split(self._prng_key)
  File "/usr/local/lib/python3.10/dist-packages/pennylane/devices/qubit/sampling.py", line 42, in jax_random_split
    return split(prng_key, num=num)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/random.py", line 292, in split
    typed_key, wrapped = _check_prng_key("split", key)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/random.py", line 79, in _check_prng_key
    wrapped_key = prng.random_wrap(key, impl=default_prng_impl())
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/prng.py", line 699, in random_wrap
    return random_wrap_p.bind(base_arr, impl=impl)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 439, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 1385, in find_top_trace
    top_tracer._assert_live()
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/partial_eval.py", line 1741, in _assert_live
    raise core.escaped_tracer_error(self, None)
jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was sample_circuit at <ipython-input-11-529d78cbecfe>:5 traced for jit.
------------------------------
The leaked intermediate value was created on line <ipython-input-11-529d78cbecfe>:21 (<cell line: 21>). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py:78 (_pseudo_sync_runner)
/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py:3257 (run_cell_async)
/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py:3473 (run_ast_nodes)
/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py:3553 (run_code)
<ipython-input-11-529d78cbecfe>:21 (<cell line: 21>)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
[<ipython-input-11-529d78cbecfe>](https://localhost:8080/#) in <cell line: 21>()
     19 
     20 # Get the samples from the jitted method.
---> 21 samples = sample_circuit([0.2, 1.0], 5.2, jax.random.PRNGKey(0))

    [... skipping hidden 10 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in __call__(self, *args)
   1275           or self.has_host_callbacks):
   1276         input_bufs = self._add_tokens_to_inputs(input_bufs)
-> 1277         results = self.xla_executable.execute_sharded(
   1278             input_bufs, with_tokens=True
   1279         )

XlaRuntimeError: INTERNAL: CpuCallback error: Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
  File "/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py", line 37, in <module>
  File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 992, in launch_instance
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 619, in start
  File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start
  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
  File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
  File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 685, in <lambda>
  File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 738, in _run_callback
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 825, in inner
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 361, in process_one
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
  File "<ipython-input-11-529d78cbecfe>", line 21, in <cell line: 21>
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 332, in cache_miss
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 190, in _python_pjit_helper
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 2782, in bind
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 443, in bind_with_trace
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 949, in process_primitive
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1739, in _pjit_call_impl
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1721, in call_impl_cache_miss
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1675, in _pjit_call_impl_python
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py", line 333, in wrapper
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 1277, in __call__
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py", line 2777, in _wrapped_callback
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/callback.py", line 228, in _callback
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/callback.py", line 89, in pure_callback_impl
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/callback.py", line 64, in __call__
  File "/usr/local/lib/python3.10/dist-packages/pennylane/workflow/interfaces/jax_jit.py", line 168, in pure_callback_wrapper
  File "/usr/local/lib/python3.10/dist-packages/pennylane/workflow/run.py", line 247, in inner_execute
  File "/usr/local/lib/python3.10/dist-packages/pennylane/devices/modifiers/simulator_tracking.py", line 30, in execute
  File "/usr/local/lib/python3.10/dist-packages/pennylane/devices/modifiers/single_tape_support.py", line 32, in execute
  File "/usr/local/lib/python3.10/dist-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
  File "/usr/local/lib/python3.10/dist-packages/pennylane/devices/default_qubit.py", line 627, in execute
  File "/usr/local/lib/python3.10/dist-packages/pennylane/devices/default_qubit.py", line 627, in <listcomp>
  File "/usr/local/lib/python3.10/dist-packages/pennylane/devices/default_qubit.py", line 439, in get_prng_keys
  File "/usr/local/lib/python3.10/dist-packages/pennylane/devices/qubit/sampling.py", line 42, in jax_random_split
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/random.py", line 292, in split
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/random.py", line 79, in _check_prng_key
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/prng.py", line 699, in random_wrap
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 439, in bind
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 1385, in find_top_trace
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/partial_eval.py", line 1741, in _assert_live
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was sample_circuit at <ipython-input-11-529d78cbecfe>:5 traced for jit.
------------------------------
The leaked intermediate value was created on line <ipython-input-11-529d78cbecfe>:21 (<cell line: 21>). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py:78 (_pseudo_sync_runner)
/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py:3257 (run_cell_async)
/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py:3473 (run_ast_nodes)
/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py:3553 (run_code)
<ipython-input-11-529d78cbecfe>:21 (<cell line: 21>)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

This error still occurs when the measurement is qml.expval(qml.PauliZ(0)) instead of sample.

The error can be avoided by using a fixed seed within the jitted circuit, or commenting out all of the gates that take parameters, which is of course not ideal.

The error was brought up in this Forum thread.

@CatalinaAlbornoz
Copy link
Contributor

There's a way to pass an integer key as an argument by using static_argnums.

import pennylane as qml
import numpy as np
import jax
from functools import partial

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")


@partial(jax.jit, static_argnums=2)
def sample_circuit(phi, theta, key):
    dev = qml.device("default.qubit", wires=2, seed=key, shots=100)

    @qml.qnode(dev, interface="jax")
    def circuit(phi, theta):
        qml.RX(phi[0], wires=0)
        qml.RZ(phi[1], wires=1)
        qml.CNOT(wires=[0, 1])
        qml.RX(theta[0], wires=0)
        return qml.expval(qml.PauliZ(0))

    return circuit(phi, theta)


phi = np.array([0.2, 1.0])
theta = np.array([0.2])
key = 10
print(sample_circuit(phi, theta, key))
print(sample_circuit(phi, theta, key))
print(sample_circuit(phi, theta, key))
print(jax.grad(sample_circuit)(phi, theta, key))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants