-
Notifications
You must be signed in to change notification settings - Fork 14
installing or updating JAX
Hajime Kawahara edited this page Feb 21, 2022
·
5 revisions
See the original JAX page. https://github.com/google/jax#installation
from source, see https://jax.readthedocs.io/en/latest/developer.html#building-from-source
author:Hajime Kawahara, Feb 21st (2022)
- Background: I wanted to use jax.experimental.sparse. Then, I needed to install the latest version of JAX.
Some dependency about CUDA was broken, so I decided to use aptitude instead of apt. Cudnn can be downloaded from the NVIDIA website.
sudo aptitude install cuda-11-5
sudo dpkg -i cudnn-local-repo-ubuntu2004-8.3.1.22_1.0-1_amd64.deb
pip uninstall jax
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
Then, a new module for sparse matrix worked.
manbou ~/jax(main)>python
Python 3.8.5 (default, Sep 4 2020, 07:30:14)
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from jax.experimental.sparse import coo
All of the unit tests in develop
passed.