From 5764a83495362a5c1d43254e248ddb97a3215583 Mon Sep 17 00:00:00 2001 From: Geson-anko <59220704+Geson-anko@users.noreply.github.com> Date: Tue, 24 Jan 2023 03:57:17 +0000 Subject: [PATCH 01/28] ADD registering names for replay buffer --- src/datamodules/buffer_names.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 src/datamodules/buffer_names.py diff --git a/src/datamodules/buffer_names.py b/src/datamodules/buffer_names.py new file mode 100644 index 0000000..501111b --- /dev/null +++ b/src/datamodules/buffer_names.py @@ -0,0 +1,5 @@ +ACTION = "action" +VOC_STATE = "voc_state" +GENERATED_SOUND = "generated_sound" +TARGET_SOUND = "target_sound" +DONE = "done" From 2faf49f5762d1112b86b52974c5f7f743ab8d3cd Mon Sep 17 00:00:00 2001 From: Geson-anko <59220704+Geson-anko@users.noreply.github.com> Date: Tue, 24 Jan 2023 03:59:27 +0000 Subject: [PATCH 02/28] ADD Dreamer (wip) --- src/models/dreamer.py | 272 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 272 insertions(+) create mode 100644 src/models/dreamer.py diff --git a/src/models/dreamer.py b/src/models/dreamer.py new file mode 100644 index 0000000..4338bc6 --- /dev/null +++ b/src/models/dreamer.py @@ -0,0 +1,272 @@ +from collections import OrderedDict +from functools import partial +from typing import Any + +import gym +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributions import kl_divergence +from torch.optim import Optimizer + +from ..datamodules import buffer_names +from ..datamodules.replay_buffer import ReplayBuffer +from ..env.array_voc_state import VocStateObsNames as ObsNames +from .abc.agent import Agent +from .abc.controller import Controller +from .abc.observation_auto_encoder import ObservationDecoder, ObservationEncoder +from .abc.prior import Prior +from .abc.transition import Transition +from .abc.world import World + + +class Dreamer(nn.Module): + """Dreamer model class.""" + + # Added attribute from Trainer + current_step: int + current_episode: int + device: torch.device + dtype: torch.dtype + + def __init__( + self, + transition: Transition, + prior: Prior, + obs_encoder: ObservationEncoder, + obs_decoder: ObservationDecoder, + controller: Controller, + world: partial[World], + agent: partial[Agent], + world_optimizer: partial[Optimizer], + controller_optimizer: partial[Optimizer], + free_nats: float = 3.0, + num_collect_experience_steps: int = 100, + imagination_horizon: int = 32, + ) -> None: + """ + Args: + transition (Transition): Instance of ransition model class. + prior (Prior): Instance of prior model class. + obs_encoder (ObservationEncoder): Instance of ObservationEncoder model class. + obs_decoder (ObservationDecoder): Instance of ObservationDecoder model class. + controller (Controller): Instance of Controller model class. + world (partial[World]): Partial instance of World interface class. + agent (partial[Agent]): Partial instance of Agent interface class. + world_optimizer (partial[Optimizer]): Partial instance of Optimizer class. + controller_optimizer (partial[Optimizer]): Partial instance of Optimizer class. + + free_nats (float): Ignore kl div loss when it is less then this value. + """ + + self.transition = transition + self.prior = prior + self.obs_encoder = obs_encoder + self.obs_decoder = obs_decoder + self.controller = controller + + self.world = world( + transition=transition, prior=prior, obs_encoder=obs_encoder, obs_decoder=obs_decoder + ) + + self.agent = agent( + controller=controller, + transition=transition, + obs_encoder=obs_encoder, + ) + + self.world_optimizer = world_optimizer + self.controller_optimizer = controller_optimizer + + self.free_nats = free_nats + self.num_collect_experience_steps = num_collect_experience_steps + self.imagination_horizon = imagination_horizon + + def configure_optimizers(self) -> tuple[Optimizer, Optimizer]: + """Configure world optimizer and controller optimizer. + + Returns: + world_optimizer (Optimizer): Updating Transition, Prior, ObservationEncoder and Decoder. + controller_optimizer (Optimizer): Updating Controller. + """ + world_params = ( + list(self.transition.parameters()) + + list(self.prior.parameters()) + + list(self.obs_encoder.parameters()) + + list(self.obs_decoder.parameters()) + ) + + world_optim = self.world_optimizer(params=world_params) + con_optim = self.controller_optimizer(params=self.controller.parameters()) + + return [world_optim, con_optim] + + def collect_experiences(self, env: gym.Env, replay_buffer: ReplayBuffer) -> ReplayBuffer: + """Explorer in env and collect experiences to replay buffer. + + Args: + env (gym.Env): PynkTrombone environment or its wrapper class. + replay_buffer (ReplayBuffer): Storing experiences. + num_steps (int): How much experiences to store. + + Returns: + replay_buffer(ReplayBuffer): Same pointer of input replay_buffer. + """ + device = self.agent.hidden.device + dtype = self.agent.hidden.dtype + + obs = env.reset() + voc_state_np = obs[ObsNames.VOC_STATE] + generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] + target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] + + voc_state = torch.as_tensor(voc_state_np, dtype, device).squeeze(0) + generated = torch.as_tensor(generated_np, dtype, device).squeeze(0) + target = torch.as_tensor(target_np, dtype, device).squeeze(0) + + for _ in range(self.num_collect_experience_steps): + action = self.agent.explore(obs=(voc_state, generated), target=target) + action = action.cpu().unsqueeze(0).numpy() + obs, _, done, _ = env.step(action) + + voc_state_np = obs[ObsNames.VOC_STATE] + generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] + target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] + + sample = { + buffer_names.ACTION: action, + buffer_names.VOC_STATE: voc_state_np, + buffer_names.GENERATED_SOUND: generated_np, + buffer_names.DONE: done, + buffer_names.TARGET_SOUND: target_np, + } + + replay_buffer.push(sample) + + if done: + obs = env.reset() + voc_state_np = obs[ObsNames.VOC_STATE] + generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] + target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] + + voc_state = torch.as_tensor(voc_state_np, dtype, device).squeeze(0) + generated = torch.as_tensor(generated_np, dtype, device).squeeze(0) + target = torch.as_tensor(target_np, dtype, device).squeeze(0) + + def world_training_step( + self, experiences: dict[str, np.ndarray] + ) -> tuple[dict[str, Any], dict[str, np.ndarray]]: + """Compute loss for training world model, and add all hiddens and states to `experiences` + for controller training. + + Args: + experiences (dict[str, np.ndarray]): Collected experiences. + + Returns: + loss_dict (dict[str, Any]): loss and some other metric values. + experiences (dict[str, np.ndarray]): Added `all_hiddens` and `all_states`. + """ + device = self.agent.hidden.device + dtype = self.agent.hidden.dtype + + actions = experiences[buffer_names.ACTION] + voc_states = experiences[buffer_names.VOC_STATE] + generated_sounds = experiences[buffer_names.GENERATED_SOUND] + dones = experiences[buffer_names.DONE] + + chunk_size, batch_size = actions.shape[:2] + + hidden = torch.zeros( + (batch_size, *self.transition.hidden_shape), dtype=dtype, device=device + ) + state = torch.zeros((batch_size, *self.prior.state_shape), dtype=dtype, device=device) + + all_hiddens = torch.empty((chunk_size, *hidden.shape), dtype=dtype, device="cpu") + all_states = torch.empty((chunk_size, *state.shape), dtype=dtype, device="cpu") + + rec_voc_state_loss = 0.0 + rec_generated_sound_loss = 0.0 + all_kl_div_loss = 0.0 + + for idx in range(chunk_size): + action = torch.as_tensor(actions[idx], dtype, device) + + voc_stat = torch.as_tensor(voc_states[idx], dtype, device) + gened_sound = torch.as_tensor(generated_sounds[idx], dtype, device) + next_obs = (voc_stat, gened_sound) + + next_hidden = self.transition.forward(hidden, state, action) + next_state_prior = self.prior.forward(next_hidden) + next_state_posterior = self.obs_encoder.forward(next_hidden, next_obs) + next_state = next_state_posterior.rsample() + + all_states[idx] = next_state.detach() + all_hiddens[idx] = next_hidden.detach() + + rec_voc_stat, rec_gened_sound = self.obs_decoder.forward(next_hidden, next_state) + + # compute losses + kl_div_loss = kl_divergence(next_state_posterior, next_state_prior).view( + batch_size, -1 + ) + all_kl_div_loss += kl_div_loss.sum(-1).mean() + + rec_voc_state_loss += F.mse_loss(voc_stat, rec_voc_stat) + rec_generated_sound_loss += F.mse_loss(gened_sound, rec_gened_sound) + + # next step + next_state[dones[idx]] = 0.0 # Initialize with zero. + next_hidden[dones[idx]] = 0.0 # Initialize with zero. + + state = next_state + hidden = next_hidden + + rec_voc_state_loss /= chunk_size + rec_generated_sound_loss /= chunk_size + kl_div_loss /= chunk_size + rec_loss = rec_voc_state_loss + rec_generated_sound_loss + + loss = rec_loss + (not kl_div_loss.item() < self.free_nats) * kl_div_loss + + loss_dict = { + "loss": loss, + "rec_loss": rec_loss, + "rec_voc_state_loss": rec_voc_state_loss, + "rec_generated_sound_loss": rec_generated_sound_loss, + "kl_div_loss": kl_div_loss, + "over_free_nat": not kl_div_loss.item() < self.free_nats, + } + + experiences["hiddens"] = all_hiddens + experiences["states"] = all_states + + return loss_dict, experiences + + def controller_training_step( + self, experiences: dict[str, np.ndarray] + ) -> tuple[dict[str, Any], dict[str, np.ndarray]]: + """Compute loss for training controller model. + Args: + experiences (dict[str, np.ndarray]): Collected experiences. + + Returns: + loss_dict (dict[str, Any]): loss and some other metric values. + experiences (dict[str, np.ndarray]): Collected experiences (No modification.) + """ + + device = self.device + dtype = self.dtype + + actions = experiences[buffer_names.ACTION] + voc_states = experiences[buffer_names.VOC_STATE] + generated_sounds = experiences[buffer_names.GENERATED_SOUND] + dones = experiences[buffer_names.DONE] + target_sounds = experiences[buffer_names.TARGET_SOUND] + old_hiddens = experiences["hiddens"] + old_states = experiences["states"] + + chunk_size, batch_size = actions.shape[:2] + + start_idx = np.random.randint(0, chunk_size - self.imagination_horizon, (chunk_size,)) From cae808f72cfc13f048fbdca86c531f03b662c4df Mon Sep 17 00:00:00 2001 From: Geson-anko <59220704+Geson-anko@users.noreply.github.com> Date: Tue, 24 Jan 2023 03:59:51 +0000 Subject: [PATCH 03/28] ADD trainer (wip) --- src/trainer.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 src/trainer.py diff --git a/src/trainer.py b/src/trainer.py new file mode 100644 index 0000000..42a5dbd --- /dev/null +++ b/src/trainer.py @@ -0,0 +1,72 @@ +import logging + +import gym +import torch +from torch.nn.utils import clip_grad_norm_ + +from .datamodules import buffer_names +from .datamodules.replay_buffer import ReplayBuffer +from .env.array_voc_state import VocStateObsNames as ObsNames +from .models.dreamer import Dreamer + +logger = logging.getLogger(__name__) + + +class Trainer: + """Trainer class for Dreamer.""" + + def __init__( + self, + num_episode: int = 1, + collect_experience_interval: int = 100, + batch_size: int = 8, + chunk_size: int = 64, + gradient_clip_value: float = 100.0, + ) -> None: + """ + Args: + + """ + self.__dict__.update(locals()) # Add all input args to class attribute. + + def fit(self, env: gym.Env, replay_buffer: ReplayBuffer, model: Dreamer) -> None: + """Fit + Args: + + """ + + world_optimizer, controller_optimizer = model.configure_optimizers() + + current_step = 0 + + logger.info("Fit started.") + for episode in range(self.num_episode): + logger.info(f"Episode {episode} is started.") + + logger.debug("Collecting experiences...") + replay_buffer = model.collect_experiences( + env, replay_buffer, self.num_collect_experience_steps + ) + logger.debug("Collected experiences.") + + for collect_interval in range(self.collect_experience_interval): + logger.debug(f"Collect interval: {collect_interval}") + + # Training World Model. + experiences_dict = replay_buffer.sample( + self.batch_size, self.chunk_size, chunk_first=True + ) + loss_dict, experiences_dict = model.world_training_step(experiences_dict) + + loss: torch.Tensor = loss_dict["loss"] + world_optimizer.zero_grad() + loss.backward() + params = [] + for p in world_optimizer.param_groups: + params += p["params"] + clip_grad_norm_(params, self.gradient_clip_value) + world_optimizer.step() + + # -- logging -- + + # Training Controller model. From 63fd8c9e54f50d6295e42e3f1eca115660177af6 Mon Sep 17 00:00:00 2001 From: Geson-anko <59220704+Geson-anko@users.noreply.github.com> Date: Tue, 24 Jan 2023 13:09:16 +0000 Subject: [PATCH 04/28] ADD eval and train --- src/models/abc/world.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/models/abc/world.py b/src/models/abc/world.py index 46fd458..226ae0c 100644 --- a/src/models/abc/world.py +++ b/src/models/abc/world.py @@ -45,7 +45,7 @@ def forward( next_obs: _tensor_or_any, *args: Any, **kwds: Any, - ) -> tuple[_dist_or_any, _dist_or_any, _tensor_or_any]: + ) -> tuple[_dist_or_any, _dist_or_any, _tensor_or_any, _tensor_or_any]: """Make world model transition. Args: @@ -65,3 +65,17 @@ def forward( next_state_posterior = self.obs_encoder.forward(next_hidden, next_obs) return next_state_prior, next_state_posterior, next_hidden + + def eval(self): + """Set models to evaluation mode.""" + self.transition.eval() + self.prior.eval() + self.obs_encoder.eval() + self.obs_decoder.eval() + + def train(self): + """Set models to training mode.""" + self.transition.train() + self.prior.train() + self.obs_encoder.train() + self.obs_decoder.train() From 9894368c955d8e2f17ef196b4852255b3c600c04 Mon Sep 17 00:00:00 2001 From: Geson-anko <59220704+Geson-anko@users.noreply.github.com> Date: Tue, 24 Jan 2023 13:11:03 +0000 Subject: [PATCH 05/28] =?UTF-8?q?=E3=81=A8=E3=82=8A=E3=81=82=E3=81=88?= =?UTF-8?q?=E3=81=9A=E5=8B=95=E4=BD=9C=E7=A2=BA=E8=AA=8D=E3=81=AF=E3=81=AA?= =?UTF-8?q?=E3=81=97=E3=81=A7=E4=BD=9C=E3=82=8B=E3=81=A0=E3=81=91=E3=81=A4?= =?UTF-8?q?=E3=81=8F=E3=81=A3=E3=81=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/dreamer.py | 151 +++++++++++++++++++++++++++++++++++++----- src/trainer.py | 81 +++++++++++++++++++++- 2 files changed, 214 insertions(+), 18 deletions(-) diff --git a/src/models/dreamer.py b/src/models/dreamer.py index 4338bc6..b7438bc 100644 --- a/src/models/dreamer.py +++ b/src/models/dreamer.py @@ -26,8 +26,8 @@ class Dreamer(nn.Module): """Dreamer model class.""" # Added attribute from Trainer - current_step: int - current_episode: int + current_step: int = 0 + current_episode: int = 0 device: torch.device dtype: torch.dtype @@ -45,6 +45,8 @@ def __init__( free_nats: float = 3.0, num_collect_experience_steps: int = 100, imagination_horizon: int = 32, + evaluation_steps: int = 44 * 60, + evaluation_blank_length: int = 22050, ) -> None: """ Args: @@ -59,6 +61,7 @@ def __init__( controller_optimizer (partial[Optimizer]): Partial instance of Optimizer class. free_nats (float): Ignore kl div loss when it is less then this value. + evaluation_blank_length (int): """ self.transition = transition @@ -83,6 +86,8 @@ def __init__( self.free_nats = free_nats self.num_collect_experience_steps = num_collect_experience_steps self.imagination_horizon = imagination_horizon + self.evaluation_steps = evaluation_steps + self.evaluation_blank_length = evaluation_blank_length def configure_optimizers(self) -> tuple[Optimizer, Optimizer]: """Configure world optimizer and controller optimizer. @@ -103,6 +108,7 @@ def configure_optimizers(self) -> tuple[Optimizer, Optimizer]: return [world_optim, con_optim] + @torch.no_grad() def collect_experiences(self, env: gym.Env, replay_buffer: ReplayBuffer) -> ReplayBuffer: """Explorer in env and collect experiences to replay buffer. @@ -114,21 +120,21 @@ def collect_experiences(self, env: gym.Env, replay_buffer: ReplayBuffer) -> Repl Returns: replay_buffer(ReplayBuffer): Same pointer of input replay_buffer. """ - device = self.agent.hidden.device - dtype = self.agent.hidden.dtype + device = self.device + dtype = self.dtype obs = env.reset() voc_state_np = obs[ObsNames.VOC_STATE] generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] - voc_state = torch.as_tensor(voc_state_np, dtype, device).squeeze(0) - generated = torch.as_tensor(generated_np, dtype, device).squeeze(0) - target = torch.as_tensor(target_np, dtype, device).squeeze(0) + voc_state = torch.as_tensor(voc_state_np, dtype, device).unsqueeze(0) + generated = torch.as_tensor(generated_np, dtype, device).unsqueeze(0) + target = torch.as_tensor(target_np, dtype, device).unsqueeze(0) for _ in range(self.num_collect_experience_steps): action = self.agent.explore(obs=(voc_state, generated), target=target) - action = action.cpu().unsqueeze(0).numpy() + action = action.cpu().squeeze(0).numpy() obs, _, done, _ = env.step(action) voc_state_np = obs[ObsNames.VOC_STATE] @@ -171,6 +177,8 @@ def world_training_step( device = self.agent.hidden.device dtype = self.agent.hidden.dtype + self.world.train() + actions = experiences[buffer_names.ACTION] voc_states = experiences[buffer_names.VOC_STATE] generated_sounds = experiences[buffer_names.GENERATED_SOUND] @@ -197,16 +205,16 @@ def world_training_step( gened_sound = torch.as_tensor(generated_sounds[idx], dtype, device) next_obs = (voc_stat, gened_sound) - next_hidden = self.transition.forward(hidden, state, action) - next_state_prior = self.prior.forward(next_hidden) - next_state_posterior = self.obs_encoder.forward(next_hidden, next_obs) + next_state_prior, next_state_posterior, next_hidden = self.world.forward( + hidden, state, action, next_obs + ) + next_state = next_state_posterior.rsample() + rec_voc_stat, rec_gened_sound = self.obs_decoder.forward(next_hidden, next_state) all_states[idx] = next_state.detach() all_hiddens[idx] = next_hidden.detach() - rec_voc_stat, rec_gened_sound = self.obs_decoder.forward(next_hidden, next_state) - # compute losses kl_div_loss = kl_divergence(next_state_posterior, next_state_prior).view( batch_size, -1 @@ -256,17 +264,126 @@ def controller_training_step( experiences (dict[str, np.ndarray]): Collected experiences (No modification.) """ + self.controller.train() + self.world.eval() + device = self.device dtype = self.dtype actions = experiences[buffer_names.ACTION] - voc_states = experiences[buffer_names.VOC_STATE] - generated_sounds = experiences[buffer_names.GENERATED_SOUND] dones = experiences[buffer_names.DONE] target_sounds = experiences[buffer_names.TARGET_SOUND] old_hiddens = experiences["hiddens"] - old_states = experiences["states"] chunk_size, batch_size = actions.shape[:2] - start_idx = np.random.randint(0, chunk_size - self.imagination_horizon, (chunk_size,)) + start_indices = np.random.randint(0, chunk_size - self.imagination_horizon, (batch_size,)) + batch_arange = np.arange(batch_size) + hidden = torch.as_tensor(old_hiddens[start_indices, batch_arange], dtype, device) + controller_hidden = torch.zeros( + batch_size, *self.controller.controller_hidden_shape, dtype=dtype, device=device + ) + state = self.prior.forward(hidden).sample() + + loss = 0.0 + for i in range(self.imagination_horizon): + indices = start_indices + i + target = torch.as_tensor(target_sounds[indices, batch_arange], dtype, device) + action, controller_hidden = self.controller.forward( + hidden, state, target, controller_hidden + ) + next_hidden = self.transition.forward(hidden, state, action) + next_state = self.prior.forward(next_hidden).sample() + rec_next_obs = self.obs_decoder.forward(next_hidden, next_state) + _, rec_gened_sound = rec_next_obs + + loss += F.mse_loss(target, rec_gened_sound) + + hidden = next_hidden + + hidden[dones[indices, batch_arange]] = old_hiddens[indices + 1, batch_arange] + state = self.prior.forward(hidden) + + loss /= self.imagination_horizon + + loss_dict = {"loss": loss} + + return loss_dict, experiences + + @torch.no_grad() + def evaluation_step(self, env: gym.Env) -> dict[str, Any]: + """Evaluation step. + Args: + env (gym.Env): PynkTrombone environment or its wrapper class. + + Returns: + loss_dict (dict[str, Any]): Returned metric values. + """ + self.world.eval() + self.controller.eval() + + device = self.device + dtype = self.dtype + + self.agent.reset() + + obs = env.reset() + voc_state_np = obs[ObsNames.VOC_STATE] + generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] + target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] + + voc_state = torch.as_tensor(voc_state_np, dtype, device).unsqueeze(0) + generated = torch.as_tensor(generated_np, dtype, device).unsqueeze(0) + target = torch.as_tensor(target_np, dtype, device).unsqueeze(0) + + generated_sound_waves = [] + target_sound_waves = [] + + blank = np.zeros(self.evaluation_blank_length) + + target_generated_mse = 0.0 + target_generated_mae = 0.0 + + for i in range(self.evaluation_steps): + target_sound_waves.append(obs[ObsNames.TARGET_SOUND_WAVE]) + + action = self.agent.act(obs=(voc_state, generated), target=target, probabilistic=False) + action = action.cpu().squeeze(0).numpy() + obs, _, done, _ = env.step(action) + + generated_sound_waves.append(obs[ObsNames.GENERATED_SOUND_WAVE]) + generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] + + target_generated_mse += np.mean((target_np - generated_np) ** 2) + target_generated_mae += np.mean(np.abs(target_np - generated_np)) + + if done: + obs = env.reset() + generated_sound_waves.append(blank) + target_sound_waves.append(blank) + self.agent.reset() + + voc_state_np = obs[ObsNames.VOC_STATE] + generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] + target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] + + voc_state = torch.as_tensor(voc_state_np, dtype, device).unsqueeze(0) + generated = torch.as_tensor(generated_np, dtype, device).unsqueeze(0) + target = torch.as_tensor(target_np, dtype, device).unsqueeze(0) + + target_generated_mae /= self.evaluation_steps + target_generated_mse /= self.evaluation_steps + + generated_sounds_for_log = np.concatenate(generated_sound_waves) + target_sounds_for_log = np.concatenate(target_sound_waves) + + # logging to tensorboard + generated_sounds_for_log + target_sounds_for_log + + loss_dict = { + "target_generated_mse": target_generated_mse, + "target_generated_mae": target_generated_mae, + } + + return loss_dict diff --git a/src/trainer.py b/src/trainer.py index 42a5dbd..f105066 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -1,8 +1,11 @@ import logging +from collections import OrderedDict +from typing import Any, Optional import gym import torch from torch.nn.utils import clip_grad_norm_ +from torch.optim import Optimizer from .datamodules import buffer_names from .datamodules.replay_buffer import ReplayBuffer @@ -12,6 +15,12 @@ logger = logging.getLogger(__name__) +class CheckPointNames: + MODEL = "model" + WORLD_OPTIMIZER = "world_optimizer" + CONTROLLER_OPTIMIZER = "controller_optimizer" + + class Trainer: """Trainer class for Dreamer.""" @@ -22,6 +31,11 @@ def __init__( batch_size: int = 8, chunk_size: int = 64, gradient_clip_value: float = 100.0, + evaluation_interval=10, + model_save_interval=20, + checkpoint_path: Optional[Any] = None, + device: Any = "cpu", + dtype: Any = torch.float32, ) -> None: """ Args: @@ -35,7 +49,13 @@ def fit(self, env: gym.Env, replay_buffer: ReplayBuffer, model: Dreamer) -> None """ + self.setup_model_attribute(model) + + model = model.to(self.device, self.dtype) + world_optimizer, controller_optimizer = model.configure_optimizers() + if self.checkpoint_path is not None: + self.load_checkpoint(model, world_optimizer, controller_optimizer) current_step = 0 @@ -43,6 +63,9 @@ def fit(self, env: gym.Env, replay_buffer: ReplayBuffer, model: Dreamer) -> None for episode in range(self.num_episode): logger.info(f"Episode {episode} is started.") + model.current_episode = episode + model.current_step = current_step + logger.debug("Collecting experiences...") replay_buffer = model.collect_experiences( env, replay_buffer, self.num_collect_experience_steps @@ -69,4 +92,60 @@ def fit(self, env: gym.Env, replay_buffer: ReplayBuffer, model: Dreamer) -> None # -- logging -- - # Training Controller model. + # ---- Training Controller model. ----- + loss_dict, experiences_dict = model.controller_training_step(experiences_dict) + + loss: torch.Tensor = loss_dict["loss"] + controller_optimizer.zero_grad() + loss.backward() + params = [] + for p in controller_optimizer.param_groups: + params += p["params"] + clip_grad_norm_(params, self.gradient_clip_value) + controller_optimizer.step() + + # logging + + if current_step % self.evaluation_interval == 0: + # ----- Evaluation steps ----- + loss_dict = model.evaluation_step(env) + + # logging + + if current_step % self.model_save_interval == 0: + self.save_checkpoint("/logs/...", model, world_optimizer, controller_optimizer) + + current_step += 1 + + self.save_checkpoint("/logs/...", model, world_optimizer, controller_optimizer) + + def setup_model_attribute(self, model: Dreamer): + """Add attribute for model training. + + Call this begin of training. + Args: + model (Dreamer): Dreamer model class. + """ + model.device = torch.device(self.device) + model.dtype = self.dtype + model.current_episode = 0 + model.current_step = 0 + + def save_checkpoint( + self, path: Any, model: Dreamer, world_optim: Optimizer, controller_optim: Optimizer + ) -> None: + """Saving checkpoint.""" + ckpt = OrderedDict() + ckpt[CheckPointNames.MODEL] = model.state_dict() + ckpt[CheckPointNames.WORLD_OPTIMIZER] = world_optim.state_dict() + ckpt[CheckPointNames.CONTROLLER_OPTIMIZER] = controller_optim.state_dict() + + torch.save(ckpt, path) + + def load_checkpoint( + self, path: Any, model: Dreamer, world_optim: Optimizer, controller_optim: Optimizer + ): + ckpt = torch.load(path, self.device) + model.load_state_dict(ckpt[CheckPointNames.MODEL]) + world_optim.load_state_dict(ckpt[CheckPointNames.WORLD_OPTIMIZER]) + controller_optim.load_state_dict(ckpt[CheckPointNames.CONTROLLER_OPTIMIZER]) From efee645303c59192811a06a1bb39eccd4a3b7167 Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Thu, 26 Jan 2023 19:22:50 +0900 Subject: [PATCH 06/28] ADD test_dreamer.py --- tests/models/test_dreamer.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 tests/models/test_dreamer.py diff --git a/tests/models/test_dreamer.py b/tests/models/test_dreamer.py new file mode 100644 index 0000000..861367c --- /dev/null +++ b/tests/models/test_dreamer.py @@ -0,0 +1,12 @@ +import torch +from src.models.dreamer import Dreamer +from ..datamodules import buffer_names +from src.datamodules.replay_buffer import ReplayBuffer +from src.env.array_voc_state import VocStateObsNames as ObsNames + + +cls = Dreamer + + +def test__init__(): + model = cls() \ No newline at end of file From 8ff4ed4aafa47d76782f4c71ad53e8e07191e0e7 Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Thu, 26 Jan 2023 23:34:02 +0900 Subject: [PATCH 07/28] ADD dummy layer to dummy classes --- tests/models/abc/dummy_classes.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/models/abc/dummy_classes.py b/tests/models/abc/dummy_classes.py index 91d2493..9f9691c 100644 --- a/tests/models/abc/dummy_classes.py +++ b/tests/models/abc/dummy_classes.py @@ -1,6 +1,7 @@ from typing import Any import torch +from torch.nn import Linear from torch import Tensor from torch.distributions import Normal @@ -21,6 +22,7 @@ class DummyTransition(Transition): def __init__(self, hidden_shape: tuple[int], *args: Any, **kwds: Any) -> None: super().__init__() self._hidden_shape = hidden_shape + self.dmy_lyr = Linear(8, 16) def forward(self, hidden: Tensor, state: Tensor, action: Tensor) -> Tensor: return torch.randn(hidden.shape) @@ -36,6 +38,7 @@ class DummyPrior(Prior): def __init__(self, state_shape: tuple[int], *args: Any, **kwds: Any) -> None: super().__init__() self._state_shape = state_shape + self.dmy_lyr = Linear(8, 16) def forward(self, hidden: Tensor) -> Normal: shape = (hidden.size(0), *self.state_shape) @@ -57,6 +60,7 @@ def __init__( super().__init__() self._state_shape = state_shape self.embedded_obs_shape = embedded_obs_shape + self.dmy_lyr = Linear(8, 16) def embed_observation(self, obs: tuple[Tensor, Tensor]) -> Tensor: v, g = obs @@ -87,6 +91,7 @@ def __init__( super().__init__() self._voc_state_shape = voc_state_shape self._generated_sound_shape = generated_sound_shape + self.dmy_lyr = Linear(8, 16) def forward(self, hidden: Tensor, state: Tensor) -> Tensor: vs_shape = (hidden.size(0), *self._voc_state_shape) @@ -108,6 +113,7 @@ def __init__( self._action_shape = action_shape self._controller_hidden_shape = controller_hidden_shape + self.dmy_lyr = Linear(8, 16) def forward( self, From 9de156ce42aa08aba73e0a7740eb6735a61bf59e Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Fri, 27 Jan 2023 14:10:11 +0900 Subject: [PATCH 08/28] ADD test__init__, test_configure_optimizers, test_collect_experiences --- tests/models/test_dreamer.py | 94 ++++++++++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 5 deletions(-) diff --git a/tests/models/test_dreamer.py b/tests/models/test_dreamer.py index 861367c..5e2aa42 100644 --- a/tests/models/test_dreamer.py +++ b/tests/models/test_dreamer.py @@ -1,12 +1,96 @@ +import glob +import pathlib +from functools import partial + +import numpy as np +import pytest import torch -from src.models.dreamer import Dreamer -from ..datamodules import buffer_names +from gym.spaces import Box +from pynktrombonegym.env import PynkTrombone as PT +from torch.optim import AdamW + +from src.env.array_voc_state import ARRAY_ORDER as AO_voc +from src.env.array_action import ArrayAction as AA +from pynktrombonegym.spaces import ObservationSpaceNames as OSN +from src.env.array_action import ARRAY_ORDER as AO_act from src.datamodules.replay_buffer import ReplayBuffer -from src.env.array_voc_state import VocStateObsNames as ObsNames +from src.env.array_voc_state import ArrayVocState as AVS, VSON +from src.models.dreamer import Dreamer +from tests.models.abc.dummy_classes import DummyAgent as DA +from tests.models.abc.dummy_classes import DummyController as DC +from tests.models.abc.dummy_classes import DummyObservationDecoder as DOD +from tests.models.abc.dummy_classes import DummyObservationEncoder as DOE +from tests.models.abc.dummy_classes import DummyPrior as DP +from tests.models.abc.dummy_classes import DummyTransition as DT +from tests.models.abc.dummy_classes import DummyWorld as DW +from src.datamodules import buffer_names + +target_file_path = pathlib.Path(__file__).parents[2].joinpath("data/sample_target_sounds/*.wav") +target_files = glob.glob(str(target_file_path)) +env = AA(AVS(PT(target_files))) + +hidden_shape = (16,) +ctrl_hidden_shape = (16,) +state_shape = (8,) +action_shape = (len(AO_act),) + +obs_space = env.observation_space +voc_stats_shape = obs_space[VSON.VOC_STATE].shape +rnn_input_shape = (8,) + +gen_sound_shape = obs_space[OSN.GENERATED_SOUND_SPECTROGRAM].shape +tgt_sound_shape = obs_space[OSN.TARGET_SOUND_SPECTROGRAM].shape +obs_shape = (24,) + +obs_enc = DOE(state_shape, obs_shape) +obs_dec = DOD( + voc_stats_shape, + gen_sound_shape, +) +prior = DP(state_shape) +trans = DT(hidden_shape) +ctrl = DC(action_shape, ctrl_hidden_shape) + +world_opt = AdamW +ctrl_opt = AdamW + -cls = Dreamer +bf_size = 32 +bf_space = { + buffer_names.ACTION: Box(-np.inf, np.inf, action_shape), + buffer_names.DONE: Box(-np.inf, np.inf, (1,)), + buffer_names.GENERATED_SOUND: Box(-np.inf,np.inf, gen_sound_shape), + buffer_names.TARGET_SOUND: Box(-np.inf, np.inf, tgt_sound_shape), + buffer_names.VOC_STATE: Box(-np.inf, np.inf, voc_stats_shape), +} +args = (trans, prior, obs_enc, obs_dec, ctrl, DW, DA, world_opt, ctrl_opt) def test__init__(): - model = cls() \ No newline at end of file + model = Dreamer(*args) + + +def test_configure_optimizers(): + model = Dreamer(*args) + opt1, opt2 = model.configure_optimizers() + + +@pytest.mark.parametrize("num_steps", [1, 2, 3]) +def test_collect_experiences(num_steps): + rb = ReplayBuffer(bf_space, bf_size) + model = Dreamer(*args, num_collect_experience_steps=num_steps) + model.collect_experiences(env, rb) + assert rb.current_index == num_steps + + +def test_world_training_step(): + pass + + +def test_controller_training_step(): + pass + + +def test_evaluation_step(): + pass From 82e1e7ab3db4f8fc9d24d7cdf6733aaabafebbdf Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Fri, 27 Jan 2023 14:12:05 +0900 Subject: [PATCH 09/28] ADD import --- tests/models/abc/dummy_classes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/abc/dummy_classes.py b/tests/models/abc/dummy_classes.py index 9f9691c..26f61b3 100644 --- a/tests/models/abc/dummy_classes.py +++ b/tests/models/abc/dummy_classes.py @@ -1,9 +1,9 @@ from typing import Any import torch -from torch.nn import Linear from torch import Tensor from torch.distributions import Normal +from torch.nn import Linear from src.models.abc.agent import Agent from src.models.abc.controller import Controller From 63910856e569b4c05d694732d17aa73069581c86 Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Fri, 27 Jan 2023 14:12:33 +0900 Subject: [PATCH 10/28] Fix as_tensor --- src/models/dreamer.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/models/dreamer.py b/src/models/dreamer.py index b7438bc..4315fc6 100644 --- a/src/models/dreamer.py +++ b/src/models/dreamer.py @@ -47,6 +47,8 @@ def __init__( imagination_horizon: int = 32, evaluation_steps: int = 44 * 60, evaluation_blank_length: int = 22050, + device: str = "cpu", + dtype: np.dtype = torch.float32, ) -> None: """ Args: @@ -64,6 +66,7 @@ def __init__( evaluation_blank_length (int): """ + super().__init__() self.transition = transition self.prior = prior self.obs_encoder = obs_encoder @@ -89,6 +92,9 @@ def __init__( self.evaluation_steps = evaluation_steps self.evaluation_blank_length = evaluation_blank_length + self.device = device + self.dtype = dtype + def configure_optimizers(self) -> tuple[Optimizer, Optimizer]: """Configure world optimizer and controller optimizer. @@ -128,9 +134,9 @@ def collect_experiences(self, env: gym.Env, replay_buffer: ReplayBuffer) -> Repl generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] - voc_state = torch.as_tensor(voc_state_np, dtype, device).unsqueeze(0) - generated = torch.as_tensor(generated_np, dtype, device).unsqueeze(0) - target = torch.as_tensor(target_np, dtype, device).unsqueeze(0) + voc_state = torch.as_tensor(voc_state_np, dtype=dtype, device=device).unsqueeze(0) + generated = torch.as_tensor(generated_np, dtype=dtype, device=device).unsqueeze(0) + target = torch.as_tensor(target_np, dtype=dtype, device=device).unsqueeze(0) for _ in range(self.num_collect_experience_steps): action = self.agent.explore(obs=(voc_state, generated), target=target) @@ -140,7 +146,7 @@ def collect_experiences(self, env: gym.Env, replay_buffer: ReplayBuffer) -> Repl voc_state_np = obs[ObsNames.VOC_STATE] generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] - + print(voc_state_np.shape) sample = { buffer_names.ACTION: action, buffer_names.VOC_STATE: voc_state_np, @@ -157,9 +163,9 @@ def collect_experiences(self, env: gym.Env, replay_buffer: ReplayBuffer) -> Repl generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] - voc_state = torch.as_tensor(voc_state_np, dtype, device).squeeze(0) - generated = torch.as_tensor(generated_np, dtype, device).squeeze(0) - target = torch.as_tensor(target_np, dtype, device).squeeze(0) + voc_state = torch.as_tensor(voc_state_np, dtype=dtype, device=device).squeeze(0) + generated = torch.as_tensor(generated_np, dtype=dtype, device=device).squeeze(0) + target = torch.as_tensor(target_np, dtype=dtype, device=device).squeeze(0) def world_training_step( self, experiences: dict[str, np.ndarray] From c8c25cf1ceb5aec0fd25ae913abee70e3bcb0252 Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Fri, 27 Jan 2023 14:50:37 +0900 Subject: [PATCH 11/28] ADD test_world_training_step --- tests/models/test_dreamer.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/tests/models/test_dreamer.py b/tests/models/test_dreamer.py index 5e2aa42..0d798c1 100644 --- a/tests/models/test_dreamer.py +++ b/tests/models/test_dreamer.py @@ -7,14 +7,16 @@ import torch from gym.spaces import Box from pynktrombonegym.env import PynkTrombone as PT +from pynktrombonegym.spaces import ObservationSpaceNames as OSN from torch.optim import AdamW -from src.env.array_voc_state import ARRAY_ORDER as AO_voc -from src.env.array_action import ArrayAction as AA -from pynktrombonegym.spaces import ObservationSpaceNames as OSN -from src.env.array_action import ARRAY_ORDER as AO_act +from src.datamodules import buffer_names from src.datamodules.replay_buffer import ReplayBuffer -from src.env.array_voc_state import ArrayVocState as AVS, VSON +from src.env.array_action import ARRAY_ORDER as AO_act +from src.env.array_action import ArrayAction as AA +from src.env.array_voc_state import ARRAY_ORDER as AO_voc +from src.env.array_voc_state import VSON +from src.env.array_voc_state import ArrayVocState as AVS from src.models.dreamer import Dreamer from tests.models.abc.dummy_classes import DummyAgent as DA from tests.models.abc.dummy_classes import DummyController as DC @@ -23,7 +25,6 @@ from tests.models.abc.dummy_classes import DummyPrior as DP from tests.models.abc.dummy_classes import DummyTransition as DT from tests.models.abc.dummy_classes import DummyWorld as DW -from src.datamodules import buffer_names target_file_path = pathlib.Path(__file__).parents[2].joinpath("data/sample_target_sounds/*.wav") target_files = glob.glob(str(target_file_path)) @@ -55,12 +56,11 @@ ctrl_opt = AdamW - bf_size = 32 bf_space = { buffer_names.ACTION: Box(-np.inf, np.inf, action_shape), buffer_names.DONE: Box(-np.inf, np.inf, (1,)), - buffer_names.GENERATED_SOUND: Box(-np.inf,np.inf, gen_sound_shape), + buffer_names.GENERATED_SOUND: Box(-np.inf, np.inf, gen_sound_shape), buffer_names.TARGET_SOUND: Box(-np.inf, np.inf, tgt_sound_shape), buffer_names.VOC_STATE: Box(-np.inf, np.inf, voc_stats_shape), } @@ -85,7 +85,18 @@ def test_collect_experiences(num_steps): def test_world_training_step(): - pass + model = Dreamer(*args, num_collect_experience_steps=128) + rb = ReplayBuffer(bf_space, bf_size) + rb = model.collect_experiences(env, rb) + experience = rb.sample(1, chunk_length=16) + print(rb.current_index) + loss_dict, experience = model.world_training_step(experience) + assert experience.get("hiddens") is not None + assert experience.get("states") is not None + + model.world_optimizer.zero_grad() + loss_dict["loss"].backward() + model.world_optimizer.step() def test_controller_training_step(): From 1dd226188aa4a1cac1f485a52ec0dc1fe4dd14f4 Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Fri, 27 Jan 2023 14:51:49 +0900 Subject: [PATCH 12/28] Removed print() for debug --- src/models/dreamer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/models/dreamer.py b/src/models/dreamer.py index 4315fc6..46b66f8 100644 --- a/src/models/dreamer.py +++ b/src/models/dreamer.py @@ -146,7 +146,6 @@ def collect_experiences(self, env: gym.Env, replay_buffer: ReplayBuffer) -> Repl voc_state_np = obs[ObsNames.VOC_STATE] generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] - print(voc_state_np.shape) sample = { buffer_names.ACTION: action, buffer_names.VOC_STATE: voc_state_np, From 961b1e0eb99c72fbe93f430a5cbde074a7eeeb4c Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Fri, 27 Jan 2023 15:04:21 +0900 Subject: [PATCH 13/28] Fix as_tensor arguments --- src/models/dreamer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/models/dreamer.py b/src/models/dreamer.py index 46b66f8..bd0d1a6 100644 --- a/src/models/dreamer.py +++ b/src/models/dreamer.py @@ -204,10 +204,10 @@ def world_training_step( all_kl_div_loss = 0.0 for idx in range(chunk_size): - action = torch.as_tensor(actions[idx], dtype, device) + action = torch.as_tensor(actions[idx], dtype=dtype, device=device) - voc_stat = torch.as_tensor(voc_states[idx], dtype, device) - gened_sound = torch.as_tensor(generated_sounds[idx], dtype, device) + voc_stat = torch.as_tensor(voc_states[idx], dtype=dtype, device=device) + gened_sound = torch.as_tensor(generated_sounds[idx], dtype=dtype, device=device) next_obs = (voc_stat, gened_sound) next_state_prior, next_state_posterior, next_hidden = self.world.forward( From e2aef99291f61abaa0c8e3e1a26b85380af8243c Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Fri, 27 Jan 2023 15:34:50 +0900 Subject: [PATCH 14/28] Fix test_world_training_step --- tests/models/test_dreamer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/models/test_dreamer.py b/tests/models/test_dreamer.py index 0d798c1..af82c5a 100644 --- a/tests/models/test_dreamer.py +++ b/tests/models/test_dreamer.py @@ -87,17 +87,14 @@ def test_collect_experiences(num_steps): def test_world_training_step(): model = Dreamer(*args, num_collect_experience_steps=128) rb = ReplayBuffer(bf_space, bf_size) - rb = model.collect_experiences(env, rb) + _, __ = model.configure_optimizers() + model.collect_experiences(env, rb) experience = rb.sample(1, chunk_length=16) - print(rb.current_index) + print(rb.is_capacity_reached) loss_dict, experience = model.world_training_step(experience) assert experience.get("hiddens") is not None assert experience.get("states") is not None - model.world_optimizer.zero_grad() - loss_dict["loss"].backward() - model.world_optimizer.step() - def test_controller_training_step(): pass From bb4971709b2bfa264bb17aa7b71dba88da1db25c Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Fri, 27 Jan 2023 15:41:16 +0900 Subject: [PATCH 15/28] Fix loss computation --- src/models/dreamer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/models/dreamer.py b/src/models/dreamer.py index bd0d1a6..75b9c26 100644 --- a/src/models/dreamer.py +++ b/src/models/dreamer.py @@ -225,7 +225,6 @@ def world_training_step( batch_size, -1 ) all_kl_div_loss += kl_div_loss.sum(-1).mean() - rec_voc_state_loss += F.mse_loss(voc_stat, rec_voc_stat) rec_generated_sound_loss += F.mse_loss(gened_sound, rec_gened_sound) @@ -238,10 +237,9 @@ def world_training_step( rec_voc_state_loss /= chunk_size rec_generated_sound_loss /= chunk_size - kl_div_loss /= chunk_size + all_kl_div_loss /= chunk_size rec_loss = rec_voc_state_loss + rec_generated_sound_loss - - loss = rec_loss + (not kl_div_loss.item() < self.free_nats) * kl_div_loss + loss = rec_loss + (not all_kl_div_loss.item() < self.free_nats) * all_kl_div_loss loss_dict = { "loss": loss, @@ -249,9 +247,9 @@ def world_training_step( "rec_voc_state_loss": rec_voc_state_loss, "rec_generated_sound_loss": rec_generated_sound_loss, "kl_div_loss": kl_div_loss, - "over_free_nat": not kl_div_loss.item() < self.free_nats, + "over_free_nat": not all_kl_div_loss.item() < self.free_nats, } - + experiences["hiddens"] = all_hiddens experiences["states"] = all_states From 9cdca3919245f7c675998ab9f33263fefc0b5c49 Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Fri, 27 Jan 2023 18:12:22 +0900 Subject: [PATCH 16/28] Fix the order of wrapping env --- tests/models/test_dreamer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/models/test_dreamer.py b/tests/models/test_dreamer.py index af82c5a..719cce1 100644 --- a/tests/models/test_dreamer.py +++ b/tests/models/test_dreamer.py @@ -6,14 +6,18 @@ import pytest import torch from gym.spaces import Box -from pynktrombonegym.env import PynkTrombone as PT +from pynktrombonegym.wrappers import Log1pMelSpectrogram as L1MS +from pynktrombonegym.wrappers import ActionByAcceleration as ABA +from src.env.normalize_action_range import NormalizeActionRange as NAR +from src.env.array_action import ArrayAction as AA +from src.env.array_voc_state import ArrayVocState as AVS from pynktrombonegym.spaces import ObservationSpaceNames as OSN from torch.optim import AdamW from src.datamodules import buffer_names from src.datamodules.replay_buffer import ReplayBuffer from src.env.array_action import ARRAY_ORDER as AO_act -from src.env.array_action import ArrayAction as AA + from src.env.array_voc_state import ARRAY_ORDER as AO_voc from src.env.array_voc_state import VSON from src.env.array_voc_state import ArrayVocState as AVS @@ -28,7 +32,7 @@ target_file_path = pathlib.Path(__file__).parents[2].joinpath("data/sample_target_sounds/*.wav") target_files = glob.glob(str(target_file_path)) -env = AA(AVS(PT(target_files))) +env = AVS(AA(NAR(ABA(L1MS(target_files), action_scaler=1.0)))) hidden_shape = (16,) ctrl_hidden_shape = (16,) From 1440aefe2abb3f78e5fd1ae09719f320b8e034ff Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Sat, 28 Jan 2023 10:08:42 +0900 Subject: [PATCH 17/28] Fix optimizer and initializing instances --- tests/models/test_dreamer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/models/test_dreamer.py b/tests/models/test_dreamer.py index 719cce1..1546c0e 100644 --- a/tests/models/test_dreamer.py +++ b/tests/models/test_dreamer.py @@ -12,7 +12,7 @@ from src.env.array_action import ArrayAction as AA from src.env.array_voc_state import ArrayVocState as AVS from pynktrombonegym.spaces import ObservationSpaceNames as OSN -from torch.optim import AdamW +from torch.optim import SGD from src.datamodules import buffer_names from src.datamodules.replay_buffer import ReplayBuffer @@ -56,8 +56,11 @@ trans = DT(hidden_shape) ctrl = DC(action_shape, ctrl_hidden_shape) -world_opt = AdamW -ctrl_opt = AdamW +d_world = partial(DW()) +d_agent = partial(DA()) + +world_opt = partial(SGD()) +ctrl_opt = partial(SGD()) bf_size = 32 @@ -68,7 +71,7 @@ buffer_names.TARGET_SOUND: Box(-np.inf, np.inf, tgt_sound_shape), buffer_names.VOC_STATE: Box(-np.inf, np.inf, voc_stats_shape), } -args = (trans, prior, obs_enc, obs_dec, ctrl, DW, DA, world_opt, ctrl_opt) +args = (trans, prior, obs_enc, obs_dec, ctrl, d_world, d_agent, world_opt, ctrl_opt) def test__init__(): From 85b596a800b9ba1b8482e09d7e90cb65443348eb Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Sat, 28 Jan 2023 11:04:18 +0900 Subject: [PATCH 18/28] Fix device, dtype --- src/models/dreamer.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/models/dreamer.py b/src/models/dreamer.py index 75b9c26..68a4945 100644 --- a/src/models/dreamer.py +++ b/src/models/dreamer.py @@ -28,8 +28,8 @@ class Dreamer(nn.Module): # Added attribute from Trainer current_step: int = 0 current_episode: int = 0 - device: torch.device - dtype: torch.dtype + device: torch.device = "cpu" + dtype: torch.dtype = torch.float32 def __init__( self, @@ -47,8 +47,6 @@ def __init__( imagination_horizon: int = 32, evaluation_steps: int = 44 * 60, evaluation_blank_length: int = 22050, - device: str = "cpu", - dtype: np.dtype = torch.float32, ) -> None: """ Args: @@ -92,8 +90,6 @@ def __init__( self.evaluation_steps = evaluation_steps self.evaluation_blank_length = evaluation_blank_length - self.device = device - self.dtype = dtype def configure_optimizers(self) -> tuple[Optimizer, Optimizer]: """Configure world optimizer and controller optimizer. @@ -279,10 +275,9 @@ def controller_training_step( old_hiddens = experiences["hiddens"] chunk_size, batch_size = actions.shape[:2] - start_indices = np.random.randint(0, chunk_size - self.imagination_horizon, (batch_size,)) batch_arange = np.arange(batch_size) - hidden = torch.as_tensor(old_hiddens[start_indices, batch_arange], dtype, device) + hidden = torch.as_tensor(old_hiddens[start_indices, batch_arange], dtype=dtype, device=device) controller_hidden = torch.zeros( batch_size, *self.controller.controller_hidden_shape, dtype=dtype, device=device ) @@ -291,9 +286,9 @@ def controller_training_step( loss = 0.0 for i in range(self.imagination_horizon): indices = start_indices + i - target = torch.as_tensor(target_sounds[indices, batch_arange], dtype, device) + target = torch.as_tensor(target_sounds[indices, batch_arange], dtype=dtype, device=device) action, controller_hidden = self.controller.forward( - hidden, state, target, controller_hidden + hidden, state, target, controller_hidden, probabilistic=True ) next_hidden = self.transition.forward(hidden, state, action) next_state = self.prior.forward(next_hidden).sample() From b3383723b7eddc740576d5e0827fdb3fc6cad9ba Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Sat, 28 Jan 2023 11:12:44 +0900 Subject: [PATCH 19/28] ADD test_controller_training_step --- tests/models/test_dreamer.py | 39 ++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/tests/models/test_dreamer.py b/tests/models/test_dreamer.py index 1546c0e..2d6d99a 100644 --- a/tests/models/test_dreamer.py +++ b/tests/models/test_dreamer.py @@ -56,11 +56,12 @@ trans = DT(hidden_shape) ctrl = DC(action_shape, ctrl_hidden_shape) -d_world = partial(DW()) -d_agent = partial(DA()) -world_opt = partial(SGD()) -ctrl_opt = partial(SGD()) +d_world = partial(DW) +d_agent = partial(DA) + +world_opt = partial(SGD, lr=1e-3) +ctrl_opt = partial(SGD, lr=1e-3) bf_size = 32 @@ -72,6 +73,15 @@ buffer_names.VOC_STATE: Box(-np.inf, np.inf, voc_stats_shape), } args = (trans, prior, obs_enc, obs_dec, ctrl, d_world, d_agent, world_opt, ctrl_opt) +del env + +def world_training_step(model, env): + rb = ReplayBuffer(bf_space, bf_size) + _, __ = model.configure_optimizers() + model.collect_experiences(env, rb) + experience = rb.sample(1, chunk_length=16) + loss_dict, experience = model.world_training_step(experience) + return loss_dict, experience def test__init__(): @@ -85,27 +95,34 @@ def test_configure_optimizers(): @pytest.mark.parametrize("num_steps", [1, 2, 3]) def test_collect_experiences(num_steps): + env = AVS(AA(NAR(ABA(L1MS(target_files), action_scaler=1.0)))) rb = ReplayBuffer(bf_space, bf_size) model = Dreamer(*args, num_collect_experience_steps=num_steps) model.collect_experiences(env, rb) assert rb.current_index == num_steps + del env def test_world_training_step(): + env = AVS(AA(NAR(ABA(L1MS(target_files), action_scaler=1.0)))) model = Dreamer(*args, num_collect_experience_steps=128) - rb = ReplayBuffer(bf_space, bf_size) - _, __ = model.configure_optimizers() - model.collect_experiences(env, rb) - experience = rb.sample(1, chunk_length=16) - print(rb.is_capacity_reached) - loss_dict, experience = model.world_training_step(experience) + loss_dict, experience = world_training_step(model, env) assert experience.get("hiddens") is not None assert experience.get("states") is not None + assert loss_dict.get("loss") is not None + del env def test_controller_training_step(): - pass + # World Training Step + env = AVS(AA(NAR(ABA(L1MS(target_files), action_scaler=1.0)))) + model = Dreamer(*args, imagination_horizon=4) + _, experience = world_training_step(model, env) + loss_dict, _ = model.controller_training_step(experience) + assert loss_dict.get("loss") is not None + del env def test_evaluation_step(): pass + From 35100ddd7b26f9b872107a733577e487829bc7a6 Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Sat, 28 Jan 2023 11:36:09 +0900 Subject: [PATCH 20/28] Fix device, dtype --- src/models/dreamer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/models/dreamer.py b/src/models/dreamer.py index 68a4945..03474cd 100644 --- a/src/models/dreamer.py +++ b/src/models/dreamer.py @@ -330,9 +330,9 @@ def evaluation_step(self, env: gym.Env) -> dict[str, Any]: generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] - voc_state = torch.as_tensor(voc_state_np, dtype, device).unsqueeze(0) - generated = torch.as_tensor(generated_np, dtype, device).unsqueeze(0) - target = torch.as_tensor(target_np, dtype, device).unsqueeze(0) + voc_state = torch.as_tensor(voc_state_np, dtype=dtype, device=device).unsqueeze(0) + generated = torch.as_tensor(generated_np, dtype=dtype, device=device).unsqueeze(0) + target = torch.as_tensor(target_np, dtype=dtype, device=device).unsqueeze(0) generated_sound_waves = [] target_sound_waves = [] @@ -365,9 +365,9 @@ def evaluation_step(self, env: gym.Env) -> dict[str, Any]: generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] - voc_state = torch.as_tensor(voc_state_np, dtype, device).unsqueeze(0) - generated = torch.as_tensor(generated_np, dtype, device).unsqueeze(0) - target = torch.as_tensor(target_np, dtype, device).unsqueeze(0) + voc_state = torch.as_tensor(voc_state_np, dtype=dtype, device=device).unsqueeze(0) + generated = torch.as_tensor(generated_np, dtype=dtype, device=device).unsqueeze(0) + target = torch.as_tensor(target_np, dtype=dtype, device=device).unsqueeze(0) target_generated_mae /= self.evaluation_steps target_generated_mse /= self.evaluation_steps From b6814b4188fe8514cf4f1431508873dbfeeb613d Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Sat, 28 Jan 2023 12:18:25 +0900 Subject: [PATCH 21/28] ADD configure_replay_buffer --- src/models/dreamer.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/models/dreamer.py b/src/models/dreamer.py index 03474cd..17cc550 100644 --- a/src/models/dreamer.py +++ b/src/models/dreamer.py @@ -3,6 +3,7 @@ from typing import Any import gym +from gym.spaces import Box import numpy as np import torch import torch.nn as nn @@ -20,6 +21,7 @@ from .abc.prior import Prior from .abc.transition import Transition from .abc.world import World +from ..env.array_voc_state import VocStateObsNames as VSON class Dreamer(nn.Module): @@ -110,6 +112,22 @@ def configure_optimizers(self) -> tuple[Optimizer, Optimizer]: return [world_optim, con_optim] + def configure_replay_buffer(self, env: gym.Env, buffer_size:int): + act_sp = env.action_space + voc_st_sp = env.observation_space[VSON.VOC_STATE] + tgt_sound_sp = env.observation_space[VSON.TARGET_SOUND_WAVE] + gen_sound_sp = env.observation_space[VSON.GENERATED_SOUND_WAVE] + spaces = {} + spaces[buffer_names.ACTION] = Box(act_sp.low, act_sp.high, act_sp.shape, act_sp.dtype) + spaces[buffer_names.VOC_STATE] = Box(voc_st_sp.low, voc_st_sp.high, voc_st_sp.shape, voc_st_sp.dtype) + spaces[buffer_names.GENERATED_SOUND] = Box(tgt_sound_sp.low, tgt_sound_sp.high, shape=tgt_sound_sp.shape, dtype=tgt_sound_sp.dtype) + spaces[buffer_names.TARGET_SOUND] =Box(gen_sound_sp.low, gen_sound_sp.high, shape=gen_sound_sp.shape, dtype=gen_sound_sp.dtype) + spaces[buffer_names.DONE] = Box(0, 1, shape=(1,), dtype=bool) + + replay_buffer = ReplayBuffer(spaces, buffer_size) + + return replay_buffer + @torch.no_grad() def collect_experiences(self, env: gym.Env, replay_buffer: ReplayBuffer) -> ReplayBuffer: """Explorer in env and collect experiences to replay buffer. From a81d124c993bff64e815563ce9e4fc45e37b48f8 Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Sat, 28 Jan 2023 12:23:55 +0900 Subject: [PATCH 22/28] ADD test_evaluation_step --- tests/models/test_dreamer.py | 38 +++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/tests/models/test_dreamer.py b/tests/models/test_dreamer.py index 2d6d99a..c641f9d 100644 --- a/tests/models/test_dreamer.py +++ b/tests/models/test_dreamer.py @@ -82,11 +82,38 @@ def world_training_step(model, env): experience = rb.sample(1, chunk_length=16) loss_dict, experience = model.world_training_step(experience) return loss_dict, experience - +def _hasattrs(model): + attributes = ( + "transition", + "prior", + "obs_encoder", + "obs_decoder", + "controller", + "world", + "agent", + "world_optimizer", + "controller_optimizer", + "free_nats", + "num_collect_experience_steps", + "imagination_horizon", + "evaluation_steps", + "evaluation_blank_length", + ) + have_attr = [] + for attr in attributes: + have_attr += [hasattr(model, attr)] + print(hasattr(model, attr)) + return have_attr, attributes def test__init__(): model = Dreamer(*args) - + has_attrs, attrs = _hasattrs(model) + for idx, has_attr_flag in enumerate(has_attrs): + if has_attr_flag: + pass + else: + assert False, f"attribute {attrs[idx]} doesn't set" + def test_configure_optimizers(): model = Dreamer(*args) @@ -124,5 +151,10 @@ def test_controller_training_step(): def test_evaluation_step(): - pass + env = AVS(AA(NAR(ABA(L1MS(target_files), action_scaler=1.0)))) + model = Dreamer(*args) + loss_dict = model.evaluation_step(env) + assert loss_dict.get("target_generated_mse") is not None + assert loss_dict.get("target_generated_mae") is not None + del env From c85a75c8516d436d8cd196a6362b7ade5604dd5f Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Sat, 28 Jan 2023 12:36:30 +0900 Subject: [PATCH 23/28] ADD test_configure_replay_buffer --- tests/models/test_dreamer.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/models/test_dreamer.py b/tests/models/test_dreamer.py index c641f9d..c7e4e32 100644 --- a/tests/models/test_dreamer.py +++ b/tests/models/test_dreamer.py @@ -158,3 +158,17 @@ def test_evaluation_step(): assert loss_dict.get("target_generated_mae") is not None del env +def test_configure_replay_buffer(): + env = AVS(AA(NAR(ABA(L1MS(target_files), action_scaler=1.0)))) + model = Dreamer(*args) + rb = model.configure_replay_buffer(env, bf_size) + spaces = { + buffer_names.VOC_STATE: env.observation_space[VSON.VOC_STATE], + buffer_names.ACTION : env.action_space, + buffer_names.TARGET_SOUND: env.observation_space[VSON.TARGET_SOUND_WAVE], + buffer_names.GENERATED_SOUND: env.observation_space[VSON.GENERATED_SOUND_WAVE], + buffer_names.DONE: Box(0, 1, shape=(1,), dtype=bool) + } + for name, box in rb.spaces.items(): + assert box == spaces.get(name), f"space {name} isn't set properly(box:{box}" + del env \ No newline at end of file From 1c97d61bb43926f557df81323058665824853eb4 Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Sat, 28 Jan 2023 12:37:17 +0900 Subject: [PATCH 24/28] Fix configure_replay_buffer --- src/models/dreamer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/models/dreamer.py b/src/models/dreamer.py index 17cc550..f28e8d3 100644 --- a/src/models/dreamer.py +++ b/src/models/dreamer.py @@ -113,15 +113,15 @@ def configure_optimizers(self) -> tuple[Optimizer, Optimizer]: return [world_optim, con_optim] def configure_replay_buffer(self, env: gym.Env, buffer_size:int): - act_sp = env.action_space - voc_st_sp = env.observation_space[VSON.VOC_STATE] - tgt_sound_sp = env.observation_space[VSON.TARGET_SOUND_WAVE] - gen_sound_sp = env.observation_space[VSON.GENERATED_SOUND_WAVE] + action_box = env.action_space + vocal_state_box = env.observation_space[VSON.VOC_STATE] + target_sound_box = env.observation_space[VSON.TARGET_SOUND_WAVE] + generated_sound_box = env.observation_space[VSON.GENERATED_SOUND_WAVE] spaces = {} - spaces[buffer_names.ACTION] = Box(act_sp.low, act_sp.high, act_sp.shape, act_sp.dtype) - spaces[buffer_names.VOC_STATE] = Box(voc_st_sp.low, voc_st_sp.high, voc_st_sp.shape, voc_st_sp.dtype) - spaces[buffer_names.GENERATED_SOUND] = Box(tgt_sound_sp.low, tgt_sound_sp.high, shape=tgt_sound_sp.shape, dtype=tgt_sound_sp.dtype) - spaces[buffer_names.TARGET_SOUND] =Box(gen_sound_sp.low, gen_sound_sp.high, shape=gen_sound_sp.shape, dtype=gen_sound_sp.dtype) + spaces[buffer_names.ACTION] = action_box + spaces[buffer_names.VOC_STATE] = vocal_state_box + spaces[buffer_names.GENERATED_SOUND] = target_sound_box + spaces[buffer_names.TARGET_SOUND] = generated_sound_box spaces[buffer_names.DONE] = Box(0, 1, shape=(1,), dtype=bool) replay_buffer = ReplayBuffer(spaces, buffer_size) From 6af48f3fc0202a186f0d6e8ee239f837c7d653d3 Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Sat, 28 Jan 2023 12:37:55 +0900 Subject: [PATCH 25/28] pre-commit --- src/models/dreamer.py | 19 +++++++++++-------- tests/models/test_dreamer.py | 23 +++++++++++++---------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/models/dreamer.py b/src/models/dreamer.py index f28e8d3..3d8a68c 100644 --- a/src/models/dreamer.py +++ b/src/models/dreamer.py @@ -3,11 +3,11 @@ from typing import Any import gym -from gym.spaces import Box import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from gym.spaces import Box from torch import Tensor from torch.distributions import kl_divergence from torch.optim import Optimizer @@ -15,13 +15,13 @@ from ..datamodules import buffer_names from ..datamodules.replay_buffer import ReplayBuffer from ..env.array_voc_state import VocStateObsNames as ObsNames +from ..env.array_voc_state import VocStateObsNames as VSON from .abc.agent import Agent from .abc.controller import Controller from .abc.observation_auto_encoder import ObservationDecoder, ObservationEncoder from .abc.prior import Prior from .abc.transition import Transition from .abc.world import World -from ..env.array_voc_state import VocStateObsNames as VSON class Dreamer(nn.Module): @@ -92,7 +92,6 @@ def __init__( self.evaluation_steps = evaluation_steps self.evaluation_blank_length = evaluation_blank_length - def configure_optimizers(self) -> tuple[Optimizer, Optimizer]: """Configure world optimizer and controller optimizer. @@ -112,7 +111,7 @@ def configure_optimizers(self) -> tuple[Optimizer, Optimizer]: return [world_optim, con_optim] - def configure_replay_buffer(self, env: gym.Env, buffer_size:int): + def configure_replay_buffer(self, env: gym.Env, buffer_size: int): action_box = env.action_space vocal_state_box = env.observation_space[VSON.VOC_STATE] target_sound_box = env.observation_space[VSON.TARGET_SOUND_WAVE] @@ -125,7 +124,7 @@ def configure_replay_buffer(self, env: gym.Env, buffer_size:int): spaces[buffer_names.DONE] = Box(0, 1, shape=(1,), dtype=bool) replay_buffer = ReplayBuffer(spaces, buffer_size) - + return replay_buffer @torch.no_grad() @@ -263,7 +262,7 @@ def world_training_step( "kl_div_loss": kl_div_loss, "over_free_nat": not all_kl_div_loss.item() < self.free_nats, } - + experiences["hiddens"] = all_hiddens experiences["states"] = all_states @@ -295,7 +294,9 @@ def controller_training_step( chunk_size, batch_size = actions.shape[:2] start_indices = np.random.randint(0, chunk_size - self.imagination_horizon, (batch_size,)) batch_arange = np.arange(batch_size) - hidden = torch.as_tensor(old_hiddens[start_indices, batch_arange], dtype=dtype, device=device) + hidden = torch.as_tensor( + old_hiddens[start_indices, batch_arange], dtype=dtype, device=device + ) controller_hidden = torch.zeros( batch_size, *self.controller.controller_hidden_shape, dtype=dtype, device=device ) @@ -304,7 +305,9 @@ def controller_training_step( loss = 0.0 for i in range(self.imagination_horizon): indices = start_indices + i - target = torch.as_tensor(target_sounds[indices, batch_arange], dtype=dtype, device=device) + target = torch.as_tensor( + target_sounds[indices, batch_arange], dtype=dtype, device=device + ) action, controller_hidden = self.controller.forward( hidden, state, target, controller_hidden, probabilistic=True ) diff --git a/tests/models/test_dreamer.py b/tests/models/test_dreamer.py index c7e4e32..e0ffbe6 100644 --- a/tests/models/test_dreamer.py +++ b/tests/models/test_dreamer.py @@ -6,21 +6,19 @@ import pytest import torch from gym.spaces import Box -from pynktrombonegym.wrappers import Log1pMelSpectrogram as L1MS -from pynktrombonegym.wrappers import ActionByAcceleration as ABA -from src.env.normalize_action_range import NormalizeActionRange as NAR -from src.env.array_action import ArrayAction as AA -from src.env.array_voc_state import ArrayVocState as AVS from pynktrombonegym.spaces import ObservationSpaceNames as OSN +from pynktrombonegym.wrappers import ActionByAcceleration as ABA +from pynktrombonegym.wrappers import Log1pMelSpectrogram as L1MS from torch.optim import SGD from src.datamodules import buffer_names from src.datamodules.replay_buffer import ReplayBuffer from src.env.array_action import ARRAY_ORDER as AO_act - +from src.env.array_action import ArrayAction as AA from src.env.array_voc_state import ARRAY_ORDER as AO_voc from src.env.array_voc_state import VSON from src.env.array_voc_state import ArrayVocState as AVS +from src.env.normalize_action_range import NormalizeActionRange as NAR from src.models.dreamer import Dreamer from tests.models.abc.dummy_classes import DummyAgent as DA from tests.models.abc.dummy_classes import DummyController as DC @@ -75,6 +73,7 @@ args = (trans, prior, obs_enc, obs_dec, ctrl, d_world, d_agent, world_opt, ctrl_opt) del env + def world_training_step(model, env): rb = ReplayBuffer(bf_space, bf_size) _, __ = model.configure_optimizers() @@ -82,6 +81,8 @@ def world_training_step(model, env): experience = rb.sample(1, chunk_length=16) loss_dict, experience = model.world_training_step(experience) return loss_dict, experience + + def _hasattrs(model): attributes = ( "transition", @@ -105,6 +106,7 @@ def _hasattrs(model): print(hasattr(model, attr)) return have_attr, attributes + def test__init__(): model = Dreamer(*args) has_attrs, attrs = _hasattrs(model) @@ -113,7 +115,7 @@ def test__init__(): pass else: assert False, f"attribute {attrs[idx]} doesn't set" - + def test_configure_optimizers(): model = Dreamer(*args) @@ -158,17 +160,18 @@ def test_evaluation_step(): assert loss_dict.get("target_generated_mae") is not None del env + def test_configure_replay_buffer(): env = AVS(AA(NAR(ABA(L1MS(target_files), action_scaler=1.0)))) model = Dreamer(*args) rb = model.configure_replay_buffer(env, bf_size) spaces = { buffer_names.VOC_STATE: env.observation_space[VSON.VOC_STATE], - buffer_names.ACTION : env.action_space, + buffer_names.ACTION: env.action_space, buffer_names.TARGET_SOUND: env.observation_space[VSON.TARGET_SOUND_WAVE], buffer_names.GENERATED_SOUND: env.observation_space[VSON.GENERATED_SOUND_WAVE], - buffer_names.DONE: Box(0, 1, shape=(1,), dtype=bool) + buffer_names.DONE: Box(0, 1, shape=(1,), dtype=bool), } for name, box in rb.spaces.items(): assert box == spaces.get(name), f"space {name} isn't set properly(box:{box}" - del env \ No newline at end of file + del env From 83f26f4aa33cd2ff225f4de7c7bffbfa6b9dfd9b Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Sat, 28 Jan 2023 13:00:53 +0900 Subject: [PATCH 26/28] ADD docstring for configure_replay_buffer, __init__ and removed some unnecessary docstrings --- src/models/dreamer.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/models/dreamer.py b/src/models/dreamer.py index 3d8a68c..9bd3a5c 100644 --- a/src/models/dreamer.py +++ b/src/models/dreamer.py @@ -111,7 +111,16 @@ def configure_optimizers(self) -> tuple[Optimizer, Optimizer]: return [world_optim, con_optim] - def configure_replay_buffer(self, env: gym.Env, buffer_size: int): + def configure_replay_buffer(self, env: gym.Env, buffer_size: int) -> ReplayBuffer: + """Configure replay buffer to store experiences. + + Args: + env (gym.Env): PynkTrombone environment or its wrapper class. + buffer_size (int): Max length of experiences you can store. + + Returns: + ReplayBuffer: Replay buffer that can store experiences. + """ action_box = env.action_space vocal_state_box = env.observation_space[VSON.VOC_STATE] target_sound_box = env.observation_space[VSON.TARGET_SOUND_WAVE] @@ -134,7 +143,6 @@ def collect_experiences(self, env: gym.Env, replay_buffer: ReplayBuffer) -> Repl Args: env (gym.Env): PynkTrombone environment or its wrapper class. replay_buffer (ReplayBuffer): Storing experiences. - num_steps (int): How much experiences to store. Returns: replay_buffer(ReplayBuffer): Same pointer of input replay_buffer. From c4a32f97c1b0a8195c9295f8ebabde46982ac8ba Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Sat, 28 Jan 2023 13:01:26 +0900 Subject: [PATCH 27/28] ADD docstring for configure_replay_buffer, __init__ and removed some unnecessary docstrings --- src/models/dreamer.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/models/dreamer.py b/src/models/dreamer.py index 3d8a68c..a745b0d 100644 --- a/src/models/dreamer.py +++ b/src/models/dreamer.py @@ -63,7 +63,10 @@ def __init__( controller_optimizer (partial[Optimizer]): Partial instance of Optimizer class. free_nats (float): Ignore kl div loss when it is less then this value. - evaluation_blank_length (int): + num_collect_experience_steps: Specifies the number of times the experiences are stored. + imagination_horizon: Specifies the number of state transitions that controller needs for learning. + evaluation_steps: Specifies the number of evaluations. + evaluation_blank_length (int):The blank lengths of generated/target sound. """ super().__init__() @@ -111,7 +114,16 @@ def configure_optimizers(self) -> tuple[Optimizer, Optimizer]: return [world_optim, con_optim] - def configure_replay_buffer(self, env: gym.Env, buffer_size: int): + def configure_replay_buffer(self, env: gym.Env, buffer_size: int) -> ReplayBuffer: + """Configure replay buffer to store experiences. + + Args: + env (gym.Env): PynkTrombone environment or its wrapper class. + buffer_size (int): Max length of experiences you can store. + + Returns: + ReplayBuffer: Replay buffer that can store experiences. + """ action_box = env.action_space vocal_state_box = env.observation_space[VSON.VOC_STATE] target_sound_box = env.observation_space[VSON.TARGET_SOUND_WAVE] @@ -134,7 +146,6 @@ def collect_experiences(self, env: gym.Env, replay_buffer: ReplayBuffer) -> Repl Args: env (gym.Env): PynkTrombone environment or its wrapper class. replay_buffer (ReplayBuffer): Storing experiences. - num_steps (int): How much experiences to store. Returns: replay_buffer(ReplayBuffer): Same pointer of input replay_buffer. From 722cef98b13f6aebadcd6e6b3bf008a81cd3b109 Mon Sep 17 00:00:00 2001 From: cehl-kurage Date: Sun, 29 Jan 2023 10:20:22 +0900 Subject: [PATCH 28/28] Fix attributes of Dreamer and type annotation in World --- src/models/abc/world.py | 2 +- src/models/dreamer.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/models/abc/world.py b/src/models/abc/world.py index 226ae0c..9efc2ae 100644 --- a/src/models/abc/world.py +++ b/src/models/abc/world.py @@ -45,7 +45,7 @@ def forward( next_obs: _tensor_or_any, *args: Any, **kwds: Any, - ) -> tuple[_dist_or_any, _dist_or_any, _tensor_or_any, _tensor_or_any]: + ) -> tuple[_dist_or_any, _dist_or_any, _tensor_or_any]: """Make world model transition. Args: diff --git a/src/models/dreamer.py b/src/models/dreamer.py index a745b0d..2909709 100644 --- a/src/models/dreamer.py +++ b/src/models/dreamer.py @@ -11,6 +11,7 @@ from torch import Tensor from torch.distributions import kl_divergence from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter from ..datamodules import buffer_names from ..datamodules.replay_buffer import ReplayBuffer @@ -32,7 +33,7 @@ class Dreamer(nn.Module): current_episode: int = 0 device: torch.device = "cpu" dtype: torch.dtype = torch.float32 - + tensorboard: SummaryWriter def __init__( self, transition: Transition, @@ -203,8 +204,8 @@ def world_training_step( loss_dict (dict[str, Any]): loss and some other metric values. experiences (dict[str, np.ndarray]): Added `all_hiddens` and `all_states`. """ - device = self.agent.hidden.device - dtype = self.agent.hidden.dtype + device = self.device + dtype = self.dtype self.world.train()