Skip to content

Commit

Permalink
Merge pull request #20 from RuiWang1998/main
Browse files Browse the repository at this point in the history
Added MPS support back
  • Loading branch information
mooninrain authored Aug 13, 2022
2 parents 03812bb + 95a42da commit cd8b5ad
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 10 deletions.
12 changes: 6 additions & 6 deletions omegafold/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,16 @@ def main():
f"{len(input_data[0]['p_msa'][0])} residues in this chain."
)
ts = time.time()
try:
output = model(
# try:
output = model(
input_data,
predict_with_confidence=True,
fwd_cfg=forward_config
)
except RuntimeError as e:
logging.info(f"Failed to generate {save_path} due to {e}")
logging.info(f"Skipping...")
continue
# except RuntimeError as e:
# logging.info(f"Failed to generate {save_path} due to {e}")
# logging.info(f"Skipping...")
# continue
logging.info(f"Finished prediction in {time.time() - ts:.2f} seconds.")

logging.info(f"Saving prediction to {save_path}")
Expand Down
2 changes: 1 addition & 1 deletion omegafold/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def forward(
Returns:
"""
atom_mask = rc.restype2atom_mask[fasta].to(self.device)
atom_mask = rc.restype2atom_mask[fasta.cpu()].to(self.device)
prev_beta = utils.create_pseudo_beta(prev_x, atom_mask)
d = utils.get_norm(prev_beta.unsqueeze(-2) - prev_beta.unsqueeze(-3))
d = self.dgram(d)
Expand Down
2 changes: 1 addition & 1 deletion omegafold/geoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def forward(
return node_repr, edge_repr

def _column_attention(self, node_repr, mask, fwd_cfg):
node_repr_col = utils.normalize(node_repr.transpose(-2, -3))
node_repr_col = utils.normalize(node_repr.transpose(-2, -3).contiguous())
node_repr_col = self.column_attention(
node_repr_col,
node_repr_col,
Expand Down
2 changes: 1 addition & 1 deletion omegafold/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def _get_gated(self, edge_repr: torch.Tensor, mask: torch.Tensor, fwd_cfg):
):
act_col = self._get_act_col(edge_col, mask[s_col:e_col])
ab = torch.einsum('...ikrd,...jkrd->...ijrd', act_row, act_col)
ab = utils.normalize(ab)
ab = utils.normalize(ab.contiguous())
gated[s_row:e_row, s_col:e_col] = torch.einsum(
'...rd,rdc->...rc', ab, self.out_proj_w
)
Expand Down
7 changes: 6 additions & 1 deletion omegafold/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@
from Bio.PDB import StructureBuilder
import torch
from torch import hub
from torch.backends import cuda, cudnn, mps
from torch.backends import cuda, cudnn
from torch.utils.hipify import hipify_python

from omegafold import utils
from omegafold.utils.protein_utils import residue_constants as rc

try:
from torch.backends import mps # Compatibility with earlier versions
except IndexError:
mps = None


# =============================================================================
# Constants
Expand Down

0 comments on commit cd8b5ad

Please sign in to comment.