Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
370 changes: 370 additions & 0 deletions patches/llama-server-tools.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,370 @@
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 18bcad3..520b558 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -3007,8 +3007,90 @@ int main(int argc, char ** argv) {

ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
- const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
+ const bool stream_tools = data.contains("__oaicompat_tools");
+ const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, stream_tools](size_t, httplib::DataSink & sink) {
+ std::string tool_stream_content;
+ bool sent_tool_role = false;
+
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
+ if (stream_tools) {
+ std::time_t t = std::time(0);
+ const std::string modelname = json_value(result.data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
+
+ if (!sent_tool_role) {
+ json role_chunk = json{
+ {"choices", json::array({json{
+ {"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json{{"role", "assistant"}}}
+ }})},
+ {"created", t},
+ {"id", completion_id},
+ {"model", modelname},
+ {"object", "chat.completion.chunk"}
+ };
+ if (!server_sent_event(sink, "data", role_chunk)) {
+ return false;
+ }
+ sent_tool_role = true;
+ }
+
+ tool_stream_content += json_value(result.data, "content", std::string());
+
+ const bool stopped_word = json_value(result.data, "stopped_word", false);
+ const bool stopped_eos = json_value(result.data, "stopped_eos", false);
+ const bool stopped_limit = json_value(result.data, "stopped_limit", false);
+ if (!stopped_word && !stopped_eos && !stopped_limit) {
+ return true;
+ }
+
+ std::string finish_reason = stopped_limit ? "length" : "stop";
+ json tool_calls = parse_tool_calls_from_content(tool_stream_content, completion_id);
+ if (!tool_calls.empty()) {
+ finish_reason = "tool_calls";
+ json tool_chunk = json{
+ {"choices", json::array({json{
+ {"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json{{"tool_calls", tool_calls}}}
+ }})},
+ {"created", t},
+ {"id", completion_id},
+ {"model", modelname},
+ {"object", "chat.completion.chunk"}
+ };
+ if (!server_sent_event(sink, "data", tool_chunk)) {
+ return false;
+ }
+ } else if (!tool_stream_content.empty()) {
+ json content_chunk = json{
+ {"choices", json::array({json{
+ {"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json{{"content", tool_stream_content}}}
+ }})},
+ {"created", t},
+ {"id", completion_id},
+ {"model", modelname},
+ {"object", "chat.completion.chunk"}
+ };
+ if (!server_sent_event(sink, "data", content_chunk)) {
+ return false;
+ }
+ }
+
+ json final_chunk = json{
+ {"choices", json::array({json{{"finish_reason", finish_reason},
+ {"index", 0},
+ {"delta", json::object()}}})},
+ {"created", t},
+ {"id", completion_id},
+ {"model", modelname},
+ {"object", "chat.completion.chunk"}
+ };
+ return server_sent_event(sink, "data", final_chunk);
+ }
+
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
for (auto & event_data : result_array) {
if (event_data.empty()) {
diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp
index 69519ef..b5d5eea 100644
--- a/examples/server/utils.hpp
+++ b/examples/server/utils.hpp
@@ -65,9 +65,21 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
std::string role = json_value(curr_msg, "role", std::string(""));

std::string content;
- if (curr_msg.contains("content")) {
+ if (role == "tool") {
+ // Most GGUF chat templates do not define a native "tool" role.
+ // Present tool output as a user-visible observation for the next
+ // assistant turn instead of relying on template-specific behavior.
+ role = "user";
+ content = "Tool result";
+ if (curr_msg.contains("tool_call_id") && curr_msg["tool_call_id"].is_string()) {
+ content += " for " + curr_msg["tool_call_id"].get<std::string>();
+ }
+ content += ":\n";
+ }
+
+ if (curr_msg.contains("content") && !curr_msg["content"].is_null()) {
if (curr_msg["content"].is_string()) {
- content = curr_msg["content"].get<std::string>();
+ content += curr_msg["content"].get<std::string>();
} else if (curr_msg["content"].is_array()) {
for (const auto & part : curr_msg["content"]) {
if (part.contains("text")) {
@@ -77,10 +89,18 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
} else {
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
}
+ } else if (curr_msg.contains("tool_calls")) {
+ content += "Tool calls:\n" + curr_msg["tool_calls"].dump();
} else {
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
}

+ if (json_value(curr_msg, "role", std::string("")) == "tool") {
+ content += "\n\nUse the tool result above to answer the user's request now. "
+ "Do not call another tool. Do not output a tool name, an Input line, "
+ "tool_calls JSON, or any action syntax.";
+ }
+
chat.push_back({role, content});
}

@@ -316,6 +336,151 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
return sink.write(str.c_str(), str.size());
}

+static std::string trim(const std::string & str) {
+ const auto first = str.find_first_not_of(" \t\n\r");
+ if (first == std::string::npos) {
+ return "";
+ }
+
+ const auto last = str.find_last_not_of(" \t\n\r");
+ return str.substr(first, last - first + 1);
+}
+
+static json normalize_tool_call_arguments(const json & args) {
+ if (args.is_string()) {
+ try {
+ return json::parse(args.get<std::string>());
+ } catch (const json::parse_error &) {
+ return json::object({{"input", args.get<std::string>()}});
+ }
+ }
+
+ if (args.is_object()) {
+ return args;
+ }
+
+ if (args.is_null()) {
+ return json::object();
+ }
+
+ return json::object({{"value", args}});
+}
+
+static json normalize_tool_call(const json & call, const std::string & completion_id, size_t index) {
+ std::string name;
+ json arguments = json::object();
+
+ if (call.contains("function") && call["function"].is_object()) {
+ const auto & fn = call["function"];
+ name = json_value(fn, "name", std::string());
+ arguments = normalize_tool_call_arguments(json_value(fn, "arguments", json::object()));
+ } else {
+ name = json_value(call, "name", std::string());
+ arguments = normalize_tool_call_arguments(json_value(call, "arguments", json::object()));
+ }
+
+ if (name.empty()) {
+ throw std::runtime_error("Tool call is missing function name");
+ }
+
+ return json{
+ {"id", json_value(call, "id", completion_id + "_tool_" + std::to_string(index))},
+ {"type", "function"},
+ {"function", json{
+ {"name", name},
+ {"arguments", arguments.dump()},
+ }},
+ };
+}
+
+static json parse_tool_calls_from_content(const std::string & content, const std::string & completion_id) {
+ std::string text = trim(content);
+
+ if (text.rfind("```json", 0) == 0) {
+ text = trim(text.substr(7));
+ } else if (text.rfind("```", 0) == 0) {
+ text = trim(text.substr(3));
+ }
+ if (text.size() >= 3 && text.substr(text.size() - 3) == "```") {
+ text = trim(text.substr(0, text.size() - 3));
+ }
+
+ json parsed;
+ try {
+ parsed = json::parse(text);
+ } catch (const json::parse_error &) {
+ return json::array();
+ }
+
+ if (parsed.is_string()) {
+ try {
+ parsed = json::parse(trim(parsed.get<std::string>()));
+ } catch (const json::parse_error &) {
+ return json::array();
+ }
+ }
+
+ if (!parsed.is_object()) {
+ return json::array();
+ }
+
+ json calls = json::array();
+ json normalized_tool_calls = json::array();
+ const json * tool_calls_json = nullptr;
+ if (parsed.contains("tool_calls") && parsed["tool_calls"].is_array()) {
+ tool_calls_json = &parsed["tool_calls"];
+ } else if (parsed.contains("tool_call") && parsed["tool_call"].is_array()) {
+ tool_calls_json = &parsed["tool_call"];
+ } else if (parsed.contains("tool_call") && parsed["tool_call"].is_object()) {
+ normalized_tool_calls.push_back(parsed["tool_call"]);
+ tool_calls_json = &normalized_tool_calls;
+ }
+
+ if (tool_calls_json != nullptr) {
+ for (size_t i = 0; i < tool_calls_json->size(); ++i) {
+ try {
+ calls.push_back(normalize_tool_call((*tool_calls_json)[i], completion_id, i));
Comment thread
rivaldodev marked this conversation as resolved.
+ } catch (const std::runtime_error &) {
+ return json::array();
+ }
+ }
+ } else if ((parsed.contains("function") && parsed["function"].is_object() && parsed["function"].contains("name"))
+ || (parsed.contains("name") && parsed["name"].is_string() && parsed.contains("arguments"))) {
+ try {
+ calls.push_back(normalize_tool_call(parsed, completion_id, 0));
+ } catch (const std::runtime_error &) {
+ return json::array();
+ }
+ }
+
+ return calls;
+}
+
+static std::string tools_prompt(const json & body) {
+ if (!body.contains("tools") || !body["tools"].is_array() || body["tools"].empty()) {
+ return "";
+ }
+
+ if (body.contains("tool_choice") && body["tool_choice"].is_string() && body["tool_choice"].get<std::string>() == "none") {
+ return "";
+ }
+
+ std::ostringstream ss;
+ ss << "You can call tools when needed. Available tools are provided as JSON below.\n";
+ ss << body["tools"].dump(2) << "\n\n";
+ ss << "If a tool is required, respond only with strict JSON in this exact shape:\n";
+ ss << "{\"tool_calls\":[{\"name\":\"tool_name\",\"arguments\":{\"arg\":\"value\"}}]}\n";
+ ss << "Prefer the \"tool_calls\" array key. A single \"tool_call\" object is accepted for compatibility.\n";
+ ss << "Do not include markdown or explanatory text around a tool call. ";
Comment thread
rivaldodev marked this conversation as resolved.
+ ss << "If no tool is required, answer normally.";
+
+ if (body.contains("tool_choice") && body["tool_choice"].is_object()) {
+ ss << "\nThe requested tool_choice is:\n" << body["tool_choice"].dump(2);
+ }
+
+ return ss.str();
+}
+
//
// OAI utils
//
@@ -329,7 +494,13 @@ static json oaicompat_completion_params_parse(
llama_params["__oaicompat"] = true;

// Apply chat template to the list of messages
- llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
+ json messages = body.at("messages");
+ const std::string tool_instructions = tools_prompt(body);
+ if (!tool_instructions.empty()) {
+ messages.insert(messages.begin(), json{{"role", "system"}, {"content", tool_instructions}});
+ llama_params["__oaicompat_tools"] = body["tools"];
+ }
+ llama_params["prompt"] = format_chat(model, chat_template, messages);

// Handle "stop" field
if (body.contains("stop") && body.at("stop").is_string()) {
@@ -367,7 +538,7 @@ static json oaicompat_completion_params_parse(
}

// Params supported by OAI but unsupported by llama.cpp
- static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
+ static const std::vector<std::string> unsupported_params {};
for (const auto & param : unsupported_params) {
if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param);
@@ -378,6 +549,9 @@ static json oaicompat_completion_params_parse(
// This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint.
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
for (const auto & item : body.items()) {
+ if (item.key() == "tools" || item.key() == "tool_choice") {
+ continue;
+ }
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
llama_params[item.key()] = item.value();
@@ -399,14 +573,33 @@ static json format_final_response_oaicompat(const json & request, const json & r
finish_reason = "stop";
}

- json choices =
- streaming ? json::array({json{{"finish_reason", finish_reason},
- {"index", 0},
- {"delta", json::object()}}})
- : json::array({json{{"finish_reason", finish_reason},
- {"index", 0},
- {"message", json{{"content", content},
- {"role", "assistant"}}}}});
+ json choices;
+ json tool_calls = json::array();
+ if (!streaming) {
+ // Some agent runtimes describe tools in the prompt instead of sending
+ // the OpenAI "tools" parameter. Still upgrade strict tool-call JSON
+ // content into OpenAI-compatible message.tool_calls.
+ // This keeps n8n/LangChain agents from treating the call as plain text.
+ tool_calls = parse_tool_calls_from_content(content, completion_id);
+ }
+
+ if (!tool_calls.empty()) {
+ finish_reason = "tool_calls";
+ choices = json::array({json{{"finish_reason", finish_reason},
+ {"index", 0},
+ {"message", json{{"content", nullptr},
+ {"role", "assistant"},
+ {"tool_calls", tool_calls}}}}});
+ } else {
+ choices =
+ streaming ? json::array({json{{"finish_reason", finish_reason},
+ {"index", 0},
+ {"delta", json::object()}}})
+ : json::array({json{{"finish_reason", finish_reason},
+ {"index", 0},
+ {"message", json{{"content", content},
+ {"role", "assistant"}}}}});
+ }

std::time_t t = std::time(0);
11 changes: 11 additions & 0 deletions setup_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,18 @@ def compile():
# run_command(["cmake", "--build", "build", "--target", "llama-cli", "--config", "Release"])
run_command(["cmake", "--build", "build", "--config", "Release"], log_step="compile")

def apply_local_patches():
llama_cpp = Path("3rdparty") / "llama.cpp"
server_cpp = llama_cpp / "examples" / "server" / "server.cpp"
utils_hpp = llama_cpp / "examples" / "server" / "utils.hpp"
if not llama_cpp.exists() or not server_cpp.exists() or not utils_hpp.exists():
logging.info("Skipping local llama.cpp patches: 3rdparty/llama.cpp server sources not found.")
return

run_command([sys.executable, "utils/apply_local_patches.py"], log_step="apply_local_patches")

def main():
apply_local_patches()
setup_gguf()
gen_code()
compile()
Comment thread
rivaldodev marked this conversation as resolved.
Expand Down
Loading