From 045d542300a0f1e6e0e7a1988f800823b566c6da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Baldo?= Date: Wed, 17 Jan 2024 20:34:40 -0300 Subject: [PATCH] hf/bench.py: need to specify bfloat16 otherwise it consumes twice as much memory in A10. BTW float16 gives wrong results in the A10 but correct on the T4 (which also doesn't work unless specifying the float16 type explicitly). Maybe related also to https://github.com/hamelsmu/llama-inference/issues/4 --- hf/bench.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hf/bench.py b/hf/bench.py index 1cb5ac3..4efcaf2 100644 --- a/hf/bench.py +++ b/hf/bench.py @@ -1,5 +1,6 @@ # Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM +import torch import time import sys sys.path.append('../common/') @@ -8,7 +9,7 @@ model_id = "meta-llama/Llama-2-7b-hf" tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) model.to("cuda") def predict(prompt:str):