Skip to content

Commit

Permalink
refactor saving tokens metadata
Browse files Browse the repository at this point in the history
Signed-off-by: Sukriti-Sharma4 <[email protected]>
  • Loading branch information
Ssukriti committed Sep 20, 2024
1 parent c8d8f98 commit e799f7d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 21 deletions.
2 changes: 1 addition & 1 deletion tuning/data/tokenizer_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def tokenizer_and_embedding_resize(

input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
return num_new_tokens
return {"num_new_tokens": num_new_tokens, "new_embedding_size": embedding_size}
36 changes: 16 additions & 20 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def train(

# TODO: lower priority but understand if resizing impacts inference quality and why its needed.
# It makes sense if we manipulate tokenizer that we also save it and provide it to inference.
num_added_tokens = tokenizer_data_utils.tokenizer_and_embedding_resize(
added_tokens_dict = tokenizer_data_utils.tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens_dict,
tokenizer=tokenizer,
model=model,
Expand Down Expand Up @@ -411,8 +411,9 @@ def train(
)

trainer.train(resume_from_checkpoint)

return trainer, num_added_tokens
additional_metadata = {}
additional_metadata["added_tokens_info"] = added_tokens_dict
return trainer, additional_metadata


def save(path: str, trainer: SFTTrainer, log_level="WARNING"):
Expand Down Expand Up @@ -463,13 +464,7 @@ def get_parser():
choices=["pt", "lora", None, "none"],
default="none",
)
parser.add_argument(
"--post_process_vllm",
type=bool,
default=False,
help="Bool to indicate if post processing of LoRA adapters for vLLM \
is required.",
)

parser.add_argument(
"--exp_metadata",
type=str,
Expand Down Expand Up @@ -644,7 +639,7 @@ def main():
combined_tracker_configs.aim_config = aim_config

try:
trainer, num_added_tokens = train(
trainer, additional_train_info = train(
model_args=model_args,
data_args=data_args,
train_args=training_args,
Expand Down Expand Up @@ -697,19 +692,20 @@ def main():
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)

# post process lora
if post_process_vllm and isinstance(tune_config, peft_config.LoraConfig):
if isinstance(tune_config, peft_config.LoraConfig):
try:
checkpoint_dir = training_args.save_model_dir
if checkpoint_dir:
print(f"Post processing LoRA adapters in {checkpoint_dir}")
post_process_vLLM_adapters_new_tokens(
path_to_checkpoint=checkpoint_dir, num_added_tokens=num_added_tokens
)
if training_args.save_model_dir:
# Write number of added tokens to artifacts
with open(os.path.join(training_args.save_model_dir,'added_tokens_info.json'), 'w') as f:
json.dump(additional_train_info["added_tokens_info"], f)
if training_args.output_dir:
# Write number of added tokens to artifacts
with open(os.path.join(training_args.output_dir,'added_tokens_info.json'), 'w') as f:
json.dump(additional_train_info["added_tokens_info"], f)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
f"Exception encountered while lora post-processing model: {e}"
f"Exception encountered when saving metadata with model artifacts: {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)

Expand Down

0 comments on commit e799f7d

Please sign in to comment.