diff --git a/src/autotrain/__init__.py b/src/autotrain/__init__.py index 5e21f2a6c0..15bf731d86 100644 --- a/src/autotrain/__init__.py +++ b/src/autotrain/__init__.py @@ -30,4 +30,4 @@ warnings.filterwarnings("ignore", category=UserWarning, module="tensorflow") -__version__ = "0.6.35.dev0" +__version__ = "0.6.36.dev0" diff --git a/src/autotrain/cli/run_llm.py b/src/autotrain/cli/run_llm.py index 2198afc704..a1306b9900 100644 --- a/src/autotrain/cli/run_llm.py +++ b/src/autotrain/cli/run_llm.py @@ -429,10 +429,16 @@ def __init__(self, args): break print(f"Bot: {tgi.chat(prompt)}") - if not torch.cuda.is_available(): - raise ValueError("No GPU found. Please install CUDA and try again.") + cuda_available = torch.cuda.is_available() + mps_available = torch.backends.mps.is_available() - self.num_gpus = torch.cuda.device_count() + if not cuda_available and not mps_available: + raise ValueError("No GPU/MPS device found. LLM training requires an accelerator") + + if cuda_available: + self.num_gpus = torch.cuda.device_count() + elif mps_available: + self.num_gpus = 1 def run(self): from autotrain.backend import EndpointsRunner, SpaceRunner diff --git a/src/autotrain/trainers/clm/__main__.py b/src/autotrain/trainers/clm/__main__.py index 406ed3b08d..72379d337f 100644 --- a/src/autotrain/trainers/clm/__main__.py +++ b/src/autotrain/trainers/clm/__main__.py @@ -134,8 +134,10 @@ def train(config): bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=False, ) + config.fp16 = True elif config.use_int8: bnb_config = BitsAndBytesConfig(load_in_8bit=config.use_int8) + config.fp16 = True else: bnb_config = None