Skip to content

Commit

Permalink
improve code comments
Browse files Browse the repository at this point in the history
Signed-off-by: Sukriti-Sharma4 <[email protected]>
  • Loading branch information
Ssukriti committed Sep 10, 2024
1 parent e5e4c27 commit 0fa3dac
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions tuning/utils/merge_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,8 @@ def fetch_base_model_from_checkpoint(checkpoint_model: str) -> str:
)
return adapter_dict["base_model_name_or_path"]

def _copy_files_to_directory(src: str , dest: str, exclude_files:list[str] = ["adapter_model.safetensors"]):
def _copy_files_to_directory(src: str , dest: str, exclude_files:list[str] = []):
import shutil

src_files = os.listdir(src)
for file_name in src_files:
if file_name in exclude_files:
Expand All @@ -119,10 +118,11 @@ def post_process_vLLM_adapters_new_tokens(path_to_checkpoint: str, modified_chec
from safetensors import safe_open
from safetensors.torch import save_file


# if not set, original checkpoint will be modified
if not modified_checkpoint_path:
modified_checkpoint_path = path_to_checkpoint

# Get all values of new token indexes
sorted_token_indexes = []
if os.path.isfile(os.path.join(path_to_checkpoint,"added_tokens.json")):
with open(os.path.join(path_to_checkpoint,"added_tokens.json"), "r") as fp:
Expand All @@ -133,37 +133,46 @@ def post_process_vLLM_adapters_new_tokens(path_to_checkpoint: str, modified_chec
new_embeddings = {}
adapters = {}
embeddings_weights_in_adapters = False
# Quickly check if post-processing is needed by checking adapters file for weights
for k in f.keys():
if 'lm_head.weight' in k or 'embed_tokens.weight' in k:
embeddings_weights_in_adapters = True
if len(sorted_token_indexes) >=1:
raise NotImplementedError("Seems like embeddings are resized without adding new tokens. \
Cannot be post-processed to load on vLLM.")
Cannot be post-processed to load on vLLM. Try setting \
parameter `embedding_size_multiple_of` to 1" )

# Post-processing is needed to copy out new vectors
if embeddings_weights_in_adapters:
for k in f.keys():
if 'lm_head.weight' in k:
lm_head = f.get_tensor(k)
# pull out tensor values of new tokens
if len(sorted_token_indexes)==1:
new_output_embeddings = lm_head[sorted_token_indexes[0]:sorted_token_indexes[0]+1]
elif len(sorted_token_indexes)>1:
new_output_embeddings = lm_head[sorted_token_indexes[0]:sorted_token_indexes[-1]]
# vLLM requires renaming to output_embeddings
new_embeddings['output_embeddings'] = new_output_embeddings

elif 'embed_tokens.weight' in k:
embed_tokens = f.get_tensor(k)
# pull out tensor values of new tokens
if len(sorted_token_indexes)==1:
new_input_embeddings = embed_tokens[sorted_token_indexes[0]:sorted_token_indexes[0]+1]
elif len(sorted_token_indexes)>1:
new_input_embeddings = embed_tokens[sorted_token_indexes[0]:sorted_token_indexes[-1]]
# vLLM requires renaming to input_embeddings
new_embeddings['input_embeddings'] = new_input_embeddings
else:
# Retain all other weights in adapters.safetensors
adapters[k] = f.get_tensors(k)

save_file(new_embeddings, os.path.join(modified_checkpoint_path, "new_embeddings.safetensors"))
save_file(adapters, os.path.join(modified_checkpoint_path, "adapter_model.safetensors"))

# copy out remaining files to desired path
if modified_checkpoint_path != path_to_checkpoint:
_copy_files_to_directory(path_to_checkpoint, modified_checkpoint_path)
_copy_files_to_directory(path_to_checkpoint, modified_checkpoint_path, exclude_files = ["adapter_model.safetensors"])


0 comments on commit 0fa3dac

Please sign in to comment.