Skip to content

Commit

Permalink
tool-call: stabilize server tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ochafik committed Dec 15, 2024
1 parent 7bfcd0a commit 7e3feff
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 57 deletions.
12 changes: 6 additions & 6 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ class llama_antiprompts {
};

std::vector<std::string> stop_words;
std::vector<std::string> grammar_trigger_words;
std::vector<std::string> grammar_triggers;

private:
// The Aho–Corasick algorithm allows efficient string matching with multiple patterns.
Expand Down Expand Up @@ -740,25 +740,25 @@ class llama_antiprompts {
stop_tokens.clear();
}

void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_triggers) {
build(
[&](const std::string & text) {
return common_tokenize(ctx, text, /* special= */ true);
},
stop_words,
grammar_trigger_words
grammar_triggers
);
}

void build(const std::function<std::vector<llama_token>(const std::string &)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
void build(const std::function<std::vector<llama_token>(const std::string &)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_triggers) {
clear();
this->stop_words = stop_words;
this->grammar_trigger_words = grammar_trigger_words;
this->grammar_triggers = grammar_triggers;

for (const std::string & stop_word : stop_words) {
antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false});
}
for (const std::string & trigger : grammar_trigger_words) {
for (const std::string & trigger : grammar_triggers) {
antiprompts.push_back({trigger, /* is_grammar_trigger= */ true});
}

Expand Down
2 changes: 1 addition & 1 deletion common/tool-call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
if (!parallel) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\"[TOOL_CALLS]\"? " + builder.add_schema("tool_calls", schema));
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
});
if (allow_content) {
handler.grammar_triggers.push_back("[TOOL_CALLS]");
Expand Down
59 changes: 25 additions & 34 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ struct slot_params {
json input_prefix;
json input_suffix;
std::vector<std::string> antiprompt;
std::vector<std::string> grammar_triggers;
bool timings_per_token = false;
bool ignore_eos = false;

Expand Down Expand Up @@ -318,47 +317,39 @@ struct server_task {
}
}

if (data.contains("grammar_triggers")) {
const auto & triggers = data.at("grammar_triggers");
if (triggers.is_array()) {
for (const auto & trigger : triggers) {
if (trigger.is_string()) {
params.grammar_triggers.push_back(trigger);
auto to_string_vec = [](const json & j) {
std::vector<std::string> out;
if (j.is_array()) {
for (const auto & e : j) {
if (e.is_string()) {
out.push_back(e);
}
}
}
}
return out;
};

{
params.antiprompt.clear();
const auto grammar_trigger_words = data.find("grammar_trigger_words");
if (grammar_trigger_words != data.end()) {
params.sampling.grammar_trigger_words = to_string_vec(*grammar_trigger_words);
}
}

const auto & stop = data.find("stop");
if (stop != data.end() && stop->is_array()) {
for (const auto & word : *stop) {
if (!word.empty()) {
params.antiprompt.push_back(word);
}
}
{
const auto stop = data.find("stop");
if (stop != data.end()) {
params.antiprompt = to_string_vec(*stop);
}
}

{
const auto & samplers = data.find("samplers");
const auto samplers = data.find("samplers");
if (samplers != data.end()) {
if (samplers->is_array()) {
std::vector<std::string> sampler_names;
for (const auto & name : *samplers) {
if (name.is_string()) {
sampler_names.emplace_back(name);
}
}
params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
params.sampling.samplers = common_sampler_types_from_names(to_string_vec(*samplers), false);
} else if (samplers->is_string()){
std::string sampler_string;
for (const auto & name : *samplers) {
sampler_string += name;
}
params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
}
} else {
params.sampling.samplers = defaults.sampling.samplers;
Expand Down Expand Up @@ -546,7 +537,7 @@ struct server_task_result_cmpl_final : server_task_result {
llama_tool_calls parsed_tool_calls;
json tool_calls;
json message_content;
if (!oaicompat_tools.is_null()) {
if (oaicompat_tool_call_style != llama_tool_call_style::None && !oaicompat_tools.is_null()) {
parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content);
if (!parsed_tool_calls.tool_calls.empty()) {
finish_reason = "tool_calls";
Expand Down Expand Up @@ -1759,7 +1750,7 @@ struct server_context {

{
slot.antiprompts.clear();
slot.antiprompts.build(ctx, slot.params.antiprompt, slot.params.grammar_triggers);
slot.antiprompts.build(ctx, slot.params.antiprompt, slot.params.sampling.grammar_trigger_words);
}

{
Expand Down Expand Up @@ -1805,7 +1796,7 @@ struct server_context {

if (match.pos != std::string::npos && !match.is_partial) {
if (match.is_grammar_trigger) {
common_sampler_trigger_grammar(model, slot.smpl, common_token_to_piece(ctx, result.tok, params_base.special));
common_sampler_trigger_grammar(model, slot.smpl, token_str);
} else {
// slot.stopped_word = true;
slot.stopping_word = match.pattern;
Expand Down Expand Up @@ -2014,7 +2005,7 @@ struct server_context {
{"mirostat_eta", slot.params.sampling.mirostat_eta},
{"penalize_nl", slot.params.sampling.penalize_nl},
{"stop", slot.params.antiprompt},
{"grammar_trigger", slot.params.grammar_triggers},
{"grammar_trigger_words", slot.params.sampling.grammar_trigger_words},
{"max_tokens", slot.params.n_predict}, // User configured n_predict
{"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard},
Expand Down Expand Up @@ -3564,7 +3555,7 @@ int main(int argc, char ** argv) {
task.params.oaicompat = oaicompat;
task.params.oaicompat_chat = oaicompat_chat;
task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_tools = json_value(data, "tools", json::array());
task.params.oaicompat_tools = json_value(data, "tools", json());
task.params.oaicompat_tool_call_style = tool_call_style;

// oaicompat_model is already populated by params_from_json_cmpl
Expand Down
28 changes: 16 additions & 12 deletions examples/server/tests/unit/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,23 +202,24 @@ def test_chat_completion_with_timings_per_token():

@pytest.mark.parametrize("template_name,n_predict,tool,expected_arguments", [
("meetkai-functionary-medium-v3.1", 32, TEST_TOOL, {} ),
("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": ". She was so excited to go to the park and s"} ),
("meetkai-functionary-medium-v3.2", 32, TEST_TOOL, {} ),
("meetkai-functionary-medium-v3.2", 32, PYTHON_TOOL, {"code": "Yes,"} ),
("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": " and played all day.\" exclasted her pare"} ),
("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, {} ),
("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, {"code": "Sure, I cannything,"} ),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, {} ),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes,"} ),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, {} ),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes,"} ),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, {} ),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "It's a shark."} ),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "Let's feel out cooking fun together,"} ),
("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, {} ),
("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "It's a shark."} ),
("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "Well you fight. Peopballs donto cheep and come again."} ),
("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, {} ),
("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "It's a small cost."} ),
("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "I can cannot count."} ),
])
def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict):
global server
server.use_jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
server.start()
res = server.make_request("POST", "/chat/completions", data={
Expand All @@ -227,13 +228,14 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"tool_choice": tool["function"]["name"],
"tool_choice": "required",
"tools": [tool],
"parallel_tool_calls": False,
})
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls==1), f'Expected 1 tool call in {choice["message"]}'
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
assert tool["function"]["name"] == tool_call["function"]["name"]
actual_arguments = json.loads(tool_call["function"]["arguments"])
Expand All @@ -254,6 +256,7 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
def test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
global server
server.use_jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
server.start()
res = server.make_request("POST", "/chat/completions", data={
Expand All @@ -267,7 +270,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
})
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
assert "tool_calls" not in choice["message"], f'Expected no tool call in {choice["message"]}'
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'


@pytest.mark.slow
Expand Down Expand Up @@ -296,6 +299,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
global server
server.use_jinja = True
server.n_predict = 128
server.model_hf_repo = hf_repo
server.model_hf_file = hf_file
if template_override:
Expand All @@ -314,7 +318,7 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls==1), f'Expected 1 tool call in {choice["message"]}'
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
assert tool["function"]["name"] == tool_call["function"]["name"]
actual_arguments = json.loads(tool_call["function"]["arguments"])
Expand Down
9 changes: 5 additions & 4 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ static json oaicompat_completion_params_parse(
auto tools = json_value(body, "tools", json());
auto has_tools = tools.is_array() && !tools.empty();

auto stream = json_value(body, "stream", json());
auto stream = json_value(body, "stream", false);
if (stream && has_tools) {
throw std::runtime_error("Cannot use tools with stream");
}
Expand Down Expand Up @@ -561,11 +561,12 @@ static json oaicompat_completion_params_parse(
llama_params["stop"].push_back(stop);
}
if (!handler.grammar_triggers.empty()) {
auto triggers = json::array();
auto trigger_words = json::array();
for (const auto & word : handler.grammar_triggers) {
triggers.push_back(word);
trigger_words.push_back(word);

}
llama_params["grammar_triggers"] = triggers;
llama_params["grammar_trigger_words"] = trigger_words;
}
if (!handler.grammar.empty()) {
if (llama_params.contains("grammar")) {
Expand Down

0 comments on commit 7e3feff

Please sign in to comment.