Skip to content

Commit

Permalink
fix 直接读bge-largh-zh
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Sep 23, 2024
1 parent 5b89fab commit 4736a6d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,9 @@ namespace fastllm {
data.UpdateUnitSize();
data.Allocate();
if (dataType == oriDataType) {
memcpy(data.cpuData, oriData, data.GetBytes());
if (oriData != nullptr) {
memcpy(data.cpuData, oriData, data.GetBytes());
}
} else if (oriDataType == DataType::BFLOAT16
&& dataType == DataType::FLOAT16) {
uint16_t *a = (uint16_t*)data.cpuData;
Expand Down
3 changes: 3 additions & 0 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ namespace fastllm {
if (dstType != DataType::FLOAT32) {
ErrorInFastLLM("SafeTensorItem.CreateBuffer: unsupport src dtype " + this->dtype + "\n");
}
} else if (this->dtype == "I64") {
printf("skip I64 tensor %s\n", this->tensorName.c_str());
return;
} else {
ErrorInFastLLM("SafeTensorItem.CreateBuffer: unsupport src dtype " + this->dtype + "\n");
}
Expand Down
12 changes: 11 additions & 1 deletion src/models/bert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,17 @@ namespace fastllm {

void BertModel::WarmUp() {
printf("Warmup...\n");
EmbeddingSentence({"1"}, true);
int batch = 1, len = 1;
std::vector <float> ids = std::vector <float> (batch * len, 0.0f);
std::vector <float> seqLens = std::vector <float> (batch, 0.0f);
std::vector <float> token_type_ids = std::vector <float> (batch * len, 0.0f);
std::vector <float> attention_mask = std::vector <float> (batch * len, -1e10f);
std::vector <float> position_ids = std::vector <float> (batch * len, 0.0f);
fastllm::Data inputIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids);
fastllm::Data attentionMask = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask);
fastllm::Data tokenTypeIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids);
fastllm::Data positionIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids);
ForwardAll(inputIds, attentionMask, tokenTypeIds, positionIds, true);
printf("finish.\n");
}

Expand Down

0 comments on commit 4736a6d

Please sign in to comment.