From 79781a2975dba5fc0105a82cf186c8264034a090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Diaz?= Date: Thu, 22 Aug 2024 14:05:25 +0200 Subject: [PATCH] Update py_predictor_v2.py --- metapredict/backend/py_predictor_v2.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/metapredict/backend/py_predictor_v2.py b/metapredict/backend/py_predictor_v2.py index 2b62c18..3b15737 100644 --- a/metapredict/backend/py_predictor_v2.py +++ b/metapredict/backend/py_predictor_v2.py @@ -79,6 +79,10 @@ def __init__(self, saved_weights, dtype, gpuid='cpu'): if torch.cuda.is_available(): device_string = f"cuda:{gpuid}" device = torch.device(device_string) + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + # Use MPS if available on ARM-based MacBooks + device_string = "mps" + device = torch.device(device_string) else: device_string = "cpu" device = torch.device(device_string)