Skip to content

Commit

Permalink
Updated documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinmicha committed Dec 12, 2023
1 parent d8592e2 commit 0f188f7
Show file tree
Hide file tree
Showing 6 changed files with 3,739 additions and 5,558 deletions.
9 changes: 2 additions & 7 deletions antipasti/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import numpy as np
import torch
from torch.nn import Linear, ReLU, Conv2d, MaxPool2d, Module, Dropout, Parameter
from torch.nn import Linear, ReLU, Conv2d, MaxPool2d, Module

class ANTIPASTI(Module):
r"""Predicting the binding affinity of an antibody from its normal mode correlation map.
Expand All @@ -25,7 +25,7 @@ class ANTIPASTI(Module):
l1_lambda: float
Weight of L1 regularisation.
mode: str
To use the full model, provide ``full``. Otherwise, ANTIPASTI corresponds to a linear classifier.
To use the full model, provide ``full``. Otherwise, ANTIPASTI corresponds to a linear map.
"""
def __init__(
Expand All @@ -47,12 +47,10 @@ def __init__(
self.fully_connected_input = n_filters * ((input_shape-filter_size+1)//pooling_size) ** 2
self.conv1 = Conv2d(1, n_filters, filter_size)
self.pool = MaxPool2d(pooling_size, pooling_size)
#self.dropit = Dropout(p=0.05)
self.relu = ReLU()
else:
self.fully_connected_input = self.input_shape ** 2
self.fc1 = Linear(self.fully_connected_input, 1, bias=False)
#self.fc2 = Linear(4, 1, bias=False)
self.l1_lambda = l1_lambda

def forward(self, x):
Expand All @@ -72,10 +70,7 @@ def forward(self, x):
x = self.relu(x)
inter = x = self.pool(x)
x = x.view(x.size(0), -1)
#if self.mode == 'full':
# x = self.dropit(x)
x = self.fc1(x)
#x = self.fc2(x)

return x.float(), inter

Expand Down
10 changes: 9 additions & 1 deletion antipasti/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ class Preprocessing(object):
Compute all the normal mode correlation maps.
renew_residues: bool
Retrieve the lists of residues for each entry.
cmaps: bool
If ``True``, ANTIPASTI computes the contact maps of the complexes instead of the Normal Modes.
cmaps_thr: float
Thresholding distance for alpha (α) carbons to build the contact maps.
ag_agnostic: bool
If ``True``, Normal Mode correlation maps are computed in complete absence of the antigen.
affinity_entries_only: bool
This is ``False`` in general, but the ANTIPASTI pipeline could be used to other types of projects and thus consider data without affinity values.
stage: str
Choose between ``training`` and ``predicting``.
test_data_path: str
Expand All @@ -64,7 +72,7 @@ class Preprocessing(object):
Amount of absent residues between positions 1 and 25 in the heavy chain.
l_offset: int
Amount of absent residues between positions 1 and 23 in the light chain.
"""

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions antipasti/utils/biology_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def remove_nanobodies(pdb_codes, representations, embedding=None, labels=[], num
Low-dimensional version of ``representations``.
labels: list
Data point labels.
numerical_values: list
If data is numerical (e.g., affinity values), it is necessary to include a list here. In this way, values associated to nanobodies can be removed.
"""
input_shape = representations.shape[-1]
Expand Down
96 changes: 90 additions & 6 deletions antipasti/utils/explaining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def get_maps_of_interest(preprocessed_data, learnt_filter, affinity_thr=-8):
return mean_learnt, mean_image, mean_diff_image

def get_output_representations(preprocessed_data, model):
r"""Returns maps that reveal the important residue interactions for the binding affinity. We call them 'output layer representations'.
Parameters
----------
preprocessed_data: antipasti.preprocessing.preprocessing.Preprocessing
The ``Preprocessing`` class.
model: antipasti.model.model.ANTIPASTI
The model class, i.e., ``ANTIPASTI``.
"""
input_shape = preprocessed_data.test_x.shape[-1]
each_img_enl = np.zeros((preprocessed_data.train_x.shape[0], input_shape**2))
size_le = int(np.sqrt(model.fc1.weight.data.numpy().shape[-1] / model.n_filters))
Expand Down Expand Up @@ -110,7 +120,7 @@ def plot_map_with_regions(preprocessed_data, map, title='Normal mode correlation
title: str
The image title.
interactive: bool
Set to ``True`` when running a script or Pytest.
Set to ``True`` when running a script or ``pytest``.
"""
# Font sizes
Expand Down Expand Up @@ -212,8 +222,8 @@ def compute_umap(preprocessed_data, model, scheme='heavy_species', categorical=T
external_cdict: dictionary
Option to provide an external dictionary of the UMAP labels.
interactive: bool
Set to ``True`` when running a script or Pytest.
remove_nanobodies: bool
Set to ``True`` when running a script or ``pytest``.
exclude_nanobodies: bool
Set to ``True`` to exclude nanobodies from the UMAP plot.
"""
Expand Down Expand Up @@ -337,8 +347,10 @@ def plot_umap(embedding, colours, scheme, pdb_codes, categorical=True, include_e
``True`` if ``scheme`` is categorical.
include_ellipses: bool
``True`` to include ellipses comprising 85% of the points of a given class.
cdict: dictionary
External dictionary of the UMAP labels.
interactive: bool
Set to ``True`` when running a script or Pytest.
Set to ``True`` when running a script or ``pytest``.
"""
fig = plt.figure(figsize=(20,20))
Expand Down Expand Up @@ -411,6 +423,22 @@ def plot_umap(embedding, colours, scheme, pdb_codes, categorical=True, include_e
plt.close('all')

def plot_region_importance(importance_factor, importance_factor_ob, antigen_type, mode='region', interactive=False):
r"""Plots ranking of important regions.
Parameters
----------
importance_factor: list
Measure of importance (0-100) for each antibody region.
importance_factor_ob: list
Measure of importance (0-100) for each antibody region attributable to off-block correlations. This can be inter-region or inter-chain depending on the selected ``mode``.
antigen_type: int
Plot corresponding to antigens of a given type. These can be proteins (0), haptens (1), peptides (2) or carbohydrates (3).
mode: str
``region`` to explicitely show which correlations are inter/intra-region (likewise for ``chain``).
interactive: bool
Set to ``True`` when running a script or ``pytest``.
"""

labels = ['FR-H1', 'CDR-H1', 'FR-H2', 'CDR-H2', 'FR-H3', 'CDR-H3', 'FR-H4', 'FR-L1', 'CDR-L1', 'FR-L2', 'CDR-L2', 'FR-L3', 'CDR-L3', 'FR-L4']
mapping_dict = {0: 0, 1: 2, 2: 1, 3: 5}
Expand Down Expand Up @@ -446,7 +474,7 @@ def plot_region_importance(importance_factor, importance_factor_ob, antigen_type
plt.close('all')

def add_region_based_on_range(list_residues):

r"""Given a list of residues in Chothia numbering, this function adds the corresponding regions in brackets for each of them."""
output_list_residues = []

new_coord = np.array([range(0, 26), range(26, 38), range(38, 57), range(57, 67), range(67, 116), range(116, 142),
Expand All @@ -463,6 +491,20 @@ def add_region_based_on_range(list_residues):
return output_list_residues

def plot_residue_importance(preprocessed_data, importance_factor, antigen_type, interactive=False):
r"""Plots ranking of important residues.
Parameters
----------
preprocessed_data: antipasti.preprocessing.preprocessing.Preprocessing
The ``Preprocessing`` class.
importance_factor: list
Measure of importance (0-100) for each antibody residue.
antigen_type: int
Plot corresponding to antigens of a given type. These can be proteins (0), haptens (1), peptides (2) or carbohydrates (3).
interactive: bool
Set to ``True`` when running a script or ``pytest``.
"""

res_labels = add_region_based_on_range(preprocessed_data.max_res_list_h+preprocessed_data.max_res_list_l)
mapping_dict = {0: 0, 1: 2, 2: 1, 3: 5}
Expand Down Expand Up @@ -495,6 +537,14 @@ def plot_residue_importance(preprocessed_data, importance_factor, antigen_type,


def get_colours_ag_type(preprocessed_data):
r"""Returns a different colour according to the antigen type.
Parameters
----------
preprocessed_data: antipasti.preprocessing.preprocessing.Preprocessing
The ``Preprocessing`` class.
"""

cluster_according_to = 'antigen_type'
db = pd.read_csv(preprocessed_data.data_path+'sabdab_summary_all.tsv', sep='\t')
Expand Down Expand Up @@ -522,7 +572,25 @@ def get_colours_ag_type(preprocessed_data):
return colours

def compute_region_importance(preprocessed_data, model, type_of_antigen, nanobodies, mode='region', interactive=False):

r"""Computes the importance factors (0-100) of all the Fv antibody regions.
Parameters
----------
preprocessed_data: antipasti.preprocessing.preprocessing.Preprocessing
The ``Preprocessing`` class.
model: antipasti.model.model.ANTIPASTI
The model class, i.e., ``ANTIPASTI``.
type_of_antigen: int
Choose between: proteins (0), haptens (1), peptides (2) or carbohydrates (3).
nanobodies: list
PDB codes of nanobodies in the dataset.
mode: str
``region`` to explicitely calculate which correlations are inter/intra-region (likewise for ``chain``).
interactive: bool
Set to ``True`` when running a script or ``pytest``.
"""

colours = get_colours_ag_type(preprocessed_data)
each_img_enl = get_output_representations(preprocessed_data, model)

Expand Down Expand Up @@ -581,6 +649,22 @@ def compute_region_importance(preprocessed_data, model, type_of_antigen, nanobod
plot_region_importance(tot, ob, type_of_antigen, mode, interactive=interactive)

def compute_residue_importance(preprocessed_data, model, type_of_antigen, nanobodies, interactive=False):
r"""Computes the importance factors (0-100) of all the amino acids of the antibody variable region.
Parameters
----------
preprocessed_data: antipasti.preprocessing.preprocessing.Preprocessing
The ``Preprocessing`` class.
model: antipasti.model.model.ANTIPASTI
The model class, i.e., ``ANTIPASTI``.
type_of_antigen: int
Choose between: proteins (0), haptens (1), peptides (2) or carbohydrates (3).
nanobodies: list
PDB codes of nanobodies in the dataset.
interactive: bool
Set to ``True`` when running a script or ``pytest``.
"""

colours = get_colours_ag_type(preprocessed_data)
each_img_enl = get_output_representations(preprocessed_data, model)
Expand Down
28 changes: 4 additions & 24 deletions antipasti/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def create_test_set(train_x, train_y, test_size=None, random_state=0):
Labels.
test_size: float
Fraction of original samples to be included in the test set.
random_state: int
Set lot number.
Returns
-------
Expand Down Expand Up @@ -105,18 +107,7 @@ def training_step(model, criterion, optimiser, train_x, test_x, train_y, test_y,
inter_filter = np.zeros((x_train.size()[0], model.n_filters, size_inter, size_inter))
if model.mode != 'full':
inter_filter = np.zeros((x_train.size()[0], 1, model.input_shape, model.input_shape))
#perm_paired = []
#perm_nano = []
permutation = torch.randperm(x_train.size()[0])
#for i in range(x_train.size()[0]):
# if torch.numel(torch.nonzero(x_train[i,0,-80:,-80:])) == 0:
# perm_nano.append(i)
# else:
# perm_paired.append(i)
#np.random.shuffle(perm_nano)
#np.random.shuffle(perm_paired)
#print(len(perm_nano))
#permutation = perm_nano + perm_paired

for i in range(0, x_train.size()[0], batch_size):
indices = permutation[i:i+batch_size]
Expand All @@ -130,17 +121,6 @@ def training_step(model, criterion, optimiser, train_x, test_x, train_y, test_y,
inter_filter[i:i+batch_size] = inter_filters_detached.numpy()

# Training loss, clearing gradients and updating weights
#def closure():
# optimiser.zero_grad()
# output_train, inter_filters = model(batch_x)
# loss_train = criterion(output_train[:, 0], batch_y[:, 0])
# loss_train.backward()
# return loss_train

#optimiser.step(closure)

#with torch.no_grad():
# loss_train = criterion(output_train[:, 0], batch_y[:, 0]).detach()
optimiser.zero_grad()
l1_loss = model.l1_regularization_loss()
mse_loss = criterion(output_train[:, 0], batch_y[:, 0])
Expand Down Expand Up @@ -246,10 +226,10 @@ def load_checkpoint(path, input_shape, n_filters=None, pooling_size=None, filter
Shape of the normal mode correlation maps.
n_filters: int
Number of filters in the convolutional layer.
filter_size: int
Size of filters in the convolutional layer.
pooling_size: int
Size of the max pooling operation.
filter_size: int
Size of filters in the convolutional layer.
Returns
-------
Expand Down
9,152 changes: 3,632 additions & 5,520 deletions notebooks/[Tutorial] Training ANTIPASTI.ipynb

Large diffs are not rendered by default.

0 comments on commit 0f188f7

Please sign in to comment.