Skip to content

Commit

Permalink
server : add support for "encoding_format": "base64" to the */embeddi…
Browse files Browse the repository at this point in the history
…ngs endpoints (#10967)

* add support for base64

* fix base64 test

* improve test

---------

Co-authored-by: Xuan Son Nguyen <[email protected]>
  • Loading branch information
elk-cloner and ngxson authored Dec 24, 2024
1 parent 2cd43f4 commit 9ba399d
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 7 deletions.
1 change: 1 addition & 0 deletions examples/server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ endforeach()
add_executable(${TARGET} ${TARGET_SRCS})
install(TARGETS ${TARGET} RUNTIME)

target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})

if (LLAMA_SERVER_SSL)
Expand Down
13 changes: 12 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3790,6 +3790,17 @@ int main(int argc, char ** argv) {
return;
}

bool use_base64 = false;
if (body.count("encoding_format") != 0) {
const std::string& format = body.at("encoding_format");
if (format == "base64") {
use_base64 = true;
} else if (format != "float") {
res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
return;
}
}

std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
for (const auto & tokens : tokenized_prompts) {
// this check is necessary for models that do not add BOS token to the input
Expand Down Expand Up @@ -3841,7 +3852,7 @@ int main(int argc, char ** argv) {
}

// write JSON response
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses);
res_ok(res, root);
};

Expand Down
41 changes: 41 additions & 0 deletions examples/server/tests/unit/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import struct
import pytest
from openai import OpenAI
from utils import *
Expand Down Expand Up @@ -194,3 +196,42 @@ def test_embedding_usage_multiple():
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == 2 * 9


def test_embedding_openai_library_base64():
server.start()
test_input = "Test base64 embedding output"

# get embedding in default format
res = server.make_request("POST", "/v1/embeddings", data={
"input": test_input
})
assert res.status_code == 200
vec0 = res.body["data"][0]["embedding"]

# get embedding in base64 format
res = server.make_request("POST", "/v1/embeddings", data={
"input": test_input,
"encoding_format": "base64"
})

assert res.status_code == 200
assert "data" in res.body
assert len(res.body["data"]) == 1

embedding_data = res.body["data"][0]
assert "embedding" in embedding_data
assert isinstance(embedding_data["embedding"], str)

# Verify embedding is valid base64
decoded = base64.b64decode(embedding_data["embedding"])
# Verify decoded data can be converted back to float array
float_count = len(decoded) // 4 # 4 bytes per float
floats = struct.unpack(f'{float_count}f', decoded)
assert len(floats) > 0
assert all(isinstance(x, float) for x in floats)
assert len(floats) == len(vec0)

# make sure the decoded data is the same as the original
for x, y in zip(floats, vec0):
assert abs(x - y) < EPSILON
28 changes: 22 additions & 6 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "common/base64.hpp"

#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
Expand Down Expand Up @@ -613,16 +614,31 @@ static json oaicompat_completion_params_parse(
return llama_params;
}

static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) {
json data = json::array();
int32_t n_tokens = 0;
int i = 0;
for (const auto & elem : embeddings) {
data.push_back(json{
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
});
json embedding_obj;

if (use_base64) {
const auto& vec = json_value(elem, "embedding", json::array()).get<std::vector<float>>();
const char* data_ptr = reinterpret_cast<const char*>(vec.data());
size_t data_size = vec.size() * sizeof(float);
embedding_obj = {
{"embedding", base64::encode(data_ptr, data_size)},
{"index", i++},
{"object", "embedding"},
{"encoding_format", "base64"}
};
} else {
embedding_obj = {
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
};
}
data.push_back(embedding_obj);

n_tokens += json_value(elem, "tokens_evaluated", 0);
}
Expand Down

0 comments on commit 9ba399d

Please sign in to comment.