diff --git a/cellfinder/core/tools/system.py b/cellfinder/core/tools/system.py index 612ad4da..c5ad9045 100644 --- a/cellfinder/core/tools/system.py +++ b/cellfinder/core/tools/system.py @@ -1,5 +1,6 @@ from pathlib import Path +import keras from brainglobe_utils.general.exceptions import CommandLineInputError @@ -80,3 +81,12 @@ def memory_in_bytes(memory_amount, unit): ) else: return memory_amount * 10 ** supported_units[unit] + + +def force_cpu(): + """ + Forces the CPU to be used, even if a GPU is available + """ + keras.src.backend.common.global_state.set_global_attribute( + "torch_device", "cpu" + ) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index e97b878f..d78e5cf5 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -1,7 +1,6 @@ import os from typing import Tuple -import keras.src.backend.common.global_state import numpy as np import pytest import torch.backends.mps @@ -11,6 +10,7 @@ DEFAULT_DOWNLOAD_DIRECTORY, download_models, ) +from cellfinder.core.tools.system import force_cpu @pytest.fixture(scope="session", autouse=True) @@ -24,9 +24,7 @@ def set_device_arm_macos_ci(): os.getenv("GITHUB_ACTIONS") == "true" and torch.backends.mps.is_available() ): - keras.src.backend.common.global_state.set_global_attribute( - "torch_device", "cpu" - ) + force_cpu() @pytest.fixture(scope="session")