Skip to content

Commit

Permalink
Update to TorchANI 2.2.2 (#35)
Browse files Browse the repository at this point in the history
* Update to TorchANI 2.2.1

* Update to PyTorch 1.10

* Add sysroot

* Fix a varialbe name

* Convert TorchANIBatchedNN to a ModuleList

* Circument TorchANI issue

* Set the oldest and latest dependeny versions

* Set the minimum CMake version

* Set Python 3.8 as a minimum version

* Update to TorchANI 2.2.2
  • Loading branch information
Raimondas Galvelis authored Nov 4, 2021
1 parent 24a6f7e commit 6939536
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@ jobs:
strategy:
matrix:
include:
# Oldest supported versions
- cuda: 10.2.89
gcc: 8.5.0
nvcc: 10.2
python: 3.8
pytorch: 1.7.1
pytorch: 1.8.0
# Latest supported versions
- cuda: 11.2.2
gcc: 10.3.0
nvcc: 11.2
python: 3.9
pytorch: 1.8.0 # Cannot test with PyTorch 1.9, because of https://github.com/aiqm/torchani/issues/598
pytorch: 1.10.0

steps:
- name: Check out
Expand Down
7 changes: 4 additions & 3 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
channels:
- conda-forge
dependencies:
- cmake
- cmake >=3.20
- cudatoolkit 11.2.2
- gxx_linux-64 10.3.0
- make
- mdtraj
- nvcc_linux-64 11.2
- torchani 2.2
- torchani 2.2.2
- pytest
- python 3.9
- pytorch-gpu 1.8.0
- pytorch-gpu 1.10.0
- sysroot_linux-64 2.17
20 changes: 16 additions & 4 deletions src/pytorch/BatchedNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
batchedLinear = torch.ops.NNPOpsBatchedNN.BatchedLinear


class TorchANIBatchedNN(torch.nn.Module):
class _BatchedNN(torch.nn.Module):

def __init__(self, converter: SpeciesConverter, ensemble: Union[ANIModel, Ensemble], atomicNumbers: Tensor):

Expand All @@ -44,10 +44,10 @@ def __init__(self, converter: SpeciesConverter, ensemble: Union[ANIModel, Ensemb
species_list = converter((atomicNumbers, torch.empty(0))).species[0].tolist()

# Handle the case when the ensemble is just one model
ensemble = [ensemble] if type(ensemble) == ANIModel else ensemble
self._ensemble = [ensemble] if type(ensemble) == ANIModel else ensemble

# Convert models to the list of linear layers
models = [list(model.values()) for model in ensemble]
models = [list(model.values()) for model in self._ensemble]

# Extract the weihts and biases of the linear layers
for ilayer in [0, 2, 4, 6]:
Expand Down Expand Up @@ -82,6 +82,9 @@ def batchLinearLayers(layers: List[List[nn.Linear]]) -> Tuple[nn.Parameter, nn.P

return nn.Parameter(weights), nn.Parameter(biases)

def _atomic_energies(self, species_aev: Tuple[Tensor, Tensor]) -> Tensor:
return self._ensemble[0]._atomic_energies(species_aev)

def forward(self, species_aev: Tuple[Tensor, Tensor]) -> SpeciesEnergies:

species, aev = species_aev
Expand All @@ -101,4 +104,13 @@ def forward(self, species_aev: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
# Mean: [num_mols, num_models] --> [num_mols]
energies = torch.mean(torch.sum(vectors, (1, 3, 4)), 1)

return SpeciesEnergies(species, energies)
return SpeciesEnergies(species, energies)


class TorchANIBatchedNN(torch.nn.ModuleList):

def __init__(self, converter: SpeciesConverter, ensemble: Union[ANIModel, Ensemble], atomicNumbers: Tensor):
super().__init__([_BatchedNN(converter, ensemble, atomicNumbers)])

def forward(self, species_aev: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
return self[0].forward(species_aev)
4 changes: 2 additions & 2 deletions src/pytorch/SymmetryFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, symmFunc: torchani.AEVComputer):
"""
super().__init__()

self.numSpecies = symmFunc.num_species
self.num_species = symmFunc.num_species
self.Rcr = symmFunc.Rcr
self.Rca = symmFunc.Rca
self.EtaR = symmFunc.EtaR[:, 0].tolist()
Expand Down Expand Up @@ -121,7 +121,7 @@ def forward(self, speciesAndPositions: Tuple[Tensor, Tensor],

if not self.holder.is_initialized():
species_: List[int] = species[0].tolist() # Explicit type casting for TorchScript
self.holder = Holder(self.numSpecies, self.Rcr, self.Rca,
self.holder = Holder(self.num_species, self.Rcr, self.Rca,
self.EtaR, self.ShfR,
self.EtaA, self.Zeta, self.ShfA, self.ShfZ,
species_, positions)
Expand Down

0 comments on commit 6939536

Please sign in to comment.