diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 958860bfb8a47..de1382a141b09 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -744,16 +744,16 @@ struct server_task_result_embd : server_task_result { json to_json_non_oaicompat() { return json { - {"index", index}, - {"embedding", embedding}, - {"tokens_evaluated", n_tokens}, + {"index", index}, + {"embedding", embedding}, }; } json to_json_oaicompat() { return json { - {"index", index}, - {"embedding", embedding[0]}, + {"index", index}, + {"embedding", embedding[0]}, + {"tokens_evaluated", n_tokens}, }; } }; diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index b5348120a74c6..e32d745829605 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -48,7 +48,7 @@ def test_embedding_multiple(): @pytest.mark.parametrize( - "content,is_multi_prompt", + "input,is_multi_prompt", [ # single prompt ("string", False), @@ -61,19 +61,20 @@ def test_embedding_multiple(): ([[12, 34, 56], [12, "string", 34, 56]], True), ] ) -def test_embedding_mixed_input(content, is_multi_prompt: bool): +def test_embedding_mixed_input(input, is_multi_prompt: bool): global server server.start() - res = server.make_request("POST", "/embeddings", data={"content": content}) + res = server.make_request("POST", "/v1/embeddings", data={"input": input}) assert res.status_code == 200 + data = res.body['data'] if is_multi_prompt: - assert len(res.body) == len(content) - for d in res.body: + assert len(data) == len(input) + for d in data: assert 'embedding' in d assert len(d['embedding']) > 1 else: - assert 'embedding' in res.body - assert len(res.body['embedding']) > 1 + assert 'embedding' in data[0] + assert len(data[0]['embedding']) > 1 def test_embedding_pooling_none(): @@ -85,7 +86,7 @@ def test_embedding_pooling_none(): }) assert res.status_code == 200 assert 'embedding' in res.body[0] - assert len(res.body[0]['embedding']) == 3 + assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special # make sure embedding vector is not normalized for x in res.body[0]['embedding']: @@ -172,7 +173,7 @@ def test_same_prompt_give_same_result(): def test_embedding_usage_single(content, n_tokens): global server server.start() - res = server.make_request("POST", "/embeddings", data={"input": content}) + res = server.make_request("POST", "/v1/embeddings", data={"input": content}) assert res.status_code == 200 assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] assert res.body['usage']['prompt_tokens'] == n_tokens @@ -181,7 +182,7 @@ def test_embedding_usage_single(content, n_tokens): def test_embedding_usage_multiple(): global server server.start() - res = server.make_request("POST", "/embeddings", data={ + res = server.make_request("POST", "/v1/embeddings", data={ "input": [ "I believe the meaning of life is", "I believe the meaning of life is",