diff --git a/examples/server/README.md b/examples/server/README.md index 5e3d6a6e643a6..c7d91be9976c4 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -450,6 +450,8 @@ These words will not be included in the completion, so make sure to add them to `post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain. +`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. + **Response format** - Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 476a9225f75b7..3fbfb13c49b72 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -92,6 +92,7 @@ struct slot_params { int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit std::vector antiprompt; + std::vector response_fields; bool timings_per_token = false; bool post_sampling_probs = false; bool ignore_eos = false; @@ -209,6 +210,7 @@ struct server_task { params.n_discard = json_value(data, "n_discard", defaults.n_discard); //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.response_fields = json_value(data, "response_fields", std::vector()); params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); @@ -522,6 +524,7 @@ struct server_task_result_cmpl_final : server_task_result { bool post_sampling_probs; std::vector probs_output; + std::vector response_fields; slot_params generation_params; @@ -568,7 +571,7 @@ struct server_task_result_cmpl_final : server_task_result { if (!stream && !probs_output.empty()) { res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); } - return res; + return response_fields.empty() ? res : json_get_nested_values(response_fields, res); } json to_json_oaicompat_chat() { @@ -2066,6 +2069,7 @@ struct server_context { res->tokens = slot.generated_tokens; res->timings = slot.get_timings(); res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); + res->response_fields = slot.params.response_fields; res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index b88d45f18547f..00d5ce391d8f0 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -257,6 +257,40 @@ def check_slots_status(): # assert match_regex(re_content, res.body["content"]) +@pytest.mark.parametrize( + "prompt,n_predict,response_fields", + [ + ("I believe the meaning of life is", 8, []), + ("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]), + ], +) +def test_completion_response_fields( + prompt: str, n_predict: int, response_fields: list[str] +): + global server + server.start() + res = server.make_request( + "POST", + "/completion", + data={ + "n_predict": n_predict, + "prompt": prompt, + "response_fields": response_fields, + }, + ) + assert res.status_code == 200 + assert "content" in res.body + assert len(res.body["content"]) + if len(response_fields): + assert res.body["generation_settings/n_predict"] == n_predict + assert res.body["prompt"] == " " + prompt + assert isinstance(res.body["content"], str) + assert len(res.body) == len(response_fields) + else: + assert len(res.body) + assert "generation_settings" in res.body + + def test_n_probs(): global server server.start() diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 1987acac89159..043d8b52897db 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -90,6 +90,28 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) { return false; } +// get value by path(key1 / key2) +static json json_get_nested_values(const std::vector & paths, const json & js) { + json result = json::object(); + + for (const std::string & path : paths) { + json current = js; + const auto keys = string_split(path, /*separator*/ '/'); + bool valid_path = true; + for (const std::string & k : keys) { + if (valid_path && current.is_object() && current.contains(k)) { + current = current[k]; + } else { + valid_path = false; + } + } + if (valid_path) { + result[path] = current; + } + } + return result; +} + /** * this handles 2 cases: * - only string, example: "string"