Skip to content

Commit

Permalink
Performance improve (#18)
Browse files Browse the repository at this point in the history
* improve performance

set one thread for NN
replace detach() op with torch.no_grad()

* fix pep 8 errors
  • Loading branch information
fengredrum authored Apr 5, 2020
1 parent b6c9db6 commit 4d4d0da
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 44 deletions.
10 changes: 7 additions & 3 deletions test/continuous/test_ddpg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
Expand All @@ -19,14 +20,15 @@
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=1e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--exploration-noise', type=float, default=0.1)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=128)
Expand All @@ -43,6 +45,7 @@ def get_args():


def test_ddpg(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250
Expand Down Expand Up @@ -81,7 +84,8 @@ def test_ddpg(args=get_args()):
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
# log
writer = SummaryWriter(args.logdir + '/' + 'ddpg')
log_path = os.path.join(args.logdir, args.task, 'ddpg', args.run_id)
writer = SummaryWriter(log_path)

def stop_fn(x):
return x >= env.spec.reward_threshold
Expand Down
10 changes: 7 additions & 3 deletions test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
Expand All @@ -19,14 +20,15 @@
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--repeat-per-collect', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=16)
Expand All @@ -47,6 +49,7 @@ def get_args():

def _test_ppo(args=get_args()):
# just a demo, I have not made it work :(
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250
Expand Down Expand Up @@ -89,7 +92,8 @@ def _test_ppo(args=get_args()):
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.step_per_epoch)
# log
writer = SummaryWriter(args.logdir + '/' + 'ppo')
log_path = os.path.join(args.logdir, args.task, 'ppo', args.run_id)
writer = SummaryWriter(log_path)

def stop_fn(x):
return x >= env.spec.reward_threshold
Expand Down
10 changes: 7 additions & 3 deletions test/continuous/test_sac.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
Expand All @@ -19,14 +20,15 @@
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128)
Expand All @@ -43,6 +45,7 @@ def get_args():


def test_sac(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250
Expand Down Expand Up @@ -86,7 +89,8 @@ def test_sac(args=get_args()):
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
writer = SummaryWriter(args.logdir + '/' + 'sac')
log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id)
writer = SummaryWriter(log_path)

def stop_fn(x):
return x >= env.spec.reward_threshold
Expand Down
10 changes: 7 additions & 3 deletions test/continuous/test_td3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
Expand All @@ -19,7 +20,8 @@
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3)
Expand All @@ -29,7 +31,7 @@ def get_args():
parser.add_argument('--policy-noise', type=float, default=0.2)
parser.add_argument('--noise-clip', type=float, default=0.5)
parser.add_argument('--update-actor-freq', type=int, default=2)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128)
Expand All @@ -46,6 +48,7 @@ def get_args():


def test_td3(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250
Expand Down Expand Up @@ -90,7 +93,8 @@ def test_td3(args=get_args()):
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
writer = SummaryWriter(args.logdir + '/' + 'td3')
log_path = os.path.join(args.logdir, args.task, 'td3', args.run_id)
writer = SummaryWriter(log_path)

def stop_fn(x):
return x >= env.spec.reward_threshold
Expand Down
5 changes: 3 additions & 2 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
import numpy as np
from tianshou.env import BaseVectorEnv
from tianshou.data import Batch, ReplayBuffer,\
from tianshou.data import Batch, ReplayBuffer, \
ListReplayBuffer
from tianshou.utils import MovAvg

Expand Down Expand Up @@ -115,7 +115,8 @@ def collect(self, n_step=0, n_episode=0, render=0):
done=self._make_batch(self._done),
obs_next=None,
info=self._make_batch(self._info))
result = self.policy(batch_data, self.state)
with torch.no_grad():
result = self.policy(batch_data, self.state)
self.state = result.state if hasattr(result, 'state') else None
if isinstance(result.act, torch.Tensor):
self._act = result.act.detach().cpu().numpy()
Expand Down
15 changes: 9 additions & 6 deletions tianshou/policy/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,15 @@ def __call__(self, batch, state=None,
return Batch(act=logits, state=h)

def learn(self, batch, batch_size=None, repeat=1):
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act)
dev = target_q.device
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act)
dev = target_q.device
rew = torch.tensor(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
current_q = self.critic(batch.obs, batch.act)
critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad()
Expand Down
25 changes: 14 additions & 11 deletions tianshou/policy/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,20 @@ def __call__(self, batch, state=None, input='obs'):
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)

def learn(self, batch, batch_size=None, repeat=1):
obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act
dev = a_.device
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_),
) - self._alpha * obs_next_result.log_prob
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
with torch.no_grad():
obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act
dev = a_.device
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_),
) - self._alpha * obs_next_result.log_prob
rew = torch.tensor(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
obs_result = self(batch)
a = obs_result.act
current_q1, current_q1a = self.critic1(
Expand Down
29 changes: 16 additions & 13 deletions tianshou/policy/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,22 @@ def sync_weight(self):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)

def learn(self, batch, batch_size=None, repeat=1):
a_ = self(batch, model='actor_old', input='obs_next').act
dev = a_.device
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
if self._noise_clip >= 0:
noise = noise.clamp(-self._noise_clip, self._noise_clip)
a_ += noise
a_ = a_.clamp(self._range[0], self._range[1])
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_))
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
with torch.no_grad():
a_ = self(batch, model='actor_old', input='obs_next').act
dev = a_.device
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
if self._noise_clip >= 0:
noise = noise.clamp(-self._noise_clip, self._noise_clip)
a_ += noise
a_ = a_.clamp(self._range[0], self._range[1])
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_))
rew = torch.tensor(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
# critic 1
current_q1 = self.critic1(batch.obs, batch.act)
critic1_loss = F.mse_loss(current_q1, target_q)
Expand Down

0 comments on commit 4d4d0da

Please sign in to comment.