Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated: Continous Action API #8

Merged
merged 10 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions .github/workflows/fastrl-docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
branches:
- main
- update_nbdev_docs
- feature/ppo
- refactor/bug-fix-api-update-and-stablize

jobs:
build:
Expand Down Expand Up @@ -54,14 +54,27 @@ jobs:
env:
BUILD_TYPE: ${{ matrix.build_type }}

- name: Cache Docker layers
if: always()
uses: actions/cache@v2
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-${{ github.sha }}
restore-keys: |
${{ runner.os }}-buildx-

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: build and tag container
run: |
export DOCKER_BUILDKIT=1
# We need to clear the previous docker images
# docker system prune -fa
docker pull ${IMAGE_NAME}:latest || true
# docker pull ${IMAGE_NAME}:latest || true
# docker build --build-arg BUILD=${BUILD_TYPE} \
docker build --cache-from ${IMAGE_NAME}:latest --build-arg BUILD=${BUILD_TYPE} \
docker buildx create --use
docker buildx build --cache-from=type=local,src=/tmp/.buildx-cache --cache-to=type=local,dest=/tmp/.buildx-cache --build-arg BUILD=${BUILD_TYPE} \
-t ${IMAGE_NAME}:latest \
-t ${IMAGE_NAME}:${VERSION} \
-t ${IMAGE_NAME}:$(date +%F) \
Expand Down
4 changes: 3 additions & 1 deletion fastrl/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,9 @@
'fastrl.loggers.vscode_visualizers.SimpleVSCodeVideoPlayer.show': ( '05_Logging/loggers.vscode_visualizers.html#simplevscodevideoplayer.show',
'fastrl/loggers/vscode_visualizers.py'),
'fastrl.loggers.vscode_visualizers.VSCodeDataPipe': ( '05_Logging/loggers.vscode_visualizers.html#vscodedatapipe',
'fastrl/loggers/vscode_visualizers.py')},
'fastrl/loggers/vscode_visualizers.py'),
'fastrl.loggers.vscode_visualizers.VSCodeDataPipe.__new__': ( '05_Logging/loggers.vscode_visualizers.html#vscodedatapipe.__new__',
'fastrl/loggers/vscode_visualizers.py')},
'fastrl.memory.experience_replay': { 'fastrl.memory.experience_replay.ExperienceReplay': ( '04_Memory/memory.experience_replay.html#experiencereplay',
'fastrl/memory/experience_replay.py'),
'fastrl.memory.experience_replay.ExperienceReplay.__init__': ( '04_Memory/memory.experience_replay.html#experiencereplay.__init__',
Expand Down
53 changes: 26 additions & 27 deletions fastrl/agents/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,32 @@
'CriticLossProcessor', 'ActorLossProcessor', 'DDPGLearner']

# %% ../../nbs/07_Agents/02_Continuous/12s_agents.ddpg.ipynb 2
# # Python native modules
# import os
# Python native modules
from typing import Tuple,Optional,Callable,Union,Dict,Literal,List
from functools import partial
# from typing_extensions import Literal
from copy import deepcopy
# # Third party libs
# Third party libs
from fastcore.all import add_docs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import traverse_dps,find_dps,DataPipe
# from torchdata.dataloader2.graph import DataPipe,traverse
from torch import nn
# from torch.optim import AdamW,Adam
import torch
# import pandas as pd
# import numpy as np
# # Local modules
# Local modules
from ..core import SimpleStep
from ..pipes.core import find_dp
from ..torch_core import Module
from ..memory.experience_replay import ExperienceReplay
from ..loggers.core import Record,is_record,not_record,_RECORD_CATCH_LIST
from ..learner.core import LearnerBase,LearnerHead,StepBatcher
# from fastrl.pipes.core import *
# from fastrl.data.block import *
# from fastrl.data.dataloader2 import *
from ..loggers.vscode_visualizers import VSCodeDataPipe
from fastrl.loggers.core import (
LogCollector,Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,is_record
ProgressBarLogger,
Record,
BatchCollector,
EpochCollector,
RollingTerminatedRewardCollector,
EpisodeCollector,
not_record,
_RECORD_CATCH_LIST
)
from fastrl.agents.core import (
AgentHead,
Expand All @@ -43,9 +41,6 @@
SimpleModelRunner,
NumpyConverter
)
# from fastrl.memory.experience_replay import ExperienceReplay
# from fastrl.learner.core import *
# from fastrl.loggers.core import *

# %% ../../nbs/07_Agents/02_Continuous/12s_agents.ddpg.ipynb 6
def init_xavier_uniform_weights(m:Module,bias=0.01):
Expand Down Expand Up @@ -798,7 +793,7 @@ def DDPGLearner(
critic:Critic,
# A list of dls, where index=0 is the training dl.
dls,
logger_bases:Optional[Callable]=None,
do_logging:bool=True,
# The learning rate for the actor. Expected to learn slower than the critic
actor_lr:float=1e-3,
# The optimizer for the actor
Expand Down Expand Up @@ -830,11 +825,12 @@ def DDPGLearner(
# Debug mode will output device moves
debug:bool=False
) -> LearnerHead:
learner = LearnerBase(actor,dls[0])
learner = LearnerBase({'actor':actor,'critic':critic},dls[0])
learner = BatchCollector(learner,batches=batches)
learner = EpochCollector(learner)
if logger_bases:
learner = logger_bases(learner)
if do_logging:
learner = learner.dump_records()
learner = ProgressBarLogger(learner)
learner = RollingTerminatedRewardCollector(learner)
learner = EpisodeCollector(learner).catch_records()
learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
Expand All @@ -847,12 +843,15 @@ def DDPGLearner(
learner = ActorLossProcessor(learner,critic,actor,clip_critic_grad=5)
learner = LossCollector(learner,title='actor-loss').catch_records()
learner = BasicOptStepper(learner,actor,actor_lr,opt=actor_opt,filter=True,do_zero_grad=False)
learner = LearnerHead(learner,(actor,critic))

# for dl in dls:
# pipe_to_device(dl.datapipe,device,debug=debug)

return learner
learner = LearnerHead(learner)

if len(dls)==2:
val_learner = LearnerBase({'actor':actor,'critic':critic},dls[1]).visualize_vscode()
val_learner = BatchCollector(val_learner,batches=batches)
val_learner = EpochCollector(val_learner).catch_records(drop=True)
return LearnerHead((learner,val_learner))
else:
return LearnerHead(learner)

DDPGLearner.__doc__="""DDPG is a continuous action, actor-critic model, first created in
(Lillicrap et al., 2016). The critic estimates a Q value estimate, and the actor
Expand Down
24 changes: 11 additions & 13 deletions fastrl/agents/dqn/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,19 @@
from collections import deque
from typing import Callable,Optional,List
# Third party libs
from fastcore.all import ifnone
import torchdata.datapipes as dp
from torchdata.dataloader2 import DataLoader2
from torchdata.dataloader2.graph import traverse_dps,DataPipe
import torch
import torch.nn.functional as F
from torch import optim
from torch import nn
import numpy as np
# Local modules
from ..core import AgentHead,AgentBase
from ...pipes.core import find_dp
from ...memory.experience_replay import ExperienceReplay
from ..core import StepFieldSelector,SimpleModelRunner,NumpyConverter
from ..discrete import EpsilonCollector,PyPrimativeConverter,ArgMaxer,EpsilonSelector
from fastrl.loggers.core import (
LogCollector,Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,is_record
Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,ProgressBarLogger
)
from ...learner.core import LearnerBase,LearnerHead,StepBatcher
from ...torch_core import Module
Expand Down Expand Up @@ -156,7 +152,7 @@ def __iter__(self):
def DQNLearner(
model,
dls,
logger_bases:Optional[Callable]=None,
do_logging:bool=True,
loss_func=nn.MSELoss(),
opt=optim.AdamW,
lr=0.005,
Expand All @@ -169,25 +165,27 @@ def DQNLearner(
learner = LearnerBase(model,dls[0])
learner = BatchCollector(learner,batches=batches)
learner = EpochCollector(learner)
if logger_bases:
learner = logger_bases(learner)
if do_logging:
learner = learner.dump_records()
learner = ProgressBarLogger(learner)
learner = RollingTerminatedRewardCollector(learner)
learner = EpisodeCollector(learner)
learner = learner.catch_records()
learner = learner.catch_records(drop=not do_logging)

learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz,freeze_memory=True)
learner = StepBatcher(learner,device=device)
learner = QCalc(learner)
learner = TargetCalc(learner,nsteps=nsteps)
learner = LossCalc(learner,loss_func=loss_func)
learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
if logger_bases:
if do_logging:
learner = LossCollector(learner).catch_records()

if len(dls)==2:
val_learner = LearnerBase(model,dls[1])
val_learner = LearnerBase(model,dls[1]).visualize_vscode()
val_learner = BatchCollector(val_learner,batches=batches)
val_learner = EpochCollector(val_learner).dump_records()
learner = LearnerHead((learner,val_learner),model)
learner = LearnerHead((learner,val_learner))
else:
learner = LearnerHead(learner,model)
learner = LearnerHead(learner)
return learner
19 changes: 9 additions & 10 deletions fastrl/agents/dqn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def CategoricalDQNAgent(
def DQNCategoricalLearner(
model,
dls,
logger_bases:Optional[Callable]=None,
do_logging:bool=True,
loss_func=PartialCrossEntropy,
opt=optim.AdamW,
lr=0.005,
Expand All @@ -276,29 +276,28 @@ def DQNCategoricalLearner(
learner = LearnerBase(model,dls=dls[0])
learner = BatchCollector(learner,batches=batches)
learner = EpochCollector(learner)
if logger_bases:
learner = logger_bases(learner)
if do_logging:
learner = learner.dump_records()
learner = ProgressBarLogger(learner)
learner = RollingTerminatedRewardCollector(learner)
learner = EpisodeCollector(learner)
learner = learner.catch_records()
learner = EpisodeCollector(learner).catch_records()
learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
learner = StepBatcher(learner,device=device)
learner = CategoricalTargetQCalc(learner,nsteps=nsteps,double_dqn_strategy=double_dqn_strategy).to(device=device)
# learner = TargetCalc(learner,nsteps=nsteps)
learner = LossCalc(learner,loss_func=loss_func)
learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
learner = TargetModelUpdater(learner,target_sync=target_sync)
if logger_bases:
if do_logging:
learner = LossCollector(learner).catch_records()

if len(dls)==2:
val_learner = LearnerBase(model,dls[1])
val_learner = LearnerBase(model,dls[1]).visualize_vscode()
val_learner = BatchCollector(val_learner,batches=batches)
val_learner = EpochCollector(val_learner).catch_records(drop=True)
val_learner = VSCodeDataPipe(val_learner)
return LearnerHead((learner,val_learner),model)
return LearnerHead((learner,val_learner))
else:
return LearnerHead(learner,model)
return LearnerHead(learner)

# %% ../../../nbs/07_Agents/01_Discrete/12o_agents.dqn.categorical.ipynb 49
def show_q(cat_dist,title='Update Distributions'):
Expand Down
23 changes: 11 additions & 12 deletions fastrl/agents/dqn/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ...loggers.core import BatchCollector,EpochCollector
from ...learner.core import LearnerBase,LearnerHead
from ...loggers.vscode_visualizers import VSCodeDataPipe
from ...loggers.core import ProgressBarLogger
from fastrl.agents.dqn.basic import (
LossCollector,
RollingTerminatedRewardCollector,
Expand All @@ -36,7 +37,7 @@

# %% ../../../nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb 5
class DoubleQCalc(dp.iter.IterDataPipe):
def __init__(self,source_datapipe=None):
def __init__(self,source_datapipe):
self.source_datapipe = source_datapipe

def __iter__(self):
Expand All @@ -53,7 +54,7 @@ def __iter__(self):
def DoubleDQNLearner(
model,
dls,
logger_bases:Optional[Callable]=None,
do_logging:bool=True,
loss_func=nn.MSELoss(),
opt=optim.AdamW,
lr=0.005,
Expand All @@ -67,27 +68,25 @@ def DoubleDQNLearner(
learner = LearnerBase(model,dls=dls[0])
learner = BatchCollector(learner,batches=batches)
learner = EpochCollector(learner)
if logger_bases:
learner = logger_bases(learner)
if do_logging:
learner = learner.dump_records()
learner = ProgressBarLogger(learner)
learner = RollingTerminatedRewardCollector(learner)
learner = EpisodeCollector(learner)
learner = learner.catch_records()
learner = EpisodeCollector(learner).catch_records()
learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
learner = StepBatcher(learner,device=device)
# learner = TargetModelQCalc(learner)
learner = DoubleQCalc(learner)
learner = TargetCalc(learner,nsteps=nsteps)
learner = LossCalc(learner,loss_func=loss_func)
learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
learner = TargetModelUpdater(learner,target_sync=target_sync)
if logger_bases:
if do_logging:
learner = LossCollector(learner).catch_records()

if len(dls)==2:
val_learner = LearnerBase(model,dls[1])
val_learner = LearnerBase(model,dls[1]).visualize_vscode()
val_learner = BatchCollector(val_learner,batches=batches)
val_learner = EpochCollector(val_learner).catch_records(drop=True)
val_learner = VSCodeDataPipe(val_learner)
return LearnerHead((learner,val_learner),model)
return LearnerHead((learner,val_learner))
else:
return LearnerHead(learner,model)
return LearnerHead(learner)
39 changes: 10 additions & 29 deletions fastrl/agents/dqn/dueling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,29 @@

# %% ../../../nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb 2
# Python native modules
from copy import deepcopy
from typing import Optional,Callable,Tuple
# Third party libs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import traverse_dps,DataPipe
import torch
from torch import nn,optim
# Local modulesf
from ...pipes.core import find_dp
from ...memory.experience_replay import ExperienceReplay
from ...loggers.core import BatchCollector,EpochCollector
from ...learner.core import LearnerBase,LearnerHead
from ...loggers.vscode_visualizers import VSCodeDataPipe
from torch import nn
# Local modules
from fastrl.agents.dqn.basic import (
LossCollector,
RollingTerminatedRewardCollector,
EpisodeCollector,
StepBatcher,
TargetCalc,
LossCalc,
ModelLearnCalc,
DQN,
DQNAgent
)
from fastrl.agents.dqn.target import (
TargetModelUpdater,
TargetModelQCalc,
DQNTargetLearner
)
from .target import DQNTargetLearner

# %% ../../../nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb 5
class DuelingHead(nn.Module):
def __init__(self,
hidden:int, # Input into the DuelingHead, likely a hidden layer input
n_actions:int, # Number/dim of actions to output
lin_cls=nn.Linear
def __init__(
self,
hidden: int, # Input into the DuelingHead, likely a hidden layer input
n_actions: int, # Number/dim of actions to output
lin_cls = nn.Linear
):
super().__init__()
self.val = lin_cls(hidden,1)
self.adv = lin_cls(hidden,n_actions)

def forward(self,xi):
val,adv=self.val(xi),self.adv(xi)
xi=val.expand_as(adv)+(adv-adv.mean()).squeeze(0)
val,adv = self.val(xi),self.adv(xi)
xi = val.expand_as(adv)+(adv-adv.mean()).squeeze(0)
return xi
Loading
Loading