Skip to content

Commit

Permalink
Merge branch 'main' into reward
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Oct 12, 2023
2 parents c368363 + 1931728 commit cb7de45
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/autotrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@
warnings.filterwarnings("ignore", category=UserWarning, module="tensorflow")


__version__ = "0.6.35.dev0"
__version__ = "0.6.36.dev0"
12 changes: 9 additions & 3 deletions src/autotrain/cli/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,10 +429,16 @@ def __init__(self, args):
break
print(f"Bot: {tgi.chat(prompt)}")

if not torch.cuda.is_available():
raise ValueError("No GPU found. Please install CUDA and try again.")
cuda_available = torch.cuda.is_available()
mps_available = torch.backends.mps.is_available()

self.num_gpus = torch.cuda.device_count()
if not cuda_available and not mps_available:
raise ValueError("No GPU/MPS device found. LLM training requires an accelerator")

if cuda_available:
self.num_gpus = torch.cuda.device_count()
elif mps_available:
self.num_gpus = 1

def run(self):
from autotrain.backend import EndpointsRunner, SpaceRunner
Expand Down
2 changes: 2 additions & 0 deletions src/autotrain/trainers/clm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,10 @@ def train(config):
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
)
config.fp16 = True
elif config.use_int8:
bnb_config = BitsAndBytesConfig(load_in_8bit=config.use_int8)
config.fp16 = True
else:
bnb_config = None

Expand Down

0 comments on commit cb7de45

Please sign in to comment.