From 5dd5494a4e0fae9bd568bb868925b80dac6653ae Mon Sep 17 00:00:00 2001 From: Angel Luu Date: Mon, 16 Sep 2024 09:45:40 -0600 Subject: [PATCH] feat: Add deps to evaluate qLora tuned model (#312) * Add support to load qLora tuned model in run_inference.py script Signed-off-by: Angel Luu * Remove comment Signed-off-by: Angel Luu * Disable gptq by default Signed-off-by: Angel Luu * Remove the gptq-dev install in Dockerfile Signed-off-by: Angel Luu * Rename gptq-dev package from gptq Signed-off-by: Angel Luu * Add comments in run_inference.py Signed-off-by: Angel Luu * Update device to cuda Signed-off-by: Angel Luu * Add in the case that there's no adapter found Signed-off-by: Angel Luu * Use torch.float16 for quantized Signed-off-by: Angel Luu --------- Signed-off-by: Angel Luu --- pyproject.toml | 1 + scripts/run_inference.py | 73 +++++++++++++++++++++++++++++++--------- 2 files changed, 59 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index aae1a9dd7..fcb049821 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<25", "ninja>=1.11.1.1,<2.0", "sci flash-attn = ["flash-attn>=2.5.3,<3.0"] aim = ["aim>=3.19.0,<4.0"] fms-accel = ["fms-acceleration>=0.1"] +gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"] [tool.setuptools.packages.find] diff --git a/scripts/run_inference.py b/scripts/run_inference.py index d64bf926b..7e4465cac 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -30,7 +30,7 @@ # Third Party from peft import PeftModel from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig import torch # Local @@ -176,6 +176,8 @@ def load( else {} ) tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) + device = "cuda" if torch.cuda.is_available() else None + print(f"Inferred device: {device}") # Apply the configs to the adapter config of this model; if no overrides # are provided, then the context manager doesn't have any effect. try: @@ -183,13 +185,36 @@ def load( try: if base_model_name_or_path is None: raise ValueError("base_model_name_or_path has to be passed") - base_model = AutoModelForCausalLM.from_pretrained( - base_model_name_or_path, - attn_implementation="flash_attention_2" - if use_flash_attn - else None, - torch_dtype=torch.bfloat16 if use_flash_attn else None, - ) + + if ( + has_quantized_config(base_model_name_or_path) + and device == "cuda" + ): + # Using GPTQConfig from HF, avail params are here + # https://huggingface.co/docs/transformers/en/main_classes/quantization#transformers.GPTQConfig + # We only support 4-bit AutoGPTQ, so setting bits to 4 + # setting exllama kernel to version 2 as it's a faster kernel + gptq_config = GPTQConfig(bits=4, exllama_config={"version": 2}) + + # Since we are using exllama kernel, we need torch.float16 as torch_dtype + base_model = AutoModelForCausalLM.from_pretrained( + base_model_name_or_path, + attn_implementation="flash_attention_2" + if use_flash_attn + else None, + device_map=device, + torch_dtype=torch.float16, + quantization_config=gptq_config, + ) + else: + base_model = AutoModelForCausalLM.from_pretrained( + base_model_name_or_path, + attn_implementation="flash_attention_2" + if use_flash_attn + else None, + torch_dtype=torch.bfloat16 if use_flash_attn else None, + ) + # since the peft library (PEFTModelForCausalLM) does not handle cases # where the model's layers are modified, in our case the embedding layer # is modified, so we resize the backbone model's embedding layer with our own @@ -211,14 +236,28 @@ def load( except FileNotFoundError: print("No adapter config found! Loading as a merged model...") # Unable to find the adapter config; fall back to loading as a merged model - model = AutoModelForCausalLM.from_pretrained( - checkpoint_path, - attn_implementation="flash_attention_2" if use_flash_attn else None, - torch_dtype=torch.bfloat16 if use_flash_attn else None, - ) + if has_quantized_config(checkpoint_path) and device == "cuda": + # Using GPTQConfig from HF, avail params are here + # https://huggingface.co/docs/transformers/en/main_classes/quantization#transformers.GPTQConfig + # We only support 4-bit AutoGPTQ, so setting bits to 4 + # setting exllama kernel to version 2 as it's a faster kernel + gptq_config = GPTQConfig(bits=4, exllama_config={"version": 2}) + + # Since we are using exllama kernel, we need torch.float16 as torch_dtype + model = AutoModelForCausalLM.from_pretrained( + checkpoint_path, + attn_implementation="flash_attention_2" if use_flash_attn else None, + device_map=device, + torch_dtype=torch.float16, + quantization_config=gptq_config, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + checkpoint_path, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=torch.bfloat16 if use_flash_attn else None, + ) - device = "cuda" if torch.cuda.is_available() else None - print(f"Inferred device: {device}") model.to(device) return cls(model, tokenizer, device) @@ -327,5 +366,9 @@ def main(): print(f"Exported results to: {args.out_file}") +def has_quantized_config(model_path: str): + return os.path.exists(os.path.join(model_path, "quantize_config.json")) + + if __name__ == "__main__": main()