diff --git a/hf/bench.py b/hf/bench.py index 1cb5ac3..4efcaf2 100644 --- a/hf/bench.py +++ b/hf/bench.py @@ -1,5 +1,6 @@ # Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM +import torch import time import sys sys.path.append('../common/') @@ -8,7 +9,7 @@ model_id = "meta-llama/Llama-2-7b-hf" tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) model.to("cuda") def predict(prompt:str):