Skip to content

Commit

Permalink
增加直接读取llama3模型的功能
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed May 28, 2024
1 parent 708e090 commit ddadf49
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 6 deletions.
6 changes: 5 additions & 1 deletion include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ namespace fastllm {
};

enum WeightType {
NONE = 0, LINEAR = 1, EMBEDDING = 2
NONE = 0, LINEAR = 1, EMBEDDING = 2, AUTO = 99999
};

struct FileMmap {
Expand Down Expand Up @@ -439,6 +439,8 @@ namespace fastllm {

std::set <std::string> embeddingNames;

std::set <std::string> linearNames;

void LoadFromFile(const std::string &fileName); // 从文件读取

void SaveLowBitModel(const std::string &fileName, int bit); // 存储成量化模型, bit = 0代表直接存
Expand All @@ -458,6 +460,8 @@ namespace fastllm {
void AddQLinearWeight(const std::string &key, const std::vector <int> &dims,
int bit, float *scales, uint8_t *oriData); // 插入一个Qlinear层的权重,量化规则为float value = scales * oriData

WeightType GetWeightType(const std::string &key); // 获取某个权重的类型(若未判断出来,则为None)

Data &operator [] (const std::string &key);
};

Expand Down
4 changes: 4 additions & 0 deletions include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ namespace fastllm {
std::unique_ptr<basellm> CreateLLMModelFromFile(const std::string &fileName);

std::unique_ptr<basellm> CreateEmptyLLMModel(const std::string &modelType);

std::unique_ptr<basellm> CreateLLMModelFromHF(const std::string &modelPath,
DataType linearDataType,
int groupCnt = -1);
}

#endif //FASTLLM_MODEL_H
44 changes: 44 additions & 0 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1986,6 +1986,40 @@ namespace fastllm {
this->peftDict[name][key] = value;
}

WeightType WeightMap::GetWeightType(const std::string &key) {
if (this->embeddingNames.find(key) != this->embeddingNames.end()) {
return WeightType::EMBEDDING;
}
for (auto &linearName : this->linearNames) {
int n = key.size(), m = linearName.size();
std::vector <std::vector <bool> > f = std::vector <std::vector <bool> > (n + 1, std::vector <bool>(m + 1, 0));
f[0][0] = 1;
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= m; j++) {
if (f[i][j]) {
if (i + 1 <= n && key[i] == '*') {
for (int l = j; l <= m; l++) {
f[i + 1][l] = 1;
}
}
if (j + 1 <= m && linearName[j] == '*') {
for (int l = i; l <= n; l++) {
f[l][j + 1] = 1;
}
}
if (i + 1 <= n && j + 1 <= m && key[i] == linearName[j]) {
f[i + 1][j + 1] = 1;
}
}
}
}
if (f[n][m]) {
return WeightType::LINEAR;
}
}
return WeightType::NONE;
}

void WeightMap::AddQLinearWeight(const std::string &key, const std::vector <int> &dims,
int bit, float *scales, uint8_t *oriData) {
AssertInFastLLM(bit == 4 || bit == 8, "Error: only support 8 bit or 4 bit QLinear.\n");
Expand Down Expand Up @@ -2041,6 +2075,16 @@ namespace fastllm {

void WeightMap::AddWeight(const std::string &key, const std::vector<int> &dims, fastllm::DataType dataType,
fastllm::WeightType weightType, fastllm::DataType oriDataType, uint8_t *oriData, int groupCnt) {
if (weightType == WeightType::AUTO) {
weightType = GetWeightType(key);
if (weightType == WeightType::EMBEDDING) {
dataType = oriDataType;
}
if (weightType == WeightType::NONE) {
dataType = oriDataType;
}
}

this->weight[key] = Data(dataType, dims);
this->weight[key].name = std::string(key);
Data &data = this->weight[key];
Expand Down
223 changes: 223 additions & 0 deletions src/model.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include "utils.h"
#include "json11.hpp"

#include "model.h"
#include "fastllm.h"
#include <sstream>
#include <fstream>

#include "chatglm.h"
#include "moss.h"
Expand All @@ -16,6 +18,40 @@
#include "bert.h"

namespace fastllm {
std::string ReadAllFile(const std::string &fileName) {
if (access(fileName.c_str(), R_OK) != 0) {
ErrorInFastLLM("Read error: can't find \"" + fileName + "\".");
}

std::ifstream t(fileName.c_str());
std::string ret((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
t.close();
return ret;
}

void ConvertDataType(uint8_t *src, DataType srcDtype, uint8_t *dst, DataType dstDtype, uint64_t len) {
if (srcDtype == dstDtype) {
int unitSize = 4;
if (dstDtype == DataType::FLOAT32) {
unitSize = 4;
} else if (dstDtype == DataType::FLOAT16 || dstDtype == DataType::BFLOAT16) {
unitSize = 2;
} else {
ErrorInFastLLM("ConvertDataType Failed. (" + std::to_string(srcDtype) + " -> " + std::to_string(dstDtype) + ")");
}
memcpy(dst, src, len * unitSize);
} else if (srcDtype == DataType::BFLOAT16 && dstDtype == DataType::FLOAT32) {
uint16_t *u16dst = (uint16_t*)dst;
uint16_t *u16src = (uint16_t*)src;
for (int i = 0; i < len; i++) {
u16dst[i * 2] = 0;
u16dst[i * 2 + 1] = u16src[i];
}
} else {
ErrorInFastLLM("ConvertDataType Failed. (" + std::to_string(srcDtype) + " -> " + std::to_string(dstDtype) + ")");
}
}

void basellm::LoadFromFile(const std::string &fileName) {
this->weight.LoadFromFile(fileName);
this->InitParams();
Expand Down Expand Up @@ -153,4 +189,191 @@ namespace fastllm {
basellm *model = CreateModelWithType(modelType);
return std::unique_ptr<fastllm::basellm> (model);
}

struct SafeTensorItem {
std::string tensorName;
std::string fileName;
std::string dtype;
std::vector <std::uint64_t> shape;
std::vector <int> intShape;
std::vector <std::uint64_t> data_offsets;

uint64_t len, bytes;
uint8_t *buffer = nullptr;

SafeTensorItem() {}

SafeTensorItem(const std::string &tensorName, const std::string &fileName, const json11::Json &config, uint64_t baseOffset) {
this->tensorName = tensorName;
this->fileName = fileName;

this->dtype = config["dtype"].string_value();
for (auto &it : config["data_offsets"].array_items()) {
this->data_offsets.push_back(baseOffset + it.ll_value());
}
for (auto &it : config["shape"].array_items()) {
this->shape.push_back(it.ll_value());
this->intShape.push_back(this->shape.back());
}

len = 1;
for (auto &it : shape) {
len *= it;
}
bytes = this->data_offsets[1] - this->data_offsets[0];
}

void CreateBuffer(DataType dstType) {
DataType srcType;
if (this->dtype == "BF16") {
srcType = DataType::BFLOAT16;
} else {
ErrorInFastLLM("SafeTensorItem.CreateBuffer: unsupport src dtype " + this->dtype + "\n");
}

int unitSize = 4;
if (dstType == DataType::FLOAT32) {
unitSize = 4;
} else if (dstType == DataType::FLOAT16 || dstType == DataType::BFLOAT16) {
unitSize = 2;
} else {
ErrorInFastLLM("SafeTensorItem.CreateBuffer: unsupport dst dtype " + std::to_string(dstType) + "\n");
}
ClearBuffer();
buffer = new uint8_t[len * unitSize];

FILE *fi = fopen(this->fileName.c_str(), "r");
int ret;
#if defined(_WIN32) or defined(_WIN64)
_fseeki64(fi, this->data_offsets[0], 0);
#else
fseek(fi, this->data_offsets[0], 0);
#endif
if (dstType == srcType) {
ret = fread(buffer, 1, this->bytes, fi);
} else {
uint8_t *ori = new uint8_t[this->bytes];
ret = fread(ori, 1, this->bytes, fi);
ConvertDataType(ori, srcType, buffer, dstType, len);
delete[] ori;
}
fclose(fi);
}

void ClearBuffer() {
delete[] buffer;
buffer = nullptr;
}
};

struct SafeTensors {
std::set <std::string> fileNames;
std::map <std::string, SafeTensorItem> itmeDict;

SafeTensors (const std::set <std::string> &fileNames) {
std::string error;
this->fileNames = fileNames;
for (auto &fileName : fileNames) {
FILE *f = fopen(fileName.c_str(), "rb");
uint64_t configBytes;
int ret = fread(&configBytes, 8, 1, f);
char *configString = new char[configBytes + 5];
ret = fread(configString, 1, configBytes, f);
configString[configBytes] = 0;
auto config = json11::Json::parse(configString, error);
for (auto it : config.object_items()) {
if (it.first != "__metadata__" ) {
itmeDict[it.first] = SafeTensorItem(it.first, fileName, it.second, 8 + configBytes);
}
}

delete[] configString;
}
}

std::vector <std::string> GetSortedItemNames() {
std::vector <std::pair <std::pair <std::string, uint64_t>, std::string> > v;
for (auto &it : itmeDict) {
v.push_back(std::make_pair(std::make_pair(it.second.fileName, it.second.data_offsets[0]), it.first));
}
std::sort(v.begin(), v.end());
std::vector <std::string> ret;
for (int i = 0; i < v.size(); i++) {
ret.push_back(v[i].second);
}
return ret;
}
};

// 从hf文件夹读取,仅支持safetensor格式的模型
std::unique_ptr <basellm> CreateLLMModelFromHF(const std::string &modelPath,
DataType linearDataType, int groupCnt) {
std::string path = modelPath;
if (path.back() != '/' || path.back() != '\\') {
path += "/";
}

// 1. 检查是否有 model.safetensors.index.json,如果有就读取
std::string stIndexFile = path + "model.safetensors.index.json";
std::string error;
auto stIndex = json11::Json::parse(ReadAllFile(stIndexFile), error)["weight_map"];
std::set <std::string> stFiles;
for (auto it : stIndex.object_items()) {
stFiles.insert(path + it.second.string_value());
}
SafeTensors safeTensors(stFiles);

// 2. 创建网络基本信息
std::string configFile = path + "config.json";
auto config = json11::Json::parse(ReadAllFile(configFile), error);
basellm *model = CreateModelWithType(config["model_type"].string_value());
for (auto &it : config.object_items()) {
model->weight.AddDict(it.first, it.second.dump().c_str());
}

// 3. 读取分词
std::string tokenizerConfigFile = path + "tokenizer_config.json";
auto tokenizerConfig = json11::Json::parse(ReadAllFile(tokenizerConfigFile), error);
std::string tokenizerClass = tokenizerConfig["tokenizer_class"].string_value();
if (tokenizerClass == "PreTrainedTokenizerFast") {
// PreTrainedTokenizerFast
std::string tokenizerFile = path + "tokenizer.json";
auto tokenizer = json11::Json::parse(ReadAllFile(tokenizerFile), error);
auto tokenizerModel = tokenizer["model"];
auto vocab = tokenizerModel["vocab"];
for (auto &it : vocab.object_items()) {
model->weight.AddTokenizerWord(it.first, it.second.int_value(), 1.0f);
}
std::map<std::string, int> spTokens;
for (auto &it : tokenizer["added_tokens"].array_items()) {
spTokens[it["content"].string_value()] = it["id"].int_value();
}
model->weight.tokenizer.SetSpecialTokens(spTokens);

if (!tokenizer["decoder"].is_null() && !tokenizer["decoder"]["type"].is_null() &&
tokenizer["decoder"]["type"].string_value() == "ByteLevel") {
model->weight.tokenizer.byteAsChar = true;
}
} else {
ErrorInFastLLM("Unsupport tokenizer_class: " + tokenizerClass);
}

// 4. 读取权重
int cur = 0;
for (auto &weightName : safeTensors.GetSortedItemNames()) {
auto &tensor = safeTensors.itmeDict[weightName];
tensor.CreateBuffer(DataType::FLOAT32);
model->weight.AddWeight(weightName, tensor.intShape, linearDataType, WeightType::AUTO, DataType::FLOAT32, tensor.buffer, groupCnt);
tensor.ClearBuffer();

printf("Load (%d / %d) \r", (++cur), (int)safeTensors.itmeDict.size());
fflush(stdout);
}
printf("\n");
fflush(stdout);

model->InitParams();
model->WarmUp();
return std::unique_ptr<fastllm::basellm> (model);
}
}
16 changes: 11 additions & 5 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,22 @@ namespace fastllm {
LlamaModel::LlamaModel() {
this->model_type = "llama";

// 默认使用alpaca的提示词和instruction
this->pre_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n";
this->user_role = "### Instruction:\n";
this->bot_role = "\n\n### Response:";
this->history_sep = "</s>";
// 默认使用 llama3 的提示词和instruction
this->pre_prompt="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant.<|eot_id|>";
this->user_role="<|start_header_id|>user<|end_header_id|>\n";
this->bot_role="<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n";
this->history_sep="<|eot_id|>\n";

block_cnt = 32;
rotary_dim = 128;

weight.embeddingNames.insert("model.embed_tokens.weight");
weight.linearNames = {
"lm_head.weight", "model.layers.*.mlp.down_proj.weight", "model.layers.*.mlp.up_proj.weight",
"model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.gateup_proj.weight",
"model.layers.*.self_attn.o_proj.weight", "model.layers.*.self_attn.q_proj.weight", "model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight", "model.layers.*.self_attn.mergeqkv.weight", "model.layers.*.self_attn.W_pack.weight"
};
}

void LlamaModel::InitParams() {
Expand Down

0 comments on commit ddadf49

Please sign in to comment.