Skip to content

Commit

Permalink
server : update /embeddings and /v1/embeddings endpoints
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Dec 17, 2024
1 parent b197797 commit c63d869
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 26 deletions.
55 changes: 38 additions & 17 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,24 +729,30 @@ struct server_task_result_embd : server_task_result {
int index = 0;
std::vector<std::vector<float>> embedding;

// OAI-compat fields
bool oaicompat = false;

virtual int get_index() override {
return index;
}

virtual json to_json() override {
if (embedding.size() == 1) {
// to be OAI compatible
return json {
{"index", index},
{"embedding", embedding[0]},
};
}
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
}

json to_json_non_oaicompat() {
return json {
{"index", index},
{"embedding", embedding},
};
}

json to_json_oaicompat() {
return json {
{"index", index},
{"embedding", embedding[0]},
};
}
};

struct server_task_result_rerank : server_task_result {
Expand Down Expand Up @@ -2018,8 +2024,9 @@ struct server_context {

void send_embedding(const server_slot & slot, const llama_batch & batch) {
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.id_task;
res->index = slot.index;
res->id = slot.id_task;
res->index = slot.index;
res->oaicompat = slot.params.oaicompat;

const int n_embd = llama_n_embd(model);

Expand Down Expand Up @@ -3667,14 +3674,17 @@ int main(int argc, char ** argv) {
res_ok(res, data);
};

const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
const json body = json::parse(req.body);
bool oaicompat = false;

if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
return;
}

// an input prompt can be a string or a list of tokens (integer)
json prompt;
if (body.count("input") != 0) {
oaicompat = true;
prompt = body.at("input");
} else if (body.count("content") != 0) {
// with "content", we only support single prompt
Expand All @@ -3691,10 +3701,15 @@ int main(int argc, char ** argv) {
std::vector<server_task> tasks;
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);

task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);

// OAI-compat
task.params.oaicompat = oaicompat;;

tasks.push_back(task);
}

Expand Down Expand Up @@ -3722,12 +3737,18 @@ int main(int argc, char ** argv) {
}

// write JSON response
json root = oaicompat
? format_embeddings_response_oaicompat(body, responses)
: responses.size() == 1 ? responses[0] : json(responses);
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
res_ok(res, root);
};

const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
handle_embeddings_impl(req, res, false);
};

const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
handle_embeddings_impl(req, res, true);
};

const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
Expand Down Expand Up @@ -3901,7 +3922,7 @@ int main(int argc, char ** argv) {
svr->Post("/infill", handle_infill);
svr->Post("/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings_oai);
svr->Post("/rerank", handle_rerank);
svr->Post("/reranking", handle_rerank);
svr->Post("/v1/rerank", handle_rerank);
Expand Down
28 changes: 19 additions & 9 deletions examples/server/tests/unit/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_embedding_single():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": "I believe the meaning of life is",
})
assert res.status_code == 200
Expand All @@ -32,7 +32,7 @@ def test_embedding_multiple():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
Expand All @@ -55,16 +55,26 @@ def test_embedding_pooling_none():
"input": "hello hello hello",
})
assert res.status_code == 200
assert len(res.body['data']) == 1
assert 'embedding' in res.body['data'][0]
assert len(res.body['data'][0]['embedding']) == 3
assert 'embedding' in res.body[0]
assert len(res.body[0]['embedding']) == 3


def test_embedding_pooling_none_oai():
global server
server.pooling = 'none'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": "hello hello hello",
})
# /v1/embeddings does not support pooling type 'none'
assert res.status_code == 400


def test_embedding_openai_library_single():
global server
server.pooling = 'last'
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
assert len(res.data) == 1
assert len(res.data[0].embedding) > 1
Expand All @@ -74,7 +84,7 @@ def test_embedding_openai_library_multiple():
global server
server.pooling = 'last'
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.embeddings.create(model="text-embedding-3-small", input=[
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
Expand All @@ -90,7 +100,7 @@ def test_embedding_error_prompt_too_long():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": "This is a test " * 512,
})
assert res.status_code != 200
Expand All @@ -100,7 +110,7 @@ def test_embedding_error_prompt_too_long():
def test_same_prompt_give_same_result():
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
Expand Down

0 comments on commit c63d869

Please sign in to comment.