diff --git a/qlora.py b/qlora.py index 23e675ee..d085e5a1 100644 --- a/qlora.py +++ b/qlora.py @@ -324,7 +324,7 @@ def get_accelerate_model(args, checkpoint_dir): bnb_4bit_use_double_quant=args.double_quant, bnb_4bit_quant_type=args.quant_type, ), - torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)), + torch_dtype=(torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)), trust_remote_code=args.trust_remote_code, use_auth_token=args.use_auth_token ) @@ -341,7 +341,7 @@ def get_accelerate_model(args, checkpoint_dir): setattr(model, 'model_parallel', True) setattr(model, 'is_parallelizable', True) - model.config.torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)) + model.config.torch_dtype=(torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)) # Tokenizer tokenizer = AutoTokenizer.from_pretrained(