Skip to content

Commit

Permalink
fix cuda device bug
Browse files Browse the repository at this point in the history
  • Loading branch information
kexinhuang12345 committed Apr 23, 2021
1 parent fa3ee21 commit 3e61473
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 6 deletions.
5 changes: 1 addition & 4 deletions DeepPurpose/DTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,10 +587,7 @@ def load_pretrained(self, path):
if not os.path.exists(path):
os.makedirs(path)

if self.device[:4] == 'cuda':
state_dict = torch.load(path)
else:
state_dict = torch.load(path, map_location = torch.device('cpu'))
state_dict = torch.load(path, map_location = torch.device('cpu'))
# to support training from multi-gpus data-parallel:

if next(iter(state_dict))[:7] == 'module.':
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ Checkout [Dataset Tutorial](DEMO/load_data_tutorial.ipynb).
We provide more than 10 pretrained models. Please see [Pretraining Model Tutorial](DEMO/load_pretraining_models_tutorial.ipynb) on how to load them. It is as simple as

```python
from DeepPurpose import models
from DeepPurpose import DTI as models
net = models.model_pretrained(model = 'MPNN_CNN_DAVIS')
or
net = models.model_pretrained(FILE_PATH)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def readme():
name="DeepPurpose",
packages = ['DeepPurpose'],
package_data={'DeepPurpose': ['ESPF/*']},
version="0.1.0",
version="0.1.1",
author="Kexin Huang, Tianfan Fu",
license="BSD-3-Clause",
author_email="[email protected]",
Expand Down

0 comments on commit 3e61473

Please sign in to comment.