From b15a07b5e67be7499c51716ff5132b63193d209f Mon Sep 17 00:00:00 2001 From: Abhishek Maurya <124327945+Abhishek-TAMU@users.noreply.github.com> Date: Thu, 12 Sep 2024 21:15:01 -0400 Subject: [PATCH] fix: remove lm_head post processing (#333) * fix: Removal of lm head hack Signed-off-by: Abhishek * set fms_accelerate to true by default Signed-off-by: Anh Uong --------- Signed-off-by: Abhishek Signed-off-by: Anh Uong Co-authored-by: Anh Uong Signed-off-by: Angel Luu --- build/Dockerfile | 2 +- build/accelerate_launch.py | 92 -------------------------------------- 2 files changed, 1 insertion(+), 93 deletions(-) diff --git a/build/Dockerfile b/build/Dockerfile index 4bd9cab6a..ffae818da 100644 --- a/build/Dockerfile +++ b/build/Dockerfile @@ -21,7 +21,7 @@ ARG PYTHON_VERSION=3.11 ARG WHEEL_VERSION="" ## Enable Aimstack if requested via ENABLE_AIM set to "true" ARG ENABLE_AIM=false -ARG ENABLE_FMS_ACCELERATION=false +ARG ENABLE_FMS_ACCELERATION=true ## Base Layer ################################################################## FROM registry.access.redhat.com/ubi9/ubi:${BASE_UBI_IMAGE_TAG} AS base diff --git a/build/accelerate_launch.py b/build/accelerate_launch.py index d7753728c..50d8eef0c 100644 --- a/build/accelerate_launch.py +++ b/build/accelerate_launch.py @@ -24,18 +24,13 @@ import sys import traceback from pathlib import Path -import json # Third Party from accelerate.commands.launch import launch_command -from transformers import AutoModelForCausalLM, AutoTokenizer -from peft import PeftModel -from torch import bfloat16 # Local from build.utils import ( process_accelerate_launch_args, - get_highest_checkpoint, ) from tuning.utils.config_utils import get_json_config from tuning.utils.error_logging import ( @@ -43,18 +38,10 @@ USER_ERROR_EXIT_CODE, INTERNAL_ERROR_EXIT_CODE, ) -from tuning.data import tokenizer_data_utils ERROR_LOG = "/dev/termination-log" -def get_base_model_from_adapter_config(adapter_config): - """Given path to adapter_config.json file, returns the base model name""" - with open(adapter_config, "r", encoding="utf-8") as config_file: - adapter_config = json.load(config_file) - return adapter_config.get("base_model_name_or_path") - - def main(): if not os.getenv("TERMINATION_LOG_FILE"): os.environ["TERMINATION_LOG_FILE"] = ERROR_LOG @@ -128,85 +115,6 @@ def main(): write_termination_log(f"Unhandled exception during training. {e}") sys.exit(INTERNAL_ERROR_EXIT_CODE) - # remove lm_head from granite with llama arch models - try: - checkpoint_dir = job_config.get("save_model_dir") - if not checkpoint_dir: - checkpoint_dir = os.path.join( - output_dir, get_highest_checkpoint(output_dir) - ) - - use_flash_attn = job_config.get("use_flash_attn", True) - adapter_config_path = os.path.join(checkpoint_dir, "adapter_config.json") - tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir) - - if os.path.exists(adapter_config_path): - base_model_path = get_base_model_from_adapter_config(adapter_config_path) - base_model = AutoModelForCausalLM.from_pretrained( - base_model_path, - attn_implementation="flash_attention_2" if use_flash_attn else None, - torch_dtype=bfloat16 if use_flash_attn else None, - ) - - # since the peft library (PEFTModelForCausalLM) does not handle cases - # where the model's layers are modified, in our case the embedding layer - # is modified, so we resize the backbone model's embedding layer with our own - # utility before passing it along to load the PEFT model. - tokenizer_data_utils.tokenizer_and_embedding_resize( - {}, tokenizer=tokenizer, model=base_model - ) - model = PeftModel.from_pretrained( - base_model, - checkpoint_dir, - attn_implementation="flash_attention_2" if use_flash_attn else None, - torch_dtype=bfloat16 if use_flash_attn else None, - ) - else: - model = AutoModelForCausalLM.from_pretrained( - checkpoint_dir, - attn_implementation="flash_attention_2" if use_flash_attn else None, - torch_dtype=bfloat16 if use_flash_attn else None, - ) - - model_arch = model.config.model_type - # check that it is a granite model with llama architecture with tied weights - # ie. lm_head is duplicate of embeddings - - # a fine tuned model will have params_dict.get("model.embed_tokens.weight") - # a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight") - # a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight") - if model_arch == "llama" and hasattr(model, "lm_head"): - if ( - # lora tuned model has an addt model layer - ( - hasattr(model.model, "model") - and model.lm_head.weight.untyped_storage().data_ptr() - == model.model.model.embed_tokens.weight.untyped_storage().data_ptr() - ) - # prompt tuned model or fine tuned model - or ( - hasattr(model.model, "embed_tokens") - and model.lm_head.weight.untyped_storage().data_ptr() - == model.model.embed_tokens.weight.untyped_storage().data_ptr() - ) - ): - - logging.info("Removing lm_head from checkpoint") - del model.lm_head.weight - - if hasattr(model, "lm_head.weight"): - logging.warning("Failed to delete lm_head.weight from model") - - logging.info("Saving checkpoint to %s", output_dir) - model.save_pretrained(checkpoint_dir) - # save tokenizer with model - tokenizer.save_pretrained(checkpoint_dir) - - except Exception as e: # pylint: disable=broad-except - logging.error(traceback.format_exc()) - write_termination_log(f"Exception encountered removing lm_head from model: {e}") - sys.exit(INTERNAL_ERROR_EXIT_CODE) - # The .complete file will signal to users that we are finished copying # files over if os.path.exists(output_dir):