Skip to content

Commit

Permalink
Updated:
Browse files Browse the repository at this point in the history
- firstlast merger now takes a merging function to support different
step merging operations. Might just rename the FirstLastMerger to a StepMerger,
but idk.
  • Loading branch information
josiahls committed Nov 4, 2023
1 parent d8f8786 commit d09b40a
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 206 deletions.
6 changes: 5 additions & 1 deletion fastrl/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,11 @@
'fastrl.core.StepTypeRegistry.is_registered': ('core.html#steptyperegistry.is_registered', 'fastrl/core.py'),
'fastrl.core.StepTypeRegistry.register': ('core.html#steptyperegistry.register', 'fastrl/core.py'),
'fastrl.core.StepTypeRegistry.types': ('core.html#steptyperegistry.types', 'fastrl/core.py'),
'fastrl.core._fmt_dataflass_fld': ('core.html#_fmt_dataflass_fld', 'fastrl/core.py'),
'fastrl.core._fmt_fld': ('core.html#_fmt_fld', 'fastrl/core.py'),
'fastrl.core._len_check': ('core.html#_len_check', 'fastrl/core.py'),
'fastrl.core._less_than': ('core.html#_less_than', 'fastrl/core.py'),
'fastrl.core.add_dataclass_doc': ('core.html#add_dataclass_doc', 'fastrl/core.py'),
'fastrl.core.add_namedtuple_doc': ('core.html#add_namedtuple_doc', 'fastrl/core.py'),
'fastrl.core.default_logging': ('core.html#default_logging', 'fastrl/core.py'),
'fastrl.core.test_in': ('core.html#test_in', 'fastrl/core.py'),
Expand Down Expand Up @@ -742,7 +744,9 @@
'fastrl.pipes.iter.firstlast.FirstLastMerger.__iter__': ( '01_DataPipes/pipes.iter.firstlast.html#firstlastmerger.__iter__',
'fastrl/pipes/iter/firstlast.py'),
'fastrl.pipes.iter.firstlast.n_first_last_steps_expected': ( '01_DataPipes/pipes.iter.firstlast.html#n_first_last_steps_expected',
'fastrl/pipes/iter/firstlast.py')},
'fastrl/pipes/iter/firstlast.py'),
'fastrl.pipes.iter.firstlast.simple_step_first_last_merge': ( '01_DataPipes/pipes.iter.firstlast.html#simple_step_first_last_merge',
'fastrl/pipes/iter/firstlast.py')},
'fastrl.pipes.iter.nskip': { 'fastrl.pipes.iter.nskip.NSkipper': ( '01_DataPipes/pipes.iter.nskip.html#nskipper',
'fastrl/pipes/iter/nskip.py'),
'fastrl.pipes.iter.nskip.NSkipper.__init__': ( '01_DataPipes/pipes.iter.nskip.html#nskipper.__init__',
Expand Down
15 changes: 9 additions & 6 deletions fastrl/agents/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# %% ../../nbs/07_Agents/12a_agents.core.ipynb 2
# Python native modules
import os
from typing import List
from typing import List,Optional
# Third party libs
from fastcore.all import add_docs,ifnone
import torchdata.datapipes as dp
Expand Down Expand Up @@ -39,7 +39,7 @@ def get_shared_model(name="default"):

class AgentBase(dp.iter.IterDataPipe):
def __init__(self,
model:nn.Module, # The base NN that we getting raw action values out of.
model:Optional[nn.Module], # The base NN that we getting raw action values out of.
action_iterator:list=None, # A reference to an iterator that contains actions to process.
logger_bases=None
):
Expand All @@ -50,22 +50,25 @@ def __init__(self,
self._mem_name = 'agent_model'

def to(self,*args,**kwargs):
self.model.to(**kwargs)
if self.model is not None:
self.model.to(**kwargs)

def __iter__(self):
while self.iterable:
yield self.iterable.pop(0)

def __getstate__(self):
share_model(self.model,self._mem_name)
if self.model is not None:
share_model(self.model,self._mem_name)
# Store the non-model state
state = self.__dict__.copy()
return state

def __setstate__(self, state):
self.__dict__.update(state)
# Assume a globally shared model instance or a reference method to retrieve it
self.model = get_shared_model(self._mem_name)
if self.model is not None:
self.model = get_shared_model(self._mem_name)

add_docs(
AgentBase,
Expand Down Expand Up @@ -98,7 +101,7 @@ def __iter__(self): yield from self.source_datapipe

def augment_actions(self,actions): return actions

def create_step(self,**kwargs): return SimpleStep(**kwargs)
def create_step(self,**kwargs): return SimpleStep(**kwargs,batch_size=[])

add_docs(
AgentHead,
Expand Down
104 changes: 66 additions & 38 deletions fastrl/core.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_core.ipynb.

# %% auto 0
__all__ = ['StepTypes', 'add_namedtuple_doc', 'SimpleStep', 'StepTypeRegistry', 'Record', 'default_logging', 'test_in',
'test_out', 'test_len', 'test_lt']
__all__ = ['StepTypes', 'add_namedtuple_doc', 'add_dataclass_doc', 'SimpleStep', 'StepTypeRegistry', 'Record', 'default_logging',
'test_in', 'test_out', 'test_len', 'test_lt']

# %% ../nbs/00_core.ipynb 1
# Python native modules
import typing
import logging
# Third party libs
from fastcore.test import test,test_fail
from fastcore.basics import in_
from fastcore.imports import in_notebook
from fastcore.all import in_notebook,in_,test,test_fail,add_docs
import torch
from tensordict import tensorclass
# Local modules
Expand All @@ -36,10 +34,34 @@ def add_namedtuple_doc(
for k,v in t.__annotations__.items():
flds.append(_fmt_fld(k,v,t))

s = 'Parameters:\n\n'+'\n'.join(flds)
t.__doc__ = doc + '\n\n' + s

def _fmt_dataflass_fld(name,t:typing.Tuple[str,type],obj):
default_v = ''
if name in obj.__dataclass_fields__:
default_v = f' = `{t.default}`'
return ' - **%s**:`%s` '%(name,t.type)+default_v+getattr(obj,name).__doc__

def add_dataclass_doc(
t:object, # Primary tuple to get docs from
doc:str, # Primary doc for the overall tuple, where the docs for individual fields will be concated.
**fields_docs:dict # Field names with associated docs to be attached in the format: field_a='some documentation'
):
"Add docs to `t` from `doc` along with individual doc fields `fields_docs`"
if not hasattr(t,'__base_doc__'): t.__base_doc__ = doc
for k,v in fields_docs.items():
if k in t.__dataclass_fields__:
getattr(t,k).__doc__ = v
# TODO: can we add optional default fields also?
flds = []
for k,v in t.__dataclass_fields__.items():
flds.append(_fmt_dataflass_fld(k,v,t))

s = 'Parameters:\n\n'+'\n'.join(flds)
t.__doc__ = doc + '\n\n' + s

# %% ../nbs/00_core.ipynb 9
# %% ../nbs/00_core.ipynb 7
@tensorclass
class SimpleStep:
state: torch.FloatTensor = torch.FloatTensor([0])
Expand All @@ -56,39 +78,45 @@ class SimpleStep:
image: torch.FloatTensor = torch.FloatTensor([0])
raw_action: torch.FloatTensor = torch.FloatTensor([0])

def random(self):

@classmethod
def random(cls,batch_size=(1,),**flds):
"Returns `cls` with all fields not defined in `flds` with `batch_size`"
self = cls(batch_size=batch_size,**flds)
d = self._tensordict
for v in d.values():
for k,v in d.items():
if k in flds:
continue
if isinstance(v,torch.BoolTensor):
v.random_(0,1)
else:
v.random_(0,100)
return self

# add_namedtuple_doc(
# SimpleStep,
# 'Represents a single step in an environment.',
# state = 'Both the initial state of the environment and the previous state.',
# next_state = 'Both the next state, and the last state in the environment',
# terminated = """Represents an ending condition for an environment such as reaching a goal or 'living long enough' as
# described by the MDP.
# Good reference is: https://github.com/openai/gym/blob/39b8661cb09f19cb8c8d2f59b57417517de89cb0/gym/core.py#L151-L155""",
# truncated = """Represents an ending condition for an environment that can be seen as an out of bounds condition either
# literally going out of bounds, breaking rules, or exceeding the timelimit allowed by the MDP.
# Good reference is: https://github.com/openai/gym/blob/39b8661cb09f19cb8c8d2f59b57417517de89cb0/gym/core.py#L151-L155'""",
# reward = 'The single reward for this step.',
# total_reward = 'The total accumulated reward for this episode up to this step.',
# action = 'The action that was taken to transition from `state` to `next_state`',
# env_id = 'The environment this step came from (useful for debugging)',
# proc_id = 'The process this step came from (useful for debugging)',
# step_n = 'The step number in a given episode.',
# episode_n = 'The episode this environment is currently running through.',
# image = """Intended for display and logging only. If the intention is to use images for training an
# agent, then use a env wrapper instead.""",
# raw_action="The immediate raw output of the model before any post processing"
# )

# %% ../nbs/00_core.ipynb 19
add_dataclass_doc(
SimpleStep,
'Represents a single step in an environment.',
state = 'Both the initial state of the environment and the previous state.',
next_state = 'Both the next state, and the last state in the environment',
terminated = """Represents an ending condition for an environment such as reaching a goal or 'living long enough' as
described by the MDP.
Good reference is: https://github.com/openai/gym/blob/39b8661cb09f19cb8c8d2f59b57417517de89cb0/gym/core.py#L151-L155""",
truncated = """Represents an ending condition for an environment that can be seen as an out of bounds condition either
literally going out of bounds, breaking rules, or exceeding the timelimit allowed by the MDP.
Good reference is: https://github.com/openai/gym/blob/39b8661cb09f19cb8c8d2f59b57417517de89cb0/gym/core.py#L151-L155'""",
reward = 'The single reward for this step.',
total_reward = 'The total accumulated reward for this episode up to this step.',
action = 'The action that was taken to transition from `state` to `next_state`',
env_id = 'The environment this step came from (useful for debugging)',
proc_id = 'The process this step came from (useful for debugging)',
step_n = 'The step number in a given episode.',
episode_n = 'The episode this environment is currently running through.',
image = """Intended for display and logging only. If the intention is to use images for training an
agent, then use a env wrapper instead.""",
raw_action="The immediate raw output of the model before any post processing"
)

# %% ../nbs/00_core.ipynb 16
class StepTypeRegistry(object):
def __init__(self):
self._registered_types = set()
Expand All @@ -105,12 +133,12 @@ def types(self) -> typing.Tuple: return tuple(self._registered_types)
StepTypes = StepTypeRegistry()
StepTypes.register(SimpleStep)

# %% ../nbs/00_core.ipynb 23
# %% ../nbs/00_core.ipynb 20
class Record(typing.NamedTuple):
name:str
value:typing.Any

# %% ../nbs/00_core.ipynb 24
# %% ../nbs/00_core.ipynb 21
def default_logging(level=logging.WARNING):
"""
Returns default logging settings.
Expand All @@ -128,25 +156,25 @@ def default_logging(level=logging.WARNING):
'format': '%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s: %(message)s'
}

# %% ../nbs/00_core.ipynb 27
# %% ../nbs/00_core.ipynb 24
def test_in(a,b):
"`test` that `a in b`"
test(a,b,in_, ' in ')

# %% ../nbs/00_core.ipynb 29
# %% ../nbs/00_core.ipynb 26
def test_out(a,b):
"`test` that `a is not in b` or `a is outside b`"
test_fail(test,args=(a,b,in_), msg=f'{a} not in {b}')

# %% ../nbs/00_core.ipynb 31
# %% ../nbs/00_core.ipynb 28
def _len_check(a,b):
return len(a)==(len(b) if not isinstance(b,int) else b)

def test_len(a,b,meta_info=''):
"`test` that `len(a) == int(b) or len(a) == len(b)`"
test(a,b,_len_check, f' len == len {meta_info}')

# %% ../nbs/00_core.ipynb 33
# %% ../nbs/00_core.ipynb 30
def _less_than(a,b): return a < b
def test_lt(a,b):
"`test` that `a < b`"
Expand Down
4 changes: 2 additions & 2 deletions fastrl/envs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def env_reset(self,
self._env_ids[env_id] = step
return step

def no_agent_create_step(self,**kwargs): return SimpleStep(**kwargs)
def no_agent_create_step(self,**kwargs): return SimpleStep(**kwargs,batch_size=[])

def __iter__(self) -> SimpleStep:
for env in self.source_datapipe:
Expand Down Expand Up @@ -165,7 +165,7 @@ def __iter__(self) -> SimpleStep:
reset="Resets the env's back to original str types to avoid pickling issues."
)

# %% ../../nbs/03_Environment/05b_envs.gym.ipynb 56
# %% ../../nbs/03_Environment/05b_envs.gym.ipynb 54
def GymDataPipe(
source,
agent:DataPipe=None, # An AgentHead
Expand Down
59 changes: 33 additions & 26 deletions fastrl/pipes/iter/firstlast.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,55 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/01_DataPipes/01f_pipes.iter.firstlast.ipynb.

# %% auto 0
__all__ = ['FirstLastMerger', 'n_first_last_steps_expected']
__all__ = ['simple_step_first_last_merge', 'FirstLastMerger', 'n_first_last_steps_expected']

# %% ../../../nbs/01_DataPipes/01f_pipes.iter.firstlast.ipynb 2
# Python native modules
import warnings
from typing import Callable,List,Union
# Third party libs
from fastcore.all import add_docs
import torchdata.datapipes as dp

import torch
# Local modules
from ...core import StepTypes
from ...core import StepTypes,SimpleStep

# %% ../../../nbs/01_DataPipes/01f_pipes.iter.firstlast.ipynb 4
def simple_step_first_last_merge(steps:List[SimpleStep],gamma):
fstep,lstep = steps[0],steps[-1]

reward = fstep.reward
for step in steps[1:]:
reward *= gamma
reward += step.reward

yield SimpleStep(
state=fstep.state.clone().detach(),
next_state=lstep.next_state.clone().detach(),
action=fstep.action,
episode_n=fstep.episode_n,
image=fstep.image,
reward=reward,
raw_action=fstep.raw_action,
terminated=lstep.terminated,
truncated=lstep.truncated,
total_reward=lstep.total_reward,
env_id=lstep.env_id,
proc_id=lstep.proc_id,
step_n=lstep.step_n,
batch_size=[]
)

class FirstLastMerger(dp.iter.IterDataPipe):
def __init__(self,
source_datapipe,
source_datapipe,
merge_behavior:Callable[[List[Union[StepTypes.types]],float],Union[StepTypes.types]]=simple_step_first_last_merge,
gamma:float=0.99
):
self.source_datapipe = source_datapipe
self.gamma = gamma
self.merge_behavior = merge_behavior

def __iter__(self) -> StepTypes.types:
self.env_buffer = {}
Expand All @@ -33,36 +61,15 @@ def __iter__(self) -> StepTypes.types:
yield steps[0]
continue

fstep,lstep = steps[0],steps[-1]

reward = fstep.reward
for step in steps[1:]:
reward *= self.gamma
reward += step.reward

yield fstep.__class__(
state=fstep.state.clone().detach(),
next_state=lstep.next_state.clone().detach(),
action=fstep.action,
terminated=lstep.terminated,
truncated=lstep.truncated,
reward=reward,
total_reward=lstep.total_reward,
env_id=lstep.env_id,
proc_id=lstep.proc_id,
step_n=lstep.step_n,
episode_n=fstep.episode_n,
image=fstep.image,
raw_action=fstep.raw_action
)
yield from simple_step_first_last_merge(steps,gamma=self.gamma)

add_docs(
FirstLastMerger,
"""Takes multiple steps and converts them into a single step consisting of properties
from the first and last steps. Reward is recalculated to factor in the multiple steps.""",
)

# %% ../../../nbs/01_DataPipes/01f_pipes.iter.firstlast.ipynb 13
# %% ../../../nbs/01_DataPipes/01f_pipes.iter.firstlast.ipynb 15
def n_first_last_steps_expected(
default_steps:int, # The number of steps the episode would run without n_steps
):
Expand Down
Loading

0 comments on commit d09b40a

Please sign in to comment.