Skip to content

Commit

Permalink
Merge pull request #6 from josiahls/version_0_7_0
Browse files Browse the repository at this point in the history
DDPG / Testing Init
  • Loading branch information
josiahls authored Oct 14, 2019
2 parents ed2d54f + 13e1227 commit ebadf44
Show file tree
Hide file tree
Showing 19 changed files with 446 additions and 349 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
[![Build Status](https://dev.azure.com/jokellum/jokellum/_apis/build/status/josiahls.fast-reinforcement-learning?branchName=master)](https://dev.azure.com/jokellum/jokellum/_build/latest?definitionId=1&branchName=master)
[![pypi fasti_rl version](https://img.shields.io/pypi/v/fast_rl)](https://pypi.python.org/pypi/fast_rl)
[![github_master version](https://img.shields.io/github/v/release/josiahls/fast-reinforcement-learning?include_prereleases)](https://github.com/josiahls/fast-reinforcement-learning/releases)

**Note: Test passing will not be a useful stability indicator until version 1.0+**

# Fast Reinforcement Learning
This repo is not affiliated with Jeremy Howard or his course which can be found here: [here](https://www.fast.ai/about/)
We will be using components from the Fastai library however for building and training our reinforcement learning (RL)
Expand Down Expand Up @@ -221,8 +227,8 @@ learn.fit(5)
```


- [ ] **Working On** 0.7.0 Full test suite using multi-processing. Connect to CI.
- [ ] 0.8.0 Comprehensive model eval **debug/verify**. Each model should succeed at at least a few known environments.
- [X] 0.7.0 Full test suite using multi-processing. Connect to CI.
- [ ] **Working On** 0.8.0 Comprehensive model eval **debug/verify**. Each model should succeed at at least a few known environments.
- [ ] 0.9.0 Notebook demonstrations of basic model usage
- [ ] **1.0.0** Base version is completed with working model visualizations proving performance / expected failure. At
this point, all models should have guaranteed environments they should succeed in.
Expand Down
39 changes: 33 additions & 6 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,40 @@ trigger:
- master

pool:
vmImage: 'ubuntu-latest'
vmImage: 'ubuntu-16.04'

steps:
- script: echo Hello, world!
displayName: 'Run a one-line script'

- bash: "sudo apt-get install -y ffmpeg xvfb freeglut3-dev python-opengl"
displayName: 'Install ffmpeg, freeglut3-dev, and xvfb'

- task: UsePythonVersion@0
inputs:
versionSpec: '3.6'

- script: sh ./build/azure_pipeline_helper.sh
displayName: 'Complex Installs'

- script: |
echo Add other tasks to build, test, and deploy your project.
echo See https://aka.ms/yaml
displayName: 'Run a multi-line script'
pip install Bottleneck
python setup.py install
pip install pytest
pip install pytest-cov
pip install pytest-xdist
displayName: 'Install Python Packages'

- script: |
xvfb-run -s "-screen 0 1400x900x24" pytest -n 8 fast_rl/tests --doctest-modules --junitxml=junit/test-results.xml --cov=./ --cov-report=xml --cov-report=html
displayName: 'Test with pytest'

- task: PublishTestResults@2
condition: succeededOrFailed()
inputs:
testResultsFiles: '**/test-*.xml'
testRunTitle: 'Publish test results for Python $(python.version)'

- task: PublishCodeCoverageResults@1
inputs:
codeCoverageTool: Cobertura
summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml'
reportDirectory: '$(System.DefaultWorkingDirectory)/**/htmlcov'
14 changes: 14 additions & 0 deletions build/azure_pipeline_helper.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/usr/bin/env bash

# Install pybullet
git clone https://github.com/benelot/pybullet-gym.git
cd pybullet-gym
pip install -e .
cd ../

# Install gym_maze
git clone https://github.com/MattChanTK/gym-maze.git
cd gym-maze
python setup.py install
cd ../

49 changes: 30 additions & 19 deletions fast_rl/agents/BaseAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,12 @@ def pick_action(self, x):
with torch.no_grad():
if len(x.shape) > 2: raise ValueError('The agent is outputting actions with more than 1 dimension...')

action, x, perturbed = self.exploration_strategy.perturb(x, x, self.data.train_ds.env.action_space)
x = np.clip(x, -1.0, 1.0)
if isinstance(self.data.train_ds.env.action_space, Discrete): action = x.argmax().numpy().item()
elif isinstance(self.data.train_ds.env.action_space, Box) and len(x.shape) != 1: action = x.squeeze(0).numpy()

if isinstance(self.data.train_ds.env.action_space, Discrete) and not perturbed: action = x.argmax().numpy().item()
elif isinstance(self.data.train_ds.env.action_space, Box): action = x.squeeze(0).numpy()
action = self.exploration_strategy.perturb(action, self.data.train_ds.env.action_space)

return action, x
return action

def interpret_q(self, items):
raise NotImplementedError
Expand All @@ -68,13 +67,20 @@ def forward(self, x):
return x.long()


class SwapImageChannel(nn.Module):
def forward(self, x):
return x.transpose(1, 3)


class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)


def create_nn_model(layer_list: list, action_size, state_size, use_bn=False, use_embed=True,
activation_function=None, final_activation_function=None):


def create_nn_model(layer_list: list, action_size, state_size, use_bn=False, use_embed=False,
activation_function=None, final_activation_function=None, action_val_to_dim=True):
"""Generates an nn module.
Notes:
Expand All @@ -84,7 +90,8 @@ def create_nn_model(layer_list: list, action_size, state_size, use_bn=False, use
"""
act = nn.LeakyReLU if activation_function is None else activation_function
action_size = action_size[0] # For now the dimension of the action does not make a difference.
# For now the dimension of the action does not make a difference.
action_size = action_size[0] if not action_val_to_dim else action_size[1]
# For now keep drop out as 0, test including dropout later
ps = [0] * len(layer_list)
sizes = [state_size] + layer_list + [action_size]
Expand Down Expand Up @@ -126,24 +133,25 @@ def get_conv(input_tuple, act, kernel_size, stride, n_conv_layers, layers):
\times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
:param input_tuple:
:param act:
:param kernel_size:
:param stride:
:param n_conv_layers:
:param layers:
:return:
Args:
input_tuple:
act:
kernel_size:
stride:
n_conv_layers:
layers:
"""
h, w = input_tuple[0], input_tuple[1]
conv_layers = []
conv_layers = [SwapImageChannel()]
for i in range(n_conv_layers):
h, w = get_next_conv_shape(h, w, stride, kernel_size)
conv_layers.append(torch.nn.Conv2d(input_tuple[2], 3, kernel_size=kernel_size, stride=stride))
conv_layers.append(act)
return layers + conv_layers, 3 * (h + 1) * (w + 1)


def create_cnn_model(layer_list: list, action_size, state_size, use_bn=False, kernel_size=5, stride=3, n_conv_layers=3):
def create_cnn_model(layer_list: list, action_size, state_size, use_bn=False, kernel_size=5, stride=3, n_conv_layers=3,
activation_function=None, final_activation_function=None, action_val_to_dim=True):
"""Generates an nn module.
Notes:
Expand All @@ -152,15 +160,18 @@ def create_cnn_model(layer_list: list, action_size, state_size, use_bn=False, ke
Returns:
"""
act = nn.LeakyReLU if activation_function is None else activation_function
# For now keep drop out as 0, test including dropout later
ps = [0] * len(layer_list)
sizes = [state_size] + layer_list + [action_size]
actns = [nn.ReLU() for _ in range(n_conv_layers + len(sizes) - 2)] + [None]
action_size = action_size[0] if not action_val_to_dim else action_size[1]
sizes = [state_size[0]] + layer_list + [action_size]
actns = [act() for _ in range(n_conv_layers + len(sizes) - 2)] + [None]
layers = []
for i, (n_in, n_out, dp, act) in enumerate(zip(sizes[:-1], sizes[1:], [0.] + ps, actns)):
if type(n_in) == tuple:
layers, n_in = get_conv(n_in, act, kernel_size, n_conv_layers=n_conv_layers, layers=layers, stride=stride)
layers += [Flatten()]

layers += bn_drop_lin(n_in, n_out, bn=use_bn and i != 0, p=dp, actn=act)
if final_activation_function is not None: layers += [final_activation_function()]
return nn.Sequential(*layers)
68 changes: 51 additions & 17 deletions fast_rl/agents/DDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from torch.nn import MSELoss
from torch.optim import Adam

from fast_rl.agents.BaseAgent import BaseAgent, create_nn_model
from fast_rl.agents.BaseAgent import BaseAgent, create_nn_model, create_cnn_model, get_next_conv_shape, get_conv, \
Flatten
from fast_rl.core.Learner import AgentLearner
from fast_rl.core.MarkovDecisionProcess import MDPDataBunch
from fast_rl.core.agent_core import GreedyEpsilon, ExperienceReplay
Expand All @@ -27,6 +28,8 @@ def on_train_begin(self, n_epochs, **kwargs: Any):

def on_epoch_begin(self, epoch, **kwargs: Any):
self.episode = epoch
# if self.learn.model.training and self.iteration != 0:
# self.learn.model.memory.update(item=self.learn.data.x.items[-1])
self.iteration = 0

def on_loss_begin(self, **kwargs: Any):
Expand All @@ -47,7 +50,7 @@ def on_loss_begin(self, **kwargs: Any):
# self.learn.model.target_copy_over()


class Critic(nn.Module):
class NNCritic(nn.Module):
def __init__(self, layer_list: list, action_size, state_size, use_bn=False, use_embed=True,
activation_function=None):
super().__init__()
Expand All @@ -59,7 +62,7 @@ def __init__(self, layer_list: list, action_size, state_size, use_bn=False, use_
self.fc3 = nn.Linear(layer_list[1], 1)

def forward(self, x):
action, x = x[:, self.state_size:], x[:, :self.state_size]
x, action = x

x = nn.LeakyReLU()(self.fc1(x))
x = nn.LeakyReLU()(self.fc2(torch.cat((x, action), 1)))
Expand All @@ -68,17 +71,41 @@ def forward(self, x):
return x


class CNNCritic(nn.Module):
def __init__(self, layer_list: list, action_size, state_size, activation_function=None):
super().__init__()
self.action_size = action_size[0]
self.state_size = state_size[0]

layers = []
layers, input_size = get_conv(self.state_size, nn.LeakyReLU(), 8, 2, 3, layers)
layers += [Flatten()]
self.conv_layers = nn.Sequential(*layers)

self.fc1 = nn.Linear(input_size + self.action_size, 200)
self.fc2 = nn.Linear(200, 1)

def forward(self, x):
x, action = x

x = nn.LeakyReLU()(self.conv_layers(x))
x = nn.LeakyReLU()(self.fc1(torch.cat((x, action), 1)))
x = nn.LeakyReLU()(self.fc2(x))

return x


class DDPG(BaseAgent):

def __init__(self, data: MDPDataBunch, memory=None, tau=1e-3, batch=64, discount=0.99,
lr=1e-3, actor_lr=1e-4, exploration_strategy=None, env_was_discrete=False):
lr=1e-3, actor_lr=1e-4, exploration_strategy=None):
"""
Implementation of a continuous control algorithm using an actor/critic architecture.
Notes:
Uses 4 networks, 2 actors, 2 critics.
All models use batch norm for feature invariance.
Critic simply predicts Q while the Actor proposes the actions to take given a state s.
NNCritic simply predicts Q while the Actor proposes the actions to take given a state s.
References:
[1] Lillicrap, Timothy P., et al. "Continuous control with deep reinforcement learning."
Expand All @@ -93,7 +120,6 @@ def __init__(self, data: MDPDataBunch, memory=None, tau=1e-3, batch=64, discount
lr: Rate that the opt will learn parameter gradients.
"""
super().__init__(data)
self.env_was_discrete = env_was_discrete
self.name = 'DDPG'
self.lr = lr
self.discount = discount
Expand Down Expand Up @@ -122,21 +148,30 @@ def __init__(self, data: MDPDataBunch, memory=None, tau=1e-3, batch=64, discount
do_exploration=self.training))

def initialize_action_model(self, layers, data):
return create_nn_model(layers, *data.get_action_state_size(), False, use_embed=data.train_ds.embeddable,
final_activation_function=nn.Tanh)
actions, state = data.get_action_state_size()
if type(state[0]) is tuple and len(state[0]) == 3:
# actions, state = actions[0], state[0]
# If the shape has 3 dimensions, we will try using cnn's instead.
return create_cnn_model([200, 200], actions, state, False, kernel_size=8,
final_activation_function=nn.Tanh, action_val_to_dim=False)
else:
return create_nn_model(layers, *data.get_action_state_size(), False, use_embed=data.train_ds.embeddable,
final_activation_function=nn.Tanh, action_val_to_dim=False)

def initialize_critic_model(self, layers, data):
""" Instead of state -> action, we are going state + action -> single expected reward. """
return Critic(layers, *data.get_action_state_size())
actions, state = data.get_action_state_size()
if type(state[0]) is tuple and len(state[0]) == 3:
return CNNCritic(layers, *data.get_action_state_size())
else:
return NNCritic(layers, *data.get_action_state_size())

def pick_action(self, x):
if self.training: self.action_model.eval()
with torch.no_grad():
action, x = super(DDPG, self).pick_action(x)
action = super(DDPG, self).pick_action(x)
if self.training: self.action_model.train()

if not self.env_was_discrete: action = np.clip(action, -1, 1)
return action, np.clip(x, -1, 1)
return np.clip(action, -1, 1)

def optimize(self):
"""
Expand All @@ -160,12 +195,11 @@ def optimize(self):
s_prime = torch.from_numpy(np.array([item.result_state for item in sampled])).float()
s = torch.from_numpy(np.array([item.current_state for item in sampled])).float()
a = torch.from_numpy(np.array([item.actions for item in sampled]).astype(float)).float()
if self.env_was_discrete: a = torch.from_numpy(np.array([item.raw_action for item in sampled]).astype(float)).float()

with torch.no_grad():
y = r + self.discount * self.t_critic_model(torch.cat((s_prime, self.t_action_model(s_prime)), 1))
y = r + self.discount * self.t_critic_model((s_prime, self.t_action_model(s_prime)))

y_hat = self.critic_model(torch.cat((s, a), 1))
y_hat = self.critic_model((s, a))

critic_loss = self.loss_func(y_hat, y)

Expand All @@ -175,7 +209,7 @@ def optimize(self):
critic_loss.backward()
self.critic_optimizer.step()

actor_loss = -self.critic_model(torch.cat((s, self.action_model(s)), 1)).mean()
actor_loss = -self.critic_model((s, self.action_model(s))).mean()

self.loss = critic_loss.cpu().detach()

Expand Down
Loading

0 comments on commit ebadf44

Please sign in to comment.