diff --git a/src/fastllm.cpp b/src/fastllm.cpp index edc699c..01cff72 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -549,7 +549,9 @@ namespace fastllm { data.UpdateUnitSize(); data.Allocate(); if (dataType == oriDataType) { - memcpy(data.cpuData, oriData, data.GetBytes()); + if (oriData != nullptr) { + memcpy(data.cpuData, oriData, data.GetBytes()); + } } else if (oriDataType == DataType::BFLOAT16 && dataType == DataType::FLOAT16) { uint16_t *a = (uint16_t*)data.cpuData; diff --git a/src/model.cpp b/src/model.cpp index ec8f53f..4b06059 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -287,6 +287,9 @@ namespace fastllm { if (dstType != DataType::FLOAT32) { ErrorInFastLLM("SafeTensorItem.CreateBuffer: unsupport src dtype " + this->dtype + "\n"); } + } else if (this->dtype == "I64") { + printf("skip I64 tensor %s\n", this->tensorName.c_str()); + return; } else { ErrorInFastLLM("SafeTensorItem.CreateBuffer: unsupport src dtype " + this->dtype + "\n"); } diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 4518633..02875a8 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -212,7 +212,17 @@ namespace fastllm { void BertModel::WarmUp() { printf("Warmup...\n"); - EmbeddingSentence({"1"}, true); + int batch = 1, len = 1; + std::vector ids = std::vector (batch * len, 0.0f); + std::vector seqLens = std::vector (batch, 0.0f); + std::vector token_type_ids = std::vector (batch * len, 0.0f); + std::vector attention_mask = std::vector (batch * len, -1e10f); + std::vector position_ids = std::vector (batch * len, 0.0f); + fastllm::Data inputIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids); + fastllm::Data attentionMask = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask); + fastllm::Data tokenTypeIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids); + fastllm::Data positionIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids); + ForwardAll(inputIds, attentionMask, tokenTypeIds, positionIds, true); printf("finish.\n"); }