-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Gambiarra to accept tool calls. #559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rivaldodev
wants to merge
5
commits into
microsoft:main
Choose a base branch
from
rivaldodev:gambiarra
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+627
−0
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
9b76190
Gambiarra to accept tool calls.
rivaldodev 7a0f8af
feat: enhance CORS patching and refactor patch application logic
rivaldodev dca304a
Address Copilot patch review feedback
rivaldodev acd69a4
Handle escaped tool call JSON
rivaldodev 709ec29
Support streaming tool call conversion
rivaldodev File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,370 @@ | ||
| diff --git a/examples/server/server.cpp b/examples/server/server.cpp | ||
| index 18bcad3..520b558 100644 | ||
| --- a/examples/server/server.cpp | ||
| +++ b/examples/server/server.cpp | ||
| @@ -3007,8 +3007,90 @@ int main(int argc, char ** argv) { | ||
|
|
||
| ctx_server.queue_results.remove_waiting_task_ids(task_ids); | ||
| } else { | ||
| - const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { | ||
| + const bool stream_tools = data.contains("__oaicompat_tools"); | ||
| + const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, stream_tools](size_t, httplib::DataSink & sink) { | ||
| + std::string tool_stream_content; | ||
| + bool sent_tool_role = false; | ||
| + | ||
| ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { | ||
| + if (stream_tools) { | ||
| + std::time_t t = std::time(0); | ||
| + const std::string modelname = json_value(result.data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); | ||
| + | ||
| + if (!sent_tool_role) { | ||
| + json role_chunk = json{ | ||
| + {"choices", json::array({json{ | ||
| + {"finish_reason", nullptr}, | ||
| + {"index", 0}, | ||
| + {"delta", json{{"role", "assistant"}}} | ||
| + }})}, | ||
| + {"created", t}, | ||
| + {"id", completion_id}, | ||
| + {"model", modelname}, | ||
| + {"object", "chat.completion.chunk"} | ||
| + }; | ||
| + if (!server_sent_event(sink, "data", role_chunk)) { | ||
| + return false; | ||
| + } | ||
| + sent_tool_role = true; | ||
| + } | ||
| + | ||
| + tool_stream_content += json_value(result.data, "content", std::string()); | ||
| + | ||
| + const bool stopped_word = json_value(result.data, "stopped_word", false); | ||
| + const bool stopped_eos = json_value(result.data, "stopped_eos", false); | ||
| + const bool stopped_limit = json_value(result.data, "stopped_limit", false); | ||
| + if (!stopped_word && !stopped_eos && !stopped_limit) { | ||
| + return true; | ||
| + } | ||
| + | ||
| + std::string finish_reason = stopped_limit ? "length" : "stop"; | ||
| + json tool_calls = parse_tool_calls_from_content(tool_stream_content, completion_id); | ||
| + if (!tool_calls.empty()) { | ||
| + finish_reason = "tool_calls"; | ||
| + json tool_chunk = json{ | ||
| + {"choices", json::array({json{ | ||
| + {"finish_reason", nullptr}, | ||
| + {"index", 0}, | ||
| + {"delta", json{{"tool_calls", tool_calls}}} | ||
| + }})}, | ||
| + {"created", t}, | ||
| + {"id", completion_id}, | ||
| + {"model", modelname}, | ||
| + {"object", "chat.completion.chunk"} | ||
| + }; | ||
| + if (!server_sent_event(sink, "data", tool_chunk)) { | ||
| + return false; | ||
| + } | ||
| + } else if (!tool_stream_content.empty()) { | ||
| + json content_chunk = json{ | ||
| + {"choices", json::array({json{ | ||
| + {"finish_reason", nullptr}, | ||
| + {"index", 0}, | ||
| + {"delta", json{{"content", tool_stream_content}}} | ||
| + }})}, | ||
| + {"created", t}, | ||
| + {"id", completion_id}, | ||
| + {"model", modelname}, | ||
| + {"object", "chat.completion.chunk"} | ||
| + }; | ||
| + if (!server_sent_event(sink, "data", content_chunk)) { | ||
| + return false; | ||
| + } | ||
| + } | ||
| + | ||
| + json final_chunk = json{ | ||
| + {"choices", json::array({json{{"finish_reason", finish_reason}, | ||
| + {"index", 0}, | ||
| + {"delta", json::object()}}})}, | ||
| + {"created", t}, | ||
| + {"id", completion_id}, | ||
| + {"model", modelname}, | ||
| + {"object", "chat.completion.chunk"} | ||
| + }; | ||
| + return server_sent_event(sink, "data", final_chunk); | ||
| + } | ||
| + | ||
| std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id); | ||
| for (auto & event_data : result_array) { | ||
| if (event_data.empty()) { | ||
| diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp | ||
| index 69519ef..b5d5eea 100644 | ||
| --- a/examples/server/utils.hpp | ||
| +++ b/examples/server/utils.hpp | ||
| @@ -65,9 +65,21 @@ inline std::string format_chat(const struct llama_model * model, const std::stri | ||
| std::string role = json_value(curr_msg, "role", std::string("")); | ||
|
|
||
| std::string content; | ||
| - if (curr_msg.contains("content")) { | ||
| + if (role == "tool") { | ||
| + // Most GGUF chat templates do not define a native "tool" role. | ||
| + // Present tool output as a user-visible observation for the next | ||
| + // assistant turn instead of relying on template-specific behavior. | ||
| + role = "user"; | ||
| + content = "Tool result"; | ||
| + if (curr_msg.contains("tool_call_id") && curr_msg["tool_call_id"].is_string()) { | ||
| + content += " for " + curr_msg["tool_call_id"].get<std::string>(); | ||
| + } | ||
| + content += ":\n"; | ||
| + } | ||
| + | ||
| + if (curr_msg.contains("content") && !curr_msg["content"].is_null()) { | ||
| if (curr_msg["content"].is_string()) { | ||
| - content = curr_msg["content"].get<std::string>(); | ||
| + content += curr_msg["content"].get<std::string>(); | ||
| } else if (curr_msg["content"].is_array()) { | ||
| for (const auto & part : curr_msg["content"]) { | ||
| if (part.contains("text")) { | ||
| @@ -77,10 +89,18 @@ inline std::string format_chat(const struct llama_model * model, const std::stri | ||
| } else { | ||
| throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); | ||
| } | ||
| + } else if (curr_msg.contains("tool_calls")) { | ||
| + content += "Tool calls:\n" + curr_msg["tool_calls"].dump(); | ||
| } else { | ||
| throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); | ||
| } | ||
|
|
||
| + if (json_value(curr_msg, "role", std::string("")) == "tool") { | ||
| + content += "\n\nUse the tool result above to answer the user's request now. " | ||
| + "Do not call another tool. Do not output a tool name, an Input line, " | ||
| + "tool_calls JSON, or any action syntax."; | ||
| + } | ||
| + | ||
| chat.push_back({role, content}); | ||
| } | ||
|
|
||
| @@ -316,6 +336,151 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons | ||
| return sink.write(str.c_str(), str.size()); | ||
| } | ||
|
|
||
| +static std::string trim(const std::string & str) { | ||
| + const auto first = str.find_first_not_of(" \t\n\r"); | ||
| + if (first == std::string::npos) { | ||
| + return ""; | ||
| + } | ||
| + | ||
| + const auto last = str.find_last_not_of(" \t\n\r"); | ||
| + return str.substr(first, last - first + 1); | ||
| +} | ||
| + | ||
| +static json normalize_tool_call_arguments(const json & args) { | ||
| + if (args.is_string()) { | ||
| + try { | ||
| + return json::parse(args.get<std::string>()); | ||
| + } catch (const json::parse_error &) { | ||
| + return json::object({{"input", args.get<std::string>()}}); | ||
| + } | ||
| + } | ||
| + | ||
| + if (args.is_object()) { | ||
| + return args; | ||
| + } | ||
| + | ||
| + if (args.is_null()) { | ||
| + return json::object(); | ||
| + } | ||
| + | ||
| + return json::object({{"value", args}}); | ||
| +} | ||
| + | ||
| +static json normalize_tool_call(const json & call, const std::string & completion_id, size_t index) { | ||
| + std::string name; | ||
| + json arguments = json::object(); | ||
| + | ||
| + if (call.contains("function") && call["function"].is_object()) { | ||
| + const auto & fn = call["function"]; | ||
| + name = json_value(fn, "name", std::string()); | ||
| + arguments = normalize_tool_call_arguments(json_value(fn, "arguments", json::object())); | ||
| + } else { | ||
| + name = json_value(call, "name", std::string()); | ||
| + arguments = normalize_tool_call_arguments(json_value(call, "arguments", json::object())); | ||
| + } | ||
| + | ||
| + if (name.empty()) { | ||
| + throw std::runtime_error("Tool call is missing function name"); | ||
| + } | ||
| + | ||
| + return json{ | ||
| + {"id", json_value(call, "id", completion_id + "_tool_" + std::to_string(index))}, | ||
| + {"type", "function"}, | ||
| + {"function", json{ | ||
| + {"name", name}, | ||
| + {"arguments", arguments.dump()}, | ||
| + }}, | ||
| + }; | ||
| +} | ||
| + | ||
| +static json parse_tool_calls_from_content(const std::string & content, const std::string & completion_id) { | ||
| + std::string text = trim(content); | ||
| + | ||
| + if (text.rfind("```json", 0) == 0) { | ||
| + text = trim(text.substr(7)); | ||
| + } else if (text.rfind("```", 0) == 0) { | ||
| + text = trim(text.substr(3)); | ||
| + } | ||
| + if (text.size() >= 3 && text.substr(text.size() - 3) == "```") { | ||
| + text = trim(text.substr(0, text.size() - 3)); | ||
| + } | ||
| + | ||
| + json parsed; | ||
| + try { | ||
| + parsed = json::parse(text); | ||
| + } catch (const json::parse_error &) { | ||
| + return json::array(); | ||
| + } | ||
| + | ||
| + if (parsed.is_string()) { | ||
| + try { | ||
| + parsed = json::parse(trim(parsed.get<std::string>())); | ||
| + } catch (const json::parse_error &) { | ||
| + return json::array(); | ||
| + } | ||
| + } | ||
| + | ||
| + if (!parsed.is_object()) { | ||
| + return json::array(); | ||
| + } | ||
| + | ||
| + json calls = json::array(); | ||
| + json normalized_tool_calls = json::array(); | ||
| + const json * tool_calls_json = nullptr; | ||
| + if (parsed.contains("tool_calls") && parsed["tool_calls"].is_array()) { | ||
| + tool_calls_json = &parsed["tool_calls"]; | ||
| + } else if (parsed.contains("tool_call") && parsed["tool_call"].is_array()) { | ||
| + tool_calls_json = &parsed["tool_call"]; | ||
| + } else if (parsed.contains("tool_call") && parsed["tool_call"].is_object()) { | ||
| + normalized_tool_calls.push_back(parsed["tool_call"]); | ||
| + tool_calls_json = &normalized_tool_calls; | ||
| + } | ||
| + | ||
| + if (tool_calls_json != nullptr) { | ||
| + for (size_t i = 0; i < tool_calls_json->size(); ++i) { | ||
| + try { | ||
| + calls.push_back(normalize_tool_call((*tool_calls_json)[i], completion_id, i)); | ||
| + } catch (const std::runtime_error &) { | ||
| + return json::array(); | ||
| + } | ||
| + } | ||
| + } else if ((parsed.contains("function") && parsed["function"].is_object() && parsed["function"].contains("name")) | ||
| + || (parsed.contains("name") && parsed["name"].is_string() && parsed.contains("arguments"))) { | ||
| + try { | ||
| + calls.push_back(normalize_tool_call(parsed, completion_id, 0)); | ||
| + } catch (const std::runtime_error &) { | ||
| + return json::array(); | ||
| + } | ||
| + } | ||
| + | ||
| + return calls; | ||
| +} | ||
| + | ||
| +static std::string tools_prompt(const json & body) { | ||
| + if (!body.contains("tools") || !body["tools"].is_array() || body["tools"].empty()) { | ||
| + return ""; | ||
| + } | ||
| + | ||
| + if (body.contains("tool_choice") && body["tool_choice"].is_string() && body["tool_choice"].get<std::string>() == "none") { | ||
| + return ""; | ||
| + } | ||
| + | ||
| + std::ostringstream ss; | ||
| + ss << "You can call tools when needed. Available tools are provided as JSON below.\n"; | ||
| + ss << body["tools"].dump(2) << "\n\n"; | ||
| + ss << "If a tool is required, respond only with strict JSON in this exact shape:\n"; | ||
| + ss << "{\"tool_calls\":[{\"name\":\"tool_name\",\"arguments\":{\"arg\":\"value\"}}]}\n"; | ||
| + ss << "Prefer the \"tool_calls\" array key. A single \"tool_call\" object is accepted for compatibility.\n"; | ||
| + ss << "Do not include markdown or explanatory text around a tool call. "; | ||
|
rivaldodev marked this conversation as resolved.
|
||
| + ss << "If no tool is required, answer normally."; | ||
| + | ||
| + if (body.contains("tool_choice") && body["tool_choice"].is_object()) { | ||
| + ss << "\nThe requested tool_choice is:\n" << body["tool_choice"].dump(2); | ||
| + } | ||
| + | ||
| + return ss.str(); | ||
| +} | ||
| + | ||
| // | ||
| // OAI utils | ||
| // | ||
| @@ -329,7 +494,13 @@ static json oaicompat_completion_params_parse( | ||
| llama_params["__oaicompat"] = true; | ||
|
|
||
| // Apply chat template to the list of messages | ||
| - llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); | ||
| + json messages = body.at("messages"); | ||
| + const std::string tool_instructions = tools_prompt(body); | ||
| + if (!tool_instructions.empty()) { | ||
| + messages.insert(messages.begin(), json{{"role", "system"}, {"content", tool_instructions}}); | ||
| + llama_params["__oaicompat_tools"] = body["tools"]; | ||
| + } | ||
| + llama_params["prompt"] = format_chat(model, chat_template, messages); | ||
|
|
||
| // Handle "stop" field | ||
| if (body.contains("stop") && body.at("stop").is_string()) { | ||
| @@ -367,7 +538,7 @@ static json oaicompat_completion_params_parse( | ||
| } | ||
|
|
||
| // Params supported by OAI but unsupported by llama.cpp | ||
| - static const std::vector<std::string> unsupported_params { "tools", "tool_choice" }; | ||
| + static const std::vector<std::string> unsupported_params {}; | ||
| for (const auto & param : unsupported_params) { | ||
| if (body.contains(param)) { | ||
| throw std::runtime_error("Unsupported param: " + param); | ||
| @@ -378,6 +549,9 @@ static json oaicompat_completion_params_parse( | ||
| // This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint. | ||
| // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp | ||
| for (const auto & item : body.items()) { | ||
| + if (item.key() == "tools" || item.key() == "tool_choice") { | ||
| + continue; | ||
| + } | ||
| // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" | ||
| if (!llama_params.contains(item.key()) || item.key() == "n_predict") { | ||
| llama_params[item.key()] = item.value(); | ||
| @@ -399,14 +573,33 @@ static json format_final_response_oaicompat(const json & request, const json & r | ||
| finish_reason = "stop"; | ||
| } | ||
|
|
||
| - json choices = | ||
| - streaming ? json::array({json{{"finish_reason", finish_reason}, | ||
| - {"index", 0}, | ||
| - {"delta", json::object()}}}) | ||
| - : json::array({json{{"finish_reason", finish_reason}, | ||
| - {"index", 0}, | ||
| - {"message", json{{"content", content}, | ||
| - {"role", "assistant"}}}}}); | ||
| + json choices; | ||
| + json tool_calls = json::array(); | ||
| + if (!streaming) { | ||
| + // Some agent runtimes describe tools in the prompt instead of sending | ||
| + // the OpenAI "tools" parameter. Still upgrade strict tool-call JSON | ||
| + // content into OpenAI-compatible message.tool_calls. | ||
| + // This keeps n8n/LangChain agents from treating the call as plain text. | ||
| + tool_calls = parse_tool_calls_from_content(content, completion_id); | ||
| + } | ||
| + | ||
| + if (!tool_calls.empty()) { | ||
| + finish_reason = "tool_calls"; | ||
| + choices = json::array({json{{"finish_reason", finish_reason}, | ||
| + {"index", 0}, | ||
| + {"message", json{{"content", nullptr}, | ||
| + {"role", "assistant"}, | ||
| + {"tool_calls", tool_calls}}}}}); | ||
| + } else { | ||
| + choices = | ||
| + streaming ? json::array({json{{"finish_reason", finish_reason}, | ||
| + {"index", 0}, | ||
| + {"delta", json::object()}}}) | ||
| + : json::array({json{{"finish_reason", finish_reason}, | ||
| + {"index", 0}, | ||
| + {"message", json{{"content", content}, | ||
| + {"role", "assistant"}}}}}); | ||
| + } | ||
|
|
||
| std::time_t t = std::time(0); | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.