Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Developer continuous #84

Merged
merged 73 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
50bfa61
:loud_sound: Capture warnings
ri-heme Jan 4, 2023
63f7c2f
:bug: VAE and data objects handle missing data type
ri-heme Jan 4, 2023
5c0c90b
:construction: Update tasks for missing data type
ri-heme Jan 4, 2023
d831a49
:construction: Committing provisional changes (not finished)
mpielies Nov 10, 2022
b5e1044
:construction: min max feature pertubation
mpielies Nov 10, 2022
ff8f8aa
Last batch of changes (unfinished)
mpielies Nov 10, 2022
26bb304
preprocessing changes
mpielies Nov 10, 2022
aa28897
Version that reaches training
mpielies Nov 11, 2022
0826212
last version
mpielies Nov 11, 2022
98a5dfa
Changes on writing the final files
mpielies Nov 14, 2022
c838976
Last commit:
mpielies Nov 14, 2022
9d560bc
# Added plus_std and minus_std as target values
mpielies Nov 14, 2022
081e737
New changes for plotting new features
mpielies Nov 15, 2022
85ce21f
Note in funtion description updated
mpielies Nov 16, 2022
337ddc5
Perturbation visualization added:
mpielies Nov 16, 2022
c7b4f1f
:build: Perturbation visualization
mpielies Nov 16, 2022
f02c63e
:art: :sparkles:
mpielies Nov 22, 2022
5248274
:pencil: Make output_subpath optional
mpielies Nov 22, 2022
79b2849
:rewind: Revert styling in conflicting files
ri-heme Nov 22, 2022
1036656
:art: Rename results directory
ri-heme Nov 22, 2022
71dd894
:art:
mpielies Nov 23, 2022
c9cf123
:sparkles: Random continuous dataset added.
mpielies Dec 8, 2022
a3ddf48
:art: Type hinting
ri-heme Jan 5, 2023
1c64573
:sparkles: Update ID associations module
ri-heme Jan 5, 2023
59a6d7a
:art: :recycle: Refactoring/styling
ri-heme Jan 5, 2023
0530139
:wrench: Update default config
ri-heme Jan 6, 2023
6fd7589
:sparkles: :construction: :monocle_face:
mpielies Jan 24, 2023
c83fb42
Added random_basic config files
mpielies Feb 6, 2023
3300a86
:construction:
mpielies Feb 6, 2023
2cac930
:sparkles: :construction:
mpielies Feb 22, 2023
00dd85d
(previous commit's description)
mpielies Feb 22, 2023
968ff14
:sparkles: :construction:
mpielies Mar 29, 2023
2813ebd
:art: :fire: :wrench: :bulb:
mpielies Mar 29, 2023
831d8ce
:fire: :see_no_evil: Ignore VS Code settings
ri-heme Apr 25, 2023
0b23637
:bug: Fix plot legend (feature importance plot)
ri-heme Apr 3, 2023
048117e
:art: :fire: :sparkles:
mpielies Apr 28, 2023
8cd44d8
:twisted_rightwards_arrows: Merge branch 'developer'
ri-heme Jun 28, 2023
1e524eb
:fire: :see_no_evil: Remove/ignore DS Store files
ri-heme Jun 28, 2023
37a7590
:fire: :see_no_evil: Remove/ignore supplementary outputs
ri-heme Jun 28, 2023
9912e11
:art: Styling/sorting imports
ri-heme Jun 28, 2023
f8bd98d
:art: Styling
ri-heme Jul 3, 2023
8c49d23
:bug: Properly re-shape categorical recon
ri-heme Jul 3, 2023
e03eec2
:art: :zap: :bug: Editing pull request.
mpielies Jul 12, 2023
3ab3a6d
:bug: :wrench:
mpielies Aug 15, 2023
030e019
:bug: Keep dimensions if NaN
ri-heme Feb 1, 2024
a54cf35
:fire: Remove
ri-heme Feb 2, 2024
fb0dd98
Merge branch 'developer' into developer-continuous-v3
May 16, 2024
4164f41
:sparkles: add workflow for formatting checks
May 16, 2024
439c8cf
:art: format files
May 16, 2024
3ff9842
:sparkles: lint src files with flake8 - might introduce regression er…
May 16, 2024
d722c0b
:bug: configure flake8 for black formatting
May 16, 2024
a7e6275
:bug: :rewind: HYDRA_VERSION_BASE needed for imports
May 16, 2024
6184235
:bug: regression: merging branches lead to name mismatch
May 16, 2024
daa74b7
:bug: correct value in yaml file
May 16, 2024
e6368a9
:white_check_mark: Add integration test based on tutorial
May 16, 2024
0542ff2
:art: format yaml files (some had no last blank line)
May 31, 2024
087dcb7
:art: add typehints and always return figure for plotting fcts
May 31, 2024
2b5b098
:art: split bayes and t-test run in CI
May 31, 2024
2cb9aa5
:art: format file
Jun 1, 2024
0dc2a0f
:rewind: random small defaults restored
Jun 4, 2024
25fc96b
:wrench: Add default KS config
ri-heme Jun 4, 2024
82c6624
:wrench: Use default KS config
ri-heme Jun 4, 2024
09ca32f
:zap: reduce action runtime by configuring tasks
Jun 4, 2024
6a02d0e
:art: Make sure correct colormap is used
ri-heme Jun 4, 2024
2a12ed1
Merge branch 'developer-continuous-v3' of https://github.com/Rasmusse…
ri-heme Jun 4, 2024
2e1c6ba
:wrench: udpate flake8 configuration according to defaults
Jun 4, 2024
6207350
:sparkles: dry-run continuous configuration sample data
Jun 4, 2024
e4e2283
:art: make line 88 characters long, remove some expections from flake8
Jun 4, 2024
5e23447
:bug: error due to black formatting?
Jun 4, 2024
7796f5e
:sparkles: run continous example
Jun 4, 2024
48a76b7
:bug: four latent dimensions are required for t-test
Jun 4, 2024
5533eca
:zap: run latent job, try to increase speed for t-test
Jun 7, 2024
d0a4409
:zap: speed up and balance both jobs' runtime
Jun 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,48 @@
name: release on pypi
on:
push:
branches:
- main
# branches:
# - main

jobs:
format:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: psf/black@stable
lint:
name: Lint with flake8
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install flake8
run: pip install flake8 flake8-bugbear
- name: Lint with flake8
run: flake8 src

publish:
name: Publish package
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags')
needs:
- format
- lint
steps:
- uses: actions/checkout@v3
- name: Publish package
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install twine and build
run: python -m pip install --upgrade twine build
- name: Build
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
Expand Down
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ tutorial/*
!tutorial/notebooks/*.ipynb
!tutorial/README.md

# Supplementary files
supplementary_files/*.png
supplementary_files/*.tsv
supplementary_files/*.txt

# Virtual environment
venv/
virtualvenv/
Expand All @@ -48,6 +53,12 @@ virtualvenv/
docs/build/
docs/source/_templates/

# VS Code settings
.vscode

# macOS
.DS_Store

# Root folder
/*.*
!/.gitignore
Expand All @@ -58,3 +69,4 @@ docs/source/_templates/
!/pyproject.toml
!/requirements.txt
!/setup.cfg
!/.github
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

html_theme = "sphinx_rtd_theme"
html_theme_options = {
"collapse_navigation" : False,
"collapse_navigation": False,
}
html_static_path = []

Expand Down
8 changes: 7 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ install_requires =
matplotlib
seaborn
scikit-learn
scipy
scipy>=1.10.0

package_dir =
= src
Expand All @@ -34,3 +34,9 @@ where = src
[options.entry_points]
console_scripts =
move-dl=move.__main__:main

[flake8]
max-line-length = 120
ri-heme marked this conversation as resolved.
Show resolved Hide resolved
aggressive = 2
extend-select = B950
extend-ignore = E203,E501,E701
7 changes: 3 additions & 4 deletions src/move/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations
from move.training.training_loop import training_loop
from move.models.vae import VAE
from move import conf, data, models

__license__ = "MIT"
__version__ = (1, 4, 10)
__all__ = ["conf", "data", "models", "training_loop", "VAE"]

HYDRA_VERSION_BASE = "1.2"

from move import conf, data, models
from move.models.vae import VAE
from move.training.training_loop import training_loop
16 changes: 16 additions & 0 deletions src/move/analysis/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,19 @@ def norm(x: np.ma.MaskedArray, axis: int = 1) -> np.ma.MaskedArray:
1D array with the specified axis removed.
"""
return np.sqrt(np.sum(x**2, axis=axis))


def get_2nd_order_polynomial(x_array, y_array, n_points=100):
"""
Given a set of x an y values, find the 2nd oder polynomial fitting best the data.

Returns:
x_pol: x coordinates for the polynomial function evaluation.
y_pol: y coordinates for the polynomial function evaluation.
"""
a2, a1, a = np.polyfit(x_array, y_array, deg=2)

x_pol = np.linspace(np.min(x_array), np.max(x_array), n_points)
y_pol = np.array([a2 * x * x + a1 * x + a for x in x_pol])

return x_pol, y_pol, (a2, a1, a)
1 change: 1 addition & 0 deletions src/move/conf/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ hydra:
job:
config:
override_dirname:
item_sep: ";"
exclude_keys:
- experiment

Expand Down
27 changes: 27 additions & 0 deletions src/move/conf/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class InputConfig:
name: str
weight: int = 1


@dataclass
class ContinuousInputConfig(InputConfig):
scale: bool = True
Expand Down Expand Up @@ -185,6 +186,27 @@ class IdentifyAssociationsTTestConfig(IdentifyAssociationsConfig):
num_latent: list[int] = MISSING


@dataclass
class IdentifyAssociationsKSConfig(IdentifyAssociationsConfig):
"""Configure the Kolmogorov-Smirnov approach to identify associations.

Args:
perturbed_feature_names: names of the perturbed features of interest.
target_feature_names: names of the target features of interest.

Description:
For each perturbed feature - target feature pair, we will plot:
- Input vs. reconstruction correlation plot: to assess reconstruction
quality of both target and perturbed features.
- Distribution of reconstruction values for the target feature before
and after the perturbation of the perturbed feature.

"""

perturbed_feature_names: list[str] = field(default_factory=list)
target_feature_names: list[str] = field(default_factory=list)


@dataclass
class MOVEConfig:
defaults: list[Any] = field(default_factory=lambda: [dict(data="base_data")])
Expand Down Expand Up @@ -237,6 +259,11 @@ def extract_names(configs: list[InputConfig]) -> list[str]:
name="identify_associations_ttest_schema",
node=IdentifyAssociationsTTestConfig,
)
cs.store(
group="task",
name="identify_associations_ks_schema",
node=IdentifyAssociationsKSConfig,
)

# Register custom resolvers
OmegaConf.register_new_resolver("weights", extract_weights)
Expand Down
2 changes: 1 addition & 1 deletion src/move/conf/task/analyze_latent.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ model:
num_latent: 150
beta: 0.0001
dropout: 0.1
cuda: false
cuda: False
ri-heme marked this conversation as resolved.
Show resolved Hide resolved

training_loop:
lr: 1e-4
Expand Down
2 changes: 2 additions & 0 deletions src/move/conf/task/identify_associations_bayes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ training_loop:
- 25
early_stopping: false
patience: 0


2 changes: 2 additions & 0 deletions src/move/conf/task/identify_associations_ttest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ training_loop:
- 25
early_stopping: false
patience: 0


92 changes: 91 additions & 1 deletion src/move/data/perturbations.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
__all__ = ["perturb_categorical_data", "perturb_continuous_data"]

from typing import cast
from pathlib import Path
from typing import Literal, Optional, cast

import numpy as np
import torch
from torch.utils.data import DataLoader

from move.data.dataloaders import MOVEDataset
from move.data.preprocessing import feature_stats
from move.visualization.dataset_distributions import plot_value_distributions

ContinuousPerturbationType = Literal["minimum", "maximum", "plus_std", "minus_std"]


def perturb_categorical_data(
Expand Down Expand Up @@ -111,3 +116,88 @@ def perturb_continuous_data(
dataloaders.append(perturbed_dataloader)

return dataloaders


def perturb_continuous_data_extended(
baseline_dataloader: DataLoader,
con_dataset_names: list[str],
target_dataset_name: str,
perturbation_type: ContinuousPerturbationType,
output_subpath: Optional[Path] = None,
) -> list[DataLoader]:
"""Add perturbations to continuous data. For each feature in the target
dataset, change the feature's value in all samples (in rows):
1,2) substituting this feature in all samples by the feature's minimum/maximum value.
3,4) Adding/Substracting one standard deviation to the sample's feature value.

Args:
baseline_dataloader: Baseline dataloader
con_dataset_names: List of continuous dataset names
target_dataset_name: Target continuous dataset to perturb
perturbation_type: 'minimum', 'maximum', 'plus_std' or 'minus_std'.
output_subpath: path where the figure showing the perturbation will be saved

Returns:
- List of dataloaders containing all perturbed datasets
- Plot of the feature value distribution after the perturbation. Note that
all perturbations are collapsed into one single plot.

Note:
This function was created so that it could generalize to non-normalized
datasets. Scaling is done per dataset, not per feature -> slightly different stds
feature to feature.
"""

baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
assert baseline_dataset.con_shapes is not None
assert baseline_dataset.con_all is not None

target_idx = con_dataset_names.index(target_dataset_name) # dataset index
splits = np.cumsum([0] + baseline_dataset.con_shapes)
slice_ = slice(*splits[target_idx : target_idx + 2])

num_features = baseline_dataset.con_shapes[target_idx]
dataloaders = []
perturbations_list = []

for i in range(num_features):
perturbed_con = baseline_dataset.con_all.clone()
target_dataset = perturbed_con[:, slice_]
# Change the desired feature value by:
min_feat_val_list, max_feat_val_list, std_feat_val_list = feature_stats(
target_dataset
)
if perturbation_type == "minimum":
target_dataset[:, i] = torch.FloatTensor([min_feat_val_list[i]])
elif perturbation_type == "maximum":
target_dataset[:, i] = torch.FloatTensor([max_feat_val_list[i]])
elif perturbation_type == "plus_std":
target_dataset[:, i] += torch.FloatTensor([std_feat_val_list[i]])
elif perturbation_type == "minus_std":
target_dataset[:, i] -= torch.FloatTensor([std_feat_val_list[i]])

perturbations_list.append(target_dataset[:, i].numpy())

perturbed_dataset = MOVEDataset(
baseline_dataset.cat_all,
perturbed_con,
baseline_dataset.cat_shapes,
baseline_dataset.con_shapes,
)

perturbed_dataloader = DataLoader(
perturbed_dataset,
shuffle=False,
batch_size=baseline_dataloader.batch_size,
)
dataloaders.append(perturbed_dataloader)

# Plot the perturbations for all features, collapsed in one plot:
if output_subpath is not None:
fig = plot_value_distributions(np.array(perturbations_list).transpose())
fig_path = str(
output_subpath / f"perturbation_distribution_{target_dataset_name}.png"
)
fig.savefig(fig_path)

return dataloaders
22 changes: 21 additions & 1 deletion src/move/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def one_hot_encode_single(mapping: dict[str, int], value: Optional[str]) -> IntA
Returns:
2D array
"""
encoded_value = np.zeros((1, len(mapping)))
encoded_value = np.zeros((1, len(mapping)), dtype=int)
if not pd.isna(value):
code = mapping[str(value)]
encoded_value[0, code] = 1
Expand All @@ -79,3 +79,23 @@ def scale(x: np.ndarray) -> tuple[FloatArray, BoolArray]:
scaled_x = standardize(logx[:, mask_1d], axis=0)
scaled_x[np.isnan(scaled_x)] = 0
return scaled_x, mask_1d


def feature_stats(x: ArrayLike) -> tuple[FloatArray, FloatArray, FloatArray]:
"""
Read an array of continuous values and extract the
minimum, maximum and standard deviation per column (feature).

Args:
x: 2D array with samples in its rows and features in its columns

Returns:
minimum: list with minimum value per feature (column)
ri-heme marked this conversation as resolved.
Show resolved Hide resolved
maximum: list with maximum " "
std: list with std " "
"""

minimum = np.nanmin(x, axis=0)
maximum = np.nanmax(x, axis=0)
std = np.nanstd(x, axis=0)
return minimum, maximum, std
Loading