diff --git a/prepare_llm_models.sh b/prepare_llm_models.sh index 31efa76ac4..c590c3d063 100755 --- a/prepare_llm_models.sh +++ b/prepare_llm_models.sh @@ -38,6 +38,7 @@ PHI4_MODEL="microsoft/Phi-4-mini-instruct" MISTRAL_MODEL="mistralai/Mistral-7B-Instruct-v0.3" GPT_OSS_MODEL="openai/gpt-oss-20b" DEVSTRAL_MODEL="unsloth/Devstral-Small-2507" +GEMMA4_MODEL="google/gemma-4-26B-A4B-it" if [ "$(python3 -c 'import sys; print(sys.version_info[1])')" -le "8" ]; then echo "Prepare models with python > 3.8."; exit 1 ; fi @@ -217,3 +218,14 @@ if [ ! -f "$1/$DEVSTRAL_MODEL/$TOKENIZER_FILE" ]; then echo "[ERROR] Models file $1/$DEVSTRAL_MODEL/$TOKENIZER_FILE does not exist." exit 1 fi + +if [ -f "$1/$GEMMA4_MODEL/$TOKENIZER_FILE" ]; then + echo "Models file $1/$GEMMA4_MODEL/$TOKENIZER_FILE exists. Skipping downloading models." +else + mkdir -p $1/$GEMMA4_MODEL + convert_tokenizer $GEMMA4_MODEL --with_detokenizer -o $1/$GEMMA4_MODEL +fi +if [ ! -f "$1/$GEMMA4_MODEL/$TOKENIZER_FILE" ]; then + echo "[ERROR] Models file $1/$GEMMA4_MODEL/$TOKENIZER_FILE does not exist." + exit 1 +fi diff --git a/spelling-whitelist.txt b/spelling-whitelist.txt index 0f55b618d5..4b5ee05f30 100644 --- a/spelling-whitelist.txt +++ b/spelling-whitelist.txt @@ -27,4 +27,5 @@ release_files/thirdparty-licenses/libgt2.LICENSE.txt:1083: publically ==> public src/test/llm/output_parsers/qwen3coder_output_parser_test.cpp demos/vlm_npu/README.md:157: mane ==> main, many, maine demos/vlm_npu/README.md:218: mane ==> main, many, maine -demos/integration_with_OpenWebUI/README.md:423: Buildin ==> Building, Build in \ No newline at end of file +src/test/llm/output_parsers/gemma4_output_parser_test.cpp +demos/integration_with_OpenWebUI/README.md:423: Buildin ==> Building, Build in diff --git a/src/llm/BUILD b/src/llm/BUILD index 8fe6059d71..aaf94b7bd3 100644 --- a/src/llm/BUILD +++ b/src/llm/BUILD @@ -143,6 +143,12 @@ ovms_cc_library( name = "io_processing_utils", hdrs = ["io_processing/utils.hpp"], srcs = ["io_processing/utils.cpp"], + deps = [ + "@com_github_tencent_rapidjson//:rapidjson", + "//src/port:rapidjson_stringbuffer", + "//src/port:rapidjson_writer", + "//src/port:rapidjson_document", + ], visibility = ["//visibility:public"], ) @@ -175,6 +181,23 @@ ovms_cc_library( ], visibility = ["//visibility:public"], ) + +ovms_cc_library( + name = "io_processing_gemma4_tool_parser", + hdrs = ["io_processing/gemma4/tool_parser.hpp"], + srcs = ["io_processing/gemma4/tool_parser.cpp"], + deps = [ + "@com_github_tencent_rapidjson//:rapidjson", + "//src/port:rapidjson_document", + "//src:libovmslogging", + "//src:libovmsstring_utils", + ":io_processing_utils", + ":io_processing_base_output_parser", + "//third_party:genai", + ], + visibility = ["//visibility:public"], +) + ovms_cc_library( # TODO split further so we don't have to recompile everything when changing one parser ... name = "output_parsers", hdrs = [ @@ -210,6 +233,7 @@ ovms_cc_library( # TODO split further so we don't have to recompile everything w ":partial_json_builder", ":io_processing_base_output_parser", ":io_processing_qwen3coder_tool_parser", + ":io_processing_gemma4_tool_parser", ":io_processing_utils", ":apis_tool_schema_wrapper", ], diff --git a/src/llm/apis/openai_api_handler.hpp b/src/llm/apis/openai_api_handler.hpp index 9071e6addc..7c56bcbf95 100644 --- a/src/llm/apis/openai_api_handler.hpp +++ b/src/llm/apis/openai_api_handler.hpp @@ -164,7 +164,7 @@ class OpenAIApiHandler { // Serialization - pure virtual, each handler produces its own response format virtual std::string serializeUnaryResponse(const std::vector& generationOutputs) = 0; virtual std::string serializeUnaryResponse(ov::genai::EncodedResults& results) = 0; - virtual std::string serializeUnaryResponse(ov::genai::VLMDecodedResults& results) = 0; + virtual std::string serializeUnaryResponse(ov::genai::VLMDecodedResults& results, const std::string& textResponse) = 0; virtual std::string serializeStreamingChunk(const std::string& chunkResponse, ov::genai::GenerationFinishReason finishReason) = 0; virtual std::string serializeStreamingUsageChunk() = 0; virtual std::string serializeStreamingHandshakeChunk() = 0; diff --git a/src/llm/apis/openai_completions.cpp b/src/llm/apis/openai_completions.cpp index 433ce59c3f..7435706351 100644 --- a/src/llm/apis/openai_completions.cpp +++ b/src/llm/apis/openai_completions.cpp @@ -458,7 +458,7 @@ std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(ov::genai::Enco return jsonResponse.ToString(); } -std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(ov::genai::VLMDecodedResults& results) { +std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(ov::genai::VLMDecodedResults& results, const std::string& textResponse) { OVMS_PROFILE_FUNCTION(); usage.promptTokens = results.perf_metrics.get_num_input_tokens(); usage.completionTokens = results.perf_metrics.get_num_generated_tokens(); @@ -470,13 +470,12 @@ std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(ov::genai::VLMD jsonResponse.StartArray("choices"); int index = 0; - for (int i = 0; i < results.texts.size(); i++) { - const std::string& text = results.texts[i]; - SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Generated text: {}", text); + if (!textResponse.empty()) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Generated text: {}", textResponse); // Workaround to use OVMS unary parsers: get tokens from string // This way we have detokenized text from GenAI and calculate tokens, to further convert back to text again, in parseOutputIfNeeded... - auto generatedTokens = encodeTextToTokens(text); + auto generatedTokens = encodeTextToTokens(textResponse); SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Generated tokens: {}", generatedTokens); ParsedOutput parsedOutput = parseOutputIfNeeded(generatedTokens); diff --git a/src/llm/apis/openai_completions.hpp b/src/llm/apis/openai_completions.hpp index 9ebe529637..cbb8f2645f 100644 --- a/src/llm/apis/openai_completions.hpp +++ b/src/llm/apis/openai_completions.hpp @@ -39,7 +39,7 @@ class OpenAIChatCompletionsHandler : public OpenAIApiHandler { std::string serializeUnaryResponse(const std::vector& generationOutputs) override; std::string serializeUnaryResponse(ov::genai::EncodedResults& results) override; - std::string serializeUnaryResponse(ov::genai::VLMDecodedResults& results) override; + std::string serializeUnaryResponse(ov::genai::VLMDecodedResults& results, const std::string& textResponse) override; std::string serializeStreamingChunk(const std::string& chunkResponse, ov::genai::GenerationFinishReason finishReason) override; std::string serializeStreamingUsageChunk() override; std::string serializeStreamingHandshakeChunk() override; diff --git a/src/llm/apis/openai_responses.cpp b/src/llm/apis/openai_responses.cpp index e5d63985e6..49703c0fc2 100644 --- a/src/llm/apis/openai_responses.cpp +++ b/src/llm/apis/openai_responses.cpp @@ -655,21 +655,21 @@ std::string OpenAIResponsesHandler::serializeUnaryResponse(ov::genai::EncodedRes return serializeUnaryResponseImpl(parsedOutputs); } -std::string OpenAIResponsesHandler::serializeUnaryResponse(ov::genai::VLMDecodedResults& results) { +std::string OpenAIResponsesHandler::serializeUnaryResponse(ov::genai::VLMDecodedResults& results, const std::string& textResponse) { OVMS_PROFILE_FUNCTION(); usage.promptTokens = results.perf_metrics.get_num_input_tokens(); usage.completionTokens = results.perf_metrics.get_num_generated_tokens(); // Usage is already correctly set from perf_metrics above — no need for updateUsage. std::vector parsedOutputs; - for (const std::string& text : results.texts) { + if (!textResponse.empty()) { if (outputParser != nullptr) { // Same workaround as in chat completions - auto generatedTokens = encodeTextToTokens(text); + auto generatedTokens = encodeTextToTokens(textResponse); parsedOutputs.push_back(parseOutputIfNeeded(generatedTokens)); } else { // Fast path: no output parser, use decoded text directly. ParsedOutput output; - output.content = text; + output.content = textResponse; parsedOutputs.push_back(std::move(output)); } } diff --git a/src/llm/apis/openai_responses.hpp b/src/llm/apis/openai_responses.hpp index 0e5fd892b7..6a10400952 100644 --- a/src/llm/apis/openai_responses.hpp +++ b/src/llm/apis/openai_responses.hpp @@ -97,7 +97,7 @@ class OpenAIResponsesHandler : public OpenAIApiHandler { std::string serializeUnaryResponse(const std::vector& generationOutputs) override; std::string serializeUnaryResponse(ov::genai::EncodedResults& results) override; - std::string serializeUnaryResponse(ov::genai::VLMDecodedResults& results) override; + std::string serializeUnaryResponse(ov::genai::VLMDecodedResults& results, const std::string& textResponse) override; std::string serializeStreamingChunk(const std::string& chunkResponse, ov::genai::GenerationFinishReason finishReason) override; std::string serializeStreamingUsageChunk() override; std::string serializeStreamingHandshakeChunk() override; diff --git a/src/llm/io_processing/gemma4/tool_parser.cpp b/src/llm/io_processing/gemma4/tool_parser.cpp new file mode 100644 index 0000000000..53385c334d --- /dev/null +++ b/src/llm/io_processing/gemma4/tool_parser.cpp @@ -0,0 +1,507 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include "tool_parser.hpp" +#include "../utils.hpp" +#include "../../../logging.hpp" +#include "../../../stringutils.hpp" +#include "rapidjson/error/en.h" +#include +#include +#include + +namespace ovms { + +const std::string Gemma4ToolParser::TOOL_CALL_START_TAG = "<|tool_call>"; +const std::string Gemma4ToolParser::TOOL_CALL_END_TAG = ""; +const std::string Gemma4ToolParser::TOOL_CALL_NAME_PREFIX = "call:"; + +const std::string Gemma4ToolParser::TOOL_ARGS_START_INDICATOR = "{"; +const std::string Gemma4ToolParser::TOOL_ARGS_END_INDICATOR = "}"; +const std::string Gemma4ToolParser::TOOL_ARGS_STRING_INDICATOR = "<|\"|>"; +const std::string Gemma4ToolParser::TOOL_ARGS_SEPARATOR_STR = ","; + +const int64_t Gemma4ToolParser::botTokenId = 48; +const int64_t Gemma4ToolParser::eotTokenId = 49; + +std::string Gemma4ToolParser::parseArrayParameter(const std::string& argumentStr) { + size_t pos = 1; + std::string parsedArguments = "["; + + while (pos != std::string::npos) { + size_t stringStartPos = argumentStr.find(TOOL_ARGS_STRING_INDICATOR, pos); + if (stringStartPos == std::string::npos) { + break; + } + stringStartPos += TOOL_ARGS_STRING_INDICATOR.size(); + size_t stringEndPos = argumentStr.find(TOOL_ARGS_STRING_INDICATOR, stringStartPos); + if (stringEndPos == std::string::npos) { + break; + } + + std::string originalStr = argumentStr.substr(stringStartPos, stringEndPos - stringStartPos); + size_t quotePos = 0; + while ((quotePos = originalStr.find('\"', quotePos)) != std::string::npos) { + originalStr.insert(quotePos, "\\"); + quotePos += 2; + } + parsedArguments += "\"" + originalStr + "\","; + + pos = stringEndPos + TOOL_ARGS_STRING_INDICATOR.size() + 1; + } + + parsedArguments.back() = ']'; + + return parsedArguments; +} + +std::string Gemma4ToolParser::parseObjectParameter(std::string argumentStr) { + size_t pos = 1; + std::vector> keyValuePairs; + + while (pos != std::string::npos) { + std::string key, value; + bool isStringValue = false; + size_t keyEndPos = argumentStr.find(':', pos); + if (keyEndPos == std::string::npos) { + break; + } + key = argumentStr.substr(pos, keyEndPos - pos); + size_t valueStartPos = keyEndPos + 1; + size_t valueEndPos; + if (argumentStr.substr(valueStartPos, TOOL_ARGS_STRING_INDICATOR.size()) == TOOL_ARGS_STRING_INDICATOR) { + valueStartPos = valueStartPos + TOOL_ARGS_STRING_INDICATOR.size(); + valueEndPos = argumentStr.find(TOOL_ARGS_STRING_INDICATOR, valueStartPos); + isStringValue = true; + } else { + valueEndPos = argumentStr.find(',', valueStartPos); + } + + if (valueEndPos == std::string::npos) { + valueEndPos = argumentStr.size() - 1; + } + value = argumentStr.substr(valueStartPos, valueEndPos - valueStartPos); + if (isStringValue) { + value = "\"" + value + "\""; + } + keyValuePairs.emplace_back(key, value); + if (valueEndPos == argumentStr.size() - 1) { + break; + } else if (isStringValue) { + pos = valueEndPos + TOOL_ARGS_STRING_INDICATOR.size() + 1; + } else { + pos = valueEndPos + 1; + } + } + + if (keyValuePairs.empty()) { + return argumentStr; + } + + std::string parsedObject = "{"; + for (const auto& [key, value] : keyValuePairs) { + parsedObject += "\"" + key + "\":" + value + ","; + } + parsedObject.back() = '}'; + return parsedObject; +} + +std::string Gemma4ToolParser::normalizeArgStr(const std::string& arg) { + if (arg.empty()) { + return arg; + } + + std::string normalized = arg; + trim(normalized); + std::string lower = normalized; + std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower); + + if (lower == "true" || lower == "false" || lower == "null") { + return lower; + } + + const char first = normalized.front(); + const char last = normalized.back(); + if (first == '{' && last == '}') { + normalized = parseObjectParameter(normalized); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Argument contains is an object, changed it to correct JSON format. Modified string: {}", normalized); + } + + if (first == '[' && last == ']' && normalized.find(TOOL_ARGS_STRING_INDICATOR) != std::string::npos) { + normalized = parseArrayParameter(normalized); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Argument is an array, normalized quotes for JSON parsing. Modified string: {}", normalized); + } + + if (normalized.substr(0, TOOL_ARGS_STRING_INDICATOR.size()) == TOOL_ARGS_STRING_INDICATOR && + normalized.substr(normalized.size() - TOOL_ARGS_STRING_INDICATOR.size(), TOOL_ARGS_STRING_INDICATOR.size()) == TOOL_ARGS_STRING_INDICATOR) { + normalized = "\"" + normalized.substr(TOOL_ARGS_STRING_INDICATOR.size(), normalized.size() - 2 * TOOL_ARGS_STRING_INDICATOR.size()) + "\""; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Argument is enclosed in string indicators, removed them for JSON parsing. Modified string: {}", normalized); + } + + rapidjson::Document tempDoc; + rapidjson::Value finalValue; + tempDoc.Parse(normalized.c_str()); + if (tempDoc.HasParseError()) { + auto errorCode = tempDoc.GetParseError(); + auto errorMessage = rapidjson::GetParseError_En(errorCode); + size_t errorOffset = tempDoc.GetErrorOffset(); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Failed to parse argument string as JSON. Argument string: {}, Error: {} Offset: {}", normalized, errorMessage, errorOffset); + + if (normalized.front() == '\"' && normalized.back() == '\"') { + normalized = normalized.substr(1, normalized.size() - 2); + } + finalValue.SetString(normalized.c_str(), static_cast(normalized.size()), tempDoc.GetAllocator()); + } else { + finalValue.CopyFrom(tempDoc, tempDoc.GetAllocator()); + } + + { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + finalValue.Accept(writer); + normalized = buffer.GetString(); + } + + return normalized; +} + +void Gemma4ToolParser::writeArgumentToWriter(const std::string& arg, rapidjson::Writer& writer) { + std::string normalized = normalizeArgStr(arg); + + rapidjson::Document doc; + doc.Parse(normalized.c_str()); + + rapidjson::Value& argumentDoc = doc; + writeArgumentOfAnyType(argumentDoc, writer); +} + +std::pair Gemma4ToolParser::parseSingleArgument(const std::string& argumentStr) { + std::pair argument; + + size_t colonPos = argumentStr.find(':'); + if (colonPos != std::string::npos) { + argument.first = argumentStr.substr(0, colonPos); + std::string value = argumentStr.substr(colonPos + 1); + argument.second = value; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed argument - name: {}, value: {}", argument.first, argument.second); + } else { + argument.first = argumentStr; + argument.second = ""; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Argument string: {} does not contain ':', setting name as entire string and value as empty", argumentStr); + } + return argument; +} + +std::vector> Gemma4ToolParser::parseArguments(const std::string& argumentsStr) { + std::vector args; + std::vector> parsedArgs; + + size_t argPos = 0; + while (argPos < argumentsStr.length()) { + size_t commaPos = findInStringRespectingSpecialChars(argumentsStr, TOOL_ARGS_SEPARATOR_STR, argPos); + if (commaPos == std::string::npos) { + auto remainingStr = argumentsStr.substr(argPos); + args.push_back(remainingStr); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "No more commas found, adding remaining argument string: {}", remainingStr); + break; + } + std::string argStr = argumentsStr.substr(argPos, commaPos - argPos); + args.push_back(argStr); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed argument string: {}", argStr); + argPos = commaPos + TOOL_ARGS_SEPARATOR_STR.length(); + } + + for (const std::string& arg : args) { + parsedArgs.push_back(parseSingleArgument(arg)); + } + return parsedArgs; +} + +bool Gemma4ToolParser::parseInContentState() { + size_t toolCallStartTagPos = this->streamingContent.find(TOOL_CALL_START_TAG, this->streamingPosition); + if (toolCallStartTagPos != std::string::npos) { + if (toolCallStartTagPos > this->streamingPosition) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Content found before tool call start tag at position: {}", toolCallStartTagPos); + return true; + } + this->streamingPosition = toolCallStartTagPos + TOOL_CALL_START_TAG.length(); + this->currentState = State::ToolCallStarted; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Detected start of tool call at position: {}", toolCallStartTagPos); + return false; + } + + return true; +} + +bool Gemma4ToolParser::parseInToolCallState() { + size_t argsPos = this->streamingContent.find(TOOL_ARGS_START_INDICATOR, this->streamingPosition); + if (argsPos == std::string::npos) { + return false; + } + + size_t toolNameStart = this->streamingContent.find(TOOL_CALL_NAME_PREFIX, this->streamingPosition); + if (toolNameStart != std::string::npos && toolNameStart < argsPos) { + toolNameStart += TOOL_CALL_NAME_PREFIX.length(); + } else { + toolNameStart = this->streamingPosition; + } + + std::string toolName = this->streamingContent.substr(toolNameStart, argsPos - toolNameStart); + this->toolCall = ToolCall{generateRandomId(), toolName, ""}; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed tool name: {}", toolName); + this->streamingPosition = argsPos + TOOL_ARGS_START_INDICATOR.length(); + this->currentState = State::ToolCallParameters; + this->toolCallIndex++; + return true; +} + +bool Gemma4ToolParser::parseToolCallParametersState() { + if (this->streamingContent.back() == TOOL_ARGS_END_INDICATOR.back()) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Tool arguments end indicator found at the end of streaming content, attempting to parse arguments: {}", this->streamingContent.substr(this->streamingPosition)); + } + size_t pos = findInStringRespectingSpecialChars(this->streamingContent, TOOL_ARGS_END_INDICATOR, this->streamingPosition); + if (pos == std::string::npos) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Tool arguments end indicator not found in streaming content starting from position: {}", this->streamingPosition); + return false; + } + std::string argumentsStr = this->streamingContent.substr(this->streamingPosition, pos - this->streamingPosition); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed arguments string: {}", argumentsStr); + std::vector> arguments = parseArguments(argumentsStr); + + rapidjson::Document argsDoc(rapidjson::kObjectType); + rapidjson::StringBuffer sb; + rapidjson::Writer argsWriter(sb); + argsWriter.StartObject(); + + for (const std::pair& argument : arguments) { + argsWriter.Key(argument.first.c_str()); + writeArgumentToWriter(argument.second, argsWriter); + } + + argsWriter.EndObject(); + this->toolCall.arguments = sb.GetString(); + this->currentState = State::ToolCallEnded; + this->streamingPosition = pos + TOOL_ARGS_END_INDICATOR.length(); + + return true; +} + +bool Gemma4ToolParser::parseInToolCallEndedState() { + size_t nextToolCallPos = this->streamingContent.find(TOOL_CALL_NAME_PREFIX, this->streamingPosition); + size_t toolCallEndTagPos = this->streamingContent.find(TOOL_CALL_END_TAG, this->streamingPosition); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Current state: ToolCallEnded. Streaming content from current position: {}", this->streamingContent.substr(this->streamingPosition)); + if (nextToolCallPos != std::string::npos && nextToolCallPos < toolCallEndTagPos) { + this->streamingPosition = nextToolCallPos; + this->currentState = State::ToolCallStarted; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Detected next tool call at position: {}", nextToolCallPos); + } else if (toolCallEndTagPos != std::string::npos) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Detected end of tool call at position: {}", toolCallEndTagPos); + this->streamingPosition = toolCallEndTagPos + TOOL_CALL_END_TAG.length(); + this->currentState = State::AfterToolCall; + } else { + this->streamingPosition = toolCallEndTagPos + TOOL_CALL_END_TAG.length(); + this->currentState = State::AfterToolCall; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Detected end of tool call at position: {}, returning to content state", toolCallEndTagPos); + } + return true; +} + +bool Gemma4ToolParser::parseNewContent() { + switch (this->currentState) { + case State::Content: { + return parseInContentState(); + } + case State::ToolCallStarted: { + return parseInToolCallState(); + } + case State::ToolCallParameters: { + return parseToolCallParametersState(); + } + case State::ToolCallEnded: { + return parseInToolCallEndedState(); + } + case State::AfterToolCall: + break; + } + return false; +} + +rapidjson::Document Gemma4ToolParser::wrapDeltaContent(const std::string& content) { + rapidjson::Document doc(rapidjson::kObjectType); + rapidjson::Value deltaObj(rapidjson::kObjectType); + deltaObj.AddMember("content", rapidjson::Value(content.c_str(), doc.GetAllocator()), doc.GetAllocator()); + doc.AddMember("delta", deltaObj, doc.GetAllocator()); + return doc; +} + +rapidjson::Document Gemma4ToolParser::wrapDeltaArgs(const std::string& argsStr, int toolCallIndex) { + rapidjson::Document doc(rapidjson::kObjectType); + doc.AddMember("arguments", rapidjson::Value(argsStr.c_str(), doc.GetAllocator()), doc.GetAllocator()); + + return BaseOutputParser::wrapDelta(doc, toolCallIndex); +} + +std::optional Gemma4ToolParser::parseChunk(const std::string& chunk, ov::genai::GenerationFinishReason finishReason) { + if (chunk.empty()) { + return std::nullopt; + } + + this->streamingContent += chunk; + + if (parseNewContent()) { + if (this->currentState == State::ToolCallParameters) { + return BaseOutputParser::wrapFirstDelta(this->toolCall.name, toolCallIndex); + } + if (this->currentState == State::ToolCallEnded) { + return wrapDeltaArgs(this->toolCall.arguments, toolCallIndex); + } + if (this->currentState == State::Content) { + size_t contentEnd = this->streamingContent.find(TOOL_CALL_START_TAG, this->streamingPosition); + std::string content; + if (contentEnd != std::string::npos) { + content = this->streamingContent.substr(this->streamingPosition, contentEnd - this->streamingPosition); + } else { + content = this->streamingContent.substr(this->streamingPosition); + } + this->streamingPosition += content.size(); + if (!content.empty()) { + return wrapDeltaContent(content); + } + } + if (this->currentState == State::AfterToolCall) { + this->currentState = State::Content; + } + } + + if (finishReason != ov::genai::GenerationFinishReason::NONE) { + if ((this->currentState == State::ToolCallParameters || this->currentState == State::ToolCallEnded) && !this->toolCall.arguments.empty()) { + return wrapDeltaArgs(this->toolCall.arguments, toolCallIndex); + } + + if (this->currentState == State::Content && this->streamingPosition < this->streamingContent.size()) { + auto content = this->streamingContent.substr(this->streamingPosition); + this->streamingPosition += content.size(); + + return wrapDeltaContent(content); + } + } + + return std::nullopt; +} + +bool Gemma4ToolParser::parseSingleToolCall(const std::string& toolStr, ToolCall& toolCall) { + size_t argsPos = toolStr.find(TOOL_ARGS_START_INDICATOR); + if (argsPos != std::string::npos) { + std::string toolNameWithPrefix = toolStr.substr(0, argsPos); + if (toolNameWithPrefix.find(TOOL_CALL_NAME_PREFIX) != 0) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Tool name does not start with expected prefix '{}'. Tool string: {}", TOOL_CALL_NAME_PREFIX, toolStr); + return false; + } + std::string toolName = toolNameWithPrefix.substr(TOOL_CALL_NAME_PREFIX.length()); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed tool name: {}", toolName); + + int argsStrLen = toolStr.length() - argsPos - TOOL_ARGS_START_INDICATOR.length() - TOOL_ARGS_END_INDICATOR.length(); + std::string argsStr = toolStr.substr(argsPos + TOOL_ARGS_START_INDICATOR.length(), argsStrLen); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed args string: {}", argsStr); + std::vector> arguments = parseArguments(argsStr); + + toolCall.name = toolName; + rapidjson::Document argsDoc(rapidjson::kObjectType); + rapidjson::StringBuffer sb; + rapidjson::Writer argsWriter(sb); + argsWriter.StartObject(); + for (const std::pair& argument : arguments) { + argsWriter.Key(argument.first.c_str()); + writeArgumentToWriter(argument.second, argsWriter); + } + argsWriter.EndObject(); + toolCall.arguments = sb.GetString(); + toolCall.id = generateRandomId(); + return true; + } + return false; +} + +void Gemma4ToolParser::parse(ParsedOutput& parsedOutput, const std::vector& generatedTokens) { + std::vector tools; + std::vector> toolCallPositions; + size_t pos = 0; + + while (pos != std::string::npos) { + size_t start, end; + auto it = std::find(generatedTokens.begin() + pos, generatedTokens.end(), botTokenId); + if (it != generatedTokens.end()) { + start = std::distance(generatedTokens.begin(), it); + } else { + break; + } + auto itArgs = std::find(generatedTokens.begin() + start, generatedTokens.end(), eotTokenId); + if (itArgs != generatedTokens.end()) { + end = std::distance(generatedTokens.begin(), itArgs); + } else { + break; + } + + std::string toolCallStr = tokenizer.decode(std::vector(generatedTokens.begin() + start + 1, generatedTokens.begin() + end + 1), ov::AnyMap{ov::genai::skip_special_tokens(false)}); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed tool list string: {}", toolCallStr); + + while (!toolCallStr.empty()) { + size_t nextToolPos = toolCallStr.find(TOOL_CALL_NAME_PREFIX, TOOL_CALL_NAME_PREFIX.length()); + size_t toolEndPos; + if (nextToolPos == std::string::npos) { + toolEndPos = toolCallStr.rfind(TOOL_ARGS_END_INDICATOR); + } else { + toolEndPos = nextToolPos - 1; + } + std::string singleTool; + if (toolEndPos != std::string::npos) { + singleTool = toolCallStr.substr(0, toolEndPos + TOOL_ARGS_END_INDICATOR.length()); + if (toolEndPos + TOOL_ARGS_END_INDICATOR.length() < toolCallStr.length()) { + toolCallStr = toolCallStr.substr(toolEndPos + TOOL_ARGS_END_INDICATOR.length()); + } else { + toolCallStr.clear(); + } + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed single tool string {}", singleTool); + } else { + break; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "No more tool strings found in the decoded string: {}", toolCallStr); + } + + if (!singleTool.empty()) { + tools.push_back(singleTool); + } + } + + pos = end; + toolCallPositions.emplace_back(start, end); + } + + for (const std::string& tool : tools) { + ToolCall toolCall; + auto wasToolCallParsed = parseSingleToolCall(tool, toolCall); + if (wasToolCallParsed) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed tool call - name: {}, args: {}", toolCall.name, toolCall.arguments); + parsedOutput.toolCalls.push_back(toolCall); + } else { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Failed to parse tool call from string: {}", tool); + } + } + std::vector contentWithoutToolCalls = generatedTokens; + for (auto it = toolCallPositions.rbegin(); it != toolCallPositions.rend(); ++it) { + contentWithoutToolCalls.erase(contentWithoutToolCalls.begin() + it->first, contentWithoutToolCalls.begin() + it->second + 1); + } + parsedOutput.content = tokenizer.decode(contentWithoutToolCalls, ov::AnyMap{ov::genai::skip_special_tokens(true)}); +} +} // namespace ovms diff --git a/src/llm/io_processing/gemma4/tool_parser.hpp b/src/llm/io_processing/gemma4/tool_parser.hpp new file mode 100644 index 0000000000..aebe7c5fa3 --- /dev/null +++ b/src/llm/io_processing/gemma4/tool_parser.hpp @@ -0,0 +1,96 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once +#include +#include +#include +#include "src/llm/io_processing/base_output_parser.hpp" + +namespace ovms { +class Gemma4ToolParser : public BaseOutputParser { +protected: + static const std::string TOOL_CALL_START_TAG; + static const std::string TOOL_CALL_END_TAG; + static const std::string TOOL_CALL_NAME_PREFIX; + + static const std::string TOOL_ARGS_START_INDICATOR; + static const std::string TOOL_ARGS_END_INDICATOR; + static const std::string TOOL_ARGS_STRING_INDICATOR; + static const std::string TOOL_ARGS_SEPARATOR_STR; + + static const int64_t botTokenId; + static const int64_t eotTokenId; + + enum class State { + Content, // Content -> ToolCallStarted (on TOOL_CALL_START_TAG) + ToolCallStarted, // ToolCallStarted -> ToolCallParameters (on TOOL_ARGS_START_INDICATOR, emits name) + ToolCallParameters, // ToolCallParameters -> ToolCallEnded (on TOOL_ARGS_END_INDICATOR, emits args) + ToolCallEnded, // ToolCallEnded -> ToolCallStarted (on TOOL_CALL_NAME_PREFIX) | AfterToolCall (on end tag) + AfterToolCall // AfterToolCall -> Content + }; + +public: + Gemma4ToolParser() = delete; + explicit Gemma4ToolParser(ov::genai::Tokenizer& tokenizer) : + BaseOutputParser(tokenizer) {} + + void parse(ParsedOutput& parsedOutput, const std::vector& generatedTokens) override; + std::optional parseChunk(const std::string& chunk, ov::genai::GenerationFinishReason finishReason) override; + const std::vector& getParsingStartTags() const override { + static const std::vector parsingStartTags = {TOOL_CALL_START_TAG}; + return parsingStartTags; + } + + const std::vector& getSpecialParsingStartTags() const override { + static const std::vector beginningOnlyTags = {}; + return beginningOnlyTags; + } + + const std::string& getParsingEndTag() const override { + return TOOL_CALL_END_TAG; + } + + bool requiresStreamingWithSpecialTokens() const override { + return true; + } + + static std::string normalizeArgStr(const std::string& arg); + static std::string parseArrayParameter(const std::string& argumentStr); + static std::string parseObjectParameter(std::string argumentStr); + +private: + void writeArgumentToWriter(const std::string& arg, rapidjson::Writer& writer); + + std::pair parseSingleArgument(const std::string& argumentStr); + std::vector> parseArguments(const std::string& argumentsStr); + + bool parseSingleToolCall(const std::string& toolStr, ToolCall& toolCall); + bool parseNewContent(); + bool parseInContentState(); + bool parseInToolCallState(); + bool parseToolCallParametersState(); + bool parseInToolCallEndedState(); + + rapidjson::Document wrapDeltaContent(const std::string& content); + rapidjson::Document wrapDeltaArgs(const std::string& argsStr, int toolCallIndex); + + std::string streamingContent; + size_t streamingPosition{0}; + State currentState{State::Content}; + ToolCall toolCall; + int toolCallIndex{-1}; +}; +} // namespace ovms diff --git a/src/llm/io_processing/output_parser.cpp b/src/llm/io_processing/output_parser.cpp index 1c060375df..32a34e207e 100644 --- a/src/llm/io_processing/output_parser.cpp +++ b/src/llm/io_processing/output_parser.cpp @@ -29,6 +29,7 @@ #include "qwen3coder/qwen3coder_tool_parser.hpp" #include "devstral/tool_parser.hpp" #include "gptoss/reasoning_parser.hpp" +#include "gemma4/tool_parser.hpp" namespace ovms { OutputParser::TagLookupStatus OutputParser::StreamOutputCache::lookupTag(const std::string& tag) const { @@ -171,6 +172,8 @@ OutputParser::OutputParser(ov::genai::Tokenizer& tokenizer, const std::string to toolParser = std::make_unique(tokenizer, toolNameSchemaMap); } else if (toolParserName == "devstral") { toolParser = std::make_unique(tokenizer, toolNameSchemaMap); + } else if (toolParserName == "gemma4") { + toolParser = std::make_unique(tokenizer); } else if (!toolParserName.empty()) { throw std::runtime_error("Unsupported tool parser: " + toolParserName); } diff --git a/src/llm/io_processing/utils.cpp b/src/llm/io_processing/utils.cpp index c1f573d49e..1b8075f404 100644 --- a/src/llm/io_processing/utils.cpp +++ b/src/llm/io_processing/utils.cpp @@ -35,4 +35,59 @@ std::string generateRandomId() { } return id; } +void writeArgumentOfAnyType(const rapidjson::Value& arg, rapidjson::Writer& writer) { + if (arg.IsString()) { + writer.String(arg.GetString()); + } else if (arg.IsInt64()) { + writer.Int64(arg.GetInt64()); + } else if (arg.IsDouble()) { + writer.Double(arg.GetDouble()); + } else if (arg.IsBool()) { + writer.Bool(arg.GetBool()); + } else if (arg.IsArray()) { + writer.StartArray(); + for (auto& elem : arg.GetArray()) { + writeArgumentOfAnyType(elem, writer); + } + writer.EndArray(); + } else if (arg.IsObject()) { + writer.StartObject(); + for (auto it = arg.MemberBegin(); it != arg.MemberEnd(); ++it) { + writer.Key(it->name.GetString()); + writeArgumentOfAnyType(it->value, writer); + } + writer.EndObject(); + } else { + writer.String(""); + } +} + +size_t findInStringRespectingSpecialChars(const std::string& str, const std::string& target, size_t startPos) { + int bracketDepth = 0; + int braceDepth = 0; + int quoteDepth = 0; + int singleQuoteDepth = 0; + + for (size_t i = startPos; i < str.length(); ++i) { + if (bracketDepth == 0 && braceDepth == 0 && quoteDepth == 0 && singleQuoteDepth == 0 && + str.compare(i, target.length(), target) == 0) { + return i; + } + + if (str[i] == '{') { + braceDepth++; + } else if (str[i] == '}') { + braceDepth--; + } else if (str[i] == '[') { + bracketDepth++; + } else if (str[i] == ']') { + bracketDepth--; + } else if (str[i] == '"' && (i == 0 || str[i - 1] != '\\')) { + quoteDepth = 1 - quoteDepth; + } else if (str[i] == '\'' && (i == 0 || str[i - 1] != '\\')) { + singleQuoteDepth = 1 - singleQuoteDepth; + } + } + return std::string::npos; +} } // namespace ovms diff --git a/src/llm/io_processing/utils.hpp b/src/llm/io_processing/utils.hpp index 78abd38166..dfb177b4f5 100644 --- a/src/llm/io_processing/utils.hpp +++ b/src/llm/io_processing/utils.hpp @@ -16,7 +16,17 @@ #pragma once #include +#pragma warning(push) +#pragma warning(disable : 6313) +#include +#include +#include +#pragma warning(pop) + namespace ovms { // Generates random alphanumeric string of length 9 for tool call ID std::string generateRandomId(); + +size_t findInStringRespectingSpecialChars(const std::string& str, const std::string& target, size_t startPos); +void writeArgumentOfAnyType(const rapidjson::Value& arg, rapidjson::Writer& writer); } // namespace ovms diff --git a/src/llm/visual_language_model/legacy/servable.cpp b/src/llm/visual_language_model/legacy/servable.cpp index bc3ecaf71f..fe29f7f6d9 100644 --- a/src/llm/visual_language_model/legacy/servable.cpp +++ b/src/llm/visual_language_model/legacy/servable.cpp @@ -110,18 +110,15 @@ absl::Status VisualLanguageModelLegacyServable::parseRequest(std::shared_ptrexecutionInProgress, &mutex = legacyExecutionContext->mutex, &lastStreamerCallbackOutput = legacyExecutionContext->lastStreamerCallbackOutput, - &clientDisconnected = legacyExecutionContext->clientDisconnected, - streamMode = legacyExecutionContext->apiHandler->isStream()](std::string text) { + &clientDisconnected = legacyExecutionContext->clientDisconnected](std::string text) { SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Streamer callback executed with text: [{}]", text); if (clientDisconnected.load()) { executionInProgress.notify_one(); return ov::genai::StreamingStatus::CANCEL; } - if (streamMode) { - std::lock_guard lock(mutex); - lastStreamerCallbackOutput += text; - executionInProgress.notify_one(); - } + std::lock_guard lock(mutex); + lastStreamerCallbackOutput += text; + executionInProgress.notify_one(); return ov::genai::StreamingStatus::RUNNING; }; ov::AnyMap streamerConfig; @@ -175,10 +172,40 @@ absl::Status VisualLanguageModelLegacyServable::readCompleteExecutionResults(std absl::Status VisualLanguageModelLegacyServable::prepareCompleteResponse(std::shared_ptr& executionContext) { auto legacyExecutionContext = std::static_pointer_cast(executionContext); - if (legacyExecutionContext->payload.client->isDisconnected()) { - return absl::CancelledError(); + + // temporary workaround to use streaming logic in unary + // to be fixed after require_special_tokens flag implemented + std::string completeText; + auto generationStatus = legacyExecutionContext->finished.wait_for(std::chrono::nanoseconds::zero()); + + while (generationStatus != std::future_status::ready) { + if (legacyExecutionContext->payload.client->isDisconnected()) { + return absl::CancelledError(); + } + std::unique_lock lock(legacyExecutionContext->mutex); + while (executionContext->lastStreamerCallbackOutput.size() == 0 && generationStatus != std::future_status::ready) { + legacyExecutionContext->executionInProgress.wait_for(lock, std::chrono::milliseconds(10)); + generationStatus = legacyExecutionContext->finished.wait_for(std::chrono::nanoseconds::zero()); + } + completeText += executionContext->lastStreamerCallbackOutput; + executionContext->lastStreamerCallbackOutput = ""; + generationStatus = legacyExecutionContext->finished.wait_for(std::chrono::nanoseconds::zero()); + } + + if (!legacyExecutionContext->success) { + return absl::InvalidArgumentError("Request processing failed, check its correctness."); } - executionContext->response = executionContext->apiHandler->serializeUnaryResponse(legacyExecutionContext->results); + + executionContext->textStreamer->end(); + { + std::unique_lock lock(legacyExecutionContext->mutex); + completeText += executionContext->lastStreamerCallbackOutput; + executionContext->lastStreamerCallbackOutput = ""; + } + + executionContext->apiHandler->setPromptTokensUsage(legacyExecutionContext->results.perf_metrics.get_num_input_tokens()); + executionContext->apiHandler->setCompletionTokensUsage(legacyExecutionContext->results.perf_metrics.get_num_generated_tokens()); + executionContext->response = executionContext->apiHandler->serializeUnaryResponse(legacyExecutionContext->results, completeText); SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Complete unary response: {}", executionContext->response); return absl::OkStatus(); } diff --git a/src/test/http_openai_handler_test.cpp b/src/test/http_openai_handler_test.cpp index 17813f5a2c..a64aa82a0d 100644 --- a/src/test/http_openai_handler_test.cpp +++ b/src/test/http_openai_handler_test.cpp @@ -929,7 +929,7 @@ TEST_F(HttpOpenAIHandlerParsingTest, serializeUnaryResponseVLMSupportsToolCallsF ov::genai::VLMDecodedResults results; std::string toolCall = R"({"name": "example_tool", "arguments": {"arg1": "value1", "arg2": 42}})"; results.texts = {toolCall}; - std::string serialized = apiHandler->serializeUnaryResponse(results); + std::string serialized = apiHandler->serializeUnaryResponse(results, toolCall); ASSERT_NE(serialized.find("\"finish_reason\":\"tool_calls\""), std::string::npos) << serialized; } @@ -2971,11 +2971,11 @@ TEST_F(HttpOpenAIHandlerParsingTest, SerializeUnaryResponseVLMDecodedResultsWith ASSERT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + std::string toolCallContent = "I will call a tool.{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}"; ov::genai::VLMDecodedResults results; - results.texts.push_back( - "I will call a tool.{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}"); + results.texts.push_back(toolCallContent); - std::string serialized = apiHandler->serializeUnaryResponse(results); + std::string serialized = apiHandler->serializeUnaryResponse(results, toolCallContent); rapidjson::Document responseDoc; responseDoc.Parse(serialized.c_str()); diff --git a/src/test/llm/output_parsers/gemma4_output_parser_test.cpp b/src/test/llm/output_parsers/gemma4_output_parser_test.cpp new file mode 100644 index 0000000000..5ed31005e5 --- /dev/null +++ b/src/test/llm/output_parsers/gemma4_output_parser_test.cpp @@ -0,0 +1,764 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../../llm/io_processing/base_output_parser.hpp" +#include "../../../llm/io_processing/output_parser.hpp" +#include "../../platform_utils.hpp" + +using namespace ovms; + +#ifdef _WIN32 +const std::string tokenizerPath = getWindowsRepoRootPath() + "\\src\\test\\llm_testing\\google\\gemma-4-26B-A4B-it"; +#else +// Hardcoded for usage in docker container +const std::string tokenizerPath = "/ovms/src/test/llm_testing/google/gemma-4-26B-A4B-it"; +#endif + +static std::unique_ptr gemma4Tokenizer; +static const ToolsSchemas_t& EMPTY_TOOLS_SCHEMA = {}; // not used in gemma4 + +class Gemma4OutputParserTest : public ::testing::Test { +protected: + std::unique_ptr outputParserWithRegularToolParsing; + + static void SetUpTestSuite() { + try { + gemma4Tokenizer = std::make_unique(tokenizerPath); + } catch (const std::exception& e) { + FAIL() << "Failed to initialize gemma4 tokenizer: " << e.what(); + } catch (...) { + FAIL() << "Failed to initialize gemma4 tokenizer due to unknown error."; + } + } + + static void TearDownTestSuite() { + gemma4Tokenizer.reset(); + } + + void SetUp() override { + // For Gemma4 model there is only tool parser available + outputParserWithRegularToolParsing = std::make_unique(*gemma4Tokenizer, "gemma4", "", EMPTY_TOOLS_SCHEMA); + } + + void assertChunkEqual(const std::optional& doc, const std::optional& expectedDelta, const std::string& chunk) { + if (!expectedDelta.has_value() && !doc.has_value()) { + return; + } + if (expectedDelta.has_value() && doc.has_value()) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + std::string docStr = buffer.GetString(); + std::string expected = expectedDelta.value(); + EXPECT_EQ(docStr, expected) << "Mismatch for chunk: " << chunk; + } else { + FAIL() << "Mismatch between expectedDelta and doc for chunk: " << chunk; + } + } +}; + +TEST_F(Gemma4OutputParserTest, ParseToolCallOutputWithSingleToolCall) { + std::string inputWithProperClosure = "<|tool_call>call:example_tool{arg1:<|\"|>value1<|\"|>,arg2:42}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "example_tool"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"arg1\":\"value1\",\"arg2\":42}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallOutputWithNoToolsInTheRequest) { + std::string inputWithProperClosure = "<|tool_call>call:example_tool{arg1:<|\"|>value1<|\"|>,arg2:42}"; + std::string inputWithoutSpecialTokens = "call:example_tool{arg1:value1,arg2:42}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, false); + EXPECT_EQ(parsedOutput.content, inputWithoutSpecialTokens); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 0); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithObjectArguments) { + std::string inputWithProperClosure = "<|tool_call>call:dummy{config:{name:<|\"|>astro_config<|\"|>,value:99}}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "dummy"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"config\":{\"name\":\"astro_config\",\"value\":99}}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArguments) { + std::string inputWithProperClosure = "<|tool_call>call:test1{arg1:<|\"|>data1,data2<|\"|>}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "test1"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"arg1\":\"data1,data2\"}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithListOfStringsAsArgument) { + std::string inputWithProperClosure = "<|tool_call>call:generate_DNA_sequence{length:100,preferences:[<|\"|>G<|\"|>,<|\"|>C<|\"|>]}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "generate_DNA_sequence"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"length\":100,\"preferences\":[\"G\",\"C\"]}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParserToolCallWithBooleanArgument) { + std::string inputWithProperClosure = "<|tool_call>call:check_status{flag:true}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "check_status"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"flag\":true}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParseTwoToolCallsAtOnce) { + std::string inputWithProperClosure = "<|tool_call>call:dummy1{config:{name:<|\"|>astro_config<|\"|>,value:99}}call:dummy2{config:{value:199,name:<|\"|>second_config<|\"|>}}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 2); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "dummy1"); + EXPECT_EQ(parsedOutput.toolCalls[1].name, "dummy2"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"config\":{\"name\":\"astro_config\",\"value\":99}}"); + EXPECT_EQ(parsedOutput.toolCalls[1].arguments, "{\"config\":{\"value\":199,\"name\":\"second_config\"}}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + EXPECT_EQ(parsedOutput.toolCalls[1].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithArrayArguments) { + std::string inputWithProperClosure = "<|tool_call>call:sort{array:[42,17,89,5,33],order:<|\"|>descending<|\"|>}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "sort"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"array\":[42,17,89,5,33],\"order\":\"descending\"}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallOutputWithThreeToolCalls) { + std::string inputWithProperClosure = "<|tool_call>call:example_tool{arg1:<|\"|>value1<|\"|>,arg2:42}" + "<|tool_call>call:another_tool{param1:<|\"|>data<|\"|>,param2:true}" + "<|tool_call>call:third_tool{key:<|\"|>value<|\"|>}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 3); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "example_tool"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"arg1\":\"value1\",\"arg2\":42}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + auto firstToolCallId = parsedOutput.toolCalls[0].id; + + EXPECT_EQ(parsedOutput.toolCalls[1].name, "another_tool"); + EXPECT_EQ(parsedOutput.toolCalls[1].arguments, "{\"param1\":\"data\",\"param2\":true}"); + EXPECT_EQ(parsedOutput.toolCalls[1].id.empty(), false); + auto secondToolCallId = parsedOutput.toolCalls[1].id; + EXPECT_NE(firstToolCallId, secondToolCallId); + + EXPECT_EQ(parsedOutput.toolCalls[2].name, "third_tool"); + EXPECT_EQ(parsedOutput.toolCalls[2].arguments, "{\"key\":\"value\"}"); + EXPECT_EQ(parsedOutput.toolCalls[2].id.empty(), false); + auto thirdToolCallId = parsedOutput.toolCalls[2].id; + EXPECT_NE(firstToolCallId, thirdToolCallId); + EXPECT_NE(secondToolCallId, thirdToolCallId); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallOutputWithThreeToolCallsWithContentInBetween) { + std::string inputWithProperClosure = "Before tool calls content. " + "<|tool_call>call:example_tool{arg1:<|\"|>value1<|\"|>,arg2:42}" + "This is some content between tool calls." + "<|tool_call>call:another_tool{param1:<|\"|>data<|\"|>,param2:true}" + " This is some content between second and third tool call. " + "<|tool_call>call:third_tool{key:<|\"|>value<|\"|>}" + "After tool calls content."; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, "Before tool calls content. This is some content between tool calls. This is some content between second and third tool call. After tool calls content."); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 3); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "example_tool"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"arg1\":\"value1\",\"arg2\":42}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + auto firstToolCallId = parsedOutput.toolCalls[0].id; + + EXPECT_EQ(parsedOutput.toolCalls[1].name, "another_tool"); + EXPECT_EQ(parsedOutput.toolCalls[1].arguments, "{\"param1\":\"data\",\"param2\":true}"); + EXPECT_EQ(parsedOutput.toolCalls[1].id.empty(), false); + auto secondToolCallId = parsedOutput.toolCalls[1].id; + EXPECT_NE(firstToolCallId, secondToolCallId); + + EXPECT_EQ(parsedOutput.toolCalls[2].name, "third_tool"); + EXPECT_EQ(parsedOutput.toolCalls[2].arguments, "{\"key\":\"value\"}"); + EXPECT_EQ(parsedOutput.toolCalls[2].id.empty(), false); + auto thirdToolCallId = parsedOutput.toolCalls[2].id; + EXPECT_NE(firstToolCallId, thirdToolCallId); + EXPECT_NE(secondToolCallId, thirdToolCallId); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithEmptyArguments) { + // Tool call with empty braces (no arguments) + std::string input = "<|tool_call>call:no_args_tool{}"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "no_args_tool"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallOutputWithContentAndNoToolCalls) { + std::string input = "This is a regular model response without tool calls."; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, "This is a regular model response without tool calls."); + ASSERT_EQ(parsedOutput.toolCalls.size(), 0); + EXPECT_EQ(parsedOutput.reasoning, ""); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallOutputWithContentAndSingleToolCall) { + std::string input = "This is a content part and next will be a tool call.\n\n<|tool_call>call:example_tool{arg1:<|\"|>value1<|\"|>,arg2:42}"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, "This is a content part and next will be a tool call.\n\n"); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "example_tool"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"arg1\":\"value1\",\"arg2\":42}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); +} + +TEST_F(Gemma4OutputParserTest, HolisticStreaming) { + std::vector>> chunkToDeltaVec{ + {"JUST_SOME_STRING_BEFORE_SPECIAL_STARTING_TAG", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"JUST_SOME_STRING_BEFORE_SPECIAL_STARTING_TAG"}})"}, + {"<|tool_call>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"call:", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"sort", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"{array", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"id":"XXXXXXXXX","type":"function","index":0,"function":{"name":"sort"}}]}})"}, + {":[", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"42", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 17", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 89", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 5", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 33", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"],", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"order", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"desc", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"ending", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"}", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"array\":[42,17,89,5,33],\"order\":\"descending\"}"}}]}})"}, + {"call:d", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"ummy", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"{config", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"id":"XXXXXXXXX","type":"function","index":1,"function":{"name":"dummy"}}]}})"}, + {":{", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"name", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"astro_config", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"value", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"99", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"}}", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"index":1,"function":{"arguments":"{\"config\":{\"name\":\"astro_config\",\"value\":99}}"}}]}})"}, + {"", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"ANOTHER_CONTENT_AFTER_TOOL_CALL", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"ANOTHER_CONTENT_AFTER_TOOL_CALL"}})"}, + }; + + for (const auto& [chunk, finishReason, expectedDelta] : chunkToDeltaVec) { + std::optional doc = outputParserWithRegularToolParsing->parseChunk(chunk, true, finishReason); + if (!expectedDelta.has_value() && !doc.has_value()) { + continue; // Both are nullopt, OK + } + if (expectedDelta.has_value() && doc.has_value()) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + std::string docStr = buffer.GetString(); + // If both strings contain "id":"...", compare id values by length and alphanumeric, else compare whole strings + std::string expected = expectedDelta.value(); + std::string idKey = "\"id\":\""; + auto docIdPos = docStr.find(idKey); + auto expectedIdPos = expected.find(idKey); + if (docIdPos != std::string::npos && expectedIdPos != std::string::npos) { + auto docIdStart = docIdPos + idKey.size(); + auto docIdEnd = docStr.find("\"", docIdStart); + auto expectedIdStart = expectedIdPos + idKey.size(); + auto expectedIdEnd = expected.find("\"", expectedIdStart); + ASSERT_NE(docIdEnd, std::string::npos); + ASSERT_NE(expectedIdEnd, std::string::npos); + std::string docId = docStr.substr(docIdStart, docIdEnd - docIdStart); + std::string expectedId = expected.substr(expectedIdStart, expectedIdEnd - expectedIdStart); + EXPECT_EQ(docId.size(), expectedId.size()) << "ID length mismatch for chunk: " << chunk; + EXPECT_TRUE(std::all_of(docId.begin(), docId.end(), ::isalnum)) << "ID not alphanumeric for chunk: " << chunk; + // Compare everything except the id value + std::string docStrNoId = docStr; + std::string expectedNoId = expected; + docStrNoId.replace(docIdStart, docId.size(), std::string(docId.size(), '*')); + expectedNoId.replace(expectedIdStart, expectedId.size(), std::string(expectedId.size(), '*')); + EXPECT_EQ(docStrNoId, expectedNoId) << "Mismatch for chunk (ignoring id value): " << chunk; + } else { + EXPECT_EQ(docStr, expected) << "Mismatch for chunk: " << chunk; + } + } else { + std::string expectedStr = expectedDelta.has_value() ? expectedDelta.value() : "std::nullopt"; + std::string docStr = doc.has_value() ? [&]() { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + return std::string(buffer.GetString()); + }() + : "std::nullopt"; + FAIL() << "Mismatch between expectedDelta and doc for chunk: " << chunk + << "\nexpectedDelta: " << expectedStr + << "\ndoc: " << docStr; + } + } +} + +TEST_F(Gemma4OutputParserTest, StreamingWithBiggerChunks) { + std::vector>> chunkToDeltaVec{ + {"SOME_CONTENT", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"SOME_CONTENT"}})"}, + {"MORE_CONTENT<|tool_call>", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"MORE_CONTENT"}})"}, + {"call:sort", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"{array:", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"id":"XXXXXXXXX","type":"function","index":0,"function":{"name":"sort"}}]}})"}, + {"[42, 17, 89, 5, 33],order:<|\"|>descending<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"}", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"array\":[42,17,89,5,33],\"order\":\"descending\"}"}}]}})"}, + {"", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"ANOTHER_CONTENT_AFTER_TOOL_CALL", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"ANOTHER_CONTENT_AFTER_TOOL_CALL"}})"}, + }; + + for (const auto& [chunk, finishReason, expectedDelta] : chunkToDeltaVec) { + std::optional doc = outputParserWithRegularToolParsing->parseChunk(chunk, true, finishReason); + if (!expectedDelta.has_value() && !doc.has_value()) { + continue; // Both are nullopt, OK + } + if (expectedDelta.has_value() && doc.has_value()) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + std::string docStr = buffer.GetString(); + std::string expected = expectedDelta.value(); + std::string idKey = "\"id\":\""; + auto docIdPos = docStr.find(idKey); + auto expectedIdPos = expected.find(idKey); + if (docIdPos != std::string::npos && expectedIdPos != std::string::npos) { + auto docIdStart = docIdPos + idKey.size(); + auto docIdEnd = docStr.find("\"", docIdStart); + auto expectedIdStart = expectedIdPos + idKey.size(); + auto expectedIdEnd = expected.find("\"", expectedIdStart); + ASSERT_NE(docIdEnd, std::string::npos); + ASSERT_NE(expectedIdEnd, std::string::npos); + std::string docId = docStr.substr(docIdStart, docIdEnd - docIdStart); + std::string expectedId = expected.substr(expectedIdStart, expectedIdEnd - expectedIdStart); + EXPECT_EQ(docId.size(), expectedId.size()) << "ID length mismatch for chunk: " << chunk; + EXPECT_TRUE(std::all_of(docId.begin(), docId.end(), ::isalnum)) << "ID not alphanumeric for chunk: " << chunk; + std::string docStrNoId = docStr; + std::string expectedNoId = expected; + docStrNoId.replace(docIdStart, docId.size(), std::string(docId.size(), '*')); + expectedNoId.replace(expectedIdStart, expectedId.size(), std::string(expectedId.size(), '*')); + EXPECT_EQ(docStrNoId, expectedNoId) << "Mismatch for chunk (ignoring id value): " << chunk; + } else { + EXPECT_EQ(docStr, expected) << "Mismatch for chunk: " << chunk; + } + } else { + std::string expectedStr = expectedDelta.has_value() ? expectedDelta.value() : "std::nullopt"; + std::string docStr = doc.has_value() ? [&]() { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + return std::string(buffer.GetString()); + }() + : "std::nullopt"; + FAIL() << "Mismatch between expectedDelta and doc for chunk: " << chunk + << "\nexpectedDelta: " << expectedStr + << "\ndoc: " << docStr; + } + } +} + +TEST_F(Gemma4OutputParserTest, StreamingWithContentBetweenToolCalls) { + std::vector>> chunkToDeltaVec{ + {"JUST_SOME_STRING_BEFORE_SPECIAL_STARTING_TAG", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"JUST_SOME_STRING_BEFORE_SPECIAL_STARTING_TAG"}})"}, + {"<|tool_call>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"call:sort", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"{array", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"id":"XXXXXXXXX","type":"function","index":0,"function":{"name":"sort"}}]}})"}, + {":[", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"42", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 17", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 89", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 5", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 33", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"],", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"order", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"desc", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"ending", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>}", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"array\":[42,17,89,5,33],\"order\":\"descending\"}"}}]}})"}, + {"", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"Some ", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"Some "}})"}, + {"content ", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"content "}})"}, + {"between ", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"between "}})"}, + {"tool ", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"tool "}})"}, + {"calls.", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"calls."}})"}, + {"<|tool_call>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"call:d", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"ummy", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"{config", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"id":"XXXXXXXXX","type":"function","index":1,"function":{"name":"dummy"}}]}})"}, + {":{", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"name", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"astro_config", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"value", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"99", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"}}", ov ::genai ::GenerationFinishReason ::NONE, R"({"delta":{"tool_calls":[{"index":1,"function":{"arguments":"{\"config\":{\"name\":\"astro_config\",\"value\":99}}"}}]}})"}, + {"", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"ANOTHER_CONTENT_AFTER_TOOL_CALL", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"ANOTHER_CONTENT_AFTER_TOOL_CALL"}})"}, + {"<|tool_call>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"call:solve", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"{e", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"id":"XXXXXXXXX","type":"function","index":2,"function":{"name":"solve"}}]}})"}, + {"quation", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"2", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"*", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"(", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"x", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"+", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"5)", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" =", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 13", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>}", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"index":2,"function":{"arguments":"{\"equation\":\"2*(x+5) = 13\"}"}}]}})"}, + {"", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"And some content after second tool call", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"And some content after second tool call"}})"}, + }; + + for (const auto& [chunk, finishReason, expectedDelta] : chunkToDeltaVec) { + std::optional doc = outputParserWithRegularToolParsing->parseChunk(chunk, true, finishReason); + if (!expectedDelta.has_value() && !doc.has_value()) { + continue; // Both are nullopt, OK + } + if (expectedDelta.has_value() && doc.has_value()) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + std::string docStr = buffer.GetString(); + std::string expected = expectedDelta.value(); + std::string idKey = "\"id\":\""; + auto docIdPos = docStr.find(idKey); + auto expectedIdPos = expected.find(idKey); + if (docIdPos != std::string::npos && expectedIdPos != std::string::npos) { + auto docIdStart = docIdPos + idKey.size(); + auto docIdEnd = docStr.find("\"", docIdStart); + auto expectedIdStart = expectedIdPos + idKey.size(); + auto expectedIdEnd = expected.find("\"", expectedIdStart); + ASSERT_NE(docIdEnd, std::string::npos); + ASSERT_NE(expectedIdEnd, std::string::npos); + std::string docId = docStr.substr(docIdStart, docIdEnd - docIdStart); + std::string expectedId = expected.substr(expectedIdStart, expectedIdEnd - expectedIdStart); + EXPECT_EQ(docId.size(), expectedId.size()) << "ID length mismatch for chunk: " << chunk; + EXPECT_TRUE(std::all_of(docId.begin(), docId.end(), ::isalnum)) << "ID not alphanumeric for chunk: " << chunk; + // Compare everything except the id value + std::string docStrNoId = docStr; + std::string expectedNoId = expected; + docStrNoId.replace(docIdStart, docId.size(), std::string(docId.size(), '*')); + expectedNoId.replace(expectedIdStart, expectedId.size(), std::string(expectedId.size(), '*')); + EXPECT_EQ(docStrNoId, expectedNoId) << "Mismatch for chunk (ignoring id value): " << chunk; + } else { + EXPECT_EQ(docStr, expected) << "Mismatch for chunk: " << chunk; + } + } else { + std::string expectedStr = expectedDelta.has_value() ? expectedDelta.value() : "std::nullopt"; + std::string docStr = doc.has_value() ? [&]() { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + return std::string(buffer.GetString()); + }() + : "std::nullopt"; + FAIL() << "Mismatch between expectedDelta and doc for chunk: " << chunk + << "\nexpectedDelta: " << expectedStr + << "\ndoc: " << docStr; + } + } +} + +TEST_F(Gemma4OutputParserTest, ToolCallsWithoutToolsInTheRequestStreaming) { + std::vector>> chunkToDeltaVec{ + {"<|tool_call>", "{\"delta\":{\"content\":\"<|tool_call>\"}}"}, + {"call:super", "{\"delta\":{\"content\":\"call:super\"}}"}, + {"_tool_number_two", "{\"delta\":{\"content\":\"_tool_number_two\"}}"}, + {"{arg1", "{\"delta\":{\"content\":\"{arg1\"}}"}, + {":<|\"|>", "{\"delta\":{\"content\":\":<|\\\"|>\"}}"}, + {"val{{{ue1", "{\"delta\":{\"content\":\"val{{{ue1\"}}"}, + {"<|\"|>}", "{\"delta\":{\"content\":\"<|\\\"|>}\"}}"}, + {"", "{\"delta\":{\"content\":\"\"}}"}, + }; + + for (const auto& [chunk, expectedDelta] : chunkToDeltaVec) { + std::optional doc = outputParserWithRegularToolParsing->parseChunk(chunk, false, ov::genai::GenerationFinishReason::NONE); + assertChunkEqual(doc, expectedDelta, chunk); + } +} + +// Malformed tool calls + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithMissingParentheses) { + std::string input = "<|tool_call>call:broken_tool"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + ASSERT_EQ(parsedOutput.toolCalls.size(), 0); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithMissingClosingParenthesis) { + std::string input = "<|tool_call>call:broken_tool{arg1:<|\"|>value1<|\"|>"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + ASSERT_EQ(parsedOutput.toolCalls.size(), 0); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithArgumentMissingEquals) { + std::string input = "<|tool_call>call:broken{malformed_arg}"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "broken"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsContainingComparison) { + std::string input = R"x(<|tool_call>call:search{query:<|"|>price >= 100, (sale)<|"|>,limit:5})x"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "search"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"x({"query":"price >= 100, (sale)","limit":5})x"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsContainingBracesAndBrackets) { + std::string input = R"(<|tool_call>call:format{template:<|"|>Hello {name}, items: [a, b, c]<|"|>,count:3})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "format"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"template":"Hello {name}, items: [a, b, c]","count":3})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsContainingSpecialCharacters) { + std::string impl = "import package\nimport package2\n\ndef func(a, b):\n\td={\"python\": \"dict\"}\n\tl = [\"list \\\"with escaped text\\\"\", 123, []]\n\treturn f\"formatted {a} and {b}\""; + std::string input = R"(<|tool_call>call:execute{code:<|"|>)" + impl + R"(<|"|>})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "execute"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"code":"import package\nimport package2\n\ndef func(a, b):\n\td={\"python\": \"dict\"}\n\tl = [\"list \\\"with escaped text\\\"\", 123, []]\n\treturn f\"formatted {a} and {b}\""})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsContainingEscapedQuotes) { + std::string input = R"x(<|tool_call>call:execute{code:<|"|>print(\"hello world\")<|"|>,verbose:true})x"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "execute"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"x({"code":"print(\"hello world\")","verbose":true})x"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsContainingApostrophes) { + std::string input = R"(<|tool_call>call:log{message:<|"|>it's a test, isn't it?<|"|>,level:<|"|>warn<|"|>})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "log"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"message":"it's a test, isn't it?","level":"warn"})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsContainingBackslashes) { + std::string input = R"(<|tool_call>call:read_file{path:<|"|>C:\Users\test\file.txt<|"|>,encoding:<|"|>utf-8<|"|>})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "read_file"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"path":"C:\\Users\\test\\file.txt","encoding":"utf-8"})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsArrayWithStringsContainingQuotes) { + std::string input = R"(<|tool_call>call:save{lines:[<|"|>it's the wonderful day<|"|>,<|"|>He said: "My name's John"<|"|>,<|"|>That's Johns' car.<|"|>]})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "save"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"lines":["it's the wonderful day","He said: \"My name's John\"","That's Johns' car."]})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsObjectWithStringsContainingQuotes) { + std::string input = R"(<|tool_call>call:save{obj:{name:<|"|>it's the wonderful day<|"|>,greeting:<|"|>Hello, my name's Jan<|"|>,note:<|"|>That's Johns' car.<|"|>}})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "save"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"obj":{"name":"it's the wonderful day","greeting":"Hello, my name's Jan","note":"That's Johns' car."}})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsContainingNestedJSON) { + std::string input = R"(<|tool_call>call:send{payload:<|"|>{'key': 'value', 'count': 42}<|"|>,endpoint:<|"|>api<|"|>})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "send"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"payload":"{'key': 'value', 'count': 42}","endpoint":"api"})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithEmptyStringArgument) { + std::string input = R"(<|tool_call>call:create{name:<|"|><|"|>,value:0})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "create"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"name":"","value":0})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithUnicodeCharactersInArguments) { + std::string input = R"(<|tool_call>call:translate{text:<|"|>zażółć gęślą jaźń<|"|>,lang:<|"|>pl<|"|>})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "translate"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"text":"zażółć gęślą jaźń","lang":"pl"})"); +} diff --git a/windows_prepare_llm_models.bat b/windows_prepare_llm_models.bat index b0c7e5db26..1a96574318 100644 --- a/windows_prepare_llm_models.bat +++ b/windows_prepare_llm_models.bat @@ -44,6 +44,7 @@ set "PHI4_MODEL=microsoft/Phi-4-mini-instruct" set "MISTRAL_MODEL=mistralai/Mistral-7B-Instruct-v0.3" set "GPTOSS_MODEL=openai/gpt-oss-20b" set "DEVSTRAL_MODEL=unsloth/Devstral-Small-2507" +set "GEMMA4_MODEL=google/gemma-4-26B-A4B-it" echo Downloading LLM testing models to directory %~1 set "PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly" @@ -83,6 +84,7 @@ call :download_tokenizer "%PHI4_MODEL%" "%~1\%PHI4_MODEL%" call :download_tokenizer "%MISTRAL_MODEL%" "%~1\%MISTRAL_MODEL%" call :download_tokenizer "%GPTOSS_MODEL%" "%~1\%GPTOSS_MODEL%" call :download_tokenizer "%DEVSTRAL_MODEL%" "%~1\%DEVSTRAL_MODEL%" +call :download_tokenizer "%GEMMA4_MODEL%" "%~1\%GEMMA4_MODEL%" exit /b 0