Skip to content

Commit

Permalink
FETCH线程加入WAIT
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jun 23, 2024
1 parent 7f7df7a commit 935854f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 215 deletions.
7 changes: 4 additions & 3 deletions include/models/basellm.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <thread>
#include <mutex>
#include <condition_variable>

#ifdef PY_API
#include "Python.h"
Expand Down Expand Up @@ -86,9 +87,7 @@ namespace fastllm {
public:
basellm() {};

~basellm() {
this->weight.ReleaseWeight();
};
~basellm();

virtual void LoadFromFile(const std::string &fileName); // 从文件读取

Expand Down Expand Up @@ -207,6 +206,7 @@ namespace fastllm {

std::thread *mainLoop = nullptr;
std::mutex mainLoopLocker, dictLocker;
std::condition_variable dictCV;

std::map <std::string, int> deviceMap;

Expand All @@ -222,6 +222,7 @@ namespace fastllm {
int lastPromptTokens = 0;

DataType dataType = DataType::FLOAT32;
bool isFree = false; // 是否释放
};
}

Expand Down
241 changes: 29 additions & 212 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,15 @@ namespace fastllm {
void PastKVCacheManager::Unlock() {
locker.unlock();
}


basellm::~basellm() {
dictLocker.lock();
this->isFree = true;
dictLocker.unlock();
dictCV.notify_all();
this->weight.ReleaseWeight();
}

std::string basellm::Response(const std::string &oriInput, RuntimeResult retCb,
const fastllm::GenerationConfig &generationConfig) {
std::string input = oriInput;
Expand Down Expand Up @@ -452,206 +460,14 @@ namespace fastllm {

int basellm::LaunchResponseTokens(const std::vector<int> &inputTokens,
const fastllm::GenerationConfig &generationConfig) {
/*
mainLoopLocker.lock();
if (mainLoop == nullptr) {
if (mainLoop == nullptr) {
mainLoop = new std::thread([](basellm *model) {
while (true) {
model->dictLocker.lock();
std::vector <int> handles;
std::vector<std::vector<float> > inputTokens;
std::vector <std::map <std::string, int> > params;
std::vector <GenerationConfig> generationConfigs;
int index = 0;
int cnt = 0;
std::vector <std::pair <int, int> > lenIdVector;
for (auto &it : model->responseContextDict.dicts) {
if (it.second->isEnding) {
continue;
}
lenIdVector.push_back(std::make_pair(it.second->generationConfig.output_token_limit,
it.first));
}
std::sort(lenIdVector.begin(), lenIdVector.end());
std::set <int> currentIds;
int maxInput = 0;
for (int i = 0; i < lenIdVector.size(); i++) {
maxInput = std::max(maxInput,
(int)model->responseContextDict.dicts[lenIdVector[i].second]->currentTokens.size());
if ((maxInput + lenIdVector[i].first) * (i + 1) > 512 * 256) {
break;
}
currentIds.insert(lenIdVector[i].second);
}
int maxOutputLimit = 0;
for (auto &it: model->responseContextDict.dicts) {
if (it.second->isEnding) {
continue;
}
if (currentIds.find(it.first) == currentIds.end()) {
continue;
}
maxOutputLimit = std::max(maxOutputLimit, it.second->generationConfig.output_token_limit);
generationConfigs.push_back(it.second->generationConfig);
handles.push_back(it.first);
if (it.second->preTokens == 0) {
it.second->intParams["promptLen"] = it.second->currentTokens.size();
it.second->intParams["index"] = 0;
} else {
it.second->intParams["index"]++;
}
inputTokens.push_back(std::vector <float> ());
for (int i : it.second->currentTokens) {
inputTokens.back().push_back(i);
}
params.push_back(std::map <std::string, int> ());
params.back()["promptLen"] = it.second->currentTokens.size();
params.back()["index"] = 0;
it.second->preTokens += (int)inputTokens.back().size();
//if (inputTokens.size() == 64) {
// break;
//}
}
if (inputTokens.size() > 0) {
model->dictLocker.unlock();
#ifdef USE_CUDA
FastllmCudaClearBigBuffer();
#endif
int batch = (int)inputTokens.size();
int last_n = 64; // TODO: 使用真实数据
std::vector <int> ret;
ret.resize(batch);
std::vector <std::vector <std::pair <Data, Data> > > *pkvPointer = new std::vector <std::vector <std::pair <Data, Data> > >();
std::vector <std::vector <std::pair <Data, Data> > > &pastKeyValuess = *pkvPointer;
pastKeyValuess.resize(batch);
for (int b = 0; b < batch; b++) {
printf("%d / %d, (%d + %d = %d)\n", b, batch, inputTokens[b].size(), generationConfigs[b].output_token_limit, inputTokens[b].size() + generationConfigs[b].output_token_limit);
Data inputIds, attentionMask, positionIds;
std::vector<std::pair<Data, Data> > &pastKeyValues = pastKeyValuess[b];
for (int i = 0; i < model->block_cnt; i++) {
pastKeyValues.push_back(std::make_pair(Data(DataType::FLOAT32),
Data(DataType::FLOAT32)));
}
LastTokensManager tokens(1, generationConfigs[b].last_n);
int promptLen = inputTokens[b].size(), index = 0;
std::vector <std::vector <float> > curInputTokens = {inputTokens[b]};
model->FillLLMInputs(curInputTokens, {{"promptLen", promptLen}, {"index", index}}, inputIds, attentionMask, positionIds);
ret[b] = model->Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfigs[b], tokens);
}
Data inputIds, attentionMask, positionIds;
LastTokensManager tokensManager (batch, last_n);
std::vector <bool> isEnding = std::vector <bool> (batch, false);
std::vector <std::pair <Data, Data> > pastKeyValues;
for (int i = 0; i < model->block_cnt; i++) {
pastKeyValues.push_back(std::make_pair(Data(DataType::FLOAT32), Data(DataType::FLOAT32)));
}
for (int i = 0; i < model->block_cnt; i++) {
auto &key = pastKeyValues[i].first;
auto &value = pastKeyValues[i].second;
std::vector <int> dims = pastKeyValuess[0][i].first.dims;
for (int b = 1; b < batch; b++) {
dims[0] += pastKeyValuess[b][i].first.dims[0];
dims[1] = std::max(dims[1], pastKeyValuess[b][i].first.dims[1]);
}
std::vector <int> expandDims = dims;
expandDims[1] += maxOutputLimit;
key.ToDevice(DataDevice::CUDA);
value.ToDevice(DataDevice::CUDA);
key.Expansion(dims);
value.Expansion(dims);
key.Resize(dims);
value.Resize(dims);
int bs = dims[0], perbs = bs / batch, len = dims[1], inner = dims[2];
for (int b = 0; b < batch; b++) {
Data &oldKey = pastKeyValuess[b][i].first;
Data &oldValue = pastKeyValuess[b][i].second;
CopyKVCache(oldKey, key, 0, b * perbs, perbs, (dims[1] - oldKey.dims[1]));
CopyKVCache(oldValue, value, 0, b * perbs, perbs, (dims[1] - oldValue.dims[1]));
}
}
delete pkvPointer;
std::vector <std::vector <int> > results;
results.resize(batch);
bool first = true;
GenerationConfig config;
while (true) {
if (first) {
first = false;
} else {
auto st = std::chrono::system_clock::now();
ret = model->ForwardBatch(batch, inputIds, attentionMask, positionIds,
pastKeyValues, config, tokensManager);
printf("batch = %d, spend = %f s.\n", batch, GetSpan(st, std::chrono::system_clock::now()));
}
for (int i = 0; i < batch; i++) {
tokensManager.units[i].Push(ret[i]);
}
std::vector <float> fret;
int endingCount = 0;
std::vector <std::string> curStrings;
for (int i = 0; i < batch; i++) {
fret.push_back(ret[i]);
inputTokens[i] = std::vector <float> {(float)ret[i]};
if (ret[i] == model->eos_token_id || (results[i].size() >= generationConfigs[i].output_token_limit)) {
isEnding[i] = true;
}
if (isEnding[i]) {
endingCount++;
continue;
}
results[i].push_back(ret[i]);
}
printf("%d / %d\n", endingCount, batch);
if (endingCount == batch) {
break;
}
params[0]["index"]++;
model->FillLLMInputsBatch(inputTokens, params, inputIds, attentionMask, positionIds);
}
model->dictLocker.lock();
for (int i = 0; i < handles.size(); i++) {
auto &it = *model->responseContextDict.dicts.find(handles[i]);
for (int token : results[i]) {
it.second->resultTokenQueue.push(token);
}
it.second->isEnding = true;
}
if (model->isFree) {
break;
}
model->dictLocker.unlock();
MySleep(0);
}
}, this);
}
}
mainLoopLocker.unlock();
*/
mainLoopLocker.lock();
if (mainLoop == nullptr) {
if (mainLoop == nullptr) {
mainLoop = new std::thread([](basellm *model) {
while (true) {
std::vector <Data*> attentionMasks;
std::vector <Data*> positionIds;
std::vector <std::pair <Data*, Data*> > pastKeyValues;
Expand All @@ -661,7 +477,8 @@ printf("%d / %d\n", endingCount, batch);
std::vector <GenerationConfig> generationConfigs;
LastTokensManager tokensManager;
std::vector <std::vector <float>* > logits;
model->dictLocker.lock();

std::unique_lock<std::mutex> dictLocker(model->dictLocker);

int limit = model->tokensLimit > 0 ? model->tokensLimit : 1e9;
int lenSum = 0;
Expand Down Expand Up @@ -748,7 +565,11 @@ printf("%d / %d\n", endingCount, batch);
}
if (isPrompt) {
cnt += it.second->currentTokens.size();
break;

if (cnt > 300) {
break;
}
// break;
}
}
}
Expand All @@ -757,7 +578,7 @@ printf("%d / %d\n", endingCount, batch);
if (seqLens.size() == 1) {
pastKeyValue1 = &model->responseContextDict.dicts[handles[0]]->pastKeyValues;
}
model->dictLocker.unlock();
dictLocker.unlock();
#ifdef USE_CUDA
FastllmCudaClearBigBuffer();
#endif
Expand Down Expand Up @@ -790,14 +611,14 @@ auto st = std::chrono::system_clock::now();
*pastKeyValue1, generationConfigs[0], tokensManager, logits[0])};
}
}
/*
PrintProfiler();
int total = 0;

//PrintProfiler();
/*int total = 0;
for (int i : seqLens) total += i;
float spend = GetSpan(st, std::chrono::system_clock::now());
printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)total / spend);
*/
model->dictLocker.lock();
dictLocker.lock();
for (int i = 0; i < handles.size(); i++) {
auto &it = *model->responseContextDict.dicts.find(handles[i]);
int curRet = ret[i];
Expand Down Expand Up @@ -828,8 +649,9 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
delete positionIds[i];
}

model->dictLocker.unlock();
MySleep(0);
if (seqLens.size() == 0) {
model->dictCV.wait(dictLocker);
}
}
}, this);
}
Expand All @@ -844,26 +666,24 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
context->generationConfig = generationConfig;
context->tokens = LastTokensUnit(generationConfig.last_n);
dictLocker.unlock();
dictCV.notify_one();
return handleId;
}

int basellm::FetchResponseTokens(int handleId) {
dictLocker.lock();
std::unique_lock<std::mutex> dictLocker(this->dictLocker);
ResponseContext *context = responseContextDict.GetHandle(handleId);
if (context == nullptr) {
dictLocker.unlock();
return -1;
} else {
while (true) {
if (context->resultTokenQueue.size() > 0) {
int ret = context->resultTokenQueue.front();
context->resultTokenQueue.pop();
dictLocker.unlock();
return ret;
} else {
if (context->isEnding) {
responseContextDict.RemoveHandle(handleId);
dictLocker.unlock();
return -1;
}
}
Expand All @@ -875,10 +695,9 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
}

int basellm::FetchResponseLogits(int handleId, std::vector<float> &logits) {
dictLocker.lock();
std::unique_lock<std::mutex> dictLocker(this->dictLocker);
ResponseContext *context = responseContextDict.GetHandle(handleId);
if (context == nullptr) {
dictLocker.unlock();
return -1;
} else {
while (true) {
Expand All @@ -890,12 +709,10 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
delete context->resultLogits.front();
context->resultLogits.pop();
}
dictLocker.unlock();
return ret;
} else {
if (context->isEnding) {
responseContextDict.RemoveHandle(handleId);
dictLocker.unlock();
return -1;
}
}
Expand Down

0 comments on commit 935854f

Please sign in to comment.