Skip to content

Commit

Permalink
utilities to post process checkpoint for LoRA
Browse files Browse the repository at this point in the history
Signed-off-by: Sukriti-Sharma4 <[email protected]>
  • Loading branch information
Ssukriti committed Sep 10, 2024
1 parent 5946949 commit fa42c73
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions tuning/utils/merge_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,68 @@ def fetch_base_model_from_checkpoint(checkpoint_model: str) -> str:
"Base model adapter config exists, but has no base_model_name_or_path!"
)
return adapter_dict["base_model_name_or_path"]

def _copy_files_to_directory(src: str , dest: str, exclude_files:list[str] = ["adapter_model.safetensors"]):
import shutil

src_files = os.listdir(src)
for file_name in src_files:
if file_name in exclude_files:
continue
full_file_name = os.path.join(src, file_name)
if os.path.isfile(full_file_name):
shutil.copy(full_file_name, dest)

def post_process_vLLM_adapters_new_tokens(path_to_checkpoint: str, modified_checkpoint_path:str = None):

from safetensors import safe_open
from safetensors.torch import save_file


if not modified_checkpoint_path:
modified_checkpoint_path = path_to_checkpoint

sorted_token_indexes = []
if os.path.isfile(os.path.join(path_to_checkpoint,"added_tokens.json")):
with open(os.path.join(path_to_checkpoint,"added_tokens.json"), "r") as fp:
added_tokens = json.load(fp)
sorted_token_indexes = sorted(added_tokens.values())

with safe_open(os.path.join(path_to_checkpoint, "adapter_model.safetensors"), framework="pt") as f:
new_embeddings = {}
adapters = {}
embeddings_weights_in_adapters = False
for k in f.keys():
if 'lm_head.weight' in k or 'embed_tokens.weight' in k:
embeddings_weights_in_adapters = True
if len(sorted_token_indexes) >=1:
raise NotImplementedError("Seems like embeddings are resized without adding new tokens. \
Cannot be post-processed to load on vLLM.")

if embeddings_weights_in_adapters:
for k in f.keys():
if 'lm_head.weight' in k:
lm_head = f.get_tensor(k)
if len(sorted_token_indexes)==1:
new_output_embeddings = lm_head[sorted_token_indexes[0]:sorted_token_indexes[0]+1]
elif len(sorted_token_indexes)>1:
new_output_embeddings = lm_head[sorted_token_indexes[0]:sorted_token_indexes[-1]]
new_embeddings['output_embeddings'] = new_output_embeddings

elif 'embed_tokens.weight' in k:
embed_tokens = f.get_tensor(k)
if len(sorted_token_indexes)==1:
new_input_embeddings = embed_tokens[sorted_token_indexes[0]:sorted_token_indexes[0]+1]
elif len(sorted_token_indexes)>1:
new_input_embeddings = embed_tokens[sorted_token_indexes[0]:sorted_token_indexes[-1]]
new_embeddings['input_embeddings'] = new_input_embeddings
else:
adapters[k] = f.get_tensors(k)

save_file(new_embeddings, os.path.join(modified_checkpoint_path, "new_embeddings.safetensors"))
save_file(adapters, os.path.join(modified_checkpoint_path, "adapter_model.safetensors"))

if modified_checkpoint_path != path_to_checkpoint:
_copy_files_to_directory(path_to_checkpoint, modified_checkpoint_path)


0 comments on commit fa42c73

Please sign in to comment.