diff --git a/src/datamodules/buffer_names.py b/src/datamodules/buffer_names.py new file mode 100644 index 0000000..501111b --- /dev/null +++ b/src/datamodules/buffer_names.py @@ -0,0 +1,5 @@ +ACTION = "action" +VOC_STATE = "voc_state" +GENERATED_SOUND = "generated_sound" +TARGET_SOUND = "target_sound" +DONE = "done" diff --git a/src/models/abc/world.py b/src/models/abc/world.py index 46fd458..9efc2ae 100644 --- a/src/models/abc/world.py +++ b/src/models/abc/world.py @@ -65,3 +65,17 @@ def forward( next_state_posterior = self.obs_encoder.forward(next_hidden, next_obs) return next_state_prior, next_state_posterior, next_hidden + + def eval(self): + """Set models to evaluation mode.""" + self.transition.eval() + self.prior.eval() + self.obs_encoder.eval() + self.obs_decoder.eval() + + def train(self): + """Set models to training mode.""" + self.transition.train() + self.prior.train() + self.obs_encoder.train() + self.obs_decoder.train() diff --git a/src/models/dreamer.py b/src/models/dreamer.py new file mode 100644 index 0000000..2909709 --- /dev/null +++ b/src/models/dreamer.py @@ -0,0 +1,420 @@ +from collections import OrderedDict +from functools import partial +from typing import Any + +import gym +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from gym.spaces import Box +from torch import Tensor +from torch.distributions import kl_divergence +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter + +from ..datamodules import buffer_names +from ..datamodules.replay_buffer import ReplayBuffer +from ..env.array_voc_state import VocStateObsNames as ObsNames +from ..env.array_voc_state import VocStateObsNames as VSON +from .abc.agent import Agent +from .abc.controller import Controller +from .abc.observation_auto_encoder import ObservationDecoder, ObservationEncoder +from .abc.prior import Prior +from .abc.transition import Transition +from .abc.world import World + + +class Dreamer(nn.Module): + """Dreamer model class.""" + + # Added attribute from Trainer + current_step: int = 0 + current_episode: int = 0 + device: torch.device = "cpu" + dtype: torch.dtype = torch.float32 + tensorboard: SummaryWriter + def __init__( + self, + transition: Transition, + prior: Prior, + obs_encoder: ObservationEncoder, + obs_decoder: ObservationDecoder, + controller: Controller, + world: partial[World], + agent: partial[Agent], + world_optimizer: partial[Optimizer], + controller_optimizer: partial[Optimizer], + free_nats: float = 3.0, + num_collect_experience_steps: int = 100, + imagination_horizon: int = 32, + evaluation_steps: int = 44 * 60, + evaluation_blank_length: int = 22050, + ) -> None: + """ + Args: + transition (Transition): Instance of ransition model class. + prior (Prior): Instance of prior model class. + obs_encoder (ObservationEncoder): Instance of ObservationEncoder model class. + obs_decoder (ObservationDecoder): Instance of ObservationDecoder model class. + controller (Controller): Instance of Controller model class. + world (partial[World]): Partial instance of World interface class. + agent (partial[Agent]): Partial instance of Agent interface class. + world_optimizer (partial[Optimizer]): Partial instance of Optimizer class. + controller_optimizer (partial[Optimizer]): Partial instance of Optimizer class. + + free_nats (float): Ignore kl div loss when it is less then this value. + num_collect_experience_steps: Specifies the number of times the experiences are stored. + imagination_horizon: Specifies the number of state transitions that controller needs for learning. + evaluation_steps: Specifies the number of evaluations. + evaluation_blank_length (int):The blank lengths of generated/target sound. + """ + + super().__init__() + self.transition = transition + self.prior = prior + self.obs_encoder = obs_encoder + self.obs_decoder = obs_decoder + self.controller = controller + + self.world = world( + transition=transition, prior=prior, obs_encoder=obs_encoder, obs_decoder=obs_decoder + ) + + self.agent = agent( + controller=controller, + transition=transition, + obs_encoder=obs_encoder, + ) + + self.world_optimizer = world_optimizer + self.controller_optimizer = controller_optimizer + + self.free_nats = free_nats + self.num_collect_experience_steps = num_collect_experience_steps + self.imagination_horizon = imagination_horizon + self.evaluation_steps = evaluation_steps + self.evaluation_blank_length = evaluation_blank_length + + def configure_optimizers(self) -> tuple[Optimizer, Optimizer]: + """Configure world optimizer and controller optimizer. + + Returns: + world_optimizer (Optimizer): Updating Transition, Prior, ObservationEncoder and Decoder. + controller_optimizer (Optimizer): Updating Controller. + """ + world_params = ( + list(self.transition.parameters()) + + list(self.prior.parameters()) + + list(self.obs_encoder.parameters()) + + list(self.obs_decoder.parameters()) + ) + + world_optim = self.world_optimizer(params=world_params) + con_optim = self.controller_optimizer(params=self.controller.parameters()) + + return [world_optim, con_optim] + + def configure_replay_buffer(self, env: gym.Env, buffer_size: int) -> ReplayBuffer: + """Configure replay buffer to store experiences. + + Args: + env (gym.Env): PynkTrombone environment or its wrapper class. + buffer_size (int): Max length of experiences you can store. + + Returns: + ReplayBuffer: Replay buffer that can store experiences. + """ + action_box = env.action_space + vocal_state_box = env.observation_space[VSON.VOC_STATE] + target_sound_box = env.observation_space[VSON.TARGET_SOUND_WAVE] + generated_sound_box = env.observation_space[VSON.GENERATED_SOUND_WAVE] + spaces = {} + spaces[buffer_names.ACTION] = action_box + spaces[buffer_names.VOC_STATE] = vocal_state_box + spaces[buffer_names.GENERATED_SOUND] = target_sound_box + spaces[buffer_names.TARGET_SOUND] = generated_sound_box + spaces[buffer_names.DONE] = Box(0, 1, shape=(1,), dtype=bool) + + replay_buffer = ReplayBuffer(spaces, buffer_size) + + return replay_buffer + + @torch.no_grad() + def collect_experiences(self, env: gym.Env, replay_buffer: ReplayBuffer) -> ReplayBuffer: + """Explorer in env and collect experiences to replay buffer. + + Args: + env (gym.Env): PynkTrombone environment or its wrapper class. + replay_buffer (ReplayBuffer): Storing experiences. + + Returns: + replay_buffer(ReplayBuffer): Same pointer of input replay_buffer. + """ + device = self.device + dtype = self.dtype + + obs = env.reset() + voc_state_np = obs[ObsNames.VOC_STATE] + generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] + target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] + + voc_state = torch.as_tensor(voc_state_np, dtype=dtype, device=device).unsqueeze(0) + generated = torch.as_tensor(generated_np, dtype=dtype, device=device).unsqueeze(0) + target = torch.as_tensor(target_np, dtype=dtype, device=device).unsqueeze(0) + + for _ in range(self.num_collect_experience_steps): + action = self.agent.explore(obs=(voc_state, generated), target=target) + action = action.cpu().squeeze(0).numpy() + obs, _, done, _ = env.step(action) + + voc_state_np = obs[ObsNames.VOC_STATE] + generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] + target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] + sample = { + buffer_names.ACTION: action, + buffer_names.VOC_STATE: voc_state_np, + buffer_names.GENERATED_SOUND: generated_np, + buffer_names.DONE: done, + buffer_names.TARGET_SOUND: target_np, + } + + replay_buffer.push(sample) + + if done: + obs = env.reset() + voc_state_np = obs[ObsNames.VOC_STATE] + generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] + target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] + + voc_state = torch.as_tensor(voc_state_np, dtype=dtype, device=device).squeeze(0) + generated = torch.as_tensor(generated_np, dtype=dtype, device=device).squeeze(0) + target = torch.as_tensor(target_np, dtype=dtype, device=device).squeeze(0) + + def world_training_step( + self, experiences: dict[str, np.ndarray] + ) -> tuple[dict[str, Any], dict[str, np.ndarray]]: + """Compute loss for training world model, and add all hiddens and states to `experiences` + for controller training. + + Args: + experiences (dict[str, np.ndarray]): Collected experiences. + + Returns: + loss_dict (dict[str, Any]): loss and some other metric values. + experiences (dict[str, np.ndarray]): Added `all_hiddens` and `all_states`. + """ + device = self.device + dtype = self.dtype + + self.world.train() + + actions = experiences[buffer_names.ACTION] + voc_states = experiences[buffer_names.VOC_STATE] + generated_sounds = experiences[buffer_names.GENERATED_SOUND] + dones = experiences[buffer_names.DONE] + + chunk_size, batch_size = actions.shape[:2] + + hidden = torch.zeros( + (batch_size, *self.transition.hidden_shape), dtype=dtype, device=device + ) + state = torch.zeros((batch_size, *self.prior.state_shape), dtype=dtype, device=device) + + all_hiddens = torch.empty((chunk_size, *hidden.shape), dtype=dtype, device="cpu") + all_states = torch.empty((chunk_size, *state.shape), dtype=dtype, device="cpu") + + rec_voc_state_loss = 0.0 + rec_generated_sound_loss = 0.0 + all_kl_div_loss = 0.0 + + for idx in range(chunk_size): + action = torch.as_tensor(actions[idx], dtype=dtype, device=device) + + voc_stat = torch.as_tensor(voc_states[idx], dtype=dtype, device=device) + gened_sound = torch.as_tensor(generated_sounds[idx], dtype=dtype, device=device) + next_obs = (voc_stat, gened_sound) + + next_state_prior, next_state_posterior, next_hidden = self.world.forward( + hidden, state, action, next_obs + ) + + next_state = next_state_posterior.rsample() + rec_voc_stat, rec_gened_sound = self.obs_decoder.forward(next_hidden, next_state) + + all_states[idx] = next_state.detach() + all_hiddens[idx] = next_hidden.detach() + + # compute losses + kl_div_loss = kl_divergence(next_state_posterior, next_state_prior).view( + batch_size, -1 + ) + all_kl_div_loss += kl_div_loss.sum(-1).mean() + rec_voc_state_loss += F.mse_loss(voc_stat, rec_voc_stat) + rec_generated_sound_loss += F.mse_loss(gened_sound, rec_gened_sound) + + # next step + next_state[dones[idx]] = 0.0 # Initialize with zero. + next_hidden[dones[idx]] = 0.0 # Initialize with zero. + + state = next_state + hidden = next_hidden + + rec_voc_state_loss /= chunk_size + rec_generated_sound_loss /= chunk_size + all_kl_div_loss /= chunk_size + rec_loss = rec_voc_state_loss + rec_generated_sound_loss + loss = rec_loss + (not all_kl_div_loss.item() < self.free_nats) * all_kl_div_loss + + loss_dict = { + "loss": loss, + "rec_loss": rec_loss, + "rec_voc_state_loss": rec_voc_state_loss, + "rec_generated_sound_loss": rec_generated_sound_loss, + "kl_div_loss": kl_div_loss, + "over_free_nat": not all_kl_div_loss.item() < self.free_nats, + } + + experiences["hiddens"] = all_hiddens + experiences["states"] = all_states + + return loss_dict, experiences + + def controller_training_step( + self, experiences: dict[str, np.ndarray] + ) -> tuple[dict[str, Any], dict[str, np.ndarray]]: + """Compute loss for training controller model. + Args: + experiences (dict[str, np.ndarray]): Collected experiences. + + Returns: + loss_dict (dict[str, Any]): loss and some other metric values. + experiences (dict[str, np.ndarray]): Collected experiences (No modification.) + """ + + self.controller.train() + self.world.eval() + + device = self.device + dtype = self.dtype + + actions = experiences[buffer_names.ACTION] + dones = experiences[buffer_names.DONE] + target_sounds = experiences[buffer_names.TARGET_SOUND] + old_hiddens = experiences["hiddens"] + + chunk_size, batch_size = actions.shape[:2] + start_indices = np.random.randint(0, chunk_size - self.imagination_horizon, (batch_size,)) + batch_arange = np.arange(batch_size) + hidden = torch.as_tensor( + old_hiddens[start_indices, batch_arange], dtype=dtype, device=device + ) + controller_hidden = torch.zeros( + batch_size, *self.controller.controller_hidden_shape, dtype=dtype, device=device + ) + state = self.prior.forward(hidden).sample() + + loss = 0.0 + for i in range(self.imagination_horizon): + indices = start_indices + i + target = torch.as_tensor( + target_sounds[indices, batch_arange], dtype=dtype, device=device + ) + action, controller_hidden = self.controller.forward( + hidden, state, target, controller_hidden, probabilistic=True + ) + next_hidden = self.transition.forward(hidden, state, action) + next_state = self.prior.forward(next_hidden).sample() + rec_next_obs = self.obs_decoder.forward(next_hidden, next_state) + _, rec_gened_sound = rec_next_obs + + loss += F.mse_loss(target, rec_gened_sound) + + hidden = next_hidden + + hidden[dones[indices, batch_arange]] = old_hiddens[indices + 1, batch_arange] + state = self.prior.forward(hidden) + + loss /= self.imagination_horizon + + loss_dict = {"loss": loss} + + return loss_dict, experiences + + @torch.no_grad() + def evaluation_step(self, env: gym.Env) -> dict[str, Any]: + """Evaluation step. + Args: + env (gym.Env): PynkTrombone environment or its wrapper class. + + Returns: + loss_dict (dict[str, Any]): Returned metric values. + """ + self.world.eval() + self.controller.eval() + + device = self.device + dtype = self.dtype + + self.agent.reset() + + obs = env.reset() + voc_state_np = obs[ObsNames.VOC_STATE] + generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] + target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] + + voc_state = torch.as_tensor(voc_state_np, dtype=dtype, device=device).unsqueeze(0) + generated = torch.as_tensor(generated_np, dtype=dtype, device=device).unsqueeze(0) + target = torch.as_tensor(target_np, dtype=dtype, device=device).unsqueeze(0) + + generated_sound_waves = [] + target_sound_waves = [] + + blank = np.zeros(self.evaluation_blank_length) + + target_generated_mse = 0.0 + target_generated_mae = 0.0 + + for i in range(self.evaluation_steps): + target_sound_waves.append(obs[ObsNames.TARGET_SOUND_WAVE]) + + action = self.agent.act(obs=(voc_state, generated), target=target, probabilistic=False) + action = action.cpu().squeeze(0).numpy() + obs, _, done, _ = env.step(action) + + generated_sound_waves.append(obs[ObsNames.GENERATED_SOUND_WAVE]) + generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] + + target_generated_mse += np.mean((target_np - generated_np) ** 2) + target_generated_mae += np.mean(np.abs(target_np - generated_np)) + + if done: + obs = env.reset() + generated_sound_waves.append(blank) + target_sound_waves.append(blank) + self.agent.reset() + + voc_state_np = obs[ObsNames.VOC_STATE] + generated_np = obs[ObsNames.GENERATED_SOUND_SPECTROGRAM] + target_np = obs[ObsNames.TARGET_SOUND_SPECTROGRAM] + + voc_state = torch.as_tensor(voc_state_np, dtype=dtype, device=device).unsqueeze(0) + generated = torch.as_tensor(generated_np, dtype=dtype, device=device).unsqueeze(0) + target = torch.as_tensor(target_np, dtype=dtype, device=device).unsqueeze(0) + + target_generated_mae /= self.evaluation_steps + target_generated_mse /= self.evaluation_steps + + generated_sounds_for_log = np.concatenate(generated_sound_waves) + target_sounds_for_log = np.concatenate(target_sound_waves) + + # logging to tensorboard + generated_sounds_for_log + target_sounds_for_log + + loss_dict = { + "target_generated_mse": target_generated_mse, + "target_generated_mae": target_generated_mae, + } + + return loss_dict diff --git a/src/trainer.py b/src/trainer.py new file mode 100644 index 0000000..f105066 --- /dev/null +++ b/src/trainer.py @@ -0,0 +1,151 @@ +import logging +from collections import OrderedDict +from typing import Any, Optional + +import gym +import torch +from torch.nn.utils import clip_grad_norm_ +from torch.optim import Optimizer + +from .datamodules import buffer_names +from .datamodules.replay_buffer import ReplayBuffer +from .env.array_voc_state import VocStateObsNames as ObsNames +from .models.dreamer import Dreamer + +logger = logging.getLogger(__name__) + + +class CheckPointNames: + MODEL = "model" + WORLD_OPTIMIZER = "world_optimizer" + CONTROLLER_OPTIMIZER = "controller_optimizer" + + +class Trainer: + """Trainer class for Dreamer.""" + + def __init__( + self, + num_episode: int = 1, + collect_experience_interval: int = 100, + batch_size: int = 8, + chunk_size: int = 64, + gradient_clip_value: float = 100.0, + evaluation_interval=10, + model_save_interval=20, + checkpoint_path: Optional[Any] = None, + device: Any = "cpu", + dtype: Any = torch.float32, + ) -> None: + """ + Args: + + """ + self.__dict__.update(locals()) # Add all input args to class attribute. + + def fit(self, env: gym.Env, replay_buffer: ReplayBuffer, model: Dreamer) -> None: + """Fit + Args: + + """ + + self.setup_model_attribute(model) + + model = model.to(self.device, self.dtype) + + world_optimizer, controller_optimizer = model.configure_optimizers() + if self.checkpoint_path is not None: + self.load_checkpoint(model, world_optimizer, controller_optimizer) + + current_step = 0 + + logger.info("Fit started.") + for episode in range(self.num_episode): + logger.info(f"Episode {episode} is started.") + + model.current_episode = episode + model.current_step = current_step + + logger.debug("Collecting experiences...") + replay_buffer = model.collect_experiences( + env, replay_buffer, self.num_collect_experience_steps + ) + logger.debug("Collected experiences.") + + for collect_interval in range(self.collect_experience_interval): + logger.debug(f"Collect interval: {collect_interval}") + + # Training World Model. + experiences_dict = replay_buffer.sample( + self.batch_size, self.chunk_size, chunk_first=True + ) + loss_dict, experiences_dict = model.world_training_step(experiences_dict) + + loss: torch.Tensor = loss_dict["loss"] + world_optimizer.zero_grad() + loss.backward() + params = [] + for p in world_optimizer.param_groups: + params += p["params"] + clip_grad_norm_(params, self.gradient_clip_value) + world_optimizer.step() + + # -- logging -- + + # ---- Training Controller model. ----- + loss_dict, experiences_dict = model.controller_training_step(experiences_dict) + + loss: torch.Tensor = loss_dict["loss"] + controller_optimizer.zero_grad() + loss.backward() + params = [] + for p in controller_optimizer.param_groups: + params += p["params"] + clip_grad_norm_(params, self.gradient_clip_value) + controller_optimizer.step() + + # logging + + if current_step % self.evaluation_interval == 0: + # ----- Evaluation steps ----- + loss_dict = model.evaluation_step(env) + + # logging + + if current_step % self.model_save_interval == 0: + self.save_checkpoint("/logs/...", model, world_optimizer, controller_optimizer) + + current_step += 1 + + self.save_checkpoint("/logs/...", model, world_optimizer, controller_optimizer) + + def setup_model_attribute(self, model: Dreamer): + """Add attribute for model training. + + Call this begin of training. + Args: + model (Dreamer): Dreamer model class. + """ + model.device = torch.device(self.device) + model.dtype = self.dtype + model.current_episode = 0 + model.current_step = 0 + + def save_checkpoint( + self, path: Any, model: Dreamer, world_optim: Optimizer, controller_optim: Optimizer + ) -> None: + """Saving checkpoint.""" + ckpt = OrderedDict() + ckpt[CheckPointNames.MODEL] = model.state_dict() + ckpt[CheckPointNames.WORLD_OPTIMIZER] = world_optim.state_dict() + ckpt[CheckPointNames.CONTROLLER_OPTIMIZER] = controller_optim.state_dict() + + torch.save(ckpt, path) + + def load_checkpoint( + self, path: Any, model: Dreamer, world_optim: Optimizer, controller_optim: Optimizer + ): + ckpt = torch.load(path, self.device) + model.load_state_dict(ckpt[CheckPointNames.MODEL]) + world_optim.load_state_dict(ckpt[CheckPointNames.WORLD_OPTIMIZER]) + controller_optim.load_state_dict(ckpt[CheckPointNames.CONTROLLER_OPTIMIZER]) diff --git a/tests/models/abc/dummy_classes.py b/tests/models/abc/dummy_classes.py index 91d2493..26f61b3 100644 --- a/tests/models/abc/dummy_classes.py +++ b/tests/models/abc/dummy_classes.py @@ -3,6 +3,7 @@ import torch from torch import Tensor from torch.distributions import Normal +from torch.nn import Linear from src.models.abc.agent import Agent from src.models.abc.controller import Controller @@ -21,6 +22,7 @@ class DummyTransition(Transition): def __init__(self, hidden_shape: tuple[int], *args: Any, **kwds: Any) -> None: super().__init__() self._hidden_shape = hidden_shape + self.dmy_lyr = Linear(8, 16) def forward(self, hidden: Tensor, state: Tensor, action: Tensor) -> Tensor: return torch.randn(hidden.shape) @@ -36,6 +38,7 @@ class DummyPrior(Prior): def __init__(self, state_shape: tuple[int], *args: Any, **kwds: Any) -> None: super().__init__() self._state_shape = state_shape + self.dmy_lyr = Linear(8, 16) def forward(self, hidden: Tensor) -> Normal: shape = (hidden.size(0), *self.state_shape) @@ -57,6 +60,7 @@ def __init__( super().__init__() self._state_shape = state_shape self.embedded_obs_shape = embedded_obs_shape + self.dmy_lyr = Linear(8, 16) def embed_observation(self, obs: tuple[Tensor, Tensor]) -> Tensor: v, g = obs @@ -87,6 +91,7 @@ def __init__( super().__init__() self._voc_state_shape = voc_state_shape self._generated_sound_shape = generated_sound_shape + self.dmy_lyr = Linear(8, 16) def forward(self, hidden: Tensor, state: Tensor) -> Tensor: vs_shape = (hidden.size(0), *self._voc_state_shape) @@ -108,6 +113,7 @@ def __init__( self._action_shape = action_shape self._controller_hidden_shape = controller_hidden_shape + self.dmy_lyr = Linear(8, 16) def forward( self, diff --git a/tests/models/test_dreamer.py b/tests/models/test_dreamer.py new file mode 100644 index 0000000..e0ffbe6 --- /dev/null +++ b/tests/models/test_dreamer.py @@ -0,0 +1,177 @@ +import glob +import pathlib +from functools import partial + +import numpy as np +import pytest +import torch +from gym.spaces import Box +from pynktrombonegym.spaces import ObservationSpaceNames as OSN +from pynktrombonegym.wrappers import ActionByAcceleration as ABA +from pynktrombonegym.wrappers import Log1pMelSpectrogram as L1MS +from torch.optim import SGD + +from src.datamodules import buffer_names +from src.datamodules.replay_buffer import ReplayBuffer +from src.env.array_action import ARRAY_ORDER as AO_act +from src.env.array_action import ArrayAction as AA +from src.env.array_voc_state import ARRAY_ORDER as AO_voc +from src.env.array_voc_state import VSON +from src.env.array_voc_state import ArrayVocState as AVS +from src.env.normalize_action_range import NormalizeActionRange as NAR +from src.models.dreamer import Dreamer +from tests.models.abc.dummy_classes import DummyAgent as DA +from tests.models.abc.dummy_classes import DummyController as DC +from tests.models.abc.dummy_classes import DummyObservationDecoder as DOD +from tests.models.abc.dummy_classes import DummyObservationEncoder as DOE +from tests.models.abc.dummy_classes import DummyPrior as DP +from tests.models.abc.dummy_classes import DummyTransition as DT +from tests.models.abc.dummy_classes import DummyWorld as DW + +target_file_path = pathlib.Path(__file__).parents[2].joinpath("data/sample_target_sounds/*.wav") +target_files = glob.glob(str(target_file_path)) +env = AVS(AA(NAR(ABA(L1MS(target_files), action_scaler=1.0)))) + +hidden_shape = (16,) +ctrl_hidden_shape = (16,) +state_shape = (8,) +action_shape = (len(AO_act),) + +obs_space = env.observation_space +voc_stats_shape = obs_space[VSON.VOC_STATE].shape +rnn_input_shape = (8,) + +gen_sound_shape = obs_space[OSN.GENERATED_SOUND_SPECTROGRAM].shape +tgt_sound_shape = obs_space[OSN.TARGET_SOUND_SPECTROGRAM].shape +obs_shape = (24,) + +obs_enc = DOE(state_shape, obs_shape) +obs_dec = DOD( + voc_stats_shape, + gen_sound_shape, +) +prior = DP(state_shape) +trans = DT(hidden_shape) +ctrl = DC(action_shape, ctrl_hidden_shape) + + +d_world = partial(DW) +d_agent = partial(DA) + +world_opt = partial(SGD, lr=1e-3) +ctrl_opt = partial(SGD, lr=1e-3) + + +bf_size = 32 +bf_space = { + buffer_names.ACTION: Box(-np.inf, np.inf, action_shape), + buffer_names.DONE: Box(-np.inf, np.inf, (1,)), + buffer_names.GENERATED_SOUND: Box(-np.inf, np.inf, gen_sound_shape), + buffer_names.TARGET_SOUND: Box(-np.inf, np.inf, tgt_sound_shape), + buffer_names.VOC_STATE: Box(-np.inf, np.inf, voc_stats_shape), +} +args = (trans, prior, obs_enc, obs_dec, ctrl, d_world, d_agent, world_opt, ctrl_opt) +del env + + +def world_training_step(model, env): + rb = ReplayBuffer(bf_space, bf_size) + _, __ = model.configure_optimizers() + model.collect_experiences(env, rb) + experience = rb.sample(1, chunk_length=16) + loss_dict, experience = model.world_training_step(experience) + return loss_dict, experience + + +def _hasattrs(model): + attributes = ( + "transition", + "prior", + "obs_encoder", + "obs_decoder", + "controller", + "world", + "agent", + "world_optimizer", + "controller_optimizer", + "free_nats", + "num_collect_experience_steps", + "imagination_horizon", + "evaluation_steps", + "evaluation_blank_length", + ) + have_attr = [] + for attr in attributes: + have_attr += [hasattr(model, attr)] + print(hasattr(model, attr)) + return have_attr, attributes + + +def test__init__(): + model = Dreamer(*args) + has_attrs, attrs = _hasattrs(model) + for idx, has_attr_flag in enumerate(has_attrs): + if has_attr_flag: + pass + else: + assert False, f"attribute {attrs[idx]} doesn't set" + + +def test_configure_optimizers(): + model = Dreamer(*args) + opt1, opt2 = model.configure_optimizers() + + +@pytest.mark.parametrize("num_steps", [1, 2, 3]) +def test_collect_experiences(num_steps): + env = AVS(AA(NAR(ABA(L1MS(target_files), action_scaler=1.0)))) + rb = ReplayBuffer(bf_space, bf_size) + model = Dreamer(*args, num_collect_experience_steps=num_steps) + model.collect_experiences(env, rb) + assert rb.current_index == num_steps + del env + + +def test_world_training_step(): + env = AVS(AA(NAR(ABA(L1MS(target_files), action_scaler=1.0)))) + model = Dreamer(*args, num_collect_experience_steps=128) + loss_dict, experience = world_training_step(model, env) + assert experience.get("hiddens") is not None + assert experience.get("states") is not None + assert loss_dict.get("loss") is not None + del env + + +def test_controller_training_step(): + # World Training Step + env = AVS(AA(NAR(ABA(L1MS(target_files), action_scaler=1.0)))) + model = Dreamer(*args, imagination_horizon=4) + _, experience = world_training_step(model, env) + loss_dict, _ = model.controller_training_step(experience) + assert loss_dict.get("loss") is not None + del env + + +def test_evaluation_step(): + env = AVS(AA(NAR(ABA(L1MS(target_files), action_scaler=1.0)))) + model = Dreamer(*args) + loss_dict = model.evaluation_step(env) + assert loss_dict.get("target_generated_mse") is not None + assert loss_dict.get("target_generated_mae") is not None + del env + + +def test_configure_replay_buffer(): + env = AVS(AA(NAR(ABA(L1MS(target_files), action_scaler=1.0)))) + model = Dreamer(*args) + rb = model.configure_replay_buffer(env, bf_size) + spaces = { + buffer_names.VOC_STATE: env.observation_space[VSON.VOC_STATE], + buffer_names.ACTION: env.action_space, + buffer_names.TARGET_SOUND: env.observation_space[VSON.TARGET_SOUND_WAVE], + buffer_names.GENERATED_SOUND: env.observation_space[VSON.GENERATED_SOUND_WAVE], + buffer_names.DONE: Box(0, 1, shape=(1,), dtype=bool), + } + for name, box in rb.spaces.items(): + assert box == spaces.get(name), f"space {name} isn't set properly(box:{box}" + del env