Skip to content

Commit

Permalink
add policy docs (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 committed Apr 6, 2020
1 parent 610390c commit e0809ff
Show file tree
Hide file tree
Showing 20 changed files with 436 additions and 143 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ In Chinese, Tianshou means the innate talent, not taught by others. Tianshou is
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:

```bash
pip3 install tianshou
pip3 install tianshou -U
```

You can also install with the newest version through GitHub:
Expand Down
18 changes: 9 additions & 9 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ Welcome to Tianshou!

**Tianshou** (`天授 <https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88>`_) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:

* `Policy Gradient (PG) <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* `Deep Q-Network (DQN) <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* `Double DQN (DDQN) <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
* `Advantage Actor-Critic (A2C) <https://openai.com/blog/baselines-acktr-a2c/>`_
* `Deep Deterministic Policy Gradient (DDPG) <https://arxiv.org/pdf/1509.02971.pdf>`_
* `Proximal Policy Optimization (PPO) <https://arxiv.org/pdf/1707.06347.pdf>`_
* `Twin Delayed DDPG (TD3) <https://arxiv.org/pdf/1802.09477.pdf>`_
* `Soft Actor-Critic (SAC) <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_


Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms.
Expand All @@ -27,7 +27,7 @@ Installation
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command:
::

pip3 install tianshou
pip3 install tianshou -U

You can also install with the newest version through GitHub:
::
Expand Down
5 changes: 3 additions & 2 deletions docs/tutorials/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ Data Buffer

Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.

.. _policy_concept:

Policy
------

Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.

A policy class typically has four parts:
A policy class typically has four parts:

* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on;
* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given observation;
Expand Down Expand Up @@ -119,7 +120,7 @@ There will be more types of trainers, for instance, multi-agent trainer.
A High-level Explanation
------------------------

We give a high-level explanation through the pseudocode used in section Policy:
We give a high-level explanation through the pseudocode used in section :ref:`policy_concept`:
::

# pseudocode, cannot work # methods in tianshou
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/trick.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ Here is about the experience of hyper-parameter tuning on CartPole and Pendulum:
Code-level optimization
-----------------------

Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V_s` and :math:`V_{s'}` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward.
Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V(s)` and :math:`V(s')` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward.

Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference.
.. Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference.
Finally
Expand Down
5 changes: 4 additions & 1 deletion test/discrete/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@


class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
softmax=False):
super().__init__()
self.device = device
self.model = [
Expand All @@ -15,6 +16,8 @@ def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
if action_shape:
self.model += [nn.Linear(128, np.prod(action_shape))]
if softmax:
self.model += [nn.Softmax(dim=-1)]
self.model = nn.Sequential(*self.model)

def forward(self, s, state=None, info={}):
Expand Down
2 changes: 1 addition & 1 deletion test/discrete/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_pg(args=get_args()):
# model
net = Net(
args.layer_num, args.state_shape, args.action_shape,
device=args.device)
device=args.device, softmax=True)
net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
dist = torch.distributions.Categorical
Expand Down
7 changes: 4 additions & 3 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,11 @@ def __len__(self):
def split(self, size=None, permute=True):
"""Split whole data into multiple small batch.
:param size: if it is ``None``, it does not split the data batch;
:param int size: if it is ``None``, it does not split the data batch;
otherwise it will divide the data batch with the given size.
:param permute: randomly shuffle the entire data batch if it is
``True``, otherwise remain in the same.
Default to ``None``.
:param bool permute: randomly shuffle the entire data batch if it is
``True``, otherwise remain in the same. Default to ``True``.
"""
length = len(self)
if size is None:
Expand Down
44 changes: 32 additions & 12 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,22 @@

class Collector(object):
"""The :class:`~tianshou.data.Collector` enables the policy to interact
with different types of environments conveniently. Here is the usage:
with different types of environments conveniently.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
class.
:param env: an environment or an instance of the
:class:`~tianshou.env.BaseVectorEnv` class.
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
``None``, it will automatically assign a small-size
:class:`~tianshou.data.ReplayBuffer`.
:param int stat_size: for the moving average of recording speed, defaults
to 100.
:param bool store_obs_next: whether to store the obs_next to replay
buffer, defaults to ``True``.
Example:
::
policy = PGPolicy(...) # or other policies if you wish
Expand Down Expand Up @@ -55,7 +70,8 @@ class Collector(object):
Please make sure the given environment has a time limitation.
"""

def __init__(self, policy, env, buffer=None, stat_size=100):
def __init__(self, policy, env, buffer=None, stat_size=100,
store_obs_next=True, **kwargs):
super().__init__()
self.env = env
self.env_num = 1
Expand Down Expand Up @@ -90,6 +106,7 @@ def __init__(self, policy, env, buffer=None, stat_size=100):
self.state = None
self.step_speed = MovAvg(stat_size)
self.episode_speed = MovAvg(stat_size)
self._save_s_ = store_obs_next

def reset_buffer(self):
"""Reset the main data buffer."""
Expand Down Expand Up @@ -141,11 +158,12 @@ def _make_batch(self, data):
def collect(self, n_step=0, n_episode=0, render=0):
"""Collect a specified number of step or episode.
:param n_step: an int, indicates how many steps you want to collect.
:param n_episode: an int or a list, indicates how many episodes you
want to collect (in each environment).
:param render: a float, the sleep time between rendering consecutive
frames. ``0`` means no rendering.
:param int n_step: how many steps you want to collect.
:param n_episode: how many episodes you want to collect (in each
environment).
:type n_episode: int or list
:param float render: the sleep time between rendering consecutive
frames. No rendering if it is ``0`` (default option).
.. note::
Expand Down Expand Up @@ -210,7 +228,8 @@ def collect(self, n_step=0, n_episode=0, render=0):
data = {
'obs': self._obs[i], 'act': self._act[i],
'rew': self._rew[i], 'done': self._done[i],
'obs_next': obs_next[i], 'info': self._info[i]}
'obs_next': obs_next[i] if self._save_s_ else None,
'info': self._info[i]}
if self._cached_buf:
warning_count += 1
self._cached_buf[i].add(**data)
Expand Down Expand Up @@ -255,7 +274,8 @@ def collect(self, n_step=0, n_episode=0, render=0):
else:
self.buffer.add(
self._obs, self._act[0], self._rew,
self._done, obs_next, self._info)
self._done, obs_next if self._save_s_ else None,
self._info)
cur_step += 1
if self._done:
cur_episode += 1
Expand Down Expand Up @@ -296,9 +316,9 @@ def sample(self, batch_size):
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
the final batch data.
:param batch_size: an int, ``0`` means it will extract all the data
from the buffer, otherwise it will extract the given batch_size of
data.
:param int batch_size: ``0`` means it will extract all the data from
the buffer, otherwise it will extract the data with the given
batch_size.
"""
if self._multi_buf:
if batch_size > 0:
Expand Down
3 changes: 1 addition & 2 deletions tianshou/env/vecenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def step(self, action):
Accept a batch of action and return a tuple (obs, rew, done, info).
:param action: a numpy.ndarray, a batch of action provided by the
agent.
:param numpy.ndarray action: a batch of action provided by the agent.
:return: A tuple including four items:
Expand Down
39 changes: 34 additions & 5 deletions tianshou/policy/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,55 @@


class A2CPolicy(PGPolicy):
"""docstring for A2CPolicy"""
"""Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.nn.Module critic: the critic network. (s -> V(s))
:param torch.optim.Optimizer optim: the optimizer for actor and critic
network.
:param torch.distributions.Distribution dist_fn: for computing the action,
defaults to ``torch.distributions.Categorical``.
:param float discount_factor: in [0, 1], defaults to 0.99.
:param float vf_coef: weight for value loss, defaults to 0.5.
:param float ent_coef: weight for entropy loss, defaults to 0.01.
:param float max_grad_norm: clipping gradients in back propagation,
defaults to ``None``.
"""

def __init__(self, actor, critic, optim,
dist_fn=torch.distributions.Categorical,
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
max_grad_norm=None):
max_grad_norm=None, **kwargs):
super().__init__(None, optim, dist_fn, discount_factor)
self.actor = actor
self.critic = critic
self._w_vf = vf_coef
self._w_ent = ent_coef
self._grad_norm = max_grad_norm

def __call__(self, batch, state=None):
def __call__(self, batch, state=None, **kwargs):
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``dist`` the action distribution.
* ``state`` the hidden state.
More information can be found at
:meth:`~tianshou.policy.BasePolicy.__call__`.
"""
logits, h = self.actor(batch.obs, state=state, info=batch.info)
dist = self.dist_fn(logits)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
dist = self.dist_fn(logits)
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)

def learn(self, batch, batch_size=None, repeat=1):
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat):
for b in batch.split(batch_size):
Expand Down
67 changes: 59 additions & 8 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,74 @@


class BasePolicy(ABC, nn.Module):
"""docstring for BasePolicy"""
"""Tianshou aims to modularizing RL algorithms. It comes into several
classes of policies in Tianshou. All of the policy classes must inherit
:class:`~tianshou.policy.BasePolicy`.
def __init__(self):
A policy class typically has four parts:
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \
including coping the target network and so on;
* :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given \
observation;
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \
the replay buffer (this function can interact with replay buffer);
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given \
batch of data.
Most of the policy needs a neural network to predict the action and an
optimizer to optimize the policy. The rules of self-defined networks are:
1. Input: observation ``obs`` (may be a ``numpy.ndarray`` or \
``torch.Tensor``), hidden state ``state`` (for RNN usage), and other \
information ``info`` provided by the environment.
2. Output: some ``logits`` and the next hidden state ``state``. The logits\
could be a tuple instead of a ``torch.Tensor``. It depends on how the \
policy process the network output. For example, in PPO, the return of \
the network might be ``(mu, sigma), state`` for Gaussian policy.
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``,
you can operate :class:`~tianshou.policy.BasePolicy` almost the same as
``torch.nn.Module``, for instance, load and save the model:
::
torch.save(policy.state_dict(), 'policy.pth')
policy.load_state_dict(torch.load('policy.pth'))
"""

def __init__(self, **kwargs):
super().__init__()

def process_fn(self, batch, buffer, indice):
"""Pre-process the data from the provided replay buffer. Check out
:ref:`policy_concept` for more information.
"""
return batch

@abstractmethod
def __call__(self, batch, state=None):
# return Batch(logits=..., act=..., state=None, ...)
def __call__(self, batch, state=None, **kwargs):
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which MUST have the following\
keys:
* ``act`` an numpy.ndarray or a torch.Tensor, the action over \
given batch data.
* ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \
internal state of the policy, ``None`` as default.
Other keys are user-defined. It depends on the algorithm. For example,
::
# some code
return Batch(logits=..., act=..., state=None, dist=...)
"""
pass

@abstractmethod
def learn(self, batch, batch_size=None):
# return a dict which includes loss and its name
pass
def learn(self, batch, **kwargs):
"""Update policy with a given batch of data.
def sync_weight(self):
:return: A dict which includes loss and its corresponding label.
"""
pass
Loading

0 comments on commit e0809ff

Please sign in to comment.