Skip to content
This repository has been archived by the owner on Sep 30, 2023. It is now read-only.

The new CodeGen2 Release #23

Open
czkoko opened this issue May 6, 2023 · 3 comments
Open

The new CodeGen2 Release #23

czkoko opened this issue May 6, 2023 · 3 comments

Comments

@czkoko
Copy link

czkoko commented May 6, 2023

Salesforce releases a new model codegen2, CodeGen2 is capable of infilling, and supports more programming languages.
https://huggingface.co/Salesforce/codegen2-1B
I tested codegen2-1B, and the code quality has improved, but it seems that the conversion script for codegen is no longer suitable to codegen2.
Codegen2 > gpt-j > ggml conversion is successful, no error prompt appears, but the output is abnormal code.

For example, the return:

<mask_52>...(<mask_1>hasportvt(porthandleport
,formhasvarformvarform(ODUCT(
((form((vchas<mask_52>,polyportvtpoly#(<mask_1>formoform(
in<mask_1>polyhandle(form,(...haswisetextpm,wisewise
(portformoform<mask_1>var(text(form/

The following is the conversion script Codegen2 > gpt-j

#!/usr/bin/env python3

import torch
from transformers import GPTJForCausalLM, GPTJConfig
# Note: these need the git version of Transformers as of 7/22/2022
from transformers import AutoTokenizer, AutoModelForCausalLM

print("Creating tokenizer")
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen2-1B")

print('Loading CodeGen model')
cg_model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen2-1B", trust_remote_code=True)
cg_config = cg_model.config

# Create empty GPTJ model
print('Creating empty GPTJ model')
config = GPTJConfig(
    vocab_size=cg_config.vocab_size,
    n_positions=cg_config.n_positions,
    n_embd=cg_config.n_embd,
    n_layer=cg_config.n_layer,
    n_head=cg_config.n_head,
    rotary_dim=cg_config.rotary_dim,
    n_inner=cg_config.n_inner,
    activation_function=cg_config.activation_function,
    resid_pdrop=cg_config.resid_pdrop,
    embd_pdrop=cg_config.embd_pdrop,
    attn_pdrop=cg_config.attn_pdrop,
    layer_norm_epsilon=cg_config.layer_norm_epsilon,
    initializer_range=cg_config.initializer_range,
    scale_attn_weights=cg_config.scale_attn_weights,
    use_cache=cg_config.use_cache,
    bos_token_id=cg_config.bos_token_id,
    eos_token_id=cg_config.eos_token_id,
    torch_dtype=cg_config.torch_dtype,
)

# Fix tokenizer type
config.tokenizer_class = 'CodeGenTokenizer'

gptj_model = GPTJForCausalLM(config)
embed_dim = config.n_embd

# Sample input for validating the conversion went OK
inputs = tokenizer.encode('fun download(url string', return_tensors='pt')

def replace(model, weights, name):
    model.state_dict()[name].copy_(weights.detach())

def replace_by_name(dest_model, src_model, old_name, new_name):
    assert old_name in src_model.state_dict()
    assert new_name in dest_model.state_dict()
    replace(dest_model, src_model.state_dict()[old_name], new_name)

print('Converting...')
# Copy weights from CodeGen model
with torch.no_grad():
    cg_model.eval()
    gptj_model.eval()
    
    for name, param in cg_model.named_parameters():
        # Handle the qkv weights separately because we need to split them
        if 'qkv_proj' in name:
            qkv_proj = param.detach().clone()
            mp_num = 4 # number of cores on their TPU I guess?
            local_dim = embed_dim // mp_num
            # GPT-J and CodeGen slice up the qkv projection slightly differently.
            # After a great deal of pain, I figured out that this permutation on
            # the weights of the qkv_proj fixes it.
            base_permutation = [0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11]
            permutation = torch.cat([torch.arange(i*local_dim, (i+1)*local_dim) for i in base_permutation])
            # NB: we permute the *rows* here because the computation is xA.T
            new_qkv_proj = qkv_proj[permutation,:]
            # NB: the name QKV is misleading here; they are actually stored in
            #     the order QVK
            query, value, key = torch.split(new_qkv_proj, embed_dim, dim=0)
            replace(gptj_model, query, name.replace('qkv_proj', 'q_proj'))
            replace(gptj_model, key, name.replace('qkv_proj', 'k_proj'))
            replace(gptj_model, value, name.replace('qkv_proj', 'v_proj'))
        else:
            replace_by_name(gptj_model, cg_model, name, name)

    print('Conversion complete, running inference')
    cg_out = cg_model.generate(inputs, min_length=32, max_length=32, do_sample=False, pad_token_id=50256)
    gptj_out = gptj_model.generate(inputs, min_length=32, max_length=32, do_sample=False, pad_token_id=50256)
    print(cg_out[0])
    print(gptj_out[0])
    cg_dec, gptj_dec = tokenizer.batch_decode(torch.stack([cg_out,gptj_out]).squeeze())
    print("====== CodeGen ======")
    print(cg_dec)
    print("======  GPT-J  ======")
    print(gptj_dec)
    assert cg_dec == gptj_dec
    print("Saving model...")
    gptj_model.save_pretrained("gpt-j")
    tokenizer.save_pretrained("gpt-j")
@ravenscroftj
Copy link
Owner

hey there - yes I definitely plan to allow the use of the new codegen model and possibly others like Santacoder.

The new models have a slightly different model architecture to the original codegen models which means that there will need to be some modification to get them to work - the best solution may be to directly implement the new archtecture in GGML rather than converting via GPT-j - I need to do some exploration to work out the best path.

@czkoko
Copy link
Author

czkoko commented May 24, 2023

After some exploration, I have completed the following conversion script, and can directly convert the original codegen2 model to ggml, There is no need to convert to GPTJ first.

The codegen2-1B successful operation, and the output of codegen2-7B seems to be abnormal.

import sys
import struct
import json
import torch
import numpy as np
from accelerate import init_empty_weights

from transformers import AutoModelForCausalLM, AutoTokenizer

def bytes_to_unicode():
    bs = (
        list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
    )
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

if len(sys.argv) < 2:
    print("Usage: codegen2-to-ggml.py codegen2-1B(dir)\n")
    sys.exit(1)

# output in the same directory as the model
dir_model = sys.argv[1]

with open(dir_model + "/vocab.json", "r", encoding="utf8") as f:
    encoder = json.load(f)

with open(dir_model + "/added_tokens.json", "r") as f:
    encoder_added = json.load(f)

with open(dir_model + "/config.json", "r") as f:
    hparams = json.load(f)

ftype = 0
fname_out = sys.argv[1] + "/ggml-model-f32.bin"

model = AutoModelForCausalLM.from_pretrained(dir_model, trust_remote_code=True, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(dir_model)

list_vars = model.state_dict()

fout = open(fname_out, "wb")

fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
fout.write(struct.pack("i", hparams['vocab_size']))
fout.write(struct.pack("i", hparams["n_positions"]))
fout.write(struct.pack("i", hparams["n_embd"]))
fout.write(struct.pack("i", hparams["n_head"]))
fout.write(struct.pack("i", hparams["n_layer"]))
fout.write(struct.pack("i", hparams["rotary_dim"]))
fout.write(struct.pack("i", ftype))

byte_encoder = bytes_to_unicode()
byte_decoder = {v:k for k, v in byte_encoder.items()}

fout.write(struct.pack("i", hparams['vocab_size']))

for word,idx in sorted(tokenizer.vocab.items(), key=lambda x: x[1]) :
    text = bytearray([byte_decoder[c] for c in word if c in byte_decoder])

    if(len(text)) < 1:
        text = bytearray(word.encode('utf8'))

    fout.write(struct.pack("i", len(text)))
    fout.write(text)

empty_vocab = hparams['vocab_size'] - tokenizer.vocab_size

for i in range( hparams['vocab_size'] - len(encoder) - len(encoder_added)):
    text = "<|endoftext|>".encode("utf8")
    fout.write(struct.pack("i", len(text)))
    fout.write(text)

new_list_vars = {}

for name in list_vars.keys():
    if name.endswith("attn.qkv_proj.weight"):
        data = list_vars[name]
        n_dims = len(data.shape)
        assert n_dims == 2
        n_embd = hparams["n_embd"]
        q_unshaped, v_unshaped, k_unshaped = torch.split(data.reshape(8, -1, n_embd), n_embd//8, dim=1)
        q_shaped, v_shaped, k_shaped = (q_unshaped.reshape(-1, n_embd), v_unshaped.reshape(-1, n_embd), k_unshaped.reshape(-1, n_embd))
        new_list_vars[name.replace(".qkv_proj.", ".q_proj.")] = q_shaped
        new_list_vars[name.replace(".qkv_proj.", ".v_proj.")] = v_shaped
        new_list_vars[name.replace(".qkv_proj.", ".k_proj.")] = k_shaped
    else:
        new_list_vars[name] = list_vars[name]
        
list_vars = new_list_vars


for name in list_vars.keys():
    data = list_vars[name].squeeze().numpy()

    if name.endswith("attn.masked_bias") or name.endswith(".attn.bias") or name.endswith("attn.causal_mask"):
        continue

    n_dims = len(data.shape);
    ftype_cur = 0;

    if data.dtype != np.float32:
        print("  Converting to float32")
        data = data.astype(np.float32)
        ftype_cur = 0

    str = name.encode('utf-8')

    fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
    for i in range(n_dims):
        fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
    fout.write(str);
    data.tofile(fout)

fout.close()

print("Done. Output file: " + fname_out)
print("")

@ravenscroftj
Copy link
Owner

hey there @czkoko - this is great progress thank you for contributing. I'd love to add this script to the repo which will allow us to support the 1b codegen2 model - along with merging recent ggml libraries that support starcoder this will allow Turbopilot to support some really interesting new use cases and additional libraries

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants