Skip to content

Commit

Permalink
Merge pull request #66 from arnor-sigurdsson/add-more-array-output-mo…
Browse files Browse the repository at this point in the history
…dules

Add more array output modules
  • Loading branch information
arnor-sigurdsson authored Aug 27, 2023
2 parents 814e7e3 + 847f6dd commit bd23207
Show file tree
Hide file tree
Showing 57 changed files with 1,797 additions and 1,067 deletions.
17 changes: 14 additions & 3 deletions docs/automake_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
c_image_captioning,
)
from docs.doc_modules.d_array_outputs import a_array_mnist_generation
from docs.doc_modules.e_pretraining import a_mini_foundation
from docs.doc_modules.experiments import AutoDocExperimentInfo, make_tutorial_data


Expand Down Expand Up @@ -72,18 +73,28 @@ def _get_d_array_outputs_experiments() -> Iterable[AutoDocExperimentInfo]:
)


def _get_e_pretraining_outputs_experiments() -> Iterable[AutoDocExperimentInfo]:
a_experiments = a_mini_foundation.get_experiments()

return chain(
a_experiments,
)


if __name__ == "__main__":
a_using_eir_experiments = _get_a_using_eir_experiments()
c_sequence_outputs_experiments = _get_c_sequence_outputs_experiments()
b_customizing_eir_experiments = _get_b_customizing_eir_experiments()
d_array_outputs_experiments = _get_d_array_outputs_experiments()
e_pretraining_experiments = _get_e_pretraining_outputs_experiments()

experiment_iter = chain.from_iterable(
[
# a_using_eir_experiments,
c_sequence_outputs_experiments,
b_customizing_eir_experiments,
d_array_outputs_experiments,
# c_sequence_outputs_experiments,
# b_customizing_eir_experiments,
# d_array_outputs_experiments,
e_pretraining_experiments,
]
)
for experiment in experiment_iter:
Expand Down
30 changes: 30 additions & 0 deletions eir/data_load/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,33 @@ def make_random_omics_columns_missing(
omics_array[:, :, random_to_drop] = missing_arr

return omics_array


def shuffle_random_omics_columns(
omics_array: torch.Tensor, percentage: float = 0.05, probability: float = 1.0
) -> torch.Tensor:
random_draw = torch.rand(1).item()
if random_draw > probability:
return omics_array

n_snps = omics_array.shape[2]
n_to_shuffle = int(n_snps * percentage)
random_to_shuffle = torch.randperm(n_snps)[:n_to_shuffle].to(dtype=torch.long)

one_hot_random = torch.zeros(
omics_array.shape[0],
4,
n_to_shuffle,
dtype=torch.bool,
)

random_indices = torch.randint(
0,
4,
(omics_array.shape[0], n_to_shuffle),
)
one_hot_random.scatter_(1, random_indices.unsqueeze(1), 1)

omics_array[:, :, random_to_shuffle] = one_hot_random

return omics_array
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def prepare_inputs_disk(
genotype_array=array_raw,
na_augment_perc=input_type_info.na_augment_perc,
na_augment_prob=input_type_info.na_augment_prob,
shuffle_augment_perc=input_type_info.shuffle_augment_perc,
shuffle_augment_prob=input_type_info.shuffle_augment_prob,
test_mode=test_mode,
)

Expand Down Expand Up @@ -178,6 +180,8 @@ def prepare_inputs_memory(
genotype_array=data,
na_augment_perc=input_type_info.na_augment_perc,
na_augment_prob=input_type_info.na_augment_prob,
shuffle_augment_perc=input_type_info.shuffle_augment_perc,
shuffle_augment_prob=input_type_info.shuffle_augment_prob,
test_mode=test_mode,
)

Expand Down
14 changes: 13 additions & 1 deletion eir/data_load/data_preparation_modules/prepare_omics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import numpy as np
import torch

from eir.data_load.data_augmentation import make_random_omics_columns_missing
from eir.data_load.data_augmentation import (
make_random_omics_columns_missing,
shuffle_random_omics_columns,
)
from eir.data_load.data_preparation_modules.common import _load_deeplake_sample
from eir.data_load.data_source_modules import deeplake_ops

Expand Down Expand Up @@ -39,6 +42,8 @@ def prepare_one_hot_omics_data(
genotype_array: np.ndarray,
na_augment_perc: float,
na_augment_prob: float,
shuffle_augment_perc: float,
shuffle_augment_prob: float,
test_mode: bool,
) -> torch.Tensor:
"""
Expand All @@ -55,5 +60,12 @@ def prepare_one_hot_omics_data(
probability=na_augment_prob,
)

if not test_mode and shuffle_augment_perc > 0 and shuffle_augment_prob > 0:
tensor_bool = shuffle_random_omics_columns(
omics_array=tensor_bool,
percentage=shuffle_augment_perc,
probability=shuffle_augment_prob,
)

assert tensor_bool.dtype == torch.bool
return tensor_bool
80 changes: 79 additions & 1 deletion eir/models/input/array/array_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,89 @@ def get_array_model_init_kwargs(
match model_type:
case "lcl":
assert isinstance(model_config, LCLModelConfig)
kwargs["flatten_fn"] = partial(torch.flatten, start_dim=1)

if model_config.patch_size is not None:
assert isinstance(model_config.patch_size, (tuple, list))
assert len(model_config.patch_size) == 3, model_config.patch_size
kwargs["flatten_fn"] = partial(
patchify_and_flatten,
size=model_config.patch_size,
)
else:
kwargs["flatten_fn"] = partial(torch.flatten, start_dim=1)

return kwargs


def check_patch_and_input_size_compatibility(
patch_size: Union[tuple[int, int, int], list[int]],
data_dimensions: "DataDimensions",
) -> None:
assert isinstance(patch_size, (tuple, list))
assert len(patch_size) == 3, patch_size

channels, height, width = patch_size

if (
data_dimensions.channels % channels != 0
or data_dimensions.height % height != 0
or data_dimensions.width % width != 0
):
mismatch_details = (
f"Data dimensions {data_dimensions.full_shape()} "
f"cannot be evenly divided into patches of size {patch_size}. "
f"Mismatch in channels: {data_dimensions.channels % channels}, "
f"height: {data_dimensions.height % height}, "
f"width: {data_dimensions.width % width}."
)
raise ValueError(mismatch_details)


def patchify_and_flatten(
x: torch.Tensor,
size: tuple[int, int, int],
) -> torch.Tensor:
stride = size
patches = patchify(x=x, size=size, stride=stride)
flattened = flatten_patches(patches=patches)
return flattened


def patchify(
x: torch.Tensor, size: tuple[int, int, int], stride: tuple[int, int, int]
) -> torch.Tensor:
"""
size: (C, H, W)
Input shape: [256, 3, 64, 64]
Batch size: batch_size
Channels: C
Vertical patches: height / H
Horizontal patches: width / W
Patch channels: C
Patch height: H
Patch width: W
After unfolding: [batch_size, C, height / H, width / W, C, H, W]
After permuting: [batch_size, height / H, width / W, C, C, H, W]
"""
patches = (
x.unfold(1, size[0], stride[0])
.unfold(2, size[1], stride[1])
.unfold(3, size[2], stride[2])
)
patches = patches.permute(0, 2, 3, 4, 1, 5, 6)
return patches


def flatten_patches(patches: torch.Tensor) -> torch.Tensor:
reshaped_patches = patches.reshape(patches.size(0), -1)
return reshaped_patches


@dataclass
class ArrayModelConfig:

Expand Down
5 changes: 3 additions & 2 deletions eir/models/input/array/models_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,16 @@ class CNNModelConfig:
:param attention_inclusion_cutoff:
If the dimension of width * height is less than this value, attention will be
included in the model across channels and width * height after that point.
included in the model across channels and width * height as embedding dimension
after that point (with the channels representing the length of the sequence).
:param l1:
L1 regularization to apply to the first layer.
"""

layers: Union[None, List[int]] = None

num_output_features: int = 32
num_output_features: int = 256

channel_exp_base: int = 2
first_channel_expansion: int = 1
Expand Down
41 changes: 35 additions & 6 deletions eir/models/input/array/models_locally_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,21 @@ class LCLModelConfig:
to ensure that the feature representations become smaller as they are propagated
through the network.
:param patch_size:
Controls the size of the patches used in the first layer. If set to ``None``,
the input is flattened according to the torch ``flatten`` function. Note that
when using this parameter, we generally want the kernel width to be set to
the multiplication of the patch size.
:param layers:
Controls the number of layers in the model. If set to ``None``, the model will
automatically set up the number of layers according to the ``cutoff`` parameter
value.
:param kernel_width:
With of the locally connected kernels. Note that this refers to the flattened
input, meaning that if we have a one-hot encoding of 4 values (e.g. SNPs), 12
With of the locally connected kernels. Note that in the context of genomic
inputs this refers to the flattened input,
meaning that if we have a one-hot encoding of 4 values (e.g. SNPs), 12
refers to 12/4 = 3 SNPs per locally connected window. Can be set to ``None`` if
the ``num_lcl_chunks`` parameter is set, which means that the kernel width
will be set automatically according to
Expand Down Expand Up @@ -172,9 +179,11 @@ class LCLModelConfig:
attention cutoff >= 256, the attention block will be included.
"""

patch_size: Optional[tuple[int, int, int]] = None

layers: Union[None, List[int]] = None

kernel_width: int = 16
kernel_width: int | Literal["patch"] = 16
first_kernel_expansion: int = -2

channel_exp_base: int = 2
Expand Down Expand Up @@ -205,8 +214,13 @@ def __init__(
self.data_dimensions = data_dimensions
self.flatten_fn = flatten_fn

kernel_width = parse_kernel_width(
kernel_width=self.model_config.kernel_width,
patch_size=self.model_config.patch_size,
)

fc_0_kernel_size = calc_value_after_expansion(
base=self.model_config.kernel_width,
base=kernel_width,
expansion=self.model_config.first_kernel_expansion,
)
fc_0_channel_exponent = calc_value_after_expansion(
Expand All @@ -225,7 +239,7 @@ def __init__(

lcl_parameter_spec = LCParameterSpec(
in_features=self.fc_0.out_features,
kernel_width=self.model_config.kernel_width,
kernel_width=kernel_width,
channel_exp_base=self.model_config.channel_exp_base,
dropout_p=self.model_config.rb_do,
cutoff=cutoff,
Expand Down Expand Up @@ -264,6 +278,19 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return out


def parse_kernel_width(
kernel_width: int | Literal["patch"],
patch_size: Optional[tuple[int, int, int]],
) -> int:
if kernel_width == "patch":
if patch_size is None:
raise ValueError(
"kernel_width set to 'patch', but no patch_size was specified."
)
kernel_width = patch_size[0] * patch_size[1] * patch_size[2]
return kernel_width


def flatten_h_w_fortran(x: torch.Tensor) -> torch.Tensor:
"""
This is needed when e.g. flattening one-hot inputs that are ordered in a columns
Expand Down Expand Up @@ -485,10 +512,12 @@ def __init__(
num_heads: Union[int, Literal["auto"]] = "auto",
dropout_p: float = 0.0,
num_layers: int = 2,
dim_feedforward_factor: int = 4,
):
super().__init__()

self.embedding_dim = embedding_dim
self.dim_feedforward_factor = dim_feedforward_factor
self.in_features = in_features
self.dropout_p = dropout_p
self.num_layers = num_layers
Expand All @@ -500,7 +529,7 @@ def __init__(
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.embedding_dim,
nhead=self.num_heads,
dim_feedforward=self.embedding_dim * 4,
dim_feedforward=self.embedding_dim * self.dim_feedforward_factor,
activation="gelu",
norm_first=True,
batch_first=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from eir.models.model_setup_modules.output_model_setup_modules import al_output_modules
from eir.models.output.array.array_output_modules import (
ArrayOutputModuleConfig,
CNNUpscaleModel,
CNNUpscaleModelConfig,
LCLOutputModelConfig,
al_array_model_types,
al_output_array_model_classes,
Expand Down Expand Up @@ -86,6 +88,7 @@ def get_array_output_feature_extractor(
def get_array_output_model_mapping() -> Dict[str, al_output_array_model_classes]:
mapping = {
"lcl": LCLModel,
"cnn": CNNUpscaleModel,
}

return mapping
Expand All @@ -103,6 +106,7 @@ def get_array_output_config_dataclass_mapping() -> (
):
mapping = {
"lcl": LCLOutputModelConfig,
"cnn": CNNUpscaleModelConfig,
}

return mapping
Expand All @@ -120,9 +124,20 @@ def get_array_output_model_init_kwargs(
model_config: al_output_array_model_configs,
input_data_dimensions: DataDimensions,
output_data_dimensions: Optional[DataDimensions],
) -> dict[str, Union[DataDimensions, LCLOutputModelConfig, FlattenFunc, int]]:
) -> dict[
str,
Union[
DataDimensions, LCLOutputModelConfig | CNNUpscaleModelConfig, FlattenFunc, int
],
]:
kwargs: dict[
str, Union[DataDimensions, LCLOutputModelConfig, FlattenFunc, int]
str,
Union[
DataDimensions,
LCLOutputModelConfig | CNNUpscaleModelConfig,
FlattenFunc,
int,
],
] = {}

model_config_dataclass = get_array_output_model_config_dataclass(
Expand All @@ -147,4 +162,9 @@ def get_array_output_model_init_kwargs(
)
kwargs["dynamic_cutoff"] = num_elements

case "cnn":
assert isinstance(model_config, CNNUpscaleModelConfig)
assert output_data_dimensions is not None
kwargs["target_dimensions"] = output_data_dimensions

return kwargs
Loading

0 comments on commit bd23207

Please sign in to comment.