Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ MOVE update #89

Open
wants to merge 183 commits into
base: developer
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
183 commits
Select commit Hold shift + click to select a range
8b67825
:sparkles: Allow turning off input scaling
ri-heme Sep 14, 2023
e52e328
:sparkles: Add new layers module
richi3f Sep 26, 2023
16c73e6
:construction: Re-do encode data module
ri-heme Oct 2, 2023
9c89f6d
:sparkles: :construction: Add exceptions module
ri-heme Oct 2, 2023
4c0cf9e
:construction: Save encoded data in one file
ri-heme Oct 2, 2023
0cdbb18
:sparkles: :construction: Introduce new dataset/loader objects
ri-heme Oct 2, 2023
f079fd9
:sparkles: Add "SplitOutput" module
ri-heme Oct 3, 2023
07a152c
:construction: Add type hint to splitting modules
ri-heme Oct 3, 2023
3138e6d
:construction: Add shape data to dataset
ri-heme Oct 18, 2023
13abcae
:construction: Add sub-task, add in/output paths to task
ri-heme Oct 18, 2023
c279df1
:sparkles: Introduce Training Loop sub-task
ri-heme Oct 18, 2023
8e8e76b
:truck: Move legacy VAE
ri-heme Oct 18, 2023
523f361
:sparkles: :construction: Add new VAE
ri-heme Oct 18, 2023
275f5b8
:art: Sort imports
ri-heme Oct 18, 2023
7959e66
:sparkles: Make x-label configurable
ri-heme Oct 18, 2023
5c2ef2f
:construction: Rename params
ri-heme Oct 18, 2023
fe5a32b
:sparkles: :construction: Add train model task
ri-heme Oct 18, 2023
c2597b0
:construction: Add base VAE class
ri-heme Oct 20, 2023
e816f78
:sparkles: :construction: Add CSV writer
ri-heme Jan 8, 2024
0c8b53e
:construction: Add properties to data classes
ri-heme Jan 8, 2024
cec4e3f
:construction: Enhance base task module
ri-heme Jan 8, 2024
c29ef6f
:sparkles: :construction: Add metrics sub-task module
ri-heme Jan 8, 2024
8e5626a
:construction: Add latent space analysis sub-task
ri-heme Jan 8, 2024
6d54884
:recycle: :construction: Refactor tasks
ri-heme Jan 8, 2024
ca98638
:construction: Adjust MRO
ri-heme Jan 8, 2024
b19ac69
:recycle: :construction: Refactor CSV writer mixin
ri-heme Jan 10, 2024
d35e355
:sparkles: :construction: Make MOVE dataset perturbable
ri-heme Jan 11, 2024
4c9969e
:sparkles: :construction: Add feature importance sub-task module
ri-heme Jan 11, 2024
e1f4945
:bug: Fix circular import
ri-heme Jan 11, 2024
bb356b6
:bug: Fix properties
ri-heme Jan 11, 2024
87d23a0
:construction: Make no grad
ri-heme Jan 11, 2024
6461d3f
:construction: Save model
ri-heme Jan 11, 2024
0fc02a5
:construction: Add functions to save/reload models
ri-heme Jan 11, 2024
ea74b78
:bug: Properly save CSV results
ri-heme Jan 11, 2024
83eff46
:bug: Iterate through datasets
ri-heme Jan 11, 2024
f6be63f
:wrench: :construction: Add data loader config
ri-heme Jan 11, 2024
ed8dc82
:construction: Add split input layer
ri-heme Jan 15, 2024
eecdabe
:bug: :construction: Adjust logger name
ri-heme Jan 15, 2024
620fadf
:art: Type hinting
ri-heme Jan 15, 2024
3ff6e09
:speaker: Log progress during training
ri-heme Jan 15, 2024
f611820
:bug: Fix logger name
ri-heme Jan 16, 2024
17430f9
:sparkles: :construction: Add VAE-T module
ri-heme Jan 16, 2024
33d8d81
:construction: Log gradients
ri-heme Jan 16, 2024
2584b3c
:wrench: Add optimizer configs
ri-heme Jan 19, 2024
2622255
:bug: :construction: Remove extra layer
ri-heme Jan 19, 2024
8e99572
:construction: Add output dir task
ri-heme Jan 19, 2024
6e07474
:construction: Freeze model
ri-heme Jan 19, 2024
cb4fb6a
:bug: :construction: Fixes and grad clipping
ri-heme Jan 19, 2024
ea3e8f6
:sparkles: Save task as YAML
ri-heme Jan 19, 2024
d55135f
:wrench: :construction: Add config dataclasses
ri-heme Jan 19, 2024
3b4e15b
:construction: Add activation to continuous output
ri-heme Jan 19, 2024
3b3733e
:construction: Reload from abc class
ri-heme Jan 22, 2024
b1103af
:art: Type hints
ri-heme Jan 22, 2024
81bc5f6
:art: :construction: Add alias, type hint
ri-heme Jan 22, 2024
1db6790
:recycle: :construction: Add select feature, refactor feat. idx
ri-heme Jan 22, 2024
599805a
:art: Type hints
ri-heme Jan 22, 2024
856289c
:wrench: :construction: Add config dataclasses
ri-heme Jan 22, 2024
9939eef
:sparkles: Properly sanitize filenames
ri-heme Jan 22, 2024
d5d6b84
:recycle: :construction: Use new config data classes
ri-heme Jan 22, 2024
9adcf6e
:sparkles: :construction: Add projection sub-task
ri-heme Jan 22, 2024
3fd1709
:construction: Ensure no perturbations
ri-heme Jan 22, 2024
5259ba5
:truck: Move modules
ri-heme Jan 22, 2024
ead7d08
:bug: Correctly instantiate training loop
ri-heme Jan 22, 2024
ca7ef3d
:bug: Fix mapping unpacking
ri-heme Jan 22, 2024
64adf2e
:speaker: Logging
ri-heme Jan 22, 2024
72a166b
:construction: Save loop config
ri-heme Jan 23, 2024
3033270
:bug: Instantiate loop
ri-heme Jan 26, 2024
d5d1181
:construction: Imports
ri-heme Jan 26, 2024
19801cc
:recycle: :construction: Store datasets in list
ri-heme Jan 31, 2024
e2d083a
:sparkles: Add contrast module
ri-heme Jan 31, 2024
168d47e
:construction: Reconstruct
ri-heme Jan 31, 2024
7efce2b
:see_no_evil: Ignore Nbooks
ri-heme Feb 1, 2024
de314dd
:fire: Remove model
ri-heme Feb 1, 2024
25eba25
:sparkles: Add HDI module
ri-heme Feb 1, 2024
2f0a783
:art: Sort imports, type hinting
ri-heme Feb 1, 2024
17e7d75
:see_no_evil: Ignore non-default config files
ri-heme Feb 1, 2024
4dd7092
:recycle: :construction: Refactor
ri-heme Feb 9, 2024
ffdff69
:sparkles: :construction: Add reservoir module
ri-heme Feb 9, 2024
1f45b82
:recycle: Return tuple instead of stacked tensor
ri-heme Feb 9, 2024
60da1ca
:fire: Remove files
ri-heme Feb 9, 2024
505717d
:fire: Remove files
ri-heme Feb 9, 2024
f2d9932
:bug: Fix validation of stream size
ri-heme Feb 13, 2024
83876ae
:sparkles: :construction: Add FDR module
ri-heme Feb 13, 2024
989c823
:sparkles: Fill function
ri-heme Feb 15, 2024
90d6d74
:art: Improve IntArray type alias
ri-heme Feb 15, 2024
cfbe5f2
:sparkles: Facet grid plots
ri-heme Feb 15, 2024
538cd4a
:construction: Ignore NaNs in loss
ri-heme Mar 12, 2024
da0910e
:construction: Return small reservoir
ri-heme Mar 13, 2024
4f975d9
:bug: Split output
ri-heme Mar 21, 2024
eb79f37
:construction: Remove redundant class
ri-heme Mar 21, 2024
7f1de7b
:bug: Fix metrics
ri-heme Mar 25, 2024
22481df
:construction: Exception handling UMAP
ri-heme Mar 25, 2024
7c128f5
:construction: Save loop config
ri-heme Apr 8, 2024
73a6044
:art: Type hints
ri-heme Apr 16, 2024
63f48e1
:construction: Minor functions
ri-heme Apr 16, 2024
c78e457
:wrench: Add perturbation config
ri-heme Apr 16, 2024
3bae6cd
:recycle: :construction: Re-structure mixins
ri-heme Apr 18, 2024
be8b6e1
:construction: Make configs optional
ri-heme Apr 18, 2024
c5b2fc1
:sparkles: :construction: Add associations module
ri-heme Apr 22, 2024
6808a49
:recycle: :construction: Return dist args as dict
ri-heme Apr 23, 2024
83a971c
:art: Organize imports
ri-heme Apr 23, 2024
036a942
:truck: :construction: Move VAE-t
ri-heme Apr 23, 2024
a02e507
:recycle: :construction: Refactor VAE with distribution
ri-heme Apr 23, 2024
d5b16e8
:bug: :construction: Fix metrics calculation
ri-heme May 3, 2024
e278931
:wrench: Add config for VAE-normal
ri-heme May 3, 2024
d6ab4c9
:art: Sort imports
ri-heme May 23, 2024
4e7571a
:construction: Fix supported distributions
ri-heme May 23, 2024
007623b
:construction: Move qualname to own module
ri-heme May 23, 2024
9f64a96
:construction: Update config schema
ri-heme May 23, 2024
325edf6
:truck: :construction: Move deprecated schemas
ri-heme May 23, 2024
0ce3684
:construction: :fire: Fix imports of deprecated classes
ri-heme May 28, 2024
562db2f
:construction: :wrench: Add/move configs to store
ri-heme May 28, 2024
c4b3892
:construction: Update command line
ri-heme May 28, 2024
8c33fa7
:wrench: Update tutorial config
ri-heme May 28, 2024
e125f31
:construction: :bug: Register resolvers
ri-heme May 29, 2024
d65388b
:construction: Enable latent space task
ri-heme May 29, 2024
a04b26a
:sparkles: Create figure module
ri-heme May 29, 2024
28a5bde
:art: Update labels and type hints
ri-heme May 29, 2024
f92751a
:bug: Correctly handle DataFrame
ri-heme May 29, 2024
546d6e5
:construction: Plot regularization error
ri-heme May 29, 2024
f7b428f
:memo: :speaker: Warn about existing model
ri-heme May 29, 2024
8ed357e
:construction: :sparkles: Handle DataFrame
ri-heme May 29, 2024
b84d19a
:construction: Generate metrics boxplot
ri-heme May 29, 2024
31987d7
:construction: Make labels optional
ri-heme May 29, 2024
3354d88
:construction: Save plot PNG
ri-heme May 29, 2024
0d7587f
:construction: Plot feature importance
ri-heme May 29, 2024
2584720
:construction: Allow fig kwargs
ri-heme May 29, 2024
5b104ef
:art: Type hints
ri-heme May 29, 2024
44d19c4
:art: Sort imports
ri-heme May 29, 2024
69ddf4d
:construction: Configure associations task
ri-heme May 30, 2024
ad564db
:memo: Docstrings
ri-heme May 30, 2024
6349e62
:construction: Add dataset attributes
ri-heme Jul 8, 2024
cf07e22
:sparkles: Instantiate from config
ri-heme Jul 8, 2024
b7c2e5d
:memo: Update tutorial
ri-heme Jul 8, 2024
55c4278
:art: :construction: Clean-up imports/style
ri-heme Jul 8, 2024
34e3686
:wrench: Configure item sep for multirun
ri-heme Jul 12, 2024
4f7eb27
:sparkles: Allow CSV writer to append
ri-heme Jul 12, 2024
d0e2d65
:construction: Add tuning task
ri-heme Jul 12, 2024
1a76291
:memo: Add docstrings
ri-heme Jul 12, 2024
f688377
:construction: Ensure protocol 4 is used to save models
ri-heme Jul 12, 2024
e2cebfb
:construction: Tune based on loss
ri-heme Jul 12, 2024
0464fc6
:construction: Create splitting module
ri-heme Jul 12, 2024
760d222
:recycle: :construction: Do not sort indices
ri-heme Jul 16, 2024
08a76ee
:construction: Split data
ri-heme Jul 16, 2024
1c26538
:construction: Load correct data split
ri-heme Jul 16, 2024
967a6e7
:construction: Record loss using test dataloader
ri-heme Jul 16, 2024
a5e0e17
:bug: Fix orphan tasks
ri-heme Jul 17, 2024
4144fd6
:construction: Record accuracy metrics
ri-heme Jul 17, 2024
46fc7de
:construction: Add stability tuning
ri-heme Jul 17, 2024
a034cd8
:sparkles: Accept CSV files as input
ri-heme Jul 17, 2024
8c9afbb
:memo: Comment/docstrings
ri-heme Jul 17, 2024
bd43069
:loud_sound: Re-arrange logging
ri-heme Jul 30, 2024
ba654ea
:recycle: Plot epochs instead of steps
ri-heme Jul 30, 2024
430b5b6
:art: Add overload
ri-heme Jul 30, 2024
7e83bd6
:construction: Use all data for plots
ri-heme Jul 30, 2024
b1def92
:construction: Apply global seed
ri-heme Aug 2, 2024
3fc9acf
:recycle: :construction: Refactor
ri-heme Aug 6, 2024
66b2155
:construction: Read TSV files
ri-heme Aug 29, 2024
4534330
:recycle: Create generate grid function
ri-heme Aug 29, 2024
fa99c28
:recycle: :wrench: Change batch size config
ri-heme Sep 3, 2024
b8ed7ff
:bug: Handle missingness
ri-heme Sep 3, 2024
b7feceb
:bug: Do not shuffle test data
ri-heme Sep 3, 2024
8f341ee
:construction: Correct op names
ri-heme Sep 10, 2024
261d3f3
:bug: Handle NaNs in norm
ri-heme Sep 10, 2024
42045c0
:construction: Add/sort imports
ri-heme Sep 10, 2024
df0bc87
:construction: Add missing import
ri-heme Sep 10, 2024
5f61e2d
:bug: :construction: Fix circular imports
ri-heme Sep 10, 2024
b98af90
:bug: Plot boxplots if NaNs
ri-heme Sep 11, 2024
b9e3ec8
:construction: Handle NaNs in accuracy
ri-heme Sep 11, 2024
f3749c7
:recycle: Make mapping a property
ri-heme Sep 12, 2024
50e7e6a
:bug: Fix factory method
ri-heme Sep 12, 2024
3036c9e
:sparkles: Add scale module
ri-heme Sep 16, 2024
b90f788
:construction: Determine scale automatically
ri-heme Sep 16, 2024
ae115bd
:construction: :wrench: Update config
ri-heme Sep 27, 2024
9e16acf
:construction: Start epoch/step at 1
ri-heme Sep 27, 2024
2b5ab18
:construction: :bug: Handle no discrete datasets
ri-heme Sep 30, 2024
9d63bf7
:construction: Refactor empty loss tensors
ri-heme Oct 1, 2024
da7f0a9
:sparkles: Add Prodigy optimizer
ri-heme Sep 12, 2024
d7ce91b
:sparkles: Log LR
ri-heme Sep 12, 2024
0bbbac7
:mute: Remove print statement
ri-heme Sep 12, 2024
ebf1b16
:construction: Add import
ri-heme Sep 16, 2024
eda4c92
:bug: Correctly generate file names
ri-heme Oct 1, 2024
1d9b855
:wrench: :construction: Implement weights on VAE
ri-heme Oct 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
__pycache__/
*.py[cod]

# NumPy binary files
data*/*.npy
# NumPy/PyTorch binary files
*.npy
*.pt

# Distribution and packaging files
build/
Expand All @@ -31,15 +32,29 @@ outputs/
*.log

# Tutorial files
**/interim_data/
**/processed_data/
**/results/
tutorial/maize/data
tutorial/*
!tutorial/config/*maize*.yaml
!tutorial/config/*random_small*.yaml
!tutorial/data
!tutorial/maize/maize_dataset.py
!tutorial/notebooks/*.ipynb
!tutorial/README.md

# Virtual environment
venv/
virtualvenv/

# docs files
docs/build/
docs/source/_templates/
docs/source/_templates/

# Root folder
/*.*
!/.gitignore
!/.readthedocs.yaml
!/LICENSE
!/MANIFEST.in
!/README.md
!/pyproject.toml
!/requirements.txt
!/setup.cfg
9 changes: 4 additions & 5 deletions src/move/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

__license__ = "MIT"
__version__ = (1, 4, 9)
__all__ = ["conf", "data", "models", "training_loop", "VAE"]
__version__ = (2, 0, 0)
__all__ = ["conf", "data", "models", "tasks", "viz"]

HYDRA_VERSION_BASE = "1.2"

from move import conf, data, models
from move.models.vae import VAE
from move.training.training_loop import training_loop
import move.visualization as viz
from move import conf, data, models, tasks
24 changes: 8 additions & 16 deletions src/move/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@
import hydra
from omegaconf import OmegaConf

import move.tasks
from move import HYDRA_VERSION_BASE
from move.conf.schema import (
AnalyzeLatentConfig,
EncodeDataConfig,
IdentifyAssociationsConfig,
MOVEConfig,
TuneModelConfig,
)
from move.conf.schema import SUPPORTED_TASKS, MOVEConfig
from move.core.logging import get_logger
from move.core.seed import set_global_seed
from move.tasks.base import Task


@hydra.main(
Expand All @@ -32,14 +27,11 @@ def main(config: MOVEConfig) -> None:
if task_type is None:
logger = get_logger("move")
logger.info("No task specified.")
elif task_type is EncodeDataConfig:
move.tasks.encode_data(config.data)
elif issubclass(task_type, TuneModelConfig):
move.tasks.tune_model(config)
elif task_type is AnalyzeLatentConfig:
move.tasks.analyze_latent(config)
elif issubclass(task_type, IdentifyAssociationsConfig):
move.tasks.identify_associations(config)
elif issubclass(task_type, SUPPORTED_TASKS):
if config.seed is not None:
set_global_seed(config.seed)
task: Task = hydra.utils.instantiate(config.task, _recursive_=False)
task.run()
else:
raise ValueError("Unsupported type of task.")

Expand Down
18 changes: 18 additions & 0 deletions src/move/analysis/fdr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import math
from typing import cast

import numpy as np
from numpy.typing import NDArray


def argnearest(array: NDArray, target: float) -> int:
"""Find value in array closest to target. Assumes array is sorted in
ascending order."""
idx = np.searchsorted(array, target, side="left")
if idx > 0 and (
idx == len(array)
or math.fabs(target - array[idx - 1]) < math.fabs(target - array[idx])
):
return cast(int, idx - 1)
else:
return cast(int, idx)
99 changes: 99 additions & 0 deletions src/move/analysis/feature_importance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
__all__ = ["FeatureImportance"]

from typing import TYPE_CHECKING

import pandas as pd
import torch

import move.visualization as viz
from move.core.exceptions import UnsetProperty
from move.data.io import sanitize_filename
from move.tasks.base import CsvWriterMixin, ParentTask, SubTask

if TYPE_CHECKING:
from move.data.dataloader import MoveDataLoader
from move.models.base import BaseVae


class FeatureImportance(CsvWriterMixin, SubTask):
"""Compute feature importance in latent space.

Feature importance is computed per feature per dataset. For each dataset,
a file will be created.

Feature importance is computed as the sum of differences in latent
variables generated when a feature is present/removed."""

data_filename_fmt: str = "feature_importance_{}.csv"
plot_filename_fmt: str = "feature_importance_{}.png"

def __init__(
self, parent: ParentTask, model: "BaseVae", dataloader: "MoveDataLoader"
) -> None:
self.parent = parent
self.model = model
self.dataloader = dataloader

def plot(self) -> None:
if self.parent is None:
return
for dataset in self.dataloader.datasets:
csv_filename = self.data_filename_fmt.format(dataset.name)
csv_filepath = self.parent.output_dir / sanitize_filename(csv_filename)
fig_filename = self.plot_filename_fmt.format(dataset.name)
fig_filepath = self.parent.output_dir / sanitize_filename(fig_filename)

diffs = pd.read_csv(csv_filepath)

if dataset.data_type == "continuous":
fig = viz.plot_continuous_feature_importance(
diffs.values, dataset.tensor.numpy(), dataset.feature_names
)
else:
# Categorical dataset is re-shaped to 3D shape
dataset_shape = getattr(dataset, "original_shape")
fig = viz.plot_categorical_feature_importance(
diffs.values,
dataset.tensor.reshape(-1, *dataset_shape).numpy(),
dataset.feature_names,
getattr(dataset, "mapping"),
)

fig.savefig(fig_filepath, bbox_inches="tight")

@torch.no_grad()
def run(self) -> None:
for dataset in self.dataloader.datasets:
self.log(f"Computing feature importance: '{dataset}'")
# Create a file for each dataset
# File is transposed; each column is a sample, each row a feature
if self.parent:
csv_filename = sanitize_filename(self.data_filename_fmt.format(dataset))
csv_filepath = self.parent.output_dir / csv_filename
colnames = ["feature_name"] + [""] * len(self.dataloader.dataset)
self.init_csv_writer(
csv_filepath, fieldnames=colnames, extrasaction="ignore"
)
else:
raise UnsetProperty("Parent task")

# Make a perturbation for each feature
for feature_name in dataset.feature_names:
value = None if dataset.data_type == "discrete" else 0.0
self.dataloader.dataset.perturb(dataset.name, feature_name, value)
row = [feature_name]
for tup in self.dataloader:
batch, pert_batch, _ = tup
z = self.model.project(batch)
z_pert = self.model.project(pert_batch)
diff = torch.sum(z_pert - z, dim=-1)
row.extend(diff.tolist())
self.write_row(row)

self.close_csv_writer(clear=True)

# Transpose CSV file, so each row is a sample, each column a feature
pd.read_csv(csv_filepath).T.to_csv(csv_filepath, index=False, header=False)

# Clear perturbation
self.dataloader.dataset.perturbation = None
36 changes: 36 additions & 0 deletions src/move/analysis/hdi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import math

import torch


def hdi_bounds(
x: torch.Tensor, hdi_prob: float = 0.95
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return highest density interval (HDI) of a samples-features matrix.
The HDI represents the range within which most of the samples are located.

Args:
x: Matrix (`num_samples` x `num_features`)
hdi_prob: Percentage of samples inside the HDI

Returns:
Lower and upper bounds of HDI
"""
# adapated from arviz

if x.dim() != 2:
raise ValueError("Can only calculate for matrices with two dimensions")

n = x.size(0)
x, _ = torch.sort(x, dim=0)

interval_idx_inc = math.floor(hdi_prob * n)
num_intervals = n - interval_idx_inc

interval_width = x[interval_idx_inc:] - x[:num_intervals]
min_idx = torch.argmin(interval_width, dim=0)

hdi_min = torch.diag(x[min_idx])
hdi_max = torch.diag(x[min_idx + interval_idx_inc])

return hdi_min, hdi_max
77 changes: 74 additions & 3 deletions src/move/analysis/metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
__all__ = ["calculate_accuracy", "calculate_cosine_similarity"]

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, cast

import numpy as np
import pandas as pd
import torch

import move.visualization as viz
from move.core.typing import FloatArray
from move.tasks.base import CsvWriterMixin, ParentTask, SubTask

if TYPE_CHECKING:
from move.data.dataloader import MoveDataLoader
from move.models.base import BaseVae


def calculate_accuracy(
Expand Down Expand Up @@ -33,7 +44,7 @@ def calculate_accuracy(
y_pred = np.ma.masked_array(reconstruction, mask=is_nan)

num_features = np.ma.count(y_true, axis=1)
scores = np.ma.filled(np.sum(y_true == y_pred, axis=1) / num_features, 0)
scores = np.ma.filled(np.sum(y_true == y_pred, axis=1) / num_features, np.nan)

return scores

Expand Down Expand Up @@ -64,7 +75,7 @@ def calculate_cosine_similarity(

# Equivalent to `np.diag(sklearn.metrics.pairwise.cosine_similarity(x, y))`
# But can handle masked arrays
scores = np.ma.compressed(np.sum(x * y, axis=1)) / (norm(x) * norm(y))
scores = np.ma.filled(np.sum(x * y, axis=1), np.nan) / (norm(x) * norm(y))

return scores

Expand All @@ -80,4 +91,64 @@ def norm(x: np.ma.MaskedArray, axis: int = 1) -> FloatArray:
Returns:
1D array with the specified axis removed.
"""
return np.ma.compressed(np.sqrt(np.sum(x**2, axis=axis)))
return np.ma.filled(np.sqrt(np.sum(x**2, axis=axis)), np.nan)


class ComputeAccuracyMetrics(CsvWriterMixin, SubTask):
"""Compute accuracy metrics between original input and reconstruction (use
cosine similarity for continuous dataset reconstructions)."""

data_filename: str = "reconstruction_metrics.csv"
plot_filename: str = "reconstruction_metrics.png"

def __init__(
self, parent: ParentTask, model: "BaseVae", dataloader: "MoveDataLoader"
) -> None:
self.parent = parent
self.model = model
self.dataloader = dataloader

def plot(self) -> None:
if self.parent and self.csv_filepath:
scores = pd.read_csv(self.csv_filepath, index_col=None)
fig = viz.plot_metrics_boxplot(scores, labels=None)
fig_path = self.parent.output_dir / self.plot_filename
fig.savefig(fig_path, bbox_inches="tight")

@torch.no_grad()
def run(self) -> None:
if self.parent:
csv_filepath = self.parent.output_dir / self.data_filename
colnames = self.dataloader.dataset.dataset_names
self.init_csv_writer(
csv_filepath, fieldnames=colnames, extrasaction="ignore"
)
else:
self.log("No parent task, metrics will not be saved.", "WARNING")

self.log("Computing accuracy metrics")

datasets = self.dataloader.datasets
for batch in self.dataloader:
batch_disc, batch_cont = self.model.split_input(batch[0])
recon = self.model.reconstruct(batch[0], as_one=True)
recon_disc, recon_cont = self.model.split_input(recon)

scores_per_dataset = {}
for i, dataset in enumerate(datasets[: len(batch_disc)]):
target = batch_disc[i].numpy()
preds = torch.argmax(
(torch.log_softmax(recon_disc[i], dim=-1)), dim=-1
).numpy()
scores = calculate_accuracy(target, preds)
scores_per_dataset[dataset.name] = scores

for i, dataset in enumerate(datasets[len(batch_disc) :]):
target = batch_cont[i].numpy()
preds = recon_cont[i].numpy()
scores = calculate_cosine_similarity(target, preds)
scores_per_dataset[dataset.name] = scores

self.write_cols(scores_per_dataset)

self.close_csv_writer()
16 changes: 14 additions & 2 deletions src/move/conf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
__all__ = ["MOVEConfig"]
__all__ = [
"AdamConfig",
"AdamWConfig",
"ProdigyConfig",
"SgdConfig",
"TrainingDataLoaderConfig",
"TrainingLoopConfig",
"VaeConfig",
"VaeNormalConfig",
"VaeTConfig",
]

from move.conf.schema import MOVEConfig
from move.conf.models import VaeConfig, VaeNormalConfig, VaeTConfig
from move.conf.optim import AdamConfig, AdamWConfig, ProdigyConfig, SgdConfig
from move.conf.training import TrainingDataLoaderConfig, TrainingLoopConfig
6 changes: 6 additions & 0 deletions src/move/conf/config_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
__all__ = ["config_store"]

from hydra.core.config_store import ConfigStore

config_store = ConfigStore.instance()
"""Hydra's config store singleton"""
Loading