From 400a5a15e3d0df4bd132d8f89802f32eb6214fbd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 17 Dec 2024 13:36:32 +0200 Subject: [PATCH] server : do not normalize embeddings when there is no pooling ggml-ci --- common/common.cpp | 4 +++- common/common.h | 3 ++- examples/gritlm/gritlm.cpp | 2 +- examples/retrieval/retrieval.cpp | 2 +- examples/server/server.cpp | 10 ++++++++-- examples/server/tests/unit/test_embedding.py | 5 +++++ 6 files changed, 20 insertions(+), 6 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index c0c98232ed3bb..05d3ba766e38b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1780,7 +1780,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) break; case 0: // max absolute for (int i = 0; i < n; i++) { - if (sum < std::abs(inp[i])) sum = std::abs(inp[i]); + if (sum < std::abs(inp[i])) { + sum = std::abs(inp[i]); + } } sum /= 32760.0; // make an int16 range break; diff --git a/common/common.h b/common/common.h index 5f556c24d933c..ec0e49f6f1806 100644 --- a/common/common.h +++ b/common/common.h @@ -596,7 +596,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si // Embedding utils // -void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2); +// TODO: repace embd_norm with an enum +void common_embd_normalize(const float * inp, float * out, int n, int embd_norm); float common_embd_similarity_cos(const float * embd1, const float * embd2, int n); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 6e42fa0734ecb..18a945b33905f 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -75,7 +75,7 @@ static std::vector> encode(llama_context * ctx, const std::ve } std::vector emb_norm(emb_unorm.size()); - common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd); + common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2); result.push_back(emb_norm); #ifdef GRIT_DEBUG diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 23ff4db27a420..a5c6fe7e58523 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } float * out = output + batch.seq_id[i][0] * n_embd; - common_embd_normalize(embd, out, n_embd); + common_embd_normalize(embd, out, n_embd, 2); } } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e4947c4576d9d..00c1639c497e0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2049,8 +2049,14 @@ struct server_context { continue; } - common_embd_normalize(embd, embd_res.data(), n_embd); - res->embedding.push_back(embd_res); + // normalize only when there is pooling + // TODO: configurable + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + } else { + res->embedding.push_back({ embd, embd + n_embd }); + } } SLT_DBG(slot, "%s", "sending embeddings\n"); diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index f702565a110dc..e3d380de21f72 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -58,6 +58,10 @@ def test_embedding_pooling_none(): assert 'embedding' in res.body[0] assert len(res.body[0]['embedding']) == 3 + # make sure embedding vector is not normalized + for x in res.body[0]['embedding']: + assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON + def test_embedding_pooling_none_oai(): global server @@ -66,6 +70,7 @@ def test_embedding_pooling_none_oai(): res = server.make_request("POST", "/v1/embeddings", data={ "input": "hello hello hello", }) + # /v1/embeddings does not support pooling type 'none' assert res.status_code == 400