Skip to content

Commit

Permalink
TRPO converges on mountain car!
Browse files Browse the repository at this point in the history
  • Loading branch information
josiahls committed Oct 31, 2023
1 parent e2fab85 commit 6b1be2c
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 55 deletions.
6 changes: 6 additions & 0 deletions fastrl/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@
'fastrl/agents/trpo.py'),
'fastrl.agents.trpo.AdvantageBuffer.zip_steps': ( '07_Agents/02_Continuous/agents.trpo.html#advantagebuffer.zip_steps',
'fastrl/agents/trpo.py'),
'fastrl.agents.trpo.AdvantageFirstLastMerger': ( '07_Agents/02_Continuous/agents.trpo.html#advantagefirstlastmerger',
'fastrl/agents/trpo.py'),
'fastrl.agents.trpo.AdvantageFirstLastMerger.__init__': ( '07_Agents/02_Continuous/agents.trpo.html#advantagefirstlastmerger.__init__',
'fastrl/agents/trpo.py'),
'fastrl.agents.trpo.AdvantageFirstLastMerger.__iter__': ( '07_Agents/02_Continuous/agents.trpo.html#advantagefirstlastmerger.__iter__',
'fastrl/agents/trpo.py'),
'fastrl.agents.trpo.AdvantageGymDataPipe': ( '07_Agents/02_Continuous/agents.trpo.html#advantagegymdatapipe',
'fastrl/agents/trpo.py'),
'fastrl.agents.trpo.AdvantageStep': ( '07_Agents/02_Continuous/agents.trpo.html#advantagestep',
Expand Down
99 changes: 77 additions & 22 deletions fastrl/agents/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

# %% auto 0
__all__ = ['AdvantageStep', 'pipe2device', 'discounted_cumsum_', 'get_flat_params_from', 'set_flat_params_to', 'AdvantageBuffer',
'AdvantageGymDataPipe', 'OptionalClampLinear', 'Actor', 'NormalExploration', 'ProbabilisticStdCollector',
'ProbabilisticMeanCollector', 'TRPOAgent', 'conjugate_gradients', 'backtrack_line_search', 'actor_prob_loss',
'pre_hessian_kl', 'auto_flat', 'forward_pass', 'CriticLossProcessor', 'ActorOptAndLossProcessor',
'TRPOLearner']
'AdvantageFirstLastMerger', 'AdvantageGymDataPipe', 'OptionalClampLinear', 'Actor', 'NormalExploration',
'ProbabilisticStdCollector', 'ProbabilisticMeanCollector', 'TRPOAgent', 'conjugate_gradients',
'backtrack_line_search', 'actor_prob_loss', 'pre_hessian_kl', 'auto_flat', 'forward_pass',
'CriticLossProcessor', 'ActorOptAndLossProcessor', 'TRPOLearner']

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 2
# Python native modules
Expand Down Expand Up @@ -178,13 +178,15 @@ def __init__(self,
# $\lambda$ is unqiue to GAE and manages importance to values when
# they are in accurate is defined in (Shulman et al., 2016) as '... $\lambda$ < 1
# introduces bias only when the value function is inaccurate....'.
gamma:float=0.99
gamma:float=0.99,
nsteps:int = 1
):
self.source_datapipe = source_datapipe
self.bs = bs
self.critic = critic
self.device = None
self.discount = discount
self.nsteps = nsteps
self.gamma = gamma
self.env_advantage_buffer:Dict[Literal['env'],list] = {}

Expand All @@ -207,8 +209,7 @@ def update_advantage_buffer(self,step:StepTypes.types) -> int:
return env_id

def zip_steps(
self,
steps:List[Union[StepTypes.types]]
self,steps:List[Union[StepTypes.types]]
) -> Tuple[torch.FloatTensor,torch.FloatTensor,torch.BoolTensor]:
step_subset = [(o.reward,o.state,o.truncated or o.terminated) for o in steps]
zipped_fields = zip(*step_subset)
Expand All @@ -231,7 +232,7 @@ def __iter__(self) -> AdvantageStep:
with evaluating(self.critic):
values = self.critic(torch.vstack((states,steps[-1].next_state)))
delta = self.delta_calc(rewards,values[:-1],values[1:],dones)
discounted_cumsum_(delta,self.discount*self.gamma,reverse=True)
discounted_cumsum_(delta,self.discount*self.gamma**self.nsteps,reverse=True)

for _step,gae_advantage,v in zip(*(steps,delta,values)):
yield AdvantageStep(
Expand Down Expand Up @@ -270,6 +271,60 @@ def __iter__(self) -> AdvantageStep:
)

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 15
class AdvantageFirstLastMerger(dp.iter.IterDataPipe):
def __init__(self,
source_datapipe,
gamma:float=0.99
):
self.source_datapipe = source_datapipe
self.gamma = gamma

def __iter__(self) -> StepTypes.types:
self.env_buffer = {}
for steps in self.source_datapipe:
if not isinstance(steps,(list,tuple)):
raise ValueError(f'Expected {self.source_datapipe} to return a list/tuple of steps, however got {type(steps)}')

if len(steps)==1:
yield steps[0]
continue

fstep,lstep = steps[0],steps[-1]

reward = fstep.reward
for step in steps[1:]:
reward *= self.gamma
reward += step.reward

advantage = fstep.advantage
for step in steps[1:]:
advantage *= self.gamma
advantage += step.advantage

next_advantage = fstep.next_advantage
for step in steps[1:]:
next_advantage *= self.gamma
next_advantage += step.next_advantage

yield fstep.__class__(
state=fstep.state.clone().detach(),
next_state=lstep.next_state.clone().detach(),
action=fstep.action,
terminated=lstep.terminated,
truncated=lstep.truncated,
reward=reward,
total_reward=lstep.total_reward,
env_id=lstep.env_id,
proc_id=lstep.proc_id,
step_n=lstep.step_n,
episode_n=fstep.episode_n,
image=fstep.image,
raw_action=fstep.raw_action,
advantage=advantage,
next_advantage=next_advantage
)

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 16
def AdvantageGymDataPipe(
source,
# AdvantageBuffer: Critic Module to valuate the advantage of state-action pairs
Expand Down Expand Up @@ -317,19 +372,19 @@ def AdvantageGymDataPipe(
include_images=include_images,
terminate_on_truncation=terminate_on_truncation,
synchronized_reset=synchronized_reset)
pipe = AdvantageBuffer(pipe,critic=critic,bs=adv_bs,discount=discount,gamma=gamma)
pipe = AdvantageBuffer(pipe,critic=critic,bs=adv_bs,discount=discount,gamma=gamma,nsteps=nsteps)
if nskips!=1: pipe = NSkipper(pipe,n=nskips)
if nsteps!=1:
pipe = NStepper(pipe,n=nsteps)
if firstlast:
pipe = FirstLastMerger(pipe)
pipe = AdvantageFirstLastMerger(pipe)
else:
pipe = NStepFlattener(pipe) # We dont want to flatten if using FirstLastMerger
if n is not None: pipe = pipe.header(limit=n)
pipe = pipe.batch(batch_size=bs)
return pipe

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 19
# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 20
class OptionalClampLinear(Module):
def __init__(self,num_inputs,state_dims,fix_variance:bool=False,
clip_min=0.3,clip_max=10.0):
Expand Down Expand Up @@ -382,7 +437,7 @@ def forward(self,x): return Independent(Normal(self.mu(x),self.std),1)
forward="Mean outputs from a parameterized Gaussian distribution."
)

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 23
# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 24
class NormalExploration(dp.iter.IterDataPipe):
def __init__(
self,
Expand Down Expand Up @@ -422,7 +477,7 @@ def __iter__(self):
else:
yield action.mean

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 25
# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 26
class ProbabilisticStdCollector(dp.iter.IterDataPipe):
title:str='std'
def __init__(self,
Expand Down Expand Up @@ -453,7 +508,7 @@ def __iter__(self):
yield Record(self.title,self.record_pipe.last_mean.item())
yield action

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 26
# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 27
def TRPOAgent(
model:Actor, # The actor to use for mapping states to actions
# LoggerBases push logs to. If None, logs will be collected and output
Expand All @@ -476,7 +531,7 @@ def TRPOAgent(
agent = AgentHead(agent)
return agent

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 34
# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 35
def conjugate_gradients(
# A function that takes the direction `d` and applies it to `A`.
# The simplest example of this found would be:
Expand Down Expand Up @@ -549,7 +604,7 @@ def conjugate_gradients(
"""
)

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 36
# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 37
def backtrack_line_search(
# A Tensor of gradients or weights to optimize
x:torch.Tensor,
Expand Down Expand Up @@ -590,7 +645,7 @@ def backtrack_line_search(
"""
)

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 38
# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 39
def actor_prob_loss(weights,s,a,r,actor,old_log_prob):
if weights is not None:
set_flat_params_to(actor,weights)
Expand All @@ -600,7 +655,7 @@ def actor_prob_loss(weights,s,a,r,actor,old_log_prob):
loss = -r.squeeze(1) * torch.exp(log_prob-old_log_prob)
return loss.mean()

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 40
# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 41
def pre_hessian_kl(
model:Actor, # An Actor or any model that outputs a probability distribution
x:torch.Tensor # Input into the model
Expand Down Expand Up @@ -642,7 +697,7 @@ def pre_hessian_kl(
kl = logstd_v - logstd0_v + (std0_v ** 2 + (mu0_v - mu_v) ** 2) / (2.0 * std_v ** 2) - 0.5
return kl.sum(1, keepdim=True)

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 42
# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 43
def auto_flat(outputs,inputs,contiguous=False,create_graph=False)->torch.Tensor:
"Calculates the gradients and flattens them into a single tensor"
grads = torch.autograd.grad(outputs,inputs,create_graph=create_graph)
Expand Down Expand Up @@ -670,7 +725,7 @@ def forward_pass(

return flat_grad_grad_kl + weights * damping

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 46
# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 45
class CriticLossProcessor(dp.iter.IterDataPipe):
debug:bool=False

Expand Down Expand Up @@ -711,7 +766,7 @@ def __iter__(self) -> Union[Dict[Literal['loss'],torch.Tensor],SimpleStep]:
yield {'loss':self.loss(pred,batch.next_advantage[m].to(dtype=torch.float32))}
yield batch

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 47
# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 46
class ActorOptAndLossProcessor(dp.iter.IterDataPipe):
debug:bool=False

Expand Down Expand Up @@ -782,7 +837,7 @@ def __iter__(self) -> Union[Dict[Literal['loss'],torch.Tensor],SimpleStep]:
yield {'loss':loss}
yield batch

# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 48
# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 47
def TRPOLearner(
# The actor model to use
actor:Actor,
Expand Down
2 changes: 1 addition & 1 deletion fastrl/loggers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
# Max size of _RECORD_CATCH_LIST before raising in exception.
# Important to avoid memory leaks, and indicates that `dump_records`
# is not being called or used.
buffer_size=1000,
buffer_size=10000,
# If True, instead of appending to _RECORD_CATCH_LIST,
# drop the record so it does not continue thorugh the
# pipeline.
Expand Down
5 changes: 3 additions & 2 deletions fastrl/memory/memory_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

# %% ../../nbs/04_Memory/01_memory_visualizer.ipynb 4
class MemoryBufferViewer:
def __init__(self, memory, agent=None):
def __init__(self, memory, agent=None, ignore_image:bool=False):
# Assuming memory contains SimpleStep instances or None
self.memory = memory
self.agent = agent
self.current_index = 0
self.ignore_image = ignore_image
# Add a label for displaying the number of elements in memory
self.memory_size_label = Label(value=f"Number of Elements in Memory: {len([x for x in memory if x is not None])}")

Expand Down Expand Up @@ -136,7 +137,7 @@ def show_current(self):
details_display = VBox(details_list)

# If the image is present, prepare left-side content
if torch.is_tensor(step.image) and step.image.nelement() > 1:
if torch.is_tensor(step.image) and step.image.nelement() > 1 and not self.ignore_image:
pil_image = self.tensor_to_pil(step.image)
img_display = widgets.Image(value=self.pil_image_to_byte_array(pil_image), format='jpeg')
display_content = HBox([img_display, details_display])
Expand Down
5 changes: 3 additions & 2 deletions nbs/04_Memory/01_memory_visualizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@
"source": [
"#|export\n",
"class MemoryBufferViewer:\n",
" def __init__(self, memory, agent=None):\n",
" def __init__(self, memory, agent=None, ignore_image:bool=False):\n",
" # Assuming memory contains SimpleStep instances or None\n",
" self.memory = memory\n",
" self.agent = agent\n",
" self.current_index = 0\n",
" self.ignore_image = ignore_image\n",
" # Add a label for displaying the number of elements in memory\n",
" self.memory_size_label = Label(value=f\"Number of Elements in Memory: {len([x for x in memory if x is not None])}\")\n",
"\n",
Expand Down Expand Up @@ -180,7 +181,7 @@
" details_display = VBox(details_list)\n",
"\n",
" # If the image is present, prepare left-side content\n",
" if torch.is_tensor(step.image) and step.image.nelement() > 1:\n",
" if torch.is_tensor(step.image) and step.image.nelement() > 1 and not self.ignore_image:\n",
" pil_image = self.tensor_to_pil(step.image)\n",
" img_display = widgets.Image(value=self.pil_image_to_byte_array(pil_image), format='jpeg')\n",
" display_content = HBox([img_display, details_display])\n",
Expand Down
2 changes: 1 addition & 1 deletion nbs/05_Logging/09a_loggers.core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
" # Max size of _RECORD_CATCH_LIST before raising in exception.\n",
" # Important to avoid memory leaks, and indicates that `dump_records`\n",
" # is not being called or used.\n",
" buffer_size=1000,\n",
" buffer_size=10000,\n",
" # If True, instead of appending to _RECORD_CATCH_LIST, \n",
" # drop the record so it does not continue thorugh the \n",
" # pipeline.\n",
Expand Down
Loading

0 comments on commit 6b1be2c

Please sign in to comment.