-
Notifications
You must be signed in to change notification settings - Fork 126
The new CodeGen2 Release #23
Comments
hey there - yes I definitely plan to allow the use of the new codegen model and possibly others like Santacoder. The new models have a slightly different model architecture to the original codegen models which means that there will need to be some modification to get them to work - the best solution may be to directly implement the new archtecture in GGML rather than converting via GPT-j - I need to do some exploration to work out the best path. |
After some exploration, I have completed the following conversion script, and can directly convert the original codegen2 model to ggml, There is no need to convert to GPTJ first. The codegen2-1B successful operation, and the output of codegen2-7B seems to be abnormal. import sys
import struct
import json
import torch
import numpy as np
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM, AutoTokenizer
def bytes_to_unicode():
bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
if len(sys.argv) < 2:
print("Usage: codegen2-to-ggml.py codegen2-1B(dir)\n")
sys.exit(1)
# output in the same directory as the model
dir_model = sys.argv[1]
with open(dir_model + "/vocab.json", "r", encoding="utf8") as f:
encoder = json.load(f)
with open(dir_model + "/added_tokens.json", "r") as f:
encoder_added = json.load(f)
with open(dir_model + "/config.json", "r") as f:
hparams = json.load(f)
ftype = 0
fname_out = sys.argv[1] + "/ggml-model-f32.bin"
model = AutoModelForCausalLM.from_pretrained(dir_model, trust_remote_code=True, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(dir_model)
list_vars = model.state_dict()
fout = open(fname_out, "wb")
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
fout.write(struct.pack("i", hparams['vocab_size']))
fout.write(struct.pack("i", hparams["n_positions"]))
fout.write(struct.pack("i", hparams["n_embd"]))
fout.write(struct.pack("i", hparams["n_head"]))
fout.write(struct.pack("i", hparams["n_layer"]))
fout.write(struct.pack("i", hparams["rotary_dim"]))
fout.write(struct.pack("i", ftype))
byte_encoder = bytes_to_unicode()
byte_decoder = {v:k for k, v in byte_encoder.items()}
fout.write(struct.pack("i", hparams['vocab_size']))
for word,idx in sorted(tokenizer.vocab.items(), key=lambda x: x[1]) :
text = bytearray([byte_decoder[c] for c in word if c in byte_decoder])
if(len(text)) < 1:
text = bytearray(word.encode('utf8'))
fout.write(struct.pack("i", len(text)))
fout.write(text)
empty_vocab = hparams['vocab_size'] - tokenizer.vocab_size
for i in range( hparams['vocab_size'] - len(encoder) - len(encoder_added)):
text = "<|endoftext|>".encode("utf8")
fout.write(struct.pack("i", len(text)))
fout.write(text)
new_list_vars = {}
for name in list_vars.keys():
if name.endswith("attn.qkv_proj.weight"):
data = list_vars[name]
n_dims = len(data.shape)
assert n_dims == 2
n_embd = hparams["n_embd"]
q_unshaped, v_unshaped, k_unshaped = torch.split(data.reshape(8, -1, n_embd), n_embd//8, dim=1)
q_shaped, v_shaped, k_shaped = (q_unshaped.reshape(-1, n_embd), v_unshaped.reshape(-1, n_embd), k_unshaped.reshape(-1, n_embd))
new_list_vars[name.replace(".qkv_proj.", ".q_proj.")] = q_shaped
new_list_vars[name.replace(".qkv_proj.", ".v_proj.")] = v_shaped
new_list_vars[name.replace(".qkv_proj.", ".k_proj.")] = k_shaped
else:
new_list_vars[name] = list_vars[name]
list_vars = new_list_vars
for name in list_vars.keys():
data = list_vars[name].squeeze().numpy()
if name.endswith("attn.masked_bias") or name.endswith(".attn.bias") or name.endswith("attn.causal_mask"):
continue
n_dims = len(data.shape);
ftype_cur = 0;
if data.dtype != np.float32:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(str);
data.tofile(fout)
fout.close()
print("Done. Output file: " + fname_out)
print("") |
hey there @czkoko - this is great progress thank you for contributing. I'd love to add this script to the repo which will allow us to support the 1b codegen2 model - along with merging recent ggml libraries that support starcoder this will allow Turbopilot to support some really interesting new use cases and additional libraries |
Salesforce releases a new model codegen2, CodeGen2 is capable of infilling, and supports more programming languages.
https://huggingface.co/Salesforce/codegen2-1B
I tested codegen2-1B, and the code quality has improved, but it seems that the conversion script for codegen is no longer suitable to codegen2.
Codegen2 > gpt-j > ggml conversion is successful, no error prompt appears, but the output is abnormal code.
For example, the return:
The following is the conversion script Codegen2 > gpt-j
The text was updated successfully, but these errors were encountered: