Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server : add support for "encoding_format": "base64" to the */embeddings endpoints #10967

Merged
merged 3 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ endforeach()
add_executable(${TARGET} ${TARGET_SRCS})
install(TARGETS ${TARGET} RUNTIME)

target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})

if (LLAMA_SERVER_SSL)
Expand Down
13 changes: 12 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3786,6 +3786,17 @@ int main(int argc, char ** argv) {
return;
}

bool use_base64 = false;
if (body.count("encoding_format") != 0) {
const std::string& format = body.at("encoding_format");
if (format == "base64") {
use_base64 = true;
} else if (format != "float") {
res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
return;
}
}

std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
for (const auto & tokens : tokenized_prompts) {
// this check is necessary for models that do not add BOS token to the input
Expand Down Expand Up @@ -3837,7 +3848,7 @@ int main(int argc, char ** argv) {
}

// write JSON response
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses);
res_ok(res, root);
};

Expand Down
31 changes: 31 additions & 0 deletions examples/server/tests/unit/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import struct
import pytest
from openai import OpenAI
from utils import *
Expand Down Expand Up @@ -194,3 +196,32 @@ def test_embedding_usage_multiple():
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == 2 * 9


def test_embedding_openai_library_base64():
server.start()
test_input = "Test base64 embedding output"

res = server.make_request("POST", "/v1/embeddings", data={
"input": test_input,
"encoding_format": "base64"
})

assert res.status_code == 200
assert "data" in res.body
assert len(res.body["data"]) == 1

embedding_data = res.body["data"][0]
assert "embedding" in embedding_data
assert isinstance(embedding_data["embedding"], str)

# Verify embedding is valid base64
try:
decoded = base64.b64decode(embedding_data["embedding"])
# Verify decoded data can be converted back to float array
float_count = len(decoded) // 4 # 4 bytes per float
floats = struct.unpack(f'{float_count}f', decoded)
assert len(floats) > 0
assert all(isinstance(x, float) for x in floats)
except Exception as e:
pytest.fail(f"Invalid base64 format: {str(e)}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bad practice to use try..catch inside a test, because if one of the assert fails, it won't let you know it happens at exactly which line of code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngxson, thanks for the hint. my bad, I forgot to remove it.

28 changes: 22 additions & 6 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "common/base64.hpp"

#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
Expand Down Expand Up @@ -591,16 +592,31 @@ static json oaicompat_completion_params_parse(
return llama_params;
}

static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) {
json data = json::array();
int32_t n_tokens = 0;
int i = 0;
for (const auto & elem : embeddings) {
data.push_back(json{
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
});
json embedding_obj;

if (use_base64) {
const auto& vec = json_value(elem, "embedding", json::array()).get<std::vector<float>>();
const char* data_ptr = reinterpret_cast<const char*>(vec.data());
size_t data_size = vec.size() * sizeof(float);
embedding_obj = {
{"embedding", base64::encode(data_ptr, data_size)},
{"index", i++},
{"object", "embedding"},
{"encoding_format", "base64"}
};
} else {
embedding_obj = {
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
};
}
data.push_back(embedding_obj);

n_tokens += json_value(elem, "tokens_evaluated", 0);
}
Expand Down
Loading