Skip to content

Commit

Permalink
add docs of collector and trainer (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 committed Apr 5, 2020
1 parent 4d4d0da commit 610390c
Show file tree
Hide file tree
Showing 13 changed files with 263 additions and 111 deletions.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ If no error occurs, you have successfully installed Tianshou.

tutorials/dqn
tutorials/concepts
tutorials/trick
tutorials/tabular
tutorials/trick

.. toctree::
:maxdepth: 1
Expand Down
55 changes: 32 additions & 23 deletions docs/tutorials/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Data Buffer
:members:
:noindex:

Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out the :doc:`/api/tianshou.data` API documentation for more detail.
Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.


Policy
Expand Down Expand Up @@ -85,51 +85,60 @@ Thus, we need a time-related interface for calculating the 2-step return. :meth:

This code does not consider the done flag, so it may not work very well. It shows two ways to get :math:`s_{t + 2}` from the replay buffer easily in :meth:`~tianshou.policy.BasePolicy.process_fn`.

For other method, you can check out the API documentation for more detail. We give a high-level explanation through the same pseudocode:
::

# pseudocode, cannot work # methods in tianshou
s = env.reset()
buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000)
agent = DQN() # done in policy.__init__(...)
for i in range(int(1e6)): # done in trainer
a = agent.compute_action(s) # done in policy.__call__(batch, ...)
s_, r, d, _ = env.step(a) # done in collector.collect(...)
buffer.store(s, a, s_, r, d) # done in collector.collect(...)
s = s_ # done in collector.collect(...)
if i % 1000 == 0: # done in trainer
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # done in collector.sample(batch_size)
# compute 2-step returns. How?
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # done in policy.process_fn(batch, buffer, indice)
# update DQN policy
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # done in policy.learn(batch, ...)
For other method, you can check out :doc:`/api/tianshou.policy`. We give the usage of policy class a high-level explanation in :ref:`pseudocode`.


Collector
---------

The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
In short, :class:`~tianshou.data.Collector` has two main methods:

* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of steps ``n_step`` or episodes ``n_episode`` and store the data in the replay buffer;
* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer;
* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data.

Why do we mention **at least** here? For a single environment, the collector will finish exactly ``n_step`` or ``n_episode``. However, for multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.

The solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number.

The general explanation is listed in the pseudocode above. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation.
The general explanation is listed in :ref:`pseudocode`. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation.


Trainer
-------

Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`.

Tianshou has two types of trainer: :meth:`~tianshou.trainer.onpolicy_trainer` and :meth:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out the API documentation for the usage.
Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage.

There will be more types of trainers, for instance, multi-agent trainer.


.. _pseudocode:

A High-level Explanation
------------------------

We give a high-level explanation through the pseudocode used in section Policy:
::

# pseudocode, cannot work # methods in tianshou
s = env.reset()
buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000)
agent = DQN() # done in policy.__init__(...)
for i in range(int(1e6)): # done in trainer
a = agent.compute_action(s) # done in policy.__call__(batch, ...)
s_, r, d, _ = env.step(a) # done in collector.collect(...)
buffer.store(s, a, s_, r, d) # done in collector.collect(...)
s = s_ # done in collector.collect(...)
if i % 1000 == 0: # done in trainer
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # done in collector.sample(batch_size)
# compute 2-step returns. How?
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # done in policy.process_fn(batch, buffer, indice)
# update DQN policy
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # done in policy.learn(batch, ...)


Conclusion
----------

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ The meaning of each parameter is as follows:
* ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``;
* ``step_per_epoch``: The number of step for updating policy network in one epoch;
* ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update";
* ``episode_per_test``: The number of episode for one policy evaluation.
* ``episode_per_test``: The number of episodes for one policy evaluation.
* ``batch_size``: The batch size of sample data, which is going to feed in the policy network.
* ``train_fn``: A function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
* ``test_fn``: A function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
Expand Down
24 changes: 11 additions & 13 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@


class Batch(object):
"""
Tianshou provides :class:`~tianshou.data.Batch` as the internal data
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
structure to pass any kind of data to other methods, for example, a
collector gives a :class:`~tianshou.data.Batch` to policy for learning.
Here is the usage:
Expand All @@ -25,12 +24,12 @@ class Batch(object):
current implementation of Tianshou typically use 6 keys in
:class:`~tianshou.data.Batch`:
* ``obs``: the observation of step :math:`t` ;
* ``act``: the action of step :math:`t` ;
* ``rew``: the reward of step :math:`t` ;
* ``done``: the done flag of step :math:`t` ;
* ``obs_next``: the observation of step :math:`t+1` ;
* ``info``: the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\
function return 4 arguments, and the last one is ``info``);
:class:`~tianshou.data.Batch` has other methods, including
Expand Down Expand Up @@ -75,7 +74,7 @@ def __getitem__(self, index):
return b

def append(self, batch):
"""Append a :class:`~tianshou.data.Batch` object to the end."""
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
for k in batch.__dict__.keys():
if batch.__dict__[k] is None:
Expand Down Expand Up @@ -103,12 +102,11 @@ def __len__(self):
if self.__dict__[k] is not None])

def split(self, size=None, permute=True):
"""
Split whole data into multiple small batch.
"""Split whole data into multiple small batch.
:param size: if equals to ``None``, it does not split the data batch; \
:param size: if it is ``None``, it does not split the data batch;
otherwise it will divide the data batch with the given size.
:param permute: randomly shuffle the entire data batch if it equals to\
:param permute: randomly shuffle the entire data batch if it is
``True``, otherwise remain in the same.
"""
length = len(self)
Expand Down
24 changes: 13 additions & 11 deletions tianshou/data/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@


class ReplayBuffer(object):
"""
:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction
between the policy and environment. It stores basically 6 types of data, as
mentioned in :class:`~tianshou.data.Batch`, based on ``numpy.ndarray``.
Here is the usage:
""":class:`~tianshou.data.ReplayBuffer` stores data generated from
interaction between the policy and environment. It stores basically 6 types
of data, as mentioned in :class:`~tianshou.data.Batch`, based on
``numpy.ndarray``. Here is the usage:
::
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf)
3
>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
Expand All @@ -22,11 +23,13 @@ class ReplayBuffer(object):
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
>>> len(buf2)
10
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
>>> # move buf2's result into buf (keep it chronologically meanwhile)
>>> # move buf2's result into buf (meanwhile keep it chronologically)
>>> buf.update(buf2)
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
0., 0., 0., 0., 0., 0., 0.])
Expand Down Expand Up @@ -96,8 +99,8 @@ def reset(self):
self.indice = []

def sample(self, batch_size):
"""
Get a random sample from buffer with size = ``batch_size``
"""Get a random sample from buffer with size equal to batch_size. \
Return all the data in the buffer if batch_size is ``0``.
:return: Sample data and its corresponding index inside the buffer.
"""
Expand All @@ -123,9 +126,8 @@ def __getitem__(self, index):


class ListReplayBuffer(ReplayBuffer):
"""
The function of :class:`~tianshou.data.ListReplayBuffer` is almost the same
as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
"""The function of :class:`~tianshou.data.ListReplayBuffer` is almost the
same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that
:class:`~tianshou.data.ListReplayBuffer` is based on ``list``.
"""

Expand Down
86 changes: 84 additions & 2 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,51 @@


class Collector(object):
"""docstring for Collector"""
"""The :class:`~tianshou.data.Collector` enables the policy to interact
with different types of environments conveniently. Here is the usage:
::
policy = PGPolicy(...) # or other policies if you wish
env = gym.make('CartPole-v0')
replay_buffer = ReplayBuffer(size=10000)
# here we set up a collector with a single environment
collector = Collector(policy, env, buffer=replay_buffer)
# the collector supports vectorized environments as well
envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
buffers = [ReplayBuffer(size=5000) for _ in range(3)]
# you can also pass a list of replay buffer to collector, for multi-env
# collector = Collector(policy, envs, buffer=buffers)
collector = Collector(policy, envs, buffer=replay_buffer)
# collect at least 3 episodes
collector.collect(n_episode=3)
# collect 1 episode for the first env, 3 for the third env
collector.collect(n_episode=[1, 0, 3])
# collect at least 2 steps
collector.collect(n_step=2)
# collect episodes with visual rendering (the render argument is the
# sleep time between rendering consecutive frames)
collector.collect(n_episode=1, render=0.03)
# sample data with a given number of batch-size:
batch_data = collector.sample(batch_size=64)
# policy.learn(batch_data) # btw, vanilla policy gradient only
# supports on-policy training, so here we pick all data in the buffer
batch_data = collector.sample(batch_size=0)
policy.learn(batch_data)
# on-policy algorithms use the collected data only once, so here we
# clear the buffer
collector.reset_buffer()
For the scenario of collecting data from multiple environments to a single
buffer, the cache buffers will turn on automatically. It may return the
data more than the given limitation.
.. note::
Please make sure the given environment has a time limitation.
"""

def __init__(self, policy, env, buffer=None, stat_size=100):
super().__init__()
Expand Down Expand Up @@ -48,16 +92,21 @@ def __init__(self, policy, env, buffer=None, stat_size=100):
self.episode_speed = MovAvg(stat_size)

def reset_buffer(self):
"""Reset the main data buffer."""
if self._multi_buf:
for b in self.buffer:
b.reset()
else:
self.buffer.reset()

def get_env_num(self):
"""Return the number of environments the collector has."""
return self.env_num

def reset_env(self):
"""Reset all of the environment(s)' states and reset all of the cache
buffers (if need).
"""
self._obs = self.env.reset()
self._act = self._rew = self._done = self._info = None
if self._multi_env:
Expand All @@ -69,14 +118,17 @@ def reset_env(self):
b.reset()

def seed(self, seed=None):
"""Reset all the seed(s) of the given environment(s)."""
if hasattr(self.env, 'seed'):
return self.env.seed(seed)

def render(self, **kwargs):
"""Render all the environment(s)."""
if hasattr(self.env, 'render'):
return self.env.render(**kwargs)

def close(self):
"""Close the environment(s)."""
if hasattr(self.env, 'close'):
self.env.close()

Expand All @@ -87,12 +139,34 @@ def _make_batch(self, data):
return np.array([data])

def collect(self, n_step=0, n_episode=0, render=0):
"""Collect a specified number of step or episode.
:param n_step: an int, indicates how many steps you want to collect.
:param n_episode: an int or a list, indicates how many episodes you
want to collect (in each environment).
:param render: a float, the sleep time between rendering consecutive
frames. ``0`` means no rendering.
.. note::
One and only one collection number specification is permitted,
either ``n_step`` or ``n_episode``.
:return: A dict including the following keys
* ``n/ep`` the collected number of episodes.
* ``n/st`` the collected number of steps.
* ``v/st`` the speed of steps per second.
* ``v/ep`` the speed of episode per second.
* ``rew`` the mean reward over collected episodes.
* ``len`` the mean length over collected episodes.
"""
warning_count = 0
if not self._multi_env:
n_episode = np.sum(n_episode)
start_time = time.time()
assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
"One and only one collection number specification permitted!"
"One and only one collection number specification is permitted!"
cur_step = 0
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
reward_sum = 0
Expand Down Expand Up @@ -218,6 +292,14 @@ def collect(self, n_step=0, n_episode=0, render=0):
}

def sample(self, batch_size):
"""Sample a data batch from the internal replay buffer. It will call
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
the final batch data.
:param batch_size: an int, ``0`` means it will extract all the data
from the buffer, otherwise it will extract the given batch_size of
data.
"""
if self._multi_buf:
if batch_size > 0:
lens = [len(b) for b in self.buffer]
Expand Down
Loading

0 comments on commit 610390c

Please sign in to comment.