Skip to content

Commit

Permalink
post processing script
Browse files Browse the repository at this point in the history
Signed-off-by: Sukriti-Sharma4 <[email protected]>
  • Loading branch information
Ssukriti committed Sep 20, 2024
1 parent 2c0c206 commit 5e4b796
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions scripts/post_process_adapters_vLLM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Standard
import argparse
import json
import os


# Local
from tuning.utils.merge_model_utils import post_process_vLLM_adapters_new_tokens

### Main & arg parsing
def main():
parser = argparse.ArgumentParser(
description="Post processes adapters due to addition of new tokens, as needed by vLLM"
)
parser.add_argument(
"--model_path", help="Path to tuned model containing either one or multiple checkpoints \
Path should have file added_tokens_info.json produced by tuning \
Hint: This will be either output_dir specified while tuning or save_model_dir", required=True
)
parser.add_argument(
"--output_model_path", help="Output directory where post-processed artifacts will be stored. \
If not provided, artifacts will be modified in place", default=None
)
args = parser.parse_args()

if os.path.exists(os.path.join(args.model_path, 'added_tokens_info.json')):
with open(os.path.join(args.model_path, 'added_tokens_info.json')) as json_data:
added_tokens_info = json.loads(json_data)
num_added_tokens = added_tokens_info["num_added_tokens"]
else:
print("file added_tokens_info.json not in model_path. Cannot post-processes")
post_process_vLLM_adapters_new_tokens(args.model_path, args.output_model_path, num_added_tokens)

0 comments on commit 5e4b796

Please sign in to comment.