Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove lm_head post processing #333

Merged
merged 3 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ARG PYTHON_VERSION=3.11
ARG WHEEL_VERSION=""
## Enable Aimstack if requested via ENABLE_AIM set to "true"
ARG ENABLE_AIM=false
ARG ENABLE_FMS_ACCELERATION=false
ARG ENABLE_FMS_ACCELERATION=true

## Base Layer ##################################################################
FROM registry.access.redhat.com/ubi9/ubi:${BASE_UBI_IMAGE_TAG} AS base
Expand Down
92 changes: 0 additions & 92 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,24 @@
import sys
import traceback
from pathlib import Path
import json

# Third Party
from accelerate.commands.launch import launch_command
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from torch import bfloat16

# Local
from build.utils import (
process_accelerate_launch_args,
get_highest_checkpoint,
)
from tuning.utils.config_utils import get_json_config
from tuning.utils.error_logging import (
write_termination_log,
USER_ERROR_EXIT_CODE,
INTERNAL_ERROR_EXIT_CODE,
)
from tuning.data import tokenizer_data_utils

ERROR_LOG = "/dev/termination-log"


def get_base_model_from_adapter_config(adapter_config):
"""Given path to adapter_config.json file, returns the base model name"""
with open(adapter_config, "r", encoding="utf-8") as config_file:
adapter_config = json.load(config_file)
return adapter_config.get("base_model_name_or_path")


def main():
if not os.getenv("TERMINATION_LOG_FILE"):
os.environ["TERMINATION_LOG_FILE"] = ERROR_LOG
Expand Down Expand Up @@ -128,85 +115,6 @@ def main():
write_termination_log(f"Unhandled exception during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

# remove lm_head from granite with llama arch models
try:
checkpoint_dir = job_config.get("save_model_dir")
if not checkpoint_dir:
checkpoint_dir = os.path.join(
output_dir, get_highest_checkpoint(output_dir)
)

use_flash_attn = job_config.get("use_flash_attn", True)
adapter_config_path = os.path.join(checkpoint_dir, "adapter_config.json")
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)

if os.path.exists(adapter_config_path):
base_model_path = get_base_model_from_adapter_config(adapter_config_path)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)

# since the peft library (PEFTModelForCausalLM) does not handle cases
# where the model's layers are modified, in our case the embedding layer
# is modified, so we resize the backbone model's embedding layer with our own
# utility before passing it along to load the PEFT model.
tokenizer_data_utils.tokenizer_and_embedding_resize(
{}, tokenizer=tokenizer, model=base_model
)
model = PeftModel.from_pretrained(
base_model,
checkpoint_dir,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)
else:
model = AutoModelForCausalLM.from_pretrained(
checkpoint_dir,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)

model_arch = model.config.model_type
# check that it is a granite model with llama architecture with tied weights
# ie. lm_head is duplicate of embeddings

# a fine tuned model will have params_dict.get("model.embed_tokens.weight")
# a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight")
# a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight")
if model_arch == "llama" and hasattr(model, "lm_head"):
if (
# lora tuned model has an addt model layer
(
hasattr(model.model, "model")
and model.lm_head.weight.untyped_storage().data_ptr()
== model.model.model.embed_tokens.weight.untyped_storage().data_ptr()
)
# prompt tuned model or fine tuned model
or (
hasattr(model.model, "embed_tokens")
and model.lm_head.weight.untyped_storage().data_ptr()
== model.model.embed_tokens.weight.untyped_storage().data_ptr()
)
):

logging.info("Removing lm_head from checkpoint")
del model.lm_head.weight

if hasattr(model, "lm_head.weight"):
logging.warning("Failed to delete lm_head.weight from model")

logging.info("Saving checkpoint to %s", output_dir)
model.save_pretrained(checkpoint_dir)
# save tokenizer with model
tokenizer.save_pretrained(checkpoint_dir)

except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(f"Exception encountered removing lm_head from model: {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

# The .complete file will signal to users that we are finished copying
# files over
if os.path.exists(output_dir):
Expand Down
Loading