Skip to content

Commit

Permalink
Support MPS on Apple Silicon (#132)
Browse files Browse the repository at this point in the history
See https://pytorch.org/docs/stable/notes/mps.html

Testing this with one image using a Mac Studio (M1 Max) the processing time decreased from ~21 minutes to ~2 minutes.

I'm not sure if the change in `utils.py` is necessary, nor if there should be changes in other places - it looks like the report in `infer.py` should be updated if this is accepted.

Also updated `.gitignore` for more Mac (and PyCharm) friendliness.
  • Loading branch information
petebankhead authored Jun 28, 2023
1 parent f7f1dc7 commit 3afff37
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

# Extras
.DS_Store
13 changes: 10 additions & 3 deletions wsinfer/_modellib/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,16 @@ def run_inference(
model = weights.load_model()
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
if torch.cuda.is_available():
device = torch.device("cuda")
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f'Using device "{device}"')

model.to(device)

if speedup:
Expand Down
8 changes: 6 additions & 2 deletions wsinfer/_patchlib/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
import math
import collections

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
else:
device = torch.device("cpu")

class SubsetSequentialSampler(Sampler):
"""Samples elements sequentially from a given list of indices, without replacement.
Expand Down

0 comments on commit 3afff37

Please sign in to comment.