diff --git a/metapredict/backend/py_predictor_v2.py b/metapredict/backend/py_predictor_v2.py index 2b62c18..3b15737 100644 --- a/metapredict/backend/py_predictor_v2.py +++ b/metapredict/backend/py_predictor_v2.py @@ -79,6 +79,10 @@ def __init__(self, saved_weights, dtype, gpuid='cpu'): if torch.cuda.is_available(): device_string = f"cuda:{gpuid}" device = torch.device(device_string) + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + # Use MPS if available on ARM-based MacBooks + device_string = "mps" + device = torch.device(device_string) else: device_string = "cpu" device = torch.device(device_string)