diff --git a/patches/llama-server-tools.patch b/patches/llama-server-tools.patch new file mode 100644 index 000000000..92d821a96 --- /dev/null +++ b/patches/llama-server-tools.patch @@ -0,0 +1,370 @@ +diff --git a/examples/server/server.cpp b/examples/server/server.cpp +index 18bcad3..520b558 100644 +--- a/examples/server/server.cpp ++++ b/examples/server/server.cpp +@@ -3007,8 +3007,90 @@ int main(int argc, char ** argv) { + + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + } else { +- const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { ++ const bool stream_tools = data.contains("__oaicompat_tools"); ++ const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, stream_tools](size_t, httplib::DataSink & sink) { ++ std::string tool_stream_content; ++ bool sent_tool_role = false; ++ + ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { ++ if (stream_tools) { ++ std::time_t t = std::time(0); ++ const std::string modelname = json_value(result.data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); ++ ++ if (!sent_tool_role) { ++ json role_chunk = json{ ++ {"choices", json::array({json{ ++ {"finish_reason", nullptr}, ++ {"index", 0}, ++ {"delta", json{{"role", "assistant"}}} ++ }})}, ++ {"created", t}, ++ {"id", completion_id}, ++ {"model", modelname}, ++ {"object", "chat.completion.chunk"} ++ }; ++ if (!server_sent_event(sink, "data", role_chunk)) { ++ return false; ++ } ++ sent_tool_role = true; ++ } ++ ++ tool_stream_content += json_value(result.data, "content", std::string()); ++ ++ const bool stopped_word = json_value(result.data, "stopped_word", false); ++ const bool stopped_eos = json_value(result.data, "stopped_eos", false); ++ const bool stopped_limit = json_value(result.data, "stopped_limit", false); ++ if (!stopped_word && !stopped_eos && !stopped_limit) { ++ return true; ++ } ++ ++ std::string finish_reason = stopped_limit ? "length" : "stop"; ++ json tool_calls = parse_tool_calls_from_content(tool_stream_content, completion_id); ++ if (!tool_calls.empty()) { ++ finish_reason = "tool_calls"; ++ json tool_chunk = json{ ++ {"choices", json::array({json{ ++ {"finish_reason", nullptr}, ++ {"index", 0}, ++ {"delta", json{{"tool_calls", tool_calls}}} ++ }})}, ++ {"created", t}, ++ {"id", completion_id}, ++ {"model", modelname}, ++ {"object", "chat.completion.chunk"} ++ }; ++ if (!server_sent_event(sink, "data", tool_chunk)) { ++ return false; ++ } ++ } else if (!tool_stream_content.empty()) { ++ json content_chunk = json{ ++ {"choices", json::array({json{ ++ {"finish_reason", nullptr}, ++ {"index", 0}, ++ {"delta", json{{"content", tool_stream_content}}} ++ }})}, ++ {"created", t}, ++ {"id", completion_id}, ++ {"model", modelname}, ++ {"object", "chat.completion.chunk"} ++ }; ++ if (!server_sent_event(sink, "data", content_chunk)) { ++ return false; ++ } ++ } ++ ++ json final_chunk = json{ ++ {"choices", json::array({json{{"finish_reason", finish_reason}, ++ {"index", 0}, ++ {"delta", json::object()}}})}, ++ {"created", t}, ++ {"id", completion_id}, ++ {"model", modelname}, ++ {"object", "chat.completion.chunk"} ++ }; ++ return server_sent_event(sink, "data", final_chunk); ++ } ++ + std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); + for (auto & event_data : result_array) { + if (event_data.empty()) { +diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp +index 69519ef..b5d5eea 100644 +--- a/examples/server/utils.hpp ++++ b/examples/server/utils.hpp +@@ -65,9 +65,21 @@ inline std::string format_chat(const struct llama_model * model, const std::stri + std::string role = json_value(curr_msg, "role", std::string("")); + + std::string content; +- if (curr_msg.contains("content")) { ++ if (role == "tool") { ++ // Most GGUF chat templates do not define a native "tool" role. ++ // Present tool output as a user-visible observation for the next ++ // assistant turn instead of relying on template-specific behavior. ++ role = "user"; ++ content = "Tool result"; ++ if (curr_msg.contains("tool_call_id") && curr_msg["tool_call_id"].is_string()) { ++ content += " for " + curr_msg["tool_call_id"].get(); ++ } ++ content += ":\n"; ++ } ++ ++ if (curr_msg.contains("content") && !curr_msg["content"].is_null()) { + if (curr_msg["content"].is_string()) { +- content = curr_msg["content"].get(); ++ content += curr_msg["content"].get(); + } else if (curr_msg["content"].is_array()) { + for (const auto & part : curr_msg["content"]) { + if (part.contains("text")) { +@@ -77,10 +89,18 @@ inline std::string format_chat(const struct llama_model * model, const std::stri + } else { + throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + } ++ } else if (curr_msg.contains("tool_calls")) { ++ content += "Tool calls:\n" + curr_msg["tool_calls"].dump(); + } else { + throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + } + ++ if (json_value(curr_msg, "role", std::string("")) == "tool") { ++ content += "\n\nUse the tool result above to answer the user's request now. " ++ "Do not call another tool. Do not output a tool name, an Input line, " ++ "tool_calls JSON, or any action syntax."; ++ } ++ + chat.push_back({role, content}); + } + +@@ -316,6 +336,151 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons + return sink.write(str.c_str(), str.size()); + } + ++static std::string trim(const std::string & str) { ++ const auto first = str.find_first_not_of(" \t\n\r"); ++ if (first == std::string::npos) { ++ return ""; ++ } ++ ++ const auto last = str.find_last_not_of(" \t\n\r"); ++ return str.substr(first, last - first + 1); ++} ++ ++static json normalize_tool_call_arguments(const json & args) { ++ if (args.is_string()) { ++ try { ++ return json::parse(args.get()); ++ } catch (const json::parse_error &) { ++ return json::object({{"input", args.get()}}); ++ } ++ } ++ ++ if (args.is_object()) { ++ return args; ++ } ++ ++ if (args.is_null()) { ++ return json::object(); ++ } ++ ++ return json::object({{"value", args}}); ++} ++ ++static json normalize_tool_call(const json & call, const std::string & completion_id, size_t index) { ++ std::string name; ++ json arguments = json::object(); ++ ++ if (call.contains("function") && call["function"].is_object()) { ++ const auto & fn = call["function"]; ++ name = json_value(fn, "name", std::string()); ++ arguments = normalize_tool_call_arguments(json_value(fn, "arguments", json::object())); ++ } else { ++ name = json_value(call, "name", std::string()); ++ arguments = normalize_tool_call_arguments(json_value(call, "arguments", json::object())); ++ } ++ ++ if (name.empty()) { ++ throw std::runtime_error("Tool call is missing function name"); ++ } ++ ++ return json{ ++ {"id", json_value(call, "id", completion_id + "_tool_" + std::to_string(index))}, ++ {"type", "function"}, ++ {"function", json{ ++ {"name", name}, ++ {"arguments", arguments.dump()}, ++ }}, ++ }; ++} ++ ++static json parse_tool_calls_from_content(const std::string & content, const std::string & completion_id) { ++ std::string text = trim(content); ++ ++ if (text.rfind("```json", 0) == 0) { ++ text = trim(text.substr(7)); ++ } else if (text.rfind("```", 0) == 0) { ++ text = trim(text.substr(3)); ++ } ++ if (text.size() >= 3 && text.substr(text.size() - 3) == "```") { ++ text = trim(text.substr(0, text.size() - 3)); ++ } ++ ++ json parsed; ++ try { ++ parsed = json::parse(text); ++ } catch (const json::parse_error &) { ++ return json::array(); ++ } ++ ++ if (parsed.is_string()) { ++ try { ++ parsed = json::parse(trim(parsed.get())); ++ } catch (const json::parse_error &) { ++ return json::array(); ++ } ++ } ++ ++ if (!parsed.is_object()) { ++ return json::array(); ++ } ++ ++ json calls = json::array(); ++ json normalized_tool_calls = json::array(); ++ const json * tool_calls_json = nullptr; ++ if (parsed.contains("tool_calls") && parsed["tool_calls"].is_array()) { ++ tool_calls_json = &parsed["tool_calls"]; ++ } else if (parsed.contains("tool_call") && parsed["tool_call"].is_array()) { ++ tool_calls_json = &parsed["tool_call"]; ++ } else if (parsed.contains("tool_call") && parsed["tool_call"].is_object()) { ++ normalized_tool_calls.push_back(parsed["tool_call"]); ++ tool_calls_json = &normalized_tool_calls; ++ } ++ ++ if (tool_calls_json != nullptr) { ++ for (size_t i = 0; i < tool_calls_json->size(); ++i) { ++ try { ++ calls.push_back(normalize_tool_call((*tool_calls_json)[i], completion_id, i)); ++ } catch (const std::runtime_error &) { ++ return json::array(); ++ } ++ } ++ } else if ((parsed.contains("function") && parsed["function"].is_object() && parsed["function"].contains("name")) ++ || (parsed.contains("name") && parsed["name"].is_string() && parsed.contains("arguments"))) { ++ try { ++ calls.push_back(normalize_tool_call(parsed, completion_id, 0)); ++ } catch (const std::runtime_error &) { ++ return json::array(); ++ } ++ } ++ ++ return calls; ++} ++ ++static std::string tools_prompt(const json & body) { ++ if (!body.contains("tools") || !body["tools"].is_array() || body["tools"].empty()) { ++ return ""; ++ } ++ ++ if (body.contains("tool_choice") && body["tool_choice"].is_string() && body["tool_choice"].get() == "none") { ++ return ""; ++ } ++ ++ std::ostringstream ss; ++ ss << "You can call tools when needed. Available tools are provided as JSON below.\n"; ++ ss << body["tools"].dump(2) << "\n\n"; ++ ss << "If a tool is required, respond only with strict JSON in this exact shape:\n"; ++ ss << "{\"tool_calls\":[{\"name\":\"tool_name\",\"arguments\":{\"arg\":\"value\"}}]}\n"; ++ ss << "Prefer the \"tool_calls\" array key. A single \"tool_call\" object is accepted for compatibility.\n"; ++ ss << "Do not include markdown or explanatory text around a tool call. "; ++ ss << "If no tool is required, answer normally."; ++ ++ if (body.contains("tool_choice") && body["tool_choice"].is_object()) { ++ ss << "\nThe requested tool_choice is:\n" << body["tool_choice"].dump(2); ++ } ++ ++ return ss.str(); ++} ++ + // + // OAI utils + // +@@ -329,7 +494,13 @@ static json oaicompat_completion_params_parse( + llama_params["__oaicompat"] = true; + + // Apply chat template to the list of messages +- llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); ++ json messages = body.at("messages"); ++ const std::string tool_instructions = tools_prompt(body); ++ if (!tool_instructions.empty()) { ++ messages.insert(messages.begin(), json{{"role", "system"}, {"content", tool_instructions}}); ++ llama_params["__oaicompat_tools"] = body["tools"]; ++ } ++ llama_params["prompt"] = format_chat(model, chat_template, messages); + + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { +@@ -367,7 +538,7 @@ static json oaicompat_completion_params_parse( + } + + // Params supported by OAI but unsupported by llama.cpp +- static const std::vector unsupported_params { "tools", "tool_choice" }; ++ static const std::vector unsupported_params {}; + for (const auto & param : unsupported_params) { + if (body.contains(param)) { + throw std::runtime_error("Unsupported param: " + param); +@@ -378,6 +549,9 @@ static json oaicompat_completion_params_parse( + // This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint. + // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp + for (const auto & item : body.items()) { ++ if (item.key() == "tools" || item.key() == "tool_choice") { ++ continue; ++ } + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); +@@ -399,14 +573,33 @@ static json format_final_response_oaicompat(const json & request, const json & r + finish_reason = "stop"; + } + +- json choices = +- streaming ? json::array({json{{"finish_reason", finish_reason}, +- {"index", 0}, +- {"delta", json::object()}}}) +- : json::array({json{{"finish_reason", finish_reason}, +- {"index", 0}, +- {"message", json{{"content", content}, +- {"role", "assistant"}}}}}); ++ json choices; ++ json tool_calls = json::array(); ++ if (!streaming) { ++ // Some agent runtimes describe tools in the prompt instead of sending ++ // the OpenAI "tools" parameter. Still upgrade strict tool-call JSON ++ // content into OpenAI-compatible message.tool_calls. ++ // This keeps n8n/LangChain agents from treating the call as plain text. ++ tool_calls = parse_tool_calls_from_content(content, completion_id); ++ } ++ ++ if (!tool_calls.empty()) { ++ finish_reason = "tool_calls"; ++ choices = json::array({json{{"finish_reason", finish_reason}, ++ {"index", 0}, ++ {"message", json{{"content", nullptr}, ++ {"role", "assistant"}, ++ {"tool_calls", tool_calls}}}}}); ++ } else { ++ choices = ++ streaming ? json::array({json{{"finish_reason", finish_reason}, ++ {"index", 0}, ++ {"delta", json::object()}}}) ++ : json::array({json{{"finish_reason", finish_reason}, ++ {"index", 0}, ++ {"message", json{{"content", content}, ++ {"role", "assistant"}}}}}); ++ } + + std::time_t t = std::time(0); diff --git a/setup_env.py b/setup_env.py index 3bf5fb8f7..0352cc679 100644 --- a/setup_env.py +++ b/setup_env.py @@ -215,7 +215,18 @@ def compile(): # run_command(["cmake", "--build", "build", "--target", "llama-cli", "--config", "Release"]) run_command(["cmake", "--build", "build", "--config", "Release"], log_step="compile") +def apply_local_patches(): + llama_cpp = Path("3rdparty") / "llama.cpp" + server_cpp = llama_cpp / "examples" / "server" / "server.cpp" + utils_hpp = llama_cpp / "examples" / "server" / "utils.hpp" + if not llama_cpp.exists() or not server_cpp.exists() or not utils_hpp.exists(): + logging.info("Skipping local llama.cpp patches: 3rdparty/llama.cpp server sources not found.") + return + + run_command([sys.executable, "utils/apply_local_patches.py"], log_step="apply_local_patches") + def main(): + apply_local_patches() setup_gguf() gen_code() compile() diff --git a/utils/apply_local_patches.py b/utils/apply_local_patches.py new file mode 100644 index 000000000..190a0639d --- /dev/null +++ b/utils/apply_local_patches.py @@ -0,0 +1,246 @@ +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + + +ROOT = Path(__file__).resolve().parents[1] +SERVER_CPP = ROOT / "3rdparty" / "llama.cpp" / "examples" / "server" / "server.cpp" +LLAMA_CPP = ROOT / "3rdparty" / "llama.cpp" +PATCHES = ROOT / "patches" + +NEW_CORS_BLOCK = """ // CORS preflight + svr->Options(R\"(.*)\", [](const httplib::Request & req, httplib::Response & res) { + // Access-Control-Allow-Origin is already set by middleware + res.set_header(\"Access-Control-Allow-Credentials\", \"true\"); + res.set_header(\"Access-Control-Allow-Methods\", \"GET, POST, OPTIONS\"); + + const auto requested_headers = req.get_header_value(\"Access-Control-Request-Headers\"); + if (!requested_headers.empty()) { + res.set_header(\"Access-Control-Allow-Headers\", requested_headers); + } else { + res.set_header(\"Access-Control-Allow-Headers\", \"*\"); + } + + return res.set_content(\"\", \"text/html\"); // blank response, no data + }); +""" + + +@dataclass +class PatchHunk: + old_start: int + old_lines: list[str] + new_lines: list[str] + + +@dataclass +class FilePatch: + target: Path + hunks: list[PatchHunk] + + +def ensure_server_cors_patch() -> None: + if not SERVER_CPP.exists(): + print(f"Skipping llama.cpp CORS patch: file not found at {SERVER_CPP}") + return + + content = SERVER_CPP.read_text(encoding="utf-8") + cors_comment = " // CORS preflight" + start = content.find(cors_comment) + if start == -1: + print("Failed to locate CORS preflight block in server.cpp", file=sys.stderr) + sys.exit(1) + + end_marker = " });" + end = content.find(end_marker, start) + if end == -1: + print("Failed to locate end of CORS preflight block in server.cpp", file=sys.stderr) + sys.exit(1) + end += len(end_marker) + + current_block = content[start:end] + if "Access-Control-Request-Headers" in current_block: + print("llama.cpp CORS patch already applied") + return + + required_markers = ( + "svr->Options", + "httplib::Request &", + "httplib::Response & res", + 'res.set_header("Access-Control-Allow-Methods"', + 'res.set_header("Access-Control-Allow-Headers"', + ) + if not all(marker in current_block for marker in required_markers): + print("Failed to locate expected CORS preflight lines in server.cpp", file=sys.stderr) + sys.exit(1) + + newline = "\r\n" if "\r\n" in content else "\n" + cors_block = NEW_CORS_BLOCK.rstrip("\n").replace("\n", newline) + SERVER_CPP.write_text(content[:start] + cors_block + content[end:], encoding="utf-8") + print("Applied llama.cpp CORS patch") + + +def parse_hunk_header(line: str) -> int: + old_range = line.split(" ", 2)[1] + return int(old_range[1:].split(",", 1)[0]) + + +def parse_unified_patch(patch: Path) -> Optional[list[FilePatch]]: + patch_lines = patch.read_text(encoding="utf-8").splitlines() + i = 0 + file_patches: list[FilePatch] = [] + + while i < len(patch_lines): + if not patch_lines[i].startswith("diff --git "): + i += 1 + continue + + i += 1 + while i < len(patch_lines) and not patch_lines[i].startswith("--- "): + i += 1 + if i >= len(patch_lines): + return None + + i += 1 + if i >= len(patch_lines) or not patch_lines[i].startswith("+++ b/"): + return None + + rel_path = patch_lines[i][6:] + target = LLAMA_CPP / rel_path + if not target.exists(): + print(f"Patch target not found: {target}", file=sys.stderr) + return None + + i += 1 + hunks: list[PatchHunk] = [] + while i < len(patch_lines): + if patch_lines[i].startswith("diff --git "): + break + if not patch_lines[i].startswith("@@ "): + i += 1 + continue + + old_start = parse_hunk_header(patch_lines[i]) + i += 1 + old_lines: list[str] = [] + new_lines: list[str] = [] + + while i < len(patch_lines) and not patch_lines[i].startswith("@@ ") and not patch_lines[i].startswith("diff --git "): + line = patch_lines[i] + if line.startswith("\\ No newline"): + i += 1 + continue + if line == "": + old_lines.append("") + new_lines.append("") + i += 1 + continue + + marker = line[:1] + value = line[1:] + if marker == " ": + old_lines.append(value) + new_lines.append(value) + elif marker == "-": + old_lines.append(value) + elif marker == "+": + new_lines.append(value) + else: + return None + i += 1 + + hunks.append(PatchHunk(old_start=old_start, old_lines=old_lines, new_lines=new_lines)) + + file_patches.append(FilePatch(target=target, hunks=hunks)) + + return file_patches + + +def simulate_file_patch(file_patch: FilePatch) -> tuple[str, Optional[str]]: + original = file_patch.target.read_text(encoding="utf-8") + newline = "\r\n" if "\r\n" in original else "\n" + has_trailing_newline = original.endswith(("\n", "\r")) + original_lines = original.splitlines() + patched_lines = original_lines.copy() + apply_offset = 0 + already_offset = 0 + can_apply = True + already_applied = True + + for hunk in file_patch.hunks: + apply_start = hunk.old_start - 1 + apply_offset + if patched_lines[apply_start:apply_start + len(hunk.old_lines)] == hunk.old_lines: + patched_lines[apply_start:apply_start + len(hunk.old_lines)] = hunk.new_lines + apply_offset += len(hunk.new_lines) - len(hunk.old_lines) + else: + can_apply = False + + already_start = hunk.old_start - 1 + already_offset + if original_lines[already_start:already_start + len(hunk.new_lines)] != hunk.new_lines: + already_applied = False + already_offset += len(hunk.new_lines) - len(hunk.old_lines) + + if can_apply: + patched_content = newline.join(patched_lines) + if has_trailing_newline: + patched_content += newline + return "apply", patched_content + if already_applied: + return "already", None + return "failed", None + + +def apply_unified_patch(patch: Path) -> str: + file_patches = parse_unified_patch(patch) + if not file_patches: + return "failed" + + pending_writes: list[tuple[Path, str]] = [] + applied_any = False + already_count = 0 + + for file_patch in file_patches: + status, new_content = simulate_file_patch(file_patch) + if status == "failed": + return "failed" + if status == "already": + already_count += 1 + continue + pending_writes.append((file_patch.target, new_content or "")) + applied_any = True + + if applied_any and already_count: + print(f"Refusing to partially apply {patch.name}: some hunks are already present", file=sys.stderr) + return "failed" + + for target, new_content in pending_writes: + target.write_text(new_content, encoding="utf-8") + + return "applied" if applied_any else "already" + + +def apply_patch_file(patch: Path) -> None: + if not patch.exists(): + print(f"Skipping patch: file not found at {patch}") + return + + patch_status = apply_unified_patch(patch) + if patch_status == "applied": + print(f"Applied {patch.name}") + return + if patch_status == "already": + print(f"{patch.name} already applied") + return + + print(f"Failed to apply {patch.name}", file=sys.stderr) + sys.exit(1) + + +def main() -> None: + ensure_server_cors_patch() + apply_patch_file(PATCHES / "llama-server-tools.patch") + + +if __name__ == "__main__": + main()