Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server : add support for "encoding_format": "base64" to the */embeddings endpoints #10967

Merged
merged 3 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ endforeach()
add_executable(${TARGET} ${TARGET_SRCS})
install(TARGETS ${TARGET} RUNTIME)

target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})

if (LLAMA_SERVER_SSL)
Expand Down
13 changes: 12 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3786,6 +3786,17 @@ int main(int argc, char ** argv) {
return;
}

bool use_base64 = false;
if (body.count("encoding_format") != 0) {
const std::string& format = body.at("encoding_format");
if (format == "base64") {
use_base64 = true;
} else if (format != "float") {
res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
return;
}
}

std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
for (const auto & tokens : tokenized_prompts) {
// this check is necessary for models that do not add BOS token to the input
Expand Down Expand Up @@ -3837,7 +3848,7 @@ int main(int argc, char ** argv) {
}

// write JSON response
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses);
res_ok(res, root);
};

Expand Down
41 changes: 41 additions & 0 deletions examples/server/tests/unit/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import struct
import pytest
from openai import OpenAI
from utils import *
Expand Down Expand Up @@ -194,3 +196,42 @@ def test_embedding_usage_multiple():
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == 2 * 9


def test_embedding_openai_library_base64():
server.start()
test_input = "Test base64 embedding output"

# get embedding in default format
res = server.make_request("POST", "/v1/embeddings", data={
"input": test_input
})
assert res.status_code == 200
vec0 = res.body["data"][0]["embedding"]

# get embedding in base64 format
res = server.make_request("POST", "/v1/embeddings", data={
"input": test_input,
"encoding_format": "base64"
})

assert res.status_code == 200
assert "data" in res.body
assert len(res.body["data"]) == 1

embedding_data = res.body["data"][0]
assert "embedding" in embedding_data
assert isinstance(embedding_data["embedding"], str)

# Verify embedding is valid base64
decoded = base64.b64decode(embedding_data["embedding"])
# Verify decoded data can be converted back to float array
float_count = len(decoded) // 4 # 4 bytes per float
floats = struct.unpack(f'{float_count}f', decoded)
assert len(floats) > 0
assert all(isinstance(x, float) for x in floats)
assert len(floats) == len(vec0)

# make sure the decoded data is the same as the original
for x, y in zip(floats, vec0):
assert abs(x - y) < EPSILON
28 changes: 22 additions & 6 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "common/base64.hpp"

#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
Expand Down Expand Up @@ -591,16 +592,31 @@ static json oaicompat_completion_params_parse(
return llama_params;
}

static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) {
json data = json::array();
int32_t n_tokens = 0;
int i = 0;
for (const auto & elem : embeddings) {
data.push_back(json{
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
});
json embedding_obj;

if (use_base64) {
const auto& vec = json_value(elem, "embedding", json::array()).get<std::vector<float>>();
const char* data_ptr = reinterpret_cast<const char*>(vec.data());
size_t data_size = vec.size() * sizeof(float);
embedding_obj = {
{"embedding", base64::encode(data_ptr, data_size)},
{"index", i++},
{"object", "embedding"},
{"encoding_format", "base64"}
};
} else {
embedding_obj = {
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
};
}
data.push_back(embedding_obj);

n_tokens += json_value(elem, "tokens_evaluated", 0);
}
Expand Down
Loading