Skip to content

Commit

Permalink
Merge branch 'main' into resume_train
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-TAMU committed Sep 16, 2024
2 parents effa188 + 5dd5494 commit d2a4f0b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 15 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<25", "ninja>=1.11.1.1,<2.0", "sci
flash-attn = ["flash-attn>=2.5.3,<3.0"]
aim = ["aim>=3.19.0,<4.0"]
fms-accel = ["fms-acceleration>=0.1"]
gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"]


[tool.setuptools.packages.find]
Expand Down
73 changes: 58 additions & 15 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# Third Party
from peft import PeftModel
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
import torch

# Local
Expand Down Expand Up @@ -176,20 +176,45 @@ def load(
else {}
)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
device = "cuda" if torch.cuda.is_available() else None
print(f"Inferred device: {device}")
# Apply the configs to the adapter config of this model; if no overrides
# are provided, then the context manager doesn't have any effect.
try:
with AdapterConfigPatcher(checkpoint_path, overrides):
try:
if base_model_name_or_path is None:
raise ValueError("base_model_name_or_path has to be passed")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)

if (
has_quantized_config(base_model_name_or_path)
and device == "cuda"
):
# Using GPTQConfig from HF, avail params are here
# https://huggingface.co/docs/transformers/en/main_classes/quantization#transformers.GPTQConfig
# We only support 4-bit AutoGPTQ, so setting bits to 4
# setting exllama kernel to version 2 as it's a faster kernel
gptq_config = GPTQConfig(bits=4, exllama_config={"version": 2})

# Since we are using exllama kernel, we need torch.float16 as torch_dtype
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
device_map=device,
torch_dtype=torch.float16,
quantization_config=gptq_config,
)
else:
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)

# since the peft library (PEFTModelForCausalLM) does not handle cases
# where the model's layers are modified, in our case the embedding layer
# is modified, so we resize the backbone model's embedding layer with our own
Expand All @@ -211,14 +236,28 @@ def load(
except FileNotFoundError:
print("No adapter config found! Loading as a merged model...")
# Unable to find the adapter config; fall back to loading as a merged model
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)
if has_quantized_config(checkpoint_path) and device == "cuda":
# Using GPTQConfig from HF, avail params are here
# https://huggingface.co/docs/transformers/en/main_classes/quantization#transformers.GPTQConfig
# We only support 4-bit AutoGPTQ, so setting bits to 4
# setting exllama kernel to version 2 as it's a faster kernel
gptq_config = GPTQConfig(bits=4, exllama_config={"version": 2})

# Since we are using exllama kernel, we need torch.float16 as torch_dtype
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
device_map=device,
torch_dtype=torch.float16,
quantization_config=gptq_config,
)
else:
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)

device = "cuda" if torch.cuda.is_available() else None
print(f"Inferred device: {device}")
model.to(device)
return cls(model, tokenizer, device)

Expand Down Expand Up @@ -327,5 +366,9 @@ def main():
print(f"Exported results to: {args.out_file}")


def has_quantized_config(model_path: str):
return os.path.exists(os.path.join(model_path, "quantize_config.json"))


if __name__ == "__main__":
main()

0 comments on commit d2a4f0b

Please sign in to comment.