Skip to content

Commit

Permalink
Updated: Continous Action API (#8)
Browse files Browse the repository at this point in the history
- all learners for all agents now have fit and validation capabilities
- cleaned up a ton of stuff
- TRPO and PPO still dont work right. I'll probably try to tacking this is another pr
  • Loading branch information
josiahls authored Nov 2, 2023
1 parent b66b8f6 commit f0712b6
Show file tree
Hide file tree
Showing 24 changed files with 507 additions and 1,120 deletions.
19 changes: 16 additions & 3 deletions .github/workflows/fastrl-docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
branches:
- main
- update_nbdev_docs
- feature/ppo
- refactor/bug-fix-api-update-and-stablize

jobs:
build:
Expand Down Expand Up @@ -54,14 +54,27 @@ jobs:
env:
BUILD_TYPE: ${{ matrix.build_type }}

- name: Cache Docker layers
if: always()
uses: actions/cache@v2
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-${{ github.sha }}
restore-keys: |
${{ runner.os }}-buildx-
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: build and tag container
run: |
export DOCKER_BUILDKIT=1
# We need to clear the previous docker images
# docker system prune -fa
docker pull ${IMAGE_NAME}:latest || true
# docker pull ${IMAGE_NAME}:latest || true
# docker build --build-arg BUILD=${BUILD_TYPE} \
docker build --cache-from ${IMAGE_NAME}:latest --build-arg BUILD=${BUILD_TYPE} \
docker buildx create --use
docker buildx build --cache-from=type=local,src=/tmp/.buildx-cache --cache-to=type=local,dest=/tmp/.buildx-cache --build-arg BUILD=${BUILD_TYPE} \
-t ${IMAGE_NAME}:latest \
-t ${IMAGE_NAME}:${VERSION} \
-t ${IMAGE_NAME}:$(date +%F) \
Expand Down
4 changes: 3 additions & 1 deletion fastrl/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,9 @@
'fastrl.loggers.vscode_visualizers.SimpleVSCodeVideoPlayer.show': ( '05_Logging/loggers.vscode_visualizers.html#simplevscodevideoplayer.show',
'fastrl/loggers/vscode_visualizers.py'),
'fastrl.loggers.vscode_visualizers.VSCodeDataPipe': ( '05_Logging/loggers.vscode_visualizers.html#vscodedatapipe',
'fastrl/loggers/vscode_visualizers.py')},
'fastrl/loggers/vscode_visualizers.py'),
'fastrl.loggers.vscode_visualizers.VSCodeDataPipe.__new__': ( '05_Logging/loggers.vscode_visualizers.html#vscodedatapipe.__new__',
'fastrl/loggers/vscode_visualizers.py')},
'fastrl.memory.experience_replay': { 'fastrl.memory.experience_replay.ExperienceReplay': ( '04_Memory/memory.experience_replay.html#experiencereplay',
'fastrl/memory/experience_replay.py'),
'fastrl.memory.experience_replay.ExperienceReplay.__init__': ( '04_Memory/memory.experience_replay.html#experiencereplay.__init__',
Expand Down
53 changes: 26 additions & 27 deletions fastrl/agents/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,32 @@
'CriticLossProcessor', 'ActorLossProcessor', 'DDPGLearner']

# %% ../../nbs/07_Agents/02_Continuous/12s_agents.ddpg.ipynb 2
# # Python native modules
# import os
# Python native modules
from typing import Tuple,Optional,Callable,Union,Dict,Literal,List
from functools import partial
# from typing_extensions import Literal
from copy import deepcopy
# # Third party libs
# Third party libs
from fastcore.all import add_docs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import traverse_dps,find_dps,DataPipe
# from torchdata.dataloader2.graph import DataPipe,traverse
from torch import nn
# from torch.optim import AdamW,Adam
import torch
# import pandas as pd
# import numpy as np
# # Local modules
# Local modules
from ..core import SimpleStep
from ..pipes.core import find_dp
from ..torch_core import Module
from ..memory.experience_replay import ExperienceReplay
from ..loggers.core import Record,is_record,not_record,_RECORD_CATCH_LIST
from ..learner.core import LearnerBase,LearnerHead,StepBatcher
# from fastrl.pipes.core import *
# from fastrl.data.block import *
# from fastrl.data.dataloader2 import *
from ..loggers.vscode_visualizers import VSCodeDataPipe
from fastrl.loggers.core import (
LogCollector,Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,is_record
ProgressBarLogger,
Record,
BatchCollector,
EpochCollector,
RollingTerminatedRewardCollector,
EpisodeCollector,
not_record,
_RECORD_CATCH_LIST
)
from fastrl.agents.core import (
AgentHead,
Expand All @@ -43,9 +41,6 @@
SimpleModelRunner,
NumpyConverter
)
# from fastrl.memory.experience_replay import ExperienceReplay
# from fastrl.learner.core import *
# from fastrl.loggers.core import *

# %% ../../nbs/07_Agents/02_Continuous/12s_agents.ddpg.ipynb 6
def init_xavier_uniform_weights(m:Module,bias=0.01):
Expand Down Expand Up @@ -798,7 +793,7 @@ def DDPGLearner(
critic:Critic,
# A list of dls, where index=0 is the training dl.
dls,
logger_bases:Optional[Callable]=None,
do_logging:bool=True,
# The learning rate for the actor. Expected to learn slower than the critic
actor_lr:float=1e-3,
# The optimizer for the actor
Expand Down Expand Up @@ -830,11 +825,12 @@ def DDPGLearner(
# Debug mode will output device moves
debug:bool=False
) -> LearnerHead:
learner = LearnerBase(actor,dls[0])
learner = LearnerBase({'actor':actor,'critic':critic},dls[0])
learner = BatchCollector(learner,batches=batches)
learner = EpochCollector(learner)
if logger_bases:
learner = logger_bases(learner)
if do_logging:
learner = learner.dump_records()
learner = ProgressBarLogger(learner)
learner = RollingTerminatedRewardCollector(learner)
learner = EpisodeCollector(learner).catch_records()
learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
Expand All @@ -847,12 +843,15 @@ def DDPGLearner(
learner = ActorLossProcessor(learner,critic,actor,clip_critic_grad=5)
learner = LossCollector(learner,title='actor-loss').catch_records()
learner = BasicOptStepper(learner,actor,actor_lr,opt=actor_opt,filter=True,do_zero_grad=False)
learner = LearnerHead(learner,(actor,critic))

# for dl in dls:
# pipe_to_device(dl.datapipe,device,debug=debug)

return learner
learner = LearnerHead(learner)

if len(dls)==2:
val_learner = LearnerBase({'actor':actor,'critic':critic},dls[1]).visualize_vscode()
val_learner = BatchCollector(val_learner,batches=batches)
val_learner = EpochCollector(val_learner).catch_records(drop=True)
return LearnerHead((learner,val_learner))
else:
return LearnerHead(learner)

DDPGLearner.__doc__="""DDPG is a continuous action, actor-critic model, first created in
(Lillicrap et al., 2016). The critic estimates a Q value estimate, and the actor
Expand Down
24 changes: 11 additions & 13 deletions fastrl/agents/dqn/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,19 @@
from collections import deque
from typing import Callable,Optional,List
# Third party libs
from fastcore.all import ifnone
import torchdata.datapipes as dp
from torchdata.dataloader2 import DataLoader2
from torchdata.dataloader2.graph import traverse_dps,DataPipe
import torch
import torch.nn.functional as F
from torch import optim
from torch import nn
import numpy as np
# Local modules
from ..core import AgentHead,AgentBase
from ...pipes.core import find_dp
from ...memory.experience_replay import ExperienceReplay
from ..core import StepFieldSelector,SimpleModelRunner,NumpyConverter
from ..discrete import EpsilonCollector,PyPrimativeConverter,ArgMaxer,EpsilonSelector
from fastrl.loggers.core import (
LogCollector,Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,is_record
Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,ProgressBarLogger
)
from ...learner.core import LearnerBase,LearnerHead,StepBatcher
from ...torch_core import Module
Expand Down Expand Up @@ -156,7 +152,7 @@ def __iter__(self):
def DQNLearner(
model,
dls,
logger_bases:Optional[Callable]=None,
do_logging:bool=True,
loss_func=nn.MSELoss(),
opt=optim.AdamW,
lr=0.005,
Expand All @@ -169,25 +165,27 @@ def DQNLearner(
learner = LearnerBase(model,dls[0])
learner = BatchCollector(learner,batches=batches)
learner = EpochCollector(learner)
if logger_bases:
learner = logger_bases(learner)
if do_logging:
learner = learner.dump_records()
learner = ProgressBarLogger(learner)
learner = RollingTerminatedRewardCollector(learner)
learner = EpisodeCollector(learner)
learner = learner.catch_records()
learner = learner.catch_records(drop=not do_logging)

learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz,freeze_memory=True)
learner = StepBatcher(learner,device=device)
learner = QCalc(learner)
learner = TargetCalc(learner,nsteps=nsteps)
learner = LossCalc(learner,loss_func=loss_func)
learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
if logger_bases:
if do_logging:
learner = LossCollector(learner).catch_records()

if len(dls)==2:
val_learner = LearnerBase(model,dls[1])
val_learner = LearnerBase(model,dls[1]).visualize_vscode()
val_learner = BatchCollector(val_learner,batches=batches)
val_learner = EpochCollector(val_learner).dump_records()
learner = LearnerHead((learner,val_learner),model)
learner = LearnerHead((learner,val_learner))
else:
learner = LearnerHead(learner,model)
learner = LearnerHead(learner)
return learner
19 changes: 9 additions & 10 deletions fastrl/agents/dqn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def CategoricalDQNAgent(
def DQNCategoricalLearner(
model,
dls,
logger_bases:Optional[Callable]=None,
do_logging:bool=True,
loss_func=PartialCrossEntropy,
opt=optim.AdamW,
lr=0.005,
Expand All @@ -276,29 +276,28 @@ def DQNCategoricalLearner(
learner = LearnerBase(model,dls=dls[0])
learner = BatchCollector(learner,batches=batches)
learner = EpochCollector(learner)
if logger_bases:
learner = logger_bases(learner)
if do_logging:
learner = learner.dump_records()
learner = ProgressBarLogger(learner)
learner = RollingTerminatedRewardCollector(learner)
learner = EpisodeCollector(learner)
learner = learner.catch_records()
learner = EpisodeCollector(learner).catch_records()
learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
learner = StepBatcher(learner,device=device)
learner = CategoricalTargetQCalc(learner,nsteps=nsteps,double_dqn_strategy=double_dqn_strategy).to(device=device)
# learner = TargetCalc(learner,nsteps=nsteps)
learner = LossCalc(learner,loss_func=loss_func)
learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
learner = TargetModelUpdater(learner,target_sync=target_sync)
if logger_bases:
if do_logging:
learner = LossCollector(learner).catch_records()

if len(dls)==2:
val_learner = LearnerBase(model,dls[1])
val_learner = LearnerBase(model,dls[1]).visualize_vscode()
val_learner = BatchCollector(val_learner,batches=batches)
val_learner = EpochCollector(val_learner).catch_records(drop=True)
val_learner = VSCodeDataPipe(val_learner)
return LearnerHead((learner,val_learner),model)
return LearnerHead((learner,val_learner))
else:
return LearnerHead(learner,model)
return LearnerHead(learner)

# %% ../../../nbs/07_Agents/01_Discrete/12o_agents.dqn.categorical.ipynb 49
def show_q(cat_dist,title='Update Distributions'):
Expand Down
23 changes: 11 additions & 12 deletions fastrl/agents/dqn/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ...loggers.core import BatchCollector,EpochCollector
from ...learner.core import LearnerBase,LearnerHead
from ...loggers.vscode_visualizers import VSCodeDataPipe
from ...loggers.core import ProgressBarLogger
from fastrl.agents.dqn.basic import (
LossCollector,
RollingTerminatedRewardCollector,
Expand All @@ -36,7 +37,7 @@

# %% ../../../nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb 5
class DoubleQCalc(dp.iter.IterDataPipe):
def __init__(self,source_datapipe=None):
def __init__(self,source_datapipe):
self.source_datapipe = source_datapipe

def __iter__(self):
Expand All @@ -53,7 +54,7 @@ def __iter__(self):
def DoubleDQNLearner(
model,
dls,
logger_bases:Optional[Callable]=None,
do_logging:bool=True,
loss_func=nn.MSELoss(),
opt=optim.AdamW,
lr=0.005,
Expand All @@ -67,27 +68,25 @@ def DoubleDQNLearner(
learner = LearnerBase(model,dls=dls[0])
learner = BatchCollector(learner,batches=batches)
learner = EpochCollector(learner)
if logger_bases:
learner = logger_bases(learner)
if do_logging:
learner = learner.dump_records()
learner = ProgressBarLogger(learner)
learner = RollingTerminatedRewardCollector(learner)
learner = EpisodeCollector(learner)
learner = learner.catch_records()
learner = EpisodeCollector(learner).catch_records()
learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
learner = StepBatcher(learner,device=device)
# learner = TargetModelQCalc(learner)
learner = DoubleQCalc(learner)
learner = TargetCalc(learner,nsteps=nsteps)
learner = LossCalc(learner,loss_func=loss_func)
learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
learner = TargetModelUpdater(learner,target_sync=target_sync)
if logger_bases:
if do_logging:
learner = LossCollector(learner).catch_records()

if len(dls)==2:
val_learner = LearnerBase(model,dls[1])
val_learner = LearnerBase(model,dls[1]).visualize_vscode()
val_learner = BatchCollector(val_learner,batches=batches)
val_learner = EpochCollector(val_learner).catch_records(drop=True)
val_learner = VSCodeDataPipe(val_learner)
return LearnerHead((learner,val_learner),model)
return LearnerHead((learner,val_learner))
else:
return LearnerHead(learner,model)
return LearnerHead(learner)
39 changes: 10 additions & 29 deletions fastrl/agents/dqn/dueling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,29 @@

# %% ../../../nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb 2
# Python native modules
from copy import deepcopy
from typing import Optional,Callable,Tuple
# Third party libs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import traverse_dps,DataPipe
import torch
from torch import nn,optim
# Local modulesf
from ...pipes.core import find_dp
from ...memory.experience_replay import ExperienceReplay
from ...loggers.core import BatchCollector,EpochCollector
from ...learner.core import LearnerBase,LearnerHead
from ...loggers.vscode_visualizers import VSCodeDataPipe
from torch import nn
# Local modules
from fastrl.agents.dqn.basic import (
LossCollector,
RollingTerminatedRewardCollector,
EpisodeCollector,
StepBatcher,
TargetCalc,
LossCalc,
ModelLearnCalc,
DQN,
DQNAgent
)
from fastrl.agents.dqn.target import (
TargetModelUpdater,
TargetModelQCalc,
DQNTargetLearner
)
from .target import DQNTargetLearner

# %% ../../../nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb 5
class DuelingHead(nn.Module):
def __init__(self,
hidden:int, # Input into the DuelingHead, likely a hidden layer input
n_actions:int, # Number/dim of actions to output
lin_cls=nn.Linear
def __init__(
self,
hidden: int, # Input into the DuelingHead, likely a hidden layer input
n_actions: int, # Number/dim of actions to output
lin_cls = nn.Linear
):
super().__init__()
self.val = lin_cls(hidden,1)
self.adv = lin_cls(hidden,n_actions)

def forward(self,xi):
val,adv=self.val(xi),self.adv(xi)
xi=val.expand_as(adv)+(adv-adv.mean()).squeeze(0)
val,adv = self.val(xi),self.adv(xi)
xi = val.expand_as(adv)+(adv-adv.mean()).squeeze(0)
return xi
Loading

0 comments on commit f0712b6

Please sign in to comment.