Skip to content

Commit

Permalink
code refactor for venv (#179)
Browse files Browse the repository at this point in the history
- Refacor code to remove duplicate code

- Enable async simulation for all vector envs

- Remove `collector.close` and rename `VectorEnv` to `DummyVectorEnv`

The abstraction of vector env changed.

Prior to this pr, each vector env is almost independent.

After this pr, each env is wrapped into a worker, and vector envs differ with their worker type. In fact, users can just use `BaseVectorEnv` with different workers, I keep `SubprocVectorEnv`, `ShmemVectorEnv` for backward compatibility.

Co-authored-by: n+e <[email protected]>
Co-authored-by: magicly <[email protected]>
  • Loading branch information
3 people authored Aug 19, 2020
1 parent 311a2be commit a9f9940
Show file tree
Hide file tree
Showing 61 changed files with 1,141 additions and 987 deletions.
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Here is Tianshou's other features:

- Elegant framework, using only ~2000 lines of code
- Support parallel environment sampling for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
Expand Down Expand Up @@ -152,7 +152,7 @@ Within this API, we can interact with different policies conveniently.

### Elegant and Flexible

Currently, the overall code of Tianshou platform is less than 1500 lines without environment wrappers for Atari and Mujoco. Most of the implemented algorithms are less than 100 lines of python code. It is quite easy to go through the framework and understand how it works. We provide many flexible API as you wish, for instance, if you want to use your policy to interact with the environment with (at least) `n` steps:
Currently, the overall code of Tianshou platform is less than 2500 lines. Most of the implemented algorithms are less than 100 lines of python code. It is quite easy to go through the framework and understand how it works. We provide many flexible API as you wish, for instance, if you want to use your policy to interact with the environment with (at least) `n` steps:

```python
result = collector.collect(n_step=n)
Expand Down Expand Up @@ -201,8 +201,8 @@ Make environments:

```python
# you can also try with SubprocVectorEnv
train_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(test_num)])
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
```

Define the network:
Expand Down Expand Up @@ -249,7 +249,6 @@ Watch the performance with 35 FPS:
```python
collector = ts.data.Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
collector.close()
```

Look at the result saved in tensorboard: (with bash script in your terminal)
Expand Down
Binary file added docs/_static/images/async.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions docs/api/tianshou.env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@ tianshou.env
:members:
:undoc-members:
:show-inheritance:

.. automodule:: tianshou.env.worker
:members:
:undoc-members:
:show-inheritance:
54 changes: 41 additions & 13 deletions docs/tutorials/cheatsheet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,50 @@ See :ref:`customized_trainer`.
Parallel Sampling
-----------------

Use :class:`~tianshou.env.VectorEnv`, :class:`~tianshou.env.SubprocVectorEnv` or :class:`~tianshou.env.ShmemVectorEnv`.
::
Tianshou provides the following classes for parallel environment simulation:

- :class:`~tianshou.env.DummyVectorEnv` is for pseudo-parallel simulation (implemented with a for-loop, useful for debugging).

- :class:`~tianshou.env.SubprocVectorEnv` uses multiple processes for parallel simulation. This is the most often choice for parallel simulation.

- :class:`~tianshou.env.ShmemVectorEnv` has a similar implementation to :class:`~tianshou.env.SubprocVectorEnv`, but is optimized (in terms of both memory footprint and simulation speed) for environments with large observations such as images.

env_fns = [
lambda: MyTestEnv(size=2),
lambda: MyTestEnv(size=3),
lambda: MyTestEnv(size=4),
lambda: MyTestEnv(size=5),
]
venv = SubprocVectorEnv(env_fns)
- :class:`~tianshou.env.RayVectorEnv` is currently the only choice for parallel simulation in a cluster with multiple machines.

Although these classes are optimized for different scenarios, they have exactly the same APIs because they are sub-classes of :class:`~tianshou.env.BaseVectorEnv`. Just provide a list of functions who return environments upon called, and it is all set.

where ``env_fns`` is a list of callable env hooker. The above code can be written in for-loop as well:
::

env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]]
venv = SubprocVectorEnv(env_fns)
venv = SubprocVectorEnv(env_fns) # DummyVectorEnv, ShmemVectorEnv, or RayVectorEnv, whichever you like.
venv.reset() # returns the initial observations of each environment
venv.step(actions) # provide actions for each environment and get their results

.. sidebar:: An example of sync/async VectorEnv (steps with the same color end up in one batch that is disposed by the policy at the same time).

.. Figure:: ../_static/images/async.png

By default, parallel environment simulation is synchronous: a step is done after all environments have finished a step. Synchronous simulation works well if each step of environments costs roughly the same time.

In case the time cost of environments varies a lot (e.g. 90% step cost 1s, but 10% cost 10s) where slow environments lag fast environments behind, async simulation can be used (related to `Issue 103 <https://github.com/thu-ml/tianshou/issues/103>`_). The idea is to start those finished environments without waiting for slow environments.

Asynchronous simulation is a built-in functionality of :class:`~tianshou.env.BaseVectorEnv`. Just provide ``wait_num`` or ``timeout`` (or both) and async simulation works.

::

env_fns = [lambda x=i: MyTestEnv(size=x, sleep=x) for i in [2, 3, 4, 5]]
# DummyVectorEnv, ShmemVectorEnv, or RayVectorEnv, whichever you like.
venv = SubprocVectorEnv(env_fns, wait_num=3, timeout=0.2)
venv.reset() # returns the initial observations of each environment
# returns ``wait_num`` steps or finished steps after ``timeout`` seconds,
# whichever occurs first.
venv.step(actions, ready_id)

If we have 4 envs and set ``wait_num = 3``, each of the step only returns 3 results of these 4 envs.

You can treat the ``timeout`` parameter as a dynamic ``wait_num``. In each vectorized step it only returns the environments finished within the given time. If there is no such environment, it will wait until any of them finished.

The figure in the right gives an intuitive comparison among synchronous/asynchronous simulation.

.. warning::

Expand Down Expand Up @@ -139,9 +167,9 @@ First of all, your self-defined environment must follow the Gym's API, some of t

- step(action) -> state, reward, done, info

- seed(s) -> None
- seed(s) -> List[int]

- render(mode) -> None
- render(mode) -> Any

- close() -> None

Expand Down
7 changes: 3 additions & 4 deletions docs/tutorials/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ It is available if you want the original ``gym.Env``:
train_envs = gym.make('CartPole-v0')
test_envs = gym.make('CartPole-v0')

Tianshou supports parallel sampling for all algorithms. It provides four types of vectorized environment wrapper: :class:`~tianshou.env.VectorEnv`, :class:`~tianshou.env.SubprocVectorEnv`, :class:`~tianshou.env.ShmemVectorEnv`, and :class:`~tianshou.env.RayVectorEnv`. It can be used as follows:
Tianshou supports parallel sampling for all algorithms. It provides four types of vectorized environment wrapper: :class:`~tianshou.env.DummyVectorEnv`, :class:`~tianshou.env.SubprocVectorEnv`, :class:`~tianshou.env.ShmemVectorEnv`, and :class:`~tianshou.env.RayVectorEnv`. It can be used as follows: (more explanation can be found at :ref:`parallel_sampling`)
::

train_envs = ts.env.VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)])
test_envs = ts.env.VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])
train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])

Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_envs``.

Expand Down Expand Up @@ -178,7 +178,6 @@ Watch the Agent's Performance

collector = ts.data.Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
collector.close()

.. _customized_trainer:

Expand Down
14 changes: 4 additions & 10 deletions docs/tutorials/tictactoe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ Tianshou already provides some builtin classes for multi-agent learning. You can
===x _ o x _ _===
===x _ _ _ x x===
=================
>>> collector.close()

Random agents perform badly. In the above game, although agent 2 wins finally, it is clear that a smart agent 1 would place an ``x`` at row 4 col 4 to win directly.

Expand All @@ -175,7 +174,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter

from tianshou.env import VectorEnv
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
Expand Down Expand Up @@ -220,8 +219,7 @@ The explanation of each Tianshou class/function will be deferred to their first
help='the path of opponent agent pth file for resuming from a pre-trained agent')
parser.add_argument('--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
return parser.parse_args()

.. sidebar:: The relationship between MultiAgentPolicyManager (Manager) and BasePolicy (Agent)

Expand Down Expand Up @@ -290,15 +288,14 @@ With the above preparation, we are close to the first learned agent. The followi
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
if args.watch:
watch(args)
exit(0)

# ======== environment setup =========
env_func = lambda: TicTacToeEnv(args.board_size, args.win_size)
train_envs = VectorEnv([env_func for _ in range(args.training_num)])
test_envs = VectorEnv([env_func for _ in range(args.test_num)])
train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)])
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
Expand Down Expand Up @@ -351,9 +348,6 @@ With the above preparation, we are close to the first learned agent. The followi
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
test_in_train=False)

train_collector.close()
test_collector.close()

agent = policy.policies[args.agent_id - 1]
# let's watch the match!
watch(args, agent)
Expand Down
File renamed without changes.
6 changes: 1 addition & 5 deletions examples/pong_a2c.py → examples/atari/pong_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def get_args():
parser.add_argument('--ent-coef', type=float, default=0.001)
parser.add_argument('--max-grad-norm', type=float, default=None)
parser.add_argument('--max_episode_steps', type=int, default=2000)
args = parser.parse_known_args()[0]
return args
return parser.parse_args()


def test_a2c(args=get_args()):
Expand Down Expand Up @@ -90,16 +89,13 @@ def stop_fn(x):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
train_collector.close()
test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = create_atari_environment(args.task)
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()


if __name__ == '__main__':
Expand Down
6 changes: 1 addition & 5 deletions examples/pong_dqn.py → examples/atari/pong_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def get_args():
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
return parser.parse_args()


def test_dqn(args=get_args()):
Expand Down Expand Up @@ -96,16 +95,13 @@ def test_fn(x):
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, writer=writer)

train_collector.close()
test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = create_atari_environment(args.task)
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()


if __name__ == '__main__':
Expand Down
6 changes: 1 addition & 5 deletions examples/pong_ppo.py → examples/atari/pong_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def get_args():
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--max_episode_steps', type=int, default=2000)
args = parser.parse_known_args()[0]
return args
return parser.parse_args()


def test_ppo(args=get_args()):
Expand Down Expand Up @@ -94,16 +93,13 @@ def stop_fn(x):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
train_collector.close()
test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = create_atari_environment(args.task)
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
result = collector.collect(n_step=2000, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.env import VectorEnv
from tianshou.env import DummyVectorEnv
from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
Expand Down Expand Up @@ -36,8 +36,7 @@ def get_args():
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
return parser.parse_args()


def test_dqn(args=get_args()):
Expand All @@ -46,10 +45,10 @@ def test_dqn(args=get_args()):
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = VectorEnv(
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = VectorEnv(
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
Expand Down Expand Up @@ -100,16 +99,13 @@ def test_fn(x):
stop_fn=stop_fn, save_fn=save_fn, writer=writer)

assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def get_args():
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
return parser.parse_args()


class EnvWrapper(object):
Expand Down Expand Up @@ -136,15 +135,13 @@ def save_fn(policy):
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=IsStop, save_fn=save_fn, writer=writer)

test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = EnvWrapper(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=16, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit a9f9940

Please sign in to comment.