Skip to content

Commit

Permalink
tests : update server tests
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Dec 18, 2024
1 parent 87df601 commit 2a5510e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
10 changes: 5 additions & 5 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,16 +744,16 @@ struct server_task_result_embd : server_task_result {

json to_json_non_oaicompat() {
return json {
{"index", index},
{"embedding", embedding},
{"tokens_evaluated", n_tokens},
{"index", index},
{"embedding", embedding},
};
}

json to_json_oaicompat() {
return json {
{"index", index},
{"embedding", embedding[0]},
{"index", index},
{"embedding", embedding[0]},
{"tokens_evaluated", n_tokens},
};
}
};
Expand Down
21 changes: 11 additions & 10 deletions examples/server/tests/unit/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_embedding_multiple():


@pytest.mark.parametrize(
"content,is_multi_prompt",
"input,is_multi_prompt",
[
# single prompt
("string", False),
Expand All @@ -61,19 +61,20 @@ def test_embedding_multiple():
([[12, 34, 56], [12, "string", 34, 56]], True),
]
)
def test_embedding_mixed_input(content, is_multi_prompt: bool):
def test_embedding_mixed_input(input, is_multi_prompt: bool):
global server
server.start()
res = server.make_request("POST", "/embeddings", data={"content": content})
res = server.make_request("POST", "/v1/embeddings", data={"input": input})
assert res.status_code == 200
data = res.body['data']
if is_multi_prompt:
assert len(res.body) == len(content)
for d in res.body:
assert len(data) == len(input)
for d in data:
assert 'embedding' in d
assert len(d['embedding']) > 1
else:
assert 'embedding' in res.body
assert len(res.body['embedding']) > 1
assert 'embedding' in data[0]
assert len(data[0]['embedding']) > 1


def test_embedding_pooling_none():
Expand All @@ -85,7 +86,7 @@ def test_embedding_pooling_none():
})
assert res.status_code == 200
assert 'embedding' in res.body[0]
assert len(res.body[0]['embedding']) == 3
assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special

# make sure embedding vector is not normalized
for x in res.body[0]['embedding']:
Expand Down Expand Up @@ -172,7 +173,7 @@ def test_same_prompt_give_same_result():
def test_embedding_usage_single(content, n_tokens):
global server
server.start()
res = server.make_request("POST", "/embeddings", data={"input": content})
res = server.make_request("POST", "/v1/embeddings", data={"input": content})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == n_tokens
Expand All @@ -181,7 +182,7 @@ def test_embedding_usage_single(content, n_tokens):
def test_embedding_usage_multiple():
global server
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
Expand Down

0 comments on commit 2a5510e

Please sign in to comment.