Skip to content

Commit

Permalink
Put launch script code back
Browse files Browse the repository at this point in the history
Signed-off-by: Angel Luu <[email protected]>
  • Loading branch information
aluu317 committed Aug 30, 2024
1 parent 8254224 commit 6bde694
Showing 1 changed file with 6 additions and 23 deletions.
29 changes: 6 additions & 23 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@

# Third Party
from accelerate.commands.launch import launch_command
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from torch import bfloat16, float16
from torch import bfloat16

# Local
from build.utils import (
Expand Down Expand Up @@ -142,28 +142,11 @@ def main():

if os.path.exists(adapter_config_path):
base_model_path = get_base_model_from_adapter_config(adapter_config_path)

is_quantized = os.path.exists(
os.path.join(base_model_path, "quantize_config.json")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)
if is_quantized:
print("ANGEL QLORA DEBUG: this model is quantized")

gptq_config = GPTQConfig(bits=4, exllama_config={"version": 2})
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
device_map="auto",
torch_dtype=float16 if use_flash_attn else None,
quantization_config=gptq_config,
)
else:
print("ANGEL QLORA DEBUG: this model is NOT quantized")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)

# since the peft library (PEFTModelForCausalLM) does not handle cases
# where the model's layers are modified, in our case the embedding layer
Expand Down

0 comments on commit 6bde694

Please sign in to comment.