A minimal JAX-based reinforcement learning template, for rapidly spinning up RL projects!
All training and evaluation is JIT-compiled end-to-end in JAX. The template is for Python 3.8.12
, built on top of:
- JAX - Autograd and XLA
- Flax - Neural network library
- Optax - Gradient-based optimisation
- Distrax - Probability distributions
- Weights & Biases - Experiment tracking and visualisation
Variants of this template are released as branches of this repository, each with different features:
Branch | Description | Agents | Environments |
---|---|---|---|
main (here) |
Basic training and evaluation functionality (e.g. training loop, logging, checkpointing), plus common online RL agents | PPO , SAC , DQN |
Gymnax |
offline (TBC) |
Adds offline RL functionality (e.g. replay buffer, offline training) | CQL , EDAC |
- |
This template is designed to provide only core functionality, providing a solid foundation for RL projects. Whilst it is not designed to be a full-featured RL library, please raise an issue if you think a feature is missing that would be useful for many projects.
- Install Python packages from
requirements-base.txt
andrequirements-cpu.txt
insetup
with:
cd setup && pip install $(cat requirements-base.txt requirements-cpu.txt)
- Sign into WandB to enable logging:
wandb login
- Build the Docker container with the provided script:
cd setup/docker && ./build.sh
- Add your WandB key to the
setup/docker
folder:
echo <wandb_key> > setup/docker/wandb_key
Install the Black pre-commit hook, after installing Python packages, with:
pre-commit install
This will check and fix formatting errors when you commit code.
To train an agent, run:
python train.py <arguments>
For example, to train a PPO agent on the CartPole-v1 environment and log to WandB, run:
python train.py --agent ppo --env_name CartPole-v1 --log --wandb_entity wandb_username --wandb_project project_name
To see all possible arguments, see experiments/parse_args.py
or run:
python train.py --help
Launch training runs inside your built container with:
./run_docker.sh <gpu_id> python3 train.py <arguments>
For example, to train a DQN agent on the Asterix-MinAtar environment using GPU 3, run:
./run_docker.sh 3 python3 train.py --agent dqn --env_name Asterix-MinAtar
Large parts of the training loop and PPO implementation are based on PureJaxRL, which contains high-performance, single-file implementations of RL agents in JAX.