Skip to content

Commit

Permalink
Merge branch 'VM/torch_compile' of https://github.com/Denys88/rl_games
Browse files Browse the repository at this point in the history
…into VM/torch_compile
  • Loading branch information
ViktorM committed Dec 16, 2024
2 parents 3009d02 + 5fe2dc1 commit 6819a1d
Show file tree
Hide file tree
Showing 11 changed files with 1,156 additions and 575 deletions.
52 changes: 35 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,22 @@ Explore RL Games quick and easily in colab notebooks:

For maximum training performance a preliminary installation of Pytorch 2.2 or newer with CUDA 12.1 or newer is highly recommended:

```conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia``` or:
```pip install pip3 install torch torchvision```
```bash
pip3 install torch torchvision
```

Then:

```pip install rl-games```
```bash
pip install rl-games
```

Or clone the repo and install the latest version from source :
```bash
pip install -e .
```

To run CPU-based environments either Ray or envpool are required ```pip install envpool``` or ```pip install ray```
To run CPU-based environments either envpool if supported or Ray are required ```pip install envpool``` or ```pip install ray```
To run Mujoco, Atari games or Box2d based environments training they need to be additionally installed with ```pip install gym[mujoco]```, ```pip install gym[atari]``` or ```pip install gym[box2d]``` respectively.

To run Atari also ```pip install opencv-python``` is required. In addition installation of envpool for maximum simulation and training perfromance of Mujoco and Atari environments is highly recommended: ```pip install envpool```
Expand Down Expand Up @@ -114,13 +122,17 @@ And IsaacGymEnvs: https://github.com/NVIDIA-Omniverse/IsaacGymEnvs

*Ant*

```python train.py task=Ant headless=True```
```python train.py task=Ant test=True checkpoint=nn/Ant.pth num_envs=100```
```bash
python train.py task=Ant headless=True
python train.py task=Ant test=True checkpoint=nn/Ant.pth num_envs=100
```

*Humanoid*

```python train.py task=Humanoid headless=True```
```python train.py task=Humanoid test=True checkpoint=nn/Humanoid.pth num_envs=100```
```bash
python train.py task=Humanoid headless=True
python train.py task=Humanoid test=True checkpoint=nn/Humanoid.pth num_envs=100
```

*Shadow Hand block orientation task*

Expand All @@ -131,6 +143,13 @@ And IsaacGymEnvs: https://github.com/NVIDIA-Omniverse/IsaacGymEnvs

*Atari Pong*

```bash
python runner.py --train --file rl_games/configs/atari/ppo_pong.yaml
python runner.py --play --file rl_games/configs/atari/ppo_pong.yaml --checkpoint nn/PongNoFrameskip.pth
```

Or with poetry:

```bash
poetry install -E atari
poetry run python runner.py --train --file rl_games/configs/atari/ppo_pong.yaml
Expand All @@ -140,22 +159,21 @@ poetry run python runner.py --play --file rl_games/configs/atari/ppo_pong.yaml -
*Brax Ant*

```bash
poetry install -E brax
poetry run pip install --upgrade "jax[cuda]==0.3.13" -f https://storage.googleapis.com/jax-releases/jax_releases.html
poetry run python runner.py --train --file rl_games/configs/brax/ppo_ant.yaml
poetry run python runner.py --play --file rl_games/configs/brax/ppo_ant.yaml --checkpoint runs/Ant_brax/nn/Ant_brax.pth
pip install -U "jax[cuda12]"
pip install brax
python runner.py --train --file rl_games/configs/brax/ppo_ant.yaml
python runner.py --play --file rl_games/configs/brax/ppo_ant.yaml --checkpoint runs/Ant_brax/nn/Ant_brax.pth
```

## Experiment tracking

rl_games support experiment tracking with [Weights and Biases](https://wandb.ai).

```bash
poetry install -E atari
poetry run python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --track
WANDB_API_KEY=xxxx poetry run python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --track
poetry run python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --wandb-project-name rl-games-special-test --track
poetry run python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --wandb-project-name rl-games-special-test -wandb-entity openrlbenchmark --track
python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --track
WANDB_API_KEY=xxxx python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --track
python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --wandb-project-name rl-games-special-test --track
python runner.py --train --file rl_games/configs/atari/ppo_breakout_torch.yaml --wandb-project-name rl-games-special-test -wandb-entity openrlbenchmark --track
```


Expand Down
1,637 changes: 1,099 additions & 538 deletions notebooks/train_and_export_onnx_example_lstm_continuous.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def restore_central_value_function(self, fn):
self.set_central_value_function_weights(checkpoint)

def get_masked_action_values(self, obs, action_masks):
assert False
raise NotImplementedError("Masked action values are not implemented for continuous actions")

def calc_gradients(self, input_dict):
"""Compute gradients needed to step the networks of the algorithm.
Expand Down
3 changes: 1 addition & 2 deletions rl_games/algos_torch/moving_mean_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import rl_games.algos_torch.torch_ext as torch_ext


'''
updates moving statistics with momentum
'''
Expand Down Expand Up @@ -76,7 +77,6 @@ def _get_stats(self):
else:
raise NotImplementedError(self.impl)


def _update_stats(self, x):
m = self.decay
if self.impl == 'off':
Expand Down Expand Up @@ -108,7 +108,6 @@ def forward(self, input, mask=None, denorm=False):
self._update_stats(input)

offset, invscale = self._get_stats()

if denorm:
y = input * invscale + offset
else:
Expand Down
8 changes: 5 additions & 3 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def set_full_state_weights(self, weights, set_epoch=True):
self.vec_env.set_env_state(env_state)

def restore(self, fn, set_epoch=True):
print("SAC restore")
if not os.path.exists(fn):
raise FileNotFoundError(f"Checkpoint file not found: {fn}")
checkpoint = torch_ext.load_checkpoint(fn)
self.set_full_state_weights(checkpoint, set_epoch=set_epoch)

Expand All @@ -268,7 +269,7 @@ def set_param(self, param_name, param_value):
pass

def get_masked_action_values(self, obs, action_masks):
assert False
raise NotImplementedError("Masked action values are not supported in SAC agent")

def set_eval(self):
self.model.eval()
Expand Down Expand Up @@ -425,7 +426,8 @@ def act(self, obs, action_dim, sample=False):

actions = dist.sample() if sample else dist.mean
actions = actions.clamp(*self.action_range)
assert actions.ndim == 2
if actions.ndim != 2:
raise ValueError(f"Actions tensor must be 2-dimensional, got shape {actions.shape}")

return actions

Expand Down
3 changes: 2 additions & 1 deletion rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ def init_tensors(self):

total_agents = self.num_agents * self.num_actors
num_seqs = self.horizon_length // self.seq_length
assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0)
if not ((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0):
raise ValueError(f"Horizon length ({self.horizon_length}) times total agents ({total_agents}) divided by num minibatches ({self.num_minibatches}) must be divisible by sequence length ({self.seq_length})")
self.mb_rnn_states = [torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states]

def init_current_rewards(self, batch_size, current_rewards_shape):
Expand Down
7 changes: 2 additions & 5 deletions rl_games/common/diagnostics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import rl_games.algos_torch.torch_ext as torch_ext


class DefaultDiagnostics(object):
def __init__(self):
pass
Expand All @@ -24,7 +25,7 @@ def __init__(self):
def send_info(self, writter):
if writter is None:
return
for k,v in self.diag_dict.items():
for k, v in self.diag_dict.items():
writter.add_scalar(k, v.cpu().numpy(), self.current_epoch)

def epoch(self, agent, current_epoch):
Expand Down Expand Up @@ -58,7 +59,3 @@ def mini_batch(self, agent, batch, e_clip, minibatch):
clip_frac = torch_ext.policy_clip_fraction(new_neglogp, old_neglogp, e_clip, masks)
self.exp_vars.append(exp_var)
self.clip_fracs.append(clip_frac)




2 changes: 1 addition & 1 deletion rl_games/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def update_data(self, name, index, val):

def update_data_rnn(self, name, indices, play_mask, val):
if type(val) is dict:
for k,v in val:
for k, v in val:
self.tensor_dict[name][k][indices, play_mask] = v
else:
self.tensor_dict[name][indices, play_mask] = val
Expand Down
2 changes: 1 addition & 1 deletion rl_games/common/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def maybe_load_new_checkpoint(self):
load_error = False
try:
torch.load(self.checkpoint_to_load)
except Exception as e:
except (OSError, IOError, torch.TorchError) as e:
print(f"Evaluation: checkpoint file is likely corrupted {self.checkpoint_to_load}: {e}")
load_error = True

Expand Down
10 changes: 6 additions & 4 deletions rl_games/envs/test/test_asymmetric_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
import numpy as np
from rl_games.common.wrappers import MaskVelocityWrapper


class TestAsymmetricCritic(gym.Env):
def __init__(self, wrapped_env_name, **kwargs):
gym.Env.__init__(self)
self.apply_mask = kwargs.pop('apply_mask', True)
self.use_central_value = kwargs.pop('use_central_value', True)
self.env = gym.make(wrapped_env_name)

if self.apply_mask:
if wrapped_env_name not in ["CartPole-v1", "Pendulum-v0", "LunarLander-v2", "LunarLanderContinuous-v2"]:
raise 'unsupported env'
supported_envs = ["CartPole-v1", "Pendulum-v0", "LunarLander-v2", "LunarLanderContinuous-v2"]
if wrapped_env_name not in supported_envs:
raise ValueError(f"Environment {wrapped_env_name} not supported. Supported environments: {supported_envs}")
self.mask = MaskVelocityWrapper(self.env, wrapped_env_name).mask
else:
self.mask = 1
Expand Down Expand Up @@ -47,6 +49,6 @@ def step(self, actions):
else:
obses = obs_dict["obs"].astype(np.float32)
return obses, rewards, dones, info

def has_action_mask(self):
return False
5 changes: 3 additions & 2 deletions rl_games/torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
def _restore(agent, args):
if 'checkpoint' in args and args['checkpoint'] is not None and args['checkpoint'] !='':
if args['train'] and args.get('load_critic_only', False):
assert agent.has_central_value, 'This should only work for asymmetric actor critic'
if not getattr(agent, 'has_central_value', False):
raise ValueError('Loading critic only works only for asymmetric actor critic')
agent.restore_central_value_function(args['checkpoint'])
return
agent.restore(args['checkpoint'])
Expand All @@ -31,7 +32,7 @@ def _override_sigma(agent, args):
with torch.no_grad():
net.sigma.fill_(float(args['sigma']))
else:
print('Print cannot set new sigma because fixed_sigma is False')
print('Cannot set new sigma because fixed_sigma is False')


class Runner:
Expand Down

0 comments on commit 6819a1d

Please sign in to comment.