From f4a8e09660bf4a04b7040ffe230fee6aa80751e1 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Sat, 13 Aug 2022 22:50:40 +0800 Subject: [PATCH] Added back supports for MPS --- omegafold/__main__.py | 12 ++++++------ omegafold/embedders.py | 2 +- omegafold/geoformer.py | 2 +- omegafold/modules.py | 2 +- omegafold/pipeline.py | 7 ++++++- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/omegafold/__main__.py b/omegafold/__main__.py index 085ee14..d9dde86 100644 --- a/omegafold/__main__.py +++ b/omegafold/__main__.py @@ -70,16 +70,16 @@ def main(): f"{len(input_data[0]['p_msa'][0])} residues in this chain." ) ts = time.time() - try: - output = model( + # try: + output = model( input_data, predict_with_confidence=True, fwd_cfg=forward_config ) - except RuntimeError as e: - logging.info(f"Failed to generate {save_path} due to {e}") - logging.info(f"Skipping...") - continue + # except RuntimeError as e: + # logging.info(f"Failed to generate {save_path} due to {e}") + # logging.info(f"Skipping...") + # continue logging.info(f"Finished prediction in {time.time() - ts:.2f} seconds.") logging.info(f"Saving prediction to {save_path}") diff --git a/omegafold/embedders.py b/omegafold/embedders.py index f25e8ce..4175f25 100644 --- a/omegafold/embedders.py +++ b/omegafold/embedders.py @@ -263,7 +263,7 @@ def forward( Returns: """ - atom_mask = rc.restype2atom_mask[fasta].to(self.device) + atom_mask = rc.restype2atom_mask[fasta.cpu()].to(self.device) prev_beta = utils.create_pseudo_beta(prev_x, atom_mask) d = utils.get_norm(prev_beta.unsqueeze(-2) - prev_beta.unsqueeze(-3)) d = self.dgram(d) diff --git a/omegafold/geoformer.py b/omegafold/geoformer.py index 9e7a645..1f11344 100644 --- a/omegafold/geoformer.py +++ b/omegafold/geoformer.py @@ -126,7 +126,7 @@ def forward( return node_repr, edge_repr def _column_attention(self, node_repr, mask, fwd_cfg): - node_repr_col = utils.normalize(node_repr.transpose(-2, -3)) + node_repr_col = utils.normalize(node_repr.transpose(-2, -3).contiguous()) node_repr_col = self.column_attention( node_repr_col, node_repr_col, diff --git a/omegafold/modules.py b/omegafold/modules.py index 05f3191..9cf4603 100644 --- a/omegafold/modules.py +++ b/omegafold/modules.py @@ -633,7 +633,7 @@ def _get_gated(self, edge_repr: torch.Tensor, mask: torch.Tensor, fwd_cfg): ): act_col = self._get_act_col(edge_col, mask[s_col:e_col]) ab = torch.einsum('...ikrd,...jkrd->...ijrd', act_row, act_col) - ab = utils.normalize(ab) + ab = utils.normalize(ab.contiguous()) gated[s_row:e_row, s_col:e_col] = torch.einsum( '...rd,rdc->...rc', ab, self.out_proj_w ) diff --git a/omegafold/pipeline.py b/omegafold/pipeline.py index ef40a39..146c39e 100644 --- a/omegafold/pipeline.py +++ b/omegafold/pipeline.py @@ -35,12 +35,17 @@ from Bio.PDB import StructureBuilder import torch from torch import hub -from torch.backends import cuda, cudnn, mps +from torch.backends import cuda, cudnn from torch.utils.hipify import hipify_python from omegafold import utils from omegafold.utils.protein_utils import residue_constants as rc +try: + from torch.backends import mps # Compatibility with earlier versions +except IndexError: + mps = None + # ============================================================================= # Constants