From 1fccfc9eb61a0eefeb9e418936ece7ae62418c9b Mon Sep 17 00:00:00 2001 From: Emreerdog Date: Wed, 25 Dec 2024 14:13:50 +0300 Subject: [PATCH] Removed unnecessary iteration of batch n_tokens on sequence embeddings generation. --- examples/embedding/embedding.cpp | 39 ++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 3f18fc6a70878..13bc670bbe9ec 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -53,28 +53,37 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } } - for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { - continue; - } + const float* embd = nullptr; + int embd_pos = 0; - const float * embd = nullptr; - int embd_pos = 0; + if(pooling_type == LLAMA_POOLING_TYPE_NONE) + { + for (int i = 0; i < batch.n_tokens; i++) + { + if (!batch.logits[i]) { + continue; + } - if (pooling_type == LLAMA_POOLING_TYPE_NONE) { - // try to get token embeddings embd = llama_get_embeddings_ith(ctx, i); embd_pos = i; GGML_ASSERT(embd != NULL && "failed to get token embeddings"); - } else { - // try to get sequence embeddings - supported only when pooling_type is not NONE - embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - embd_pos = batch.seq_id[i][0]; - GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); + + float * out = output + embd_pos * n_embd; + common_embd_normalize(embd, out, n_embd, embd_norm); } + } - float * out = output + embd_pos * n_embd; - common_embd_normalize(embd, out, n_embd, embd_norm); + else + { + for(int i = 0; i < n_seq; i++) + { + embd = llama_get_embeddings_seq(ctx, i); + embd_pos = i; + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); + + float * out = output + embd_pos * n_embd; + common_embd_normalize(embd, out, n_embd, embd_norm); + } } }