Skip to content

Commit

Permalink
fix: remove lm_head post processing (#333)
Browse files Browse the repository at this point in the history
* fix: Removal of lm head hack

Signed-off-by: Abhishek <[email protected]>

* set fms_accelerate to true by default

Signed-off-by: Anh Uong <[email protected]>

---------

Signed-off-by: Abhishek <[email protected]>
Signed-off-by: Anh Uong <[email protected]>
Co-authored-by: Anh Uong <[email protected]>
Signed-off-by: Angel Luu <[email protected]>
  • Loading branch information
2 people authored and aluu317 committed Sep 13, 2024
1 parent 427202f commit b15a07b
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 93 deletions.
2 changes: 1 addition & 1 deletion build/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ARG PYTHON_VERSION=3.11
ARG WHEEL_VERSION=""
## Enable Aimstack if requested via ENABLE_AIM set to "true"
ARG ENABLE_AIM=false
ARG ENABLE_FMS_ACCELERATION=false
ARG ENABLE_FMS_ACCELERATION=true

## Base Layer ##################################################################
FROM registry.access.redhat.com/ubi9/ubi:${BASE_UBI_IMAGE_TAG} AS base
Expand Down
92 changes: 0 additions & 92 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,24 @@
import sys
import traceback
from pathlib import Path
import json

# Third Party
from accelerate.commands.launch import launch_command
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from torch import bfloat16

# Local
from build.utils import (
process_accelerate_launch_args,
get_highest_checkpoint,
)
from tuning.utils.config_utils import get_json_config
from tuning.utils.error_logging import (
write_termination_log,
USER_ERROR_EXIT_CODE,
INTERNAL_ERROR_EXIT_CODE,
)
from tuning.data import tokenizer_data_utils

ERROR_LOG = "/dev/termination-log"


def get_base_model_from_adapter_config(adapter_config):
"""Given path to adapter_config.json file, returns the base model name"""
with open(adapter_config, "r", encoding="utf-8") as config_file:
adapter_config = json.load(config_file)
return adapter_config.get("base_model_name_or_path")


def main():
if not os.getenv("TERMINATION_LOG_FILE"):
os.environ["TERMINATION_LOG_FILE"] = ERROR_LOG
Expand Down Expand Up @@ -128,85 +115,6 @@ def main():
write_termination_log(f"Unhandled exception during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

# remove lm_head from granite with llama arch models
try:
checkpoint_dir = job_config.get("save_model_dir")
if not checkpoint_dir:
checkpoint_dir = os.path.join(
output_dir, get_highest_checkpoint(output_dir)
)

use_flash_attn = job_config.get("use_flash_attn", True)
adapter_config_path = os.path.join(checkpoint_dir, "adapter_config.json")
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)

if os.path.exists(adapter_config_path):
base_model_path = get_base_model_from_adapter_config(adapter_config_path)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)

# since the peft library (PEFTModelForCausalLM) does not handle cases
# where the model's layers are modified, in our case the embedding layer
# is modified, so we resize the backbone model's embedding layer with our own
# utility before passing it along to load the PEFT model.
tokenizer_data_utils.tokenizer_and_embedding_resize(
{}, tokenizer=tokenizer, model=base_model
)
model = PeftModel.from_pretrained(
base_model,
checkpoint_dir,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)
else:
model = AutoModelForCausalLM.from_pretrained(
checkpoint_dir,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)

model_arch = model.config.model_type
# check that it is a granite model with llama architecture with tied weights
# ie. lm_head is duplicate of embeddings

# a fine tuned model will have params_dict.get("model.embed_tokens.weight")
# a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight")
# a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight")
if model_arch == "llama" and hasattr(model, "lm_head"):
if (
# lora tuned model has an addt model layer
(
hasattr(model.model, "model")
and model.lm_head.weight.untyped_storage().data_ptr()
== model.model.model.embed_tokens.weight.untyped_storage().data_ptr()
)
# prompt tuned model or fine tuned model
or (
hasattr(model.model, "embed_tokens")
and model.lm_head.weight.untyped_storage().data_ptr()
== model.model.embed_tokens.weight.untyped_storage().data_ptr()
)
):

logging.info("Removing lm_head from checkpoint")
del model.lm_head.weight

if hasattr(model, "lm_head.weight"):
logging.warning("Failed to delete lm_head.weight from model")

logging.info("Saving checkpoint to %s", output_dir)
model.save_pretrained(checkpoint_dir)
# save tokenizer with model
tokenizer.save_pretrained(checkpoint_dir)

except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(f"Exception encountered removing lm_head from model: {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

# The .complete file will signal to users that we are finished copying
# files over
if os.path.exists(output_dir):
Expand Down

0 comments on commit b15a07b

Please sign in to comment.