diff --git a/README.md b/README.md index d7710ae9f..fe4f2b91e 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ In Chinese, Tianshou means the innate talent, not taught by others. Tianshou is Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command: ```bash -pip3 install tianshou +pip3 install tianshou -U ``` You can also install with the newest version through GitHub: diff --git a/docs/index.rst b/docs/index.rst index 1e983ab95..f7cc3dfb4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -8,14 +8,14 @@ Welcome to Tianshou! **Tianshou** (`天授 `_) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include: -* `Policy Gradient (PG) `_ -* `Deep Q-Network (DQN) `_ -* `Double DQN (DDQN) `_ with n-step returns -* `Advantage Actor-Critic (A2C) `_ -* `Deep Deterministic Policy Gradient (DDPG) `_ -* `Proximal Policy Optimization (PPO) `_ -* `Twin Delayed DDPG (TD3) `_ -* `Soft Actor-Critic (SAC) `_ +* :class:`~tianshou.policy.PGPolicy` `Policy Gradient `_ +* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network `_ +* :class:`~tianshou.policy.DQNPolicy` `Double DQN `_ with n-step returns +* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic `_ +* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient `_ +* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization `_ +* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG `_ +* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ Tianshou supports parallel workers for all algorithms as well. All of these algorithms are reformatted as replay-buffer based algorithms. @@ -27,7 +27,7 @@ Installation Tianshou is currently hosted on `PyPI `_. You can simply install Tianshou with the following command: :: - pip3 install tianshou + pip3 install tianshou -U You can also install with the newest version through GitHub: :: diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 40a7bd194..f5dca97ae 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -25,13 +25,14 @@ Data Buffer Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail. +.. _policy_concept: Policy ------ Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`. -A policy class typically has four parts: +A policy class typically has four parts: * :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on; * :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given observation; @@ -119,7 +120,7 @@ There will be more types of trainers, for instance, multi-agent trainer. A High-level Explanation ------------------------ -We give a high-level explanation through the pseudocode used in section Policy: +We give a high-level explanation through the pseudocode used in section :ref:`policy_concept`: :: # pseudocode, cannot work # methods in tianshou diff --git a/docs/tutorials/trick.rst b/docs/tutorials/trick.rst index 6cfd4b6eb..6b7d599ce 100644 --- a/docs/tutorials/trick.rst +++ b/docs/tutorials/trick.rst @@ -68,9 +68,9 @@ Here is about the experience of hyper-parameter tuning on CartPole and Pendulum: Code-level optimization ----------------------- -Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V_s` and :math:`V_{s'}` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward. +Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V(s)` and :math:`V(s')` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward. -Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference. +.. Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference. Finally diff --git a/test/discrete/net.py b/test/discrete/net.py index f272e0954..d64c32f9c 100644 --- a/test/discrete/net.py +++ b/test/discrete/net.py @@ -5,7 +5,8 @@ class Net(nn.Module): - def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): + def __init__(self, layer_num, state_shape, action_shape=0, device='cpu', + softmax=False): super().__init__() self.device = device self.model = [ @@ -15,6 +16,8 @@ def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] if action_shape: self.model += [nn.Linear(128, np.prod(action_shape))] + if softmax: + self.model += [nn.Softmax(dim=-1)] self.model = nn.Sequential(*self.model) def forward(self, s, state=None, info={}): diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index e0754f52a..5bcdeab75 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -113,7 +113,7 @@ def test_pg(args=get_args()): # model net = Net( args.layer_num, args.state_shape, args.action_shape, - device=args.device) + device=args.device, softmax=True) net = net.to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) dist = torch.distributions.Categorical diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index d6a95f1fa..301de9aec 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -104,10 +104,11 @@ def __len__(self): def split(self, size=None, permute=True): """Split whole data into multiple small batch. - :param size: if it is ``None``, it does not split the data batch; + :param int size: if it is ``None``, it does not split the data batch; otherwise it will divide the data batch with the given size. - :param permute: randomly shuffle the entire data batch if it is - ``True``, otherwise remain in the same. + Default to ``None``. + :param bool permute: randomly shuffle the entire data batch if it is + ``True``, otherwise remain in the same. Default to ``True``. """ length = len(self) if size is None: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 6607b43e8..3572d6c25 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -10,7 +10,22 @@ class Collector(object): """The :class:`~tianshou.data.Collector` enables the policy to interact - with different types of environments conveniently. Here is the usage: + with different types of environments conveniently. + + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` + class. + :param env: an environment or an instance of the + :class:`~tianshou.env.BaseVectorEnv` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` + class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to + ``None``, it will automatically assign a small-size + :class:`~tianshou.data.ReplayBuffer`. + :param int stat_size: for the moving average of recording speed, defaults + to 100. + :param bool store_obs_next: whether to store the obs_next to replay + buffer, defaults to ``True``. + + Example: :: policy = PGPolicy(...) # or other policies if you wish @@ -55,7 +70,8 @@ class Collector(object): Please make sure the given environment has a time limitation. """ - def __init__(self, policy, env, buffer=None, stat_size=100): + def __init__(self, policy, env, buffer=None, stat_size=100, + store_obs_next=True, **kwargs): super().__init__() self.env = env self.env_num = 1 @@ -90,6 +106,7 @@ def __init__(self, policy, env, buffer=None, stat_size=100): self.state = None self.step_speed = MovAvg(stat_size) self.episode_speed = MovAvg(stat_size) + self._save_s_ = store_obs_next def reset_buffer(self): """Reset the main data buffer.""" @@ -141,11 +158,12 @@ def _make_batch(self, data): def collect(self, n_step=0, n_episode=0, render=0): """Collect a specified number of step or episode. - :param n_step: an int, indicates how many steps you want to collect. - :param n_episode: an int or a list, indicates how many episodes you - want to collect (in each environment). - :param render: a float, the sleep time between rendering consecutive - frames. ``0`` means no rendering. + :param int n_step: how many steps you want to collect. + :param n_episode: how many episodes you want to collect (in each + environment). + :type n_episode: int or list + :param float render: the sleep time between rendering consecutive + frames. No rendering if it is ``0`` (default option). .. note:: @@ -210,7 +228,8 @@ def collect(self, n_step=0, n_episode=0, render=0): data = { 'obs': self._obs[i], 'act': self._act[i], 'rew': self._rew[i], 'done': self._done[i], - 'obs_next': obs_next[i], 'info': self._info[i]} + 'obs_next': obs_next[i] if self._save_s_ else None, + 'info': self._info[i]} if self._cached_buf: warning_count += 1 self._cached_buf[i].add(**data) @@ -255,7 +274,8 @@ def collect(self, n_step=0, n_episode=0, render=0): else: self.buffer.add( self._obs, self._act[0], self._rew, - self._done, obs_next, self._info) + self._done, obs_next if self._save_s_ else None, + self._info) cur_step += 1 if self._done: cur_episode += 1 @@ -296,9 +316,9 @@ def sample(self, batch_size): :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. - :param batch_size: an int, ``0`` means it will extract all the data - from the buffer, otherwise it will extract the given batch_size of - data. + :param int batch_size: ``0`` means it will extract all the data from + the buffer, otherwise it will extract the data with the given + batch_size. """ if self._multi_buf: if batch_size > 0: diff --git a/tianshou/env/vecenv.py b/tianshou/env/vecenv.py index ab43459ff..79abbc38a 100644 --- a/tianshou/env/vecenv.py +++ b/tianshou/env/vecenv.py @@ -60,8 +60,7 @@ def step(self, action): Accept a batch of action and return a tuple (obs, rew, done, info). - :param action: a numpy.ndarray, a batch of action provided by the - agent. + :param numpy.ndarray action: a batch of action provided by the agent. :return: A tuple including four items: diff --git a/tianshou/policy/a2c.py b/tianshou/policy/a2c.py index 6a0db238b..52bd54394 100644 --- a/tianshou/policy/a2c.py +++ b/tianshou/policy/a2c.py @@ -7,12 +7,26 @@ class A2CPolicy(PGPolicy): - """docstring for A2CPolicy""" + """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783 + + :param torch.nn.Module actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.nn.Module critic: the critic network. (s -> V(s)) + :param torch.optim.Optimizer optim: the optimizer for actor and critic + network. + :param torch.distributions.Distribution dist_fn: for computing the action, + defaults to ``torch.distributions.Categorical``. + :param float discount_factor: in [0, 1], defaults to 0.99. + :param float vf_coef: weight for value loss, defaults to 0.5. + :param float ent_coef: weight for entropy loss, defaults to 0.01. + :param float max_grad_norm: clipping gradients in back propagation, + defaults to ``None``. + """ def __init__(self, actor, critic, optim, dist_fn=torch.distributions.Categorical, discount_factor=0.99, vf_coef=.5, ent_coef=.01, - max_grad_norm=None): + max_grad_norm=None, **kwargs): super().__init__(None, optim, dist_fn, discount_factor) self.actor = actor self.critic = critic @@ -20,13 +34,28 @@ def __init__(self, actor, critic, optim, self._w_ent = ent_coef self._grad_norm = max_grad_norm - def __call__(self, batch, state=None): + def __call__(self, batch, state=None, **kwargs): + """Compute action over the given batch data. + + :return: A :class:`~tianshou.data.Batch` which has 4 keys: + + * ``act`` the action. + * ``logits`` the network's raw output. + * ``dist`` the action distribution. + * ``state`` the hidden state. + + More information can be found at + :meth:`~tianshou.policy.BasePolicy.__call__`. + """ logits, h = self.actor(batch.obs, state=state, info=batch.info) - dist = self.dist_fn(logits) + if isinstance(logits, tuple): + dist = self.dist_fn(*logits) + else: + dist = self.dist_fn(logits) act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist) - def learn(self, batch, batch_size=None, repeat=1): + def learn(self, batch, batch_size=None, repeat=1, **kwargs): losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size): diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 487d6c8c1..ef34655a5 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -3,23 +3,74 @@ class BasePolicy(ABC, nn.Module): - """docstring for BasePolicy""" + """Tianshou aims to modularizing RL algorithms. It comes into several + classes of policies in Tianshou. All of the policy classes must inherit + :class:`~tianshou.policy.BasePolicy`. - def __init__(self): + A policy class typically has four parts: + + * :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \ + including coping the target network and so on; + * :meth:`~tianshou.policy.BasePolicy.__call__`: compute action with given \ + observation; + * :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \ + the replay buffer (this function can interact with replay buffer); + * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given \ + batch of data. + + Most of the policy needs a neural network to predict the action and an + optimizer to optimize the policy. The rules of self-defined networks are: + + 1. Input: observation ``obs`` (may be a ``numpy.ndarray`` or \ + ``torch.Tensor``), hidden state ``state`` (for RNN usage), and other \ + information ``info`` provided by the environment. + 2. Output: some ``logits`` and the next hidden state ``state``. The logits\ + could be a tuple instead of a ``torch.Tensor``. It depends on how the \ + policy process the network output. For example, in PPO, the return of \ + the network might be ``(mu, sigma), state`` for Gaussian policy. + + Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``, + you can operate :class:`~tianshou.policy.BasePolicy` almost the same as + ``torch.nn.Module``, for instance, load and save the model: + :: + + torch.save(policy.state_dict(), 'policy.pth') + policy.load_state_dict(torch.load('policy.pth')) + """ + + def __init__(self, **kwargs): super().__init__() def process_fn(self, batch, buffer, indice): + """Pre-process the data from the provided replay buffer. Check out + :ref:`policy_concept` for more information. + """ return batch @abstractmethod - def __call__(self, batch, state=None): - # return Batch(logits=..., act=..., state=None, ...) + def __call__(self, batch, state=None, **kwargs): + """Compute action over the given batch data. + + :return: A :class:`~tianshou.data.Batch` which MUST have the following\ + keys: + + * ``act`` an numpy.ndarray or a torch.Tensor, the action over \ + given batch data. + * ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \ + internal state of the policy, ``None`` as default. + + Other keys are user-defined. It depends on the algorithm. For example, + :: + + # some code + return Batch(logits=..., act=..., state=None, dist=...) + """ pass @abstractmethod - def learn(self, batch, batch_size=None): - # return a dict which includes loss and its name - pass + def learn(self, batch, **kwargs): + """Update policy with a given batch of data. - def sync_weight(self): + :return: A dict which includes loss and its corresponding label. + """ pass diff --git a/tianshou/policy/ddpg.py b/tianshou/policy/ddpg.py index 68093b55a..a7a95785c 100644 --- a/tianshou/policy/ddpg.py +++ b/tianshou/policy/ddpg.py @@ -5,18 +5,35 @@ from tianshou.data import Batch from tianshou.policy import BasePolicy - - # from tianshou.exploration import OUNoise class DDPGPolicy(BasePolicy): - """docstring for DDPGPolicy""" + """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971 + + :param torch.nn.Module actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer actor_optim: the optimizer for actor network. + :param torch.nn.Module critic: the critic network. (s, a -> Q(s, a)) + :param torch.optim.Optimizer critic_optim: the optimizer for critic + network. + :param float tau: param for soft update of the target network, defaults to + 0.005. + :param float gamma: discount factor, in [0, 1], defaults to 0.99. + :param float exploration_noise: the noise intensity, add to the action, + defaults to 0.1. + :param action_range: the action range (minimum, maximum). + :type action_range: [float, float] + :param bool reward_normalization: normalize the reward to Normal(0, 1), + defaults to ``False``. + :param bool ignore_done: ignore the done flag while training the policy, + defaults to ``False``. + """ def __init__(self, actor, actor_optim, critic, critic_optim, tau=0.005, gamma=0.99, exploration_noise=0.1, action_range=None, reward_normalization=False, - ignore_done=False): + ignore_done=False, **kwargs): super().__init__() if actor is not None: self.actor, self.actor_old = actor, deepcopy(actor) @@ -26,9 +43,9 @@ def __init__(self, actor, actor_optim, critic, critic_optim, self.critic, self.critic_old = critic, deepcopy(critic) self.critic_old.eval() self.critic_optim = critic_optim - assert 0 < tau <= 1, 'tau should in (0, 1]' + assert 0 <= tau <= 1, 'tau should in [0, 1]' self._tau = tau - assert 0 < gamma <= 1, 'gamma should in (0, 1]' + assert 0 <= gamma <= 1, 'gamma should in [0, 1]' self._gamma = gamma assert 0 <= exploration_noise, 'noise should not be negative' self._eps = exploration_noise @@ -43,19 +60,23 @@ def __init__(self, actor, actor_optim, critic, critic_optim, self.__eps = np.finfo(np.float32).eps.item() def set_eps(self, eps): + """Set the eps for exploration.""" self._eps = eps def train(self): + """Set the module in training mode, except for the target network.""" self.training = True self.actor.train() self.critic.train() def eval(self): + """Set the module in evaluation mode, except for the target network.""" self.training = False self.actor.eval() self.critic.eval() def sync_weight(self): + """Soft-update the weight for the target network.""" for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) for o, n in zip( @@ -73,7 +94,19 @@ def process_fn(self, batch, buffer, indice): return batch def __call__(self, batch, state=None, - model='actor', input='obs', eps=None): + model='actor', input='obs', eps=None, **kwargs): + """Compute action over the given batch data. + + :param float eps: in [0, 1], for exploration use. + + :return: A :class:`~tianshou.data.Batch` which has 2 keys: + + * ``act`` the action. + * ``state`` the hidden state. + + More information can be found at + :meth:`~tianshou.policy.BasePolicy.__call__`. + """ model = getattr(self, model) obs = getattr(batch, input) logits, h = model(obs, state=state, info=batch.info) @@ -89,7 +122,7 @@ def __call__(self, batch, state=None, logits = logits.clamp(self._range[0], self._range[1]) return Batch(act=logits, state=h) - def learn(self, batch, batch_size=None, repeat=1): + def learn(self, batch, **kwargs): with torch.no_grad(): target_q = self.critic_old(batch.obs_next, self( batch, model='actor_old', input='obs_next', eps=0).act) diff --git a/tianshou/policy/dqn.py b/tianshou/policy/dqn.py index 25bf20299..e8980a984 100644 --- a/tianshou/policy/dqn.py +++ b/tianshou/policy/dqn.py @@ -8,11 +8,20 @@ class DQNPolicy(BasePolicy): - """docstring for DQNPolicy""" + """Implementation of Deep Q Network. arXiv:1312.5602 + + :param torch.nn.Module model: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. + :param float discount_factor: in [0, 1]. + :param int estimation_step: greater than 1, the number of steps to look + ahead. + :param int target_update_freq: the target network update frequency (``0`` + if you do not use the target network). + """ def __init__(self, model, optim, discount_factor=0.99, - estimation_step=1, use_target_network=True, - target_update_freq=300): + estimation_step=1, target_update_freq=0, **kwargs): super().__init__() self.model = model self.optim = optim @@ -21,28 +30,44 @@ def __init__(self, model, optim, discount_factor=0.99, self._gamma = discount_factor assert estimation_step > 0, 'estimation_step should greater than 0' self._n_step = estimation_step - self._target = use_target_network + self._target = target_update_freq > 0 self._freq = target_update_freq self._cnt = 0 - if use_target_network: + if self._target: self.model_old = deepcopy(self.model) self.model_old.eval() def set_eps(self, eps): + """Set the eps for epsilon-greedy exploration.""" self.eps = eps def train(self): + """Set the module in training mode, except for the target network.""" self.training = True self.model.train() def eval(self): + """Set the module in evaluation mode, except for the target network.""" self.training = False self.model.eval() def sync_weight(self): + """Synchronize the weight for the target network.""" self.model_old.load_state_dict(self.model.state_dict()) def process_fn(self, batch, buffer, indice): + r"""Compute the n-step return for Q-learning targets: + + .. math:: + G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + + \gamma^n (1 - d_{t + n}) \max_a Q_{old}(s_{t + n}, \arg\max_a + (Q_{new}(s_{t + n}, a))) + + , where :math:`\gamma` is the discount factor, + :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step + :math:`t`. If there is no target network, the :math:`Q_{old}` is equal + to :math:`Q_{new}`. + """ returns = np.zeros_like(indice) gammas = np.zeros_like(indice) + self._n_step for n in range(self._n_step - 1, -1, -1): @@ -70,7 +95,20 @@ def process_fn(self, batch, buffer, indice): return batch def __call__(self, batch, state=None, - model='model', input='obs', eps=None): + model='model', input='obs', eps=None, **kwargs): + """Compute action over the given batch data. + + :param float eps: in [0, 1], for epsilon-greedy exploration method. + + :return: A :class:`~tianshou.data.Batch` which has 3 keys: + + * ``act`` the action. + * ``logits`` the network's raw output. + * ``state`` the hidden state. + + More information can be found at + :meth:`~tianshou.policy.BasePolicy.__call__`. + """ model = getattr(self, model) obs = getattr(batch, input) q, h = model(obs, state=state, info=batch.info) @@ -83,7 +121,7 @@ def __call__(self, batch, state=None, act[i] = np.random.randint(q.shape[1]) return Batch(logits=q, act=act, state=h) - def learn(self, batch, batch_size=None, repeat=1): + def learn(self, batch, **kwargs): if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/pg.py b/tianshou/policy/pg.py index 941eaa4e3..ca3173865 100644 --- a/tianshou/policy/pg.py +++ b/tianshou/policy/pg.py @@ -1,37 +1,65 @@ import torch import numpy as np -import torch.nn.functional as F from tianshou.data import Batch from tianshou.policy import BasePolicy class PGPolicy(BasePolicy): - """docstring for PGPolicy""" + """Implementation of Vanilla Policy Gradient. + + :param torch.nn.Module model: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. + :param torch.distributions.Distribution dist_fn: for computing the action. + :param float discount_factor: in [0, 1]. + """ def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, - discount_factor=0.99): + discount_factor=0.99, **kwargs): super().__init__() self.model = model self.optim = optim self.dist_fn = dist_fn self._eps = np.finfo(np.float32).eps.item() - assert 0 < discount_factor <= 1, 'discount_factor should in (0, 1]' + assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]' self._gamma = discount_factor def process_fn(self, batch, buffer, indice): + r"""Compute the discounted returns for each frame: + + .. math:: + G_t = \sum_{i=t}^T \gamma^{i-t}r_i + + , where :math:`T` is the terminal time step, :math:`\gamma` is the + discount factor, :math:`\gamma \in [0, 1]`. + """ batch.returns = self._vanilla_returns(batch) # batch.returns = self._vectorized_returns(batch) return batch - def __call__(self, batch, state=None): + def __call__(self, batch, state=None, **kwargs): + """Compute action over the given batch data. + + :return: A :class:`~tianshou.data.Batch` which has 4 keys: + + * ``act`` the action. + * ``logits`` the network's raw output. + * ``dist`` the action distribution. + * ``state`` the hidden state. + + More information can be found at + :meth:`~tianshou.policy.BasePolicy.__call__`. + """ logits, h = self.model(batch.obs, state=state, info=batch.info) - logits = F.softmax(logits, dim=1) - dist = self.dist_fn(logits) + if isinstance(logits, tuple): + dist = self.dist_fn(*logits) + else: + dist = self.dist_fn(logits) act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist) - def learn(self, batch, batch_size=None, repeat=1): + def learn(self, batch, batch_size=None, repeat=1, **kwargs): losses = [] r = batch.returns batch.returns = (r - r.mean()) / (r.std() + self._eps) @@ -57,7 +85,7 @@ def _vanilla_returns(self, batch): return returns def _vectorized_returns(self, batch): - # according to my tests, it is slower than vanilla + # according to my tests, it is slower than _vanilla_returns # import scipy.signal convolve = np.convolve # convolve = scipy.signal.convolve diff --git a/tianshou/policy/ppo.py b/tianshou/policy/ppo.py index d36472ab3..1bca45250 100644 --- a/tianshou/policy/ppo.py +++ b/tianshou/policy/ppo.py @@ -9,7 +9,24 @@ class PPOPolicy(PGPolicy): - """docstring for PPOPolicy""" + r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347 + + :param torch.nn.Module actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.nn.Module critic: the critic network. (s -> V(s)) + :param torch.optim.Optimizer optim: the optimizer for actor and critic + network. + :param torch.distributions.Distribution dist_fn: for computing the action. + :param float discount_factor: in [0, 1], defaults to 0.99. + :param float max_grad_norm: clipping gradients in back propagation, + defaults to ``None``. + :param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original + paper, defaults to 0.2. + :param float vf_coef: weight for value loss, defaults to 0.5. + :param float ent_coef: weight for entropy loss, defaults to 0.01. + :param action_range: the action range (minimum, maximum). + :type action_range: [float, float] + """ def __init__(self, actor, critic, optim, dist_fn, discount_factor=0.99, @@ -17,7 +34,8 @@ def __init__(self, actor, critic, optim, dist_fn, eps_clip=.2, vf_coef=.5, ent_coef=.0, - action_range=None): + action_range=None, + **kwargs): super().__init__(None, None, dist_fn, discount_factor) self._max_grad_norm = max_grad_norm self._eps_clip = eps_clip @@ -31,16 +49,30 @@ def __init__(self, actor, critic, optim, dist_fn, self.optim = optim def train(self): + """Set the module in training mode, except for the target network.""" self.training = True self.actor.train() self.critic.train() def eval(self): + """Set the module in evaluation mode, except for the target network.""" self.training = False self.actor.eval() self.critic.eval() - def __call__(self, batch, state=None, model='actor'): + def __call__(self, batch, state=None, model='actor', **kwargs): + """Compute action over the given batch data. + + :return: A :class:`~tianshou.data.Batch` which has 4 keys: + + * ``act`` the action. + * ``logits`` the network's raw output. + * ``dist`` the action distribution. + * ``state`` the hidden state. + + More information can be found at + :meth:`~tianshou.policy.BasePolicy.__call__`. + """ model = getattr(self, model) logits, h = model(batch.obs, state=state, info=batch.info) if isinstance(logits, tuple): @@ -53,10 +85,11 @@ def __call__(self, batch, state=None, model='actor'): return Batch(logits=logits, act=act, state=h, dist=dist) def sync_weight(self): + """Synchronize the weight for the target network.""" self.actor_old.load_state_dict(self.actor.state_dict()) self.critic_old.load_state_dict(self.critic.state_dict()) - def learn(self, batch, batch_size=None, repeat=1): + def learn(self, batch, batch_size=None, repeat=1, **kwargs): losses, clip_losses, vf_losses, ent_losses = [], [], [], [] r = batch.returns batch.returns = (r - r.mean()) / (r.std() + self._eps) @@ -79,7 +112,6 @@ def learn(self, batch, batch_size=None, repeat=1): clip_losses.append(clip_loss.item()) vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v) vf_losses.append(vf_loss.item()) - e_loss = dist.entropy().mean() ent_losses.append(e_loss.item()) loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss diff --git a/tianshou/policy/sac.py b/tianshou/policy/sac.py index bc231dbfd..7a28dc34d 100644 --- a/tianshou/policy/sac.py +++ b/tianshou/policy/sac.py @@ -8,12 +8,37 @@ class SACPolicy(DDPGPolicy): - """docstring for SACPolicy""" + """Implementation of Soft Actor-Critic. arXiv:1812.05905 + + :param torch.nn.Module actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer actor_optim: the optimizer for actor network. + :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, + a)) + :param torch.optim.Optimizer critic1_optim: the optimizer for the first + critic network. + :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, + a)) + :param torch.optim.Optimizer critic2_optim: the optimizer for the second + critic network. + :param float tau: param for soft update of the target network, defaults to + 0.005. + :param float gamma: discount factor, in [0, 1], defaults to 0.99. + :param float exploration_noise: the noise intensity, add to the action, + defaults to 0.1. + :param float alpha: entropy regularization coefficient, default to 0.2. + :param action_range: the action range (minimum, maximum). + :type action_range: [float, float] + :param bool reward_normalization: normalize the reward to Normal(0, 1), + defaults to ``False``. + :param bool ignore_done: ignore the done flag while training the policy, + defaults to ``False``. + """ def __init__(self, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau=0.005, gamma=0.99, alpha=0.2, action_range=None, reward_normalization=False, - ignore_done=False): + ignore_done=False, **kwargs): super().__init__(None, None, None, None, tau, gamma, 0, action_range, reward_normalization, ignore_done) self.actor, self.actor_optim = actor, actor_optim @@ -46,12 +71,11 @@ def sync_weight(self): self.critic2_old.parameters(), self.critic2.parameters()): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) - def __call__(self, batch, state=None, input='obs'): + def __call__(self, batch, state=None, input='obs', **kwargs): obs = getattr(batch, input) logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = torch.distributions.Normal(*logits) - x = dist.rsample() y = torch.tanh(x) act = y * self._action_scale + self._action_bias @@ -61,7 +85,7 @@ def __call__(self, batch, state=None, input='obs'): return Batch( logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) - def learn(self, batch, batch_size=None, repeat=1): + def learn(self, batch, **kwargs): with torch.no_grad(): obs_next_result = self(batch, input='obs_next') a_ = obs_next_result.act diff --git a/tianshou/policy/td3.py b/tianshou/policy/td3.py index d72cb5b67..5df68f6b2 100644 --- a/tianshou/policy/td3.py +++ b/tianshou/policy/td3.py @@ -6,13 +6,44 @@ class TD3Policy(DDPGPolicy): - """docstring for TD3Policy""" + """Implementation of Twin Delayed Deep Deterministic Policy Gradient, + arXiv:1802.09477 + + :param torch.nn.Module actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer actor_optim: the optimizer for actor network. + :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, + a)) + :param torch.optim.Optimizer critic1_optim: the optimizer for the first + critic network. + :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, + a)) + :param torch.optim.Optimizer critic2_optim: the optimizer for the second + critic network. + :param float tau: param for soft update of the target network, defaults to + 0.005. + :param float gamma: discount factor, in [0, 1], defaults to 0.99. + :param float exploration_noise: the noise intensity, add to the action, + defaults to 0.1. + :param float policy_noise: the noise used in updating policy network, + default to 0.2. + :param int update_actor_freq: the update frequency of actor network, + default to 2. + :param float noise_clip: the clipping range used in updating policy + network, default to 0.5. + :param action_range: the action range (minimum, maximum). + :type action_range: [float, float] + :param bool reward_normalization: normalize the reward to Normal(0, 1), + defaults to ``False``. + :param bool ignore_done: ignore the done flag while training the policy, + defaults to ``False``. + """ def __init__(self, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau=0.005, gamma=0.99, exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2, noise_clip=0.5, action_range=None, - reward_normalization=False, ignore_done=False): + reward_normalization=False, ignore_done=False, **kwargs): super().__init__(actor, actor_optim, None, None, tau, gamma, exploration_noise, action_range, reward_normalization, ignore_done) @@ -50,7 +81,7 @@ def sync_weight(self): self.critic2_old.parameters(), self.critic2.parameters()): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) - def learn(self, batch, batch_size=None, repeat=1): + def learn(self, batch, **kwargs): with torch.no_grad(): a_ = self(batch, model='actor_old', input='obs_next').act dev = a_.device diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 51ed8f8b7..78ea860d9 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -12,34 +12,35 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, **kwargs): """A wrapper for off-policy trainer procedure. - Parameters - * **policy** – an instance of the :class:`~tianshou.policy.BasePolicy`\ - class. - * **train_collector** – the collector used for training. - * **test_collector** – the collector used for testing. - * **max_epoch** – the maximum of epochs for training. The training \ - process might be finished before reaching the ``max_epoch``. - * **step_per_epoch** – the number of step for updating policy network \ - in one epoch. - * **collect_per_step** – the number of frames the collector would \ - collect before the network update. In other words, collect some \ - frames and do one policy network update. - * **episode_per_test** – the number of episodes for one policy \ - evaluation. - * **batch_size** – the batch size of sample data, which is going to \ - feed in the policy network. - * **train_fn** – a function receives the current number of epoch index\ - and performs some operations at the beginning of training in this \ - epoch. - * **test_fn** – a function receives the current number of epoch index \ - and performs some operations at the beginning of testing in this \ - epoch. - * **stop_fn** – a function receives the average undiscounted returns \ - of the testing result, return a boolean which indicates whether \ - reaching the goal. - * **writer** – a SummaryWriter provided from TensorBoard. - * **log_interval** – an int indicating the log interval of the writer. - * **verbose** – a boolean indicating whether to print the information. + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` + class. + :param train_collector: the collector used for training. + :type train_collector: :class:`~tianshou.data.Collector` + :param test_collector: the collector used for testing. + :type test_collector: :class:`~tianshou.data.Collector` + :param int max_epoch: the maximum of epochs for training. The training + process might be finished before reaching the ``max_epoch``. + :param int step_per_epoch: the number of step for updating policy network + in one epoch. + :param int collect_per_step: the number of frames the collector would + collect before the network update. In other words, collect some frames + and do one policy network update. + :param episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to + feed in the policy network. + :param function train_fn: a function receives the current number of epoch + index and performs some operations at the beginning of training in this + epoch. + :param function test_fn: a function receives the current number of epoch + index and performs some operations at the beginning of testing in this + epoch. + :param function stop_fn: a function receives the average undiscounted + returns of the testing result, return a boolean which indicates whether + reaching the goal. + :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard + SummaryWriter. + :param int log_interval: the log interval of the writer. + :param bool verbose: whether to print the information. :return: See :func:`~tianshou.trainer.gather_info`. """ diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index aef1a8630..e633c5bc1 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -13,37 +13,39 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, **kwargs): """A wrapper for on-policy trainer procedure. - Parameters - * **policy** – an instance of the :class:`~tianshou.policy.BasePolicy`\ - class. - * **train_collector** – the collector used for training. - * **test_collector** – the collector used for testing. - * **max_epoch** – the maximum of epochs for training. The training \ - process might be finished before reaching the ``max_epoch``. - * **step_per_epoch** – the number of step for updating policy network \ - in one epoch. - * **collect_per_step** – the number of frames the collector would \ - collect before the network update. In other words, collect some \ - frames and do one policy network update. - * **repeat_per_collect** – the number of repeat time for policy \ - learning, for example, set it to 2 means the policy needs to learn\ - each given batch data twice. - * **episode_per_test** – the number of episodes for one policy \ - evaluation. - * **batch_size** – the batch size of sample data, which is going to \ - feed in the policy network. - * **train_fn** – a function receives the current number of epoch index\ - and performs some operations at the beginning of training in this \ - epoch. - * **test_fn** – a function receives the current number of epoch index \ - and performs some operations at the beginning of testing in this \ - epoch. - * **stop_fn** – a function receives the average undiscounted returns \ - of the testing result, return a boolean which indicates whether \ - reaching the goal. - * **writer** – a SummaryWriter provided from TensorBoard. - * **log_interval** – an int indicating the log interval of the writer. - * **verbose** – a boolean indicating whether to print the information. + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` + class. + :param train_collector: the collector used for training. + :type train_collector: :class:`~tianshou.data.Collector` + :param test_collector: the collector used for testing. + :type test_collector: :class:`~tianshou.data.Collector` + :param int max_epoch: the maximum of epochs for training. The training + process might be finished before reaching the ``max_epoch``. + :param int step_per_epoch: the number of step for updating policy network + in one epoch. + :param int collect_per_step: the number of frames the collector would + collect before the network update. In other words, collect some frames + and do one policy network update. + :param int repeat_per_collect: the number of repeat time for policy + learning, for example, set it to 2 means the policy needs to learn each + given batch data twice. + :param episode_per_test: the number of episodes for one policy evaluation. + :type episode_per_test: int or list of ints + :param int batch_size: the batch size of sample data, which is going to + feed in the policy network. + :param function train_fn: a function receives the current number of epoch + index and performs some operations at the beginning of training in this + epoch. + :param function test_fn: a function receives the current number of epoch + index and performs some operations at the beginning of testing in this + epoch. + :param function stop_fn: a function receives the average undiscounted + returns of the testing result, return a boolean which indicates whether + reaching the goal. + :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard + SummaryWriter. + :param int log_interval: the log interval of the writer. + :param bool verbose: whether to print the information. :return: See :func:`~tianshou.trainer.gather_info`. """ diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/moving_average.py index ae0e7d6a9..0b5933edb 100644 --- a/tianshou/utils/moving_average.py +++ b/tianshou/utils/moving_average.py @@ -26,7 +26,7 @@ def __init__(self, size=100): def add(self, x): """Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with only one element, a python scalar, or a list of python scalar. It will - exclude the infinity. + automatically exclude the infinity. """ if isinstance(x, torch.Tensor): x = x.item()