Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server : output embeddings for all tokens when pooling = none #10861

Merged
merged 12 commits into from
Dec 18, 2024
4 changes: 3 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1780,7 +1780,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm)
break;
case 0: // max absolute
for (int i = 0; i < n; i++) {
if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
if (sum < std::abs(inp[i])) {
sum = std::abs(inp[i]);
}
}
sum /= 32760.0; // make an int16 range
break;
Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si
// Embedding utils
//

void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
// TODO: repace embd_norm with an enum
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);

float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);

Expand Down
2 changes: 1 addition & 1 deletion examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
}

std::vector<float> emb_norm(emb_unorm.size());
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2);
result.push_back(emb_norm);

#ifdef GRIT_DEBUG
Expand Down
2 changes: 1 addition & 1 deletion examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}

float * out = output + batch.seq_id[i][0] * n_embd;
common_embd_normalize(embd, out, n_embd);
common_embd_normalize(embd, out, n_embd, 2);
}
}

Expand Down
42 changes: 42 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,8 @@ curl http://localhost:8080/v1/chat/completions \

### POST `/v1/embeddings`: OpenAI-compatible embeddings API

This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm.

*Options:*

See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings).
Expand Down Expand Up @@ -795,6 +797,46 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r
}'
```

### POST `/embeddings`: non-OpenAI-compatible embeddings API

This endpoint supports all poolings, including `--pooling none`. When the pooling is `none`, the responses will contain the *unnormalized* embeddings for *all* input tokens. For all other pooling types, only the pooled embeddings are returned, normalized using Euclidian norm.

Note that the response format of this endpoint is different from `/v1/embeddings`.

*Options:*

Same as the `/v1/embeddings` endpoint.

*Examples:*

Same as the `/v1/embeddings` endpoint.

**Response format**

```json
[
{
"index": 0,
"embedding": [
[ ... embeddings for token 0 ... ],
[ ... embeddings for token 1 ... ],
[ ... ]
[ ... embeddings for token N-1 ... ],
]
},
...
{
"index": P,
"embedding": [
[ ... embeddings for token 0 ... ],
[ ... embeddings for token 1 ... ],
[ ... ]
[ ... embeddings for token N-1 ... ],
]
}
]
```

### GET `/slots`: Returns the current slots processing state

> [!WARNING]
Expand Down
75 changes: 57 additions & 18 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,7 @@ struct server_task_result_cmpl_partial : server_task_result {
{"delta",
json {
{"content", content},
{"tokens", tokens}
}},
}});
}
Expand All @@ -726,18 +727,32 @@ struct server_task_result_cmpl_partial : server_task_result {

struct server_task_result_embd : server_task_result {
int index = 0;
std::vector<float> embedding;
std::vector<std::vector<float>> embedding;

int32_t n_tokens;

// OAI-compat fields
bool oaicompat = false;

virtual int get_index() override {
return index;
}

virtual json to_json() override {
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
}

json to_json_non_oaicompat() {
return json {
{"index", index},
{"embedding", embedding},
};
}

json to_json_oaicompat() {
return json {
{"index", index},
{"embedding", embedding},
{"embedding", embedding[0]},
{"tokens_evaluated", n_tokens},
};
}
Expand Down Expand Up @@ -2017,9 +2032,10 @@ struct server_context {

void send_embedding(const server_slot & slot, const llama_batch & batch) {
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
res->oaicompat = slot.params.oaicompat;

const int n_embd = llama_n_embd(model);

Expand All @@ -2038,12 +2054,18 @@ struct server_context {
if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);

res->embedding = std::vector<float>(n_embd, 0.0f);
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
continue;
}

common_embd_normalize(embd, embd_res.data(), n_embd);
res->embedding = embd_res;
// normalize only when there is pooling
// TODO: configurable
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
res->embedding.push_back(embd_res);
} else {
res->embedding.push_back({ embd, embd + n_embd });
}
}

SLT_DBG(slot, "%s", "sending embeddings\n");
Expand Down Expand Up @@ -2657,7 +2679,10 @@ struct server_context {

// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
// without pooling, we want to output the embeddings for all the tokens in the batch
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;

common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);

if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
Expand Down Expand Up @@ -3665,14 +3690,17 @@ int main(int argc, char ** argv) {
res_ok(res, data);
};

const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
const json body = json::parse(req.body);
bool oaicompat = false;

if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
return;
}

// for the shape of input/content, see tokenize_input_prompts()
json prompt;
if (body.contains("input")) {
oaicompat = true;
if (body.count("input") != 0) {
prompt = body.at("input");
} else if (body.contains("content")) {
oaicompat = false;
Expand All @@ -3697,10 +3725,15 @@ int main(int argc, char ** argv) {
{
std::vector<server_task> tasks;
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);

task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);

// OAI-compat
task.params.oaicompat = oaicompat;

tasks.push_back(task);
}

Expand Down Expand Up @@ -3728,12 +3761,18 @@ int main(int argc, char ** argv) {
}

// write JSON response
json root = oaicompat
? format_embeddings_response_oaicompat(body, responses)
: responses.size() == 1 ? responses[0] : json(responses);
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
res_ok(res, root);
};

const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
handle_embeddings_impl(req, res, false);
};

const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
handle_embeddings_impl(req, res, true);
};

const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
Expand Down Expand Up @@ -3907,7 +3946,7 @@ int main(int argc, char ** argv) {
svr->Post("/infill", handle_infill);
svr->Post("/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings_oai);
svr->Post("/rerank", handle_rerank);
svr->Post("/reranking", handle_rerank);
svr->Post("/v1/rerank", handle_rerank);
Expand Down
Loading
Loading