From 3ac1cf1a8fed8604562b1d6962c69bd5a1345ee2 Mon Sep 17 00:00:00 2001 From: mkulakow Date: Thu, 30 Apr 2026 14:28:53 +0200 Subject: [PATCH] Support functions in responses api --- src/llm/apis/openai_responses.cpp | 459 +++++++++++++++++++----- src/llm/py_jinja_template_processor.cpp | 2 +- 2 files changed, 376 insertions(+), 85 deletions(-) diff --git a/src/llm/apis/openai_responses.cpp b/src/llm/apis/openai_responses.cpp index e5d63985e6..77f0073a5c 100644 --- a/src/llm/apis/openai_responses.cpp +++ b/src/llm/apis/openai_responses.cpp @@ -57,6 +57,88 @@ static std::string joinServerSideEvents(const std::vector& events) return ss.str(); } +static void appendReasoningSummaryText(const rapidjson::Value::ConstObject& itemObj, std::string& reasoningBuffer) { + auto summaryIt = itemObj.FindMember("summary"); + if (summaryIt == itemObj.MemberEnd() || !summaryIt->value.IsArray()) { + return; + } + for (const auto& summaryItem : summaryIt->value.GetArray()) { + if (!summaryItem.IsObject()) + continue; + auto summaryObj = summaryItem.GetObject(); + auto stTypeIt = summaryObj.FindMember("type"); + if (stTypeIt == summaryObj.MemberEnd() || !stTypeIt->value.IsString()) + continue; + if (std::string(stTypeIt->value.GetString()) != "summary_text") + continue; + auto textIt = summaryObj.FindMember("text"); + if (textIt == summaryObj.MemberEnd() || !textIt->value.IsString()) + continue; + if (!reasoningBuffer.empty()) + reasoningBuffer += "\n"; + reasoningBuffer += textIt->value.GetString(); + } +} + +template +static absl::Status iterateResponsesInputArrayItems( + const rapidjson::Value& inputArray, + OnReasoning&& onReasoning, + OnFunctionCall&& onFunctionCall, + OnFunctionCallOutput&& onFunctionCallOutput, + OnRoleItem&& onRoleItem, + OnMissingRole&& onMissingRole) { + if (!inputArray.IsArray()) { + return absl::InvalidArgumentError("input is not an array"); + } + + for (rapidjson::SizeType i = 0; i < inputArray.GetArray().Size(); ++i) { + const auto& item = inputArray.GetArray()[i]; + if (!item.IsObject()) { + return absl::InvalidArgumentError("input array items must be objects"); + } + + auto itemObj = item.GetObject(); + auto itemTypeIt = itemObj.FindMember("type"); + const std::string itemType = (itemTypeIt != itemObj.MemberEnd() && itemTypeIt->value.IsString()) + ? itemTypeIt->value.GetString() + : ""; + + if (itemType == "reasoning") { + auto status = onReasoning(itemObj, i); + if (!status.ok()) + return status; + continue; + } + if (itemType == "function_call") { + auto status = onFunctionCall(item, itemObj, i); + if (!status.ok()) + return status; + continue; + } + if (itemType == "function_call_output") { + auto status = onFunctionCallOutput(itemObj, i); + if (!status.ok()) + return status; + continue; + } + + auto roleIt = itemObj.FindMember("role"); + if (roleIt == itemObj.MemberEnd() || !roleIt->value.IsString()) { + auto status = onMissingRole(itemObj, i); + if (!status.ok()) + return status; + continue; + } + + auto status = onRoleItem(itemObj, std::string(roleIt->value.GetString()), i); + if (!status.ok()) + return status; + } + + return absl::OkStatus(); +} + // --- Request parsing --- absl::Status OpenAIResponsesHandler::parseRequest(std::optional maxTokensLimit, uint32_t bestOfLimit, std::optional maxModelLength, @@ -88,86 +170,123 @@ absl::Status OpenAIResponsesHandler::parseInput(std::optional allow return absl::InvalidArgumentError("Messages array cannot be empty"); } - for (size_t i = 0; i < inputIt->value.GetArray().Size(); ++i) { - auto& item = inputIt->value.GetArray()[i]; - if (!item.IsObject()) { - return absl::InvalidArgumentError("input array items must be objects"); - } - - auto itemObj = item.GetObject(); - auto roleIt = itemObj.FindMember("role"); - if (roleIt == itemObj.MemberEnd() || !roleIt->value.IsString()) { - return absl::InvalidArgumentError("input item role is missing or invalid"); - } - - request.chatHistory.push_back({}); - request.chatHistory.last()["role"] = roleIt->value.GetString(); - - auto contentIt = itemObj.FindMember("content"); - if (contentIt == itemObj.MemberEnd()) { - return absl::InvalidArgumentError("input item content is missing"); - } + std::string pendingReasoning; + auto parseStatus = iterateResponsesInputArrayItems( + inputIt->value, + [&](const rapidjson::Value::ConstObject& itemObj, rapidjson::SizeType) -> absl::Status { + appendReasoningSummaryText(itemObj, pendingReasoning); + return absl::OkStatus(); + }, + [&](const rapidjson::Value&, const rapidjson::Value::ConstObject&, rapidjson::SizeType) -> absl::Status { + // For chatHistory (non-Python path), represent function call as an assistant message with empty content. + request.chatHistory.push_back({}); + request.chatHistory.last()["role"] = "assistant"; + request.chatHistory.last()["content"] = ""; + if (!pendingReasoning.empty()) { + request.chatHistory.last()["reasoning_content"] = pendingReasoning; + pendingReasoning.clear(); + } + return absl::OkStatus(); + }, + [&](const rapidjson::Value::ConstObject& itemObj, rapidjson::SizeType) -> absl::Status { + auto callIdIt = itemObj.FindMember("call_id"); + auto outputIt = itemObj.FindMember("output"); + request.chatHistory.push_back({}); + request.chatHistory.last()["role"] = "tool"; + if (callIdIt != itemObj.MemberEnd() && callIdIt->value.IsString()) { + request.chatHistory.last()["tool_call_id"] = callIdIt->value.GetString(); + } + const std::string outputContent = (outputIt != itemObj.MemberEnd() && outputIt->value.IsString()) + ? outputIt->value.GetString() + : ""; + request.chatHistory.last()["content"] = outputContent; + return absl::OkStatus(); + }, + [&](const rapidjson::Value::ConstObject& itemObj, const std::string& role, rapidjson::SizeType index) -> absl::Status { + request.chatHistory.push_back({}); + request.chatHistory.last()["role"] = role; + if (role == "assistant" && !pendingReasoning.empty()) { + request.chatHistory.last()["reasoning_content"] = pendingReasoning; + pendingReasoning.clear(); + } - if (contentIt->value.IsString()) { - request.chatHistory.last()["content"] = contentIt->value.GetString(); - continue; - } + auto contentIt = itemObj.FindMember("content"); + if (contentIt == itemObj.MemberEnd()) { + // Allow messages without content (e.g., assistant message paired with tool calls) + request.chatHistory.last()["content"] = ""; + return absl::OkStatus(); + } - if (!contentIt->value.IsArray()) { - return absl::InvalidArgumentError("input item content must be a string or array"); - } - if (contentIt->value.GetArray().Size() == 0) { - return absl::InvalidArgumentError("Invalid message structure - content array is empty"); - } + if (contentIt->value.IsString()) { + request.chatHistory.last()["content"] = contentIt->value.GetString(); + return absl::OkStatus(); + } - std::string contentText = ""; - for (auto& contentItem : contentIt->value.GetArray()) { - if (!contentItem.IsObject()) { - return absl::InvalidArgumentError("input content items must be objects"); + if (!contentIt->value.IsArray()) { + return absl::InvalidArgumentError("input item content must be a string or array"); } - auto contentObj = contentItem.GetObject(); - auto typeIt = contentObj.FindMember("type"); - if (typeIt == contentObj.MemberEnd() || !typeIt->value.IsString()) { - return absl::InvalidArgumentError("input content item type is missing or invalid"); + if (contentIt->value.GetArray().Size() == 0) { + // Empty content array is allowed (e.g., assistant message with only tool calls) + request.chatHistory.last()["content"] = ""; + return absl::OkStatus(); } - const std::string type = typeIt->value.GetString(); - if (type == "input_text") { - auto textIt = contentObj.FindMember("text"); - if (textIt == contentObj.MemberEnd() || !textIt->value.IsString()) { - return absl::InvalidArgumentError("input_text requires a valid text field"); + std::string contentText = ""; + for (const auto& contentItem : contentIt->value.GetArray()) { + if (!contentItem.IsObject()) { + return absl::InvalidArgumentError("input content items must be objects"); } - contentText = textIt->value.GetString(); - } else if (type == "input_image") { - std::string imageUrl; - auto imageUrlIt = contentObj.FindMember("image_url"); - if (imageUrlIt == contentObj.MemberEnd()) { - return absl::InvalidArgumentError("input_image requires image_url field"); + auto contentObj = contentItem.GetObject(); + auto typeIt = contentObj.FindMember("type"); + if (typeIt == contentObj.MemberEnd() || !typeIt->value.IsString()) { + return absl::InvalidArgumentError("input content item type is missing or invalid"); } - if (imageUrlIt->value.IsString()) { - imageUrl = imageUrlIt->value.GetString(); - } else if (imageUrlIt->value.IsObject()) { - auto imageUrlObj = imageUrlIt->value.GetObject(); - auto urlIt = imageUrlObj.FindMember("url"); - if (urlIt == imageUrlObj.MemberEnd() || !urlIt->value.IsString()) { - return absl::InvalidArgumentError("input_image.image_url.url is missing or invalid"); + + const std::string type = typeIt->value.GetString(); + if (type == "input_text" || type == "output_text") { + auto textIt = contentObj.FindMember("text"); + if (textIt == contentObj.MemberEnd() || !textIt->value.IsString()) { + return absl::InvalidArgumentError(absl::StrCat(type, " requires a valid text field")); + } + contentText = textIt->value.GetString(); + } else if (type == "input_image") { + std::string imageUrl; + auto imageUrlIt = contentObj.FindMember("image_url"); + if (imageUrlIt == contentObj.MemberEnd()) { + return absl::InvalidArgumentError("input_image requires image_url field"); + } + if (imageUrlIt->value.IsString()) { + imageUrl = imageUrlIt->value.GetString(); + } else if (imageUrlIt->value.IsObject()) { + auto imageUrlObj = imageUrlIt->value.GetObject(); + auto urlIt = imageUrlObj.FindMember("url"); + if (urlIt == imageUrlObj.MemberEnd() || !urlIt->value.IsString()) { + return absl::InvalidArgumentError("input_image.image_url.url is missing or invalid"); + } + imageUrl = urlIt->value.GetString(); + } else { + return absl::InvalidArgumentError("input_image.image_url must be a string or object"); } - imageUrl = urlIt->value.GetString(); - } else { - return absl::InvalidArgumentError("input_image.image_url must be a string or object"); - } - auto tensorResult = loadImage(imageUrl, allowedLocalMediaPath, allowedMediaDomains); - if (!tensorResult.ok()) { - return tensorResult.status(); + auto tensorResult = loadImage(imageUrl, allowedLocalMediaPath, allowedMediaDomains); + if (!tensorResult.ok()) { + return tensorResult.status(); + } + request.imageHistory.push_back({index, tensorResult.value()}); + } else { + // Skip unrecognised content item types for forward compatibility + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Skipping unsupported content type: {}", type); } - request.imageHistory.push_back({i, tensorResult.value()}); - } else { - return absl::InvalidArgumentError("Unsupported content type. Supported types are input_text and input_image."); } - } - request.chatHistory.last()["content"] = contentText; + request.chatHistory.last()["content"] = contentText; + return absl::OkStatus(); + }, + [&](const rapidjson::Value::ConstObject&, rapidjson::SizeType) -> absl::Status { + return absl::InvalidArgumentError("input item role is missing or invalid"); + }); + if (!parseStatus.ok()) { + return parseStatus; } } else { return absl::InvalidArgumentError("input is not a string or array"); @@ -228,34 +347,206 @@ absl::Status OpenAIResponsesHandler::parseResponsesPart(std::optional } #if (PYTHON_DISABLE == 0) - // Build processedJson with "messages" array from chatHistory so that - // the Python chat template path (which reads request_json["messages"]) - // can consume Responses API input without a separate code path. + // Build processedJson with a "messages" array in chat/completions format so that + // the Python Jinja template path can consume Responses API input without a separate code path. + // Handles reasoning, function_call (merged into assistant tool_calls), and + // function_call_output (converted to role:tool messages). { Document processedDoc; processedDoc.SetObject(); auto& alloc = processedDoc.GetAllocator(); Value messagesArray(kArrayType); - for (size_t i = 0; i < request.chatHistory.size(); ++i) { - Value msgObj(kObjectType); - auto role = request.chatHistory[i]["role"].as_string(); - if (role.has_value()) { - msgObj.AddMember("role", Value(role.value().c_str(), alloc), alloc); - } - auto content = request.chatHistory[i]["content"].as_string(); - if (content.has_value()) { - msgObj.AddMember("content", Value(content.value().c_str(), alloc), alloc); + + auto inputArrIt = doc.FindMember("input"); + if (inputArrIt != doc.MemberEnd() && inputArrIt->value.IsArray()) { + // Pending function_call items to be merged into the next assistant message + std::vector pendingFunctionCalls; + std::string pendingReasoningJson; + + // Helper: flush pending function_calls as an assistant message with the given text content + auto flushPendingFunctionCalls = [&](const std::string& textContent) { + if (pendingFunctionCalls.empty()) { + return; + } + Value msgObj(kObjectType); + msgObj.AddMember("role", Value("assistant", alloc), alloc); + msgObj.AddMember("content", Value(textContent.c_str(), alloc), alloc); + if (!pendingReasoningJson.empty()) { + msgObj.AddMember("reasoning_content", Value(pendingReasoningJson.c_str(), alloc), alloc); + pendingReasoningJson.clear(); + } + Value toolCallsArray(kArrayType); + for (const auto* fc : pendingFunctionCalls) { + auto fcObj = fc->GetObject(); + Value tcObj(kObjectType); + auto idIt = fcObj.FindMember("id"); + const std::string tcId = (idIt != fcObj.MemberEnd() && idIt->value.IsString()) + ? idIt->value.GetString() + : ""; + tcObj.AddMember("id", Value(tcId.c_str(), alloc), alloc); + tcObj.AddMember("type", Value("function", alloc), alloc); + Value funcObj(kObjectType); + auto nameIt = fcObj.FindMember("name"); + const std::string funcName = (nameIt != fcObj.MemberEnd() && nameIt->value.IsString()) + ? nameIt->value.GetString() + : ""; + funcObj.AddMember("name", Value(funcName.c_str(), alloc), alloc); + auto argsIt = fcObj.FindMember("arguments"); + const std::string args = (argsIt != fcObj.MemberEnd() && argsIt->value.IsString()) + ? argsIt->value.GetString() + : ""; + funcObj.AddMember("arguments", Value(args.c_str(), alloc), alloc); + tcObj.AddMember("function", funcObj, alloc); + toolCallsArray.PushBack(tcObj, alloc); + } + msgObj.AddMember("tool_calls", toolCallsArray, alloc); + messagesArray.PushBack(msgObj, alloc); + pendingFunctionCalls.clear(); + }; + + // Helper: extract text content from a Responses API content field (string or array) + auto extractTextContent = [&](const rapidjson::Value& contentVal) -> std::string { + if (contentVal.IsString()) { + return contentVal.GetString(); + } + if (contentVal.IsArray()) { + for (auto& ci : contentVal.GetArray()) { + if (!ci.IsObject()) + continue; + auto ctTypeIt = ci.GetObject().FindMember("type"); + if (ctTypeIt == ci.GetObject().MemberEnd() || !ctTypeIt->value.IsString()) + continue; + const std::string ctType = ctTypeIt->value.GetString(); + if (ctType == "input_text" || ctType == "output_text") { + auto textIt = ci.GetObject().FindMember("text"); + if (textIt != ci.GetObject().MemberEnd() && textIt->value.IsString()) { + return textIt->value.GetString(); + } + } + } + } + return ""; + }; + + auto processedStatus = iterateResponsesInputArrayItems( + inputArrIt->value, + [&](const rapidjson::Value::ConstObject& itemObj, rapidjson::SizeType) -> absl::Status { + appendReasoningSummaryText(itemObj, pendingReasoningJson); + return absl::OkStatus(); + }, + [&](const rapidjson::Value& item, const rapidjson::Value::ConstObject&, rapidjson::SizeType) -> absl::Status { + pendingFunctionCalls.push_back(&item); + return absl::OkStatus(); + }, + [&](const rapidjson::Value::ConstObject& itemObj, rapidjson::SizeType) -> absl::Status { + flushPendingFunctionCalls(""); + Value msgObj(kObjectType); + msgObj.AddMember("role", Value("tool", alloc), alloc); + auto callIdIt = itemObj.FindMember("call_id"); + if (callIdIt != itemObj.MemberEnd() && callIdIt->value.IsString()) { + msgObj.AddMember("tool_call_id", Value(callIdIt->value.GetString(), alloc), alloc); + } + auto outputIt = itemObj.FindMember("output"); + const std::string outputContent = (outputIt != itemObj.MemberEnd() && outputIt->value.IsString()) + ? outputIt->value.GetString() + : ""; + msgObj.AddMember("content", Value(outputContent.c_str(), alloc), alloc); + messagesArray.PushBack(msgObj, alloc); + return absl::OkStatus(); + }, + [&](const rapidjson::Value::ConstObject& itemObj, const std::string& role, rapidjson::SizeType) -> absl::Status { + std::string contentText = ""; + auto contentIt = itemObj.FindMember("content"); + if (contentIt != itemObj.MemberEnd()) { + contentText = extractTextContent(contentIt->value); + } + + if (role == "assistant") { + if (!pendingFunctionCalls.empty()) { + // Merge buffered function_call items into this assistant message + flushPendingFunctionCalls(contentText); + } else { + // Plain assistant message with no associated tool calls + Value msgObj(kObjectType); + msgObj.AddMember("role", Value("assistant", alloc), alloc); + msgObj.AddMember("content", Value(contentText.c_str(), alloc), alloc); + if (!pendingReasoningJson.empty()) { + msgObj.AddMember("reasoning_content", Value(pendingReasoningJson.c_str(), alloc), alloc); + pendingReasoningJson.clear(); + } + messagesArray.PushBack(msgObj, alloc); + } + } else { + // Non-assistant message — flush any pending function calls first + flushPendingFunctionCalls(""); + Value msgObj(kObjectType); + msgObj.AddMember("role", Value(role.c_str(), alloc), alloc); + msgObj.AddMember("content", Value(contentText.c_str(), alloc), alloc); + messagesArray.PushBack(msgObj, alloc); + } + return absl::OkStatus(); + }, + [&](const rapidjson::Value::ConstObject&, rapidjson::SizeType) -> absl::Status { + return absl::OkStatus(); // Skip unknown items without a role in processed JSON path + }); + if (!processedStatus.ok()) { + return processedStatus; } - messagesArray.PushBack(msgObj, alloc); + + // Flush any trailing buffered function_calls + flushPendingFunctionCalls(""); } + processedDoc.AddMember("messages", messagesArray, alloc); - // Copy tools from original doc if present + // Convert tools from Responses API flat format to chat/completions nested format. + // Responses API: {"type": "function", "name": "foo", "description": "...", "parameters": {...}} + // Chat/completions: {"type": "function", "function": {"name": "foo", "description": "...", "parameters": {...}}} auto toolsIt = doc.FindMember("tools"); - if (toolsIt != doc.MemberEnd() && !toolsIt->value.IsNull()) { - Value toolsCopy(toolsIt->value, alloc); - processedDoc.AddMember("tools", toolsCopy, alloc); + if (toolsIt != doc.MemberEnd() && !toolsIt->value.IsNull() && toolsIt->value.IsArray()) { + Value toolsArray(kArrayType); + for (const auto& tool : toolsIt->value.GetArray()) { + if (!tool.IsObject()) + continue; + auto toolObj = tool.GetObject(); + // Check if this tool already has a nested "function" key (chat/completions format) + if (toolObj.FindMember("function") != toolObj.MemberEnd()) { + // Already in chat/completions format — copy as-is + Value toolCopy(tool, alloc); + toolsArray.PushBack(toolCopy, alloc); + } else { + auto typeIt = toolObj.FindMember("type"); + const std::string toolType = (typeIt != toolObj.MemberEnd() && typeIt->value.IsString()) + ? typeIt->value.GetString() + : ""; + + if (toolType == "function") { + // Responses API flat function format — wrap under "function" key. + Value convertedTool(kObjectType); + convertedTool.AddMember("type", Value("function", alloc), alloc); + Value funcObj(kObjectType); + // Copy all fields except "type" and "response" into the nested function object. + for (auto it2 = toolObj.MemberBegin(); it2 != toolObj.MemberEnd(); ++it2) { + if (!it2->name.IsString()) + continue; + const std::string fieldName = it2->name.GetString(); + if (fieldName == "type" || fieldName == "response") + continue; + Value keyCopy(it2->name, alloc); + Value valCopy(it2->value, alloc); + funcObj.AddMember(keyCopy, valCopy, alloc); + } + convertedTool.AddMember("function", funcObj, alloc); + toolsArray.PushBack(convertedTool, alloc); + } else { + // Preserve non-function tools as-is instead of rewriting type. + Value toolCopy(tool, alloc); + toolsArray.PushBack(toolCopy, alloc); + } + } + } + processedDoc.AddMember("tools", toolsArray, alloc); } // Copy chat_template_kwargs from original doc if present diff --git a/src/llm/py_jinja_template_processor.cpp b/src/llm/py_jinja_template_processor.cpp index 432aa8e722..61116d3c5d 100644 --- a/src/llm/py_jinja_template_processor.cpp +++ b/src/llm/py_jinja_template_processor.cpp @@ -40,7 +40,7 @@ bool PyJinjaTemplateProcessor::applyChatTemplate(PyJinjaTemplateProcessor& templ output = "Error: Chat template not loaded correctly, so it cannot be applied"; return false; } - + SPDLOG_DEBUG("Before chat template: \n {}", requestBody); py::gil_scoped_acquire acquire; try { auto locals = py::dict("request_body"_a = requestBody, "chat_template"_a = templateProcessor.chatTemplate->getObject(),