Skip to content

Commit

Permalink
Fixes setting the device from CLI in the RL training scripts (#1013)
Browse files Browse the repository at this point in the history
This pull request fixes the issue where the device (`CPU` or `CUDA`) is
not set correctly when using the `--device` argument in Hydra-configured
scripts like `rsl_rl/train.py` and `skrl/train.py`. The bug caused the
scripts to always default to `cuda:0`, even when `cpu` or a specific
CUDA device (e.g., `cuda:1`) was selected.

The fix adds the following line to ensure that the selected device is
properly set in `env_cfg` before initializing the environment with
`gym.make()`:

```python
env_cfg.sim.device = args_cli.device
```

Fixes #1012

- Bug fix (non-breaking change which fixes an issue)

Before:
- skrl/train, when running the script with --device cpu, it defaults to
cuda:0.
- rsl_rl/train.py, the script freezes at `[INFO]: Starting the
simulation. This may take a few seconds. Please wait....`

After:
- Both scripts run correctly on the specified device (e.g., cpu or
cuda:1) without defaulting to cuda:0 or freezing.

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
  • Loading branch information
amrmousa144 authored and Mayankm96 committed Sep 24, 2024
1 parent 59fd1f7 commit 90b1150
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 7 deletions.
3 changes: 2 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ Guidelines for modifications:

## Contributors

* Anton Bjørndahl Mortensen
* Alice Zhou
* Amr Mousa
* Andrej Orsula
* Anton Bjørndahl Mortensen
* Antonio Serrano-Muñoz
* Arjun Bhardwaj
* Brayden Zhang
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def __init__(self, cfg: DirectMARLEnvCfg, render_mode: str | None = None, **kwar
# initialize internal variables
self._is_closed = False

# set the seed for the environment
if self.cfg.seed is not None:
self.seed(self.cfg.seed)
else:
carb.log_warn("Seed not set for the environment. The environment creation may not be deterministic.")

# create a simulation context to control the simulator
if SimulationContext.instance() is None:
self.sim: SimulationContext = SimulationContext(self.cfg.sim)
Expand All @@ -88,6 +94,7 @@ def __init__(self, cfg: DirectMARLEnvCfg, render_mode: str | None = None, **kwar
# print useful information
print("[INFO]: Base environment:")
print(f"\tEnvironment device : {self.device}")
print(f"\tEnvironment seed : {self.cfg.seed}")
print(f"\tPhysics step-size : {self.physics_dt}")
print(f"\tRendering step-size : {self.physics_dt * self.cfg.sim.render_interval}")
print(f"\tEnvironment step-size : {self.step_dt}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ class DirectMARLEnvCfg:
"""

# general settings
seed: int | None = None
"""The seed for the random number generator. Defaults to None, in which case the seed is not set.
Note:
The seed is set at the beginning of the environment initialization. This ensures that the environment
creation is deterministic and behaves similarly across different runs.
"""

decimation: int = MISSING
"""Number of control action updates @ sim dt per policy dt.
Expand Down
2 changes: 2 additions & 0 deletions source/standalone/workflows/rl_games/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
"""Train with RL-Games agent."""
# override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device

agent_cfg["params"]["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["params"]["seed"]
agent_cfg["params"]["config"]["max_epochs"] = (
args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg["params"]["config"]["max_epochs"]
Expand Down
1 change: 1 addition & 0 deletions source/standalone/workflows/rsl_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here
env_cfg.seed = agent_cfg.seed
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device

# specify directory for logging experiments
log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name)
Expand Down
1 change: 1 addition & 0 deletions source/standalone/workflows/sb3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here
env_cfg.seed = agent_cfg["seed"]
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device

# directory for logging into
log_dir = os.path.join("logs", "sb3", args_cli.task, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
Expand Down
9 changes: 3 additions & 6 deletions source/standalone/workflows/skrl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
"""Train with skrl agent."""
# override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device

# multi-gpu training config
if args_cli.distributed:
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
Expand All @@ -118,7 +120,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"

# set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here
# note: certain randomization occur in the environment initialization so we set the seed here
env_cfg.seed = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"]

# specify directory for logging experiments
Expand All @@ -135,11 +137,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# update log_dir
log_dir = os.path.join(log_root_path, log_dir)

# multi-gpu training config
if args_cli.distributed:
# update env config device
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"

# dump the configuration into log-directory
dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg)
dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), agent_cfg)
Expand Down

0 comments on commit 90b1150

Please sign in to comment.