Skip to content

Commit

Permalink
server : do not normalize embeddings when there is no pooling
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Dec 17, 2024
1 parent c63d869 commit 400a5a1
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 6 deletions.
4 changes: 3 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1780,7 +1780,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm)
break;
case 0: // max absolute
for (int i = 0; i < n; i++) {
if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
if (sum < std::abs(inp[i])) {
sum = std::abs(inp[i]);
}
}
sum /= 32760.0; // make an int16 range
break;
Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si
// Embedding utils
//

void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
// TODO: repace embd_norm with an enum
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);

float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);

Expand Down
2 changes: 1 addition & 1 deletion examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
}

std::vector<float> emb_norm(emb_unorm.size());
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2);
result.push_back(emb_norm);

#ifdef GRIT_DEBUG
Expand Down
2 changes: 1 addition & 1 deletion examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}

float * out = output + batch.seq_id[i][0] * n_embd;
common_embd_normalize(embd, out, n_embd);
common_embd_normalize(embd, out, n_embd, 2);
}
}

Expand Down
10 changes: 8 additions & 2 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2049,8 +2049,14 @@ struct server_context {
continue;
}

common_embd_normalize(embd, embd_res.data(), n_embd);
res->embedding.push_back(embd_res);
// normalize only when there is pooling
// TODO: configurable
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
res->embedding.push_back(embd_res);
} else {
res->embedding.push_back({ embd, embd + n_embd });
}
}

SLT_DBG(slot, "%s", "sending embeddings\n");
Expand Down
5 changes: 5 additions & 0 deletions examples/server/tests/unit/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def test_embedding_pooling_none():
assert 'embedding' in res.body[0]
assert len(res.body[0]['embedding']) == 3

# make sure embedding vector is not normalized
for x in res.body[0]['embedding']:
assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON


def test_embedding_pooling_none_oai():
global server
Expand All @@ -66,6 +70,7 @@ def test_embedding_pooling_none_oai():
res = server.make_request("POST", "/v1/embeddings", data={
"input": "hello hello hello",
})

# /v1/embeddings does not support pooling type 'none'
assert res.status_code == 400

Expand Down

0 comments on commit 400a5a1

Please sign in to comment.