diff --git a/.gitignore b/.gitignore index d9b88844..26e3ab4b 100644 --- a/.gitignore +++ b/.gitignore @@ -5,8 +5,9 @@ __pycache__/ *.py[cod] -# NumPy binary files -data*/*.npy +# NumPy/PyTorch binary files +*.npy +*.pt # Distribution and packaging files build/ @@ -31,10 +32,13 @@ outputs/ *.log # Tutorial files -**/interim_data/ -**/processed_data/ -**/results/ -tutorial/maize/data +tutorial/* +!tutorial/config/*maize*.yaml +!tutorial/config/*random_small*.yaml +!tutorial/data +!tutorial/maize/maize_dataset.py +!tutorial/notebooks/*.ipynb +!tutorial/README.md # Virtual environment venv/ @@ -42,4 +46,15 @@ virtualvenv/ # docs files docs/build/ -docs/source/_templates/ \ No newline at end of file +docs/source/_templates/ + +# Root folder +/*.* +!/.gitignore +!/.readthedocs.yaml +!/LICENSE +!/MANIFEST.in +!/README.md +!/pyproject.toml +!/requirements.txt +!/setup.cfg diff --git a/src/move/__init__.py b/src/move/__init__.py index a4ce9fd3..fe5c08c3 100644 --- a/src/move/__init__.py +++ b/src/move/__init__.py @@ -1,11 +1,10 @@ from __future__ import annotations __license__ = "MIT" -__version__ = (1, 4, 9) -__all__ = ["conf", "data", "models", "training_loop", "VAE"] +__version__ = (2, 0, 0) +__all__ = ["conf", "data", "models", "tasks", "viz"] HYDRA_VERSION_BASE = "1.2" -from move import conf, data, models -from move.models.vae import VAE -from move.training.training_loop import training_loop +import move.visualization as viz +from move import conf, data, models, tasks diff --git a/src/move/__main__.py b/src/move/__main__.py index 485e5bac..15d4f991 100644 --- a/src/move/__main__.py +++ b/src/move/__main__.py @@ -3,16 +3,11 @@ import hydra from omegaconf import OmegaConf -import move.tasks from move import HYDRA_VERSION_BASE -from move.conf.schema import ( - AnalyzeLatentConfig, - EncodeDataConfig, - IdentifyAssociationsConfig, - MOVEConfig, - TuneModelConfig, -) +from move.conf.schema import SUPPORTED_TASKS, MOVEConfig from move.core.logging import get_logger +from move.core.seed import set_global_seed +from move.tasks.base import Task @hydra.main( @@ -32,14 +27,11 @@ def main(config: MOVEConfig) -> None: if task_type is None: logger = get_logger("move") logger.info("No task specified.") - elif task_type is EncodeDataConfig: - move.tasks.encode_data(config.data) - elif issubclass(task_type, TuneModelConfig): - move.tasks.tune_model(config) - elif task_type is AnalyzeLatentConfig: - move.tasks.analyze_latent(config) - elif issubclass(task_type, IdentifyAssociationsConfig): - move.tasks.identify_associations(config) + elif issubclass(task_type, SUPPORTED_TASKS): + if config.seed is not None: + set_global_seed(config.seed) + task: Task = hydra.utils.instantiate(config.task, _recursive_=False) + task.run() else: raise ValueError("Unsupported type of task.") diff --git a/src/move/analysis/fdr.py b/src/move/analysis/fdr.py new file mode 100644 index 00000000..4020f87b --- /dev/null +++ b/src/move/analysis/fdr.py @@ -0,0 +1,18 @@ +import math +from typing import cast + +import numpy as np +from numpy.typing import NDArray + + +def argnearest(array: NDArray, target: float) -> int: + """Find value in array closest to target. Assumes array is sorted in + ascending order.""" + idx = np.searchsorted(array, target, side="left") + if idx > 0 and ( + idx == len(array) + or math.fabs(target - array[idx - 1]) < math.fabs(target - array[idx]) + ): + return cast(int, idx - 1) + else: + return cast(int, idx) diff --git a/src/move/analysis/feature_importance.py b/src/move/analysis/feature_importance.py new file mode 100644 index 00000000..850c830c --- /dev/null +++ b/src/move/analysis/feature_importance.py @@ -0,0 +1,99 @@ +__all__ = ["FeatureImportance"] + +from typing import TYPE_CHECKING + +import pandas as pd +import torch + +import move.visualization as viz +from move.core.exceptions import UnsetProperty +from move.data.io import sanitize_filename +from move.tasks.base import CsvWriterMixin, ParentTask, SubTask + +if TYPE_CHECKING: + from move.data.dataloader import MoveDataLoader + from move.models.base import BaseVae + + +class FeatureImportance(CsvWriterMixin, SubTask): + """Compute feature importance in latent space. + + Feature importance is computed per feature per dataset. For each dataset, + a file will be created. + + Feature importance is computed as the sum of differences in latent + variables generated when a feature is present/removed.""" + + data_filename_fmt: str = "feature_importance_{}.csv" + plot_filename_fmt: str = "feature_importance_{}.png" + + def __init__( + self, parent: ParentTask, model: "BaseVae", dataloader: "MoveDataLoader" + ) -> None: + self.parent = parent + self.model = model + self.dataloader = dataloader + + def plot(self) -> None: + if self.parent is None: + return + for dataset in self.dataloader.datasets: + csv_filename = self.data_filename_fmt.format(dataset.name) + csv_filepath = self.parent.output_dir / sanitize_filename(csv_filename) + fig_filename = self.plot_filename_fmt.format(dataset.name) + fig_filepath = self.parent.output_dir / sanitize_filename(fig_filename) + + diffs = pd.read_csv(csv_filepath) + + if dataset.data_type == "continuous": + fig = viz.plot_continuous_feature_importance( + diffs.values, dataset.tensor.numpy(), dataset.feature_names + ) + else: + # Categorical dataset is re-shaped to 3D shape + dataset_shape = getattr(dataset, "original_shape") + fig = viz.plot_categorical_feature_importance( + diffs.values, + dataset.tensor.reshape(-1, *dataset_shape).numpy(), + dataset.feature_names, + getattr(dataset, "mapping"), + ) + + fig.savefig(fig_filepath, bbox_inches="tight") + + @torch.no_grad() + def run(self) -> None: + for dataset in self.dataloader.datasets: + self.log(f"Computing feature importance: '{dataset}'") + # Create a file for each dataset + # File is transposed; each column is a sample, each row a feature + if self.parent: + csv_filename = sanitize_filename(self.data_filename_fmt.format(dataset)) + csv_filepath = self.parent.output_dir / csv_filename + colnames = ["feature_name"] + [""] * len(self.dataloader.dataset) + self.init_csv_writer( + csv_filepath, fieldnames=colnames, extrasaction="ignore" + ) + else: + raise UnsetProperty("Parent task") + + # Make a perturbation for each feature + for feature_name in dataset.feature_names: + value = None if dataset.data_type == "discrete" else 0.0 + self.dataloader.dataset.perturb(dataset.name, feature_name, value) + row = [feature_name] + for tup in self.dataloader: + batch, pert_batch, _ = tup + z = self.model.project(batch) + z_pert = self.model.project(pert_batch) + diff = torch.sum(z_pert - z, dim=-1) + row.extend(diff.tolist()) + self.write_row(row) + + self.close_csv_writer(clear=True) + + # Transpose CSV file, so each row is a sample, each column a feature + pd.read_csv(csv_filepath).T.to_csv(csv_filepath, index=False, header=False) + + # Clear perturbation + self.dataloader.dataset.perturbation = None diff --git a/src/move/analysis/hdi.py b/src/move/analysis/hdi.py new file mode 100644 index 00000000..15ff1ee5 --- /dev/null +++ b/src/move/analysis/hdi.py @@ -0,0 +1,36 @@ +import math + +import torch + + +def hdi_bounds( + x: torch.Tensor, hdi_prob: float = 0.95 +) -> tuple[torch.Tensor, torch.Tensor]: + """Return highest density interval (HDI) of a samples-features matrix. + The HDI represents the range within which most of the samples are located. + + Args: + x: Matrix (`num_samples` x `num_features`) + hdi_prob: Percentage of samples inside the HDI + + Returns: + Lower and upper bounds of HDI + """ + # adapated from arviz + + if x.dim() != 2: + raise ValueError("Can only calculate for matrices with two dimensions") + + n = x.size(0) + x, _ = torch.sort(x, dim=0) + + interval_idx_inc = math.floor(hdi_prob * n) + num_intervals = n - interval_idx_inc + + interval_width = x[interval_idx_inc:] - x[:num_intervals] + min_idx = torch.argmin(interval_width, dim=0) + + hdi_min = torch.diag(x[min_idx]) + hdi_max = torch.diag(x[min_idx + interval_idx_inc]) + + return hdi_min, hdi_max diff --git a/src/move/analysis/metrics.py b/src/move/analysis/metrics.py index 35f5bc60..dd2660f5 100644 --- a/src/move/analysis/metrics.py +++ b/src/move/analysis/metrics.py @@ -1,8 +1,19 @@ __all__ = ["calculate_accuracy", "calculate_cosine_similarity"] +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, cast + import numpy as np +import pandas as pd +import torch +import move.visualization as viz from move.core.typing import FloatArray +from move.tasks.base import CsvWriterMixin, ParentTask, SubTask + +if TYPE_CHECKING: + from move.data.dataloader import MoveDataLoader + from move.models.base import BaseVae def calculate_accuracy( @@ -33,7 +44,7 @@ def calculate_accuracy( y_pred = np.ma.masked_array(reconstruction, mask=is_nan) num_features = np.ma.count(y_true, axis=1) - scores = np.ma.filled(np.sum(y_true == y_pred, axis=1) / num_features, 0) + scores = np.ma.filled(np.sum(y_true == y_pred, axis=1) / num_features, np.nan) return scores @@ -64,7 +75,7 @@ def calculate_cosine_similarity( # Equivalent to `np.diag(sklearn.metrics.pairwise.cosine_similarity(x, y))` # But can handle masked arrays - scores = np.ma.compressed(np.sum(x * y, axis=1)) / (norm(x) * norm(y)) + scores = np.ma.filled(np.sum(x * y, axis=1), np.nan) / (norm(x) * norm(y)) return scores @@ -80,4 +91,64 @@ def norm(x: np.ma.MaskedArray, axis: int = 1) -> FloatArray: Returns: 1D array with the specified axis removed. """ - return np.ma.compressed(np.sqrt(np.sum(x**2, axis=axis))) + return np.ma.filled(np.sqrt(np.sum(x**2, axis=axis)), np.nan) + + +class ComputeAccuracyMetrics(CsvWriterMixin, SubTask): + """Compute accuracy metrics between original input and reconstruction (use + cosine similarity for continuous dataset reconstructions).""" + + data_filename: str = "reconstruction_metrics.csv" + plot_filename: str = "reconstruction_metrics.png" + + def __init__( + self, parent: ParentTask, model: "BaseVae", dataloader: "MoveDataLoader" + ) -> None: + self.parent = parent + self.model = model + self.dataloader = dataloader + + def plot(self) -> None: + if self.parent and self.csv_filepath: + scores = pd.read_csv(self.csv_filepath, index_col=None) + fig = viz.plot_metrics_boxplot(scores, labels=None) + fig_path = self.parent.output_dir / self.plot_filename + fig.savefig(fig_path, bbox_inches="tight") + + @torch.no_grad() + def run(self) -> None: + if self.parent: + csv_filepath = self.parent.output_dir / self.data_filename + colnames = self.dataloader.dataset.dataset_names + self.init_csv_writer( + csv_filepath, fieldnames=colnames, extrasaction="ignore" + ) + else: + self.log("No parent task, metrics will not be saved.", "WARNING") + + self.log("Computing accuracy metrics") + + datasets = self.dataloader.datasets + for batch in self.dataloader: + batch_disc, batch_cont = self.model.split_input(batch[0]) + recon = self.model.reconstruct(batch[0], as_one=True) + recon_disc, recon_cont = self.model.split_input(recon) + + scores_per_dataset = {} + for i, dataset in enumerate(datasets[: len(batch_disc)]): + target = batch_disc[i].numpy() + preds = torch.argmax( + (torch.log_softmax(recon_disc[i], dim=-1)), dim=-1 + ).numpy() + scores = calculate_accuracy(target, preds) + scores_per_dataset[dataset.name] = scores + + for i, dataset in enumerate(datasets[len(batch_disc) :]): + target = batch_cont[i].numpy() + preds = recon_cont[i].numpy() + scores = calculate_cosine_similarity(target, preds) + scores_per_dataset[dataset.name] = scores + + self.write_cols(scores_per_dataset) + + self.close_csv_writer() diff --git a/src/move/conf/__init__.py b/src/move/conf/__init__.py index 7669b5c3..0f058009 100644 --- a/src/move/conf/__init__.py +++ b/src/move/conf/__init__.py @@ -1,3 +1,15 @@ -__all__ = ["MOVEConfig"] +__all__ = [ + "AdamConfig", + "AdamWConfig", + "ProdigyConfig", + "SgdConfig", + "TrainingDataLoaderConfig", + "TrainingLoopConfig", + "VaeConfig", + "VaeNormalConfig", + "VaeTConfig", +] -from move.conf.schema import MOVEConfig +from move.conf.models import VaeConfig, VaeNormalConfig, VaeTConfig +from move.conf.optim import AdamConfig, AdamWConfig, ProdigyConfig, SgdConfig +from move.conf.training import TrainingDataLoaderConfig, TrainingLoopConfig diff --git a/src/move/conf/config_store.py b/src/move/conf/config_store.py new file mode 100644 index 00000000..38c02b28 --- /dev/null +++ b/src/move/conf/config_store.py @@ -0,0 +1,6 @@ +__all__ = ["config_store"] + +from hydra.core.config_store import ConfigStore + +config_store = ConfigStore.instance() +"""Hydra's config store singleton""" diff --git a/src/move/conf/data/base_data.yaml b/src/move/conf/data/base_data.yaml index 27904c3d..7f251788 100644 --- a/src/move/conf/data/base_data.yaml +++ b/src/move/conf/data/base_data.yaml @@ -6,33 +6,17 @@ raw_data_path: data/ interim_data_path: interim_data/ results_path: processed_data/ -sample_names: baseline_ids +sample_names: sample_names -categorical_inputs: - - name: diabetes_genotypes - weight: 1 - - name: baseline_drugs - weight: 1 - - name: baseline_categorical - weight: 1 +categorical_inputs: [] -continuous_inputs: - - name: baseline_continuous - weight: 2 - - name: baseline_transcriptomics - weight: 1 - - name: baseline_diet_wearables - weight: 1 - - name: baseline_proteomic_antibodies - weight: 1 - - name: baseline_target_metabolomics - weight: 1 - - name: baseline_untarget_metabolomics - weight: 1 - - name: baseline_metagenomics - weight: 1 +continuous_inputs: [] categorical_names: ${names:${data.categorical_inputs}} continuous_names: ${names:${data.continuous_inputs}} categorical_weights: ${weights:${data.categorical_inputs}} continuous_weights: ${weights:${data.continuous_inputs}} + +train_frac: 0.9 +test_frac: 0.1 +valid_frac: 0.0 diff --git a/src/move/conf/legacy.py b/src/move/conf/legacy.py new file mode 100644 index 00000000..be0b41b4 --- /dev/null +++ b/src/move/conf/legacy.py @@ -0,0 +1,145 @@ +__all__ = [] + +from dataclasses import dataclass, field +from typing import Any, Optional + +from omegaconf import MISSING + +from move.core.qualname import get_fully_qualname +from move.models.vae_legacy import VAE +from move.training.training_loop import training_loop + + +@dataclass +class ModelConfig: + _target_: str = MISSING + cuda: bool = MISSING + + +@dataclass +class VAEConfig(ModelConfig): + """Configuration for the VAE module.""" + + _target_: str = get_fully_qualname(VAE) + categorical_weights: list[int] = MISSING + continuous_weights: list[int] = MISSING + num_hidden: list[int] = MISSING + num_latent: int = MISSING + beta: float = MISSING + dropout: float = MISSING + cuda: bool = False + + +@dataclass +class TrainingLoopConfig: + _target_: str = get_fully_qualname(training_loop) + num_epochs: int = MISSING + lr: float = MISSING + kld_warmup_steps: list[int] = MISSING + batch_dilation_steps: list[int] = MISSING + early_stopping: bool = MISSING + patience: int = MISSING + + +@dataclass +class TaskConfig: + """Configure a MOVE task.""" + + +@dataclass +class ModelTaskConfig(TaskConfig): + """Configure a MOVE task involving a training loop. + + Attributes: + batch_size: Number of samples in a training batch. + model: Configuration for a model. + training_loop: Configuration for a training loop. + """ + + batch_size: Optional[int] + model: Optional[VAEConfig] + training_loop: Optional[TrainingLoopConfig] + + +@dataclass +class TuneModelConfig(ModelTaskConfig): + """Configure the "tune model" task.""" + + ... + + +@dataclass +class TuneModelStabilityConfig(TuneModelConfig): + """Configure the "tune model" task.""" + + num_refits: int = MISSING + + +@dataclass +class TuneModelReconstructionConfig(TuneModelConfig): + """Configure the "tune model" task.""" + + ... + + +@dataclass +class AnalyzeLatentConfig(ModelTaskConfig): + """Configure the "analyze latents" task. + + Attributes: + feature_names: + Names of features to visualize. + """ + + feature_names: list[str] = field(default_factory=list) + reducer: dict[str, Any] = MISSING + + +@dataclass +class IdentifyAssociationsConfig(ModelTaskConfig): + """Configure the "identify associations" task. + + Attributes: + target_dataset: + Name of categorical dataset to perturb. + target_value: + The value to change to. It should be a category name. + num_refits: + Number of times to refit the model. + sig_threshold: + Threshold used to determine whether an association is significant. + In the t-test approach, this is called significance level (alpha). + In the probabilistc approach, significant associations are selected + if their FDR is below this threshold. + + This value should be within the range [0, 1]. + save_models: + Whether to save the weights of each refit. If weights are saved, + rerunning the task will load them instead of training. + """ + + target_dataset: str = MISSING + target_value: str = MISSING + num_refits: int = MISSING + sig_threshold: float = 0.05 + save_refits: bool = False + + +@dataclass +class IdentifyAssociationsBayesConfig(IdentifyAssociationsConfig): + """Configure the probabilistic approach to identify associations.""" + + ... + + +@dataclass +class IdentifyAssociationsTTestConfig(IdentifyAssociationsConfig): + """Configure the t-test approach to identify associations. + + Args: + num_latent: + List of latent space dimensions to train. It should contain four + elements. + """ + + num_latent: list[int] = MISSING diff --git a/src/move/conf/main.yaml b/src/move/conf/main.yaml index f6bf981c..4e40eda6 100644 --- a/src/move/conf/main.yaml +++ b/src/move/conf/main.yaml @@ -20,6 +20,7 @@ hydra: job: config: override_dirname: + item_sep: ; exclude_keys: - experiment diff --git a/src/move/conf/models.py b/src/move/conf/models.py new file mode 100644 index 00000000..a9b270af --- /dev/null +++ b/src/move/conf/models.py @@ -0,0 +1,58 @@ +__all__ = ["VaeConfig", "VaeNormalConfig", "VaeTConfig"] + +from dataclasses import dataclass, field + +from move.conf.config_store import config_store +from move.core.qualname import get_fully_qualname +from move.models.vae import Vae +from move.models.vae_distribution import VaeNormal +from move.models.vae_t import VaeT + + +@dataclass +class ModelConfig: + """Configure a model.""" + + _target_: str + + +@dataclass +class VaeConfig(ModelConfig): + """Configure a variational encoder.""" + + _target_: str = field(default=get_fully_qualname(Vae), init=False) + + num_hidden: list[int] + num_latent: int + kl_weight: float + dropout_rate: float + use_cuda: bool = False + discrete_weights: list[float] = "${data.categorical_weights}" # type: ignore + continuous_weights: list[float] = "${data.continuous_weights}" # type: ignore + + +@dataclass +class VaeNormalConfig(VaeConfig): + """Configure a t-distribution variational autoencoder.""" + + _target_: str = field(default=get_fully_qualname(VaeNormal), init=False) + + +@dataclass +class VaeTConfig(VaeConfig): + """Configure a t-distribution variational autoencoder.""" + + _target_: str = field(default=get_fully_qualname(VaeT), init=False) + + +config_store.store( + group="task/model_config", + name="vae", + node=VaeConfig, +) + +config_store.store( + group="task/model_config", + name="vae_normal", + node=VaeNormalConfig, +) diff --git a/src/move/conf/optim.py b/src/move/conf/optim.py new file mode 100644 index 00000000..4890e003 --- /dev/null +++ b/src/move/conf/optim.py @@ -0,0 +1,140 @@ +__all__ = [ + "AdamConfig", + "AdamWConfig", + "ProdigyConfig", + "SgdConfig", + "ExponentialLrConfig", + "CosineAnnealingLrConfig", + "ReduceLrOnPlateauConfig", +] + +from dataclasses import dataclass, field + +from torch.optim import SGD, Adam, AdamW +from torch.optim.lr_scheduler import ( + CosineAnnealingLR, + ExponentialLR, + ReduceLROnPlateau, +) + +from move.conf.config_store import config_store +from move.core.qualname import get_fully_qualname +from move.training.optim.prodigy import Prodigy + + +@dataclass +class OptimizerConfig: + """Configure an optimizer algorithm.""" + + _target_: str + + +@dataclass +class AdamConfig(OptimizerConfig): + """Configure Adam algorithm.""" + + _target_: str = field(default=get_fully_qualname(Adam), init=False) + lr: float + weight_decay: float = 0.0 + + +@dataclass +class AdamWConfig(AdamConfig): + """Configure AdamW algorithm.""" + + _target_: str = field(default=get_fully_qualname(AdamW), init=False) + + +@dataclass +class ProdigyConfig(OptimizerConfig): + """Configure Prodigy algorithm.""" + + _target_: str = field(default=get_fully_qualname(Prodigy), init=False) + weight_decay: float = 0.0 + decouple: bool = True + d_coef: float = 1.0 + + +@dataclass +class SgdConfig(OptimizerConfig): + """Configure stochastic gradient descent algorithm.""" + + _target_: str = field(default=get_fully_qualname(SGD), init=False) + lr: float + momentum: float = 0.0 + weight_decay: float = 0.0 + + +@dataclass +class LrSchedulerConfig: + """Configure a learning rate scheduler.""" + + _target_: str + + +@dataclass +class CosineAnnealingLrConfig(LrSchedulerConfig): + """Configure a cosine annealing learning rate scheduler.""" + + _target_: str = field(default=get_fully_qualname(CosineAnnealingLR), init=False) + + T_max: int + eta_min: float = 0.0 + + +@dataclass +class ExponentialLrConfig(LrSchedulerConfig): + """Configure exponential decay learning rate scheduler.""" + + _target_: str = field(default=get_fully_qualname(ExponentialLR), init=False) + gamma: float + + +@dataclass +class ReduceLrOnPlateauConfig(LrSchedulerConfig): + """Configure learning rate scheduler set to decay when a metric stops + improving.""" + + _target_: str = field(default=get_fully_qualname(ReduceLROnPlateau), init=False) + + +config_store.store( + group="task/training_loop_config/optimizer_config", + name="optim_adam", + node=AdamConfig, +) +config_store.store( + group="task/training_loop_config/optimizer_config", + name="optim_adamw", + node=AdamWConfig, +) +config_store.store( + group="task/training_loop_config/optimizer_config", + name="optim_prodigy", + node=ProdigyConfig, +) +config_store.store( + group="task/training_loop_config/optimizer_config", + name="optim_sgd", + node=SgdConfig, +) +config_store.store( + group="task/training_loop_config/lr_scheduler_config", + name="optim_lr_scheduler", + node=LrSchedulerConfig, +) +config_store.store( + group="task/training_loop_config/lr_scheduler_config", + name="optim_cosine_annealing_lr", + node=CosineAnnealingLrConfig, +) +config_store.store( + group="task/training_loop_config/lr_scheduler_config", + name="optim_exponential_lr", + node=ExponentialLrConfig, +) +config_store.store( + group="task/training_loop_config/lr_scheduler_config", + name="optim_reduce_lr_on_plateau", + node=ReduceLrOnPlateauConfig, +) diff --git a/src/move/conf/resolvers.py b/src/move/conf/resolvers.py new file mode 100644 index 00000000..f58f124b --- /dev/null +++ b/src/move/conf/resolvers.py @@ -0,0 +1,22 @@ +__all__ = ["register_resolvers"] + +from omegaconf import OmegaConf + +from move.conf.tasks import InputConfig + + +def extract_weights(configs: list[InputConfig]) -> list[int]: + """Extract the weights from a list of input configs. If not specified, + weight defaults to 1.""" + return [1 if not hasattr(item, "weight") else item.weight for item in configs] + + +def extract_names(configs: list[InputConfig]) -> list[str]: + """Extract the names from a list of input configs.""" + return [item.name for item in configs] + + +def register_resolvers() -> None: + """Register custom resolvers.""" + OmegaConf.register_new_resolver("weights", extract_weights) + OmegaConf.register_new_resolver("names", extract_names) diff --git a/src/move/conf/schema.py b/src/move/conf/schema.py index c9dee984..a787591c 100644 --- a/src/move/conf/schema.py +++ b/src/move/conf/schema.py @@ -1,32 +1,23 @@ __all__ = [ "MOVEConfig", "EncodeDataConfig", - "AnalyzeLatentConfig", - "TuneModelReconstructionConfig", - "TuneModelStabilityConfig", - "IdentifyAssociationsConfig", - "IdentifyAssociationsBayesConfig", - "IdentifyAssociationsTTestConfig", ] from dataclasses import dataclass, field -from typing import Any, Optional +from typing import Any, Optional, Type -from hydra.core.config_store import ConfigStore -from omegaconf import MISSING, OmegaConf +from omegaconf import MISSING -from move.models.vae import VAE -from move.training.training_loop import training_loop - - -def get_fully_qualname(sth: Any) -> str: - return ".".join((sth.__module__, sth.__qualname__)) - - -@dataclass -class InputConfig: - name: str - weight: int = 1 +from move.conf.config_store import config_store +from move.conf.models import ModelConfig +from move.conf.resolvers import register_resolvers +from move.conf.tasks import InputConfig, PerturbationConfig, ReducerConfig +from move.conf.training import TrainingLoopConfig +from move.core.qualname import get_fully_qualname +from move.tasks.associations import Associations +from move.tasks.encode_data import EncodeData +from move.tasks.latent_space_analysis import LatentSpaceAnalysis +from move.tasks.tuning import TuneModel, TuneStability @dataclass @@ -41,199 +32,164 @@ class DataConfig: continuous_names: list[str] = MISSING categorical_weights: list[int] = MISSING continuous_weights: list[int] = MISSING - - -@dataclass -class ModelConfig: - _target_: str = MISSING - cuda: bool = MISSING - - -@dataclass -class VAEConfig(ModelConfig): - """Configuration for the VAE module.""" - - _target_: str = get_fully_qualname(VAE) - categorical_weights: list[int] = MISSING - continuous_weights: list[int] = MISSING - num_hidden: list[int] = MISSING - num_latent: int = MISSING - beta: float = MISSING - dropout: float = MISSING - cuda: bool = False - - -@dataclass -class TrainingLoopConfig: - _target_: str = get_fully_qualname(training_loop) - num_epochs: int = MISSING - lr: float = MISSING - kld_warmup_steps: list[int] = MISSING - batch_dilation_steps: list[int] = MISSING - early_stopping: bool = MISSING - patience: int = MISSING + train_frac: float = MISSING + test_frac: float = MISSING + valid_frac: float = MISSING @dataclass class TaskConfig: - """Configuration for a MOVE task. - - Attributes: - batch_size: Number of samples in a training batch. - model: Configuration for a model. - training_loop: Configuration for a training loop. - """ - - batch_size: Optional[int] - model: Optional[VAEConfig] - training_loop: Optional[TrainingLoopConfig] + """Configure a task.""" @dataclass class EncodeDataConfig(TaskConfig): - """Configuration for a data-encoding task.""" + """Configure data encoding.""" - batch_size = None - model = None - training_loop = None + _target_: str = field( + default=get_fully_qualname(EncodeData), init=False, repr=False + ) + raw_data_path: str = "${data.raw_data_path}" + interim_data_path: str = "${data.interim_data_path}" + sample_names_filename: str = "${data.sample_names}" + discrete_inputs: list[dict[str, Any]] = "${data.categorical_inputs}" # type: ignore + continuous_inputs: list[dict[str, Any]] = "${data.continuous_inputs}" # type: ignore + train_frac: float = "${data.train_frac}" # type: ignore + test_frac: float = "${data.test_frac}" # type: ignore + valid_frac: float = "${data.valid_frac}" # type: ignore @dataclass -class TuneModelConfig(TaskConfig): - """Configure the "tune model" task.""" +class MoveTaskConfig(TaskConfig): + """Configure generic MOVE task.""" - ... + discrete_dataset_names: list[str] = "${data.categorical_names}" # type: ignore + continuous_dataset_names: list[str] = "${data.continuous_names}" # type: ignore + batch_size: int = 16 + model_config: Optional[ModelConfig] = MISSING # "${model}" # type: ignore + training_loop_config: Optional[TrainingLoopConfig] = MISSING @dataclass -class TuneModelStabilityConfig(TuneModelConfig): - """Configure the "tune model" task.""" - - num_refits: int = MISSING +class LatentSpaceAnalysisConfig(MoveTaskConfig): + """Configure latent space analysis.""" + + defaults: list[Any] = field( + default_factory=lambda: [ + dict(reducer_config="tsne"), + dict(training_loop_config="schema_training_loop"), + ] + ) + + _target_: str = field( + default=get_fully_qualname(LatentSpaceAnalysis), init=False, repr=False + ) + interim_data_path: str = "${data.interim_data_path}" + results_path: str = "${data.results_path}" + compute_accuracy_metrics: bool = MISSING + compute_feature_importance: bool = MISSING + reducer_config: Optional[ReducerConfig] = MISSING + features_to_plot: Optional[list[str]] = MISSING @dataclass -class TuneModelReconstructionConfig(TuneModelConfig): - """Configure the "tune model" task.""" - - ... - - -@dataclass -class AnalyzeLatentConfig(TaskConfig): - """Configure the "analyze latents" task. - - Attributes: - feature_names: - Names of features to visualize.""" - - feature_names: list[str] = field(default_factory=list) - reducer: dict[str, Any] = MISSING - - -@dataclass -class IdentifyAssociationsConfig(TaskConfig): - """Configure the "identify associations" task. - - Attributes: - target_dataset: - Name of categorical dataset to perturb. - target_value: - The value to change to. It should be a category name. - num_refits: - Number of times to refit the model. - sig_threshold: - Threshold used to determine whether an association is significant. - In the t-test approach, this is called significance level (alpha). - In the probabilistc approach, significant associations are selected - if their FDR is below this threshold. - - This value should be within the range [0, 1]. - save_models: - Whether to save the weights of each refit. If weights are saved, - rerunning the task will load them instead of training. - """ - - target_dataset: str = MISSING - target_value: str = MISSING +class AssociationsConfig(MoveTaskConfig): + """Configure associations.""" + + defaults: list[Any] = field( + default_factory=lambda: [ + dict(perturbation_config="perturbation"), + dict(training_loop_config="schema_training_loop"), + ] + ) + + _target_: str = field( + default=get_fully_qualname(Associations), init=False, repr=False + ) + interim_data_path: str = "${data.interim_data_path}" + results_path: str = "${data.results_path}" + perturbation_config: PerturbationConfig = MISSING num_refits: int = MISSING sig_threshold: float = 0.05 - save_refits: bool = False + write_only_sig: bool = True @dataclass -class IdentifyAssociationsBayesConfig(IdentifyAssociationsConfig): - """Configure the probabilistic approach to identify associations.""" +class TuningConfig(MoveTaskConfig): + """Configure tuning.""" + + defaults: list[Any] = field( + default_factory=lambda: [ + dict(training_loop_config="schema_training_loop"), + ] + ) - ... + _target_: str = field(default=get_fully_qualname(TuneModel), init=False, repr=False) + interim_data_path: str = "${data.interim_data_path}" + results_path: str = "${data.results_path}" @dataclass -class IdentifyAssociationsTTestConfig(IdentifyAssociationsConfig): - """Configure the t-test approach to identify associations. +class StabilityTuningConfig(TuningConfig): + """Configure tuning for stability.""" - Args: - num_latent: - List of latent space dimensions to train. It should contain four - elements. - """ + defaults: list[Any] = field( + default_factory=lambda: [ + dict(training_loop_config="schema_training_loop"), + ] + ) - num_latent: list[int] = MISSING + _target_: str = field( + default=get_fully_qualname(TuneStability), init=False, repr=False + ) + num_refits: int = 10 @dataclass class MOVEConfig: + """Configure MOVE command line.""" + defaults: list[Any] = field(default_factory=lambda: [dict(data="base_data")]) data: DataConfig = MISSING task: TaskConfig = MISSING seed: Optional[int] = None -def extract_weights(configs: list[InputConfig]) -> list[int]: - """Extracts the weights from a list of input configs.""" - return [1 if not hasattr(item, "weight") else item.weight for item in configs] - - -def extract_names(configs: list[InputConfig]) -> list[str]: - """Extracts the weights from a list of input configs.""" - return [item.name for item in configs] - - # Store config schema -cs = ConfigStore.instance() -cs.store(name="config_schema", node=MOVEConfig) -cs.store( +config_store.store(name="config_schema", node=MOVEConfig) +config_store.store( group="task", name="encode_data", node=EncodeDataConfig, ) -cs.store( +config_store.store( group="task", - name="tune_model_reconstruction_schema", - node=TuneModelReconstructionConfig, + name="task_latent_space", + node=LatentSpaceAnalysisConfig, ) - -cs.store( - group="task", - name="tune_model_stability_schema", - node=TuneModelStabilityConfig, -) -cs.store( +config_store.store( group="task", - name="analyze_latent_schema", - node=AnalyzeLatentConfig, + name="task_associations", + node=AssociationsConfig, ) -cs.store( +config_store.store( group="task", - name="identify_associations_bayes_schema", - node=IdentifyAssociationsBayesConfig, + name="task_tuning", + node=TuningConfig, ) -cs.store( +config_store.store( group="task", - name="identify_associations_ttest_schema", - node=IdentifyAssociationsTTestConfig, + name="task_tune_stability", + node=StabilityTuningConfig, ) -# Register custom resolvers -OmegaConf.register_new_resolver("weights", extract_weights) -OmegaConf.register_new_resolver("names", extract_names) +register_resolvers() + +SUPPORTED_TASKS: tuple[Type, ...] = ( + AssociationsConfig, + EncodeDataConfig, + LatentSpaceAnalysisConfig, + TuningConfig, + StabilityTuningConfig, +) +"""List of tasks that can be ran from the command line.""" diff --git a/src/move/conf/tasks.py b/src/move/conf/tasks.py new file mode 100644 index 00000000..efa0e242 --- /dev/null +++ b/src/move/conf/tasks.py @@ -0,0 +1,82 @@ +__all__ = ["PcaConfig", "TsneConfig", "PerturbationConfig"] + +from dataclasses import dataclass, field +from typing import Optional, Union + +from omegaconf import MISSING +from sklearn.decomposition import PCA +from sklearn.manifold import TSNE + +from move.conf.config_store import config_store +from move.core.qualname import get_fully_qualname +from move.data.preprocessing import PreprocessingOpName + + +@dataclass +class InputConfig: + name: str + weight: int = 1 + preprocessing: PreprocessingOpName = "none" + + +@dataclass +class DiscreteInputConfig(InputConfig): + preprocessing: PreprocessingOpName = "one_hot_encode" + + +@dataclass +class ContinuousInputConfig(InputConfig): + preprocessing: PreprocessingOpName = "standardize" + + +@dataclass +class ReducerConfig: + _target_: str + n_components: int = 2 + + +@dataclass +class PcaConfig(ReducerConfig): + _target_: str = field(default=get_fully_qualname(PCA), init=False, repr=False) + + +@dataclass +class TsneConfig(ReducerConfig): + _target_: str = field(default=get_fully_qualname(TSNE), init=False, repr=False) + perplexity: float = 30.0 + + +try: + from umap import UMAP + + @dataclass + class UmapConfig(ReducerConfig): + _target_: str = field(default=get_fully_qualname(UMAP), init=False, repr=False) + n_neighbors: int = 15 + +except (ModuleNotFoundError, SystemError, TypeError): + pass + + +@dataclass +class PerturbationConfig: + target_dataset_name: str + target_feature_name: Optional[str] = None + target_value: Union[float, int, str] = MISSING + + +config_store.store( + group="task/reducer_config", + name="tsne", + node=TsneConfig, +) +config_store.store( + group="task/reducer_config", + name="pca", + node=PcaConfig, +) +config_store.store( + group="task/perturbation_config", + name="perturbation", + node=PerturbationConfig, +) diff --git a/src/move/conf/training.py b/src/move/conf/training.py new file mode 100644 index 00000000..1d73a837 --- /dev/null +++ b/src/move/conf/training.py @@ -0,0 +1,79 @@ +__all__ = [ + "DataLoaderConfig", + "TrainingDataLoaderConfig", + "TestDataLoaderConfig", + "TrainingLoopConfig", +] + +from dataclasses import dataclass, field +from typing import Optional + +from omegaconf import MISSING + +from move.conf.config_store import config_store +from move.conf.optim import LrSchedulerConfig, OptimizerConfig +from move.core.qualname import get_fully_qualname +from move.data.dataloader import MoveDataLoader +from move.training.loop import TrainingLoop + + +@dataclass +class DataLoaderConfig: + """Configure a data loader.""" + + _target_: str = field( + default=get_fully_qualname(MoveDataLoader), init=False, repr=False + ) + batch_size: int + shuffle: bool + drop_last: bool + + +@dataclass +class TrainingDataLoaderConfig(DataLoaderConfig): + """Configure a training data loader, which shuffles data and drops the last + batch.""" + + shuffle: bool = True + drop_last: bool = True + + +@dataclass +class TestDataLoaderConfig(DataLoaderConfig): + """Configure a test data loader, which does not shuffle data and does not + drop the last batch.""" + + shuffle: bool = False + drop_last: bool = False + + +@dataclass +class TrainingLoopConfig: + """Configure a training loop.""" + + _target_: str = field( + default=get_fully_qualname(TrainingLoop), init=False, repr=False + ) + + max_epochs: int = MISSING + + optimizer_config: OptimizerConfig = MISSING + lr_scheduler_config: Optional[LrSchedulerConfig] = None + + max_grad_norm: Optional[float] = None + + annealing_epochs: int = 0 + annealing_function: str = "linear" + annealing_schedule: str = "monotonic" + + prog_every_n_epoch: Optional[int] = 10 + + log_grad: bool = False + log_lr: bool = False + + +config_store.store( + group="task/training_loop_config", + name="schema_training_loop", + node=TrainingLoopConfig, +) diff --git a/src/move/core/exceptions.py b/src/move/core/exceptions.py new file mode 100644 index 00000000..99ef3a55 --- /dev/null +++ b/src/move/core/exceptions.py @@ -0,0 +1,25 @@ +__all__ = [] + + +FILE_EXISTS_WARNING = "File '{}' already exists. It will be overwritten." + + +class CudaIsNotAvailable(RuntimeError): + """CUDA is not available.""" + + def __init__(self) -> None: + super().__init__(self.__class__.__doc__) + + +class ShapeAndWeightMismatch(ValueError): + def __init__(self, num_shapes, num_weights) -> None: + message = ( + f"Mismatch between supplied number of dataset shapes ({num_shapes})" + f" and number of dataset weights ({num_weights})." + ) + super().__init__(message) + + +class UnsetProperty(ValueError): + def __init__(self, property_name: str) -> None: + super().__init__(f"{property_name} has not been set") diff --git a/src/move/core/logging.py b/src/move/core/logging.py index e665dfab..f7d06f44 100644 --- a/src/move/core/logging.py +++ b/src/move/core/logging.py @@ -1,3 +1,5 @@ +__all__ = ["get_logger"] + import logging from pathlib import Path diff --git a/src/move/core/qualname.py b/src/move/core/qualname.py new file mode 100644 index 00000000..3459f8af --- /dev/null +++ b/src/move/core/qualname.py @@ -0,0 +1,18 @@ +__all__ = ["get_fully_qualname"] + +from typing import Any + + +def get_fully_qualname(sth: Any) -> str: + """Get the fully-qualified name of a class or object instance. + + Args: + sth: Anything""" + if not isinstance(sth, type): + class_ = type(sth) + else: + class_ = sth + module_name = class_.__module__ + if module_name == "builtins": + return class_.__qualname__ + return f"{module_name}.{class_.__qualname__}" diff --git a/src/move/core/typing.py b/src/move/core/typing.py index 9fc16ced..edfc16a1 100644 --- a/src/move/core/typing.py +++ b/src/move/core/typing.py @@ -1,14 +1,50 @@ -__all__ = ["BoolArray", "FloatArray", "IntArray", "ObjectArray", "PathLike"] +__all__ = [ + "BoolArray", + "FloatArray", + "IntArray", + "ObjectArray", + "EncodedData", + "PathLike", +] import os -from typing import Union +from typing import Literal, TypedDict, Union import numpy as np +import torch from numpy.typing import NDArray +LoggingLevel = Union[ + int, + Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], +] PathLike = Union[str, os.PathLike] +Split = Literal["all", "train", "valid", "test"] + + +class IndicesDict(TypedDict): + train_indices: torch.Tensor + test_indices: torch.Tensor + valid_indices: torch.Tensor + + BoolArray = NDArray[np.bool_] -IntArray = NDArray[np.int_] +IntArray = Union[NDArray[np.int_], NDArray[np.uint]] FloatArray = NDArray[np.float_] ObjectArray = NDArray[np.object_] + + +class EncodedData(TypedDict): + """Dictionary containing a tensor, a name, and a list of feature names.""" + + dataset_name: str + tensor: torch.Tensor + feature_names: list[str] + + +class EncodedDiscreteData(EncodedData): + """Dictionary containing a tensor, a name, a list of feature names, and a + mapping.""" + + mapping: dict[str, int] diff --git a/src/move/data/__init__.py b/src/move/data/__init__.py index 4d43b968..6a119839 100644 --- a/src/move/data/__init__.py +++ b/src/move/data/__init__.py @@ -1,3 +1,5 @@ -__all__ = ["io", "preprocessing"] +__all__ = ["io", "preprocessing", "MoveDataset", "MoveDataLoader"] from move.data import io, preprocessing +from move.data.dataloader import MoveDataLoader +from move.data.dataset import MoveDataset diff --git a/src/move/data/dataloader.py b/src/move/data/dataloader.py new file mode 100644 index 00000000..3918626c --- /dev/null +++ b/src/move/data/dataloader.py @@ -0,0 +1,19 @@ +__all__ = ["MoveDataLoader"] + +from typing import Iterator + +import torch +from torch.utils.data import DataLoader + +from move.data.dataset import MoveDataset, NamedDataset + + +class MoveDataLoader(DataLoader): + dataset: MoveDataset + + @property + def datasets(self) -> list[NamedDataset]: + return list(self.dataset.datasets.values()) + + def __iter__(self) -> Iterator[tuple[torch.Tensor, ...]]: + return super().__iter__() diff --git a/src/move/data/dataset.py b/src/move/data/dataset.py new file mode 100644 index 00000000..5738f2e5 --- /dev/null +++ b/src/move/data/dataset.py @@ -0,0 +1,483 @@ +__all__ = ["DiscreteDataset", "ContinuousDataset", "MoveDataset"] + +import operator +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Literal, Optional, Type, TypeVar, Union, cast + +import pandas as pd +import torch +from torch import nn +from torch.utils.data import Dataset + +from move.core.exceptions import UnsetProperty +from move.core.typing import EncodedData, IndicesDict, Split +from move.tasks.encode_data import EncodeData + +DataType = Literal["continuous", "discrete"] +Index = Union[int, tuple[str, int], tuple[int, int]] +T = TypeVar("T", bound="NamedDataset") + + +class NamedDataset(Dataset, ABC): + """A dataset with a name and names for its features. + + Args: + tensor: Data + name: Name of the dataset + feature_names: Name of each feature contained in dataset""" + + def __init__( + self, + tensor: torch.Tensor, + dataset_name: str, + feature_names: Optional[list[str]] = None, + ) -> None: + self.tensor = tensor + self.name = dataset_name + if feature_names is not None: + self._validate_names(feature_names) + else: + feature_names = [f"{self.name}_{i}" for i in range(self.num_features)] + self.feature_names = feature_names + + def __add__(self, other: "NamedDataset") -> "MoveDataset": # type: ignore[override] + return MoveDataset(self, other) + + def __getitem__(self, index: int) -> torch.Tensor: + return self.tensor[index] + + def __len__(self): + return self.tensor.size(0) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}("{self.name}")' + + def __str__(self) -> str: + return self.name + + def _validate_names(self, feature_names: list[str]) -> None: + num_feature_names = len(feature_names) + if num_feature_names != self.num_feature_names: + raise ValueError( + f"Number of features ({self.num_features}) must match " + f"number of feature names {len(feature_names)}." + ) + if num_feature_names != len(set(feature_names)): + raise ValueError("Duplicate feature names") + + @property + @abstractmethod + def data_type(self) -> DataType: + raise NotImplementedError() + + @property + def mapping(self) -> dict[str, int]: + raise NotImplementedError() + + @property + def num_features(self) -> int: + return self.tensor.size(1) + + @property + def num_feature_names(self) -> int: + return self.num_features + + @classmethod + def load(cls: Type[T], path: Path, indices: Optional[torch.Tensor] = None) -> T: + """Load dataset. + + Args: + path: Path to encoded data + indices: Use to load only a subset of the data. Load all data if None. + """ + enc_data = cast(EncodedData, torch.load(path)) + if indices is not None: + enc_data["tensor"] = enc_data["tensor"][indices, :] + return cls(**enc_data) + + def select(self, feature_name: str) -> torch.Tensor: + """Slice and return values corresponding to a single feature.""" + slice_ = self.feature_slice(feature_name) + return self.tensor[:, slice_] + + def feature_slice(self, feature_name: str) -> slice: + """Return a slice object containing start and end position of a feature + in the dataset.""" + if feature_name not in self.feature_names: + raise KeyError(f"{feature_name} not found") + idx = self.feature_names.index(feature_name) + num_classes = getattr(self, "num_classes", 1) + start = idx * num_classes + stop = (idx + 1) * num_classes + return slice(start, stop) + + +class DiscreteDataset(NamedDataset): + """A dataset for discrete values. Discrete data is expected to be a one-hot + encoded tensor of three dimensions corresponding to number of samples, + features, and classes.""" + + def __init__( + self, + tensor: torch.Tensor, + dataset_name: str, + feature_names: Optional[list[str]] = None, + mapping: Optional[dict[str, int]] = None, + ): + if tensor.dim() != 3: + raise ValueError("Discrete datasets must have three dimensions.") + *_, dim0, dim1 = tensor.shape + self.original_shape = (dim0, dim1) + self._mapping = mapping + flattened_tensor = torch.flatten(tensor, start_dim=1) + super().__init__(flattened_tensor, dataset_name, feature_names) + + @property + def data_type(self) -> DataType: + return "discrete" + + @property + def mapping(self) -> dict[str, int]: + if self._mapping is not None: + return self._mapping + return {str(i): i for i in range(self.num_classes)} + + @property + def num_classes(self) -> int: + return self.original_shape[1] + + @property + def num_features(self) -> int: + return operator.mul(*self.original_shape) + + @property + def num_feature_names(self) -> int: + return self.original_shape[0] + + def one_hot_encode(self, value: Union[str, float, None]) -> torch.Tensor: + """One-hot encode a single value. + + Args: + value: category""" + if self.mapping is None: + raise ValueError("Unknown encoding") + encoded_value = torch.zeros(len(self.mapping)) + if not pd.isna(value): + code = self.mapping[str(value)] + encoded_value[code] = 1 + return encoded_value + + +class ContinuousDataset(NamedDataset): + """A dataset for continuous values.""" + + @property + def data_type(self) -> DataType: + return "continuous" + + +class MoveDataset(Dataset): + """Multi-omics dataset composed of one or more datasets (both categorical + and continuous). + + When indexed, returns a flat concatenation of the indexed elements of all + constituent datasets. + + A MOVE dataset can have a perturbation in one of its features. This + changes all the values of that feature. A perturbed dataset will return a + tuple when indexed. In the first position, it will contain the original + output as if there was no perturbation. The second element of the tuple + will correspond to the output affected by the perturbation. Lastly, the + third element is a boolean indicating whether the perturbation changed or + not the original value.""" + + def __init__(self, *args: NamedDataset) -> None: + if len(args) > 1 and not all( + len(args[0]) == len(dataset) for dataset in args[1:] + ): + raise ValueError("Size mismatch between datasets") + self._list = sorted(args, key=operator.attrgetter("data_type"), reverse=True) + self.datasets = {dataset.name: dataset for dataset in self._list} + if len(self.datasets) != len(args): + raise ValueError("One or more datasets have the same name") + self._perturbation = None + + def __getitem__(self, index: int) -> tuple[torch.Tensor, ...]: + indices = None + items = [[dataset[index] for dataset in self.datasets.values()]] + if self.perturbation is not None: + values: list[torch.Tensor] = [] + for dataset in self._list: + if dataset.name == self.perturbation.dataset_name: + left, _, right = torch.tensor_split( + dataset[index], self.perturbation.feature_indices + ) + values.extend((left, self.perturbation.mapped_value, right)) + indices = torch.all( + dataset[index][self.perturbation.feature_slice] + != self.perturbation.mapped_value + ) + else: + values.append(dataset[index]) + items.append(values) + out = tuple(torch.cat(item, dim=-1) for item in items) + if indices is not None: + out += (indices,) + return out + + def __len__(self) -> int: + return len(self._list[0]) + + def __repr__(self) -> str: + dataset_count = len(self._list) + s = "s" if dataset_count != 1 else "" + return f"{self.__class__.__name__}({dataset_count} dataset{s})" + + def _repr_html_(self) -> str: + rows = "" + for dataset in self._list: + num_classes = ( + str(dataset.num_classes) + if isinstance(dataset, DiscreteDataset) + else "N/A" + ) + rows += ( + "" + + "".join( + ( + dataset.name, + dataset.data_type, + f"{dataset.num_feature_names:,}", + num_classes, + ) + ) + + "" + ) + return f""" + + + + + + + + + + + + {rows} +
+ + MOVE dataset ({len(self):,} samples) +
datatype# features# classes
""" + + @property + def num_features(self) -> int: + return sum(dataset.num_features for dataset in self._list) + + @property + def num_discrete_features(self) -> int: + return sum( + dataset.num_features + for dataset in self._list + if isinstance(dataset, DiscreteDataset) + ) + + @property + def num_continuous_features(self) -> int: + return sum( + dataset.num_features + for dataset in self._list + if isinstance(dataset, ContinuousDataset) + ) + + @property + def discrete_datasets(self) -> list[DiscreteDataset]: + return [ + dataset for dataset in self._list if isinstance(dataset, DiscreteDataset) + ] + + @property + def continuous_datasets(self) -> list[ContinuousDataset]: + return [ + dataset for dataset in self._list if isinstance(dataset, ContinuousDataset) + ] + + @property + def discrete_shapes(self) -> list[tuple[int, int]]: + return [dataset.original_shape for dataset in self.discrete_datasets] + + @property + def continuous_shapes(self) -> list[int]: + return [dataset.num_features for dataset in self.continuous_datasets] + + @property + def dataset_names(self) -> list[str]: + return list(self.datasets.keys()) + + @property + def feature_names(self) -> list[str]: + feature_names = [] + for dataset in self._list: + if dataset.feature_names: + feature_names.extend(dataset.feature_names) + else: + raise ValueError("Missing feature names in one or more datasets") + return feature_names + + @property + def discrete_feature_names(self) -> list[str]: + feature_names = [] + for dataset in self.discrete_datasets: + feature_names.extend(dataset.feature_names) + return feature_names + + @property + def continuous_feature_names(self) -> list[str]: + feature_names = [] + for dataset in self.continuous_datasets: + feature_names.extend(dataset.feature_names) + return feature_names + + @property + def perturbation(self) -> Optional["Perturbation"]: + return self._perturbation + + @perturbation.setter + def perturbation(self, value: Optional["Perturbation"]) -> None: + if value is not None: + if value.dataset_name not in self.datasets: + raise KeyError( + f"Target dataset '{value.dataset_name}' not found in " + "MOVE dataset" + ) + dataset = self.datasets[value.dataset_name] + if value.feature_name not in dataset.feature_names: + raise KeyError( + f"Target feature {value.feature_name} not found in " + f"'{dataset}' dataset" + ) + if isinstance(dataset, DiscreteDataset): + value.mapped_value = dataset.one_hot_encode(value.target_value) + else: + value.mapped_value = torch.FloatTensor([value.target_value]) + value.feature_slice = dataset.feature_slice(value.feature_name) + self._perturbation = value + + @classmethod + def load( + cls, + path: Path, + discrete_dataset_names: list[str], + continuous_dataset_names: list[str], + split: Split = "all", + ) -> "MoveDataset": + """Load dataset. + + Args: + path: Path to encoded data + discrete_dataset_names: Names of discrete datasets + continuous_dataset_names: Names of continuous datasets + split: Subset of data to load ('train', 'test', 'valid', or 'all') + """ + if split != "all": + ind_dict: IndicesDict = torch.load(path / EncodeData.indices_filename) + indices = ind_dict.get(f"{split}_indices") + if indices is None: + raise KeyError(f"Unknown data subset: '{split}'") + else: + indices = None + datasets: list[NamedDataset] = [] + for dataset_name in discrete_dataset_names: + dataset = DiscreteDataset.load(path / f"{dataset_name}.pt", indices) + datasets.append(dataset) + for dataset_name in continuous_dataset_names: + dataset = ContinuousDataset.load(path / f"{dataset_name}.pt", indices) + datasets.append(dataset) + return cls(*datasets) + + def feature_names_of(self, dataset_name: str) -> list[str]: + """Return feature names of a constituent dataset.""" + return self.datasets[dataset_name].feature_names + + def find(self, feature_name) -> NamedDataset: + """Return constituent dataset which contains feature name.""" + for dataset in self._list: + if feature_name in dataset.feature_names: + return dataset + raise KeyError(f"{feature_name} not found in any dataset") + + def perturb( + self, dataset_name: str, feature_name: str, value: Union[str, float, None] + ) -> None: + """Add a perturbation to a feature in a constituent dataset. + + Args: + dataset_name: Name of dataset to perturb + feature_name: Name of feature in dataset to perturb + value: Value of perturbation + """ + self.perturbation = Perturbation(dataset_name, feature_name, value) + + def remove_perturbation(self) -> None: + """Remove perturbation from dataset.""" + self.perturbation = None + + def select(self, feature_name: str) -> torch.Tensor: + """Slice and return values corresponding to a single feature. If the + same feature name exists in more than one dataset, the first matching + feature will be returned.""" + for dataset in self._list: + if feature_name in dataset.feature_names: + return dataset.select(feature_name) + raise KeyError(f"{feature_name} not found in any dataset") + + +class Perturbation: + """Perturbation in a MOVE dataset. A perturbation will target a feature in + one of the MOVE datasets. All the values of that feature will be replaced + by the defined target value. For example, target 'metformin' feature in + 'drugs' dataset and change value from 0 to 1.""" + + def __init__( + self, + target_dataset_name: str, + target_feature_name: str, + target_value: Union[str, float, None], + ) -> None: + self.dataset_name = target_dataset_name + self.feature_name = target_feature_name + self.target_value = target_value + + def __repr__(self) -> str: + return str(self) + + def __str__(self) -> str: + return f'{self.__class__.__name__}("{self.dataset_name}/{self.feature_name}")' + + @property + def feature_indices(self) -> tuple[int, int]: + slice_ = self.feature_slice + return slice_.start, slice_.stop + + @property + def feature_slice(self) -> slice: + if (slice_ := getattr(self, "_feature_slice", None)) is None: + raise UnsetProperty("Target feature indices") + return slice_ + + @feature_slice.setter + def feature_slice(self, value: slice) -> None: + self._feature_slice = value + + @property + def mapped_value(self) -> torch.Tensor: + if (value := getattr(self, "_mapped_value", None)) is None: + raise UnsetProperty("Encoded target value") + return value + + @mapped_value.setter + def mapped_value(self, mapped_value: torch.Tensor) -> None: + self._mapped_value = mapped_value diff --git a/src/move/data/io.py b/src/move/data/io.py index 309fbd04..14200f94 100644 --- a/src/move/data/io.py +++ b/src/move/data/io.py @@ -9,6 +9,7 @@ ] import json +import re from pathlib import Path from typing import Optional @@ -142,10 +143,17 @@ def read_tsv( Tuple containing (1) feature names and (2) 2D matrix (samples x features) """ - data = pd.read_csv(path, index_col=0, sep="\t") + extension = Path(path).suffix + if extension == ".tsv": + sep = "\t" + elif extension == ".csv": + sep = "," + else: + raise ValueError(f"Unsupported file type: {extension}") + data = pd.read_csv(path, index_col=0, sep=sep, na_values=["./."]) if sample_names is not None: data.index = data.index.astype(str, False) - data = data.loc[sample_names] + data = data.reindex(sample_names) return data.columns.values, data.values @@ -162,3 +170,44 @@ def dump_mappings(path: PathLike, mappings: dict[str, dict[str, int]]) -> None: def dump_names(path: PathLike, names: np.ndarray) -> None: with open(path, "w", encoding="utf-8") as file: file.writelines([f"{name}\n" for name in names]) + + +RESERVED_NAMES = [ + "CON", + "PRN", + "AUX", + "NUL", + "COM1", + "COM2", + "COM3", + "COM4", + "COM5", + "COM6", + "COM7", + "COM8", + "COM9", + "LPT1", + "LPT2", + "LPT3", + "LPT4", + "LPT5", + "LPT6", + "LPT7", + "LPT8", + "LPT9", +] +MAX_FILENAME_LEN = 255 + + +def sanitize_filename(string: str) -> str: + """Sanitize a filename.""" + # Replace non-alpha characters with underscore + filename = re.sub(r'[<>:"/\\|?*\0-\x1f\x7f]', "_", string) + # Check if reserved Windows name + sep_idx = filename.rindex(".") + stem, suffix = filename[:sep_idx], filename[sep_idx:] + if stem.upper() in RESERVED_NAMES: + stem += "_" + # Truncate + filename = stem[: MAX_FILENAME_LEN - len(suffix)] + suffix + return filename diff --git a/src/move/data/preprocessing.py b/src/move/data/preprocessing.py index 1c76df57..878187d9 100644 --- a/src/move/data/preprocessing.py +++ b/src/move/data/preprocessing.py @@ -1,14 +1,24 @@ -__all__ = ["one_hot_encode", "one_hot_encode_single", "scale"] +__all__ = [ + "one_hot_encode", + "one_hot_encode_single", + "log_n_standardize", + "standardize", +] -from typing import Any, Optional +from typing import Any, Literal, Optional, Union, cast import numpy as np import pandas as pd +import torch from numpy.typing import ArrayLike -from sklearn.preprocessing import scale as standardize +from sklearn.preprocessing import StandardScaler from move.core.typing import BoolArray, FloatArray, IntArray +PreprocessingOpName = Literal[ + "one_hot_encode", "log_and_standardize", "standardize", "none" +] + def _category_name(value: Any) -> str: return value if isinstance(value, str) else str(int(value)) @@ -47,7 +57,7 @@ def one_hot_encode(x_: ArrayLike) -> tuple[IntArray, dict[str, int]]: return encoded_x, mapping -def one_hot_encode_single(mapping: dict[str, int], value: Optional[str]) -> IntArray: +def one_hot_encode_single(mapping: dict[str, int], value: Optional[str]) -> FloatArray: """One-hot encode a single value given an existing mapping. Args: @@ -64,18 +74,56 @@ def one_hot_encode_single(mapping: dict[str, int], value: Optional[str]) -> IntA return encoded_value -def scale(x: np.ndarray) -> tuple[FloatArray, BoolArray]: - """Center to mean and scale to unit variance. Convert NaN values to 0. +Indices = Optional[Union[IntArray, torch.Tensor]] + + +def log_n_standardize( + x: np.ndarray, train_indices: Optional[Indices] = None +) -> FloatArray: + """Apply base-2 logarithm. Then, center to mean and scale to unit variance. + Convert NaN values to 0. Args: x: 2D array with samples in its rows and features in its columns + train_indices: Array with indices corresponding to training data subset Returns: - Tuple containing (1) scaled output and (2) a 1D mask marking columns - (i.e., features) without zero variance + Tuple containing standardized output """ logx = np.log2(x + 1) - mask_1d = ~np.isclose(np.nanstd(logx, axis=0), 0.0) - scaled_x = standardize(logx[:, mask_1d], axis=0) - scaled_x[np.isnan(scaled_x)] = 0 - return scaled_x, mask_1d + return standardize(logx, train_indices) + + +def standardize(x: np.ndarray, train_indices: Optional[Indices] = None) -> FloatArray: + """Center to mean and scale to unit variance. Convert NaN values to 0. + + Args: + x: 2D array with samples in its rows and features in its columns + train_indices: Array with indices corresponding to training data subset + + Returns: + Tuple containing standardized output + """ + op = StandardScaler() + if train_indices is None: + scaled_x = op.fit_transform(x) + else: + # Standardize based only on training subset + train_x = np.take(x, train_indices, axis=0) + op.fit(train_x) + # Apply transformation to all data + scaled_x = op.transform(x) + # Fill NaNs with zeros + return fill(cast(FloatArray, scaled_x)) + + +def fill(x: np.ndarray) -> FloatArray: + """Replace NaNs with zeroes. + + Args: + x: Array + + Returns: + Array with no NaNs""" + x[np.isnan(x)] = 0 + return x diff --git a/src/move/data/reservoir.py b/src/move/data/reservoir.py new file mode 100644 index 00000000..ba5bee15 --- /dev/null +++ b/src/move/data/reservoir.py @@ -0,0 +1,169 @@ +__all__ = ["Reservoir", "PairedReservoir"] + +import math +import random + +import torch + +from move.core.exceptions import UnsetProperty + + +class ReservoirTest: + """Generate a random sample of k items from a stream of n items, where + n is either unknown or very large. + + This implementation uses the so-called Algorithm R.""" + + # based on: https://richardstartin.github.io/posts/reservoir-sampling + + def __init__(self, capacity: int): + self.capacity = capacity + self.reservoir = torch.empty(capacity) + self.idx = 0 + + def add(self, value: torch.Tensor) -> None: + if self.idx < self.capacity: + self.reservoir[self.idx] = value + else: + repl_idx = math.floor(random.random() * self.idx) + if repl_idx < self.capacity: + self.reservoir[repl_idx] = value + self.idx += 1 + + +class Reservoir: + """Generate a random sample (reservoir) from a stream of `n` items, where + `n` is either unknown or very large. + + This implementation uses the so-called Algorithm R. + + Args: + capacity: Number of items in the reservoir + """ + + def __init__(self, capacity: int): + self.capacity = capacity + self._reservoir = None + self.idx = 0 + self.total_samples = 0 + + def __call__(self) -> torch.Tensor: + return self.reservoir + + @property + def reservoir(self) -> torch.Tensor: + if self._reservoir is None: + raise UnsetProperty("Reservoir") + if self.total_samples < self.capacity: + return self._reservoir[: self.total_samples] + return self._reservoir + + def add(self, stream: torch.Tensor): + """Select a random sample from stream and add it to the reservoir. + + Args: + stream: tensor with samples in its first dimension + """ + + num_samples = stream.size(0) + self.total_samples += num_samples + + # Init reservoir + if self._reservoir is None: + self._reservoir = torch.empty((self.capacity, *stream.shape[1:])) + elif self._reservoir.shape[1:] != stream.shape[1:]: + raise ValueError(f"Shape mismatch between reservoir and stream") + + # Fill empty reservoir + if self.idx < self.capacity: + stop = min(self.capacity - self.idx, num_samples) + self._reservoir[self.idx : self.idx + stop] = stream[:stop] + self.idx += stop + if stop == num_samples: + return + + stream = stream[stop:] + num_samples -= stop + + # Sample and fill reservoir + i = torch.arange(self.idx, self.total_samples) + j = torch.floor(torch.rand(i.shape) * i).long() + replace = j < self.capacity + + for a, b in zip(j[replace], i[replace] - self.idx): + self._reservoir[a] = stream[b] + + self.idx += num_samples + + +class PairedReservoir(Reservoir): + """Genereate a paired set of random samples (reservoirs) from a paired set + of streams, whose size is either unknown or very large.""" + + def __init__(self, capacity: int): + super().__init__(capacity) + + def __call__(self) -> tuple[torch.Tensor, ...]: + return self.reservoir + + @property + def reservoir(self) -> tuple[torch.Tensor, ...]: + if self._reservoir is None: + raise UnsetProperty("Reservoir") + if self.total_samples < self.capacity: + return tuple([r[: self.total_samples] for r in self._reservoir]) + return self._reservoir + + def add(self, *streams: torch.Tensor): + """Select a random sample from a paired set of streams and add it to + the corresponding reservoir. + + Args: + streams: Tensors with samples in its first dimension + """ + streams_ = list(streams) + stream1 = streams_[0] + num_samples = stream1.size(0) + + if not all(stream.size(0) == num_samples for stream in streams_[1:]): + raise ValueError("Streams must have the same number of samples") + + self.total_samples += num_samples + + # Init reservoir + if self._reservoir is None: + self._reservoir = tuple( + [ + torch.empty((self.capacity, *stream1.shape[1:])) + for _ in range(len(streams_)) + ] + ) + elif len(streams_) != len(self._reservoir): + raise ValueError("Size mismatch between number of streams and reservoirs") + + # Fill empty reservoir + if self.idx < self.capacity: + stop = min(self.capacity - self.idx, num_samples) + + for i, reservoir in enumerate(self._reservoir): + stream = streams_[i] + reservoir[self.idx : self.idx + stop] = stream[:stop] + streams_[i] = stream[stop:] + + self.idx += stop + + if stop == num_samples: + return + + num_samples -= stop + + # Sample and fill reservoir + i = torch.arange(self.idx, self.total_samples) + j = torch.floor(torch.rand(i.shape) * i).long() + replace = j < self.capacity + + for m, n in zip(j[replace], i[replace] - self.idx): + for reservoir, stream in zip(self._reservoir, streams_): + reservoir[m] = stream[n] + + self.idx += num_samples diff --git a/src/move/data/splitting.py b/src/move/data/splitting.py new file mode 100644 index 00000000..f2c16dac --- /dev/null +++ b/src/move/data/splitting.py @@ -0,0 +1,33 @@ +__all__ = ["split_samples"] + +import torch + + +def split_samples( + num_samples: int, + train_frac: float = 0.9, + test_frac: float = 0.1, + valid_frac: float = 0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Randomly split samples into training, test, and validation sets. + + Args: + num_samples: Number of samples to split. + train_frac: Fraction of samples corresponding to training set. + test_frac: Fraction of samples corresponding to test set. + valid_frac: Fraction of samples corresponding to validation set. + + Returns: + Tuple containing indices corresponding to each subset. + """ + if (train_frac + test_frac + valid_frac) != 1.0: + raise ValueError("The sum of the subset fractions must be equal to one.") + + train_size = int(train_frac * num_samples) + test_size = int(test_frac * num_samples) + + perm = torch.randperm(num_samples) + + tup = tuple(torch.tensor_split(perm, (train_size, train_size + test_size))) + assert len(tup) == 3 + return tup diff --git a/src/move/data/writer.py b/src/move/data/writer.py new file mode 100644 index 00000000..69675cb9 --- /dev/null +++ b/src/move/data/writer.py @@ -0,0 +1,26 @@ +__all__ = ["CsvWriter"] + +import csv +from typing import Any, Mapping, Sequence, Union + +from numpy.typing import NDArray + + +class CsvWriter(csv.DictWriter): + """Create a CSV writer that maps dictionaries onto output rows and columns.""" + + def writecols(self, cols: Mapping[str, Union[Sequence[Any], NDArray]]): + """Write all elements in columns to the writer's file object.""" + colnames = set(self.fieldnames) + if not colnames.issubset(cols.keys()): + raise ValueError("Missing a column") + if self.extrasaction == "raise" and not colnames.issuperset(cols.keys()): + raise ValueError("Extra column found") + key1, *keys = self.fieldnames + if not all(len(cols[key]) == len(cols[key1]) for key in keys): + raise ValueError("Columns of varying length") + rows = ( + tuple(cols[key][i] for key in self.fieldnames) + for i in range(len(cols[key1])) + ) + self.writer.writerows(rows) diff --git a/src/move/models/__init__.py b/src/move/models/__init__.py index 01d2dadf..febbfba8 100644 --- a/src/move/models/__init__.py +++ b/src/move/models/__init__.py @@ -1,3 +1,6 @@ -__all__ = ["VAE"] +__all__ = ["Vae", "VaeNormal", "VaeT", "reload_vae"] -from move.models.vae import VAE +from move.models.base import reload_vae +from move.models.vae import Vae +from move.models.vae_distribution import VaeNormal +from move.models.vae_t import VaeT diff --git a/src/move/models/base.py b/src/move/models/base.py new file mode 100644 index 00000000..3603210f --- /dev/null +++ b/src/move/models/base.py @@ -0,0 +1,139 @@ +__all__ = ["BaseVae"] + +import inspect +from abc import ABC, abstractmethod +from importlib import import_module +from pathlib import Path +from typing import ( + Any, + Literal, + OrderedDict, + Type, + TypedDict, + TypeVar, + Union, + cast, + overload, +) + +import torch +from torch import nn + +from move.core.qualname import get_fully_qualname +from move.models.layers.chunk import SplitInput, SplitOutput + +T = TypeVar("T", bound="BaseVae") + + +class VaeOutput(TypedDict): + z_loc: torch.Tensor + z_scale: torch.Tensor + x_recon: torch.Tensor + + +class LossDict(TypedDict): + elbo: torch.Tensor + discrete_loss: torch.Tensor + continuous_loss: torch.Tensor + kl_div: torch.Tensor + kl_weight: float + + +class SerializedModel(TypedDict): + config: dict[str, Any] + state_dict: OrderedDict[str, torch.Tensor] + + +class BaseVae(nn.Module, ABC): + embedding_args: int = 2 + output_args: int = 1 + encoder: nn.Module + decoder: nn.Module + split_input: SplitInput + split_output: SplitOutput + num_latent: int + + def __call__(self, *args: Any, **kwds: Any) -> VaeOutput: + return super().__call__(*args, **kwds) + + @abstractmethod + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: ... + + @abstractmethod + def reparameterize( + self, loc: torch.Tensor, scale: torch.Tensor + ) -> torch.Tensor: ... + + @abstractmethod + def decode(self, z: torch.Tensor) -> tuple[torch.Tensor, ...]: ... + + @abstractmethod + def compute_loss( + self, batch: torch.Tensor, annealing_factor: float + ) -> LossDict: ... + + @torch.no_grad() + @abstractmethod + def project(self, batch: torch.Tensor) -> torch.Tensor: + """Create latent representation.""" + ... + + @overload + @abstractmethod + def reconstruct( + self, batch: torch.Tensor, as_one: Literal[False] + ) -> tuple[torch.Tensor, torch.Tensor]: ... + + @overload + @abstractmethod + def reconstruct( + self, batch: torch.Tensor, as_one: Literal[True] + ) -> torch.Tensor: ... + + @overload + @abstractmethod + def reconstruct( + self, batch: torch.Tensor, as_one: bool = False + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ... + + @torch.no_grad() + @abstractmethod + def reconstruct( + self, batch: torch.Tensor, as_one: bool = False + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Create reconstruction.""" + ... + + @classmethod + def reload(cls: Type[T], model_path: Path) -> T: + """Reload a model from its serialized config and state dict.""" + model_dict = cast(SerializedModel, torch.load(model_path)) + target = model_dict["config"].pop("_target_") + module_name, class_name = target.rsplit(".", 1) + module = import_module(module_name) + cls_: Type = getattr(module, class_name) + model = cls_(**model_dict["config"]) + model.load_state_dict(model_dict["state_dict"]) + return model + + def freeze(self) -> None: + """Freeze all parameters.""" + for param in self.parameters(): + param.requires_grad = False + self.eval() + + def save(self, model_path: Path) -> None: + """Save the serialized config and state dict of the model to disk.""" + argnames = inspect.signature(self.__class__).parameters.keys() + config = {argname: getattr(self, argname) for argname in argnames} + config["_target_"] = get_fully_qualname(self) + model = SerializedModel( + config=config, + state_dict=self.state_dict(), + ) + torch.save(model, model_path, pickle_protocol=4) + + +def reload_vae(model_path: Path) -> BaseVae: + """Alias of `BaseVae.reload`.""" + return BaseVae.reload(model_path) diff --git a/src/move/models/layers/__init__.py b/src/move/models/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/move/models/layers/chunk.py b/src/move/models/layers/chunk.py new file mode 100644 index 00000000..87b658a8 --- /dev/null +++ b/src/move/models/layers/chunk.py @@ -0,0 +1,241 @@ +__all__ = ["Chunk", "SplitInput", "SplitOutput"] + +import itertools +import operator +from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union, cast + +import torch +from torch import nn +from torch.distributions import Distribution, constraints + +if TYPE_CHECKING: + from move.data.dataset import MoveDataset + +DiscreteData = list[torch.Tensor] +ContinuousData = list[torch.Tensor] +ContinuousDistribution = list[dict[str, torch.Tensor]] +SplitData = Union[ + tuple[DiscreteData, ContinuousData], tuple[DiscreteData, ContinuousDistribution] +] + +SUPPORTED_DISTRIBUTIONS = [ + "Normal", + "StudentT", + # other distributions possible but untested +] + +T = TypeVar("T", bound="SplitOutput") + + +class Chunk(nn.Module): + """Split output into a given number of chunks. + + Args: + chunks: Number of chunks + dim: Dimension to apply operation in + """ + + chunks: int + dim: int + + def __init__(self, chunks: int, dim: int = -1) -> None: + super().__init__() + self.chunks = chunks + self.dim = dim + + def __call__(self, *args: Any, **kwds: Any) -> tuple[torch.Tensor, ...]: + return super().__call__(*args, **kwds) + + def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, ...]: + if self.chunks == 1: + return (input,) + return tuple(torch.chunk(input, self.chunks, self.dim)) + + +class SplitOutput(nn.Module): + """Split and re-shape concatenated datasets into their original shapes. + + Args: + discrete_dataset_shapes: + Number of features and classes of each discrete dataset. + continuous_dataset_shapes: + Number of features of each continuous dataset. + distribution_name_or_cls: + If given, continuous variables will be treated as distribution + arguments. For instance, for a normal distribution, the continuous + subset of the output will be split into mean and standard deviation. + This can be either the name of a class from the `torch.distributions` + module or a class that can be instantiated. + """ + + num_discrete_features: int + num_continuous_features: int + num_features: int + num_expected_features: int + num_distribution_args: int + discrete_dataset_shapes: list[tuple[int, int]] + discrete_dataset_shapes_1d: list[int] + continuous_dataset_shapes: list[int] + discrete_split_indices: list[int] + continuous_split_indices: list[int] + discrete_activation: Optional[nn.Module] + continuous_activation: Optional[nn.Module] + + def __init__( + self, + discrete_dataset_shapes: list[tuple[int, int]], + continuous_dataset_shapes: list[int], + distribution_name_or_cls: Union[str, Type[Distribution], None] = None, + discrete_activation_name: Optional[str] = None, + continuous_activation_name: Optional[str] = None, + ) -> None: + super().__init__() + + self.distribution: Optional[Type[Distribution]] = None + self.num_distribution_args = 1 + if distribution_name_or_cls is not None: + if isinstance(distribution_name_or_cls, str): + if distribution_name_or_cls not in SUPPORTED_DISTRIBUTIONS: + raise ValueError("Unsupported distribution") + self.distribution = getattr( + torch.distributions, distribution_name_or_cls, None + ) + else: + if not issubclass(distribution_name_or_cls, Distribution): + raise ValueError("Not a distribution") + self.distribution = distribution_name_or_cls + if self.distribution is not None: + self.num_distribution_args = len(self.distribution.arg_constraints) # type: ignore + + activation_funs = [] + for name in [discrete_activation_name, continuous_activation_name]: + if name is not None: + activation_fun = getattr(nn, name) + assert issubclass(activation_fun, nn.Module) + activation_funs.append(activation_fun()) + else: + activation_funs.append(None) + self.discrete_activation, self.continuous_activation = activation_funs + + self.discrete_dataset_shapes = discrete_dataset_shapes + self.continuous_dataset_shapes = continuous_dataset_shapes + + # Flatten discrete dataset shapes (normally, 2D) + (*self.discrete_dataset_shapes_1d,) = itertools.starmap( + operator.mul, self.discrete_dataset_shapes + ) + + # Count num. features + self.num_discrete_features = sum(self.discrete_dataset_shapes_1d) + self.num_continuous_features = sum(self.continuous_dataset_shapes) + self.num_features = self.num_discrete_features + self.num_continuous_features + self.num_expected_features = self.num_discrete_features + ( + self.num_continuous_features * self.num_distribution_args + ) + + # Compute split indices + if len(self.discrete_dataset_shapes_1d) > 0: + *self.discrete_split_indices, _ = itertools.accumulate( + self.discrete_dataset_shapes_1d + ) + else: + self.discrete_split_indices = [] + if len(self.continuous_dataset_shapes) > 0: + *self.continuous_split_indices, _ = itertools.accumulate( + [ + shape * self.num_distribution_args + for shape in self.continuous_dataset_shapes + ] + ) + else: + self.continuous_split_indices = [] + + def __call__(self, *args: Any, **kwds: Any) -> SplitData: + return super().__call__(*args, **kwds) + + @classmethod + def from_move_dataset(cls: Type[T], move_dataset: "MoveDataset") -> T: + """Create layer from shapes of discrete and continuous datasets contained in a + MOVE dataset.""" + discrete_dataset_shapes = [] + continuous_dataset_shapes = [] + for dataset in move_dataset.discrete_datasets: + discrete_dataset_shapes.append(dataset.original_shape) + for dataset in move_dataset.continuous_datasets: + continuous_dataset_shapes.append(dataset.num_features) + return cls(discrete_dataset_shapes, continuous_dataset_shapes) + + def forward(self, x: torch.Tensor) -> SplitData: + if x.dim() != 2: + raise ValueError("Input expected to be 2D.") + + if x.size(1) != self.num_expected_features: + raise ValueError( + f"Size mismatch: input ({x.size(1)}) is not equal to expected " + f"number of features ({self.num_expected_features})." + ) + + # Split into discrete/continuous sets + discrete_x, continuous_x = torch.tensor_split( + x, [self.num_discrete_features], dim=-1 + ) + if self.discrete_activation is not None: + discrete_x = self.discrete_activation(discrete_x) + if self.continuous_activation is not None: + continuous_x = self.continuous_activation(continuous_x) + + # Split and re-shape discrete set into 3D subsets + discrete_subsets_flat = torch.tensor_split( + discrete_x, self.discrete_split_indices, dim=-1 + ) + discrete_subsets = [ + torch.reshape(subset, (-1, *shape)) + for subset, shape in zip( + discrete_subsets_flat, self.discrete_dataset_shapes + ) + ] + + # Split continuous set into subsets + # If outputs are distributions, split into correct # of arguments + continous_subsets = list( + torch.tensor_split(continuous_x, self.continuous_split_indices, dim=-1) + ) + if self.num_distribution_args > 1: + if self.distribution is not None: + continous_distributions = [] + # For each distribution, split into correct # arguments + # Example: if modeling a Normal distribution, split into loc and scale + # Chunks are saved in dictionary + # If distribution arg constrained positive (e.g. scale), apply transform + for subset in continous_subsets: + chunks = torch.chunk(subset, self.num_distribution_args, dim=-1) + args = {} + for arg, (arg_name, arg_constraint) in zip( + chunks, self.distribution.arg_constraints.items() # type: ignore + ): + if arg_constraint is constraints.positive: + arg = torch.exp(arg * 0.5) + args[arg_name] = arg + continous_distributions.append(args) + return discrete_subsets, continous_distributions + + if isinstance(self, SplitInput): + return discrete_subsets, list(continous_subsets) + + return discrete_subsets, continous_subsets + + +class SplitInput(SplitOutput): + """Alias of `SplitOutput`.""" + + def __init__( + self, + discrete_dataset_shapes: list[tuple[int, int]], + continuous_dataset_shapes: list[int], + ) -> None: + super().__init__(discrete_dataset_shapes, continuous_dataset_shapes, None) + + def __call__(self, *args: Any, **kwds: Any) -> tuple[DiscreteData, ContinuousData]: + return cast( + tuple[DiscreteData, ContinuousData], super().__call__(*args, **kwds) + ) diff --git a/src/move/models/layers/encoder_decoder.py b/src/move/models/layers/encoder_decoder.py new file mode 100644 index 00000000..3981ae68 --- /dev/null +++ b/src/move/models/layers/encoder_decoder.py @@ -0,0 +1,85 @@ +__all__ = ["Encoder", "Decoder"] + +from typing import Any, Sequence + +import torch +from torch import nn + +from move.models.layers.chunk import Chunk + + +def build_network( + input_dim: int, + compress_dims: Sequence[int], + output_dim: int, + dropout_rate: float, + activation_fun_name: str, +) -> list[nn.Module]: + """Build a network that takes # input dimensions, (de)compresses them, and + returns # output dimensions using a sequence of linear, non-linear, dropout, + and batch normalization layers.""" + + activation_fun = getattr(nn, activation_fun_name) + assert issubclass(activation_fun, nn.Module) + + layers: list[nn.Module] = [] + layer_dims = [input_dim, *compress_dims] + + out_features = None + for in_features, out_features in zip(layer_dims, layer_dims[1:]): + layers.append(nn.Linear(in_features, out_features)) + layers.append(activation_fun()) + if dropout_rate > 0: + layers.append(nn.Dropout(dropout_rate)) + layers.append(nn.BatchNorm1d(out_features)) + + assert out_features is not None + layers.append(nn.Linear(out_features, output_dim)) + + return layers + + +class Encoder(nn.Sequential): + num_args: int + + def __init__( + self, + input_dim: int, + compress_dims: Sequence[int], + embedding_dim: int, + embedding_args: int = 2, + dropout_rate: float = 0.0, + activation_fun_name: str = "LeakyReLU", + ) -> None: + self.num_args = embedding_args + layers = build_network( + input_dim, + compress_dims, + embedding_dim * embedding_args, + dropout_rate, + activation_fun_name, + ) + layers.append(Chunk(embedding_args)) + super().__init__(*layers) + + def __call__(self, *args: Any, **kwds: Any) -> tuple[torch.Tensor, ...]: + return super().__call__(*args, **kwds) + + +class Decoder(Encoder): + def __init__( + self, + embedding_dim: int, + compress_dims: Sequence[int], + output_dim: int, + dropout_rate: float = 0.0, + activation_fun_name: str = "LeakyReLU", + ) -> None: + super().__init__( + embedding_dim, + compress_dims, + output_dim, + 1, + dropout_rate, + activation_fun_name, + ) diff --git a/src/move/models/vae.py b/src/move/models/vae.py index 08f639ae..5df2c58b 100644 --- a/src/move/models/vae.py +++ b/src/move/models/vae.py @@ -1,747 +1,198 @@ -__all__ = ["VAE"] +__all__ = ["Vae"] -import logging -from typing import Optional, Callable +import itertools +import operator +from typing import Optional, cast import torch -from torch import nn, optim -from torch.utils.data import DataLoader +import torch.optim +from torch import nn -from move.core.typing import FloatArray, IntArray +from move.core.exceptions import CudaIsNotAvailable, ShapeAndWeightMismatch +from move.models.base import BaseVae, LossDict, VaeOutput +from move.models.layers.chunk import ContinuousData, SplitInput, SplitOutput +from move.models.layers.encoder_decoder import Decoder, Encoder -logger = logging.getLogger("vae.py") - -class VAE(nn.Module): - """Variational autoencoder. - - Instantiate with: - continuous_shapes: shape of the different continuous datasets if any - categorical_shapes: shape of the different categorical datasets if any - num_hidden: List of n_neurons in the hidden layers [[200, 200]] - num_latent: Number of neurons in the latent layer [15] - beta: Multiply KLD by the inverse of this value [0.0001] - continuous_weights: list of weights for each continuous dataset - categorical_weights: list of weights for each categorical dataset - dropout: Probability of dropout on forward pass [0.2] - cuda: Use CUDA (GPU accelerated training) [False] - - Raises: - ValueError: Minimum 1 latent unit - ValueError: Beta must be greater than zero. - ValueError: Dropout must be between zero and one. - ValueError: Shapes of the input data must be provided. - ValueError: Number of continuous weights must be the same as number of - continuous datasets - ValueError: Number of categorical weights must be the same as number of - categorical datasets - """ +class Vae(BaseVae): + """Variational autoencoder""" def __init__( self, - categorical_shapes: Optional[list[tuple[int, ...]]] = None, - continuous_shapes: Optional[list[int]] = None, - categorical_weights: Optional[list[int]] = None, - continuous_weights: Optional[list[int]] = None, + discrete_shapes: list[tuple[int, int]], + continuous_shapes: list[int], + discrete_weights: Optional[list[float]] = None, + continuous_weights: Optional[list[float]] = None, num_hidden: list[int] = [200, 200], num_latent: int = 20, - beta: float = 0.01, - dropout: float = 0.2, - cuda: bool = False, - ): + kl_weight: float = 0.01, + dropout_rate: float = 0.2, + use_cuda: bool = False, + ) -> None: + super().__init__() + + # Validate and save arguments + if sum(num_hidden) <= 0: + raise ValueError( + "Number of hidden units in encoder/decoder must be non-negative." + ) + self.num_hidden = num_hidden if num_latent < 1: - raise ValueError(f"Minimum 1 latent unit. Input was {num_latent}.") - - if beta <= 0: - raise ValueError("Beta must be greater than zero.") - - if not (0 <= dropout < 1): - raise ValueError("Dropout must be between zero and one.") - - if continuous_shapes is None and categorical_shapes is None: - raise ValueError("Shapes of the input data must be provided.") + raise ValueError("Latent space size must be non-negative.") + self.num_latent = num_latent - self.input_size = 0 - if continuous_shapes is not None: - self.num_continuous = sum(continuous_shapes) - self.input_size += self.num_continuous - self.continuous_shapes = continuous_shapes + if kl_weight <= 0: + raise ValueError("KLD weight must be greater than zero.") + self.kl_weight = kl_weight + + if not (0 <= dropout_rate < 1): + raise ValueError("Dropout rate must be between [0, 1).") + self.dropout_rate = dropout_rate + + if discrete_shapes is None and continuous_shapes is None: + raise ValueError("Shapes of input datasets must be provided.") + + self.discrete_shapes = discrete_shapes + self.num_disc_features = 0 + self.discrete_weights = [1.0] * len(self.discrete_shapes) + if discrete_shapes is not None and len(discrete_shapes) > 0: + (*shapes_1d,) = itertools.starmap(operator.mul, discrete_shapes) + self.num_disc_features = sum(shapes_1d) + if discrete_weights is not None: + if len(discrete_shapes) != len(discrete_weights): + raise ShapeAndWeightMismatch( + len(discrete_shapes), len(discrete_weights) + ) + self.discrete_weights = discrete_weights + self.continuous_shapes = continuous_shapes + self.num_cont_features = 0 + self.continuous_weights = [1.0] * len(self.continuous_shapes) + if continuous_shapes is not None and len(continuous_shapes) > 0: + self.num_cont_features = sum(continuous_shapes) if continuous_weights is not None: - self.continuous_weights = continuous_weights if len(continuous_shapes) != len(continuous_weights): - raise ValueError( - "Number of continuous weights must be the same as" - " number of continuous datasets" + raise ShapeAndWeightMismatch( + len(continuous_shapes), len(continuous_weights) ) - else: - self.num_continuous = 0 - - if categorical_shapes is not None: - self.num_categorical = sum( - [int.__mul__(*shape) for shape in categorical_shapes] - ) - self.input_size += self.num_categorical - self.categorical_shapes = categorical_shapes - - if categorical_weights is not None: - self.categorical_weights = categorical_weights - if len(categorical_shapes) != len(categorical_weights): - raise ValueError( - "Number of categorical weights must be the same as" - " number of categorical datasets" - ) - else: - self.num_categorical = 0 - - super(VAE, self).__init__() - - # Initialize simple attributes - self.beta = beta - self.num_hidden = num_hidden - self.num_latent = num_latent - self.dropout = dropout - - self.device = torch.device("cuda" if cuda == True else "cpu") - - # Activation functions - self.relu = nn.LeakyReLU() - self.log_softmax = nn.LogSoftmax(dim=1) - self.dropoutlayer = nn.Dropout(p=self.dropout) - - # Initialize lists for holding hidden layers - self.encoderlayers = nn.ModuleList() - self.encodernorms = nn.ModuleList() - self.decoderlayers = nn.ModuleList() - self.decodernorms = nn.ModuleList() - - ### Layers - # Hidden layers - for nin, nout in zip([self.input_size] + self.num_hidden, self.num_hidden): - self.encoderlayers.append(nn.Linear(nin, nout)) - self.encodernorms.append(nn.BatchNorm1d(nout)) - - # Latent layers - self.mu = nn.Linear(self.num_hidden[-1], self.num_latent) # mu layer - self.var = nn.Linear(self.num_hidden[-1], self.num_latent) # logvariance layer - - # Decoding layers - for nin, nout in zip( - [self.num_latent] + self.num_hidden[::-1], self.num_hidden[::-1] - ): - self.decoderlayers.append(nn.Linear(nin, nout)) - self.decodernorms.append(nn.BatchNorm1d(nout)) - - # Reconstruction - output layers - self.out = nn.Linear(self.num_hidden[0], self.input_size) # to output - - def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Encodes the data in the data loader and returns the encoded matrix. - - Args: - x: input data - - Returns: - A tuple containing: - mean latent vector - log-variance latent vector - """ - # Hidden layers - for encoderlayer, encodernorm in zip(self.encoderlayers, self.encodernorms): - x = encoderlayer(x) - x = self.relu(x) - x = self.dropoutlayer(x) - x = encodernorm(x) - - return self.mu(x), self.var(x) - - def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: - """ - Performs reparametrization trick - - Args: - mu: mean latent vector - logvar: log-variance latent vector - - Returns: - sample from latent space distribution - """ - std = torch.exp(0.5 * logvar) - eps = torch.randn_like(std) - - return eps.mul(std).add_(mu) - - def decompose_categorical(self, reconstruction: torch.Tensor) -> list[torch.Tensor]: - """ - Returns list of final reconstructions (after applying - log-softmax to the outputs of decoder) of each categorical class - - Args: - reconstruction: results of final layer of decoder - - Returns: - final reconstructions of each categorical class - """ - cat_tmp = reconstruction.narrow(1, 0, self.num_categorical) - - # handle soft max for each categorical dataset - cat_out = [] - pos = 0 - for cat_shape in self.categorical_shapes: - cat_dataset = cat_tmp[:, pos : (cat_shape[0] * cat_shape[1] + pos)] - - cat_out_tmp = cat_dataset.view( - cat_dataset.shape[0], cat_shape[0], cat_shape[1] - ) - cat_out_tmp = cat_out_tmp.transpose(1, 2) - cat_out_tmp = self.log_softmax(cat_out_tmp) - - cat_out.append(cat_out_tmp) - pos += cat_shape[0] * cat_shape[1] - - return cat_out - - def decode( - self, x: torch.Tensor - ) -> tuple[Optional[list[torch.Tensor]], Optional[torch.Tensor]]: - """ - Decode to the input space from the latent space - - Args: - x: sample from latent space distribution - - Returns: - A tuple containing: - cat_out: - list of reconstructions of every categorical data class - con_out: - reconstruction of continuous data - """ - for decoderlayer, decodernorm in zip(self.decoderlayers, self.decodernorms): - x = decoderlayer(x) - x = self.relu(x) - x = self.dropoutlayer(x) - x = decodernorm(x) - - reconstruction = self.out(x) - - # Decompose reconstruction to categorical and continuous variables - # if both types are in the input - cat_out, con_out = None, None - if self.num_categorical > 0: - cat_out = self.decompose_categorical(reconstruction) - if self.num_continuous > 0: - con_out = reconstruction.narrow( - 1, self.num_categorical, self.num_continuous - ) - - return cat_out, con_out - - def forward( - self, tensor: torch.Tensor - ) -> tuple[list[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Forward propagate through the VAE network - - Args: - tensor (torch.Tensor): input data - - Returns: - (tuple): a tuple containing: - cat_out (list): list of reconstructions of every categorical - data class - con_out (torch.Tensor): reconstructions of continuous data - mu (torch.Tensor): mean latent vector - logvar (torch.Tensor): mean log-variance vector - """ - mu, logvar = self.encode(tensor) - z = self.reparameterize(mu, logvar) - cat_out, con_out = self.decode(z) - - return cat_out, con_out, mu, logvar - - def calculate_cat_error( - self, - cat_in: torch.Tensor, - cat_out: list[torch.Tensor], - ) -> torch.Tensor: - """ - Calculates errors (cross-entropy) for categorical data reconstructions - - Args: - cat_in: - input categorical data - cat_out: - list of reconstructions of every categorical data class - - Returns: - torch.Tensor: - Errors (cross-entropy) for categorical data reconstructions - """ - batch_size = cat_in.shape[0] - - # calcualte target values for all cat datasets - count = 0 - cat_errors = [] - pos = 0 - for cat_shape in self.categorical_shapes: - cat_dataset = cat_in[:, pos : (cat_shape[0] * cat_shape[1] + pos)] - - cat_dataset = cat_dataset.view(cat_in.shape[0], cat_shape[0], cat_shape[1]) - cat_target = cat_dataset - cat_target = cat_target.argmax(2) - cat_target[cat_dataset.sum(dim=2) == 0] = -1 - cat_target = cat_target.to(self.device) - - # Cross entropy loss for categroical - loss = nn.NLLLoss(reduction="sum", ignore_index=-1) - cat_errors.append( - loss(cat_out[count], cat_target) / (batch_size * cat_shape[0]) - ) - count += 1 - pos += cat_shape[0] * cat_shape[1] - - cat_errors = torch.stack(cat_errors) - return cat_errors - - def calculate_con_error( - self, con_in: torch.Tensor, con_out: torch.Tensor, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] - ) -> torch.Tensor: - """ - Calculates errors (MSE) for continuous data reconstructions - - Args: - con_in: input continuous data - con_out: reconstructions of continuous data - loss: loss function - - Returns: - MSE loss - """ - batch_size = con_in.shape[0] - total_shape = 0 - con_errors_list: list[torch.Tensor] = [] - for s in self.continuous_shapes: - c_in = con_in[:, total_shape : (s + total_shape - 1)] - c_re = con_out[:, total_shape : (s + total_shape - 1)] - error = loss(c_re, c_in) / batch_size - con_errors_list.append(error) - total_shape += s - - con_errors = torch.stack(con_errors_list) - con_errors = con_errors / torch.Tensor(self.continuous_shapes).to(self.device) - MSE = torch.sum( - con_errors * torch.Tensor(self.continuous_weights).to(self.device) - ) - return MSE - - # Reconstruction + KL divergence losses summed over all elements and batch - def loss_function( - self, - cat_in: torch.Tensor, - cat_out: list[torch.Tensor], - con_in: torch.Tensor, - con_out: torch.Tensor, - mu: torch.Tensor, - logvar: torch.Tensor, - kld_w: float, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Calculates the loss for data reconstructions - - Args: - cat_in: input categorical data - cat_out: list of reconstructions of every categorical data class - con_in: input continuous data - con_out: reconstructions of continuous data - mu: mean latent vector - logvar: mean log-variance vector - kld_w: kld weight - - Returns: - (tuple): a tuple containing: - total loss on train set during the training of the epoch - BCE loss on train set during the training of the epoch - SSE loss on train set during the training of the epoch - KLD loss on train set during the training of the epoch - """ - - MSE = 0 - CE = 0 - # calculate loss for catecorical data if in the input - if cat_out is not None: - cat_errors = self.calculate_cat_error(cat_in, cat_out) - if self.categorical_weights is not None: - CE = torch.sum( - cat_errors * torch.Tensor(self.categorical_weights).to(self.device) - ) - else: - CE = torch.sum(cat_errors) / len(cat_errors) - - # calculate loss for continuous data if in the input - if con_out is not None: - batch_size = con_in.shape[0] - # Mean square error loss for continauous - loss = nn.MSELoss(reduction="sum") - # set missing data to 0 to remove any loss these would provide - con_out[con_in == 0] = 0 - - # include different weights for each omics dataset - if self.continuous_weights is not None: - MSE = self.calculate_con_error(con_in, con_out, loss) - else: - MSE = loss(con_out, con_in) / (batch_size * self.num_continuous) - - # see Appendix B from VAE paper: - # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 - # https://arxiv.org/abs/1312.6114 - # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) - KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / (batch_size) - - KLD_weight = self.beta * kld_w - loss = CE + MSE + KLD * KLD_weight - - return loss, CE, MSE, KLD * KLD_weight - - def encoding( - self, - train_loader: DataLoader, - epoch: int, - lrate: float, - kld_w: float, - ) -> tuple[float, float, float, float]: - """ - One iteration of VAE - - Args: - train_loader: Dataloader with train dataset - epoch: the epoch - lrate: learning rate for the model - kld_w: float of KLD weight - - Returns: - (tuple): a tuple containing: - total loss on train set during the training of the epoch - BCE loss on train set during the training of the epoch - SSE loss on train set during the training of the epoch - KLD loss on train set during the training of the epoch - """ - self.train() - optimizer = optim.Adam(self.parameters(), lr=lrate) - - epoch_loss = 0 - epoch_kldloss = 0 - epoch_sseloss = 0 - epoch_bceloss = 0 - - for _, (cat, con) in enumerate(train_loader): - # Move input to GPU if requested - cat = cat.to(self.device) - con = con.to(self.device) - - if self.num_categorical > 0 and self.num_continuous > 0: - tensor = torch.cat((cat, con), 1) - elif self.num_categorical > 0: - tensor = cat - elif self.num_continuous > 0: - tensor = con - else: - assert False, "Must have at least 1 categorial or 1 continuous feature" - - optimizer.zero_grad() - - cat_out, con_out, mu, logvar = self(tensor) - - loss, bce, sse, kld = self.loss_function( - cat, cat_out, con, con_out, mu, logvar, kld_w - ) - loss.backward() - - epoch_loss += loss.data.item() - epoch_kldloss += kld.data.item() - - if self.num_continuous > 0: - epoch_sseloss += sse.data.item() - - if self.num_categorical > 0: - epoch_bceloss += bce.data.item() + self.continuous_weights = continuous_weights - optimizer.step() + self.in_features = self.num_disc_features + self.num_cont_features - logger.info( - "\tEpoch: {}\tLoss: {:.6f}\tCE: {:.7f}\tSSE: {:.6f}\t" - "KLD: {:.4f}\tBatchsize: {}".format( - epoch, - epoch_loss / len(train_loader), - epoch_bceloss / len(train_loader), - epoch_sseloss / len(train_loader), - epoch_kldloss / len(train_loader), - train_loader.batch_size, - ) + self.encoder = Encoder( + self.in_features, + num_hidden, + num_latent, + embedding_args=self.embedding_args, + dropout_rate=dropout_rate, ) - return ( - epoch_loss / len(train_loader), - epoch_bceloss / len(train_loader), - epoch_sseloss / len(train_loader), - epoch_kldloss / len(train_loader), + self.decoder = Decoder( + num_latent, + num_hidden[::-1], + self.in_features, + dropout_rate=dropout_rate, ) + self.log_softmax = nn.LogSoftmax(dim=-1) + self.split_input = SplitInput(self.discrete_shapes, self.continuous_shapes) + self.split_output = SplitOutput(self.discrete_shapes, self.continuous_shapes) - def make_cat_recon_out(self, length: int) -> tuple[torch.Tensor, torch.Tensor, int]: - """ - Initiate empty tensors for categorical data - - Args: - length: number of samples - - Returns: - (tuple): a tuple containing: - cat_class: empty tensor for input categorical data - cat_recon: empty tensor for reconstructed categorical data - cat_total_shape: number of features of linearized one hot - categorical data - """ - cat_total_shape = 0 - for cat_shape in self.categorical_shapes: - cat_total_shape += cat_shape[0] - - cat_class = torch.empty((length, cat_total_shape)).int() - cat_recon = torch.empty((length, cat_total_shape)).int() - return cat_class, cat_recon, cat_total_shape - - def get_cat_recon( - self, batch: int, cat_total_shape: int, cat: torch.Tensor, cat_out: torch.Tensor - ) -> tuple[IntArray, IntArray]: - """ - Generates reconstruction data of categorical data class - - Args: - batch: number of samples in the batch - cat_total_shape: number of features of linearized one hot - categorical data - cat: input categorical data - cat_out: reconstructed categorical data - - Returns: - (tuple): a tuple containing: - cat_out_class: reconstructed categorical data - cat_target: input categorical data - """ - count = 0 - cat_out_class = torch.empty((batch, cat_total_shape)).int() - cat_target = torch.empty((batch, cat_total_shape)).int() - pos = 0 - shape_1 = 0 - for cat_shape in self.categorical_shapes: - # Get input categorical data - cat_in_tmp = cat[:, pos : (cat_shape[0] * cat_shape[1] + pos)] - cat_in_tmp = cat_in_tmp.view(cat.shape[0], cat_shape[0], cat_shape[1]) - - # Calculate target values for input - cat_target_tmp = cat_in_tmp - cat_target_tmp = torch.argmax(cat_target_tmp.detach(), dim=2) - cat_target_tmp[cat_in_tmp.sum(dim=2) == 0] = -1 - cat_target[ - :, shape_1 : (cat_shape[0] + shape_1) - ] = cat_target_tmp # .numpy() - - # Get reconstructed categorical data - cat_out_tmp = cat_out[count] - cat_out_tmp = cat_out_tmp.transpose(1, 2) - cat_out_class[:, shape_1 : (cat_shape[0] + shape_1)] = torch.argmax( - cat_out_tmp, dim=2 - ) # .numpy() + self.nll_loss = nn.NLLLoss(reduction="sum", ignore_index=-1) + self.mse_loss = nn.MSELoss(reduction="sum") - # make counts for next dataset - pos += cat_shape[0] * cat_shape[1] - shape_1 += cat_shape[0] - count += 1 + if use_cuda and not torch.cuda.is_available(): + raise CudaIsNotAvailable() + self.use_cuda = use_cuda - cat_target = cat_target.numpy() - cat_out_class = cat_out_class.numpy() + device = torch.device("cuda" if use_cuda else "cpu") + self.to(device) - return cat_out_class, cat_target - - def _validate_batch(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - """ - Returns the batch of categorical and continuous data if they are not - None - - Args: - batch: batches of categorical and continuous data - - Returns: - a formed batch - """ - cat, con = batch - cat = cat.to(self.device) - con = con.to(self.device) - - if self.num_categorical == 0: - return con - elif self.num_continuous == 0: - return cat - return torch.cat((cat, con), dim=1) - - @torch.no_grad() - def project(self, dataloader: DataLoader) -> FloatArray: - """Generates an embedding of the data contained in the DataLoader. - - Args: - dataloader: A DataLoader with categorical or continuous data - - Returns: - FloatArray: Embedding - """ - self.eval() - embedding = [] - for batch in dataloader: - batch = self._validate_batch(batch) - *_, mu, _ = self(batch) - embedding.append(mu) - embedding = torch.cat(embedding, dim=0).cpu().numpy() - return embedding - - @torch.no_grad() - def reconstruct( - self, dataloader: DataLoader - ) -> tuple[list[FloatArray], FloatArray]: - """ - Generates a reconstruction of the data contained in the DataLoader. - - Args: - dataloader: A DataLoader with categorical or continuous data - - Returns: - A list of categorical reconstructions and the continuous - reconstruction - """ - self.eval() - cat_recons = [[] for _ in range(len(self.categorical_shapes))] - con_recons = [] - for batch in dataloader: - batch = self._validate_batch(batch) - cat_recon, con_recon, *_ = self(batch) - if cat_recon is not None: - for i, cat in enumerate(cat_recon): - cat_recons[i].append(torch.argmax(cat, dim=1)) - if con_recon is not None: - con_recons.append(con_recon) - if cat_recons: - cat_recons = [torch.cat(cats, dim=0).cpu().numpy() for cats in cat_recons] - if con_recons: - con_recons = torch.cat(con_recons, dim=0).cpu().numpy() - return cat_recons, con_recons - - @torch.no_grad() - def latent( - self, dataloader: DataLoader, kld_weight: float - ) -> tuple[FloatArray, FloatArray, IntArray, IntArray, FloatArray, float, float]: - """ - Iterate through validation or test dataset - - Args: - dataloader: Dataloader with test dataset - kld_weight: KLD weight - - Returns: - (tuple): a tuple containing: - latent: array of VAE latent space mean vectors values - latent_var: array of VAE latent space logvar vectors values - cat_recon: reconstructed categorical data - cat_class: input categorical data - con_recon: reconstructions of continuous data - test_loss: total loss on test set - test_likelihood: total likelihood on test set - """ - - self.eval() - test_loss = 0 - test_likelihood = 0 - - num_samples = dataloader.dataset.num_samples - - latent = torch.empty((num_samples, self.num_latent)) - latent_var = torch.empty((num_samples, self.num_latent)) - - # reconstructed output - if self.num_categorical > 0: - cat_class, cat_recon, cat_total_shape = self.make_cat_recon_out(num_samples) - else: - cat_class = None - cat_recon = None - - con_recon = ( - None - if self.num_continuous == 0 - else torch.empty((num_samples, self.num_continuous)) + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + z_loc, z_logvar, *_ = self.encoder(x) + return z_loc, torch.exp(z_logvar * 0.5) + + def reparameterize(self, loc: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + eps = torch.randn_like(scale) + return eps.mul(scale).add_(loc) + + def decode(self, z: torch.Tensor) -> tuple[torch.Tensor, ...]: + return self.decoder(z) + + def project(self, batch: torch.Tensor) -> torch.Tensor: + return self.encode(batch)[0] + + def reconstruct(self, batch: torch.Tensor, as_one: bool = False): + out = self(batch)["x_recon"] + if as_one: + return out + out_disc, out_cont = self.split_output(out) + recon_disc = torch.cat( + [logits.flatten(start_dim=1) for logits in out_disc], dim=1 ) + recon_cont = torch.cat(cast(ContinuousData, out_cont), dim=1) + return recon_disc, recon_cont + + def forward(self, x: torch.Tensor) -> VaeOutput: + z_loc, z_scale = self.encode(x) + z = self.reparameterize(z_loc, z_scale) + x_recon, *_ = self.decode(z) + return { + "z_loc": z_loc, + "z_scale": z_scale, + "x_recon": x_recon, + } + + def compute_loss(self, batch: torch.Tensor, annealing_factor: float) -> LossDict: + # Split concatenated input + batch_disc, batch_cont = self.split_input(batch) + # Split concatenated output + out = self(batch) + out_disc, out_cont = self.split_output(out["x_recon"]) + + # Compute discrete dataset losses + disc_losses = [] + for disc_input, disc_logits, disc_wt in zip( + batch_disc, out_disc, self.discrete_weights + ): + disc_recon = self.log_softmax(disc_logits).transpose(1, 2) + disc_cats = disc_input.argmax(dim=-1) + na_mask = disc_input.sum(dim=-1) == 0 + disc_cats[na_mask] = -1.0 + multiplier = disc_wt / operator.mul(*disc_input.shape[:-1]) + loss = self.nll_loss(disc_recon, disc_cats) * multiplier + disc_losses.append(loss) + if len(disc_losses) > 0: + disc_loss = torch.stack(disc_losses).sum() + else: + disc_loss = torch.zeros(()) - row = 0 - for (cat, con) in dataloader: - cat = cat.to(self.device) - con = con.to(self.device) - - # get dataset - if self.num_categorical > 0 and self.num_continuous > 0: - tensor = torch.cat((cat, con), 1) - elif self.num_categorical > 0: - tensor = cat - elif self.num_continuous > 0: - tensor = con - else: - assert False, "Must have at least 1 categorial or 1 continuous feature" - - # Evaluate - cat_out, con_out, mu, logvar = self(tensor) - - mu = mu.to(self.device) - logvar = logvar.to(self.device) - batch = len(mu) - - loss, bce, sse, _ = self.loss_function( - cat, cat_out, con, con_out, mu, logvar, kld_weight - ) - test_likelihood += bce + sse - test_loss += loss.data.item() - - if self.num_categorical > 0: - cat_out_class, cat_target = self.get_cat_recon( - batch, cat_total_shape, cat, cat_out - ) - cat_recon[row : row + len(cat_out_class)] = torch.Tensor(cat_out_class) - cat_class[row : row + len(cat_target)] = torch.Tensor(cat_target) - - if self.num_continuous > 0: - con_recon[row : row + len(con_out)] = con_out - - latent_var[row : row + len(logvar)] = logvar - latent[row : row + len(mu)] = mu - row += len(mu) - - test_loss /= len(dataloader) - logger.info("====> Test set loss: {:.4f}".format(test_loss)) - - latent = latent.numpy() - latent_var = latent_var.numpy() - cat_recon = cat_recon.numpy() - cat_class = cat_class.numpy() - con_recon = con_recon.numpy() + # Compute continuous dataset losses + cont_losses = [] + for cont_input, cont_recon, cont_wt in zip( + batch_cont, out_cont, self.continuous_weights + ): + na_mask = (cont_input == 0).logical_not().float() + multiplier = cont_wt / operator.mul(*cont_input.shape) + loss = self.mse_loss(na_mask * cont_recon, cont_input) * multiplier + cont_losses.append(loss) + if len(cont_losses) > 0: + cont_loss = torch.stack(cont_losses).sum() + else: + cont_loss = torch.zeros(()) - assert row == num_samples - return ( - latent, - latent_var, - cat_recon, - cat_class, - con_recon, - test_loss, - test_likelihood, + # Compute KL divergence + z_loc, z_var = out["z_loc"], out["z_scale"] ** 2 + kl_div = ( + -0.5 * torch.sum(1 + z_var.log() - z_loc.pow(2) - z_var) / batch.size(0) ) - def __repr__(self) -> str: - return ( - f"VAE ({self.input_size} ⇄ {' ⇄ '.join(map(str, self.num_hidden))}" - f" ⇄ {self.num_latent})" - ) + # Compute ELBO + kl_weight = annealing_factor * self.kl_weight + elbo = disc_loss + cont_loss + kl_div * kl_weight + return { + "elbo": elbo, + "discrete_loss": disc_loss, + "continuous_loss": cont_loss, + "kl_div": kl_div, + "kl_weight": kl_weight, + } diff --git a/src/move/models/vae_distribution.py b/src/move/models/vae_distribution.py new file mode 100644 index 00000000..94f49e02 --- /dev/null +++ b/src/move/models/vae_distribution.py @@ -0,0 +1,227 @@ +__all__ = ["VaeDistribution", "VaeNormal"] + +import itertools +import operator +from typing import Optional, Type, cast + +import torch +import torch.optim +from torch.distributions import ( + Categorical, + Distribution, + Normal, + kl_divergence, +) + +from move.core.exceptions import CudaIsNotAvailable, ShapeAndWeightMismatch +from move.models.base import BaseVae, LossDict, VaeOutput +from move.models.layers.chunk import ( + ContinuousDistribution, + SplitInput, + SplitOutput, +) +from move.models.layers.encoder_decoder import Decoder, Encoder + + +class VaeDistribution(BaseVae): + """Variational autoencoder with a distribution on its decoder.""" + + def __init__( + self, + discrete_shapes: list[tuple[int, int]], + continuous_shapes: list[int], + discrete_weights: Optional[list[float]] = None, + continuous_weights: Optional[list[float]] = None, + num_hidden: list[int] = [200, 200], + num_latent: int = 20, + kl_weight: float = 0.01, + dropout_rate: float = 0.2, + use_cuda: bool = False, + ) -> None: + super().__init__() + + # Validate and save arguments + if sum(num_hidden) <= 0: + raise ValueError( + "Number of hidden units in encoder/decoder must be non-negative." + ) + self.num_hidden = num_hidden + + if num_latent < 1: + raise ValueError("Latent space size must be non-negative.") + self.num_latent = num_latent + + if kl_weight <= 0: + raise ValueError("KLD weight must be greater than zero.") + self.kl_weight = kl_weight + + if not (0 <= dropout_rate < 1): + raise ValueError("Dropout rate must be between [0, 1).") + self.dropout_rate = dropout_rate + + if discrete_shapes is None and continuous_shapes is None: + raise ValueError("Shapes of input datasets must be provided.") + + self.discrete_shapes = discrete_shapes + self.num_disc_features = 0 + self.discrete_weights = [1.0] * len(self.discrete_shapes) + if discrete_shapes is not None and len(discrete_shapes) > 0: + (*shapes_1d,) = itertools.starmap(operator.mul, discrete_shapes) + self.num_disc_features = sum(shapes_1d) + if discrete_weights is not None: + if len(discrete_shapes) != len(discrete_weights): + raise ShapeAndWeightMismatch( + len(discrete_shapes), len(discrete_weights) + ) + self.discrete_weights = discrete_weights + + self.continuous_shapes = continuous_shapes + self.num_cont_features = 0 + self.continuous_weights = [1.0] * len(self.continuous_shapes) + if continuous_shapes is not None and len(continuous_shapes) > 0: + self.num_cont_features = sum(continuous_shapes) + if continuous_weights is not None: + if len(continuous_shapes) != len(continuous_weights): + raise ShapeAndWeightMismatch( + len(continuous_shapes), len(continuous_weights) + ) + self.continuous_weights = continuous_weights + + self.split_input = SplitInput(self.discrete_shapes, self.continuous_shapes) + self.split_output = SplitOutput( + self.discrete_shapes, + self.continuous_shapes, + self.decoder_distribution, + # continuous_activation_name="Tanh", + ) + + self.in_features = self.num_disc_features + self.num_cont_features + + self.encoder = Encoder( + self.in_features, + num_hidden, + num_latent, + embedding_args=self.embedding_args, + dropout_rate=dropout_rate, + ) + + self.decoder = Decoder( + num_latent, + num_hidden[::-1], + self.split_output.num_expected_features, + dropout_rate=dropout_rate, + ) + + if use_cuda and not torch.cuda.is_available(): + raise CudaIsNotAvailable() + self.use_cuda = use_cuda + + device = torch.device("cuda" if use_cuda else "cpu") + self.to(device) + + @property + def decoder_distribution(self) -> Type[Distribution]: + return Normal + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + z_loc, z_logvar, *_ = self.encoder(x) + return z_loc, torch.exp(z_logvar * 0.5) + + def reparameterize(self, loc: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + eps = torch.randn_like(scale) + return eps.mul(scale).add_(loc) + + def decode(self, z: torch.Tensor) -> tuple[torch.Tensor, ...]: + return self.decoder(z) + + def project(self, batch: torch.Tensor) -> torch.Tensor: + return self.encode(batch)[0] + + def reconstruct(self, batch: torch.Tensor, as_one: bool = False): + out = self(batch)["x_recon"] + out_disc, out_cont = self.split_output(out) + out_cont = cast(ContinuousDistribution, out_cont) + recon_disc = torch.cat( + [logits.flatten(start_dim=1) for logits in out_disc], dim=1 + ) + # get location only (mean) + recon_cont = torch.cat([args["loc"] for args in out_cont], dim=1) + if as_one: + return torch.cat((recon_disc, recon_cont), dim=1) + return recon_disc, recon_cont + + def forward(self, x: torch.Tensor) -> VaeOutput: + z_loc, z_scale = self.encode(x) + z = self.reparameterize(z_loc, z_scale) + x_recon, *_ = self.decode(z) + return { + "z_loc": z_loc, + "z_scale": z_scale, + "x_recon": x_recon, + } + + @staticmethod + def compute_log_prob( + dist: Type[Distribution], + x: torch.Tensor, + ignore_mask: Optional[torch.Tensor] = None, + **dist_args + ): + """Compute the log of the probability density of the likelihood p(x|z).""" + px = dist(**dist_args) + out = px.log_prob(x) + if ignore_mask is None: + return torch.sum(out, dim=-1).mean() + out = out * ignore_mask + return out.sum(dim=-1).sum() / ignore_mask.sum(-1).sum() + + @staticmethod + def compute_kl_div(qz_loc: torch.Tensor, qz_scale: torch.Tensor): + """Compute the KL divergence between posterior q(z|x) and prior p(z). + The prior has a Normal(0, 1) distribution.""" + qz = Normal(qz_loc, qz_scale) + pz = Normal(0.0, 1.0) + return kl_divergence(qz, pz).sum(dim=-1).mean() + + def compute_loss(self, batch: torch.Tensor, annealing_factor: float) -> LossDict: + # Split concatenated input + batch_disc, batch_cont = self.split_input(batch) + # Split concatenated output + out = self(batch) + out_disc, out_cont = self.split_output(out["x_recon"]) + out_cont = cast(ContinuousDistribution, out_cont) + + # Compute discrete dataset losses + disc_rec_loss = torch.zeros(()) + for i, args in enumerate(out_disc): + y = torch.argmax(batch_disc[i], dim=-1) + ignore_mask = torch.any(batch_disc[i] == 1, dim=-1).float() # Ignore NaNs + disc_rec_loss -= self.discrete_weights[i] * self.compute_log_prob( + Categorical, y, ignore_mask, logits=args + ) + + # Compute continuous dataset losses + cont_rec_loss = torch.zeros(()) + for i, args in enumerate(out_cont): + ignore_mask = torch.logical_not(batch_cont[i] == 0.0) # Ignore NaNs + cont_rec_loss -= self.continuous_weights[i] * self.compute_log_prob( + self.decoder_distribution, batch_cont[i], ignore_mask, **args + ) + + # Calculate overall reconstruction and regularization loss + rec_loss = disc_rec_loss + cont_rec_loss + reg_loss = self.compute_kl_div(out["z_loc"], out["z_scale"]) + + # Compute ELBO + kl_weight = annealing_factor * self.kl_weight + elbo = rec_loss + reg_loss * kl_weight + return { + "elbo": elbo, + "discrete_loss": disc_rec_loss, + "continuous_loss": cont_rec_loss, + "kl_div": reg_loss, + "kl_weight": kl_weight, + } + + +VaeNormal = VaeDistribution diff --git a/src/move/models/vae_legacy.py b/src/move/models/vae_legacy.py new file mode 100644 index 00000000..4000b672 --- /dev/null +++ b/src/move/models/vae_legacy.py @@ -0,0 +1,747 @@ +__all__ = ["VAE"] + +import logging +from typing import Callable, Optional + +import torch +from torch import nn, optim +from torch.utils.data import DataLoader + +from move.core.typing import FloatArray, IntArray + +logger = logging.getLogger("vae.py") + + +class VAE(nn.Module): + """Variational autoencoder. + + Instantiate with: + continuous_shapes: shape of the different continuous datasets if any + categorical_shapes: shape of the different categorical datasets if any + num_hidden: List of n_neurons in the hidden layers [[200, 200]] + num_latent: Number of neurons in the latent layer [15] + beta: Multiply KLD by the inverse of this value [0.0001] + continuous_weights: list of weights for each continuous dataset + categorical_weights: list of weights for each categorical dataset + dropout: Probability of dropout on forward pass [0.2] + cuda: Use CUDA (GPU accelerated training) [False] + + Raises: + ValueError: Minimum 1 latent unit + ValueError: Beta must be greater than zero. + ValueError: Dropout must be between zero and one. + ValueError: Shapes of the input data must be provided. + ValueError: Number of continuous weights must be the same as number of + continuous datasets + ValueError: Number of categorical weights must be the same as number of + categorical datasets + """ + + def __init__( + self, + categorical_shapes: Optional[list[tuple[int, ...]]] = None, + continuous_shapes: Optional[list[int]] = None, + categorical_weights: Optional[list[int]] = None, + continuous_weights: Optional[list[int]] = None, + num_hidden: list[int] = [200, 200], + num_latent: int = 20, + beta: float = 0.01, + dropout: float = 0.2, + cuda: bool = False, + ): + + if num_latent < 1: + raise ValueError(f"Minimum 1 latent unit. Input was {num_latent}.") + + if beta <= 0: + raise ValueError("Beta must be greater than zero.") + + if not (0 <= dropout < 1): + raise ValueError("Dropout must be between zero and one.") + + if continuous_shapes is None and categorical_shapes is None: + raise ValueError("Shapes of the input data must be provided.") + + self.input_size = 0 + if continuous_shapes is not None: + self.num_continuous = sum(continuous_shapes) + self.input_size += self.num_continuous + self.continuous_shapes = continuous_shapes + + if continuous_weights is not None: + self.continuous_weights = continuous_weights + if len(continuous_shapes) != len(continuous_weights): + raise ValueError( + "Number of continuous weights must be the same as" + " number of continuous datasets" + ) + else: + self.num_continuous = 0 + + if categorical_shapes is not None: + self.num_categorical = sum( + [int.__mul__(*shape) for shape in categorical_shapes] + ) + self.input_size += self.num_categorical + self.categorical_shapes = categorical_shapes + + if categorical_weights is not None: + self.categorical_weights = categorical_weights + if len(categorical_shapes) != len(categorical_weights): + raise ValueError( + "Number of categorical weights must be the same as" + " number of categorical datasets" + ) + else: + self.num_categorical = 0 + + super(VAE, self).__init__() + + # Initialize simple attributes + self.beta = beta + self.num_hidden = num_hidden + self.num_latent = num_latent + self.dropout = dropout + + self.device = torch.device("cuda" if cuda == True else "cpu") + + # Activation functions + self.relu = nn.LeakyReLU() + self.log_softmax = nn.LogSoftmax(dim=1) + self.dropoutlayer = nn.Dropout(p=self.dropout) + + # Initialize lists for holding hidden layers + self.encoderlayers = nn.ModuleList() + self.encodernorms = nn.ModuleList() + self.decoderlayers = nn.ModuleList() + self.decodernorms = nn.ModuleList() + + ### Layers + # Hidden layers + for nin, nout in zip([self.input_size] + self.num_hidden, self.num_hidden): + self.encoderlayers.append(nn.Linear(nin, nout)) + self.encodernorms.append(nn.BatchNorm1d(nout)) + + # Latent layers + self.mu = nn.Linear(self.num_hidden[-1], self.num_latent) # mu layer + self.var = nn.Linear(self.num_hidden[-1], self.num_latent) # logvariance layer + + # Decoding layers + for nin, nout in zip( + [self.num_latent] + self.num_hidden[::-1], self.num_hidden[::-1] + ): + self.decoderlayers.append(nn.Linear(nin, nout)) + self.decodernorms.append(nn.BatchNorm1d(nout)) + + # Reconstruction - output layers + self.out = nn.Linear(self.num_hidden[0], self.input_size) # to output + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encodes the data in the data loader and returns the encoded matrix. + + Args: + x: input data + + Returns: + A tuple containing: + mean latent vector + log-variance latent vector + """ + # Hidden layers + for encoderlayer, encodernorm in zip(self.encoderlayers, self.encodernorms): + x = encoderlayer(x) + x = self.relu(x) + x = self.dropoutlayer(x) + x = encodernorm(x) + + return self.mu(x), self.var(x) + + def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + """ + Performs reparametrization trick + + Args: + mu: mean latent vector + logvar: log-variance latent vector + + Returns: + sample from latent space distribution + """ + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + + return eps.mul(std).add_(mu) + + def decompose_categorical(self, reconstruction: torch.Tensor) -> list[torch.Tensor]: + """ + Returns list of final reconstructions (after applying + log-softmax to the outputs of decoder) of each categorical class + + Args: + reconstruction: results of final layer of decoder + + Returns: + final reconstructions of each categorical class + """ + cat_tmp = reconstruction.narrow(1, 0, self.num_categorical) + + # handle soft max for each categorical dataset + cat_out = [] + pos = 0 + for cat_shape in self.categorical_shapes: + cat_dataset = cat_tmp[:, pos : (cat_shape[0] * cat_shape[1] + pos)] + + cat_out_tmp = cat_dataset.view( + cat_dataset.shape[0], cat_shape[0], cat_shape[1] + ) + cat_out_tmp = cat_out_tmp.transpose(1, 2) + cat_out_tmp = self.log_softmax(cat_out_tmp) + + cat_out.append(cat_out_tmp) + pos += cat_shape[0] * cat_shape[1] + + return cat_out + + def decode( + self, x: torch.Tensor + ) -> tuple[Optional[list[torch.Tensor]], Optional[torch.Tensor]]: + """ + Decode to the input space from the latent space + + Args: + x: sample from latent space distribution + + Returns: + A tuple containing: + cat_out: + list of reconstructions of every categorical data class + con_out: + reconstruction of continuous data + """ + for decoderlayer, decodernorm in zip(self.decoderlayers, self.decodernorms): + x = decoderlayer(x) + x = self.relu(x) + x = self.dropoutlayer(x) + x = decodernorm(x) + + reconstruction = self.out(x) + + # Decompose reconstruction to categorical and continuous variables + # if both types are in the input + cat_out, con_out = None, None + if self.num_categorical > 0: + cat_out = self.decompose_categorical(reconstruction) + if self.num_continuous > 0: + con_out = reconstruction.narrow( + 1, self.num_categorical, self.num_continuous + ) + + return cat_out, con_out + + def forward( + self, tensor: torch.Tensor + ) -> tuple[list[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward propagate through the VAE network + + Args: + tensor (torch.Tensor): input data + + Returns: + (tuple): a tuple containing: + cat_out (list): list of reconstructions of every categorical + data class + con_out (torch.Tensor): reconstructions of continuous data + mu (torch.Tensor): mean latent vector + logvar (torch.Tensor): mean log-variance vector + """ + mu, logvar = self.encode(tensor) + z = self.reparameterize(mu, logvar) + cat_out, con_out = self.decode(z) + + return cat_out, con_out, mu, logvar + + def calculate_cat_error( + self, + cat_in: torch.Tensor, + cat_out: list[torch.Tensor], + ) -> torch.Tensor: + """ + Calculates errors (cross-entropy) for categorical data reconstructions + + Args: + cat_in: + input categorical data + cat_out: + list of reconstructions of every categorical data class + + Returns: + torch.Tensor: + Errors (cross-entropy) for categorical data reconstructions + """ + batch_size = cat_in.shape[0] + + # calcualte target values for all cat datasets + count = 0 + cat_errors = [] + pos = 0 + for cat_shape in self.categorical_shapes: + cat_dataset = cat_in[:, pos : (cat_shape[0] * cat_shape[1] + pos)] + + cat_dataset = cat_dataset.view(cat_in.shape[0], cat_shape[0], cat_shape[1]) + cat_target = cat_dataset + cat_target = cat_target.argmax(2) + cat_target[cat_dataset.sum(dim=2) == 0] = -1 + cat_target = cat_target.to(self.device) + + # Cross entropy loss for categroical + loss = nn.NLLLoss(reduction="sum", ignore_index=-1) + cat_errors.append( + loss(cat_out[count], cat_target) / (batch_size * cat_shape[0]) + ) + count += 1 + pos += cat_shape[0] * cat_shape[1] + + cat_errors = torch.stack(cat_errors) + return cat_errors + + def calculate_con_error( + self, con_in: torch.Tensor, con_out: torch.Tensor, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + ) -> torch.Tensor: + """ + Calculates errors (MSE) for continuous data reconstructions + + Args: + con_in: input continuous data + con_out: reconstructions of continuous data + loss: loss function + + Returns: + MSE loss + """ + batch_size = con_in.shape[0] + total_shape = 0 + con_errors_list: list[torch.Tensor] = [] + for s in self.continuous_shapes: + c_in = con_in[:, total_shape : (s + total_shape - 1)] + c_re = con_out[:, total_shape : (s + total_shape - 1)] + error = loss(c_re, c_in) / batch_size + con_errors_list.append(error) + total_shape += s + + con_errors = torch.stack(con_errors_list) + con_errors = con_errors / torch.Tensor(self.continuous_shapes).to(self.device) + MSE = torch.sum( + con_errors * torch.Tensor(self.continuous_weights).to(self.device) + ) + return MSE + + # Reconstruction + KL divergence losses summed over all elements and batch + def loss_function( + self, + cat_in: torch.Tensor, + cat_out: list[torch.Tensor], + con_in: torch.Tensor, + con_out: torch.Tensor, + mu: torch.Tensor, + logvar: torch.Tensor, + kld_w: float, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Calculates the loss for data reconstructions + + Args: + cat_in: input categorical data + cat_out: list of reconstructions of every categorical data class + con_in: input continuous data + con_out: reconstructions of continuous data + mu: mean latent vector + logvar: mean log-variance vector + kld_w: kld weight + + Returns: + (tuple): a tuple containing: + total loss on train set during the training of the epoch + BCE loss on train set during the training of the epoch + SSE loss on train set during the training of the epoch + KLD loss on train set during the training of the epoch + """ + + MSE = 0 + CE = 0 + # calculate loss for catecorical data if in the input + if cat_out is not None: + cat_errors = self.calculate_cat_error(cat_in, cat_out) + if self.categorical_weights is not None: + CE = torch.sum( + cat_errors * torch.Tensor(self.categorical_weights).to(self.device) + ) + else: + CE = torch.sum(cat_errors) / len(cat_errors) + + # calculate loss for continuous data if in the input + if con_out is not None: + batch_size = con_in.shape[0] + # Mean square error loss for continauous + loss = nn.MSELoss(reduction="sum") + # set missing data to 0 to remove any loss these would provide + con_out[con_in == 0] = 0 + + # include different weights for each omics dataset + if self.continuous_weights is not None: + MSE = self.calculate_con_error(con_in, con_out, loss) + else: + MSE = loss(con_out, con_in) / (batch_size * self.num_continuous) + + # see Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / (batch_size) + + KLD_weight = self.beta * kld_w + loss = CE + MSE + KLD * KLD_weight + + return loss, CE, MSE, KLD * KLD_weight + + def encoding( + self, + train_loader: DataLoader, + epoch: int, + lrate: float, + kld_w: float, + ) -> tuple[float, float, float, float]: + """ + One iteration of VAE + + Args: + train_loader: Dataloader with train dataset + epoch: the epoch + lrate: learning rate for the model + kld_w: float of KLD weight + + Returns: + (tuple): a tuple containing: + total loss on train set during the training of the epoch + BCE loss on train set during the training of the epoch + SSE loss on train set during the training of the epoch + KLD loss on train set during the training of the epoch + """ + self.train() + optimizer = optim.Adam(self.parameters(), lr=lrate) + + epoch_loss = 0 + epoch_kldloss = 0 + epoch_sseloss = 0 + epoch_bceloss = 0 + + for _, (cat, con) in enumerate(train_loader): + # Move input to GPU if requested + cat = cat.to(self.device) + con = con.to(self.device) + + if self.num_categorical > 0 and self.num_continuous > 0: + tensor = torch.cat((cat, con), 1) + elif self.num_categorical > 0: + tensor = cat + elif self.num_continuous > 0: + tensor = con + else: + assert False, "Must have at least 1 categorial or 1 continuous feature" + + optimizer.zero_grad() + + cat_out, con_out, mu, logvar = self(tensor) + + loss, bce, sse, kld = self.loss_function( + cat, cat_out, con, con_out, mu, logvar, kld_w + ) + loss.backward() + + epoch_loss += loss.data.item() + epoch_kldloss += kld.data.item() + + if self.num_continuous > 0: + epoch_sseloss += sse.data.item() + + if self.num_categorical > 0: + epoch_bceloss += bce.data.item() + + optimizer.step() + + logger.info( + "\tEpoch: {}\tLoss: {:.6f}\tCE: {:.7f}\tSSE: {:.6f}\t" + "KLD: {:.4f}\tBatchsize: {}".format( + epoch, + epoch_loss / len(train_loader), + epoch_bceloss / len(train_loader), + epoch_sseloss / len(train_loader), + epoch_kldloss / len(train_loader), + train_loader.batch_size, + ) + ) + return ( + epoch_loss / len(train_loader), + epoch_bceloss / len(train_loader), + epoch_sseloss / len(train_loader), + epoch_kldloss / len(train_loader), + ) + + def make_cat_recon_out(self, length: int) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Initiate empty tensors for categorical data + + Args: + length: number of samples + + Returns: + (tuple): a tuple containing: + cat_class: empty tensor for input categorical data + cat_recon: empty tensor for reconstructed categorical data + cat_total_shape: number of features of linearized one hot + categorical data + """ + cat_total_shape = 0 + for cat_shape in self.categorical_shapes: + cat_total_shape += cat_shape[0] + + cat_class = torch.empty((length, cat_total_shape)).int() + cat_recon = torch.empty((length, cat_total_shape)).int() + return cat_class, cat_recon, cat_total_shape + + def get_cat_recon( + self, batch: int, cat_total_shape: int, cat: torch.Tensor, cat_out: torch.Tensor + ) -> tuple[IntArray, IntArray]: + """ + Generates reconstruction data of categorical data class + + Args: + batch: number of samples in the batch + cat_total_shape: number of features of linearized one hot + categorical data + cat: input categorical data + cat_out: reconstructed categorical data + + Returns: + (tuple): a tuple containing: + cat_out_class: reconstructed categorical data + cat_target: input categorical data + """ + count = 0 + cat_out_class = torch.empty((batch, cat_total_shape)).int() + cat_target = torch.empty((batch, cat_total_shape)).int() + pos = 0 + shape_1 = 0 + for cat_shape in self.categorical_shapes: + # Get input categorical data + cat_in_tmp = cat[:, pos : (cat_shape[0] * cat_shape[1] + pos)] + cat_in_tmp = cat_in_tmp.view(cat.shape[0], cat_shape[0], cat_shape[1]) + + # Calculate target values for input + cat_target_tmp = cat_in_tmp + cat_target_tmp = torch.argmax(cat_target_tmp.detach(), dim=2) + cat_target_tmp[cat_in_tmp.sum(dim=2) == 0] = -1 + cat_target[ + :, shape_1 : (cat_shape[0] + shape_1) + ] = cat_target_tmp # .numpy() + + # Get reconstructed categorical data + cat_out_tmp = cat_out[count] + cat_out_tmp = cat_out_tmp.transpose(1, 2) + cat_out_class[:, shape_1 : (cat_shape[0] + shape_1)] = torch.argmax( + cat_out_tmp, dim=2 + ) # .numpy() + + # make counts for next dataset + pos += cat_shape[0] * cat_shape[1] + shape_1 += cat_shape[0] + count += 1 + + cat_target = cat_target.numpy() + cat_out_class = cat_out_class.numpy() + + return cat_out_class, cat_target + + def _validate_batch(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """ + Returns the batch of categorical and continuous data if they are not + None + + Args: + batch: batches of categorical and continuous data + + Returns: + a formed batch + """ + cat, con = batch + cat = cat.to(self.device) + con = con.to(self.device) + + if self.num_categorical == 0: + return con + elif self.num_continuous == 0: + return cat + return torch.cat((cat, con), dim=1) + + @torch.no_grad() + def project(self, dataloader: DataLoader) -> FloatArray: + """Generates an embedding of the data contained in the DataLoader. + + Args: + dataloader: A DataLoader with categorical or continuous data + + Returns: + FloatArray: Embedding + """ + self.eval() + embedding = [] + for batch in dataloader: + batch = self._validate_batch(batch) + *_, mu, _ = self(batch) + embedding.append(mu) + embedding = torch.cat(embedding, dim=0).cpu().numpy() + return embedding + + @torch.no_grad() + def reconstruct( + self, dataloader: DataLoader + ) -> tuple[list[FloatArray], FloatArray]: + """ + Generates a reconstruction of the data contained in the DataLoader. + + Args: + dataloader: A DataLoader with categorical or continuous data + + Returns: + A list of categorical reconstructions and the continuous + reconstruction + """ + self.eval() + cat_recons = [[] for _ in range(len(self.categorical_shapes))] + con_recons = [] + for batch in dataloader: + batch = self._validate_batch(batch) + cat_recon, con_recon, *_ = self(batch) + if cat_recon is not None: + for i, cat in enumerate(cat_recon): + cat_recons[i].append(torch.argmax(cat, dim=1)) + if con_recon is not None: + con_recons.append(con_recon) + if cat_recons: + cat_recons = [torch.cat(cats, dim=0).cpu().numpy() for cats in cat_recons] + if con_recons: + con_recons = torch.cat(con_recons, dim=0).cpu().numpy() + return cat_recons, con_recons + + @torch.no_grad() + def latent( + self, dataloader: DataLoader, kld_weight: float + ) -> tuple[FloatArray, FloatArray, IntArray, IntArray, FloatArray, float, float]: + """ + Iterate through validation or test dataset + + Args: + dataloader: Dataloader with test dataset + kld_weight: KLD weight + + Returns: + (tuple): a tuple containing: + latent: array of VAE latent space mean vectors values + latent_var: array of VAE latent space logvar vectors values + cat_recon: reconstructed categorical data + cat_class: input categorical data + con_recon: reconstructions of continuous data + test_loss: total loss on test set + test_likelihood: total likelihood on test set + """ + + self.eval() + test_loss = 0 + test_likelihood = 0 + + num_samples = dataloader.dataset.num_samples + + latent = torch.empty((num_samples, self.num_latent)) + latent_var = torch.empty((num_samples, self.num_latent)) + + # reconstructed output + if self.num_categorical > 0: + cat_class, cat_recon, cat_total_shape = self.make_cat_recon_out(num_samples) + else: + cat_class = None + cat_recon = None + + con_recon = ( + None + if self.num_continuous == 0 + else torch.empty((num_samples, self.num_continuous)) + ) + + row = 0 + for (cat, con) in dataloader: + cat = cat.to(self.device) + con = con.to(self.device) + + # get dataset + if self.num_categorical > 0 and self.num_continuous > 0: + tensor = torch.cat((cat, con), 1) + elif self.num_categorical > 0: + tensor = cat + elif self.num_continuous > 0: + tensor = con + else: + assert False, "Must have at least 1 categorial or 1 continuous feature" + + # Evaluate + cat_out, con_out, mu, logvar = self(tensor) + + mu = mu.to(self.device) + logvar = logvar.to(self.device) + batch = len(mu) + + loss, bce, sse, _ = self.loss_function( + cat, cat_out, con, con_out, mu, logvar, kld_weight + ) + test_likelihood += bce + sse + test_loss += loss.data.item() + + if self.num_categorical > 0: + cat_out_class, cat_target = self.get_cat_recon( + batch, cat_total_shape, cat, cat_out + ) + cat_recon[row : row + len(cat_out_class)] = torch.Tensor(cat_out_class) + cat_class[row : row + len(cat_target)] = torch.Tensor(cat_target) + + if self.num_continuous > 0: + con_recon[row : row + len(con_out)] = con_out + + latent_var[row : row + len(logvar)] = logvar + latent[row : row + len(mu)] = mu + row += len(mu) + + test_loss /= len(dataloader) + logger.info("====> Test set loss: {:.4f}".format(test_loss)) + + latent = latent.numpy() + latent_var = latent_var.numpy() + cat_recon = cat_recon.numpy() + cat_class = cat_class.numpy() + con_recon = con_recon.numpy() + + assert row == num_samples + return ( + latent, + latent_var, + cat_recon, + cat_class, + con_recon, + test_loss, + test_likelihood, + ) + + def __repr__(self) -> str: + return ( + f"VAE ({self.input_size} ⇄ {' ⇄ '.join(map(str, self.num_hidden))}" + f" ⇄ {self.num_latent})" + ) diff --git a/src/move/models/vae_t.py b/src/move/models/vae_t.py new file mode 100644 index 00000000..47ffa8a7 --- /dev/null +++ b/src/move/models/vae_t.py @@ -0,0 +1,15 @@ +__all__ = ["VaeT"] + +from typing import Type + +from torch.distributions import Distribution, StudentT + +from move.models.vae_distribution import VaeDistribution + + +class VaeT(VaeDistribution): + """Variational autoencoder with a Student-t distribution on its decoder.""" + + @property + def decoder_distribution(self) -> Type[Distribution]: + return StudentT diff --git a/src/move/tasks/__init__.py b/src/move/tasks/__init__.py index 324027c9..21e1b532 100644 --- a/src/move/tasks/__init__.py +++ b/src/move/tasks/__init__.py @@ -1,11 +1,14 @@ __all__ = [ - "analyze_latent", - "encode_data", - "identify_associations", - "tune_model", + "Associations", + "EncodeData", + "LatentSpaceAnalysis", + "TrainModel", + "TuneModel", + "TuneStability", ] -from move.tasks.analyze_latent import analyze_latent -from move.tasks.encode_data import encode_data -from move.tasks.identify_associations import identify_associations -from move.tasks.tune_model import tune_model +from move.tasks.associations import Associations +from move.tasks.encode_data import EncodeData +from move.tasks.latent_space_analysis import LatentSpaceAnalysis +from move.tasks.train_model import TrainModel +from move.tasks.tuning import TuneModel, TuneStability diff --git a/src/move/tasks/analyze_latent.py b/src/move/tasks/analyze_latent.py index 132a371a..f74173b1 100644 --- a/src/move/tasks/analyze_latent.py +++ b/src/move/tasks/analyze_latent.py @@ -15,7 +15,8 @@ calculate_accuracy, calculate_cosine_similarity, ) -from move.conf.schema import AnalyzeLatentConfig, MOVEConfig +from move.conf.legacy import AnalyzeLatentConfig +from move.conf.schema import MOVEConfig from move.core.logging import get_logger from move.core.typing import FloatArray from move.data import io @@ -25,7 +26,7 @@ perturb_continuous_data, ) from move.data.preprocessing import one_hot_encode_single -from move.models.vae import VAE +from move.models.vae_legacy import VAE from move.training.training_loop import TrainingLoopOutput diff --git a/src/move/tasks/associations.py b/src/move/tasks/associations.py new file mode 100644 index 00000000..e57ca3cd --- /dev/null +++ b/src/move/tasks/associations.py @@ -0,0 +1,209 @@ +__all__ = [] + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Iterator + +import numpy as np +import pandas as pd +import torch + +import move.visualization as viz +from move.conf.tasks import PerturbationConfig +from move.core.typing import PathLike +from move.tasks.base import CsvWriterMixin +from move.tasks.move import MoveTask + +if TYPE_CHECKING: + from move.models.base import BaseVae + + +class Associations(CsvWriterMixin, MoveTask): + """Find associations by applying a perturbation to dataset and checking + which features significantly change after perturbation. + + 1. Train and refit # models (or reload them if already existing). + 2. Perturb a feature in the dataset. + 3. Obtain reconstructions before/after perturbation. + 4. Compute difference between reconstructions. + 5. Average difference over # re-fits + 6. Compute indicator variable as probability difference > 0. Significant + features are either: majorly lower than 0 or majorly greater than 0. + Thus, when ranking features by this probability those with 50% will + be considered not significant. + 7. Rank by indicator variable, and calculate FDR as cumulative + probability of being significant. + 8. Select as significant the features below a determined threshold. + + Args: + interim_data_path: + Directory where encoded data is stored + results_path: + Directory where results will be saved + discrete_dataset_names: + Names of discrete datasets + continuous_dataset_names: + Names of continuous datasets + batch_size: + Number of samples in one batch (used during training and testing) + model_config: + Config of the VAEexit + training_loop_config: + Config of the training loop + perturbation_config: + Config of the perturbation + num_refits: + Number of times to refit the model + sig_threshold: + Threshold used to determine whether an association is significant. + Significant associations are selected if their FDR is below this + threshold. Value should range between (0, 1) + write_only_sig: + Whether all or only significant hits are written in output file. + """ + + loop_filename: str = "loop.yaml" + model_filename_fmt: str = "model_{}.pt" + results_subdir: str = "associations" + results_filename: str = "associations.csv" + + def __init__( + self, + interim_data_path: PathLike, + results_path: PathLike, + perturbation_config: PerturbationConfig, + num_refits: int, + sig_threshold: float = 0.05, + write_only_sig: bool = True, + **kwargs, + ) -> None: + if not (0 < sig_threshold < 1.0): + raise ValueError("Significant threshold should be in range (0, 1)") + + super().__init__( + input_dir=interim_data_path, + output_dir=Path(results_path) / self.results_subdir, + **kwargs, + ) + self.perturbation_config = perturbation_config + self.num_refits = num_refits + self.sig_threshold = sig_threshold + self.write_only_sig = write_only_sig + + def get_trained_model(self, train_dataloader) -> Iterator["BaseVae"]: + """Yield a trained model. Model will be trained from scratch or + re-loaded if it already exists.""" + from move.models.base import reload_vae + + for i in range(self.num_refits): + model_path = self.output_dir / self.model_filename_fmt.format(i) + + if model_path.exists(): + if i == 0: + self.logger.debug(f"Re-loading models") + model = reload_vae(model_path) + else: + if i == 0: + self.logger.debug(f"Training models from scratch") + model = self.init_model(train_dataloader) + training_loop = self.init_training_loop(False) # prevent write + training_loop.train(model, train_dataloader) + if i == 1: + training_loop.to_yaml(self.output_dir / self.loop_filename) + model.save(model_path) + model.freeze() + yield model + + def run(self) -> Any: + + colnames = ["perturbed_feature", "target_feature", "prob", "bayes_k"] + self.init_csv_writer( + self.output_dir / self.results_filename, + fieldnames=colnames, + extrasaction="ignore", + ) + + # Prep dataloaders + train_dataloader = self.make_dataloader() + test_dataloader = self.make_dataloader(split="test") + num_discrete_indices = [test_dataloader.dataset.num_discrete_features] + dataset = test_dataloader.dataset + + # Gather perturbations + perturbation_names: list[str] = [] + if self.perturbation_config.target_feature_name is None: + perturbation_names.extend( + dataset.feature_names_of(self.perturbation_config.target_dataset_name) + ) + else: + perturbation_names.append(self.perturbation_config.target_feature_name) + num_perturbations = len(perturbation_names) + + # Compute bayes factor per feature per perturbation + for i, perturbed_feature_name in enumerate(perturbation_names, 1): + mean_diff = None + for model in self.get_trained_model(train_dataloader): + # Perturb feature + dataset.perturb( + self.perturbation_config.target_dataset_name, + perturbed_feature_name, + self.perturbation_config.target_value, + ) + + # Compute reconstruction differences (only continuous) + diff_list = [] + for orig_batch, pert_batch, pert_mask in test_dataloader: + _, orig_recon = model.reconstruct(orig_batch) + _, pert_recon = model.reconstruct(pert_batch) + _, orig_input = torch.tensor_split( + orig_batch, num_discrete_indices, dim=-1 + ) + diff = pert_recon - orig_recon + diff[orig_input == 0] = 0.0 # mark NaN as 0 + diff = diff[pert_mask, :] # ignore unperturbed features + diff_list.append(diff) + + dataset.remove_perturbation() + + # Concatenate and normalize reconstruction differences + cat_diff = torch.cat(diff_list) / self.num_refits + if mean_diff is None: + mean_diff = cat_diff + else: + mean_diff += cat_diff + + self.logger.info(f"Perturbing ({i}/{num_perturbations})") + assert mean_diff is not None + + prob = torch.sum(mean_diff > 1e-8, dim=0) / mean_diff.count_nonzero(dim=0) + bayes_k = torch.log(prob + 1e-8) - torch.log(1 - prob + 1e-8) + abs_prob = torch.special.expit(torch.abs(bayes_k)) + + self.write_cols( + { + "perturbed_feature": [perturbed_feature_name] + * dataset.num_continuous_features, + "target_feature": dataset.continuous_feature_names, + "prob": abs_prob.numpy(), + "bayes_k": bayes_k.numpy(), + } + ) + + self.logger.info("Complete! Writing out results") + self.close_csv_writer() + + # Sort results, compute FDR + assert self.csv_filepath is not None + + results = pd.read_csv(self.csv_filepath) + results.sort_values("prob", ascending=False, inplace=True, ignore_index=True) + results["fdr"] = np.cumsum(1 - results["prob"]) / np.arange(1, len(results) + 1) + results["pred_significant"] = results["fdr"] < self.sig_threshold + + sig = results[results.pred_significant] + self.logger.info(f"Significant associations found: {len(sig)}") + + if self.write_only_sig: + results = sig + results.drop(columns=["pred_significant"], inplace=True) + + results.to_csv(self.csv_filepath, index=False) diff --git a/src/move/tasks/base.py b/src/move/tasks/base.py new file mode 100644 index 00000000..9242efe1 --- /dev/null +++ b/src/move/tasks/base.py @@ -0,0 +1,266 @@ +__all__ = ["Task", "ParentTask", "SubTask", "CsvWriterMixin"] + +import inspect +import logging +from abc import ABC, abstractmethod +from io import TextIOWrapper +from pathlib import Path +from typing import Any, Optional, Sequence, Type, TypeVar, Union, cast + +import hydra +from numpy.typing import NDArray +from omegaconf import DictConfig, OmegaConf + +from move.core.exceptions import FILE_EXISTS_WARNING, UnsetProperty +from move.core.logging import get_logger +from move.core.qualname import get_fully_qualname +from move.core.typing import LoggingLevel, PathLike +from move.data.writer import CsvWriter + + +class InputDirMixin: + """Mixin class for adding an input directory property to a class.""" + + @property + def input_dir(self) -> Path: + if path := getattr(self, "_input_dir", None): + return path + raise UnsetProperty("Input directory") + + @input_dir.setter + def input_dir(self, pathlike: PathLike) -> None: + self._input_dir = Path(pathlike) + self._input_dir.mkdir(parents=True, exist_ok=True) + + +class OutputDirMixin: + """Mixin class for adding an output directory property to a class.""" + + @property + def output_dir(self) -> Path: + if path := getattr(self, "_output_dir", None): + return path + raise UnsetProperty("Output directory") + + @output_dir.setter + def output_dir(self, pathlike: PathLike) -> None: + self._output_dir = Path(pathlike) + self._output_dir.mkdir(parents=True, exist_ok=True) + + +class LoggerMixin: + """Mixin class for logging.""" + + @property + def logger(self) -> logging.Logger: + if issubclass(self.__class__, SubTaskMixin): + task = cast(SubTaskMixin, self).parent + if task: + return task.logger + raise UnsetProperty("Parent task") + if getattr(self, "_logger", None) is None: + self._logger = get_logger(self.__class__.__name__) + return self._logger + + def log(self, message: str, level: LoggingLevel = "INFO") -> None: + """Log a message. + + Args: + message: logged message + level: predefined logging level name or numeric value.""" + if isinstance(level, str): + level_num = logging.getLevelName(level) + if not isinstance(level_num, int): + raise ValueError(f"Unexpected logging level: {level}") + else: + level_num = level + self.logger.log(level_num, message) + + +T = TypeVar("T", bound="Task") + + +class Task(ABC, LoggerMixin): + """Base class for a task""" + + @abstractmethod + def run(self, *args, **kwargs) -> Any: + raise NotImplementedError() + + @classmethod + def from_config(cls: Type[T], config: DictConfig) -> T: + """Instantiate a task from its config file.""" + if not hasattr(config, "task"): + raise UnsetProperty("Task configuration") + target_qualname = config.task._target_ + current_qualname = get_fully_qualname(cls) + if target_qualname == current_qualname: + return hydra.utils.instantiate(config.task, _recursive_=False) + raise ValueError( + f"Received config for `{target_qualname}`, but current class is `{current_qualname}`" + ) + + def to_yaml(self, filepath: PathLike) -> None: + """Save task config as YAML file.""" + signature = inspect.signature(self.__class__).parameters.keys() + config = OmegaConf.create({name: getattr(self, name) for name in signature}) + with open(filepath, "w") as file: + file.write(OmegaConf.to_yaml(config)) + + +class ParentTask(InputDirMixin, OutputDirMixin, Task): + """A simple task with an input and output directory. This task may have + children (sub-tasks). + + Args: + input_path: Path where input files are read from + output_path: Path where output files will be saved to + """ + + def __init__(self, input_dir: PathLike, output_dir: PathLike) -> None: + self.input_dir = input_dir + self.output_dir = output_dir + + +class OutputDir(ParentTask): + """Task used to set an output directory.""" + + def __init__(self, output_dir: PathLike) -> None: + super().__init__(Path.cwd(), output_dir) + + def run(self) -> None: + raise NotImplementedError() + + +class TestTask(Task): + """Task used for testing""" + + def run(self) -> Any: + pass + + +class SubTaskMixin(LoggerMixin): + """Mixin class to designate a task is child of another task.""" + + @property + def parent(self) -> Optional[ParentTask]: + return getattr(self, "_parent", None) + + @parent.setter + def parent(self, task: ParentTask) -> None: + self._parent = task + + +class SubTask(SubTaskMixin, Task): + """Base class for sub-tasks.""" + + ... + + +CsvRow = dict[str, Any] + + +class CsvWriterMixin(LoggerMixin): + """Mixin class to designate a sub-task that has its own CSV writer.""" + + csv_filepath: Optional[Path] = None + buffer_size = 1000 + + @property + def can_write(self) -> bool: + return ( + getattr(self, "_csv_writer", None) is not None + and getattr(self, "_csv_file", None) is not None + and not self.csv_file.closed + ) + + @property + def csv_file(self) -> TextIOWrapper: + return getattr(self, "_csv_file") + + @csv_file.setter + def csv_file(self, value: Optional[TextIOWrapper]) -> None: + self._csv_file = value + + @property + def csv_writer(self) -> CsvWriter: + return getattr(self, "_csv_writer") + + @csv_writer.setter + def csv_writer(self, value: Optional[CsvWriter]) -> None: + self._csv_writer = value + + @property + def row_buffer(self) -> list[CsvRow]: + if getattr(self, "_buffer", None) is None: + self._buffer: list[CsvRow] = [] + return self._buffer + + @property + def output_dir(self) -> Path: + if self.parent is None: + raise UnsetProperty("Output directory") + return self.parent.output_dir + + @output_dir.setter + def output_dir(self, value: PathLike) -> None: + self.parent = OutputDir(value) + + def init_csv_writer(self, filepath: Path, mode: str = "w", **writer_kwargs) -> None: + """Initialize the CSV writer. + + Args: + filepath: Where to save the CSV file + mode: Whether to open the file in 'w' or 'a' mode + writer_args: Args passed to the CSV Writer object + """ + self.csv_filepath = filepath + exists = self.csv_filepath.exists() + if exists and mode == "w": # Warn about overwriting + self.log(FILE_EXISTS_WARNING.format(self.csv_filepath)) + self.csv_file = open(self.csv_filepath, mode, newline="") # type: ignore + self.csv_writer = CsvWriter(self.csv_file, **writer_kwargs) + # Do not write header if file exists and appending + if (not exists) or mode != "a": + self.csv_writer.writeheader() + + def write_cols(self, cols: dict[str, Union[Sequence[Any], NDArray]]) -> None: + """Directly write columns to CSV file. + + Args: + cols: Column name to values dictionary.""" + if self.can_write: + self.csv_writer.writecols(cols) + + def write_row(self, row: list[Any]) -> None: + """Directly write a row to CSV file.""" + if self.can_write: + self.csv_writer.writer.writerow(row) + + def add_row_to_buffer(self, csv_row: CsvRow) -> None: + """Add row to buffer and flush buffer if it has reached its limit. + + Args: + csv_row: Header names to values dictionary, representing a row + """ + if self.can_write: + self.row_buffer.append(csv_row) + # Flush + if len(self.row_buffer) >= self.buffer_size: + self.csv_writer.writerows(self.row_buffer) + self.row_buffer.clear() + + def close_csv_writer(self, clear: bool = False) -> None: + """Close file and flush buffer. + + Args: + clear: whether to nullify the writer and file object""" + if self.can_write: + if len(self.row_buffer) > 0: + self.csv_writer.writerows(self.row_buffer) + self.row_buffer.clear() + self.csv_file.close() + if clear: + self.csv_file = None + self.csv_filepath = None + self.csv_writer = None diff --git a/src/move/tasks/encode_data.py b/src/move/tasks/encode_data.py index 5092064a..1bc19084 100644 --- a/src/move/tasks/encode_data.py +++ b/src/move/tasks/encode_data.py @@ -1,50 +1,125 @@ -__all__ = ["encode_data"] - -from pathlib import Path +__all__ = ["EncodeData"] import numpy as np +import torch -from move.conf.schema import DataConfig -from move.core.logging import get_logger +from move.conf.tasks import InputConfig +from move.core.typing import PathLike from move.data import io, preprocessing +from move.data.splitting import split_samples +from move.tasks.base import ParentTask -def encode_data(config: DataConfig): - """Encodes categorical and continuous datasets specified in configuration. - Categorical data is one-hot encoded, whereas continuous data is z-score - normalized. +class EncodeData(ParentTask): + """Encode discrete and continuous datasets. By default, discrete data is + one-hot encoded, whereas continuous data is z-score normalized. Args: - config: data configuration + raw_data_path: + Directory where "raw data" is stored + interim_data_path: + Directory where pre-processed data will be saved + sample_names_filename: + Filename of file containing names given to each sample + discrete_inputs: + List of configs for each discrete dataset. Each config is a dict + containing keys 'name' and 'preprocessing' + continuous_inputs: + Same as `discrete_inputs`, but for continuous datasets + train_frac: + Fraction of samples corresponding to training set. + test_frac: + Fraction of samples corresponding to test set. + valid_frac: + Fraction of samples corresponding to validation set. """ - logger = get_logger(__name__) - logger.info("Beginning task: encode data") - - raw_data_path = Path(config.raw_data_path) - raw_data_path.mkdir(exist_ok=True) - interim_data_path = Path(config.interim_data_path) - interim_data_path.mkdir(exist_ok=True, parents=True) - - sample_names = io.read_names(raw_data_path / f"{config.sample_names}.txt") - - mappings = {} - for dataset_name in config.categorical_names: - logger.info(f"Encoding '{dataset_name}'") - filepath = raw_data_path / f"{dataset_name}.tsv" - names, values = io.read_tsv(filepath, sample_names) - values, mapping = preprocessing.one_hot_encode(values) - mappings[dataset_name] = mapping - io.dump_names(interim_data_path / f"{dataset_name}.txt", names) - np.save(interim_data_path / f"{dataset_name}.npy", values) - if mappings: - io.dump_mappings(interim_data_path / "mappings.json", mappings) - - for dataset_name in config.continuous_names: - logger.info(f"Encoding '{dataset_name}'") - filepath = raw_data_path / f"{dataset_name}.tsv" - names, values = io.read_tsv(filepath, sample_names) - values, mask_1d = preprocessing.scale(values) - names = names[mask_1d] - logger.debug(f"Columns with zero variance: {np.sum(~mask_1d)}") - io.dump_names(interim_data_path / f"{dataset_name}.txt", names) - np.save(interim_data_path / f"{dataset_name}.npy", values) + + indices_filename = "indices.pt" + sample_names: list[str] + train_indices: torch.Tensor + + def __init__( + self, + raw_data_path: PathLike, + interim_data_path: PathLike, + sample_names_filename: str, + discrete_inputs: list[InputConfig], + continuous_inputs: list[InputConfig], + train_frac: float = 0.9, + test_frac: float = 0.1, + valid_frac: float = 0.0, + ) -> None: + super().__init__(raw_data_path, interim_data_path) + self.sample_names_filepath = self.input_dir / f"{sample_names_filename}.txt" + self.discrete_inputs = discrete_inputs + self.continuous_inputs = continuous_inputs + self.split_fracs = (train_frac, test_frac, valid_frac) + + def encode_datasets( + self, + input_configs: list[InputConfig], + default_op_name: preprocessing.PreprocessingOpName, + ) -> None: + """Read CSV or TSV files containing datasets and run pre-processing operations. + + Args: + input_configs: + List of configs, each with a dataset file name and operation + name. Valid operation names are 'none', 'one_hot_encode', 'standardize', + 'log_and_standardize' + default_op_name: + Default operation if no operation in config + """ + for config in input_configs: + op_name: preprocessing.PreprocessingOpName = getattr( + config, "preprocessing", default_op_name + ) + action_name = "Reading" if op_name == "none" else "Encoding" + dataset_name = getattr(config, "name") + self.logger.info(f"{action_name} '{dataset_name}'") + dataset_path = self.input_dir / f"{dataset_name}.tsv" + if not dataset_path.exists(): + # Try to load CSV if TSV does not exist + dataset_path = self.input_dir / f"{dataset_name}.csv" + enc_data_path = self.output_dir / f"{dataset_name}.pt" + if enc_data_path.exists(): + self.logger.warning( + f"File '{enc_data_path.name}' already exists. It will be " + "overwritten." + ) + # Read and encode data + feature_names, values = io.read_tsv(dataset_path, self.sample_names) + mapping = None + if op_name in ("standardize", "log_and_standardize"): + values = preprocessing.standardize(values, self.train_indices) + elif op_name == "one_hot_encode": + values, mapping = preprocessing.one_hot_encode(values) + else: + values = preprocessing.fill(values) + tensor = torch.from_numpy(values).float() + # Save data + data = { + "dataset_name": dataset_name, + "tensor": tensor, + "feature_names": feature_names.tolist(), + } + if mapping is not None: + data["mapping"] = mapping + torch.save(data, enc_data_path, pickle_protocol=4) + + def split_samples(self) -> None: + """Create indices to split data into training, test, and validation subsets.""" + indices = split_samples(len(self.sample_names), *self.split_fracs) + ind_dict = dict( + zip(("train_indices", "test_indices", "valid_indices"), indices) + ) + torch.save(ind_dict, self.output_dir / self.indices_filename, pickle_protocol=4) + self.train_indices = ind_dict["train_indices"] + + def run(self) -> None: + """Encode data.""" + self.logger.info("Beginning task: encode data") + self.sample_names = io.read_names(self.sample_names_filepath) + self.split_samples() + self.encode_datasets(self.discrete_inputs, "one_hot_encode") + self.encode_datasets(self.continuous_inputs, "standardize") diff --git a/src/move/tasks/identify_associations.py b/src/move/tasks/identify_associations.py index 2c26933c..e0338546 100644 --- a/src/move/tasks/identify_associations.py +++ b/src/move/tasks/identify_associations.py @@ -10,19 +10,19 @@ import torch from omegaconf import OmegaConf -from move.conf.schema import ( +from move.conf.legacy import ( IdentifyAssociationsBayesConfig, IdentifyAssociationsConfig, IdentifyAssociationsTTestConfig, - MOVEConfig, ) +from move.conf.schema import MOVEConfig from move.core.logging import get_logger from move.core.typing import FloatArray, IntArray from move.data import io from move.data.dataloaders import MOVEDataset, make_dataloader from move.data.perturbations import perturb_categorical_data from move.data.preprocessing import one_hot_encode_single -from move.models.vae import VAE +from move.models.vae_legacy import VAE TaskType = Literal["bayes", "ttest"] @@ -210,7 +210,6 @@ def _ttest_approach( for k, num_latent in enumerate(task_config.num_latent): for j in range(task_config.num_refits): - # Initialize model model: VAE = hydra.utils.instantiate( task_config.model, @@ -303,10 +302,8 @@ def _ttest_approach( b_df = pd.DataFrame(dict(feature_b_name=con_names)) b_df.index.name = "feature_b_id" b_df.reset_index(inplace=True) - results = ( - results - .merge(a_df, on="feature_a_id", how="left") - .merge(b_df, on="feature_b_id", how="left") + results = results.merge(a_df, on="feature_a_id", how="left").merge( + b_df, on="feature_b_id", how="left" ) results["feature_b_dataset"] = pd.cut( cast(IntArray, results["feature_b_id"].values), diff --git a/src/move/tasks/latent_space_analysis.py b/src/move/tasks/latent_space_analysis.py new file mode 100644 index 00000000..0f24eb63 --- /dev/null +++ b/src/move/tasks/latent_space_analysis.py @@ -0,0 +1,237 @@ +__all__ = [] + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +import hydra +import numpy as np +import pandas as pd +import torch +from numpy.typing import NDArray +from sklearn.base import TransformerMixin + +import move.visualization as viz +from move.analysis.feature_importance import FeatureImportance +from move.analysis.metrics import ComputeAccuracyMetrics +from move.conf.tasks import ReducerConfig +from move.core.exceptions import UnsetProperty +from move.core.typing import PathLike +from move.data.io import sanitize_filename +from move.tasks.base import CsvWriterMixin, SubTask +from move.tasks.move import MoveTask + +if TYPE_CHECKING: + from move.data.dataloader import MoveDataLoader + from move.models.base import BaseVae + + +class LatentSpaceAnalysis(MoveTask): + """Analyze latent space. + + 1. Train a model (or reload it if it already exists). + 2. Obtain latent representation of input data. + 3. If set, reduce latent representation to 2 dimensions (e.g. using t-SNE or PCA). + 4. If set, compute how accurately each dataset can be reconstructed. + 5. If set, compute the feature importance in the latent space. + + Args: + interim_data_path: + Directory where encoded data is stored + results_path: + Directory where results will be saved + discrete_dataset_names: + Names of discrete datasets + continuous_dataset_names: + Names of continuous datasets + batch_size: + Number of samples in one batch (used during training and testing) + model_config: + Config of the VAE + training_loop_config: + Config of the training loop + compute_accuracy_metrics: + Whether accuracy metrics for each dataset will be computed. + compute_feature_importance: + Whether feature importance for each feature in the latent space will + be computed. May take a while depending on number of features. + reducer_config: + Config of the reducer used to further reduce the dimensions of the + latent space to two dimensions. Expected to behave like a + transformer from scikit-learn. + features_to_plot: + List of feature names to generate color-coded latent space plots. + If not given, no latent space will be generated, but a CSV file + containing all latent representations will still be created. + """ + + loop_filename: str = "loop.yaml" + model_filename: str = "model.pt" + results_subdir: str = "latent_space" + + def __init__( + self, + interim_data_path: PathLike, + results_path: PathLike, + compute_accuracy_metrics: bool, + compute_feature_importance: bool, + reducer_config: Optional[ReducerConfig] = None, + features_to_plot: Optional[list[str]] = None, + **kwargs, + ) -> None: + super().__init__( + input_dir=interim_data_path, + output_dir=Path(results_path) / self.results_subdir, + **kwargs, + ) + self.reducer_config = reducer_config + self.features_to_plot = features_to_plot + self.compute_accuracy_metrics = compute_accuracy_metrics + self.compute_feature_importance = compute_feature_importance + + def run(self) -> Any: + from move.models.base import reload_vae + + model_path = self.output_dir / self.model_filename + + if model_path.exists(): + self.logger.warning( + f"A model file was found: '{model_path}' and will be reloaded. " + "Erase the file if you wish to train a new model." + ) + self.logger.debug("Re-loading model") + model = reload_vae(model_path) + else: + self.logger.debug("Training a new model") + train_dataloader = self.make_dataloader() + model = self.init_model(train_dataloader) + training_loop = self.init_training_loop() + training_loop.train(model, train_dataloader) + training_loop.plot() + training_loop.to_yaml(self.output_dir / self.loop_filename) + model.save(model_path) + + model.eval() + + all_dataloader = self.make_dataloader(shuffle=False, drop_last=False) + test_dataloader = self.make_dataloader(split="test") + + project_subtask = Project(model, all_dataloader, self.reducer_config) + project_subtask.parent = self + project_subtask.run() + if self.features_to_plot: + project_subtask.plot(self.features_to_plot) + + if self.compute_accuracy_metrics: + metrics_subtask = ComputeAccuracyMetrics(self, model, test_dataloader) + metrics_subtask.run() + metrics_subtask.plot() + + if self.compute_feature_importance: + fi_subtask = FeatureImportance(self, model, all_dataloader) + fi_subtask.run() + fi_subtask.plot() + + +class Project(CsvWriterMixin, SubTask): + """Use a variational autoencoder to compress input data from a dataloader + into a latent space. Additionally, use a reducer to further compress the + latent space into an even lower-dimensional space that can be easily + visualized (e.g., in 2D or 3D). + + Args: + model: Variational autoencoder model + dataloader: Data loader + """ + + filename: str = "latent_space.csv" + plot_filename_fmt: str = "latent_space_{}.png" + reducer: Optional[TransformerMixin] + + def __init__( + self, + model: "BaseVae", + dataloader: "MoveDataLoader", + reducer_config: Optional[ReducerConfig], + output_dir: Optional[PathLike] = None, + ): + self.model = model + self.dataloader = dataloader + if reducer_config is None: + self.reducer = None + else: + self.reducer = hydra.utils.instantiate(reducer_config) + if output_dir is not None: + self.output_dir = output_dir + + @property + def num_features(self) -> int: + return self.model.num_latent + + @property + def num_reduced_features(self) -> int: + if self.reducer is None: + return 0 + return getattr(self.reducer, "n_components") + + def plot(self, feature_names: list[str]) -> None: + # NOTE: assumes 2D + if self.csv_filepath is None: + raise ValueError("No CSV data found") + data = pd.read_csv(self.csv_filepath) + latent_space = np.take(data.values, (0, 1), axis=1) + for name in feature_names: + try: + dataset = self.dataloader.dataset.find(name) + except KeyError as e: + self.log(str(e), "WARNING") + continue + # Obtain target values for color coding + target_values = dataset.select(name).numpy() + if dataset.data_type == "discrete": + # Convert one-hot encoded values to category codes + is_nan = target_values.sum(axis=1) == 0 + target_values = np.argmax(target_values, axis=1) + code2cat_map = { + str(code): category for category, code in dataset.mapping.items() + } + fig = viz.plot_latent_space_with_cat( + latent_space, name, target_values, code2cat_map, is_nan + ) + else: + fig = viz.plot_latent_space_with_con(latent_space, name, target_values) + fig_filename = sanitize_filename(self.plot_filename_fmt.format(name)) + fig_path = str(self.output_dir / fig_filename) + fig.savefig(fig_path, bbox_inches="tight") + + @torch.no_grad() + def run(self) -> None: + if self.parent is None: + raise UnsetProperty("Output directory") + + colnames = [f"reduced_dim{i}" for i in range(self.num_reduced_features)] + colnames.extend([f"dim{i}" for i in range(self.num_features)]) + + csv_filepath = self.parent.output_dir / self.filename + self.init_csv_writer(csv_filepath, fieldnames=colnames, extrasaction="ignore") + + self.log("Compressing input data") + + if self.dataloader.dataset.perturbation is not None: + self.log("Dataset's perturbation will be removed", "WARNING") + self.dataloader.dataset.remove_perturbation() + + tensors = [] + for (batch,) in self.dataloader: + tensors.append(self.model.project(batch)) + + latent_space: NDArray = torch.cat(tensors, dim=0).numpy() + if self.reducer is None: + array = latent_space + else: + self.log("Reducing data to two dimensions") + reduced_latent_space = self.reducer.fit_transform(latent_space) + array = np.hstack((reduced_latent_space, latent_space)) + + self.write_cols({colname: array[:, i] for i, colname in enumerate(colnames)}) + + self.close_csv_writer() diff --git a/src/move/tasks/move.py b/src/move/tasks/move.py new file mode 100644 index 00000000..764400b0 --- /dev/null +++ b/src/move/tasks/move.py @@ -0,0 +1,93 @@ +__all__ = ["MoveTask"] + +from typing import TYPE_CHECKING, Optional + +import hydra +from torch import nn + +from move.core.exceptions import UnsetProperty +from move.core.typing import Split +from move.tasks.base import ParentTask + +if TYPE_CHECKING: + from move.conf.models import ModelConfig + from move.conf.training import TrainingLoopConfig + from move.data.dataloader import MoveDataLoader + from move.models.base import BaseVae + from move.training.loop import TrainingLoop + + +class MoveTask(ParentTask): + """A task that can initialize a MOVE model, dataloader, and training loop.""" + + def __init__( + self, + discrete_dataset_names: list[str], + continuous_dataset_names: list[str], + batch_size: int, + model_config: Optional["ModelConfig"], + training_loop_config: Optional["TrainingLoopConfig"], + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.discrete_dataset_names = discrete_dataset_names + self.continuous_dataset_names = continuous_dataset_names + self.model_config = model_config + self.batch_size = batch_size + self.training_loop_config = training_loop_config + + def make_dataloader( + self, split: Split = "all", **dataloader_kwargs + ) -> "MoveDataLoader": + """Make a MOVE dataloader. For the training split, data will be shuffled + and the last batch will be dropped.""" + from move.conf.training import DataLoaderConfig + from move.data.dataset import MoveDataset + + dataset = MoveDataset.load( + self.input_dir, + self.discrete_dataset_names, + self.continuous_dataset_names, + split, + ) + + is_training = not (split == "test" or split == "valid") + if "shuffle" not in dataloader_kwargs: + dataloader_kwargs["shuffle"] = is_training + + if "drop_last" not in dataloader_kwargs: + dataloader_kwargs["drop_last"] = is_training + + dataloader_kwargs["batch_size"] = self.batch_size + config = DataLoaderConfig(**dataloader_kwargs) + return hydra.utils.instantiate(config, dataset=dataset) + + def init_model(self, dataloader: "MoveDataLoader") -> "BaseVae": + """Initialize a MOVE model.""" + if self.model_config is None: + raise UnsetProperty("Model config") + return hydra.utils.instantiate( + self.model_config, + discrete_shapes=dataloader.dataset.discrete_shapes, + continuous_shapes=dataloader.dataset.continuous_shapes, + ) + + def init_training_loop(self, set_parent: bool = True) -> "TrainingLoop": + """Initialize a training loop. + + Args: + set_parent: + Whether the training task is linked to a parent task. If + orphaned, this task cannot use the logger to track its progress. + """ + if self.training_loop_config is None: + raise UnsetProperty("Training loop config") + training_loop: "TrainingLoop" = hydra.utils.instantiate( + self.training_loop_config, _recursive_=False + ) + if set_parent: + training_loop.parent = self + else: + # if orphan, cannot use logger + training_loop.prog_every_n_epoch = None + return training_loop diff --git a/src/move/tasks/train_model.py b/src/move/tasks/train_model.py new file mode 100644 index 00000000..79dae46f --- /dev/null +++ b/src/move/tasks/train_model.py @@ -0,0 +1,80 @@ +__all__ = ["TrainModel"] + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Union + +import hydra + +from move.core.exceptions import FILE_EXISTS_WARNING +from move.core.typing import PathLike +from move.tasks.base import ParentTask + +if TYPE_CHECKING: + from move.conf.models import ModelConfig + from move.conf.training import TrainingLoopConfig + from move.data.dataloader import MoveDataLoader + from move.models.base import BaseVae + from move.training.loop import TrainingLoop + + +class TrainModel(ParentTask): + """Train a single model.""" + + model_filename: str = "model.pt" + loop_filename: str = "loop.yaml" + results_subdir: str = "train_model" + + def __init__( + self, + interim_data_path: PathLike, + results_path: PathLike, + discrete_dataset_names: list[str], + continuous_dataset_names: list[str], + batch_size: int, + model_config: Union["ModelConfig", dict[str, Any]], + training_loop_config: Union["TrainingLoopConfig", dict[str, Any]], + ) -> None: + super().__init__( + input_dir=interim_data_path, + output_dir=Path(results_path) / self.results_subdir, + ) + self.discrete_dataset_names = discrete_dataset_names + self.continuous_dataset_names = continuous_dataset_names + self.batch_size = batch_size + self.training_loop_config = training_loop_config + self.model_config = model_config + + def make_dataloader(self, **kwargs) -> "MoveDataLoader": + from move.data.dataloader import MoveDataLoader + from move.data.dataset import MoveDataset + + dataset = MoveDataset.load( + self.input_dir, self.discrete_dataset_names, self.continuous_dataset_names + ) + return MoveDataLoader(dataset, **kwargs) + + def run(self) -> None: + model_path = self.output_dir / self.model_filename + if model_path.exists(): + self.logger.warning(FILE_EXISTS_WARNING.format(model_path)) + # Init data/model + dataloader = self.make_dataloader( + batch_size=self.batch_size, shuffle=True, drop_last=True + ) + model: "BaseVae" = hydra.utils.instantiate( + self.model_config, + discrete_shapes=dataloader.dataset.discrete_shapes, + continuous_shapes=dataloader.dataset.continuous_shapes, + ) + self.logger.info("Training model") + # Train + training_loop: "TrainingLoop" = hydra.utils.instantiate( + self.training_loop_config, _recursive_=False + ) + training_loop.parent = self + training_loop.run(model, dataloader) + training_loop.plot() + self.logger.info("Training complete!") + # Save model/config + training_loop.to_yaml(self.output_dir / self.loop_filename) + model.save(model_path) diff --git a/src/move/tasks/tune_model.py b/src/move/tasks/tune_model.py index 0e5d417c..90c52fca 100644 --- a/src/move/tasks/tune_model.py +++ b/src/move/tasks/tune_model.py @@ -19,17 +19,17 @@ calculate_accuracy, calculate_cosine_similarity, ) -from move.conf.schema import ( - MOVEConfig, +from move.conf.legacy import ( TuneModelConfig, TuneModelReconstructionConfig, TuneModelStabilityConfig, ) +from move.conf.schema import MOVEConfig from move.core.logging import get_logger from move.core.typing import BoolArray, FloatArray from move.data import io from move.data.dataloaders import MOVEDataset, make_dataloader, split_samples -from move.models.vae import VAE +from move.models.vae_legacy import VAE TaskType = Literal["reconstruction", "stability"] @@ -208,7 +208,9 @@ def _tune_reconstruction( batch_size=task_config.batch_size, ) cat_recons, con_recons = model.reconstruct(dataloader) - con_recons = np.split(con_recons, np.cumsum(model.continuous_shapes[:-1]), axis=1) + con_recons = np.split( + con_recons, np.cumsum(model.continuous_shapes[:-1]), axis=1 + ) for cat, cat_recon, dataset_name in zip( cat_list, cat_recons, config.data.categorical_names ): diff --git a/src/move/tasks/tuning.py b/src/move/tasks/tuning.py new file mode 100644 index 00000000..68741de6 --- /dev/null +++ b/src/move/tasks/tuning.py @@ -0,0 +1,324 @@ +__all__ = ["TuneModel", "TuneStability"] + +from collections import defaultdict +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal, cast + +import numpy as np +import torch +from hydra.core.hydra_config import HydraConfig +from hydra.types import RunMode +from matplotlib.cbook import boxplot_stats +from sklearn.metrics.pairwise import cosine_similarity + +from move.analysis.metrics import ( + calculate_accuracy, + calculate_cosine_similarity, +) +from move.core.exceptions import FILE_EXISTS_WARNING +from move.core.typing import FloatArray, PathLike +from move.tasks.base import CsvWriterMixin +from move.tasks.move import MoveTask + +if TYPE_CHECKING: + from move.data.dataloader import MoveDataLoader + from move.models.base import BaseVae + +TaskType = Literal["reconstruction", "stability"] + +BOXPLOT_STATS = ["mean", "med", "q1", "q3", "iqr", "cilo", "cihi", "whislo", "whishi"] + + +class TuneModel(CsvWriterMixin, MoveTask): + """Run a model with a set of hyperparameters and report metrics, such as, + reconstruction accuracy, loss, or stability. + + Args: + interim_data_path: + Directory where encoded data is stored + results_path: + Directory where results will be saved + discrete_dataset_names: + Names of discrete datasets + continuous_dataset_names: + Names of continuous datasets + batch_size: + Number of samples in one batch (used during training and testing) + model_config: + Config of the VAE + training_loop_config: + Config of the training loop + """ + + loss_filename: str = "loss.csv" + metrics_filename: str = "metrics.csv" + model_filename_fmt: str = "model_{}.pt" + results_subdir: str = "tuning" + + def __init__( + self, + interim_data_path: PathLike, + results_path: PathLike, + **kwargs, + ) -> None: + # Check that Hydra is being used (e.g., not called from a script/notebook) + try: + hydra_config = HydraConfig.get() + except ValueError: + raise ValueError("Use the command line to run this task.") + + if hydra_config.mode != RunMode.MULTIRUN: + raise ValueError("This task must run in multirun mode.") + + super().__init__( + input_dir=interim_data_path, + output_dir=Path(results_path) / self.results_subdir, + **kwargs, + ) + + # Delete sweep run config + sweep_config_path = Path(hydra_config.sweep.dir).joinpath("multirun.yaml") + if sweep_config_path.exists(): + sweep_config_path.unlink() + + kv_sep = hydra_config.job.config.override_dirname.kv_sep + item_sep = hydra_config.job.config.override_dirname.item_sep + + self.job_num = hydra_config.job.num + 1 + self.job_name = hydra_config.job.override_dirname + self.override_dict: dict[str, str] = {} + for item in hydra_config.job.override_dirname.split(item_sep): + key, value = item.split(kv_sep) + self.override_dict[key] = value + + def record_metrics( + self, model: "BaseVae", dataloader_dict: dict[str, "MoveDataLoader"] + ): + """Record accuracy or cosine similarity metric for each dataloader. + + Args: + model: VAE model + dataloader_dict: Dict of dataloaders corresponding to different data subsets + """ + + colnames = [ + "job_num", + *self.override_dict.keys(), + "metric", + "dataset_name", + "split", + ] + colnames.extend(BOXPLOT_STATS) + + self.init_csv_writer( + self.output_dir / self.metrics_filename, + mode="a", + fieldnames=colnames, + extrasaction="ignore", + ) + + for split, dataloader in dataloader_dict.items(): + datasets = dataloader.datasets + scores_per_dataset: dict[str, list[FloatArray]] = defaultdict(list) + + for (batch,) in dataloader: + batch_disc, batch_cont = model.split_input(batch) + recon = model.reconstruct(batch, as_one=True) + recon_disc, recon_cont = model.split_input(recon) + + for i, dataset in enumerate(datasets[: len(batch_disc)]): + target = batch_disc[i].numpy() + preds = torch.argmax( + (torch.log_softmax(recon_disc[i], dim=-1)), dim=-1 + ).numpy() + scores = calculate_accuracy(target, preds) + scores_per_dataset[dataset.name].append(scores) + + for i, dataset in enumerate(datasets[len(batch_disc) :]): + target = batch_cont[i].numpy() + preds = recon_cont[i].numpy() + scores = calculate_cosine_similarity(target, preds) + scores_per_dataset[dataset.name].append(scores) + + for dataset in datasets: + metric = ( + "accuracy" + if dataset.data_type == "discrete" + else "cosine_similarity" + ) + csv_row: dict[str, Any] = dict( + job_num=self.job_num, + **self.override_dict, + metric=metric, + dataset_name=dataset.name, + split=split, + ) + scores = np.concatenate(scores_per_dataset[dataset.name], axis=0) + bxp_stas, *_ = boxplot_stats(scores) + bxp_stas.pop("fliers") + csv_row.update(bxp_stas) + + self.add_row_to_buffer(csv_row) + + # Append to file + self.close_csv_writer(True) + + def record_loss(self, model: "BaseVae", dataloader: "MoveDataLoader"): + """Record final loss in a CSV row.""" + from move.models.base import LossDict + + colnames = ["job_num", *self.override_dict.keys()] + colnames.extend(LossDict.__annotations__.keys()) + + self.init_csv_writer( + self.output_dir / self.loss_filename, + mode="a", + fieldnames=colnames, + extrasaction="ignore", + ) + + loss_epoch = None + + for (batch,) in dataloader: + loss_batch = model.compute_loss(batch, 1.0) + if loss_epoch is None: + loss_epoch = loss_batch + else: + for key in loss_batch.keys(): + loss_epoch[key] += loss_batch[key] + + csv_row: dict[str, Any] = dict(job_num=self.job_num, **self.override_dict) + + assert loss_epoch is not None + for key, value in loss_epoch.items(): + if isinstance(value, torch.Tensor): + csv_row[key] = value.item() / len(dataloader) + else: + csv_row[key] = cast(float, value) / len(dataloader) + + # Append to file + self.add_row_to_buffer(csv_row) + self.close_csv_writer(True) + + def run(self): + from move.models.base import reload_vae + + model_path = self.output_dir / self.model_filename_fmt.format(self.job_num) + loss_filepath = self.output_dir / self.loss_filename + metrics_filepath = self.output_dir / self.metrics_filename + + for filepath in (loss_filepath, metrics_filepath): + if filepath.exists() and self.job_num == 1: + filepath.unlink() + self.logger.warning(FILE_EXISTS_WARNING.format(filepath)) + + dataloaders = { + "train": self.make_dataloader("train"), + "test": self.make_dataloader("test"), + } + + if model_path.exists(): + self.logger.warning( + f"A model file was found: '{model_path}' and will be reloaded. " + "Erase the file if you wish to train a new model." + ) + self.logger.debug(f"Re-loading model {self.job_num}") + model = reload_vae(model_path) + else: + self.logger.debug(f"Training a new model {self.job_num}") + model = self.init_model(dataloaders["train"]) + training_loop = self.init_training_loop(False) + training_loop.train(model, dataloaders["train"]) + model.save(model_path) + + model.freeze() + + self.record_loss(model, dataloaders["test"]) + self.record_metrics(model, dataloaders) + + +class TuneStability(TuneModel): + """Train a number of models and compute the stability of their latent space. + + Args: + num_refits: Number of models to train + """ + + stabilility_filename: str = "stability.csv" + + def __init__(self, num_refits: int, **kwargs) -> None: + super().__init__(**kwargs) + self.num_refits = num_refits + self.baseline_cosine_sim = None + + def calculate_stability(self, latent_repr: FloatArray) -> float: + """Compute stability (mean difference between the cosine similarities of two + latent representations). + + Args: + latent_repr: Latent representation""" + + if self.baseline_cosine_sim is None: + raise ValueError("Cannot calculate stability without a baseline.") + cosine_sim = cosine_similarity(latent_repr) + abs_diff = np.absolute(cosine_sim - self.baseline_cosine_sim) + # Remove diagonal (cosine similarity with itself) + diff = abs_diff[~np.eye(abs_diff.shape[0], dtype=bool)].reshape( + abs_diff.shape[0], -1 + ) + return np.mean(diff).item() + + def run(self) -> None: + results_filepath = self.output_dir / self.stabilility_filename + + if results_filepath.exists() and self.job_num == 1: + results_filepath.unlink() + self.logger.warning(FILE_EXISTS_WARNING.format(results_filepath)) + + train_dataloader = self.make_dataloader("train") + test_dataloader = self.make_dataloader("test") + + stabs: list[float] = [] + + for i in range(self.num_refits): + self.logger.debug(f"Refit: {i+1}/{self.num_refits}") + model = self.init_model(train_dataloader) + training_loop = self.init_training_loop(False) + training_loop.train(model, train_dataloader) + model.freeze() + + # Create latent representation for all samples + latent_reprs = [] + for (batch,) in test_dataloader: + latent_reprs.append(model.project(batch)) + latent_repr = torch.concat(latent_reprs).numpy() + + if self.baseline_cosine_sim is None: + # Store first cosine similarity as baseline + self.baseline_cosine_sim = cosine_similarity(latent_repr) + else: + # Calculate stability from a pair of cosine similarity arrays + stability = self.calculate_stability(latent_repr) + stabs.append(stability) + + # Append row to CSV file + colnames = ["job_num", *self.override_dict.keys(), "metric"] + colnames.extend(BOXPLOT_STATS) + + self.init_csv_writer( + results_filepath, + mode="a", + fieldnames=colnames, + extrasaction="ignore", + ) + csv_row: dict[str, Any] = dict( + job_num=self.job_num, + **self.override_dict, + metric="stability", + ) + bxp_stas, *_ = boxplot_stats(stabs) + bxp_stas.pop("fliers") + csv_row.update(bxp_stas) + + self.add_row_to_buffer(csv_row) + self.close_csv_writer() diff --git a/src/move/training/loop.py b/src/move/training/loop.py new file mode 100644 index 00000000..815439eb --- /dev/null +++ b/src/move/training/loop.py @@ -0,0 +1,296 @@ +__all__ = [] + +import math +from typing import Literal, Optional, cast + +import hydra +import pandas as pd +import torch +from torch.nn.utils.clip_grad import clip_grad_norm_ +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + +import move.visualization as viz +from move.conf.optim import LrSchedulerConfig, OptimizerConfig +from move.data.dataloader import MoveDataLoader +from move.models.base import BaseVae, LossDict +from move.tasks.base import CsvWriterMixin, SubTask +from move.training.optim.prodigy import Prodigy + +AnnealingFunction = Literal["linear", "cosine", "sigmoid", "stairs"] +AnnealingSchedule = Literal["monotonic", "cyclical"] + + +class TrainingLoop(CsvWriterMixin, SubTask): + """Train a VAE model. + + Args: + optimizer_config: + Configuration for the optimizer. + lr_scheduler_config: + Configuration for the learning rate scheduler. + max_epochs: + Max training epochs, may be lower if early stopping is implemented. + max_grad_norm: + If other than none, clip gradient norm. + annealing_epochs: + Epochs required to fully warm KL divergence. Set to 0 and a + `monotonic` schedue to turn off KL divergence annealing. + annealing_function: + Function to warm KL divergence. + annealing_schedule: + Whether KL divergence is warmed monotonically or cyclically. + prog_every_n_epoch: + Log progress every n-th epoch. Note this only controls a message + displaying the current epoch. Loss and other metrics are logged at + every step. + log_grad: + Whether gradients should be logged. + log_lr: + Whether LR should be logged. + """ + + max_steps: int + global_step: int + + def __init__( + self, + optimizer_config: OptimizerConfig, + lr_scheduler_config: Optional[LrSchedulerConfig] = None, + max_epochs: int = 100, + max_grad_norm: Optional[float] = None, + annealing_epochs: int = 20, + annealing_function: AnnealingFunction = "linear", + annealing_schedule: AnnealingSchedule = "monotonic", + prog_every_n_epoch: Optional[int] = 10, + log_grad: bool = False, + log_lr: bool = False, + ): + if annealing_epochs < 0: + raise ValueError("Annealing epochs must be a non-negative integer") + if annealing_epochs == 0 and annealing_schedule == "cyclical": + raise ValueError( + "Annealing epochs must be a positive integer if schedule is cyclical" + ) + self.optimizer_config = optimizer_config + self.lr_scheduler_config = lr_scheduler_config + self.max_epochs = max_epochs + self.max_grad_norm = max_grad_norm + self.annealing_epochs = annealing_epochs + self.annealing_function = annealing_function + self.annealing_schedule = annealing_schedule + self.current_epoch = 0 + self.prog_every_n_epoch = prog_every_n_epoch + self.log_grad = log_grad + self.log_lr = log_lr + + def _repr_html_(self) -> str: + return "" + + @property + def annealing_factor(self) -> float: + epoch = self.current_epoch + if ( + self.annealing_schedule == "monotonic" and epoch < self.annealing_epochs + ) or (self.annealing_schedule == "cyclical"): + if self.annealing_function == "stairs": + num_epochs_cyc = self.max_epochs / self.num_cycles + # location in cycle: 0 (start) - 1 (end) + loc = (epoch % math.ceil(num_epochs_cyc)) / num_epochs_cyc + # first half of the cycle, KL weight is warmed up + # second half, it is fixed + if loc <= 0.5: + return loc * 2 + else: + num_steps_cyc = self.max_steps / self.num_cycles + step = self.global_step + loc = (step % math.ceil(num_steps_cyc)) / num_steps_cyc + if loc < 0.5: + if self.annealing_function == "linear": + return loc * 2 + elif self.annealing_function == "sigmoid": + # ensure it reaches 0.5 at 1/4 of the cycle + shift = 0.25 + slope = self.annealing_epochs + return 1 / (1 + math.exp(slope * (shift - loc))) + elif self.annealing_function == "cosine": + return math.cos((loc - 0.5) * math.pi) + return 1.0 + + @property + def kl_weight(self) -> float: + return self.annealing_factor * 1.0 + + @property + def num_cycles(self) -> float: + return self.max_epochs / (self.annealing_epochs * 2) + + def plot(self) -> None: + if self.parent is not None and self.csv_filepath is not None: + data = pd.read_csv(self.csv_filepath) + data["kl_div"] *= data["kl_weight"] + fig = viz.plot_loss_curves(data) + fig_path = str(self.parent.output_dir / "loss_curve.png") + fig.savefig(fig_path, bbox_inches="tight") + + def run(self, model: BaseVae, train_dataloader: MoveDataLoader) -> None: + return self.train(model, train_dataloader) + + def get_colnames(self, model: Optional[BaseVae] = None) -> list[str]: + """Return the list of column names of the CSV being generated. If set + to log gradients, a model would be required to obtain the names of its + parameters. + + Args: + model: a deep learning model""" + colnames = ["epoch", "step"] + colnames.extend(LossDict.__annotations__.keys()) + if self.log_grad and model is not None: + for module_name, module in model.named_children(): + param_names = [] + for param_name, param in module.named_parameters(module_name): + if param.requires_grad: + param_names.append(param_name) + if param_names: + colnames.append(module_name) + colnames.extend(param_names) + if self.log_lr: + colnames.append("lr") + return colnames + + def get_last_lr( + self, optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] + ) -> float: + if isinstance(optimizer, Prodigy): + d = optimizer.param_groups[0]["d"] + lr = optimizer.param_groups[0]["lr"] + return d * lr + elif lr_scheduler is None: + return optimizer.param_groups[0]["lr"] + else: + return lr_scheduler.get_last_lr()[0] + + def make_row( + self, loss_dict: LossDict, model: BaseVae, lr: Optional[float] + ) -> dict[str, float]: + """Format a loss dictionary and the model's gradients into a dictionary + representing a CSV row. + + Args: + loss_dict: dictionary with loss metrics + model: deep-learning model + """ + csv_row: dict[str, float] = { + "epoch": self.current_epoch + 1, + "step": self.global_step + 1, + } + for key, value in loss_dict.items(): + if isinstance(value, torch.Tensor): + csv_row[key] = value.item() + else: + csv_row[key] = cast(float, value) + if self.log_grad and model is not None: + for module_name, module in model.named_children(): + grads = [] + for param_name, param in module.named_parameters(module_name): + if param.grad is not None: + grad = torch.norm(param.grad.detach()) + grads.append(grad) + csv_row[param_name] = grad.item() + if len(grads) > 0: + csv_row[module_name] = torch.norm(torch.stack(grads)).item() + if self.log_lr and lr is not None: + csv_row["lr"] = lr + return csv_row + + def train( + self, + model: BaseVae, + train_dataloader: MoveDataLoader, + ) -> None: + """Train a VAE model. + + Args: + model: VAE model + train_dataloader: Training data loader + """ + num_batches = len(train_dataloader) + self.max_steps = self.max_epochs * num_batches + self.global_step = 0 + if self.parent: + self.init_csv_writer( + self.parent.output_dir / "loss_curve.csv", + fieldnames=self.get_colnames(model), + ) + + if train_dataloader.dataset.perturbation is not None: + self.log("Dataset's perturbation will be removed", "WARNING") + train_dataloader.dataset.perturbation = None + + optimizer: Optimizer = hydra.utils.instantiate( + self.optimizer_config, params=model.parameters() + ) + if self.lr_scheduler_config: + lr_scheduler: Optional[LRScheduler] = hydra.utils.instantiate( + self.lr_scheduler_config, optimizer=optimizer + ) + else: + lr_scheduler = None + + for epoch in range(0, self.max_epochs): + self.current_epoch = epoch + + model.train() + + for (batch,) in train_dataloader: + # Zero gradients + optimizer.zero_grad() + + # Forward pass + try: + loss_dict = model.compute_loss(batch, self.annealing_factor) + except (KeyboardInterrupt, ValueError) as exception: + self.close_csv_writer() + raise exception + # Backward pass and optimize + loss_dict["elbo"].backward() + + if self.max_grad_norm is not None: + clip_grad_norm_(model.parameters(), self.max_grad_norm) + + optimizer.step() + + lr = ( + None + if not self.log_lr + else self.get_last_lr(optimizer, lr_scheduler) + ) + + csv_row = self.make_row(loss_dict, model, lr) + self.add_row_to_buffer(csv_row) + + self.global_step += 1 + + """ if valid_dataloader is not None: + model.eval() + with torch.no_grad(): + for (batch,) in valid_dataloader: + loss_dict = model.compute_loss(batch, self.annealing_factor) + for key, value in loss_dict.items(): + if isinstance(value, torch.Tensor): + epoch_loss[f"valid_{key}"] += value.item() / num_batches """ + + if lr_scheduler is not None: + lr_scheduler.step() + + self.current_epoch += 1 + + if ( + self.prog_every_n_epoch is not None + and self.current_epoch % self.prog_every_n_epoch == 0 + ): + num_zeros = int(math.log10(self.max_epochs)) + 1 + self.log(f"Epoch {self.current_epoch:0{num_zeros}}") + + model.freeze() + self.close_csv_writer() diff --git a/src/move/training/optim/__init__.py b/src/move/training/optim/__init__.py new file mode 100644 index 00000000..ff0e3a61 --- /dev/null +++ b/src/move/training/optim/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["Prodigy"] + +from move.training.optim.prodigy import Prodigy diff --git a/src/move/training/optim/prodigy.py b/src/move/training/optim/prodigy.py new file mode 100644 index 00000000..49c21e09 --- /dev/null +++ b/src/move/training/optim/prodigy.py @@ -0,0 +1,277 @@ +# SEE: https://github.com/konstmish/prodigy + +__all__ = ["Prodigy"] + +import math +from typing import TYPE_CHECKING, Callable, Optional + +import torch +import torch.optim +import torch.distributed as dist + +if TYPE_CHECKING: + from torch.optim.optimizer import _params_t + + +class Prodigy(torch.optim.Optimizer): + """Implement Adam with Prodigy step-sizes. + Leave LR set to 1 unless you encounter instability. + + Args: + params: + Iterable of parameters to optimize or dicts defining parameter groups. + lr: + Learning rate adjustment parameter. Increases or decreases the Prodigy + learning rate. + betas: + Coefficients used for computing running averages of gradient and its square + beta3: + Coefficients for computing the Prodidy stepsize using running averages. + If set to None, uses the value of square root of beta2 + eps: + Term added to the denominator outside of the root operation to improve + numerical stability. + weight_decay: + Weight decay, i.e. a L2 penalty + decouple: + Use AdamW style decoupled weight decay + use_bias_correction: + Turn on Adam's bias correction. + safeguard_warmup: + Remove lr from the denominator of D estimate to avoid issues during warm-up + stage. + d0: + Initial D estimate for D-adaptation. Rarely needs changing. + d_coef: + Coefficient in the expression for the estimate of d + Values such as 0.5 and 2.0 typically work as well. + Changing this parameter is the preferred way to tune the method. + growth_rate: + prevent the D estimate from growing faster than this multiplicative rate. + Default is inf, for unrestricted. Values like 1.02 give a kind of learning + rate warmup effect. + fsdp_in_use: + If you're using sharded parameters, this should be set to True. The optimizer + will attempt to auto-detect this, but if you're using an implementation other + than PyTorch's builtin version, the auto-detection won't work. + """ + + def __init__( + self, + params: "_params_t", + lr: float = 1.0, + betas: tuple[float, float] = (0.9, 0.999), + beta3: Optional[float] = None, + eps: float = 1e-8, + weight_decay: float = 0, + decouple: bool = True, + use_bias_correction: bool = False, + safeguard_warmup: bool = False, + d0: float = 1e-6, + d_coef: float = 1.0, + growth_rate: float = float("inf"), + fsdp_in_use: bool = False, + ): + if not 0.0 < d0: + raise ValueError("Invalid d0 value: {}".format(d0)) + if not 0.0 < lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 < eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + defaults = dict( + lr=lr, + betas=betas, + beta3=beta3, + eps=eps, + weight_decay=weight_decay, + d=d0, + d0=d0, + d_max=d0, + d_numerator=0.0, + d_coef=d_coef, + k=0, + growth_rate=growth_rate, + use_bias_correction=use_bias_correction, + decouple=decouple, + safeguard_warmup=safeguard_warmup, + fsdp_in_use=fsdp_in_use, + ) + self.d0 = d0 + super().__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return False + + @property + def supports_flat_params(self): + return True + + def step(self, closure: Optional[Callable] = None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + d_denom = 0.0 + + group = self.param_groups[0] + use_bias_correction = group["use_bias_correction"] + beta1, beta2 = group["betas"] + beta3 = group["beta3"] + if beta3 is None: + beta3 = math.sqrt(beta2) + k = group["k"] + + d = group["d"] + d_max = group["d_max"] + d_coef = group["d_coef"] + lr = max(group["lr"] for group in self.param_groups) + + if use_bias_correction: + bias_correction = ((1 - beta2 ** (k + 1)) ** 0.5) / (1 - beta1 ** (k + 1)) + else: + bias_correction = 1 + + dlr = d * lr * bias_correction + + growth_rate = group["growth_rate"] + decouple = group["decouple"] + fsdp_in_use = group["fsdp_in_use"] + + d_numerator = group["d_numerator"] + d_numerator *= beta3 + + for group in self.param_groups: + decay = group["weight_decay"] + k = group["k"] + eps = group["eps"] + group_lr = group["lr"] + d0 = group["d0"] + safeguard_warmup = group["safeguard_warmup"] + + if group_lr not in [lr, 0.0]: + raise RuntimeError( + f"Setting different lr values in different parameter groups is only supported for values of 0" + ) + + for p in group["params"]: + if p.grad is None: + continue + if hasattr(p, "_fsdp_flattened"): + fsdp_in_use = True + + grad = p.grad.data + + # Apply weight decay (coupled variant) + if decay != 0 and not decouple: + grad.add_(p.data, alpha=decay) + + state = self.state[p] + + # State initialization + if "step" not in state: + state["step"] = 0 + state["s"] = torch.zeros_like(p.data).detach() + state["p0"] = p.detach().clone() + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data).detach() + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data).detach() + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + s = state["s"] + p0 = state["p0"] + + if group_lr > 0.0: + # we use d / d0 instead of just d to avoid getting values that are too small + d_numerator += ( + (d / d0) + * dlr + * torch.dot(grad.flatten(), (p0.data - p.data).flatten()).item() + ) + + # Adam EMA updates + exp_avg.mul_(beta1).add_(grad, alpha=d * (1 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=d * d * (1 - beta2) + ) + + if safeguard_warmup: + s.mul_(beta3).add_(grad, alpha=((d / d0) * d)) + else: + s.mul_(beta3).add_(grad, alpha=((d / d0) * dlr)) + d_denom += s.abs().sum().item() + + ###### + + d_hat = d + + # if we have not done any progres, return + # if we have any gradients available, will have d_denom > 0 (unless \|g\|=0) + if d_denom == 0: + return loss + + if lr > 0.0: + if fsdp_in_use: + dist_tensor = torch.zeros(2).cuda() + dist_tensor[0] = d_numerator + dist_tensor[1] = d_denom + dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) + global_d_numerator = dist_tensor[0] + global_d_denom = dist_tensor[1] + else: + global_d_numerator = d_numerator + global_d_denom = d_denom + + d_hat = d_coef * global_d_numerator / global_d_denom + if d == group["d0"]: + d = max(d, d_hat) + d_max = max(d_max, d_hat) + d = min(d_max, d * growth_rate) + + for group in self.param_groups: + group["d_numerator"] = global_d_numerator + group["d_denom"] = global_d_denom + group["d"] = d + group["d_max"] = d_max + group["d_hat"] = d_hat + + decay = group["weight_decay"] + k = group["k"] + eps = group["eps"] + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + + state = self.state[p] + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + state["step"] += 1 + + denom = exp_avg_sq.sqrt().add_(d * eps) + + # Apply weight decay (decoupled variant) + if decay != 0 and decouple: + p.data.add_(p.data, alpha=-decay * dlr) + + ### Take step + p.data.addcdiv_(exp_avg, denom, value=-dlr) + + group["k"] = k + 1 + + return loss diff --git a/src/move/training/training_loop.py b/src/move/training/training_loop.py index 8f59afc4..33b4386a 100644 --- a/src/move/training/training_loop.py +++ b/src/move/training/training_loop.py @@ -1,8 +1,10 @@ +__all__ = [] + from typing import Optional from torch.utils.data import DataLoader -from move.models.vae import VAE +from move.models.vae_legacy import VAE TrainingLoopOutput = tuple[list[float], list[float], list[float], list[float], float] diff --git a/src/move/visualization/__init__.py b/src/move/visualization/__init__.py index f841e309..60afe95e 100644 --- a/src/move/visualization/__init__.py +++ b/src/move/visualization/__init__.py @@ -1,6 +1,7 @@ __all__ = [ "LOSS_LABELS", "color_cycle", + "generate_grid", "plot_categorical_feature_importance", "plot_continuous_feature_importance", "plot_latent_space_with_cat", @@ -15,6 +16,7 @@ plot_categorical_feature_importance, plot_continuous_feature_importance, ) +from move.visualization.grid import generate_grid from move.visualization.latent_space import ( plot_latent_space_with_cat, plot_latent_space_with_con, diff --git a/src/move/visualization/contrast.py b/src/move/visualization/contrast.py new file mode 100644 index 00000000..1320a166 --- /dev/null +++ b/src/move/visualization/contrast.py @@ -0,0 +1,50 @@ +__all__ = ["get_luminance", "get_contrast_ratio", "MIN_CONTRAST"] + +from typing import Sequence, Union + +import numpy as np +from numpy.typing import NDArray + +MIN_CONTRAST = 4.5 +# https://www.w3.org/WAI/WCAG22/Understanding/contrast-minimum.html + +Color = Union[Sequence[float], NDArray] + + +def get_luminance(color: Color) -> float: + """Compute relative brightness of any point in a colorspace, normalized + to 0 for darkest black and 1 for lightest white. + + Args: + color: + Array or three-element tuple representing a color in terms of red, + green, and blue coordinates + """ + # http://www.w3.org/TR/2008/REC-WCAG20-20081211/#relativeluminancedef + color = np.asarray(color) + w = np.array([0.2126, 0.7152, 0.0722]) + rgb = np.piecewise( + color, + [color <= 0.03928, color > 0.03928], + [lambda x: x / 12.92, lambda x: ((x + 0.055) / 1.055) ** 2.4], + ) + return (w @ rgb).item() + + +def get_contrast_ratio(color1: Color, color2: Color) -> float: + """Compute ratio between the relative luminance of the lighter and the + darker of two colors. Contrast ratios range from 1 to 21. + + Args: + color1: + color2: + Array or three-element tuple representing a color in terms of red, + green, and blue coordinates + """ + # http://www.w3.org/TR/2008/REC-WCAG20-20081211/#contrast-ratiodef + lum1 = get_luminance(color1) + lum2 = get_luminance(color2) + ratio = (lum1 + 0.05) / (lum2 + 0.05) + if lum2 > lum1: + return ratio**-1 + return ratio diff --git a/src/move/visualization/feature_importance.py b/src/move/visualization/feature_importance.py index f16cf55f..18f2e554 100644 --- a/src/move/visualization/feature_importance.py +++ b/src/move/visualization/feature_importance.py @@ -10,11 +10,11 @@ from matplotlib.colors import TwoSlopeNorm from move.core.typing import FloatArray +from move.visualization.figure import create_figure from move.visualization.style import ( DEFAULT_DIVERGING_PALETTE, DEFAULT_PLOT_STYLE, DEFAULT_QUALITATIVE_PALETTE, - color_cycle, style_settings, ) @@ -87,8 +87,10 @@ def plot_categorical_feature_importance( ) ) with style_settings(style): - fig, ax = plt.subplots(figsize=figsize) - sns.stripplot(data=data, x="x", y="y", hue="category", size=1, ax=ax, palette=colormap) + fig, ax = create_figure(figsize=figsize) + sns.stripplot( + data=data, x="x", y="y", hue="category", size=1, ax=ax, palette=colormap + ) ax.set(xlabel="Impact on latent space", ylabel="Feature") # Fix labels in legend legend = ax.get_legend() @@ -165,14 +167,15 @@ def plot_continuous_feature_importance( vmin, vmax = data["value"].min(), data["value"].max() norm = TwoSlopeNorm(0.0, vmin, vmax) sm = ScalarMappable(norm, colormap) - data["category"] = np.ma.compressed(norm(data["value"]) * 25).astype(int) + category_values = np.multiply(norm(data["value"]), 25) + data["category"] = np.ma.compressed(category_values).astype(int) palette = np.empty((25, 4)) # 25 colors x 4 channels palette[:13, :] = sm.to_rgba(np.linspace(vmin, 0, 13)) # first slope palette[12:, :] = sm.to_rgba(np.linspace(0, vmax, 13)) # second slope palette = palette.tolist() # NDArray not always supported with style_settings(style): - fig, ax = plt.subplots(figsize=figsize) + fig, ax = create_figure(figsize=figsize) sns.stripplot( data=data, x="x", y="y", hue="category", ax=ax, palette=palette, size=2 ) diff --git a/src/move/visualization/figure.py b/src/move/visualization/figure.py new file mode 100644 index 00000000..8c96fdfa --- /dev/null +++ b/src/move/visualization/figure.py @@ -0,0 +1,18 @@ +__all__ = ["create_figure"] + +from matplotlib import pyplot as plt +from matplotlib.axes import Axes +from matplotlib.figure import Figure + + +def create_figure(**fig_kw) -> tuple[Figure, Axes]: + """Create a figure. + + Returns: + A tuple containing a Figure and an Axes object. Unlike the customary + (and equivalent) `matplotlib.pyplot.subplots` function, this method is + correctly typed. That's the only difference.""" + fig, ax = plt.subplots(**fig_kw) + assert isinstance(fig, Figure) + assert isinstance(ax, Axes) + return fig, ax diff --git a/src/move/visualization/grid.py b/src/move/visualization/grid.py new file mode 100644 index 00000000..fbbd7c13 --- /dev/null +++ b/src/move/visualization/grid.py @@ -0,0 +1,229 @@ +__all__ = ["find_grid_dimensions", "facet_grid", "generate_grid"] + +import math +from typing import Literal, Optional + +import matplotlib.axes +import matplotlib.figure +import matplotlib.pyplot as plt +import pandas as pd +from matplotlib.colors import LogNorm, Normalize + +from move.visualization.style import style_settings + +Orientation = Literal["vertical", "horizontal"] +Location = Literal["left", "right", "top", "bottom"] + + +def find_grid_dimensions(num_subplots: int) -> tuple[int, int]: + """Compute minimum number of columns and number of rows necessary to + accommodate a given number of subplots into a nearly square grid.""" + # Adapted from: https://gist.github.com/pganssle/5e921b0dfc93ac54f3c35fea2cbcff2f + + num_sqrt_f = math.sqrt(num_subplots) + num_sqrt = math.ceil(math.sqrt(num_subplots)) + + if num_sqrt == num_sqrt_f: + # perfect square + return num_sqrt, num_sqrt + elif num_subplots <= num_sqrt * (num_sqrt - 1): + # try horizontal rectangle + x, y = num_sqrt, num_sqrt - 1 + elif not (num_sqrt % 2) and num_subplots % 2: + # try horizontal rectangle + x, y = num_sqrt + 1, num_sqrt - 1 + else: + # square grid + x, y = num_sqrt, num_sqrt + return x, y + + +def facet_grid( + data: pd.DataFrame, + x_name: str, + y_name: str, + hue_name: str, + facet_name: str, + facet_title_fmt: str, + hue_label: Optional[str] = None, + x_label: Optional[str] = None, + y_label: Optional[str] = None, + cbar_orientation: Orientation = "vertical", + cbar_location: Location = "right", + sharex: bool = False, + sharey: bool = False, + use_lognorm: bool = True, +) -> matplotlib.figure.Figure: + """Form a matrix of panels to visualize four variables. The facet variable + should have discrete values, whereas the other three be continuous. + + Two variables will be represented by a scatter plot, and the remaining + variable will be represented by the hue of the scatter dots. An + accompanying color bar will be generated. + + Args: + x_name: + Name of column corresponding to variable represented by x-axis + y_name: + Name of column corresponding to variable represented by y-axis + hue_name: + Name of column corresponding to variable represented by hue + facet_name: + Name of column corresponding to variable represented by facet + facet_title_fmt: + Format used to create the label of each subplot + x_label: + Label of the x-axis + y_label: + Label of the y-axis + hue_label: + Label of the colorbar + cbar_orientation: + Whether the colorbar is displayed horizontally or vertically + cbar_location: + Where the colorbar should be positioned + sharex: + Whether all subplots should share the same x-axis + sharey: + Whether all subplots should share the same y-axis + use_lognorm: + Whether the hue variable should be represented in the log dimension + """ + levels = data[facet_name].unique() + if len(levels) == len(data): + raise ValueError(f"f{facet_name} is not discrete.") + + vmin, vmax = data[hue_name].min(), data[hue_name].max() + norm_class = LogNorm if use_lognorm else Normalize + norm = norm_class(vmin, vmax) + + with style_settings("ggplot"): + fig, axs, cax = generate_grid( + len(levels), + x_label, + y_label, + cbar_orientation, + cbar_location, + sharex, + sharey, + ) + markers = None + + for i, level in enumerate(levels): + subset = data[data[facet_name] == level] + + ax = axs[i] + ax.plot( + subset[x_name], subset[y_name], color="k", alpha=0.75, linestyle=":" + ) + markers = ax.scatter( + subset[x_name], + subset[y_name], + c=subset[hue_name], + norm=norm, + zorder=100, + ) + ax.set(title=facet_title_fmt.format(level)) + + assert markers is not None + fig.colorbar(markers, cax, orientation=cbar_orientation) + if hue_label: + if cbar_orientation == "horizontal": + cax.set(xlabel=hue_label) + else: + cax.set(ylabel=hue_label) + + fig.tight_layout() + + return fig + + +def generate_grid( + num_subplots: int, + x_label: Optional[str] = None, + y_label: Optional[str] = None, + cbar_orientation: Orientation = "vertical", + cbar_location: Location = "right", + sharex: bool = False, + sharey: bool = False, +) -> tuple[matplotlib.figure.Figure, list[matplotlib.axes.Axes], matplotlib.axes.Axes]: + """Form a matrix of panels to visualize multiple variables. + + Args: + num_subplots: + Number of subplots to accomodate in the grid + x_label: + Label of the x-axis + y_label: + Label of the y-axis + cbar_orientation: + Whether the colorbar is displayed horizontally or vertically + cbar_location: + Where the colorbar should be positioned + sharex: + Whether all subplots should share the same x-axis + sharey: + Whether all subplots should share the same y-axis + """ + if cbar_orientation == "horizontal": + if cbar_location not in ("top", "bottom"): + raise ValueError( + "Only 'top' or 'bottom' location is valid for 'horizontal' alignment" + ) + elif cbar_orientation == "vertical": + if cbar_location not in ("left", "right"): + raise ValueError( + "Only 'left' or 'right' location is valid for 'vertical' alignment" + ) + else: + raise ValueError("Only 'horizontal' or 'vertical' alignment allowed") + + ncols, nrows = find_grid_dimensions(num_subplots) + num_unused = ncols * nrows - num_subplots + + fig = plt.figure(figsize=(4 * ncols, 3 * nrows)) + + if cbar_orientation == "horizontal": + if cbar_location == "top": + cax_idx = 0 + height_ratios = [1] + [3 * ncols] * nrows + else: + cax_idx = -1 + height_ratios = [3 * ncols] * nrows + [1] + + gs = fig.add_gridspec(nrows + 1, ncols, height_ratios=height_ratios) + cax = fig.add_subplot(gs[cax_idx, :]) + else: + if cbar_location == "left": + cax_idx = 0 + width_ratios = [1] + [4 * nrows] * ncols + else: + cax_idx = -1 + width_ratios = [4 * nrows] * ncols + [1] + + gs = fig.add_gridspec(nrows, ncols + 1, width_ratios=width_ratios) + cax = fig.add_subplot(gs[:, cax_idx]) + + axs = [] + for i in range(num_subplots): + x_coord = (i // ncols) + 1 * (cbar_location == "top") + y_coord = (i % ncols) + 1 * (cbar_location == "left") + + kwargs = {} + if sharex and len(axs) > 0: + kwargs["sharex"] = axs[0] + if sharey and len(axs) > 0: + kwargs["sharey"] = axs[0] + + ax = fig.add_subplot(gs[x_coord, y_coord], **kwargs) + axs.append(ax) + + if x_label: + for ax in axs[-(ncols - num_unused) :]: + ax.set(xlabel=x_label) + + if y_label: + for i in range(0, len(axs), ncols): + axs[i].set(ylabel=y_label) + + return fig, axs, cax diff --git a/src/move/visualization/latent_space.py b/src/move/visualization/latent_space.py index b07826b5..6b6720cf 100644 --- a/src/move/visualization/latent_space.py +++ b/src/move/visualization/latent_space.py @@ -1,7 +1,8 @@ __all__ = ["plot_latent_space_with_cat", "plot_latent_space_with_con"] -from typing import Any +from typing import Any, cast +import matplotlib.axes import matplotlib.figure import matplotlib.pyplot as plt import matplotlib.style @@ -56,13 +57,15 @@ def plot_latent_space_with_cat( raise ValueError("Expected at least two dimensions in latent space.") with style_settings(style), color_cycle(colormap): fig, ax = plt.subplots() + assert isinstance(fig, matplotlib.figure.Figure) + assert isinstance(ax, matplotlib.axes.Axes) codes = np.unique(feature_values) for code in codes: category = feature_mapping[str(code)] is_category = (feature_values == code) & ~is_nan - dims = np.take(latent_space.compress(is_category, axis=0), [0, 1], axis=1).T + dims = np.take(latent_space.compress(is_category, axis=0), (0, 1), axis=1).T ax.scatter(*dims, label=category) - dims = np.take(latent_space.compress(is_nan, axis=0), [0, 1], axis=1).T + dims = np.take(latent_space.compress(is_nan, axis=0), (0, 1), axis=1).T ax.scatter(*dims, label="NaN") ax.set(xlabel="dim 0", ylabel="dim 1") legend = ax.legend() @@ -98,8 +101,10 @@ def plot_latent_space_with_con( norm = TwoSlopeNorm(0.0, min(feature_values), max(feature_values)) with style_settings(style): fig, ax = plt.subplots() + assert isinstance(fig, matplotlib.figure.Figure) + assert isinstance(ax, matplotlib.axes.Axes) dims = latent_space[:, 0], latent_space[:, 1] - pts = ax.scatter(*dims, c=feature_values, cmap=colormap, norm=norm) + pts = ax.scatter(*dims, c=cast(list, feature_values), cmap=colormap, norm=norm) cbar = fig.colorbar(pts, ax=ax) cbar.ax.set(ylabel=feature_name) ax.set(xlabel="dim 0", ylabel="dim 1") diff --git a/src/move/visualization/loss_curves.py b/src/move/visualization/loss_curves.py index 41daa290..381a8282 100644 --- a/src/move/visualization/loss_curves.py +++ b/src/move/visualization/loss_curves.py @@ -1,11 +1,14 @@ __all__ = ["LOSS_LABELS", "plot_loss_curves"] from collections.abc import Sequence +from typing import Union import matplotlib.figure -import matplotlib.pyplot as plt import numpy as np +import pandas as pd +from move.visualization.figure import create_figure +from move.visualization.scale import axis_scale from move.visualization.style import ( DEFAULT_PLOT_STYLE, DEFAULT_QUALITATIVE_PALETTE, @@ -13,19 +16,25 @@ style_settings, ) -LOSS_LABELS = ("Loss", "Cross-Entropy", "Sum of Squared Errors", "KLD") +LOSS_LABELS = ( + "Loss", + "Reconstruction error (discrete)", + "Reconstruction error (continuous)", + "Regularization error", +) def plot_loss_curves( - losses: Sequence[list[float]], + losses: Union[Sequence[list[float]], pd.DataFrame], labels: Sequence[str] = LOSS_LABELS, style: str = DEFAULT_PLOT_STYLE, colormap: str = DEFAULT_QUALITATIVE_PALETTE, + xlabel: str = "Epochs", ) -> matplotlib.figure.Figure: """Plot one or more loss curves. Args: - losses: List containing lists of loss values + losses: List containing lists of loss values or a DataFrame labels: List containing names of each loss line style: Name of style to apply to the plot colormap: Name of colormap to use for the curves @@ -33,12 +42,27 @@ def plot_loss_curves( Returns: Figure """ - num_epochs = len(losses[0]) - epochs = np.arange(num_epochs) + is_df = isinstance(losses, pd.DataFrame) + if is_df: + # Calculate epoch from steps + max_epochs = losses["epoch"].max() + max_steps = losses["step"].max() + steps_epoch = max_steps / max_epochs + x_values = losses["step"] / steps_epoch + losses.drop(["epoch", "step"], axis=1, inplace=True) + yscale = axis_scale(losses.iloc[:, 0]) + else: + x_values = np.arange(len(losses[0])) + yscale = axis_scale(losses[0]) with style_settings(style), color_cycle(colormap): - fig, ax = plt.subplots() - for loss, label in zip(losses, labels): - ax.plot(epochs, loss, label=label, linestyle="-") + fig, ax = create_figure() + for i, label in enumerate(labels): + if is_df: + colname = losses.columns[i] + loss = losses[colname] + else: + loss = losses[i] + ax.plot(x_values, loss, label=label, linestyle="-") ax.legend() - ax.set(xlabel="Epochs", ylabel="Loss") + ax.set(xlabel=xlabel, ylabel="Loss", yscale=yscale) return fig diff --git a/src/move/visualization/metrics.py b/src/move/visualization/metrics.py index 26e28f19..d9058587 100644 --- a/src/move/visualization/metrics.py +++ b/src/move/visualization/metrics.py @@ -1,12 +1,14 @@ __all__ = ["plot_metrics_boxplot"] from collections.abc import Sequence +from typing import Callable, Optional, Union, cast import matplotlib import matplotlib.figure -import matplotlib.pyplot as plt +import pandas as pd from move.core.typing import FloatArray +from move.visualization.figure import create_figure from move.visualization.style import ( DEFAULT_PLOT_STYLE, DEFAULT_QUALITATIVE_PALETTE, @@ -16,8 +18,8 @@ def plot_metrics_boxplot( - scores: Sequence[FloatArray], - labels: Sequence[str], + scores: Union[Sequence[FloatArray], pd.DataFrame], + labels: Optional[Sequence[str]], style: str = DEFAULT_PLOT_STYLE, colormap: str = DEFAULT_QUALITATIVE_PALETTE, ) -> matplotlib.figure.Figure: @@ -25,20 +27,27 @@ def plot_metrics_boxplot( score corresponds (for example) to a sample. Args: - scores: List of dataset metrics - labels: List of dataset names + scores: List of dataset metrics or DataFrame + labels: List of dataset names. If None, DataFrame column names will be used. style: Name of style to apply to the plot colormap: Name of colormap to use for the boxes Returns: Figure """ + is_df = isinstance(scores, pd.DataFrame) + not_na: Callable[[pd.Series], pd.Series] = lambda sr: sr.notna() + values = [scores[col][not_na].values for col in scores.columns] if is_df else scores # type: ignore + if labels is None: + if not is_df: + raise ValueError("Label names missing") + labels = cast(Sequence[str], scores.columns) with style_settings(style), color_cycle(colormap): labelcolor = matplotlib.rcParams["axes.labelcolor"] - fig, ax = plt.subplots() + fig, ax = create_figure() comps = ax.boxplot( - scores[::-1], - labels=labels[::-1], + values, + labels=labels, patch_artist=True, vert=False, capprops=dict(color=labelcolor), @@ -55,4 +64,5 @@ def plot_metrics_boxplot( for box, prop in zip(comps["boxes"], prop_cycle()): box.update(dict(facecolor=prop["color"], edgecolor=labelcolor)) ax.set(xlim=(-0.05, 1.05), xlabel="Score", ylabel="Dataset") + ax.invert_yaxis() return fig diff --git a/src/move/visualization/scale.py b/src/move/visualization/scale.py new file mode 100644 index 00000000..75e1937d --- /dev/null +++ b/src/move/visualization/scale.py @@ -0,0 +1,21 @@ +__all__ = ["axis_scale"] + + +import math +from typing import Literal, Sequence, Union + +import numpy as np +import pandas as pd + +from move.core.typing import FloatArray + +Scale = Literal["log", "linear"] + + +def axis_scale(data: Union[FloatArray, pd.Series, Sequence[float]]) -> Scale: + """Determine which scale (either log or linear) to use when plotting. If the data + spans more than two orders of magnitude, log scale will be used.""" + ratio = math.log10(np.nanmax(data) / np.nanmin(data)) + if ratio > 2: + return "log" + return "linear" diff --git a/src/move/visualization/style.py b/src/move/visualization/style.py index d414c2ef..d76a2068 100644 --- a/src/move/visualization/style.py +++ b/src/move/visualization/style.py @@ -23,13 +23,13 @@ def color_cycle(colormap: str) -> ContextManager: """Returns a context manager for using a color cycle in plots. Args: - colormap: Name of qualitative color map. + colormap: Name of palette. Returns: Context manager """ registry: ColormapRegistry = getattr(matplotlib, "colormaps") - colormap = registry[colormap] + colormap = registry[colormap] # type: ignore if not isinstance(colormap, ListedColormap): raise ValueError("Only colormaps that are list of colors supported.") prop_cycle = cycler(color=getattr(colormap, "colors")) diff --git a/tutorial/config/task/random_small__latent.yaml b/tutorial/config/task/random_small__latent.yaml index 76552e14..1d1e1ce2 100644 --- a/tutorial/config/task/random_small__latent.yaml +++ b/tutorial/config/task/random_small__latent.yaml @@ -1,18 +1,32 @@ defaults: - - analyze_latent + - task_latent_space + - model_config: vae + - training_loop_config/optimizer_config: optim_adam + - training_loop_config/lr_scheduler_config: null + - override reducer_config: pca -batch_size: 10 - -feature_names: +features_to_plot: - drugs_20 - metagenomics_174 - proteomics_195 -model: +compute_accuracy_metrics: true +compute_feature_importance: true + +model_config: num_hidden: - 1000 num_latent: 150 + kl_weight: 1e-4 + dropout_rate: 0.1 + +batch_size: 10 + +training_loop_config: + max_epochs: 40 + log_grad: false -training_loop: - lr: 1e-4 - num_epochs: 40 + # change type of optimizer above + optimizer_config: + lr: 1e-4 + weight_decay: 0 \ No newline at end of file diff --git a/tutorial/notebooks/01 Encoding data.ipynb b/tutorial/notebooks/01 Encoding data.ipynb index 894ec047..1b71f2bc 100644 --- a/tutorial/notebooks/01 Encoding data.ipynb +++ b/tutorial/notebooks/01 Encoding data.ipynb @@ -6,8 +6,10 @@ "metadata": {}, "outputs": [], "source": [ - "from move.data import io\n", - "from move.tasks import encode_data" + "from pathlib import Path\n", + "\n", + "from move.conf.schema import InputConfig\n", + "from move.tasks.encode_data import EncodeData" ] }, { @@ -20,9 +22,9 @@ "\n", "⚠️ The notebook takes user-defined configs in a `config/data` directory.\n", "\n", - "For encoding the data you need to have each dataset in a TSV format. Each table has `N` × `M` shape, where `N` is the numer of samples/individuals and `M` is the number of features. The continuous data is z-score normalized, whereas the categorical data is one-hot encoded. Below is an example of processing a continuous and categorical datasets.\n", + "For encoding the data you need to have each dataset in a TSV format. Each table has `N` × `M` shape, where `N` is the numer of samples/individuals and `M` is the number of features. The continuous data is z-score normalized (`standardization`), whereas the categorical data is one-hot encoded (`one_hot_encoding`). Below is an example of processing a continuous and categorical datasets.\n", "\n", - "First step is to read the configuration called `random_small` and specify the pre-defined task called `encode_data`." + "First step is to locate where our `random_small` dataset is, which datasets it comprises, and what kind of pre-processing each dataset will undergo." ] }, { @@ -31,14 +33,23 @@ "metadata": {}, "outputs": [], "source": [ - "config = io.read_config(\"random_small\", \"encode_data\")" + "base_path = Path(\"../data\")\n", + "interim_path = Path(\"../interim_data\")\n", + "\n", + "discrete_dnames = [\"random.small.drugs\"]\n", + "continuous_dnames = [\"random.small.proteomics\", \"random.small.metagenomics\"]\n", + "\n", + "# Indicate which kind of pre-processing is required for each config file.\n", + "# If a dataset has already been pre-processed, you can set this to 'none'\n", + "disc_conf = [InputConfig(name, preprocessing=\"one_hot_encode\") for name in discrete_dnames]\n", + "cont_conf = [InputConfig(name, preprocessing=\"standardize\") for name in continuous_dnames]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The next step is to run the `encode_data` task, passing our `config` object." + "The next step is to run the `EncodeData` task." ] }, { @@ -50,15 +61,16 @@ "name": "stderr", "output_type": "stream", "text": [ - "[INFO - encode_data]: Beginning task: encode data\n", - "[INFO - encode_data]: Encoding 'random.small.drugs'\n", - "[INFO - encode_data]: Encoding 'random.small.proteomics'\n", - "[INFO - encode_data]: Encoding 'random.small.metagenomics'\n" + "[INFO - EncodeData]: Beginning task: encode data\n", + "[INFO - EncodeData]: Encoding 'random.small.drugs'\n", + "[INFO - EncodeData]: Encoding 'random.small.proteomics'\n", + "[INFO - EncodeData]: Encoding 'random.small.metagenomics'\n" ] } ], "source": [ - "encode_data(config.data)" + "task = EncodeData(base_path, interim_path, \"random.small.ids\", disc_conf, cont_conf)\n", + "task.run()" ] }, { @@ -67,87 +79,114 @@ "source": [ "Data will be encoded accordingly and saved to the directory defined as `interim_data_path` in the `data` configuration.\n", "\n", - "We can confirm how the data looks by loading it." + "We can confirm how the data looks by loading it and creating a `MoveDataset` object. This type of object concatenates our datasets and keeps the information such as original dataset shapes and feature names.\n", + "\n", + "The drug dataset has been encoded as a matrix of 500 samples × 20 drugs × 2 categories (either took or did not take the drug), whereas the proteomics and metagenomics datasets keep their original shapes." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " MOVE dataset (500 samples)\n", + "
datatype# features# classes
random.small.drugsdiscrete202
random.small.proteomicscontinuous200N/A
random.small.metagenomicscontinuous1,000N/A
" + ], + "text/plain": [ + "MoveDataset(3 datasets)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "from pathlib import Path\n", - "\n", - "path = Path(config.data.interim_data_path)\n", + "from move.data.dataset import MoveDataset\n", "\n", - "cat_datasets, cat_names, con_datasets, con_names = io.load_preprocessed_data(path, config.data.categorical_names, config.data.continuous_names)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "assert len(cat_datasets) == 1 # one categorical dataset\n", - "assert len(con_datasets) == 2 # two continuous datasets\n", - "assert len(cat_names) == 1\n", - "assert len(con_names) == 2" + "dataset = MoveDataset.load(interim_path, discrete_dnames, continuous_dnames)\n", + "dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The drug dataset has been encoded as a matrix of 500 samples × 20 drugs × 2 categories (either took or did not take the drug), whereas the proteomics and metagenomics datasets keep their original shapes." + "We can also confirm that the mean of the continuous datasets is now close to 0, and the standard deviation is close to 1." ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "random.small.drugs: (500, 20, 2)\n", - "random.small.proteomics: (500, 200)\n", - "random.small.metagenomics: (500, 1000)\n" + "random.small.proteomics: mean = -0.000, std = 0.975\n", + "random.small.metagenomics: mean = 0.000, std = 0.975\n" ] } ], "source": [ - "dataset_names = config.data.categorical_names + config.data.continuous_names\n", - "\n", - "for dataset, dataset_name in zip(cat_datasets + con_datasets, dataset_names):\n", - " print(f\"{dataset_name}: {dataset.shape}\")" + "for continuous_dataset in dataset.continuous_datasets:\n", + " print(f\"{continuous_dataset.name}: mean = {continuous_dataset.tensor.mean():.3f}, std = {continuous_dataset.tensor.std():.3f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can also confirm that the mean of the continuous datasets is now close to 0, and the standard deviation is close to 1." + "Alternatively, you can directly read the config YAML and create a task object from its content.\n", + "\n", + "Note, that for this to work, the config directory structure must be in the same directory as your notebook." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "random.small.proteomics: mean = -0.000, std = 0.975\n", - "random.small.metagenomics: mean = 0.000, std = 0.975\n" + "[INFO - EncodeData]: Beginning task: encode data\n", + "[INFO - EncodeData]: Encoding 'random.small.drugs'\n", + "[INFO - EncodeData]: Encoding 'random.small.proteomics'\n", + "[INFO - EncodeData]: Encoding 'random.small.metagenomics'\n" ] } ], "source": [ - "for dataset, dataset_name in zip(con_datasets, dataset_names[1:]):\n", - " print(f\"{dataset_name}: mean = {dataset.mean():.3f}, std = {dataset.std():.3f}\")" + "from move.data.io import read_config\n", + "\n", + "if not Path.cwd().joinpath(\"config/data\").exists():\n", + " raise FileNotFoundError(\"Requires a config files in the current working directory.\")\n", + "\n", + "config = read_config(\"random_small\", \"encode_data\", \"data.raw_data_path='../data'\")\n", + "task = EncodeData.from_config(config)\n", + "task.run()" ] } ], @@ -167,7 +206,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.9.18" }, "orig_nbformat": 4, "vscode": {