diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 186fc7cd040a4b..e4947c4576d9d7 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -729,24 +729,30 @@ struct server_task_result_embd : server_task_result { int index = 0; std::vector> embedding; + // OAI-compat fields + bool oaicompat = false; + virtual int get_index() override { return index; } virtual json to_json() override { - if (embedding.size() == 1) { - // to be OAI compatible - return json { - {"index", index}, - {"embedding", embedding[0]}, - }; - } + return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat(); + } + json to_json_non_oaicompat() { return json { {"index", index}, {"embedding", embedding}, }; } + + json to_json_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding[0]}, + }; + } }; struct server_task_result_rerank : server_task_result { @@ -2018,8 +2024,9 @@ struct server_context { void send_embedding(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; + res->id = slot.id_task; + res->index = slot.index; + res->oaicompat = slot.params.oaicompat; const int n_embd = llama_n_embd(model); @@ -3667,14 +3674,17 @@ int main(int argc, char ** argv) { res_ok(res, data); }; - const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) { const json body = json::parse(req.body); - bool oaicompat = false; + + if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return; + } // an input prompt can be a string or a list of tokens (integer) json prompt; if (body.count("input") != 0) { - oaicompat = true; prompt = body.at("input"); } else if (body.count("content") != 0) { // with "content", we only support single prompt @@ -3691,10 +3701,15 @@ int main(int argc, char ** argv) { std::vector tasks; std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true); for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; task.prompt_tokens = std::move(tokenized_prompts[i]); + + // OAI-compat + task.params.oaicompat = oaicompat;; + tasks.push_back(task); } @@ -3722,12 +3737,18 @@ int main(int argc, char ** argv) { } // write JSON response - json root = oaicompat - ? format_embeddings_response_oaicompat(body, responses) - : responses.size() == 1 ? responses[0] : json(responses); + json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses); res_ok(res, root); }; + const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { + handle_embeddings_impl(req, res, false); + }; + + const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { + handle_embeddings_impl(req, res, true); + }; + const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED)); @@ -3901,7 +3922,7 @@ int main(int argc, char ** argv) { svr->Post("/infill", handle_infill); svr->Post("/embedding", handle_embeddings); // legacy svr->Post("/embeddings", handle_embeddings); - svr->Post("/v1/embeddings", handle_embeddings); + svr->Post("/v1/embeddings", handle_embeddings_oai); svr->Post("/rerank", handle_rerank); svr->Post("/reranking", handle_rerank); svr->Post("/v1/rerank", handle_rerank); diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index 371fed637418ac..f702565a110dc3 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -16,7 +16,7 @@ def test_embedding_single(): global server server.pooling = 'last' server.start() - res = server.make_request("POST", "/embeddings", data={ + res = server.make_request("POST", "/v1/embeddings", data={ "input": "I believe the meaning of life is", }) assert res.status_code == 200 @@ -32,7 +32,7 @@ def test_embedding_multiple(): global server server.pooling = 'last' server.start() - res = server.make_request("POST", "/embeddings", data={ + res = server.make_request("POST", "/v1/embeddings", data={ "input": [ "I believe the meaning of life is", "Write a joke about AI from a very long prompt which will not be truncated", @@ -55,16 +55,26 @@ def test_embedding_pooling_none(): "input": "hello hello hello", }) assert res.status_code == 200 - assert len(res.body['data']) == 1 - assert 'embedding' in res.body['data'][0] - assert len(res.body['data'][0]['embedding']) == 3 + assert 'embedding' in res.body[0] + assert len(res.body[0]['embedding']) == 3 + + +def test_embedding_pooling_none_oai(): + global server + server.pooling = 'none' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": "hello hello hello", + }) + # /v1/embeddings does not support pooling type 'none' + assert res.status_code == 400 def test_embedding_openai_library_single(): global server server.pooling = 'last' server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is") assert len(res.data) == 1 assert len(res.data[0].embedding) > 1 @@ -74,7 +84,7 @@ def test_embedding_openai_library_multiple(): global server server.pooling = 'last' server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") res = client.embeddings.create(model="text-embedding-3-small", input=[ "I believe the meaning of life is", "Write a joke about AI from a very long prompt which will not be truncated", @@ -90,7 +100,7 @@ def test_embedding_error_prompt_too_long(): global server server.pooling = 'last' server.start() - res = server.make_request("POST", "/embeddings", data={ + res = server.make_request("POST", "/v1/embeddings", data={ "input": "This is a test " * 512, }) assert res.status_code != 200 @@ -100,7 +110,7 @@ def test_embedding_error_prompt_too_long(): def test_same_prompt_give_same_result(): server.pooling = 'last' server.start() - res = server.make_request("POST", "/embeddings", data={ + res = server.make_request("POST", "/v1/embeddings", data={ "input": [ "I believe the meaning of life is", "I believe the meaning of life is",