Skip to content

Commit

Permalink
Removed unnecessary iteration of batch n_tokens on sequence embedding…
Browse files Browse the repository at this point in the history
…s generation.
  • Loading branch information
Emreerdog committed Dec 25, 2024
1 parent 9ba399d commit 1fccfc9
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,37 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
}

for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
}
const float* embd = nullptr;
int embd_pos = 0;

const float * embd = nullptr;
int embd_pos = 0;
if(pooling_type == LLAMA_POOLING_TYPE_NONE)
{
for (int i = 0; i < batch.n_tokens; i++)
{
if (!batch.logits[i]) {
continue;
}

if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
// try to get token embeddings
embd = llama_get_embeddings_ith(ctx, i);
embd_pos = i;
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
} else {
// try to get sequence embeddings - supported only when pooling_type is not NONE
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
embd_pos = batch.seq_id[i][0];
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");

float * out = output + embd_pos * n_embd;
common_embd_normalize(embd, out, n_embd, embd_norm);
}
}

float * out = output + embd_pos * n_embd;
common_embd_normalize(embd, out, n_embd, embd_norm);
else
{
for(int i = 0; i < n_seq; i++)
{
embd = llama_get_embeddings_seq(ctx, i);
embd_pos = i;
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");

float * out = output + embd_pos * n_embd;
common_embd_normalize(embd, out, n_embd, embd_norm);
}
}
}

Expand Down

0 comments on commit 1fccfc9

Please sign in to comment.