Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Developer continuous #84

Merged
merged 73 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
50bfa61
:loud_sound: Capture warnings
ri-heme Jan 4, 2023
63f7c2f
:bug: VAE and data objects handle missing data type
ri-heme Jan 4, 2023
5c0c90b
:construction: Update tasks for missing data type
ri-heme Jan 4, 2023
d831a49
:construction: Committing provisional changes (not finished)
mpielies Nov 10, 2022
b5e1044
:construction: min max feature pertubation
mpielies Nov 10, 2022
ff8f8aa
Last batch of changes (unfinished)
mpielies Nov 10, 2022
26bb304
preprocessing changes
mpielies Nov 10, 2022
aa28897
Version that reaches training
mpielies Nov 11, 2022
0826212
last version
mpielies Nov 11, 2022
98a5dfa
Changes on writing the final files
mpielies Nov 14, 2022
c838976
Last commit:
mpielies Nov 14, 2022
9d560bc
# Added plus_std and minus_std as target values
mpielies Nov 14, 2022
081e737
New changes for plotting new features
mpielies Nov 15, 2022
85ce21f
Note in funtion description updated
mpielies Nov 16, 2022
337ddc5
Perturbation visualization added:
mpielies Nov 16, 2022
c7b4f1f
:build: Perturbation visualization
mpielies Nov 16, 2022
f02c63e
:art: :sparkles:
mpielies Nov 22, 2022
5248274
:pencil: Make output_subpath optional
mpielies Nov 22, 2022
79b2849
:rewind: Revert styling in conflicting files
ri-heme Nov 22, 2022
1036656
:art: Rename results directory
ri-heme Nov 22, 2022
71dd894
:art:
mpielies Nov 23, 2022
c9cf123
:sparkles: Random continuous dataset added.
mpielies Dec 8, 2022
a3ddf48
:art: Type hinting
ri-heme Jan 5, 2023
1c64573
:sparkles: Update ID associations module
ri-heme Jan 5, 2023
59a6d7a
:art: :recycle: Refactoring/styling
ri-heme Jan 5, 2023
0530139
:wrench: Update default config
ri-heme Jan 6, 2023
6fd7589
:sparkles: :construction: :monocle_face:
mpielies Jan 24, 2023
c83fb42
Added random_basic config files
mpielies Feb 6, 2023
3300a86
:construction:
mpielies Feb 6, 2023
2cac930
:sparkles: :construction:
mpielies Feb 22, 2023
00dd85d
(previous commit's description)
mpielies Feb 22, 2023
968ff14
:sparkles: :construction:
mpielies Mar 29, 2023
2813ebd
:art: :fire: :wrench: :bulb:
mpielies Mar 29, 2023
831d8ce
:fire: :see_no_evil: Ignore VS Code settings
ri-heme Apr 25, 2023
0b23637
:bug: Fix plot legend (feature importance plot)
ri-heme Apr 3, 2023
048117e
:art: :fire: :sparkles:
mpielies Apr 28, 2023
8cd44d8
:twisted_rightwards_arrows: Merge branch 'developer'
ri-heme Jun 28, 2023
1e524eb
:fire: :see_no_evil: Remove/ignore DS Store files
ri-heme Jun 28, 2023
37a7590
:fire: :see_no_evil: Remove/ignore supplementary outputs
ri-heme Jun 28, 2023
9912e11
:art: Styling/sorting imports
ri-heme Jun 28, 2023
f8bd98d
:art: Styling
ri-heme Jul 3, 2023
8c49d23
:bug: Properly re-shape categorical recon
ri-heme Jul 3, 2023
e03eec2
:art: :zap: :bug: Editing pull request.
mpielies Jul 12, 2023
3ab3a6d
:bug: :wrench:
mpielies Aug 15, 2023
030e019
:bug: Keep dimensions if NaN
ri-heme Feb 1, 2024
a54cf35
:fire: Remove
ri-heme Feb 2, 2024
fb0dd98
Merge branch 'developer' into developer-continuous-v3
May 16, 2024
4164f41
:sparkles: add workflow for formatting checks
May 16, 2024
439c8cf
:art: format files
May 16, 2024
3ff9842
:sparkles: lint src files with flake8 - might introduce regression er…
May 16, 2024
d722c0b
:bug: configure flake8 for black formatting
May 16, 2024
a7e6275
:bug: :rewind: HYDRA_VERSION_BASE needed for imports
May 16, 2024
6184235
:bug: regression: merging branches lead to name mismatch
May 16, 2024
daa74b7
:bug: correct value in yaml file
May 16, 2024
e6368a9
:white_check_mark: Add integration test based on tutorial
May 16, 2024
0542ff2
:art: format yaml files (some had no last blank line)
May 31, 2024
087dcb7
:art: add typehints and always return figure for plotting fcts
May 31, 2024
2b5b098
:art: split bayes and t-test run in CI
May 31, 2024
2cb9aa5
:art: format file
Jun 1, 2024
0dc2a0f
:rewind: random small defaults restored
Jun 4, 2024
25fc96b
:wrench: Add default KS config
ri-heme Jun 4, 2024
82c6624
:wrench: Use default KS config
ri-heme Jun 4, 2024
09ca32f
:zap: reduce action runtime by configuring tasks
Jun 4, 2024
6a02d0e
:art: Make sure correct colormap is used
ri-heme Jun 4, 2024
2a12ed1
Merge branch 'developer-continuous-v3' of https://github.com/Rasmusse…
ri-heme Jun 4, 2024
2e1c6ba
:wrench: udpate flake8 configuration according to defaults
Jun 4, 2024
6207350
:sparkles: dry-run continuous configuration sample data
Jun 4, 2024
e4e2283
:art: make line 88 characters long, remove some expections from flake8
Jun 4, 2024
5e23447
:bug: error due to black formatting?
Jun 4, 2024
7796f5e
:sparkles: run continous example
Jun 4, 2024
48a76b7
:bug: four latent dimensions are required for t-test
Jun 4, 2024
5533eca
:zap: run latent job, try to increase speed for t-test
Jun 7, 2024
d0a4409
:zap: speed up and balance both jobs' runtime
Jun 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,6 @@ console_scripts =

[flake8]
max-line-length = 120
ri-heme marked this conversation as resolved.
Show resolved Hide resolved
aggressive = 2
aggressive = 2
extend-select = B950
extend-ignore = E203,E501,E701
6 changes: 3 additions & 3 deletions src/move/data/perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def perturb_categorical_data(
splits = np.cumsum(
[0] + [int.__mul__(*shape) for shape in baseline_dataset.cat_shapes]
)
slice_ = slice(*splits[target_idx: target_idx + 2])
slice_ = slice(*splits[target_idx : target_idx + 2])

target_shape = baseline_dataset.cat_shapes[target_idx]
num_features = target_shape[0] # CHANGE
Expand Down Expand Up @@ -93,7 +93,7 @@ def perturb_continuous_data(

target_idx = con_dataset_names.index(target_dataset_name)
splits = np.cumsum([0] + baseline_dataset.con_shapes)
slice_ = slice(*splits[target_idx: target_idx + 2])
slice_ = slice(*splits[target_idx : target_idx + 2])

num_features = baseline_dataset.con_shapes[target_idx]

Expand Down Expand Up @@ -154,7 +154,7 @@ def perturb_continuous_data_extended(

target_idx = con_dataset_names.index(target_dataset_name) # dataset index
splits = np.cumsum([0] + baseline_dataset.con_shapes)
slice_ = slice(*splits[target_idx: target_idx + 2])
slice_ = slice(*splits[target_idx : target_idx + 2])

num_features = baseline_dataset.con_shapes[target_idx]
dataloaders = []
Expand Down
32 changes: 18 additions & 14 deletions src/move/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def decompose_categorical(self, reconstruction: torch.Tensor) -> list[torch.Tens
cat_out = []
pos = 0
for cat_shape in self.categorical_shapes:
cat_dataset = cat_tmp[:, pos: (cat_shape[0] * cat_shape[1] + pos)]
cat_dataset = cat_tmp[:, pos : (cat_shape[0] * cat_shape[1] + pos)]

cat_out_tmp = cat_dataset.view(
cat_dataset.shape[0], cat_shape[0], cat_shape[1]
Expand Down Expand Up @@ -287,7 +287,7 @@ def calculate_cat_error(
cat_errors = []
pos = 0
for cat_shape in self.categorical_shapes:
cat_dataset = cat_in[:, pos: (cat_shape[0] * cat_shape[1] + pos)]
cat_dataset = cat_in[:, pos : (cat_shape[0] * cat_shape[1] + pos)]

cat_dataset = cat_dataset.view(cat_in.shape[0], cat_shape[0], cat_shape[1])
cat_target = cat_dataset
Expand Down Expand Up @@ -327,8 +327,8 @@ def calculate_con_error(
total_shape = 0
con_errors_list: list[torch.Tensor] = []
for s in self.continuous_shapes:
c_in = con_in[:, total_shape: (s + total_shape - 1)]
c_re = con_out[:, total_shape: (s + total_shape - 1)]
c_in = con_in[:, total_shape : (s + total_shape - 1)]
c_re = con_out[:, total_shape : (s + total_shape - 1)]
error = loss(c_re, c_in) / batch_size
con_errors_list.append(error)
total_shape += s
Expand Down Expand Up @@ -451,7 +451,9 @@ def encoding(
elif self.num_continuous > 0:
tensor = con
else:
raise ValueError("Must have at least 1 categorial or 1 continuous feature")
raise ValueError(
"Must have at least 1 categorial or 1 continuous feature"
)

optimizer.zero_grad()

Expand Down Expand Up @@ -538,21 +540,21 @@ def get_cat_recon(
shape_1 = 0
for cat_shape in self.categorical_shapes:
# Get input categorical data
cat_in_tmp = cat[:, pos: (cat_shape[0] * cat_shape[1] + pos)]
cat_in_tmp = cat[:, pos : (cat_shape[0] * cat_shape[1] + pos)]
cat_in_tmp = cat_in_tmp.view(cat.shape[0], cat_shape[0], cat_shape[1])

# Calculate target values for input
cat_target_tmp = cat_in_tmp
cat_target_tmp = torch.argmax(cat_target_tmp.detach(), dim=2)
cat_target_tmp[cat_in_tmp.sum(dim=2) == 0] = -1
cat_target[:, shape_1: (cat_shape[0] + shape_1)] = (
cat_target[:, shape_1 : (cat_shape[0] + shape_1)] = (
cat_target_tmp # .numpy()
)

# Get reconstructed categorical data
cat_out_tmp = cat_out[count]
cat_out_tmp = cat_out_tmp.transpose(1, 2)
cat_out_class[:, shape_1: (cat_shape[0] + shape_1)] = torch.argmax(
cat_out_class[:, shape_1 : (cat_shape[0] + shape_1)] = torch.argmax(
cat_out_tmp, dim=2
) # .numpy()

Expand Down Expand Up @@ -694,7 +696,9 @@ def latent(
elif self.num_continuous > 0:
tensor = con
else:
raise ValueError("Must have at least 1 categorial or 1 continuous feature")
raise ValueError(
"Must have at least 1 categorial or 1 continuous feature"
)

# Evaluate
cat_out, con_out, mu, logvar = self(tensor)
Expand All @@ -713,14 +717,14 @@ def latent(
cat_out_class, cat_target = self.get_cat_recon(
batch, cat_total_shape, cat, cat_out
)
cat_recon[row: row + len(cat_out_class)] = torch.Tensor(cat_out_class)
cat_class[row: row + len(cat_target)] = torch.Tensor(cat_target)
cat_recon[row : row + len(cat_out_class)] = torch.Tensor(cat_out_class)
cat_class[row : row + len(cat_target)] = torch.Tensor(cat_target)

if self.num_continuous > 0:
con_recon[row: row + len(con_out)] = con_out
con_recon[row : row + len(con_out)] = con_out

latent_var[row: row + len(logvar)] = logvar
latent[row: row + len(mu)] = mu
latent_var[row : row + len(logvar)] = logvar
latent[row : row + len(mu)] = mu
row += len(mu)

test_loss /= len(dataloader)
Expand Down
4 changes: 2 additions & 2 deletions src/move/tasks/identify_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,11 +598,11 @@ def _ks_approach(
edges,
hist_base,
hist_pert,
f"Cumulative_perturbed_{i}_measuring_{k}_stats_{stats[j,i,k]}",
f"Cumulative_perturbed_{i}_measuring_{k}_stats_{stats[j, i, k]}",
)
fig.savefig(
figure_path
/ f"Cumulative_refit_{j}_perturbed_{i}_measuring_{k}_stats_{stats[j,i,k]}.png"
/ f"Cumulative_refit_{j}_perturbed_{i}_measuring_{k}_stats_{stats[j, i, k]}.png"
)

# Feature changes:
Expand Down
2 changes: 1 addition & 1 deletion src/move/tasks/tune_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _tune_stability(
cosine_sim0 = None
cosine_sim_diffs = []
for j in range(task_config.num_refits):
logger.debug(f"Refit: {j+1}/{task_config.num_refits}")
logger.debug(f"Refit: {j + 1}/{task_config.num_refits}")
model: VAE = hydra.utils.instantiate(
task_config.model,
continuous_shapes=train_dataset.con_shapes,
Expand Down