From cde97539e829e89b946901fb3db677c1005cf056 Mon Sep 17 00:00:00 2001 From: josiah Date: Wed, 1 Nov 2023 15:20:30 +0000 Subject: [PATCH 1/9] Updated: - rainbow to allow validation --- .../01_Discrete/12r_agents.dqn.rainbow.ipynb | 7 +- .../02_Continuous/12t_agents.trpo.ipynb | 4 - .../02_Continuous/12u_agents.ppo.ipynb | 188 +----------------- 3 files changed, 15 insertions(+), 184 deletions(-) diff --git a/nbs/07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb b/nbs/07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb index 39cf9f5..d75287c 100644 --- a/nbs/07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb +++ b/nbs/07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb @@ -178,7 +178,12 @@ " firstlast=True,\n", " bs=1\n", ")\n", - "dls = dataloaders(train_pipe)\n", + "validation_pipe = GymDataPipe(\n", + " ['CartPole-v1']*1,\n", + " agent=agent,\n", + " include_images=True\n", + ")\n", + "dls = dataloaders((train_pipe,validation_pipe))\n", "# Setup the Learner\n", "learner = DQNRainbowLearner(\n", " model,\n", diff --git a/nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb b/nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb index 3bf94bf..382a547 100644 --- a/nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb +++ b/nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb @@ -1568,10 +1568,6 @@ "display_name": "python3", "language": "python", "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.8.0" } }, "nbformat": 4, diff --git a/nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb b/nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb index 7cd4fde..0c7a366 100644 --- a/nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb +++ b/nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "durable-dialogue", "metadata": {}, "outputs": [], @@ -14,7 +14,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "assisted-contract", "metadata": {}, "outputs": [], @@ -83,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "95d688b6", "metadata": {}, "outputs": [], @@ -183,7 +183,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "5d8cae3d", "metadata": {}, "outputs": [], @@ -246,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "734c996b", "metadata": {}, "outputs": [], @@ -258,155 +258,10 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "45814d94", "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "53963b72c8ad40e7ba672e89b1e5e646", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Epochs: 0%| | 0/10 [00:00\n", - " \n", - " \n", - " episode\n", - " rolling_reward\n", - " epoch\n", - " batch\n", - " std\n", - " mean\n", - " actor-loss\n", - " \n", - " \n", - " \n", - " \n", - " 4\n", - " -24.429207\n", - " 0\n", - " 9\n", - " 0.494225\n", - " 0.066967\n", - " -2.6739354\n", - " \n", - " \n", - " 8\n", - " -23.218408\n", - " 1\n", - " 9\n", - " 0.477826\n", - " 0.028452\n", - " -2.8338091\n", - " \n", - " \n", - " 12\n", - " -22.785492\n", - " 2\n", - " 9\n", - " 0.466163\n", - " 0.214959\n", - " -1.7720228\n", - " \n", - " \n", - " 16\n", - " -22.401961\n", - " 3\n", - " 9\n", - " 0.457776\n", - " -0.010530\n", - " -1.7371688\n", - " \n", - " \n", - " 20\n", - " -22.117896\n", - " 4\n", - " 9\n", - " 0.447333\n", - " -0.062435\n", - " -3.4784434\n", - " \n", - " \n", - " 24\n", - " -21.779138\n", - " 5\n", - " 9\n", - " 0.435938\n", - " -0.172362\n", - " -2.7109303\n", - " \n", - " \n", - " 28\n", - " -21.439993\n", - " 6\n", - " 9\n", - " 0.423310\n", - " -0.036928\n", - " -1.5233021\n", - " \n", - " \n", - " 32\n", - " -21.141646\n", - " 7\n", - " 9\n", - " 0.410206\n", - " -0.013177\n", - " -2.0156078\n", - " \n", - " \n", - " 36\n", - " -20.610226\n", - " 8\n", - " 9\n", - " 0.392541\n", - " 0.088212\n", - " -2.9441102\n", - " \n", - " \n", - " 40\n", - " -20.021371\n", - " 9\n", - " 9\n", - " 0.376230\n", - " 0.103705\n", - " -3.461331\n", - " \n", - " \n", - "" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "#|eval:false\n", "torch.manual_seed(0)\n", @@ -438,23 +293,10 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "9ddc7107", "metadata": {}, - "outputs": [ - { - "ename": "IndexError", - "evalue": "tuple index out of range", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/home/fastrl_user/fastrl/nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb Cell 8\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m learner\u001b[39m.\u001b[39;49mvalidate()\n", - "File \u001b[0;32m~/fastrl/fastrl/learner/core.py:107\u001b[0m, in \u001b[0;36mLearnerHead.validate\u001b[0;34m(self, epochs, batches, show)\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mvalidate\u001b[39m(\u001b[39mself\u001b[39m,epochs\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m,batches\u001b[39m=\u001b[39m\u001b[39m100\u001b[39m,show\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m DataPipe:\n\u001b[1;32m 106\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdp_idx \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[0;32m--> 107\u001b[0m epocher \u001b[39m=\u001b[39m find_dp(traverse_dps(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msource_datapipes[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdp_idx]),EpochCollector)\n\u001b[1;32m 108\u001b[0m epocher\u001b[39m.\u001b[39mepochs \u001b[39m=\u001b[39m epochs\n\u001b[1;32m 109\u001b[0m batcher \u001b[39m=\u001b[39m find_dp(traverse_dps(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39msource_datapipes[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdp_idx]),BatchCollector)\n", - "\u001b[0;31mIndexError\u001b[0m: tuple index out of range" - ] - } - ], + "outputs": [], "source": [ "learner.validate()" ] @@ -485,18 +327,6 @@ "display_name": "python3", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.0" } }, "nbformat": 4, From ed25d5ab0bec622578a191d3d2dcf703b1f1944a Mon Sep 17 00:00:00 2001 From: josiahls Date: Wed, 1 Nov 2023 13:06:57 -0400 Subject: [PATCH 2/9] Update fastrl-docker.yml --- .github/workflows/fastrl-docker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fastrl-docker.yml b/.github/workflows/fastrl-docker.yml index 969e0e4..ce03b5a 100644 --- a/.github/workflows/fastrl-docker.yml +++ b/.github/workflows/fastrl-docker.yml @@ -7,7 +7,7 @@ on: branches: - main - update_nbdev_docs - - feature/ppo + - refactor/bug-fix-api-update-and-stablize jobs: build: From 1c3deb4a4d2ff1513a393f22d3e8cee35988747e Mon Sep 17 00:00:00 2001 From: josiahls Date: Wed, 1 Nov 2023 13:17:44 -0400 Subject: [PATCH 3/9] Update fastrl-docker.yml --- .github/workflows/fastrl-docker.yml | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/.github/workflows/fastrl-docker.yml b/.github/workflows/fastrl-docker.yml index ce03b5a..09211e1 100644 --- a/.github/workflows/fastrl-docker.yml +++ b/.github/workflows/fastrl-docker.yml @@ -54,6 +54,15 @@ jobs: env: BUILD_TYPE: ${{ matrix.build_type }} + - name: Cache Docker layers + if: always() + uses: actions/cache@v2 + with: + path: /tmp/.buildx-cache + key: ${{ runner.os }}-buildx-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-buildx- + - name: build and tag container run: | export DOCKER_BUILDKIT=1 @@ -61,7 +70,8 @@ jobs: # docker system prune -fa docker pull ${IMAGE_NAME}:latest || true # docker build --build-arg BUILD=${BUILD_TYPE} \ - docker build --cache-from ${IMAGE_NAME}:latest --build-arg BUILD=${BUILD_TYPE} \ + docker buildx create --use + docker buildx --cache-from=type=local,src=/tmp/.buildx-cache --cache-to=type=local,dest=/tmp/.buildx-cache --cache-from ${IMAGE_NAME}:latest --build-arg BUILD=${BUILD_TYPE} \ -t ${IMAGE_NAME}:latest \ -t ${IMAGE_NAME}:${VERSION} \ -t ${IMAGE_NAME}:$(date +%F) \ From c392a95a3f52b685482090e0f91272b19e4e4a68 Mon Sep 17 00:00:00 2001 From: josiahls Date: Wed, 1 Nov 2023 13:19:52 -0400 Subject: [PATCH 4/9] Update fastrl-docker.yml --- .github/workflows/fastrl-docker.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/fastrl-docker.yml b/.github/workflows/fastrl-docker.yml index 09211e1..08e23e7 100644 --- a/.github/workflows/fastrl-docker.yml +++ b/.github/workflows/fastrl-docker.yml @@ -63,6 +63,9 @@ jobs: restore-keys: | ${{ runner.os }}-buildx- + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + - name: build and tag container run: | export DOCKER_BUILDKIT=1 From 2f51a433a1bac2611a351735b1958027f16b5762 Mon Sep 17 00:00:00 2001 From: josiahls Date: Wed, 1 Nov 2023 13:29:26 -0400 Subject: [PATCH 5/9] Update fastrl-docker.yml --- .github/workflows/fastrl-docker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fastrl-docker.yml b/.github/workflows/fastrl-docker.yml index 08e23e7..242ea6a 100644 --- a/.github/workflows/fastrl-docker.yml +++ b/.github/workflows/fastrl-docker.yml @@ -74,7 +74,7 @@ jobs: docker pull ${IMAGE_NAME}:latest || true # docker build --build-arg BUILD=${BUILD_TYPE} \ docker buildx create --use - docker buildx --cache-from=type=local,src=/tmp/.buildx-cache --cache-to=type=local,dest=/tmp/.buildx-cache --cache-from ${IMAGE_NAME}:latest --build-arg BUILD=${BUILD_TYPE} \ + docker buildx --cache-from=type=local,src=/tmp/.buildx-cache --cache-to=type=local,dest=/tmp/.buildx-cache --build-arg BUILD=${BUILD_TYPE} \ -t ${IMAGE_NAME}:latest \ -t ${IMAGE_NAME}:${VERSION} \ -t ${IMAGE_NAME}:$(date +%F) \ From e15125646c441f0961cdb019544fbb8a5fa0f985 Mon Sep 17 00:00:00 2001 From: josiahls Date: Wed, 1 Nov 2023 13:32:11 -0400 Subject: [PATCH 6/9] Update fastrl-docker.yml --- .github/workflows/fastrl-docker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fastrl-docker.yml b/.github/workflows/fastrl-docker.yml index 242ea6a..9c9a4b6 100644 --- a/.github/workflows/fastrl-docker.yml +++ b/.github/workflows/fastrl-docker.yml @@ -64,7 +64,7 @@ jobs: ${{ runner.os }}-buildx- - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: build and tag container run: | From 771c71a9392da2b93b7bd758849b508d4c2593ac Mon Sep 17 00:00:00 2001 From: josiahls Date: Wed, 1 Nov 2023 15:12:12 -0400 Subject: [PATCH 7/9] Update fastrl-docker.yml --- .github/workflows/fastrl-docker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fastrl-docker.yml b/.github/workflows/fastrl-docker.yml index 9c9a4b6..4608243 100644 --- a/.github/workflows/fastrl-docker.yml +++ b/.github/workflows/fastrl-docker.yml @@ -74,7 +74,7 @@ jobs: docker pull ${IMAGE_NAME}:latest || true # docker build --build-arg BUILD=${BUILD_TYPE} \ docker buildx create --use - docker buildx --cache-from=type=local,src=/tmp/.buildx-cache --cache-to=type=local,dest=/tmp/.buildx-cache --build-arg BUILD=${BUILD_TYPE} \ + docker buildx build --cache-from=type=local,src=/tmp/.buildx-cache --cache-to=type=local,dest=/tmp/.buildx-cache --build-arg BUILD=${BUILD_TYPE} \ -t ${IMAGE_NAME}:latest \ -t ${IMAGE_NAME}:${VERSION} \ -t ${IMAGE_NAME}:$(date +%F) \ From 64d8c0aed1cb8b2612853811dde2daa64700a48b Mon Sep 17 00:00:00 2001 From: josiahls Date: Wed, 1 Nov 2023 15:16:05 -0400 Subject: [PATCH 8/9] Update fastrl-docker.yml --- .github/workflows/fastrl-docker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fastrl-docker.yml b/.github/workflows/fastrl-docker.yml index 4608243..98978c6 100644 --- a/.github/workflows/fastrl-docker.yml +++ b/.github/workflows/fastrl-docker.yml @@ -71,7 +71,7 @@ jobs: export DOCKER_BUILDKIT=1 # We need to clear the previous docker images # docker system prune -fa - docker pull ${IMAGE_NAME}:latest || true + # docker pull ${IMAGE_NAME}:latest || true # docker build --build-arg BUILD=${BUILD_TYPE} \ docker buildx create --use docker buildx build --cache-from=type=local,src=/tmp/.buildx-cache --cache-to=type=local,dest=/tmp/.buildx-cache --build-arg BUILD=${BUILD_TYPE} \ From 3aad10eac9ac67d11e75be09453471faeee68f3a Mon Sep 17 00:00:00 2001 From: josiah Date: Thu, 2 Nov 2023 16:15:09 +0000 Subject: [PATCH 9/9] Changed: - all learners for all agents now have fit and validation capabilities - cleaned up a ton of stuff - TRPO and PPO still dont work right. I'll probably try to tacking this is another pr --- .github/workflows/fastrl-docker.yml | 2 +- fastrl/_modidx.py | 4 +- fastrl/agents/ddpg.py | 53 +++-- fastrl/agents/dqn/basic.py | 24 +-- fastrl/agents/dqn/categorical.py | 19 +- fastrl/agents/dqn/double.py | 23 +-- fastrl/agents/dqn/dueling.py | 39 +--- fastrl/agents/dqn/rainbow.py | 22 +- fastrl/agents/dqn/target.py | 30 ++- fastrl/agents/ppo.py | 137 +++++++------ fastrl/agents/trpo.py | 47 ++--- fastrl/learner/core.py | 26 ++- fastrl/loggers/vscode_visualizers.py | 17 +- .../09f_loggers.vscode_visualizers.ipynb | 37 +++- nbs/06_Learning/10a_learner.core.ipynb | 30 +-- .../01_Discrete/12g_agents.dqn.basic.ipynb | 194 +++--------------- .../01_Discrete/12h_agents.dqn.target.ipynb | 150 +++----------- .../01_Discrete/12m_agents.dqn.double.ipynb | 89 ++------ .../01_Discrete/12n_agents.dqn.dueling.ipynb | 59 ++---- .../12o_agents.dqn.categorical.ipynb | 75 ++----- .../01_Discrete/12r_agents.dqn.rainbow.ipynb | 46 ++--- .../02_Continuous/12s_agents.ddpg.ipynb | 80 ++++---- .../02_Continuous/12t_agents.trpo.ipynb | 136 +++--------- .../02_Continuous/12u_agents.ppo.ipynb | 84 ++++---- 24 files changed, 483 insertions(+), 940 deletions(-) diff --git a/.github/workflows/fastrl-docker.yml b/.github/workflows/fastrl-docker.yml index 969e0e4..ce03b5a 100644 --- a/.github/workflows/fastrl-docker.yml +++ b/.github/workflows/fastrl-docker.yml @@ -7,7 +7,7 @@ on: branches: - main - update_nbdev_docs - - feature/ppo + - refactor/bug-fix-api-update-and-stablize jobs: build: diff --git a/fastrl/_modidx.py b/fastrl/_modidx.py index 0468a37..8988c52 100644 --- a/fastrl/_modidx.py +++ b/fastrl/_modidx.py @@ -645,7 +645,9 @@ 'fastrl.loggers.vscode_visualizers.SimpleVSCodeVideoPlayer.show': ( '05_Logging/loggers.vscode_visualizers.html#simplevscodevideoplayer.show', 'fastrl/loggers/vscode_visualizers.py'), 'fastrl.loggers.vscode_visualizers.VSCodeDataPipe': ( '05_Logging/loggers.vscode_visualizers.html#vscodedatapipe', - 'fastrl/loggers/vscode_visualizers.py')}, + 'fastrl/loggers/vscode_visualizers.py'), + 'fastrl.loggers.vscode_visualizers.VSCodeDataPipe.__new__': ( '05_Logging/loggers.vscode_visualizers.html#vscodedatapipe.__new__', + 'fastrl/loggers/vscode_visualizers.py')}, 'fastrl.memory.experience_replay': { 'fastrl.memory.experience_replay.ExperienceReplay': ( '04_Memory/memory.experience_replay.html#experiencereplay', 'fastrl/memory/experience_replay.py'), 'fastrl.memory.experience_replay.ExperienceReplay.__init__': ( '04_Memory/memory.experience_replay.html#experiencereplay.__init__', diff --git a/fastrl/agents/ddpg.py b/fastrl/agents/ddpg.py index 9a93b9d..908faef 100644 --- a/fastrl/agents/ddpg.py +++ b/fastrl/agents/ddpg.py @@ -7,34 +7,32 @@ 'CriticLossProcessor', 'ActorLossProcessor', 'DDPGLearner'] # %% ../../nbs/07_Agents/02_Continuous/12s_agents.ddpg.ipynb 2 -# # Python native modules -# import os +# Python native modules from typing import Tuple,Optional,Callable,Union,Dict,Literal,List from functools import partial -# from typing_extensions import Literal from copy import deepcopy -# # Third party libs +# Third party libs from fastcore.all import add_docs import torchdata.datapipes as dp from torchdata.dataloader2.graph import traverse_dps,find_dps,DataPipe -# from torchdata.dataloader2.graph import DataPipe,traverse from torch import nn -# from torch.optim import AdamW,Adam import torch -# import pandas as pd -# import numpy as np -# # Local modules +# Local modules from ..core import SimpleStep from ..pipes.core import find_dp from ..torch_core import Module from ..memory.experience_replay import ExperienceReplay -from ..loggers.core import Record,is_record,not_record,_RECORD_CATCH_LIST from ..learner.core import LearnerBase,LearnerHead,StepBatcher -# from fastrl.pipes.core import * -# from fastrl.data.block import * -# from fastrl.data.dataloader2 import * +from ..loggers.vscode_visualizers import VSCodeDataPipe from fastrl.loggers.core import ( - LogCollector,Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,is_record + ProgressBarLogger, + Record, + BatchCollector, + EpochCollector, + RollingTerminatedRewardCollector, + EpisodeCollector, + not_record, + _RECORD_CATCH_LIST ) from fastrl.agents.core import ( AgentHead, @@ -43,9 +41,6 @@ SimpleModelRunner, NumpyConverter ) -# from fastrl.memory.experience_replay import ExperienceReplay -# from fastrl.learner.core import * -# from fastrl.loggers.core import * # %% ../../nbs/07_Agents/02_Continuous/12s_agents.ddpg.ipynb 6 def init_xavier_uniform_weights(m:Module,bias=0.01): @@ -798,7 +793,7 @@ def DDPGLearner( critic:Critic, # A list of dls, where index=0 is the training dl. dls, - logger_bases:Optional[Callable]=None, + do_logging:bool=True, # The learning rate for the actor. Expected to learn slower than the critic actor_lr:float=1e-3, # The optimizer for the actor @@ -830,11 +825,12 @@ def DDPGLearner( # Debug mode will output device moves debug:bool=False ) -> LearnerHead: - learner = LearnerBase(actor,dls[0]) + learner = LearnerBase({'actor':actor,'critic':critic},dls[0]) learner = BatchCollector(learner,batches=batches) learner = EpochCollector(learner) - if logger_bases: - learner = logger_bases(learner) + if do_logging: + learner = learner.dump_records() + learner = ProgressBarLogger(learner) learner = RollingTerminatedRewardCollector(learner) learner = EpisodeCollector(learner).catch_records() learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz) @@ -847,12 +843,15 @@ def DDPGLearner( learner = ActorLossProcessor(learner,critic,actor,clip_critic_grad=5) learner = LossCollector(learner,title='actor-loss').catch_records() learner = BasicOptStepper(learner,actor,actor_lr,opt=actor_opt,filter=True,do_zero_grad=False) - learner = LearnerHead(learner,(actor,critic)) - - # for dl in dls: - # pipe_to_device(dl.datapipe,device,debug=debug) - - return learner + learner = LearnerHead(learner) + + if len(dls)==2: + val_learner = LearnerBase({'actor':actor,'critic':critic},dls[1]).visualize_vscode() + val_learner = BatchCollector(val_learner,batches=batches) + val_learner = EpochCollector(val_learner).catch_records(drop=True) + return LearnerHead((learner,val_learner)) + else: + return LearnerHead(learner) DDPGLearner.__doc__="""DDPG is a continuous action, actor-critic model, first created in (Lillicrap et al., 2016). The critic estimates a Q value estimate, and the actor diff --git a/fastrl/agents/dqn/basic.py b/fastrl/agents/dqn/basic.py index 7ccb797..abbca76 100644 --- a/fastrl/agents/dqn/basic.py +++ b/fastrl/agents/dqn/basic.py @@ -10,15 +10,11 @@ from collections import deque from typing import Callable,Optional,List # Third party libs -from fastcore.all import ifnone import torchdata.datapipes as dp -from torchdata.dataloader2 import DataLoader2 from torchdata.dataloader2.graph import traverse_dps,DataPipe import torch -import torch.nn.functional as F from torch import optim from torch import nn -import numpy as np # Local modules from ..core import AgentHead,AgentBase from ...pipes.core import find_dp @@ -26,7 +22,7 @@ from ..core import StepFieldSelector,SimpleModelRunner,NumpyConverter from ..discrete import EpsilonCollector,PyPrimativeConverter,ArgMaxer,EpsilonSelector from fastrl.loggers.core import ( - LogCollector,Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,is_record + Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,ProgressBarLogger ) from ...learner.core import LearnerBase,LearnerHead,StepBatcher from ...torch_core import Module @@ -156,7 +152,7 @@ def __iter__(self): def DQNLearner( model, dls, - logger_bases:Optional[Callable]=None, + do_logging:bool=True, loss_func=nn.MSELoss(), opt=optim.AdamW, lr=0.005, @@ -169,25 +165,27 @@ def DQNLearner( learner = LearnerBase(model,dls[0]) learner = BatchCollector(learner,batches=batches) learner = EpochCollector(learner) - if logger_bases: - learner = logger_bases(learner) + if do_logging: + learner = learner.dump_records() + learner = ProgressBarLogger(learner) learner = RollingTerminatedRewardCollector(learner) learner = EpisodeCollector(learner) - learner = learner.catch_records() + learner = learner.catch_records(drop=not do_logging) + learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz,freeze_memory=True) learner = StepBatcher(learner,device=device) learner = QCalc(learner) learner = TargetCalc(learner,nsteps=nsteps) learner = LossCalc(learner,loss_func=loss_func) learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr)) - if logger_bases: + if do_logging: learner = LossCollector(learner).catch_records() if len(dls)==2: - val_learner = LearnerBase(model,dls[1]) + val_learner = LearnerBase(model,dls[1]).visualize_vscode() val_learner = BatchCollector(val_learner,batches=batches) val_learner = EpochCollector(val_learner).dump_records() - learner = LearnerHead((learner,val_learner),model) + learner = LearnerHead((learner,val_learner)) else: - learner = LearnerHead(learner,model) + learner = LearnerHead(learner) return learner diff --git a/fastrl/agents/dqn/categorical.py b/fastrl/agents/dqn/categorical.py index 5329953..ced4656 100644 --- a/fastrl/agents/dqn/categorical.py +++ b/fastrl/agents/dqn/categorical.py @@ -261,7 +261,7 @@ def CategoricalDQNAgent( def DQNCategoricalLearner( model, dls, - logger_bases:Optional[Callable]=None, + do_logging:bool=True, loss_func=PartialCrossEntropy, opt=optim.AdamW, lr=0.005, @@ -276,11 +276,11 @@ def DQNCategoricalLearner( learner = LearnerBase(model,dls=dls[0]) learner = BatchCollector(learner,batches=batches) learner = EpochCollector(learner) - if logger_bases: - learner = logger_bases(learner) + if do_logging: + learner = learner.dump_records() + learner = ProgressBarLogger(learner) learner = RollingTerminatedRewardCollector(learner) - learner = EpisodeCollector(learner) - learner = learner.catch_records() + learner = EpisodeCollector(learner).catch_records() learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz) learner = StepBatcher(learner,device=device) learner = CategoricalTargetQCalc(learner,nsteps=nsteps,double_dqn_strategy=double_dqn_strategy).to(device=device) @@ -288,17 +288,16 @@ def DQNCategoricalLearner( learner = LossCalc(learner,loss_func=loss_func) learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr)) learner = TargetModelUpdater(learner,target_sync=target_sync) - if logger_bases: + if do_logging: learner = LossCollector(learner).catch_records() if len(dls)==2: - val_learner = LearnerBase(model,dls[1]) + val_learner = LearnerBase(model,dls[1]).visualize_vscode() val_learner = BatchCollector(val_learner,batches=batches) val_learner = EpochCollector(val_learner).catch_records(drop=True) - val_learner = VSCodeDataPipe(val_learner) - return LearnerHead((learner,val_learner),model) + return LearnerHead((learner,val_learner)) else: - return LearnerHead(learner,model) + return LearnerHead(learner) # %% ../../../nbs/07_Agents/01_Discrete/12o_agents.dqn.categorical.ipynb 49 def show_q(cat_dist,title='Update Distributions'): diff --git a/fastrl/agents/dqn/double.py b/fastrl/agents/dqn/double.py index ba86428..37e62e5 100644 --- a/fastrl/agents/dqn/double.py +++ b/fastrl/agents/dqn/double.py @@ -18,6 +18,7 @@ from ...loggers.core import BatchCollector,EpochCollector from ...learner.core import LearnerBase,LearnerHead from ...loggers.vscode_visualizers import VSCodeDataPipe +from ...loggers.core import ProgressBarLogger from fastrl.agents.dqn.basic import ( LossCollector, RollingTerminatedRewardCollector, @@ -36,7 +37,7 @@ # %% ../../../nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb 5 class DoubleQCalc(dp.iter.IterDataPipe): - def __init__(self,source_datapipe=None): + def __init__(self,source_datapipe): self.source_datapipe = source_datapipe def __iter__(self): @@ -53,7 +54,7 @@ def __iter__(self): def DoubleDQNLearner( model, dls, - logger_bases:Optional[Callable]=None, + do_logging:bool=True, loss_func=nn.MSELoss(), opt=optim.AdamW, lr=0.005, @@ -67,27 +68,25 @@ def DoubleDQNLearner( learner = LearnerBase(model,dls=dls[0]) learner = BatchCollector(learner,batches=batches) learner = EpochCollector(learner) - if logger_bases: - learner = logger_bases(learner) + if do_logging: + learner = learner.dump_records() + learner = ProgressBarLogger(learner) learner = RollingTerminatedRewardCollector(learner) - learner = EpisodeCollector(learner) - learner = learner.catch_records() + learner = EpisodeCollector(learner).catch_records() learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz) learner = StepBatcher(learner,device=device) - # learner = TargetModelQCalc(learner) learner = DoubleQCalc(learner) learner = TargetCalc(learner,nsteps=nsteps) learner = LossCalc(learner,loss_func=loss_func) learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr)) learner = TargetModelUpdater(learner,target_sync=target_sync) - if logger_bases: + if do_logging: learner = LossCollector(learner).catch_records() if len(dls)==2: - val_learner = LearnerBase(model,dls[1]) + val_learner = LearnerBase(model,dls[1]).visualize_vscode() val_learner = BatchCollector(val_learner,batches=batches) val_learner = EpochCollector(val_learner).catch_records(drop=True) - val_learner = VSCodeDataPipe(val_learner) - return LearnerHead((learner,val_learner),model) + return LearnerHead((learner,val_learner)) else: - return LearnerHead(learner,model) + return LearnerHead(learner) diff --git a/fastrl/agents/dqn/dueling.py b/fastrl/agents/dqn/dueling.py index 764448a..7a80aa6 100644 --- a/fastrl/agents/dqn/dueling.py +++ b/fastrl/agents/dqn/dueling.py @@ -5,48 +5,29 @@ # %% ../../../nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb 2 # Python native modules -from copy import deepcopy -from typing import Optional,Callable,Tuple # Third party libs -import torchdata.datapipes as dp -from torchdata.dataloader2.graph import traverse_dps,DataPipe import torch -from torch import nn,optim -# Local modulesf -from ...pipes.core import find_dp -from ...memory.experience_replay import ExperienceReplay -from ...loggers.core import BatchCollector,EpochCollector -from ...learner.core import LearnerBase,LearnerHead -from ...loggers.vscode_visualizers import VSCodeDataPipe +from torch import nn +# Local modules from fastrl.agents.dqn.basic import ( - LossCollector, - RollingTerminatedRewardCollector, - EpisodeCollector, - StepBatcher, - TargetCalc, - LossCalc, - ModelLearnCalc, DQN, DQNAgent ) -from fastrl.agents.dqn.target import ( - TargetModelUpdater, - TargetModelQCalc, - DQNTargetLearner -) +from .target import DQNTargetLearner # %% ../../../nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb 5 class DuelingHead(nn.Module): - def __init__(self, - hidden:int, # Input into the DuelingHead, likely a hidden layer input - n_actions:int, # Number/dim of actions to output - lin_cls=nn.Linear + def __init__( + self, + hidden: int, # Input into the DuelingHead, likely a hidden layer input + n_actions: int, # Number/dim of actions to output + lin_cls = nn.Linear ): super().__init__() self.val = lin_cls(hidden,1) self.adv = lin_cls(hidden,n_actions) def forward(self,xi): - val,adv=self.val(xi),self.adv(xi) - xi=val.expand_as(adv)+(adv-adv.mean()).squeeze(0) + val,adv = self.val(xi),self.adv(xi) + xi = val.expand_as(adv)+(adv-adv.mean()).squeeze(0) return xi diff --git a/fastrl/agents/dqn/rainbow.py b/fastrl/agents/dqn/rainbow.py index 762d41a..229cca4 100644 --- a/fastrl/agents/dqn/rainbow.py +++ b/fastrl/agents/dqn/rainbow.py @@ -23,8 +23,9 @@ from ...memory.experience_replay import ExperienceReplay from ...loggers.core import BatchCollector,EpochCollector from ...learner.core import LearnerBase,LearnerHead -from ...loggers.vscode_visualizers import VSCodeDataPipe from ..core import AgentHead,AgentBase +from ...loggers.vscode_visualizers import VSCodeDataPipe +from ...loggers.core import ProgressBarLogger from fastrl.agents.dqn.basic import ( LossCollector, RollingTerminatedRewardCollector, @@ -52,7 +53,7 @@ def DQNRainbowLearner( model, dls, - logger_bases:Optional[Callable]=None, + do_logging:bool=True, loss_func=PartialCrossEntropy, opt=optim.AdamW, lr=0.005, @@ -68,25 +69,24 @@ def DQNRainbowLearner( learner = LearnerBase(model,dls=dls[0]) learner = BatchCollector(learner,batches=batches) learner = EpochCollector(learner) - if logger_bases: - learner = logger_bases(learner) + if do_logging: + learner = learner.dump_records() + learner = ProgressBarLogger(learner) learner = RollingTerminatedRewardCollector(learner) - learner = EpisodeCollector(learner) - learner = learner.catch_records() + learner = EpisodeCollector(learner).catch_records() learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz) learner = StepBatcher(learner,device=device) learner = CategoricalTargetQCalc(learner,nsteps=nsteps,double_dqn_strategy=double_dqn_strategy).to(device=device) learner = LossCalc(learner,loss_func=loss_func) learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr)) learner = TargetModelUpdater(learner,target_sync=target_sync) - if logger_bases: + if do_logging: learner = LossCollector(learner).catch_records() if len(dls)==2: - val_learner = LearnerBase(model,dls[1]) + val_learner = LearnerBase(model,dls[1]).visualize_vscode() val_learner = BatchCollector(val_learner,batches=batches) val_learner = EpochCollector(val_learner).catch_records(drop=True) - val_learner = VSCodeDataPipe(val_learner) - return LearnerHead((learner,val_learner),model) + return LearnerHead((learner,val_learner)) else: - return LearnerHead(learner,model) + return LearnerHead(learner) diff --git a/fastrl/agents/dqn/target.py b/fastrl/agents/dqn/target.py index 4ea9384..60f8cb1 100644 --- a/fastrl/agents/dqn/target.py +++ b/fastrl/agents/dqn/target.py @@ -18,6 +18,7 @@ from ...loggers.core import BatchCollector,EpochCollector from ...learner.core import LearnerBase,LearnerHead from ...loggers.vscode_visualizers import VSCodeDataPipe +from ...loggers.core import ProgressBarLogger from fastrl.agents.dqn.basic import ( LossCollector, RollingTerminatedRewardCollector, @@ -32,17 +33,15 @@ # %% ../../../nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb 7 class TargetModelUpdater(dp.iter.IterDataPipe): - def __init__(self,source_datapipe=None,target_sync=300): + def __init__(self,source_datapipe,target_sync=300): self.source_datapipe = source_datapipe - if source_datapipe is not None: - self.learner = find_dp(traverse_dps(self),LearnerBase) - with torch.no_grad(): - self.learner.target_model = deepcopy(self.learner.model) self.target_sync = target_sync self.n_batch = 0 + self.learner = find_dp(traverse_dps(self),LearnerBase) + with torch.no_grad(): + self.learner.target_model = deepcopy(self.learner.model) def reset(self): - print('resetting') self.learner = find_dp(traverse_dps(self),LearnerBase) with torch.no_grad(): self.learner.target_model = deepcopy(self.learner.model) @@ -76,7 +75,7 @@ def __iter__(self): def DQNTargetLearner( model, dls, - logger_bases:Optional[Callable]=None, + do_logging:bool=True, loss_func=nn.MSELoss(), opt=optim.AdamW, lr=0.005, @@ -90,11 +89,11 @@ def DQNTargetLearner( learner = LearnerBase(model,dls=dls[0]) learner = BatchCollector(learner,batches=batches) learner = EpochCollector(learner) - if logger_bases: - learner = logger_bases(learner) + if do_logging: + learner = learner.dump_records() + learner = ProgressBarLogger(learner) learner = RollingTerminatedRewardCollector(learner) - learner = EpisodeCollector(learner) - learner = learner.catch_records() + learner = EpisodeCollector(learner).catch_records() learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz) learner = StepBatcher(learner,device=device) learner = TargetModelQCalc(learner) @@ -102,14 +101,13 @@ def DQNTargetLearner( learner = LossCalc(learner,loss_func=loss_func) learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr)) learner = TargetModelUpdater(learner,target_sync=target_sync) - if logger_bases: + if do_logging: learner = LossCollector(learner).catch_records() if len(dls)==2: - val_learner = LearnerBase(model,dls[1]) + val_learner = LearnerBase(model,dls[1]).visualize_vscode() val_learner = BatchCollector(val_learner,batches=batches) val_learner = EpochCollector(val_learner).catch_records(drop=True) - val_learner = VSCodeDataPipe(val_learner) - return LearnerHead((learner,val_learner),model) + return LearnerHead((learner,val_learner)) else: - return LearnerHead(learner,model) + return LearnerHead(learner) diff --git a/fastrl/agents/ppo.py b/fastrl/agents/ppo.py index 826688f..bb62fe7 100644 --- a/fastrl/agents/ppo.py +++ b/fastrl/agents/ppo.py @@ -3,55 +3,59 @@ # %% auto 0 __all__ = ['PPOActorOptAndLossProcessor', 'PPOLearner'] -# %% ../../nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb 3 +# %% ../../nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb 2 # Python native modules -from typing import * -from typing_extensions import Literal -import typing -from warnings import warn -# Third party libs -import numpy as np +from typing import Union,Dict,Literal,List,Callable,Optional +# from typing_extensions import Literal +# import typing +# from warnings import warn +# # Third party libs +# import numpy as np import torch from torch import nn -from torch.distributions import * +# from torch.distributions import * import torchdata.datapipes as dp -from torchdata.dataloader2.graph import DataPipe,traverse,replace_dp -from fastcore.all import test_eq,test_ne,ifnone,L +from torchdata.dataloader2.graph import DataPipe,traverse_dps +# from fastcore.all import test_eq,test_ne,ifnone,L from torch.optim import AdamW,Adam -# Local modules -from ..core import * -from ..pipes.core import * -from ..torch_core import * -from ..layers import * -from ..data.block import * -from ..envs.gym import * -from .trpo import * -from ..loggers.vscode_visualizers import VSCodeTransformBlock -from ..loggers.jupyter_visualizers import ProgressBarLogger -from .discrete import EpsilonCollector -from .core import AgentHead,StepFieldSelector,AgentBase -from .ddpg import ActionClip,ActionUnbatcher,NumpyConverter,OrnsteinUhlenbeck,SimpleModelRunner -from ..loggers.core import LoggerBase,CacheLoggerBase -from ..dataloader2_ext import InputInjester -from ..loggers.core import LoggerBasePassThrough,BatchCollector,EpocherCollector,RollingTerminatedRewardCollector,EpisodeCollector +# # Local modules +from ..core import SimpleStep +# from fastrl.pipes.core import * +# from fastrl.torch_core import * +from ..layers import Critic +# from fastrl.data.block import * +# from fastrl.envs.gym import * +from .trpo import Actor +# from fastrl.loggers.vscode_visualizers import VSCodeTransformBlock +# from fastrl.loggers.jupyter_visualizers import ProgressBarLogger +# from fastrl.agents.discrete import EpsilonCollector +# from fastrl.agents.core import AgentHead,StepFieldSelector,AgentBase +# from fastrl.agents.ddpg import ActionClip,ActionUnbatcher,NumpyConverter,OrnsteinUhlenbeck,SimpleModelRunner +# from fastrl.loggers.core import LoggerBase,CacheLoggerBase +# from fastrl.dataloader2_ext import InputInjester +# from fastrl.loggers.core import LoggerBasePassThrough,BatchCollector,EpocherCollector,RollingTerminatedRewardCollector,EpisodeCollector from ..learner.core import LearnerBase,LearnerHead -from ..pipes.core import * -from ..pipes.iter.nskip import * -from ..pipes.iter.nstep import * -from ..pipes.iter.firstlast import * -from ..pipes.iter.transforms import * -from ..pipes.map.transforms import * -from ..data.block import * -from ..torch_core import * -from ..layers import * -from ..data.block import * -from ..envs.gym import * +from ..loggers.core import BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector +import fastrl.pipes.iter.cacheholder from .ddpg import LossCollector,BasicOptStepper,StepBatcher -from ..loggers.core import LogCollector -from .discrete import EpsilonCollector - - -# %% ../../nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb 5 +from .trpo import CriticLossProcessor +# from fastrl.pipes.core import * +# from fastrl.pipes.iter.nskip import * +# from fastrl.pipes.iter.nstep import * +# from fastrl.pipes.iter.firstlast import * +# from fastrl.pipes.iter.transforms import * +# from fastrl.pipes.map.transforms import * +# from fastrl.data.block import * +# from fastrl.torch_core import * +# from fastrl.layers import * +# from fastrl.data.block import * +# from fastrl.envs.gym import * +# from fastrl.agents.ddpg import LossCollector,BasicOptStepper,StepBatcher +# from fastrl.loggers.core import LogCollector +# from fastrl.agents.discrete import EpsilonCollector + + +# %% ../../nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb 4 class PPOActorOptAndLossProcessor(dp.iter.IterDataPipe): debug:bool=False @@ -67,9 +71,11 @@ def __init__(self, actor_opt:torch.optim.Optimizer=AdamW, # The optimizer to use critic_opt:torch.optim.Optimizer=AdamW, + ppo_epochs = 10, + ppo_batch_sz = 64, + ppo_eps = 0.2, # kwargs to be passed to the `opt` **opt_kwargs - ): self.source_datapipe = source_datapipe self.actor = actor @@ -84,9 +90,9 @@ def __init__(self, self.critic_loss = nn.MSELoss() self._critic_opt = self.critic_opt(self.critic.parameters(),lr=self.critic_lr,**self.opt_kwargs) self._actor_opt = self.actor_opt(self.actor.parameters(),lr=self.actor_lr,**self.opt_kwargs) - self.ppo_epochs = 10 - self.ppo_batch_sz = 64 - self.ppo_eps = 0.2 + self.ppo_epochs = ppo_epochs + self.ppo_batch_sz = ppo_batch_sz + self.ppo_eps = ppo_eps def to(self,*args,**kwargs): self.actor.to(**kwargs) @@ -141,16 +147,16 @@ def __iter__(self) -> Union[Dict[Literal['loss'],torch.Tensor],SimpleStep]: yield {'loss':loss} yield batch -# %% ../../nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb 6 +# %% ../../nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb 5 def PPOLearner( # The actor model to use actor:Actor, # The critic model to use critic:Critic, # A list of dls, where index=0 is the training dl. - dls:List[DataPipeOrDataLoader], + dls:List[object], # Optional logger bases to log training/validation data to. - logger_bases:Optional[List[LoggerBase]]=None, + logger_bases:Optional[Callable]=None, # The learning rate for the actor. Expected to learn slower than the critic actor_lr:float=1e-4, # The optimizer for the actor @@ -169,33 +175,28 @@ def PPOLearner( device:torch.device=None, # Number of batches per epoch batches:int=None, - # Any augmentations to the learner - dp_augmentation_fns:Optional[List[DataPipeAugmentationFn]]=None, # Debug mode will output device moves - debug:bool=False + debug:bool=False, + ppo_epochs = 10, + ppo_batch_sz = 64, + ppo_eps = 0.2, ) -> LearnerHead: - warn("") - - learner = LearnerBase(actor,dls,batches=batches) - learner = LoggerBasePassThrough(learner,logger_bases) - learner = BatchCollector(learner,batch_on_pipe=LearnerBase) - learner = EpocherCollector(learner) - for logger_base in L(logger_bases): learner = logger_base.connect_source_datapipe(learner) + learner = LearnerBase(actor,dls[0]) + learner = BatchCollector(learner,batches=batches) + learner = EpochCollector(learner) if logger_bases: + learner = logger_bases(learner) learner = RollingTerminatedRewardCollector(learner) - learner = EpisodeCollector(learner) + learner = EpisodeCollector(learner).catch_records() learner = StepBatcher(learner) # learner = CriticLossProcessor(learner,critic=critic) - # learner = LossCollector(learner,header='critic-loss') + # learner = LossCollector(learner,title='critic-loss').catch_records() # learner = BasicOptStepper(learner,critic,critic_lr,opt=critic_opt,filter=True,do_zero_grad=False) learner = PPOActorOptAndLossProcessor(learner,actor=actor,actor_lr=actor_lr, - critic=critic,critic_lr=critic_lr) - learner = LossCollector(learner,header='actor-loss',filter=True) - learner = LearnerHead(learner) - - learner = apply_dp_augmentation_fns(learner,dp_augmentation_fns) - pipe2device(learner,device,debug=debug) - for dl in dls: pipe2device(dl.datapipe,device,debug=debug) + critic=critic,critic_lr=critic_lr,ppo_epochs=ppo_epochs, + ppo_batch_sz=ppo_batch_sz,ppo_eps=ppo_eps) + learner = LossCollector(learner,title='actor-loss').catch_records() + learner = LearnerHead(learner,(actor,critic)) return learner diff --git a/fastrl/agents/trpo.py b/fastrl/agents/trpo.py index be738fd..1c17fc8 100644 --- a/fastrl/agents/trpo.py +++ b/fastrl/agents/trpo.py @@ -11,49 +11,37 @@ # Python native modules from typing import NamedTuple,List,Tuple,Optional,Dict,Literal,Callable,Union from functools import partial -# from typing_extensions import Literal # import typing from warnings import warn # Third party libs import numpy as np import torch from torch import nn -from torch.optim import AdamW,Adam +from torch.optim import Adam from torch.distributions import Independent,Normal import torchdata.datapipes as dp from torchdata.dataloader2.graph import DataPipe,traverse_dps -from fastcore.all import add_docs,store_attr,ifnone,L +from fastcore.all import add_docs,store_attr,L import gymnasium as gym from torchdata.dataloader2.graph import find_dps,traverse_dps -# from fastrl.data.dataloader2 import * -# from torchdata.dataloader2 import DataLoader2,DataLoader2Iterator -# from torchdata.dataloader2.graph import find_dps,traverse,DataPipe,IterDataPipe,MapDataPipe # Local modules from ..core import add_namedtuple_doc,SimpleStep,StepTypes from ..pipes.core import find_dp -from ..loggers.core import Record,is_record,not_record,_RECORD_CATCH_LIST +from ..loggers.core import Record,not_record,_RECORD_CATCH_LIST from ..torch_core import Module,evaluating from ..layers import Critic -# from fastrl.data.block import * from ..envs.gym import GymStepper -from ..pipes.iter.firstlast import FirstLastMerger from ..pipes.iter.nskip import NSkipper from ..pipes.iter.nstep import NStepper,NStepFlattener import fastrl.pipes.iter.cacheholder from .ddpg import LossCollector,BasicOptStepper,StepBatcher -# from fastrl.loggers.core import LogCollector -# from fastrl.agents.discrete import EpsilonCollector -# from copy import deepcopy from ..learner.core import LearnerBase,LearnerHead from ..loggers.core import BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector -# from fastrl.agents.ddpg import BasicOptStepper -# from fastrl.loggers.vscode_visualizers import VSCodeTransformBlock -# from fastrl.loggers.jupyter_visualizers import ProgressBarLogger -# from fastrl.layers import Critic -# from fastrl.agents.discrete import EpsilonCollector +from ..loggers.core import ProgressBarLogger,EpochCollector,BatchCollector +from ..loggers.vscode_visualizers import VSCodeDataPipe from .core import AgentHead,StepFieldSelector,AgentBase -from .ddpg import ActionClip,ActionUnbatcher,NumpyConverter,OrnsteinUhlenbeck,SimpleModelRunner +from .ddpg import ActionClip,ActionUnbatcher,NumpyConverter,SimpleModelRunner # %% ../../nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb 7 class AdvantageStep(NamedTuple): @@ -846,11 +834,7 @@ def TRPOLearner( # A list of dls, where index=0 is the training dl. dls:List[object], # Optional logger bases to log training/validation data to. - logger_bases:Optional[Callable]=None, - # The learning rate for the actor. Expected to learn slower than the critic - actor_lr:float=1e-3, - # The optimizer for the actor - actor_opt:torch.optim.Optimizer=Adam, + do_logging:bool=True, # The learning rate for the critic. Expected to learn faster than the actor critic_lr:float=1e-2, # The optimizer for the critic @@ -870,11 +854,12 @@ def TRPOLearner( ) -> LearnerHead: warn("TRPO only kind of converges. There is a likely a bug, however I am unable to identify until after PPO implimentation") - learner = LearnerBase(actor,dls[0]) + learner = LearnerBase({'actor':actor,'critic':critic},dls[0]) learner = BatchCollector(learner,batches=batches) learner = EpochCollector(learner) - if logger_bases: - learner = logger_bases(learner) + if do_logging: + learner = learner.dump_records() + learner = ProgressBarLogger(learner) learner = RollingTerminatedRewardCollector(learner) learner = EpisodeCollector(learner).catch_records() learner = StepBatcher(learner) @@ -883,7 +868,13 @@ def TRPOLearner( learner = BasicOptStepper(learner,critic,critic_lr,opt=critic_opt,filter=True,do_zero_grad=False) learner = ActorOptAndLossProcessor(learner,actor) learner = LossCollector(learner,title='actor-loss').catch_records() - learner = LearnerHead(learner,(actor,critic)) - return learner + + if len(dls)==2: + val_learner = LearnerBase({'actor':actor,'critic':critic},dls[1]).visualize_vscode() + val_learner = BatchCollector(val_learner,batches=batches) + val_learner = EpochCollector(val_learner).catch_records(drop=True) + return LearnerHead((learner,val_learner)) + else: + return LearnerHead(learner) TRPOLearner.__doc__="""""" diff --git a/fastrl/learner/core.py b/fastrl/learner/core.py index 525729c..31276e7 100644 --- a/fastrl/learner/core.py +++ b/fastrl/learner/core.py @@ -79,37 +79,42 @@ def __iter__(self): class LearnerHead(dp.iter.IterDataPipe): def __init__( self, - source_datapipes:Tuple[dp.iter.IterDataPipe], - model + source_datapipes:Tuple[dp.iter.IterDataPipe] ): if not isinstance(source_datapipes,tuple): self.source_datapipes = (source_datapipes,) else: self.source_datapipes = source_datapipes self.dp_idx = 0 - self.model = model def __iter__(self): yield from self.source_datapipes[self.dp_idx] def fit(self,epochs): self.dp_idx = 0 epocher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),EpochCollector) + learner = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),LearnerBase) epocher.epochs = epochs - if isinstance(self.model,tuple): - for m in self.model: + if isinstance(learner.model,dict): + for m in learner.model.values(): m.train() else: - self.model.train() + learner.model.train() for _ in self: pass - def validate(self,epochs=1,batches=100,show=True) -> DataPipe: + def validate(self,epochs=1,batches=100,show=True,return_outputs=False) -> DataPipe: self.dp_idx = 1 epocher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),EpochCollector) epocher.epochs = epochs batcher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),BatchCollector) batcher.batches = batches - with evaluating(self.model): - for _ in self: pass + learner = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),LearnerBase) + model = learner.model + model = tuple(model.values()) if isinstance(model,dict) else model + with evaluating(model): + if return_outputs: + return list(self) + else: + for _ in self: pass if show: pipes = list_dps(traverse_dps(self.source_datapipes[self.dp_idx])) for pipe in pipes: @@ -118,7 +123,8 @@ def validate(self,epochs=1,batches=100,show=True) -> DataPipe: add_docs( LearnerHead, -""" +"""LearnerHead can connect to multiple `LearnerBase`s and handles training +and validation execution. """, fit="Runs the `LearnerHead` pipeline for `epochs`", validate="""If there is more than 1 dl, then run 1 epoch of that dl based on diff --git a/fastrl/loggers/vscode_visualizers.py b/fastrl/loggers/vscode_visualizers.py index debea57..cc2bddc 100644 --- a/fastrl/loggers/vscode_visualizers.py +++ b/fastrl/loggers/vscode_visualizers.py @@ -13,6 +13,7 @@ from fastcore.all import add_docs,ifnone import matplotlib.pyplot as plt import torchdata.datapipes as dp +from torchdata.datapipes import functional_datapipe from IPython.core.display import Video,Image from torchdata.dataloader2 import DataLoader2,MultiProcessingReadingService # Local modules @@ -20,7 +21,7 @@ from ..pipes.core import DataPipeAugmentationFn,apply_dp_augmentation_fns from .jupyter_visualizers import ImageCollector -# %% ../../nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb 4 +# %% ../../nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb 5 class SimpleVSCodeVideoPlayer(dp.iter.IterDataPipe): def __init__(self, source_datapipe=None, @@ -69,10 +70,12 @@ def __iter__(self) -> Tuple[NamedTuple]: reset="Will reset the bytes object that is used to store file data." ) -# %% ../../nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb 5 -def VSCodeDataPipe(source:Iterable): - "This is the function that is actually run by `DataBlock`" - pipe = ImageCollector(source).dump_records() - pipe = SimpleVSCodeVideoPlayer(pipe) - return pipe +# %% ../../nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb 6 +@functional_datapipe('visualize_vscode') +class VSCodeDataPipe(dp.iter.IterDataPipe): + def __new__(self,source:Iterable): + "This is the function that is actually run by `DataBlock`" + pipe = ImageCollector(source).dump_records() + pipe = SimpleVSCodeVideoPlayer(pipe) + return pipe diff --git a/nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb b/nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb index 7d4ddac..0d1aa90 100644 --- a/nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb +++ b/nbs/05_Logging/09f_loggers.vscode_visualizers.ipynb @@ -29,6 +29,7 @@ "from fastcore.all import add_docs,ifnone\n", "import matplotlib.pyplot as plt\n", "import torchdata.datapipes as dp\n", + "from torchdata.datapipes import functional_datapipe\n", "from IPython.core.display import Video,Image\n", "from torchdata.dataloader2 import DataLoader2,MultiProcessingReadingService\n", "# Local modules\n", @@ -60,6 +61,16 @@ "based outputs. For vscode, we can generate a gif instead." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "fde1fa58", + "metadata": {}, + "outputs": [], + "source": [ + "dp.iter.Repeater" + ] + }, { "cell_type": "code", "execution_count": null, @@ -125,11 +136,13 @@ "outputs": [], "source": [ "#|export\n", - "def VSCodeDataPipe(source:Iterable):\n", - " \"This is the function that is actually run by `DataBlock`\"\n", - " pipe = ImageCollector(source).dump_records()\n", - " pipe = SimpleVSCodeVideoPlayer(pipe)\n", - " return pipe \n", + "@functional_datapipe('visualize_vscode')\n", + "class VSCodeDataPipe(dp.iter.IterDataPipe):\n", + " def __new__(self,source:Iterable):\n", + " \"This is the function that is actually run by `DataBlock`\"\n", + " pipe = ImageCollector(source).dump_records()\n", + " pipe = SimpleVSCodeVideoPlayer(pipe)\n", + " return pipe \n", " " ] }, @@ -143,6 +156,20 @@ "from fastrl.envs.gym import GymDataPipe" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b58a84d", + "metadata": {}, + "outputs": [], + "source": [ + "#|hide\n", + "pipe = GymDataPipe(['CartPole-v1'],None,n=100,seed=0,include_images=True).visualize_vscode()\n", + "\n", + "list(pipe);\n", + "pipe.show()" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/06_Learning/10a_learner.core.ipynb b/nbs/06_Learning/10a_learner.core.ipynb index a44672b..60556de 100644 --- a/nbs/06_Learning/10a_learner.core.ipynb +++ b/nbs/06_Learning/10a_learner.core.ipynb @@ -131,37 +131,42 @@ "class LearnerHead(dp.iter.IterDataPipe):\n", " def __init__(\n", " self,\n", - " source_datapipes:Tuple[dp.iter.IterDataPipe],\n", - " model\n", + " source_datapipes:Tuple[dp.iter.IterDataPipe]\n", " ):\n", " if not isinstance(source_datapipes,tuple):\n", " self.source_datapipes = (source_datapipes,)\n", " else:\n", " self.source_datapipes = source_datapipes\n", " self.dp_idx = 0\n", - " self.model = model\n", "\n", " def __iter__(self): yield from self.source_datapipes[self.dp_idx]\n", " \n", " def fit(self,epochs):\n", " self.dp_idx = 0\n", " epocher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),EpochCollector)\n", + " learner = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),LearnerBase)\n", " epocher.epochs = epochs\n", - " if isinstance(self.model,tuple):\n", - " for m in self.model: \n", + " if isinstance(learner.model,dict):\n", + " for m in learner.model.values(): \n", " m.train()\n", " else:\n", - " self.model.train()\n", + " learner.model.train()\n", " for _ in self: pass\n", "\n", - " def validate(self,epochs=1,batches=100,show=True) -> DataPipe:\n", + " def validate(self,epochs=1,batches=100,show=True,return_outputs=False) -> DataPipe:\n", " self.dp_idx = 1\n", " epocher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),EpochCollector)\n", " epocher.epochs = epochs\n", " batcher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),BatchCollector)\n", " batcher.batches = batches\n", - " with evaluating(self.model):\n", - " for _ in self: pass\n", + " learner = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),LearnerBase)\n", + " model = learner.model\n", + " model = tuple(model.values()) if isinstance(model,dict) else model\n", + " with evaluating(model):\n", + " if return_outputs:\n", + " return list(self)\n", + " else:\n", + " for _ in self: pass\n", " if show:\n", " pipes = list_dps(traverse_dps(self.source_datapipes[self.dp_idx]))\n", " for pipe in pipes:\n", @@ -170,7 +175,8 @@ " \n", "add_docs(\n", "LearnerHead,\n", - "\"\"\"\n", + "\"\"\"LearnerHead can connect to multiple `LearnerBase`s and handles training\n", + "and validation execution.\n", "\"\"\",\n", "fit=\"Runs the `LearnerHead` pipeline for `epochs`\",\n", "validate=\"\"\"If there is more than 1 dl, then run 1 epoch of that dl based on \n", @@ -217,7 +223,7 @@ " val_learner = BatchCollector(val_learner,batches=1000)\n", " val_learner = EpochCollector(val_learner)\n", "\n", - " learner = LearnerHead((learner,val_learner),model)\n", + " learner = LearnerHead((learner,val_learner))\n", " return learner\n", "\n", "dls = dataloaders((\n", @@ -341,7 +347,7 @@ " val_learner = BatchCollector(val_learner,batches=100)\n", " val_learner = EpochCollector(val_learner)\n", "\n", - " learner = LearnerHead((learner,val_learner),model)\n", + " learner = LearnerHead((learner,val_learner))\n", " return learner" ] }, diff --git a/nbs/07_Agents/01_Discrete/12g_agents.dqn.basic.ipynb b/nbs/07_Agents/01_Discrete/12g_agents.dqn.basic.ipynb index 263b557..5cb2e54 100644 --- a/nbs/07_Agents/01_Discrete/12g_agents.dqn.basic.ipynb +++ b/nbs/07_Agents/01_Discrete/12g_agents.dqn.basic.ipynb @@ -35,15 +35,11 @@ "from collections import deque\n", "from typing import Callable,Optional,List\n", "# Third party libs\n", - "from fastcore.all import ifnone\n", "import torchdata.datapipes as dp\n", - "from torchdata.dataloader2 import DataLoader2\n", "from torchdata.dataloader2.graph import traverse_dps,DataPipe\n", "import torch\n", - "import torch.nn.functional as F\n", "from torch import optim\n", "from torch import nn\n", - "import numpy as np\n", "# Local modules\n", "from fastrl.agents.core import AgentHead,AgentBase\n", "from fastrl.pipes.core import find_dp\n", @@ -51,7 +47,7 @@ "from fastrl.agents.core import StepFieldSelector,SimpleModelRunner,NumpyConverter\n", "from fastrl.agents.discrete import EpsilonCollector,PyPrimativeConverter,ArgMaxer,EpsilonSelector\n", "from fastrl.loggers.core import (\n", - " LogCollector,Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,is_record\n", + " Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,ProgressBarLogger\n", ")\n", "from fastrl.learner.core import LearnerBase,LearnerHead,StepBatcher\n", "from fastrl.torch_core import Module" @@ -193,8 +189,7 @@ "metadata": {}, "outputs": [], "source": [ - "from fastrl.envs.gym import GymDataPipe\n", - "from fastrl.loggers.core import ProgressBarLogger,EpochCollector,BatchCollector" + "from fastrl.envs.gym import GymDataPipe" ] }, { @@ -214,14 +209,10 @@ "pipe = GymDataPipe(['CartPole-v1']*1,agent=agent,n=10)\n", "pipe = BatchCollector(pipe,batches=5)\n", "pipe = EpochCollector(pipe,epochs=10).dump_records()\n", - "# dls = L(block.dataloaders(['CartPole-v1']*1,n=10,bs=1))\n", - "# pipes = list(block(['CartPole-v1']*1))\n", "# Setup Logger\n", "pipe = ProgressBarLogger(pipe)\n", "\n", - "# list(dls[0])\n", - "list(pipe);\n", - "# traverse_dps(agent)" + "list(pipe);" ] }, { @@ -360,7 +351,7 @@ "def DQNLearner(\n", " model,\n", " dls,\n", - " logger_bases:Optional[Callable]=None,\n", + " do_logging:bool=True,\n", " loss_func=nn.MSELoss(),\n", " opt=optim.AdamW,\n", " lr=0.005,\n", @@ -373,27 +364,29 @@ " learner = LearnerBase(model,dls[0])\n", " learner = BatchCollector(learner,batches=batches)\n", " learner = EpochCollector(learner)\n", - " if logger_bases: \n", - " learner = logger_bases(learner) \n", + " if do_logging: \n", + " learner = learner.dump_records()\n", + " learner = ProgressBarLogger(learner)\n", " learner = RollingTerminatedRewardCollector(learner)\n", " learner = EpisodeCollector(learner)\n", - " learner = learner.catch_records()\n", + " learner = learner.catch_records(drop=not do_logging)\n", + "\n", " learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz,freeze_memory=True)\n", " learner = StepBatcher(learner,device=device)\n", " learner = QCalc(learner)\n", " learner = TargetCalc(learner,nsteps=nsteps)\n", " learner = LossCalc(learner,loss_func=loss_func)\n", " learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))\n", - " if logger_bases: \n", + " if do_logging: \n", " learner = LossCollector(learner).catch_records()\n", "\n", " if len(dls)==2:\n", - " val_learner = LearnerBase(model,dls[1])\n", + " val_learner = LearnerBase(model,dls[1]).visualize_vscode()\n", " val_learner = BatchCollector(val_learner,batches=batches)\n", " val_learner = EpochCollector(val_learner).dump_records()\n", - " learner = LearnerHead((learner,val_learner),model)\n", + " learner = LearnerHead((learner,val_learner))\n", " else:\n", - " learner = LearnerHead(learner,model)\n", + " learner = LearnerHead(learner)\n", " return learner" ] }, @@ -413,7 +406,8 @@ "metadata": {}, "outputs": [], "source": [ - "from fastrl.dataloading.core import dataloaders" + "from fastrl.dataloading.core import dataloaders\n", + "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe" ] }, { @@ -424,31 +418,28 @@ "outputs": [], "source": [ "#|eval:false\n", - "# Setup Loggers\n", - "def logger_bases(pipe):\n", - " pipe = pipe.dump_records()\n", - " pipe = ProgressBarLogger(pipe)\n", - " return pipe\n", - "\n", "# Setup up the core NN\n", "torch.manual_seed(0)\n", "model = DQN(4,2).cuda()\n", "# Setup the Agent\n", "agent = DQNAgent(model,do_logging=True,max_steps=4000,device='cuda')\n", "# Setup the DataBlock\n", + "params = dict(source=['CartPole-v1']*1,agent=agent,nsteps=1,nskips=1,firstlast=False,bs=1)\n", "dls = dataloaders((\n", - " GymDataPipe(['CartPole-v1']*1,agent=agent,nsteps=1,nskips=1,firstlast=False,bs=1),\n", - " GymDataPipe(['CartPole-v1']*1,agent=agent,nsteps=1,nskips=1,firstlast=False,bs=1)\n", + " GymDataPipe(**params), GymDataPipe(**params,include_images=True)\n", "))\n", "\n", "# Setup the Learner\n", - "learner = DQNLearner(model,dls,batches=1000,\n", - " logger_bases=logger_bases,\n", - " bs=128,\n", - " max_sz=1000,device='cuda')\n", - "# learner.fit(3)\n", - "learner.fit(5)\n", - "\n" + "learner = DQNLearner(\n", + " model,\n", + " dls,\n", + " batches=1000,\n", + " bs=128,\n", + " max_sz=1000,\n", + " device='cuda'\n", + ")\n", + "# learner.fit(1)\n", + "learner.fit(5)" ] }, { @@ -467,136 +458,7 @@ "id": "d5e0ed73-7bbf-415b-9ee7-9a95de31d638", "metadata": {}, "source": [ - "If we try a regular DQN with nsteps/nskips it doesnt really converge after 130. We cant expect stability at all, and im pretty sure that nsteps (correctly) tries to reduce to number of duplicated states so that the agent can sample more unique state transitions. The problem with this is that the base dqn is not stable, so giving it lots of \"new\" stuff, im not sure helps. In otherwords, its going to forget the old stuff very quickly, and having duplicate states helps \"remind it\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c920552b", - "metadata": {}, - "outputs": [], - "source": [ - "import logging\n", - "from fastrl.core import default_logging" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c95d510e-38c1-458c-9830-df5a68e6a53c", - "metadata": {}, - "outputs": [], - "source": [ - "# Setup Loggers\n", - "def logger_bases(pipe):\n", - " pipe = pipe.dump_records()\n", - " pipe = ProgressBarLogger(pipe)\n", - " return pipe\n", - "# Setup up the core NN\n", - "torch.manual_seed(0)\n", - "model = DQN(4,2)\n", - "# Setup the Agent\n", - "agent = DQNAgent(model,do_logging=True,max_steps=10000)\n", - "# Setup the DataBlock\n", - "dls = dataloaders(\n", - " GymDataPipe(['CartPole-v1']*1,agent=agent,nsteps=1,nskips=1,firstlast=False,bs=1)\n", - ")\n", - "\n", - "# Setup the Learner\n", - "learner = DQNLearner(model,dls,batches=10,logger_bases=logger_bases,bs=128,max_sz=20_000,lr=0.001)\n", - "\n", - "# del agent\n", - "# learner.fit(3)\n", - "learner.fit(2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9791de5b", - "metadata": {}, - "outputs": [], - "source": [ - "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c8784650-6f5c-42b7-9a72-68a0d37d8983", - "metadata": {}, - "outputs": [], - "source": [ - "#|hide\n", - "#|eval: false\n", - "model.eval()\n", - "\n", - "pipe = GymDataPipe(['CartPole-v1']*1,agent=agent,n=1000,seed=0,include_images=True)\n", - "pipe = VSCodeDataPipe(pipe)\n", - "\n", - "list(pipe);\n", - "pipe.show()\n", - "# list(pipe)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dccaa785-605b-4e75-bff7-bae8c5603817", - "metadata": {}, - "outputs": [], - "source": [ - "# Setup Loggers\n", - "def logger_bases(pipe):\n", - " pipe = pipe.dump_records()\n", - " pipe = ProgressBarLogger(\n", - " pipe,\n", - " epoch_on_pipe=EpochCollector,\n", - " batch_on_pipe=BatchCollector\n", - " )\n", - " return pipe\n", - "\n", - "# Setup up the core NN\n", - "torch.manual_seed(0)\n", - "model = DQN(8,4)\n", - "# Setup the Agent\n", - "agent = DQNAgent(model,do_logging=True)\n", - "# Setup the DataBlock\n", - "dls = dataloaders(\n", - " GymDataPipe(['LunarLander-v2']*1,agent=agent,n=1000,bs=1)\n", - ")\n", - "\n", - "# Setup the Learner\n", - "learner = DQNLearner(model,dls,logger_bases=logger_bases,batches=1000)\n", - "learner.fit(3)\n", - "# learner.fit(30)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e5a8064c", - "metadata": {}, - "outputs": [], - "source": [ - "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "723a8a98-5091-4e31-9cca-9220c64ecdb7", - "metadata": {}, - "outputs": [], - "source": [ - "#|hide\n", - "#|eval: false\n", - "pipe = GymDataPipe(['LunarLander-v2']*1,agent=agent,n=1000,bs=1,include_images=True)\n", - "pipe = VSCodeDataPipe(pipe)\n", - "\n", - "list(pipe);\n", - "pipe.show(step=2)" + "If we try a regular DQN with nsteps/nskips it doesnt really converge after 130. We cannot expect stability at all, and nsteps (correctly) tries to reduce to number of duplicated states so that the agent can sample more unique state transitions. The problem with this is the base dqn is not stable, so giving it lots of \"new\" unique state transitions do not help. In otherwords, its going to forget the old stuff very quickly, and having duplicate states helps \"remind it\"" ] }, { diff --git a/nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb b/nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb index 426b125..2fdd354 100644 --- a/nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb +++ b/nbs/07_Agents/01_Discrete/12h_agents.dqn.target.ipynb @@ -44,6 +44,7 @@ "from fastrl.loggers.core import BatchCollector,EpochCollector\n", "from fastrl.learner.core import LearnerBase,LearnerHead\n", "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe\n", + "from fastrl.loggers.core import ProgressBarLogger\n", "from fastrl.agents.dqn.basic import (\n", " LossCollector,\n", " RollingTerminatedRewardCollector,\n", @@ -108,17 +109,15 @@ "source": [ "#|export\n", "class TargetModelUpdater(dp.iter.IterDataPipe):\n", - " def __init__(self,source_datapipe=None,target_sync=300):\n", + " def __init__(self,source_datapipe,target_sync=300):\n", " self.source_datapipe = source_datapipe\n", - " if source_datapipe is not None:\n", - " self.learner = find_dp(traverse_dps(self),LearnerBase)\n", - " with torch.no_grad():\n", - " self.learner.target_model = deepcopy(self.learner.model)\n", " self.target_sync = target_sync\n", " self.n_batch = 0\n", + " self.learner = find_dp(traverse_dps(self),LearnerBase)\n", + " with torch.no_grad():\n", + " self.learner.target_model = deepcopy(self.learner.model)\n", " \n", " def reset(self):\n", - " print('resetting')\n", " self.learner = find_dp(traverse_dps(self),LearnerBase)\n", " with torch.no_grad():\n", " self.learner.target_model = deepcopy(self.learner.model)\n", @@ -168,7 +167,7 @@ "def DQNTargetLearner(\n", " model,\n", " dls,\n", - " logger_bases:Optional[Callable]=None,\n", + " do_logging:bool=True,\n", " loss_func=nn.MSELoss(),\n", " opt=optim.AdamW,\n", " lr=0.005,\n", @@ -182,11 +181,11 @@ " learner = LearnerBase(model,dls=dls[0])\n", " learner = BatchCollector(learner,batches=batches)\n", " learner = EpochCollector(learner)\n", - " if logger_bases: \n", - " learner = logger_bases(learner)\n", + " if do_logging: \n", + " learner = learner.dump_records()\n", + " learner = ProgressBarLogger(learner)\n", " learner = RollingTerminatedRewardCollector(learner)\n", - " learner = EpisodeCollector(learner)\n", - " learner = learner.catch_records()\n", + " learner = EpisodeCollector(learner).catch_records()\n", " learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)\n", " learner = StepBatcher(learner,device=device)\n", " learner = TargetModelQCalc(learner)\n", @@ -194,25 +193,16 @@ " learner = LossCalc(learner,loss_func=loss_func)\n", " learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))\n", " learner = TargetModelUpdater(learner,target_sync=target_sync)\n", - " if logger_bases: \n", + " if do_logging: \n", " learner = LossCollector(learner).catch_records()\n", "\n", " if len(dls)==2:\n", - " val_learner = LearnerBase(model,dls[1])\n", + " val_learner = LearnerBase(model,dls[1]).visualize_vscode()\n", " val_learner = BatchCollector(val_learner,batches=batches)\n", " val_learner = EpochCollector(val_learner).catch_records(drop=True)\n", - " val_learner = VSCodeDataPipe(val_learner)\n", - " return LearnerHead((learner,val_learner),model)\n", + " return LearnerHead((learner,val_learner))\n", " else:\n", - " return LearnerHead(learner,model)" - ] - }, - { - "cell_type": "markdown", - "id": "2b8f9ed8-fb05-40a1-ac0d-d4cafee8fa07", - "metadata": {}, - "source": [ - "Try training with basic defaults..." + " return LearnerHead(learner)" ] }, { @@ -222,20 +212,10 @@ "metadata": {}, "outputs": [], "source": [ - "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe\n", "from fastrl.envs.gym import GymDataPipe\n", - "from fastrl.loggers.core import ProgressBarLogger\n", "from fastrl.dataloading.core import dataloaders" ] }, - { - "cell_type": "markdown", - "id": "d5e0ed73-7bbf-415b-9ee7-9a95de31d638", - "metadata": {}, - "source": [ - "The DQN learns, but I wonder if we can get it to learn faster..." - ] - }, { "cell_type": "code", "execution_count": null, @@ -244,31 +224,24 @@ "outputs": [], "source": [ "#|eval:false\n", - "# Setup Loggers\n", - "def logger_bases(pipe):\n", - " pipe = pipe.dump_records()\n", - " pipe = ProgressBarLogger(pipe)\n", - " return pipe\n", "# Setup up the core NN\n", "torch.manual_seed(0)\n", "model = DQN(4,2)\n", "# Setup the Agent\n", "agent = DQNAgent(model,do_logging=True,min_epsilon=0.02,max_epsilon=1,max_steps=5000)\n", - "# Setup the DataBlock\n", - "train_pipe = GymDataPipe(\n", - " ['CartPole-v1']*1,\n", + "# Setup the DataLoader\n", + "params = dict(\n", + " source=['CartPole-v1']*1,\n", " agent=agent,\n", " nsteps=2,\n", " nskips=2,\n", - " firstlast=True,\n", - " bs=1\n", + " firstlast=True\n", ")\n", - "dls = dataloaders(train_pipe)\n", + "dls = dataloaders((GymDataPipe(**params),GymDataPipe(**params,include_images=True).unbatch()))\n", "# Setup the Learner\n", "learner = DQNTargetLearner(\n", " model,\n", " dls,\n", - " logger_bases=logger_bases,\n", " bs=128,\n", " max_sz=100_000,\n", " nsteps=2,\n", @@ -276,72 +249,28 @@ " batches=1000,\n", " target_sync=300\n", ")\n", - "learner.fit(7)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5c6eefa2", - "metadata": {}, - "outputs": [], - "source": [ - "# exp_replay = find_dp(traverse_dps(learner),ExperienceReplay)\n", - "# exp_replay.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8c4cbfaa", - "metadata": {}, - "outputs": [], - "source": [ - "from fastrl.agents.core import AgentHead,AgentBase\n", - "from fastrl.agents.core import SimpleModelRunner,NumpyConverter\n", - "from fastrl.agents.discrete import ArgMaxer\n", - "from fastrl.memory.memory_visualizer import MemoryBufferViewer" + "# learner.fit(7)\n", + "learner.fit(1)" ] }, { "cell_type": "code", "execution_count": null, - "id": "a5f88a93", + "id": "a6888519", "metadata": {}, "outputs": [], "source": [ - "val_agent = DQNAgent(model,min_epsilon=0,max_epsilon=0)\n", - "valid_pipe = GymDataPipe(\n", - " ['CartPole-v1']*1,\n", - " agent=val_agent,\n", - " nsteps=2,\n", - " nskips=2,\n", - " firstlast=True,\n", - " bs=1,\n", - " n=100,\n", - " include_images=False\n", - ")\n", - "# valid_pipe = VSCodeDataPipe(valid_pipe)\n", - "sample_run = [o[0] for o in valid_pipe.catch_records(drop=True)];" + "learner.validate(2)" ] }, { "cell_type": "code", "execution_count": null, - "id": "4ff8e556", + "id": "87f2a552", "metadata": {}, "outputs": [], "source": [ - "#|hide\n", - "#|eval: false\n", - "model.eval()\n", - "\n", - "pipe = GymDataPipe(['CartPole-v1']*1,agent=agent,n=1000,seed=0,include_images=True)\n", - "pipe = VSCodeDataPipe(pipe)\n", - "\n", - "list(pipe);\n", - "pipe.show()\n", - "# list(pipe)" + "sample_run = learner.validate(2,show=False,return_outputs=True)" ] }, { @@ -351,6 +280,9 @@ "metadata": {}, "outputs": [], "source": [ + "from fastrl.agents.core import AgentHead,AgentBase\n", + "from fastrl.agents.core import SimpleModelRunner\n", + "from fastrl.memory.memory_visualizer import MemoryBufferViewer\n", "from fastrl.agents.core import StepFieldSelector" ] }, @@ -368,37 +300,13 @@ " agent_base = AgentBase(model)\n", " agent = StepFieldSelector(agent_base,field='state')\n", " agent = SimpleModelRunner(agent).to(device=device)\n", - " # agent = ArgMaxer(agent,only_idx=True)\n", - " # agent = NumpyConverter(agent)\n", - " # agent = PyPrimativeConverter(agent)\n", " agent = AgentHead(agent)\n", " return agent\n", "\n", - "val_agent = DQNValAgent(model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1ab57ad7", - "metadata": {}, - "outputs": [], - "source": [ + "val_agent = DQNValAgent(model)\n", "MemoryBufferViewer(sample_run,val_agent)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "f5b57c4d", - "metadata": {}, - "outputs": [], - "source": [ - "#|hide\n", - "#|eval:false\n", - "# learner.validate(epochs=1,batches=200)" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb b/nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb index 95a3c2b..cefc8f2 100644 --- a/nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb +++ b/nbs/07_Agents/01_Discrete/12m_agents.dqn.double.ipynb @@ -44,6 +44,7 @@ "from fastrl.loggers.core import BatchCollector,EpochCollector\n", "from fastrl.learner.core import LearnerBase,LearnerHead\n", "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe\n", + "from fastrl.loggers.core import ProgressBarLogger\n", "from fastrl.agents.dqn.basic import (\n", " LossCollector,\n", " RollingTerminatedRewardCollector,\n", @@ -89,7 +90,7 @@ "source": [ "#|export\n", "class DoubleQCalc(dp.iter.IterDataPipe):\n", - " def __init__(self,source_datapipe=None):\n", + " def __init__(self,source_datapipe):\n", " self.source_datapipe = source_datapipe\n", " \n", " def __iter__(self):\n", @@ -114,7 +115,7 @@ "def DoubleDQNLearner(\n", " model,\n", " dls,\n", - " logger_bases:Optional[Callable]=None,\n", + " do_logging:bool=True,\n", " loss_func=nn.MSELoss(),\n", " opt=optim.AdamW,\n", " lr=0.005,\n", @@ -128,49 +129,38 @@ " learner = LearnerBase(model,dls=dls[0])\n", " learner = BatchCollector(learner,batches=batches)\n", " learner = EpochCollector(learner)\n", - " if logger_bases: \n", - " learner = logger_bases(learner)\n", + " if do_logging: \n", + " learner = learner.dump_records()\n", + " learner = ProgressBarLogger(learner)\n", " learner = RollingTerminatedRewardCollector(learner)\n", - " learner = EpisodeCollector(learner)\n", - " learner = learner.catch_records()\n", + " learner = EpisodeCollector(learner).catch_records()\n", " learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)\n", " learner = StepBatcher(learner,device=device)\n", - " # learner = TargetModelQCalc(learner)\n", " learner = DoubleQCalc(learner)\n", " learner = TargetCalc(learner,nsteps=nsteps)\n", " learner = LossCalc(learner,loss_func=loss_func)\n", " learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))\n", " learner = TargetModelUpdater(learner,target_sync=target_sync)\n", - " if logger_bases: \n", + " if do_logging: \n", " learner = LossCollector(learner).catch_records()\n", "\n", " if len(dls)==2:\n", - " val_learner = LearnerBase(model,dls[1])\n", + " val_learner = LearnerBase(model,dls[1]).visualize_vscode()\n", " val_learner = BatchCollector(val_learner,batches=batches)\n", " val_learner = EpochCollector(val_learner).catch_records(drop=True)\n", - " val_learner = VSCodeDataPipe(val_learner)\n", - " return LearnerHead((learner,val_learner),model)\n", + " return LearnerHead((learner,val_learner))\n", " else:\n", - " return LearnerHead(learner,model)" - ] - }, - { - "cell_type": "markdown", - "id": "2b8f9ed8-fb05-40a1-ac0d-d4cafee8fa07", - "metadata": {}, - "source": [ - "Try training with basic defaults..." + " return LearnerHead(learner)" ] }, { "cell_type": "code", "execution_count": null, - "id": "90f5ce18", + "id": "480451a4", "metadata": {}, "outputs": [], "source": [ "from fastrl.envs.gym import GymDataPipe\n", - "from fastrl.loggers.core import ProgressBarLogger\n", "from fastrl.dataloading.core import dataloaders" ] }, @@ -182,31 +172,24 @@ "outputs": [], "source": [ "#|eval:false\n", - "# Setup Loggers\n", - "def logger_bases(pipe):\n", - " pipe = pipe.dump_records()\n", - " pipe = ProgressBarLogger(pipe)\n", - " return pipe\n", "# Setup up the core NN\n", "torch.manual_seed(0)\n", "model = DQN(4,2)\n", "# Setup the Agent\n", "agent = DQNAgent(model,do_logging=True,min_epsilon=0.02,max_epsilon=1,max_steps=5000)\n", - "# Setup the DataBlock\n", - "train_pipe = GymDataPipe(\n", - " ['CartPole-v1']*1,\n", + "# Setup the Dataloaders\n", + "params = dict(\n", + " source=['CartPole-v1']*1,\n", " agent=agent,\n", " nsteps=2,\n", " nskips=2,\n", - " firstlast=True,\n", - " bs=1\n", + " firstlast=True\n", ")\n", - "dls = dataloaders(train_pipe)\n", + "dls = dataloaders((GymDataPipe(**params),GymDataPipe(**params,include_images=True).unbatch()))\n", "# Setup the Learner\n", "learner = DoubleDQNLearner(\n", " model,\n", " dls,\n", - " logger_bases=logger_bases,\n", " bs=128,\n", " max_sz=100_000,\n", " nsteps=2,\n", @@ -217,44 +200,6 @@ "learner.fit(7)" ] }, - { - "cell_type": "markdown", - "id": "d5e0ed73-7bbf-415b-9ee7-9a95de31d638", - "metadata": {}, - "source": [ - "The DQN learners, but I wonder if we can get it to learn faster..." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1fdef5f2", - "metadata": {}, - "outputs": [], - "source": [ - "train_pipe = GymDataPipe(\n", - " ['CartPole-v1']*1,\n", - " agent=agent,\n", - " nsteps=2,\n", - " nskips=2,\n", - " firstlast=True,\n", - " bs=1\n", - ")\n", - "dls = dataloaders([train_pipe,train_pipe])\n", - "# Setup the Learner\n", - "learner = DoubleDQNLearner(\n", - " model,\n", - " dls,\n", - " logger_bases=logger_bases,\n", - " bs=128,\n", - " max_sz=100_000,\n", - " nsteps=2,\n", - " lr=0.01,\n", - " batches=1000,\n", - " target_sync=300\n", - ")" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb b/nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb index 00c0746..9086db5 100644 --- a/nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb +++ b/nbs/07_Agents/01_Discrete/12n_agents.dqn.dueling.ipynb @@ -31,35 +31,15 @@ "source": [ "#|export\n", "# Python native modules\n", - "from copy import deepcopy\n", - "from typing import Optional,Callable,Tuple\n", "# Third party libs\n", - "import torchdata.datapipes as dp\n", - "from torchdata.dataloader2.graph import traverse_dps,DataPipe\n", "import torch\n", - "from torch import nn,optim\n", - "# Local modulesf\n", - "from fastrl.pipes.core import find_dp\n", - "from fastrl.memory.experience_replay import ExperienceReplay\n", - "from fastrl.loggers.core import BatchCollector,EpochCollector\n", - "from fastrl.learner.core import LearnerBase,LearnerHead\n", - "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe\n", + "from torch import nn\n", + "# Local modules\n", "from fastrl.agents.dqn.basic import (\n", - " LossCollector,\n", - " RollingTerminatedRewardCollector,\n", - " EpisodeCollector,\n", - " StepBatcher,\n", - " TargetCalc,\n", - " LossCalc,\n", - " ModelLearnCalc,\n", " DQN,\n", " DQNAgent\n", ")\n", - "from fastrl.agents.dqn.target import (\n", - " TargetModelUpdater,\n", - " TargetModelQCalc,\n", - " DQNTargetLearner\n", - ")" + "from fastrl.agents.dqn.target import DQNTargetLearner" ] }, { @@ -90,18 +70,19 @@ "source": [ "#|export\n", "class DuelingHead(nn.Module):\n", - " def __init__(self,\n", - " hidden:int, # Input into the DuelingHead, likely a hidden layer input\n", - " n_actions:int, # Number/dim of actions to output\n", - " lin_cls=nn.Linear\n", + " def __init__(\n", + " self,\n", + " hidden: int, # Input into the DuelingHead, likely a hidden layer input\n", + " n_actions: int, # Number/dim of actions to output\n", + " lin_cls = nn.Linear\n", " ):\n", " super().__init__()\n", " self.val = lin_cls(hidden,1)\n", " self.adv = lin_cls(hidden,n_actions)\n", "\n", " def forward(self,xi):\n", - " val,adv=self.val(xi),self.adv(xi)\n", - " xi=val.expand_as(adv)+(adv-adv.mean()).squeeze(0)\n", + " val,adv = self.val(xi),self.adv(xi)\n", + " xi = val.expand_as(adv)+(adv-adv.mean()).squeeze(0)\n", " return xi" ] }, @@ -120,9 +101,7 @@ "metadata": {}, "outputs": [], "source": [ - "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe\n", "from fastrl.envs.gym import GymDataPipe\n", - "from fastrl.loggers.core import ProgressBarLogger\n", "from fastrl.dataloading.core import dataloaders" ] }, @@ -134,33 +113,25 @@ "outputs": [], "source": [ "#|eval:false\n", - "# Setup Loggers\n", - "def logger_bases(pipe):\n", - " pipe = pipe.dump_records()\n", - " pipe = ProgressBarLogger(pipe)\n", - " return pipe\n", "# Setup up the core NN\n", "torch.manual_seed(0)\n", "model = DQN(4,2,head_layer=DuelingHead)\n", "# Setup the Agent\n", "model.train()\n", - "model = model.share_memory()\n", "agent = DQNAgent(model,do_logging=True,min_epsilon=0.02,max_epsilon=1,max_steps=5000)\n", - "# Setup the DataBlock\n", - "train_pipe = GymDataPipe(\n", - " ['CartPole-v1']*1,\n", + "# Setup the Dataloaders\n", + "params = dict(\n", + " source=['CartPole-v1']*1,\n", " agent=agent,\n", " nsteps=2,\n", " nskips=2,\n", - " firstlast=True,\n", - " bs=1\n", + " firstlast=True\n", ")\n", - "dls = dataloaders(train_pipe)\n", + "dls = dataloaders((GymDataPipe(**params),GymDataPipe(**params,include_images=True).unbatch()))\n", "# Setup the Learner\n", "learner = DQNTargetLearner(\n", " model,\n", " dls,\n", - " logger_bases=logger_bases,\n", " bs=128,\n", " max_sz=100_000,\n", " nsteps=2,\n", diff --git a/nbs/07_Agents/01_Discrete/12o_agents.dqn.categorical.ipynb b/nbs/07_Agents/01_Discrete/12o_agents.dqn.categorical.ipynb index 7198428..0c5810f 100644 --- a/nbs/07_Agents/01_Discrete/12o_agents.dqn.categorical.ipynb +++ b/nbs/07_Agents/01_Discrete/12o_agents.dqn.categorical.ipynb @@ -226,7 +226,7 @@ "outputs": [], "source": [ "def categorical_update(v_min,v_max,n_atoms,support,delta_z,model,reward,gamma,action,next_state):\n", - " t_q=(support*Softmax(model(next_state).gather(action))).sum()\n", + " t_q=(support*nn.Softmax(model(next_state).gather(action))).sum()\n", " a_star=torch.argmax(t_q)\n", " \n", " m=torch.zeros((N,)) # m_i = 0 where i in 1,...,N-1\n", @@ -782,7 +782,7 @@ "def DQNCategoricalLearner(\n", " model,\n", " dls,\n", - " logger_bases:Optional[Callable]=None,\n", + " do_logging:bool=True,\n", " loss_func=PartialCrossEntropy,\n", " opt=optim.AdamW,\n", " lr=0.005,\n", @@ -797,11 +797,11 @@ " learner = LearnerBase(model,dls=dls[0])\n", " learner = BatchCollector(learner,batches=batches)\n", " learner = EpochCollector(learner)\n", - " if logger_bases: \n", - " learner = logger_bases(learner)\n", + " if do_logging: \n", + " learner = learner.dump_records()\n", + " learner = ProgressBarLogger(learner)\n", " learner = RollingTerminatedRewardCollector(learner)\n", - " learner = EpisodeCollector(learner)\n", - " learner = learner.catch_records()\n", + " learner = EpisodeCollector(learner).catch_records()\n", " learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)\n", " learner = StepBatcher(learner,device=device)\n", " learner = CategoricalTargetQCalc(learner,nsteps=nsteps,double_dqn_strategy=double_dqn_strategy).to(device=device)\n", @@ -809,17 +809,16 @@ " learner = LossCalc(learner,loss_func=loss_func)\n", " learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))\n", " learner = TargetModelUpdater(learner,target_sync=target_sync)\n", - " if logger_bases: \n", + " if do_logging: \n", " learner = LossCollector(learner).catch_records()\n", "\n", " if len(dls)==2:\n", - " val_learner = LearnerBase(model,dls[1])\n", + " val_learner = LearnerBase(model,dls[1]).visualize_vscode()\n", " val_learner = BatchCollector(val_learner,batches=batches)\n", " val_learner = EpochCollector(val_learner).catch_records(drop=True)\n", - " val_learner = VSCodeDataPipe(val_learner)\n", - " return LearnerHead((learner,val_learner),model)\n", + " return LearnerHead((learner,val_learner))\n", " else:\n", - " return LearnerHead(learner,model)" + " return LearnerHead(learner)" ] }, { @@ -830,31 +829,24 @@ "outputs": [], "source": [ "#|eval:false\n", - "# Setup Loggers\n", - "def logger_bases(pipe):\n", - " pipe = pipe.dump_records()\n", - " pipe = ProgressBarLogger(pipe)\n", - " return pipe\n", "# Setup up the core NN\n", "torch.manual_seed(0)\n", "model = CategoricalDQN(4,2)\n", "# Setup the Agent\n", "agent = CategoricalDQNAgent(model,do_logging=True,min_epsilon=0.02,max_epsilon=1,max_steps=5000)\n", - "# Setup the DataBlock\n", - "train_pipe = GymDataPipe(\n", - " ['CartPole-v1']*1,\n", + "# Setup the DataLoader\n", + "params = dict(\n", + " source=['CartPole-v1']*1,\n", " agent=agent,\n", " nsteps=2,\n", " nskips=2,\n", - " firstlast=True,\n", - " bs=1\n", + " firstlast=True\n", ")\n", - "dls = dataloaders(train_pipe)\n", + "dls = dataloaders((GymDataPipe(**params),GymDataPipe(**params,include_images=True).unbatch()))\n", "# Setup the Learner\n", "learner = DQNCategoricalLearner(\n", " model,\n", " dls,\n", - " logger_bases=logger_bases,\n", " bs=128,\n", " max_sz=100_000,\n", " nsteps=2,\n", @@ -862,42 +854,7 @@ " batches=1000,\n", " target_sync=300\n", ")\n", - "learner.fit(7)\n", - "\n", - "\n", - "#|eval: false\n", - "# # Setup Loggers\n", - "# logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector,\n", - "# batch_on_pipe=BatchCollector)\n", - "\n", - "# # Setup up the core NN\n", - "# torch.manual_seed(0)\n", - "# model = CategoricalDQN(4,2).to(device='cuda')\n", - "# # Setup the Agent\n", - "# agent = DQNAgent(model,[logger_base],max_steps=4000,device='cuda',\n", - "# dp_augmentation_fns=[\n", - "# MultiModelRunner.replace_dp(device='cuda')\n", - "# ])\n", - "# # Setup the DataBlock\n", - "# block = DataBlock(\n", - "# GymTransformBlock(agent=agent,nsteps=2,nskips=2,firstlast=True), # We basically merge 2 steps into 1 and skip. \n", - "# (GymTransformBlock(agent=agent,nsteps=2,nskips=2,firstlast=True,n=100,include_images=True),VSCodeTransformBlock())\n", - "# )\n", - "# # pipes = L(block.datapipes(['CartPole-v1']*1,n=10))\n", - "# dls = L(block.dataloaders(['CartPole-v1']*1))\n", - "# # Setup the Learner\n", - "# learner = DQNLearner(model,dls,logger_bases=[logger_base],bs=128,\n", - "# batches=1000,\n", - "# loss_func = PartialCrossEntropy,\n", - "# device='cuda',\n", - "# max_sz=100_000,\n", - "# lr=0.001,\n", - "# dp_augmentation_fns=[\n", - "# TargetModelUpdater.insert_dp(),\n", - "# CategoricalTargetQCalc.replace_remove_dp(device='cuda',nsteps=2,double_dqn_strategy=True)\n", - "# ])\n", - "# learner.fit(1)\n", - "# learner.fit(7)" + "learner.fit(7)" ] }, { diff --git a/nbs/07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb b/nbs/07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb index d75287c..e4aedf6 100644 --- a/nbs/07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb +++ b/nbs/07_Agents/01_Discrete/12r_agents.dqn.rainbow.ipynb @@ -49,8 +49,9 @@ "from fastrl.memory.experience_replay import ExperienceReplay\n", "from fastrl.loggers.core import BatchCollector,EpochCollector\n", "from fastrl.learner.core import LearnerBase,LearnerHead\n", - "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe\n", "from fastrl.agents.core import AgentHead,AgentBase\n", + "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe\n", + "from fastrl.loggers.core import ProgressBarLogger\n", "from fastrl.agents.dqn.basic import (\n", " LossCollector,\n", " RollingTerminatedRewardCollector,\n", @@ -97,7 +98,7 @@ "def DQNRainbowLearner(\n", " model,\n", " dls,\n", - " logger_bases:Optional[Callable]=None,\n", + " do_logging:bool=True,\n", " loss_func=PartialCrossEntropy,\n", " opt=optim.AdamW,\n", " lr=0.005,\n", @@ -113,28 +114,27 @@ " learner = LearnerBase(model,dls=dls[0])\n", " learner = BatchCollector(learner,batches=batches)\n", " learner = EpochCollector(learner)\n", - " if logger_bases: \n", - " learner = logger_bases(learner)\n", + " if do_logging: \n", + " learner = learner.dump_records()\n", + " learner = ProgressBarLogger(learner)\n", " learner = RollingTerminatedRewardCollector(learner)\n", - " learner = EpisodeCollector(learner)\n", - " learner = learner.catch_records()\n", + " learner = EpisodeCollector(learner).catch_records()\n", " learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)\n", " learner = StepBatcher(learner,device=device)\n", " learner = CategoricalTargetQCalc(learner,nsteps=nsteps,double_dqn_strategy=double_dqn_strategy).to(device=device)\n", " learner = LossCalc(learner,loss_func=loss_func)\n", " learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))\n", " learner = TargetModelUpdater(learner,target_sync=target_sync)\n", - " if logger_bases: \n", + " if do_logging: \n", " learner = LossCollector(learner).catch_records()\n", "\n", " if len(dls)==2:\n", - " val_learner = LearnerBase(model,dls[1])\n", + " val_learner = LearnerBase(model,dls[1]).visualize_vscode()\n", " val_learner = BatchCollector(val_learner,batches=batches)\n", " val_learner = EpochCollector(val_learner).catch_records(drop=True)\n", - " val_learner = VSCodeDataPipe(val_learner)\n", - " return LearnerHead((learner,val_learner),model)\n", + " return LearnerHead((learner,val_learner))\n", " else:\n", - " return LearnerHead(learner,model)" + " return LearnerHead(learner)" ] }, { @@ -144,9 +144,7 @@ "metadata": {}, "outputs": [], "source": [ - "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe\n", "from fastrl.envs.gym import GymDataPipe\n", - "from fastrl.loggers.core import ProgressBarLogger\n", "from fastrl.dataloading.core import dataloaders" ] }, @@ -158,37 +156,25 @@ "outputs": [], "source": [ "#|eval:false\n", - "# Setup Loggers\n", - "def logger_bases(pipe):\n", - " pipe = pipe.dump_records()\n", - " pipe = ProgressBarLogger(pipe)\n", - " return pipe\n", "# Setup up the core NN\n", "torch.manual_seed(0)\n", "# Rainbow uses a CategoricalDQN with a DuelingHead (DuealingDQN)\n", "model = CategoricalDQN(4,2,head_layer=DuelingHead)\n", "# Setup the Agent\n", "agent = CategoricalDQNAgent(model,do_logging=True,min_epsilon=0.02,max_epsilon=1,max_steps=5000)\n", - "# Setup the DataBlock\n", - "train_pipe = GymDataPipe(\n", - " ['CartPole-v1']*1,\n", + "# Setup the Dataloaders\n", + "params = dict(\n", + " source=['CartPole-v1']*1,\n", " agent=agent,\n", " nsteps=2,\n", " nskips=2,\n", - " firstlast=True,\n", - " bs=1\n", - ")\n", - "validation_pipe = GymDataPipe(\n", - " ['CartPole-v1']*1,\n", - " agent=agent,\n", - " include_images=True\n", + " firstlast=True\n", ")\n", - "dls = dataloaders((train_pipe,validation_pipe))\n", + "dls = dataloaders((GymDataPipe(**params),GymDataPipe(**params,include_images=True).unbatch()))\n", "# Setup the Learner\n", "learner = DQNRainbowLearner(\n", " model,\n", " dls,\n", - " logger_bases=logger_bases,\n", " bs=128,\n", " max_sz=100_000,\n", " nsteps=2,\n", diff --git a/nbs/07_Agents/02_Continuous/12s_agents.ddpg.ipynb b/nbs/07_Agents/02_Continuous/12s_agents.ddpg.ipynb index ff63763..6b18449 100644 --- a/nbs/07_Agents/02_Continuous/12s_agents.ddpg.ipynb +++ b/nbs/07_Agents/02_Continuous/12s_agents.ddpg.ipynb @@ -30,34 +30,32 @@ "outputs": [], "source": [ "#|export\n", - "# # Python native modules\n", - "# import os\n", + "# Python native modules\n", "from typing import Tuple,Optional,Callable,Union,Dict,Literal,List\n", "from functools import partial\n", - "# from typing_extensions import Literal\n", "from copy import deepcopy\n", - "# # Third party libs\n", + "# Third party libs\n", "from fastcore.all import add_docs\n", "import torchdata.datapipes as dp\n", "from torchdata.dataloader2.graph import traverse_dps,find_dps,DataPipe\n", - "# from torchdata.dataloader2.graph import DataPipe,traverse\n", "from torch import nn\n", - "# from torch.optim import AdamW,Adam\n", "import torch\n", - "# import pandas as pd\n", - "# import numpy as np\n", - "# # Local modules\n", + "# Local modules\n", "from fastrl.core import SimpleStep\n", "from fastrl.pipes.core import find_dp\n", "from fastrl.torch_core import Module\n", "from fastrl.memory.experience_replay import ExperienceReplay\n", - "from fastrl.loggers.core import Record,is_record,not_record,_RECORD_CATCH_LIST\n", "from fastrl.learner.core import LearnerBase,LearnerHead,StepBatcher\n", - "# from fastrl.pipes.core import *\n", - "# from fastrl.data.block import *\n", - "# from fastrl.data.dataloader2 import *\n", + "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe\n", "from fastrl.loggers.core import (\n", - " LogCollector,Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,is_record\n", + " ProgressBarLogger,\n", + " Record,\n", + " BatchCollector,\n", + " EpochCollector,\n", + " RollingTerminatedRewardCollector,\n", + " EpisodeCollector,\n", + " not_record,\n", + " _RECORD_CATCH_LIST\n", ")\n", "from fastrl.agents.core import (\n", " AgentHead,\n", @@ -65,10 +63,7 @@ " StepFieldSelector,\n", " SimpleModelRunner,\n", " NumpyConverter\n", - ")\n", - "# from fastrl.memory.experience_replay import ExperienceReplay\n", - "# from fastrl.learner.core import *\n", - "# from fastrl.loggers.core import *" + ")" ] }, { @@ -899,9 +894,7 @@ "metadata": {}, "outputs": [], "source": [ - "from fastrl.envs.gym import GymDataPipe\n", - "from fastrl.loggers.core import ProgressBarLogger,EpochCollector,BatchCollector\n", - "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe" + "from fastrl.envs.gym import GymDataPipe" ] }, { @@ -913,19 +906,14 @@ "source": [ "#|hide\n", "torch.manual_seed(0)\n", - "\n", "actor = Actor(3,1)\n", "\n", "# Setup the Agent\n", "agent = DDPGAgent(actor,max_steps=10000)\n", "\n", - "\n", "pipe = GymDataPipe(['Pendulum-v1']*1,agent=agent,n=100,seed=None,include_images=True)\n", "pipe = VSCodeDataPipe(pipe)\n", "\n", - "# pipe = GymTransformBlock(agent=agent,n=100,seed=None,include_images=True)(['Pendulum-v1'])\n", - "# pipe = VSCodeTransformBlock()(pipe)\n", - "\n", "pipe_to_device(pipe,default_device(),debug=True)\n", "\n", "list(pipe);\n", @@ -1397,7 +1385,7 @@ " critic:Critic,\n", " # A list of dls, where index=0 is the training dl.\n", " dls,\n", - " logger_bases:Optional[Callable]=None,\n", + " do_logging:bool=True,\n", " # The learning rate for the actor. Expected to learn slower than the critic\n", " actor_lr:float=1e-3,\n", " # The optimizer for the actor\n", @@ -1429,11 +1417,12 @@ " # Debug mode will output device moves\n", " debug:bool=False\n", ") -> LearnerHead:\n", - " learner = LearnerBase(actor,dls[0])\n", + " learner = LearnerBase({'actor':actor,'critic':critic},dls[0])\n", " learner = BatchCollector(learner,batches=batches)\n", " learner = EpochCollector(learner)\n", - " if logger_bases: \n", - " learner = logger_bases(learner) \n", + " if do_logging: \n", + " learner = learner.dump_records()\n", + " learner = ProgressBarLogger(learner)\n", " learner = RollingTerminatedRewardCollector(learner)\n", " learner = EpisodeCollector(learner).catch_records()\n", " learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)\n", @@ -1446,12 +1435,15 @@ " learner = ActorLossProcessor(learner,critic,actor,clip_critic_grad=5)\n", " learner = LossCollector(learner,title='actor-loss').catch_records()\n", " learner = BasicOptStepper(learner,actor,actor_lr,opt=actor_opt,filter=True,do_zero_grad=False)\n", - " learner = LearnerHead(learner,(actor,critic))\n", - " \n", - " # for dl in dls: \n", - " # pipe_to_device(dl.datapipe,device,debug=debug)\n", - " \n", - " return learner\n", + " learner = LearnerHead(learner)\n", + "\n", + " if len(dls)==2:\n", + " val_learner = LearnerBase({'actor':actor,'critic':critic},dls[1]).visualize_vscode()\n", + " val_learner = BatchCollector(val_learner,batches=batches)\n", + " val_learner = EpochCollector(val_learner).catch_records(drop=True)\n", + " return LearnerHead((learner,val_learner))\n", + " else:\n", + " return LearnerHead(learner)\n", "\n", "DDPGLearner.__doc__=\"\"\"DDPG is a continuous action, actor-critic model, first created in\n", "(Lillicrap et al., 2016). The critic estimates a Q value estimate, and the actor\n", @@ -1476,11 +1468,6 @@ "outputs": [], "source": [ "#|eval:false\n", - "# Setup Loggers\n", - "def logger_bases(pipe):\n", - " pipe = pipe.dump_records()\n", - " pipe = ProgressBarLogger(pipe)\n", - " return pipe\n", "# Setup up the core NN\n", "torch.manual_seed(0)\n", "actor = Actor(3,1)\n", @@ -1490,15 +1477,20 @@ "agent = DDPGAgent(actor,do_logging=True,max_steps=5000,min_epsilon=0.1)\n", "\n", "# Setup the Dataloaders\n", - "dls = dataloaders(\n", - " GymDataPipe(['Pendulum-v1']*1,agent=agent,nsteps=2,nskips=2,firstlast=True,bs=1)\n", + "params = dict(\n", + " source=['Pendulum-v1']*1,\n", + " agent=agent,\n", + " nsteps=2,\n", + " nskips=2,\n", + " firstlast=True\n", ")\n", + "dls = dataloaders((GymDataPipe(**params),GymDataPipe(**params,include_images=True).unbatch()))\n", + "\n", "# Setup the Learner\n", "learner = DDPGLearner(\n", " actor,\n", " critic,\n", " dls,\n", - " logger_bases=logger_bases,\n", " bs=128,\n", " max_sz=20_000,\n", " nsteps=2,\n", diff --git a/nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb b/nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb index 382a547..f82afbf 100644 --- a/nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb +++ b/nbs/07_Agents/02_Continuous/12t_agents.trpo.ipynb @@ -33,49 +33,37 @@ "# Python native modules\n", "from typing import NamedTuple,List,Tuple,Optional,Dict,Literal,Callable,Union\n", "from functools import partial\n", - "# from typing_extensions import Literal\n", "# import typing \n", "from warnings import warn\n", "# Third party libs\n", "import numpy as np\n", "import torch\n", "from torch import nn\n", - "from torch.optim import AdamW,Adam\n", + "from torch.optim import Adam\n", "from torch.distributions import Independent,Normal\n", "import torchdata.datapipes as dp \n", "from torchdata.dataloader2.graph import DataPipe,traverse_dps\n", - "from fastcore.all import add_docs,store_attr,ifnone,L\n", + "from fastcore.all import add_docs,store_attr,L\n", "import gymnasium as gym\n", "from torchdata.dataloader2.graph import find_dps,traverse_dps\n", - "# from fastrl.data.dataloader2 import *\n", - "# from torchdata.dataloader2 import DataLoader2,DataLoader2Iterator\n", - "# from torchdata.dataloader2.graph import find_dps,traverse,DataPipe,IterDataPipe,MapDataPipe\n", "# Local modules\n", "from fastrl.core import add_namedtuple_doc,SimpleStep,StepTypes\n", "from fastrl.pipes.core import find_dp\n", - "from fastrl.loggers.core import Record,is_record,not_record,_RECORD_CATCH_LIST\n", + "from fastrl.loggers.core import Record,not_record,_RECORD_CATCH_LIST\n", "from fastrl.torch_core import Module,evaluating\n", "from fastrl.layers import Critic\n", - "# from fastrl.data.block import *\n", "from fastrl.envs.gym import GymStepper\n", - "from fastrl.pipes.iter.firstlast import FirstLastMerger\n", "from fastrl.pipes.iter.nskip import NSkipper\n", "from fastrl.pipes.iter.nstep import NStepper,NStepFlattener\n", "import fastrl.pipes.iter.cacheholder\n", "from fastrl.agents.ddpg import LossCollector,BasicOptStepper,StepBatcher\n", - "# from fastrl.loggers.core import LogCollector\n", - "# from fastrl.agents.discrete import EpsilonCollector\n", - "# from copy import deepcopy\n", "from fastrl.learner.core import LearnerBase,LearnerHead\n", "from fastrl.loggers.core import BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector\n", "\n", - "# from fastrl.agents.ddpg import BasicOptStepper\n", - "# from fastrl.loggers.vscode_visualizers import VSCodeTransformBlock\n", - "# from fastrl.loggers.jupyter_visualizers import ProgressBarLogger\n", - "# from fastrl.layers import Critic\n", - "# from fastrl.agents.discrete import EpsilonCollector\n", + "from fastrl.loggers.core import ProgressBarLogger,EpochCollector,BatchCollector\n", + "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe\n", "from fastrl.agents.core import AgentHead,StepFieldSelector,AgentBase\n", - "from fastrl.agents.ddpg import ActionClip,ActionUnbatcher,NumpyConverter,OrnsteinUhlenbeck,SimpleModelRunner" + "from fastrl.agents.ddpg import ActionClip,ActionUnbatcher,NumpyConverter,SimpleModelRunner" ] }, { @@ -1325,11 +1313,7 @@ " # A list of dls, where index=0 is the training dl.\n", " dls:List[object],\n", " # Optional logger bases to log training/validation data to.\n", - " logger_bases:Optional[Callable]=None,\n", - " # The learning rate for the actor. Expected to learn slower than the critic\n", - " actor_lr:float=1e-3,\n", - " # The optimizer for the actor\n", - " actor_opt:torch.optim.Optimizer=Adam,\n", + " do_logging:bool=True,\n", " # The learning rate for the critic. Expected to learn faster than the actor\n", " critic_lr:float=1e-2,\n", " # The optimizer for the critic\n", @@ -1349,11 +1333,12 @@ ") -> LearnerHead:\n", " warn(\"TRPO only kind of converges. There is a likely a bug, however I am unable to identify until after PPO implimentation\")\n", "\n", - " learner = LearnerBase(actor,dls[0])\n", + " learner = LearnerBase({'actor':actor,'critic':critic},dls[0])\n", " learner = BatchCollector(learner,batches=batches)\n", " learner = EpochCollector(learner)\n", - " if logger_bases: \n", - " learner = logger_bases(learner)\n", + " if do_logging: \n", + " learner = learner.dump_records()\n", + " learner = ProgressBarLogger(learner)\n", " learner = RollingTerminatedRewardCollector(learner)\n", " learner = EpisodeCollector(learner).catch_records()\n", " learner = StepBatcher(learner)\n", @@ -1362,8 +1347,14 @@ " learner = BasicOptStepper(learner,critic,critic_lr,opt=critic_opt,filter=True,do_zero_grad=False)\n", " learner = ActorOptAndLossProcessor(learner,actor)\n", " learner = LossCollector(learner,title='actor-loss').catch_records()\n", - " learner = LearnerHead(learner,(actor,critic)) \n", - " return learner\n", + "\n", + " if len(dls)==2:\n", + " val_learner = LearnerBase({'actor':actor,'critic':critic},dls[1]).visualize_vscode()\n", + " val_learner = BatchCollector(val_learner,batches=batches)\n", + " val_learner = EpochCollector(val_learner).catch_records(drop=True)\n", + " return LearnerHead((learner,val_learner))\n", + " else:\n", + " return LearnerHead(learner)\n", "\n", "TRPOLearner.__doc__=\"\"\"\"\"\"" ] @@ -1375,8 +1366,7 @@ "metadata": {}, "outputs": [], "source": [ - "from fastrl.dataloading.core import dataloaders\n", - "from fastrl.loggers.core import ProgressBarLogger" + "from fastrl.dataloading.core import dataloaders" ] }, { @@ -1387,14 +1377,6 @@ "outputs": [], "source": [ "#|eval:false\n", - "# Setup Loggers\n", - "def logger_bases(pipe):\n", - " pipe = pipe.dump_records()\n", - " pipe = ProgressBarLogger(pipe)\n", - " return pipe\n", - "\n", - "# env='HalfCheetahBulletEnv-v0'\n", - "\n", "# Setup up the core NN\n", "torch.manual_seed(0)\n", "actor = Actor(2,1)\n", @@ -1405,42 +1387,38 @@ "agent = TRPOAgent(actor,do_logging=True,clip_min=-1,clip_max=1)\n", "\n", "# Setup the Dataloaders\n", + "params = dict(source=['MountainCarContinuous-v0']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,gamma=0.99,discount=0.99)\n", + "\n", "dls = dataloaders((\n", - " # AdvantageGymDataPipe(['Pendulum-v1']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=200,gamma=0.99,discount=0.99),MountainCarContinuous-v0\n", + " # AdvantageGymDataPipe(['Pendulum-v1']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=200,gamma=0.99,discount=0.99),\n", " # AdvantageGymDataPipe(['Pendulum-v1']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=1,gamma=0.99,discount=0.99)\n", - " AdvantageGymDataPipe(['MountainCarContinuous-v0']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=200,gamma=0.99,discount=0.99),\n", - " AdvantageGymDataPipe(['MountainCarContinuous-v0']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=1,gamma=0.99,discount=0.99)\n", + " AdvantageGymDataPipe(**params,bs=200),\n", + " AdvantageGymDataPipe(**params,bs=1,include_images=True)\n", "))\n", "# Setup the Learner\n", - "learner = TRPOLearner(actor,critic,dls,logger_bases=logger_bases,\n", - " batches=10,critic_lr=0.01)\n", + "learner = TRPOLearner(actor,critic,dls,batches=10,critic_lr=0.01)\n", "# learner.fit(1)\n", - "learner.fit(10)" + "learner.fit(15)" ] }, { "cell_type": "code", "execution_count": null, - "id": "604dddf6", + "id": "b6cb30fd", "metadata": {}, "outputs": [], "source": [ - "val_agent = TRPOAgent(actor,do_logging=True,clip_min=-1,clip_max=1)" + "learner.validate(2)" ] }, { "cell_type": "code", "execution_count": null, - "id": "b4ab25a1", + "id": "c8c1699a", "metadata": {}, "outputs": [], "source": [ - "# valid_pipe = AdvantageGymDataPipe(['Pendulum-v1']*1,agent=val_agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=1,gamma=0.99,discount=0.99,n=300,include_images=True)\n", - "valid_pipe = AdvantageGymDataPipe(['MountainCarContinuous-v0']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=1,gamma=0.99,discount=0.99,n=1000,include_images=True)\n", - "valid_pipe = VSCodeDataPipe(valid_pipe)\n", - "# sample_run = [o[0] for o in valid_pipe.dump_records().catch_records(drop=True)];\n", - "list(valid_pipe);\n", - "valid_pipe.show()" + "import matplotlib.pyplot as plt" ] }, { @@ -1450,8 +1428,6 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", - "\n", "def normalize(data):\n", " mean = sum(data) / len(data)\n", " variance = sum([(x - mean) ** 2 for x in data]) / len(data)\n", @@ -1492,56 +1468,6 @@ " plt.show()" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "f5175cd7", - "metadata": {}, - "outputs": [], - "source": [ - "visualize_advantage_steps(sample_run)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fae186a4", - "metadata": {}, - "outputs": [], - "source": [ - "from fastrl.memory.memory_visualizer import MemoryBufferViewer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1e7360cd", - "metadata": {}, - "outputs": [], - "source": [ - "sample_run[0].advantage" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "07c59d61", - "metadata": {}, - "outputs": [], - "source": [ - "MemoryBufferViewer(sample_run,val_agent)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c2c87913", - "metadata": {}, - "outputs": [], - "source": [ - "learner.validate()" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb b/nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb index 0c7a366..714c906 100644 --- a/nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb +++ b/nbs/07_Agents/02_Continuous/12u_agents.ppo.ipynb @@ -12,6 +12,16 @@ "initialize_notebook()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "31ad979a", + "metadata": {}, + "outputs": [], + "source": [ + "#|default_exp agents.ppo" + ] + }, { "cell_type": "code", "execution_count": null, @@ -21,54 +31,24 @@ "source": [ "#|export\n", "# Python native modules\n", - "from typing import Union,Dict,Literal,List,Callable,Optional\n", - "# from typing_extensions import Literal\n", - "# import typing \n", - "# from warnings import warn\n", - "# # Third party libs\n", - "# import numpy as np\n", + "from typing import Union,Dict,Literal,List\n", + "# Third party libs\n", "import torch\n", "from torch import nn\n", - "# from torch.distributions import *\n", "import torchdata.datapipes as dp \n", - "from torchdata.dataloader2.graph import DataPipe,traverse_dps\n", - "# from fastcore.all import test_eq,test_ne,ifnone,L\n", + "from torchdata.dataloader2.graph import DataPipe\n", "from torch.optim import AdamW,Adam\n", "# # Local modules\n", "from fastrl.core import SimpleStep\n", - "# from fastrl.pipes.core import *\n", - "# from fastrl.torch_core import *\n", "from fastrl.layers import Critic\n", - "# from fastrl.data.block import *\n", - "# from fastrl.envs.gym import *\n", "from fastrl.agents.trpo import Actor\n", - "# from fastrl.loggers.vscode_visualizers import VSCodeTransformBlock\n", - "# from fastrl.loggers.jupyter_visualizers import ProgressBarLogger\n", - "# from fastrl.agents.discrete import EpsilonCollector\n", - "# from fastrl.agents.core import AgentHead,StepFieldSelector,AgentBase \n", - "# from fastrl.agents.ddpg import ActionClip,ActionUnbatcher,NumpyConverter,OrnsteinUhlenbeck,SimpleModelRunner\n", - "# from fastrl.loggers.core import LoggerBase,CacheLoggerBase\n", - "# from fastrl.dataloader2_ext import InputInjester\n", - "# from fastrl.loggers.core import LoggerBasePassThrough,BatchCollector,EpocherCollector,RollingTerminatedRewardCollector,EpisodeCollector\n", + "from fastrl.loggers.core import ProgressBarLogger\n", + "from fastrl.loggers.vscode_visualizers import VSCodeDataPipe\n", "from fastrl.learner.core import LearnerBase,LearnerHead\n", "from fastrl.loggers.core import BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector\n", "import fastrl.pipes.iter.cacheholder\n", "from fastrl.agents.ddpg import LossCollector,BasicOptStepper,StepBatcher\n", - "from fastrl.agents.trpo import CriticLossProcessor\n", - "# from fastrl.pipes.core import *\n", - "# from fastrl.pipes.iter.nskip import *\n", - "# from fastrl.pipes.iter.nstep import *\n", - "# from fastrl.pipes.iter.firstlast import *\n", - "# from fastrl.pipes.iter.transforms import *\n", - "# from fastrl.pipes.map.transforms import *\n", - "# from fastrl.data.block import *\n", - "# from fastrl.torch_core import *\n", - "# from fastrl.layers import *\n", - "# from fastrl.data.block import *\n", - "# from fastrl.envs.gym import *\n", - "# from fastrl.agents.ddpg import LossCollector,BasicOptStepper,StepBatcher\n", - "# from fastrl.loggers.core import LogCollector\n", - "# from fastrl.agents.discrete import EpsilonCollector\n" + "from fastrl.agents.trpo import CriticLossProcessor" ] }, { @@ -197,7 +177,7 @@ " # A list of dls, where index=0 is the training dl.\n", " dls:List[object],\n", " # Optional logger bases to log training/validation data to.\n", - " logger_bases:Optional[Callable]=None,\n", + " do_logging:bool=True,\n", " # The learning rate for the actor. Expected to learn slower than the critic\n", " actor_lr:float=1e-4,\n", " # The optimizer for the actor\n", @@ -222,22 +202,27 @@ " ppo_batch_sz = 64,\n", " ppo_eps = 0.2,\n", ") -> LearnerHead:\n", - " learner = LearnerBase(actor,dls[0])\n", + " learner = LearnerBase({'actor':actor,'critic':critic},dls[0])\n", " learner = BatchCollector(learner,batches=batches)\n", " learner = EpochCollector(learner)\n", - " if logger_bases: \n", - " learner = logger_bases(learner)\n", + " if do_logging: \n", + " learner = learner.dump_records()\n", + " learner = ProgressBarLogger(learner)\n", " learner = RollingTerminatedRewardCollector(learner)\n", " learner = EpisodeCollector(learner).catch_records()\n", " learner = StepBatcher(learner)\n", - " # learner = CriticLossProcessor(learner,critic=critic)\n", - " # learner = LossCollector(learner,title='critic-loss').catch_records()\n", - " # learner = BasicOptStepper(learner,critic,critic_lr,opt=critic_opt,filter=True,do_zero_grad=False)\n", " learner = PPOActorOptAndLossProcessor(learner,actor=actor,actor_lr=actor_lr,\n", " critic=critic,critic_lr=critic_lr,ppo_epochs=ppo_epochs,\n", " ppo_batch_sz=ppo_batch_sz,ppo_eps=ppo_eps)\n", " learner = LossCollector(learner,title='actor-loss').catch_records()\n", - " learner = LearnerHead(learner,(actor,critic))\n", + "\n", + " if len(dls)==2:\n", + " val_learner = LearnerBase({'actor':actor,'critic':critic},dls[1]).visualize_vscode()\n", + " val_learner = BatchCollector(val_learner,batches=batches)\n", + " val_learner = EpochCollector(val_learner).catch_records(drop=True)\n", + " return LearnerHead((learner,val_learner))\n", + " else:\n", + " return LearnerHead(learner)\n", " \n", " return learner\n", "\n", @@ -252,7 +237,6 @@ "outputs": [], "source": [ "from fastrl.dataloading.core import dataloaders\n", - "from fastrl.loggers.core import ProgressBarLogger\n", "from fastrl.agents.trpo import TRPOAgent,AdvantageGymDataPipe" ] }, @@ -279,14 +263,16 @@ "agent = TRPOAgent(actor,do_logging=True,clip_min=-1,clip_max=1)\n", "\n", "# Setup the Dataloaders\n", + "params = dict(source=['MountainCarContinuous-v0']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,gamma=0.99,discount=0.99)\n", + "\n", "dls = dataloaders((\n", - " # AdvantageGymDataPipe(['Pendulum-v1']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=200,gamma=0.99,discount=0.99),MountainCarContinuous-v0\n", + " # AdvantageGymDataPipe(['Pendulum-v1']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=200,gamma=0.99,discount=0.99),\n", " # AdvantageGymDataPipe(['Pendulum-v1']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=1,gamma=0.99,discount=0.99)\n", - " AdvantageGymDataPipe(['MountainCarContinuous-v0']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=200,gamma=0.99,discount=0.99),\n", - " AdvantageGymDataPipe(['MountainCarContinuous-v0']*1,agent=agent,critic=critic,nsteps=2,nskips=2,firstlast=True,bs=1,gamma=0.99,discount=0.99)\n", + " AdvantageGymDataPipe(**params,bs=200),\n", + " AdvantageGymDataPipe(**params,bs=1,include_images=True)\n", "))\n", "# Setup the Learner\n", - "learner = PPOLearner(actor,critic,dls,logger_bases=logger_bases,batches=10,ppo_batch_sz = 64*2)\n", + "learner = PPOLearner(actor,critic,dls,batches=10,ppo_batch_sz = 64*2)\n", "# learner.fit(1)\n", "learner.fit(10)" ]