From a9f30c251fc518dba4dae0c149ee3005de3d45bb Mon Sep 17 00:00:00 2001 From: Manuel Burger Date: Wed, 3 Jul 2024 10:49:07 +0200 Subject: [PATCH] Adapt weight conversion script --- examples/llama/convert_nanotron_to_hf.py | 43 +++++++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/examples/llama/convert_nanotron_to_hf.py b/examples/llama/convert_nanotron_to_hf.py index e11b27da..dbdb0497 100644 --- a/examples/llama/convert_nanotron_to_hf.py +++ b/examples/llama/convert_nanotron_to_hf.py @@ -2,8 +2,11 @@ Converts a nanotron model to HF format Command: torchrun --nproc_per_node=1 convert_nanotron_to_hf.py --checkpoint_path=nanotron-path --save_path=hf-path + +Modified for Petagraph usage. """ +import os import json from argparse import ArgumentParser from pathlib import Path @@ -128,13 +131,32 @@ def check_converted_model_generation(save_path: Path): """Loads a huggingface model and tokenizer from `save_path` and performs a dummy text generation.""" - tokenizer = AutoTokenizer.from_pretrained(save_path) - input_ids = tokenizer(TEST_PROMPT, return_tensors="pt")["input_ids"].cuda() - print("Inputs:", tokenizer.batch_decode(input_ids)) + # tokenizer = AutoTokenizer.from_pretrained(save_path) + # input_ids = tokenizer(TEST_PROMPT, return_tensors="pt")["input_ids"].cuda() + # print("Inputs:", tokenizer.batch_decode(input_ids)) + + # Vocabulary is defined in the run_train.py right now + # We need to write a proper tokenizer when we're getting serious + VOCABULARY = { + "BOS": 0, "EOS": 1, "PAD": 2, "UNK": 3, + "A": 4, "C": 5, "G": 6, "T": 7 + } + INV_VOCABULARY = {v: k for k, v in VOCABULARY.items()} + + TEST_SEQUENCE = "ACGTACGT" + test_sequence = [VOCABULARY["BOS"]] + [VOCABULARY[char] for char in TEST_SEQUENCE] + input_ids = torch.tensor(test_sequence).unsqueeze(0).cuda() + print(f"Test sequence: {TEST_SEQUENCE}") + print(f"Test sequence (converted): {input_ids}") model = LlamaForCausalLM.from_pretrained(save_path).cuda().bfloat16() out = model.generate(input_ids, max_new_tokens=100) - print("Generation (converted): ", tokenizer.batch_decode(out)) + decoded = [INV_VOCABULARY[token] for token in out[0].tolist()] + print("Generation: ", out) + print("Generation (decoded): ", decoded) + + # print("Generation (converted): ", tokenizer.batch_decode(out)) + if __name__ == "__main__": @@ -142,11 +164,22 @@ def check_converted_model_generation(save_path: Path): parser.add_argument("--checkpoint_path", type=Path, default="llama-7b", help="Path to the checkpoint") parser.add_argument("--save_path", type=Path, default="llama-7b-hf", help="Path to save the HF model") parser.add_argument("--tokenizer_name", type=str, default="meta-llama/Llama-2-7b-chat-hf") + parser.add_argument("--set_single_gpu_env", action="store_true", help="Set the environment variables for single gpu", default=False) args = parser.parse_args() + print(f"Converting Nanotron model in {args.checkpoint_path} to HF format and saving to {args.save_path}") + if args.set_single_gpu_env: + print("Setting environment variables for single GPU") + torch.cuda.set_device(0) + + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + os.environ["WORLD_SIZE"] = "1" + # Convert Nanotron model to HF format. convert_checkpoint_and_save( - checkpoint_path=args.checkpoint_path, save_path=args.save_path, tokenizer_name=args.tokenizer_name + checkpoint_path=args.checkpoint_path, save_path=args.save_path, tokenizer_name=None ) # Check if the conversion was successful by generating some text.