[Project website] [Paper] [arXiv]
This project is a PyTorch implementation of Skill-based Model-based Reinforcement Learning, published in CoRL 2022.
run.py
: launches an appropriate trainer based on algorithmskill_trainer.py
: trainer for skill-based approachesskimo_agent.py
: model and training code for SkiMoskimo_rollout.py
: rollout with SkiMo agentspirl_tdmpc_agent.py
: model and training code for SPiRL+TD-MPCspirl_tdmpc_rollout.py
: rollout with SPiRL+TD-MPCspirl_dreamer_agent.py
: model and training code for SPiRL+Dreamerspirl_dreamer_rollout.py
: rollout with SPiRL+Dreamerspirl_trainer.py
: trainer for SPiRLspirl_agent.py
: model for SPiRLconfig/
: default hyperparameterscalvin/
: CALVIN environmentsd4rl/
: D4RL environments forked by Karl Pertsch. The only change from us is in the installation commandenvs/
: environment wrappersspirl/
: SPiRL codedata/
: offline data directoryrolf/
: implementation of RL algorithms from robot-learning by Youngwoon Leelog/
: training log, evaluation results, checkpoints
- Ubuntu 20.04
- Python 3.9
- MuJoCo 2.1
- Clone this repository.
git clone --recursive [email protected]:clvrai/skimo.git
cd skimo
- Create a virtual environment
conda create -n skimo_venv python=3.9
conda activate skimo_venv
- Install MuJoCo 2.1
- Download the MuJoCo version 2.1 binaries for Linux or OSX.
- Extract the downloaded
mujoco210
directory into~/.mujoco/mujoco210
.
- Install packages
sh install.sh
# Navigate to the data directory
mkdir data && cd data
# Maze
gdown 1GWo8Vr8Xqj7CfJs7TaDsUA6ELno4grKJ
# Kitchen (and mis-aligned kitchen)
gdown 1Fym9prOt5Cu_I73F20cdd3lXZPhrvEsd
# CALVIN
gdown 1g4ONf_3cNQtrZAo2uFa_t5MOopSr2DNY
cd ..
Commands for SkiMo and all baselines. Results will be logged to WandB. Before running the commands below, please change the wandb entity in run.py#L36 to match your account.
Please replace [ENV]
with one of maze
, kitchen
, calvin
. For mis-aligned kitchen, append env.task=misaligned
to the downstream RL command.
After pre-training, please set [PRETRAINED_CKPT]
with the proper path to the checkpoint.
- Pre-training
python run.py --config-name skimo_[ENV] run_prefix=test gpu=0 wandb=true
You can also skip this step by downloading our pre-trained model checkpoints. See instructions in pretrained_models.md.
- Downstream RL
python run.py --config-name skimo_[ENV] run_prefix=test gpu=0 wandb=true rolf.phase=rl rolf.pretrain_ckpt_path=[PRETRAINED_CKPT]
python run.py --config-name dreamer_config env=[ENV] run_prefix=test gpu=0 wandb=true
python run.py --config-name tdmpc_config env=[ENV] run_prefix=test gpu=0 wandb=true
- Need to first pre-train or download the skill prior (see instructions here).
- Downstream RL
python run.py --config-name spirl_config env=[ENV] run_prefix=test gpu=0 wandb=true
- Downstream RL
python run.py --config-name spirl_dreamer_[ENV] run_prefix=test gpu=0 wandb=true
- Downstream RL
python run.py --config-name spirl_tdmpc_[ENV] run_prefix=test gpu=0 wandb=true
- Downstream RL
python run.py --config-name skimo_[ENV] run_prefix=sac gpu=0 wandb=true rolf.phase=rl rolf.use_cem=false rolf.n_skill=1 rolf.prior_reg_critic=true rolf.sac=true rolf.pretrain_ckpt_path=[PRETRAINED_CKPT]
- Pre-training
python run.py --config-name skimo_[ENV] run_prefix=no_joint gpu=0 wandb=true rolf.joint_training=false
- Downstream RL
python run.py --config-name skimo_[ENV] run_prefix=no_joint gpu=0 wandb=true rolf.joint_training=false rolf.phase=rl rolf.pretrain_ckpt_path=[PRETRAINED_CKPT]
Solution: install mpi4py
with conda instead, which requires a lower version of python.
conda install python==3.8
conda install mpi4py
Now you can re-run sh install.sh
.
See this. In my case, I needed to change /usr/local/
to /opt/homebrew/
in all paths.
If you find our code useful for your research, please cite:
@inproceedings{shi2022skimo,
title={Skill-based Model-based Reinforcement Learning},
author={Lucy Xiaoyang Shi and Joseph J. Lim and Youngwoon Lee},
booktitle={Conference on Robot Learning},
year={2022}
}
- This code is based on Youngwoon's robot-learning repo: https://github.com/youngwoon/robot-learning
- SPiRL: https://github.com/clvrai/spirl
- TD-MPC: https://github.com/nicklashansen/tdmpc
- Dreamer: https://github.com/danijar/dreamer
- D4RL: https://github.com/rail-berkeley/d4rl
- CALVIN: https://github.com/mees/calvin