From 51dd27f790778ee33ed7401bdca37e7d816b96e5 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 24 Dec 2024 17:00:17 +0100 Subject: [PATCH] improve test --- examples/server/tests/unit/test_embedding.py | 28 +++++++++++++------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index 98bec2c2f453f..8b0eb42b0926f 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -202,6 +202,14 @@ def test_embedding_openai_library_base64(): server.start() test_input = "Test base64 embedding output" + # get embedding in default format + res = server.make_request("POST", "/v1/embeddings", data={ + "input": test_input + }) + assert res.status_code == 200 + vec0 = res.body["data"][0]["embedding"] + + # get embedding in base64 format res = server.make_request("POST", "/v1/embeddings", data={ "input": test_input, "encoding_format": "base64" @@ -216,12 +224,14 @@ def test_embedding_openai_library_base64(): assert isinstance(embedding_data["embedding"], str) # Verify embedding is valid base64 - try: - decoded = base64.b64decode(embedding_data["embedding"]) - # Verify decoded data can be converted back to float array - float_count = len(decoded) // 4 # 4 bytes per float - floats = struct.unpack(f'{float_count}f', decoded) - assert len(floats) > 0 - assert all(isinstance(x, float) for x in floats) - except Exception as e: - pytest.fail(f"Invalid base64 format: {str(e)}") + decoded = base64.b64decode(embedding_data["embedding"]) + # Verify decoded data can be converted back to float array + float_count = len(decoded) // 4 # 4 bytes per float + floats = struct.unpack(f'{float_count}f', decoded) + assert len(floats) > 0 + assert all(isinstance(x, float) for x in floats) + assert len(floats) == len(vec0) + + # make sure the decoded data is the same as the original + for x, y in zip(floats, vec0): + assert abs(x - y) < EPSILON