Skip to content

Commit

Permalink
Adapt weight conversion script
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelburger committed Jul 3, 2024
1 parent 1dd8036 commit a9f30c2
Showing 1 changed file with 38 additions and 5 deletions.
43 changes: 38 additions & 5 deletions examples/llama/convert_nanotron_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
Converts a nanotron model to HF format
Command:
torchrun --nproc_per_node=1 convert_nanotron_to_hf.py --checkpoint_path=nanotron-path --save_path=hf-path
Modified for Petagraph usage.
"""

import os
import json
from argparse import ArgumentParser
from pathlib import Path
Expand Down Expand Up @@ -128,25 +131,55 @@ def check_converted_model_generation(save_path: Path):
"""Loads a huggingface model and tokenizer from `save_path` and
performs a dummy text generation."""

tokenizer = AutoTokenizer.from_pretrained(save_path)
input_ids = tokenizer(TEST_PROMPT, return_tensors="pt")["input_ids"].cuda()
print("Inputs:", tokenizer.batch_decode(input_ids))
# tokenizer = AutoTokenizer.from_pretrained(save_path)
# input_ids = tokenizer(TEST_PROMPT, return_tensors="pt")["input_ids"].cuda()
# print("Inputs:", tokenizer.batch_decode(input_ids))

# Vocabulary is defined in the run_train.py right now
# We need to write a proper tokenizer when we're getting serious
VOCABULARY = {
"BOS": 0, "EOS": 1, "PAD": 2, "UNK": 3,
"A": 4, "C": 5, "G": 6, "T": 7
}
INV_VOCABULARY = {v: k for k, v in VOCABULARY.items()}

TEST_SEQUENCE = "ACGTACGT"
test_sequence = [VOCABULARY["BOS"]] + [VOCABULARY[char] for char in TEST_SEQUENCE]
input_ids = torch.tensor(test_sequence).unsqueeze(0).cuda()
print(f"Test sequence: {TEST_SEQUENCE}")
print(f"Test sequence (converted): {input_ids}")

model = LlamaForCausalLM.from_pretrained(save_path).cuda().bfloat16()
out = model.generate(input_ids, max_new_tokens=100)
print("Generation (converted): ", tokenizer.batch_decode(out))
decoded = [INV_VOCABULARY[token] for token in out[0].tolist()]
print("Generation: ", out)
print("Generation (decoded): ", decoded)

# print("Generation (converted): ", tokenizer.batch_decode(out))



if __name__ == "__main__":
parser = ArgumentParser(description="Convert Nanotron weights to HF format")
parser.add_argument("--checkpoint_path", type=Path, default="llama-7b", help="Path to the checkpoint")
parser.add_argument("--save_path", type=Path, default="llama-7b-hf", help="Path to save the HF model")
parser.add_argument("--tokenizer_name", type=str, default="meta-llama/Llama-2-7b-chat-hf")
parser.add_argument("--set_single_gpu_env", action="store_true", help="Set the environment variables for single gpu", default=False)
args = parser.parse_args()

print(f"Converting Nanotron model in {args.checkpoint_path} to HF format and saving to {args.save_path}")
if args.set_single_gpu_env:
print("Setting environment variables for single GPU")
torch.cuda.set_device(0)

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
os.environ["WORLD_SIZE"] = "1"

# Convert Nanotron model to HF format.
convert_checkpoint_and_save(
checkpoint_path=args.checkpoint_path, save_path=args.save_path, tokenizer_name=args.tokenizer_name
checkpoint_path=args.checkpoint_path, save_path=args.save_path, tokenizer_name=None
)

# Check if the conversion was successful by generating some text.
Expand Down

0 comments on commit a9f30c2

Please sign in to comment.