diff --git a/fastrl/_modidx.py b/fastrl/_modidx.py index 5717a15..0468a37 100644 --- a/fastrl/_modidx.py +++ b/fastrl/_modidx.py @@ -333,6 +333,12 @@ 'fastrl/agents/trpo.py'), 'fastrl.agents.trpo.AdvantageBuffer.zip_steps': ( '07_Agents/02_Continuous/agents.trpo.html#advantagebuffer.zip_steps', 'fastrl/agents/trpo.py'), + 'fastrl.agents.trpo.AdvantageFirstLastMerger': ( '07_Agents/02_Continuous/agents.trpo.html#advantagefirstlastmerger', + 'fastrl/agents/trpo.py'), + 'fastrl.agents.trpo.AdvantageFirstLastMerger.__init__': ( '07_Agents/02_Continuous/agents.trpo.html#advantagefirstlastmerger.__init__', + 'fastrl/agents/trpo.py'), + 'fastrl.agents.trpo.AdvantageFirstLastMerger.__iter__': ( '07_Agents/02_Continuous/agents.trpo.html#advantagefirstlastmerger.__iter__', + 'fastrl/agents/trpo.py'), 'fastrl.agents.trpo.AdvantageGymDataPipe': ( '07_Agents/02_Continuous/agents.trpo.html#advantagegymdatapipe', 'fastrl/agents/trpo.py'), 'fastrl.agents.trpo.AdvantageStep': ( '07_Agents/02_Continuous/agents.trpo.html#advantagestep', diff --git a/fastrl/agents/trpo.py b/fastrl/agents/trpo.py index 6fdc884..be738fd 100644 --- a/fastrl/agents/trpo.py +++ b/fastrl/agents/trpo.py @@ -2,10 +2,10 @@ # %% auto 0 __all__ = ['AdvantageStep', 'pipe2device', 'discounted_cumsum_', 'get_flat_params_from', 'set_flat_params_to', 'AdvantageBuffer', - 'AdvantageGymDataPipe', 'OptionalClampLinear', 'Actor', 'NormalExploration', 'ProbabilisticStdCollector', - 'ProbabilisticMeanCollector', 'TRPOAgent', 'conjugate_gradients', 'backtrack_line_search', 'actor_prob_loss', - 'pre_hessian_kl', 'auto_flat', 'forward_pass', 'CriticLossProcessor', 'ActorOptAndLossProcessor', - 'TRPOLearner'] + 'AdvantageFirstLastMerger', 'AdvantageGymDataPipe', 'OptionalClampLinear', 'Actor', 'NormalExploration', + 'ProbabilisticStdCollector', 'ProbabilisticMeanCollector', 'TRPOAgent', 'conjugate_gradients', + 'backtrack_line_search', 'actor_prob_loss', 'pre_hessian_kl', 'auto_flat', 'forward_pass', + 'CriticLossProcessor', 'ActorOptAndLossProcessor', 'TRPOLearner'] # %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 2 # Python native modules @@ -178,13 +178,15 @@ def __init__(self, # $\lambda$ is unqiue to GAE and manages importance to values when # they are in accurate is defined in (Shulman et al., 2016) as '... $\lambda$ < 1 # introduces bias only when the value function is inaccurate....'. - gamma:float=0.99 + gamma:float=0.99, + nsteps:int = 1 ): self.source_datapipe = source_datapipe self.bs = bs self.critic = critic self.device = None self.discount = discount + self.nsteps = nsteps self.gamma = gamma self.env_advantage_buffer:Dict[Literal['env'],list] = {} @@ -207,8 +209,7 @@ def update_advantage_buffer(self,step:StepTypes.types) -> int: return env_id def zip_steps( - self, - steps:List[Union[StepTypes.types]] + self,steps:List[Union[StepTypes.types]] ) -> Tuple[torch.FloatTensor,torch.FloatTensor,torch.BoolTensor]: step_subset = [(o.reward,o.state,o.truncated or o.terminated) for o in steps] zipped_fields = zip(*step_subset) @@ -231,7 +232,7 @@ def __iter__(self) -> AdvantageStep: with evaluating(self.critic): values = self.critic(torch.vstack((states,steps[-1].next_state))) delta = self.delta_calc(rewards,values[:-1],values[1:],dones) - discounted_cumsum_(delta,self.discount*self.gamma,reverse=True) + discounted_cumsum_(delta,self.discount*self.gamma**self.nsteps,reverse=True) for _step,gae_advantage,v in zip(*(steps,delta,values)): yield AdvantageStep( @@ -270,6 +271,60 @@ def __iter__(self) -> AdvantageStep: ) # %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 15 +class AdvantageFirstLastMerger(dp.iter.IterDataPipe): + def __init__(self, + source_datapipe, + gamma:float=0.99 + ): + self.source_datapipe = source_datapipe + self.gamma = gamma + + def __iter__(self) -> StepTypes.types: + self.env_buffer = {} + for steps in self.source_datapipe: + if not isinstance(steps,(list,tuple)): + raise ValueError(f'Expected {self.source_datapipe} to return a list/tuple of steps, however got {type(steps)}') + + if len(steps)==1: + yield steps[0] + continue + + fstep,lstep = steps[0],steps[-1] + + reward = fstep.reward + for step in steps[1:]: + reward *= self.gamma + reward += step.reward + + advantage = fstep.advantage + for step in steps[1:]: + advantage *= self.gamma + advantage += step.advantage + + next_advantage = fstep.next_advantage + for step in steps[1:]: + next_advantage *= self.gamma + next_advantage += step.next_advantage + + yield fstep.__class__( + state=fstep.state.clone().detach(), + next_state=lstep.next_state.clone().detach(), + action=fstep.action, + terminated=lstep.terminated, + truncated=lstep.truncated, + reward=reward, + total_reward=lstep.total_reward, + env_id=lstep.env_id, + proc_id=lstep.proc_id, + step_n=lstep.step_n, + episode_n=fstep.episode_n, + image=fstep.image, + raw_action=fstep.raw_action, + advantage=advantage, + next_advantage=next_advantage + ) + +# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 16 def AdvantageGymDataPipe( source, # AdvantageBuffer: Critic Module to valuate the advantage of state-action pairs @@ -317,19 +372,19 @@ def AdvantageGymDataPipe( include_images=include_images, terminate_on_truncation=terminate_on_truncation, synchronized_reset=synchronized_reset) - pipe = AdvantageBuffer(pipe,critic=critic,bs=adv_bs,discount=discount,gamma=gamma) + pipe = AdvantageBuffer(pipe,critic=critic,bs=adv_bs,discount=discount,gamma=gamma,nsteps=nsteps) if nskips!=1: pipe = NSkipper(pipe,n=nskips) if nsteps!=1: pipe = NStepper(pipe,n=nsteps) if firstlast: - pipe = FirstLastMerger(pipe) + pipe = AdvantageFirstLastMerger(pipe) else: pipe = NStepFlattener(pipe) # We dont want to flatten if using FirstLastMerger if n is not None: pipe = pipe.header(limit=n) pipe = pipe.batch(batch_size=bs) return pipe -# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 19 +# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 20 class OptionalClampLinear(Module): def __init__(self,num_inputs,state_dims,fix_variance:bool=False, clip_min=0.3,clip_max=10.0): @@ -382,7 +437,7 @@ def forward(self,x): return Independent(Normal(self.mu(x),self.std),1) forward="Mean outputs from a parameterized Gaussian distribution." ) -# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 23 +# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 24 class NormalExploration(dp.iter.IterDataPipe): def __init__( self, @@ -422,7 +477,7 @@ def __iter__(self): else: yield action.mean -# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 25 +# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 26 class ProbabilisticStdCollector(dp.iter.IterDataPipe): title:str='std' def __init__(self, @@ -453,7 +508,7 @@ def __iter__(self): yield Record(self.title,self.record_pipe.last_mean.item()) yield action -# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 26 +# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 27 def TRPOAgent( model:Actor, # The actor to use for mapping states to actions # LoggerBases push logs to. If None, logs will be collected and output @@ -476,7 +531,7 @@ def TRPOAgent( agent = AgentHead(agent) return agent -# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 34 +# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 35 def conjugate_gradients( # A function that takes the direction `d` and applies it to `A`. # The simplest example of this found would be: @@ -549,7 +604,7 @@ def conjugate_gradients( """ ) -# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 36 +# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 37 def backtrack_line_search( # A Tensor of gradients or weights to optimize x:torch.Tensor, @@ -590,7 +645,7 @@ def backtrack_line_search( """ ) -# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 38 +# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 39 def actor_prob_loss(weights,s,a,r,actor,old_log_prob): if weights is not None: set_flat_params_to(actor,weights) @@ -600,7 +655,7 @@ def actor_prob_loss(weights,s,a,r,actor,old_log_prob): loss = -r.squeeze(1) * torch.exp(log_prob-old_log_prob) return loss.mean() -# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 40 +# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 41 def pre_hessian_kl( model:Actor, # An Actor or any model that outputs a probability distribution x:torch.Tensor # Input into the model @@ -642,7 +697,7 @@ def pre_hessian_kl( kl = logstd_v - logstd0_v + (std0_v ** 2 + (mu0_v - mu_v) ** 2) / (2.0 * std_v ** 2) - 0.5 return kl.sum(1, keepdim=True) -# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 42 +# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 43 def auto_flat(outputs,inputs,contiguous=False,create_graph=False)->torch.Tensor: "Calculates the gradients and flattens them into a single tensor" grads = torch.autograd.grad(outputs,inputs,create_graph=create_graph) @@ -670,7 +725,7 @@ def forward_pass( return flat_grad_grad_kl + weights * damping -# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 46 +# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 45 class CriticLossProcessor(dp.iter.IterDataPipe): debug:bool=False @@ -711,7 +766,7 @@ def __iter__(self) -> Union[Dict[Literal['loss'],torch.Tensor],SimpleStep]: yield {'loss':self.loss(pred,batch.next_advantage[m].to(dtype=torch.float32))} yield batch -# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 47 +# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 46 class ActorOptAndLossProcessor(dp.iter.IterDataPipe): debug:bool=False @@ -782,7 +837,7 @@ def __iter__(self) -> Union[Dict[Literal['loss'],torch.Tensor],SimpleStep]: yield {'loss':loss} yield batch -# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 48 +# %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 47 def TRPOLearner( # The actor model to use actor:Actor, diff --git a/fastrl/loggers/core.py b/fastrl/loggers/core.py index d49806a..316c62e 100644 --- a/fastrl/loggers/core.py +++ b/fastrl/loggers/core.py @@ -62,7 +62,7 @@ def __init__( # Max size of _RECORD_CATCH_LIST before raising in exception. # Important to avoid memory leaks, and indicates that `dump_records` # is not being called or used. - buffer_size=1000, + buffer_size=10000, # If True, instead of appending to _RECORD_CATCH_LIST, # drop the record so it does not continue thorugh the # pipeline. diff --git a/fastrl/memory/memory_visualizer.py b/fastrl/memory/memory_visualizer.py index ade6d7e..bd985bb 100644 --- a/fastrl/memory/memory_visualizer.py +++ b/fastrl/memory/memory_visualizer.py @@ -17,11 +17,12 @@ # %% ../../nbs/04_Memory/01_memory_visualizer.ipynb 4 class MemoryBufferViewer: - def __init__(self, memory, agent=None): + def __init__(self, memory, agent=None, ignore_image:bool=False): # Assuming memory contains SimpleStep instances or None self.memory = memory self.agent = agent self.current_index = 0 + self.ignore_image = ignore_image # Add a label for displaying the number of elements in memory self.memory_size_label = Label(value=f"Number of Elements in Memory: {len([x for x in memory if x is not None])}") @@ -136,7 +137,7 @@ def show_current(self): details_display = VBox(details_list) # If the image is present, prepare left-side content - if torch.is_tensor(step.image) and step.image.nelement() > 1: + if torch.is_tensor(step.image) and step.image.nelement() > 1 and not self.ignore_image: pil_image = self.tensor_to_pil(step.image) img_display = widgets.Image(value=self.pil_image_to_byte_array(pil_image), format='jpeg') display_content = HBox([img_display, details_display]) diff --git a/nbs/04_Memory/01_memory_visualizer.ipynb b/nbs/04_Memory/01_memory_visualizer.ipynb index b9fe3ab..5d089e5 100644 --- a/nbs/04_Memory/01_memory_visualizer.ipynb +++ b/nbs/04_Memory/01_memory_visualizer.ipynb @@ -61,11 +61,12 @@ "source": [ "#|export\n", "class MemoryBufferViewer:\n", - " def __init__(self, memory, agent=None):\n", + " def __init__(self, memory, agent=None, ignore_image:bool=False):\n", " # Assuming memory contains SimpleStep instances or None\n", " self.memory = memory\n", " self.agent = agent\n", " self.current_index = 0\n", + " self.ignore_image = ignore_image\n", " # Add a label for displaying the number of elements in memory\n", " self.memory_size_label = Label(value=f\"Number of Elements in Memory: {len([x for x in memory if x is not None])}\")\n", "\n", @@ -180,7 +181,7 @@ " details_display = VBox(details_list)\n", "\n", " # If the image is present, prepare left-side content\n", - " if torch.is_tensor(step.image) and step.image.nelement() > 1:\n", + " if torch.is_tensor(step.image) and step.image.nelement() > 1 and not self.ignore_image:\n", " pil_image = self.tensor_to_pil(step.image)\n", " img_display = widgets.Image(value=self.pil_image_to_byte_array(pil_image), format='jpeg')\n", " display_content = HBox([img_display, details_display])\n", diff --git a/nbs/05_Logging/09a_loggers.core.ipynb b/nbs/05_Logging/09a_loggers.core.ipynb index 32ca0f6..edca7c4 100644 --- a/nbs/05_Logging/09a_loggers.core.ipynb +++ b/nbs/05_Logging/09a_loggers.core.ipynb @@ -146,7 +146,7 @@ " # Max size of _RECORD_CATCH_LIST before raising in exception.\n", " # Important to avoid memory leaks, and indicates that `dump_records`\n", " # is not being called or used.\n", - " buffer_size=1000,\n", + " buffer_size=10000,\n", " # If True, instead of appending to _RECORD_CATCH_LIST, \n", " # drop the record so it does not continue thorugh the \n", " # pipeline.\n", diff --git a/nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb b/nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb index d4dec23..382a547 100644 --- a/nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb +++ b/nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb @@ -313,13 +313,15 @@ " # $\\lambda$ is unqiue to GAE and manages importance to values when \n", " # they are in accurate is defined in (Shulman et al., 2016) as '... $\\lambda$ < 1\n", " # introduces bias only when the value function is inaccurate....'.\n", - " gamma:float=0.99\n", + " gamma:float=0.99,\n", + " nsteps:int = 1\n", " ):\n", " self.source_datapipe = source_datapipe\n", " self.bs = bs\n", " self.critic = critic\n", " self.device = None\n", " self.discount = discount\n", + " self.nsteps = nsteps\n", " self.gamma = gamma\n", " self.env_advantage_buffer:Dict[Literal['env'],list] = {}\n", "\n", @@ -342,8 +344,7 @@ " return env_id\n", " \n", " def zip_steps(\n", - " self,\n", - " steps:List[Union[StepTypes.types]]\n", + " self,steps:List[Union[StepTypes.types]]\n", " ) -> Tuple[torch.FloatTensor,torch.FloatTensor,torch.BoolTensor]:\n", " step_subset = [(o.reward,o.state,o.truncated or o.terminated) for o in steps]\n", " zipped_fields = zip(*step_subset)\n", @@ -366,7 +367,7 @@ " with evaluating(self.critic):\n", " values = self.critic(torch.vstack((states,steps[-1].next_state)))\n", " delta = self.delta_calc(rewards,values[:-1],values[1:],dones)\n", - " discounted_cumsum_(delta,self.discount*self.gamma,reverse=True)\n", + " discounted_cumsum_(delta,self.discount*self.gamma**self.nsteps,reverse=True)\n", "\n", " for _step,gae_advantage,v in zip(*(steps,delta,values)):\n", " yield AdvantageStep(\n", @@ -405,6 +406,68 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "7971aca9", + "metadata": {}, + "outputs": [], + "source": [ + "#|export\n", + "class AdvantageFirstLastMerger(dp.iter.IterDataPipe):\n", + " def __init__(self, \n", + " source_datapipe, \n", + " gamma:float=0.99\n", + " ):\n", + " self.source_datapipe = source_datapipe\n", + " self.gamma = gamma\n", + " \n", + " def __iter__(self) -> StepTypes.types:\n", + " self.env_buffer = {}\n", + " for steps in self.source_datapipe:\n", + " if not isinstance(steps,(list,tuple)):\n", + " raise ValueError(f'Expected {self.source_datapipe} to return a list/tuple of steps, however got {type(steps)}')\n", + " \n", + " if len(steps)==1:\n", + " yield steps[0]\n", + " continue\n", + " \n", + " fstep,lstep = steps[0],steps[-1]\n", + " \n", + " reward = fstep.reward\n", + " for step in steps[1:]:\n", + " reward *= self.gamma\n", + " reward += step.reward\n", + "\n", + " advantage = fstep.advantage\n", + " for step in steps[1:]:\n", + " advantage *= self.gamma\n", + " advantage += step.advantage\n", + "\n", + " next_advantage = fstep.next_advantage\n", + " for step in steps[1:]:\n", + " next_advantage *= self.gamma\n", + " next_advantage += step.next_advantage\n", + " \n", + " yield fstep.__class__(\n", + " state=fstep.state.clone().detach(),\n", + " next_state=lstep.next_state.clone().detach(),\n", + " action=fstep.action,\n", + " terminated=lstep.terminated,\n", + " truncated=lstep.truncated,\n", + " reward=reward,\n", + " total_reward=lstep.total_reward,\n", + " env_id=lstep.env_id,\n", + " proc_id=lstep.proc_id,\n", + " step_n=lstep.step_n,\n", + " episode_n=fstep.episode_n,\n", + " image=fstep.image,\n", + " raw_action=fstep.raw_action,\n", + " advantage=advantage,\n", + " next_advantage=next_advantage\n", + " )" + ] + }, { "cell_type": "code", "execution_count": null, @@ -460,12 +523,12 @@ " include_images=include_images,\n", " terminate_on_truncation=terminate_on_truncation,\n", " synchronized_reset=synchronized_reset)\n", - " pipe = AdvantageBuffer(pipe,critic=critic,bs=adv_bs,discount=discount,gamma=gamma)\n", + " pipe = AdvantageBuffer(pipe,critic=critic,bs=adv_bs,discount=discount,gamma=gamma,nsteps=nsteps)\n", " if nskips!=1: pipe = NSkipper(pipe,n=nskips)\n", " if nsteps!=1:\n", " pipe = NStepper(pipe,n=nsteps)\n", " if firstlast:\n", - " pipe = FirstLastMerger(pipe)\n", + " pipe = AdvantageFirstLastMerger(pipe)\n", " else:\n", " pipe = NStepFlattener(pipe) # We dont want to flatten if using FirstLastMerger\n", " if n is not None: pipe = pipe.header(limit=n)\n", @@ -1118,22 +1181,6 @@ ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "260e2ccf", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "83668f11", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, @@ -1350,16 +1397,19 @@ "\n", "# Setup up the core NN\n", "torch.manual_seed(0)\n", - "actor = Actor(3,1)\n", - "critic = Critic(3)\n", + "actor = Actor(2,1)\n", + "critic = Critic(2)\n", "\n", "# Setup the Agent\n", - "agent = TRPOAgent(actor,do_logging=True,clip_min=-2,clip_max=2)\n", + "# agent = TRPOAgent(actor,do_logging=True,clip_min=-2,clip_max=2)\n", + "agent = TRPOAgent(actor,do_logging=True,clip_min=-1,clip_max=1)\n", "\n", "# Setup the Dataloaders\n", "dls = dataloaders((\n", - " AdvantageGymDataPipe(['Pendulum-v1']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=200,gamma=0.95,discount=0.99),\n", - " AdvantageGymDataPipe(['Pendulum-v1']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=1,gamma=0.95,discount=0.99)\n", + " # AdvantageGymDataPipe(['Pendulum-v1']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=200,gamma=0.99,discount=0.99),MountainCarContinuous-v0\n", + " # AdvantageGymDataPipe(['Pendulum-v1']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=1,gamma=0.99,discount=0.99)\n", + " AdvantageGymDataPipe(['MountainCarContinuous-v0']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=200,gamma=0.99,discount=0.99),\n", + " AdvantageGymDataPipe(['MountainCarContinuous-v0']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=1,gamma=0.99,discount=0.99)\n", "))\n", "# Setup the Learner\n", "learner = TRPOLearner(actor,critic,dls,logger_bases=logger_bases,\n", @@ -1368,6 +1418,120 @@ "learner.fit(10)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "604dddf6", + "metadata": {}, + "outputs": [], + "source": [ + "val_agent = TRPOAgent(actor,do_logging=True,clip_min=-1,clip_max=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4ab25a1", + "metadata": {}, + "outputs": [], + "source": [ + "# valid_pipe = AdvantageGymDataPipe(['Pendulum-v1']*1,agent=val_agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=1,gamma=0.99,discount=0.99,n=300,include_images=True)\n", + "valid_pipe = AdvantageGymDataPipe(['MountainCarContinuous-v0']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=1,gamma=0.99,discount=0.99,n=1000,include_images=True)\n", + "valid_pipe = VSCodeDataPipe(valid_pipe)\n", + "# sample_run = [o[0] for o in valid_pipe.dump_records().catch_records(drop=True)];\n", + "list(valid_pipe);\n", + "valid_pipe.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51e85344", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "def normalize(data):\n", + " mean = sum(data) / len(data)\n", + " variance = sum([(x - mean) ** 2 for x in data]) / len(data)\n", + " stddev = variance ** 0.5\n", + " return [(x - mean) / stddev for x in data]\n", + "\n", + "\n", + "def visualize_advantage_steps(steps: [AdvantageStep]):\n", + " # Extract relevant data from steps\n", + " rewards = [step.reward.item() for step in steps]\n", + " advantages = [step.advantage.item() for step in steps]\n", + " next_advantages = [step.next_advantage.item() for step in steps]\n", + " action = [step.action.item() for step in steps]\n", + " critic_values = [na - a for na, a in zip(next_advantages, advantages)]\n", + "\n", + " # Normalize the data\n", + " rewards = normalize(rewards)\n", + " advantages = normalize(advantages)\n", + " action = normalize(action)\n", + " next_advantages = normalize(next_advantages)\n", + " critic_values = normalize(critic_values)\n", + "\n", + "\n", + " # Plot the data\n", + " fig, ax = plt.subplots(figsize=(12, 6))\n", + "\n", + " # ax.plot(rewards, label=\"Rewards\", color=\"blue\")\n", + " ax.plot(advantages, label=\"Advantages\", color=\"green\")\n", + " ax.plot(next_advantages, label=\"Next Advantages\", color=\"red\")\n", + " ax.plot(critic_values, label=\"Critic Value Estimates\", color=\"purple\")\n", + " # ax.plot(action, label=\"Action Value Estimates\", color=\"orange\")\n", + " \n", + " ax.set_xlabel(\"Steps\")\n", + " ax.set_ylabel(\"Value\")\n", + " ax.set_title(\"Visualization of Advantage Steps\")\n", + " ax.legend()\n", + "\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5175cd7", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_advantage_steps(sample_run)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fae186a4", + "metadata": {}, + "outputs": [], + "source": [ + "from fastrl.memory.memory_visualizer import MemoryBufferViewer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e7360cd", + "metadata": {}, + "outputs": [], + "source": [ + "sample_run[0].advantage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07c59d61", + "metadata": {}, + "outputs": [], + "source": [ + "MemoryBufferViewer(sample_run,val_agent)" + ] + }, { "cell_type": "code", "execution_count": null,