diff --git a/src/main/java/org/beehive/gpullama3/Options.java b/src/main/java/org/beehive/gpullama3/Options.java index 919f9751..274428b7 100644 --- a/src/main/java/org/beehive/gpullama3/Options.java +++ b/src/main/java/org/beehive/gpullama3/Options.java @@ -11,8 +11,8 @@ public record Options(Path modelPath, String prompt, String systemPrompt, String public Options { require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\""); - require(0 <= temperature, "Invalid argument: --temperature must be non-negative"); - require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]"); + require(Float.isNaN(temperature) || 0 <= temperature, "Invalid argument: --temperature must be non-negative"); + require(Float.isNaN(topp) || 0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]"); require(batchPrefillSize >= 1, "Invalid argument: --batch-prefill-size must be >= 1"); require(batchPrefillSize == 1 || withPrefillDecode, "Invalid argument: --batch-prefill-size requires --with-prefill-decode"); // Publish to system properties so TornadoVMMasterPlan and Llama read the right values @@ -44,8 +44,8 @@ public static void printUsage(PrintStream out) { out.println(" --prompt, -p input prompt"); out.println(" --system-prompt, -sp (optional) system prompt (Llama models)"); out.println(" --suffix suffix for fill-in-the-middle request (Codestral)"); - out.println(" --temperature, -temp temperature in [0,inf], default 0.1"); - out.println(" --top-p p value in top-p (nucleus) sampling in [0,1] default 0.95"); + out.println(" --temperature, -temp temperature in [0,inf], default: auto-detected from model family"); + out.println(" --top-p p value in top-p (nucleus) sampling in [0,1], default: auto-detected from model family"); out.println(" --seed random seed, default System.nanoTime()"); out.println(" --max-tokens, -n number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS); out.println(" --stream print tokens during generation; may cause encoding artifacts for non ASCII text, default true"); @@ -59,8 +59,8 @@ public static Options getDefaultOptions() { String prompt = "Tell me a story with Java"; // Hardcoded for testing String systemPrompt = null; String suffix = null; - float temperature = 0.1f; - float topp = 0.95f; + float temperature = Float.NaN; // resolved from model family after loading + float topp = Float.NaN; // resolved from model family after loading Path modelPath = null; long seed = System.nanoTime(); int maxTokens = DEFAULT_MAX_TOKENS; @@ -76,8 +76,8 @@ public static Options parseOptions(String[] args) { String prompt = "Tell me a story with Java"; // Hardcoded for testing String systemPrompt = null; String suffix = null; - float temperature = 0.1f; - float topp = 0.95f; + float temperature = Float.NaN; // resolved from model family after loading + float topp = Float.NaN; // resolved from model family after loading Path modelPath = null; long seed = System.nanoTime(); int maxTokens = DEFAULT_MAX_TOKENS; diff --git a/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java b/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java index 496d0761..bbb7c5ad 100644 --- a/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java +++ b/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java @@ -122,7 +122,13 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp, } static Sampler createSampler(Model model, Options options) { - return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed()); + float temperature = Float.isNaN(options.temperature()) + ? (float) model.chatFormat().defaultTemperature() + : options.temperature(); + float topp = Float.isNaN(options.topp()) + ? (float) model.chatFormat().defaultTopP() + : options.topp(); + return selectSampler(model.configuration().vocabularySize(), temperature, topp, options.seed()); } /** diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index 827ad625..9f7121d4 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -7,7 +7,9 @@ import org.beehive.gpullama3.tokenizer.Phi3Tokenizer; import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer; +import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.Set; public interface ChatFormat { @@ -36,6 +38,155 @@ default ChatTokens chatTokens() { Set getStopTokens(); + /** + * Returns {@code true} when this chat format supports tool calling. + * Formats that implement tool-calling methods must override this to return {@code true}. + * Callers should check this before passing tool specifications to avoid hitting the + * default {@link UnsupportedOperationException} deep inside a format method. + */ + default boolean supportsToolCalling() { + return false; + } + + /** + * Returns plain text to append to the system message content when tools are available. + * Used by formats that inject tool definitions into the system message. + * + *

Formats that inject tools into the user message instead should override + * {@link #injectsToolsInUserMessage()}, {@link #toolSystemMessagePrefix()}, and + * {@link #toolFirstUserMessagePrefix(String)} rather than this method. + * + * @param toolsJson JSON array of tool definitions + */ + default String toolSystemPromptSuffix(String toolsJson) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Returns {@code true} when this format injects tool definitions into the + * first user message instead of the system message. + * + *

When this returns {@code true}, callers should: + *

    + *
  1. Prepend {@link #toolSystemMessagePrefix()} to the system message content.
  2. + *
  3. Prepend {@link #toolFirstUserMessagePrefix(String)} to the first user message.
  4. + *
+ * When {@code false} (default), callers should append {@link #toolSystemPromptSuffix} to + * the system message as before. + */ + default boolean injectsToolsInUserMessage() { + return false; + } + + /** + * Returns text to prepend to the system message content when tools are active + * and {@link #injectsToolsInUserMessage()} is {@code true}. + * Default: empty string (no prefix). + */ + default String toolSystemMessagePrefix() { + return ""; + } + + /** + * Returns the preamble to prepend to the first user message when + * {@link #injectsToolsInUserMessage()} is {@code true}. + * The preamble should include the tool definitions and usage instructions. + * + * @param toolsJson JSON array of tool definitions + */ + default String toolFirstUserMessagePrefix(String toolsJson) { + return ""; + } + + /** + * Re-encodes a prior assistant tool-call turn into the conversation token stream. + * Used when replaying multi-turn history that contains a previous tool call. + * + * @param toolCall the tool call to encode (name + raw arguments JSON) + */ + default List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Re-encodes a prior assistant turn that contained one or more tool calls as a + * single assistant message. Implementations must emit all calls inside one + * header/footer pair so the model does not see spurious assistant turn boundaries. + * + *

The default delegates to {@link #encodeToolCallAssistantTurn(ToolCallExtract)} + * for single-element lists and naively concatenates individual encodings for larger + * lists — formats that support batch tool calls should override this method. + * + * @param toolCalls the ordered list of tool calls from a single assistant turn + */ + default List encodeToolCallAssistantTurn(List toolCalls) { + if (toolCalls.isEmpty()) return List.of(); + if (toolCalls.size() == 1) return encodeToolCallAssistantTurn(toolCalls.get(0)); + List tokens = new ArrayList<>(); + for (ToolCallExtract tc : toolCalls) { + tokens.addAll(encodeToolCallAssistantTurn(tc)); + } + return tokens; + } + + /** + * Encodes a tool execution result message in the model-native format. + * + * @param toolCallId the ID of the originating tool call (may be ignored by some formats) + * @param toolName the name of the tool that was called + * @param result the result content string + */ + default List encodeToolResultTurn(String toolCallId, String toolName, String result) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Detects and extracts a tool call from fully decoded model response text. + * Returns {@link Optional#empty()} when the response is a plain text answer. + * + * @param responseText the fully decoded response from the model + */ + default Optional extractToolCall(String responseText) { + return Optional.empty(); + } + + /** + * Extracts ALL tool calls from a response. Models may emit multiple + * {@code } blocks in a single turn (batch tool calls). + * The default delegates to {@link #extractToolCall} for formats that + * do not support batch calls. + * + * @param responseText the fully decoded response from the model + */ + default List extractAllToolCalls(String responseText) { + return extractToolCall(responseText).map(List::of).orElse(List.of()); + } + + /** + * Returns the recommended default temperature for this chat format. + * Used when the caller has not explicitly configured a temperature. + */ + default double defaultTemperature() { + return 0.7; + } + + /** + * Returns the recommended default top-p for this chat format. + * Used when the caller has not explicitly configured a top-p value. + */ + default double defaultTopP() { + return 0.9; + } + + /** + * Stop tokens to use when tool calling is enabled. + * Some models (LLaMA 3.1+) use a different end-of-turn token ({@code <|eom_id|>}) + * when emitting a tool call instead of a regular response. + */ + default Set getToolAwareStopTokens() { + return getStopTokens(); + } + record ChatTokens(String tStartHeader, String tEndHeader, String tEndOfTurn, String tEndOfText, String tEndOfTextFim) { } diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index c98a72c9..f23e3c26 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -6,6 +6,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; public class LlamaChatFormat implements ChatFormat { @@ -17,6 +18,7 @@ public class LlamaChatFormat implements ChatFormat { protected final int endOfTurn; protected final int endOfText; protected final int endOfMessage; + protected final int pythonTag; protected final Set stopTokens; public LlamaChatFormat(Tokenizer tokenizer) { @@ -28,6 +30,7 @@ public LlamaChatFormat(Tokenizer tokenizer) { this.endOfTurn = specialTokens.get("<|eot_id|>"); this.endOfText = specialTokens.get("<|end_of_text|>"); this.endOfMessage = specialTokens.getOrDefault("<|eom_id|>", -1); // only in 3.1 + this.pythonTag = specialTokens.getOrDefault("<|python_tag|>", -1); // only in 3.1 this.stopTokens = Set.of(endOfText, endOfTurn); } @@ -71,4 +74,155 @@ public List encodeDialogPrompt(boolean appendAssistantTurn, Listfirst user message + * (the GGUF-embedded chat template has {@code tools_in_user_message = true} by default). + * The system message receives only an environment prefix; the tools and usage instructions + * go in the user turn. + */ + @Override + public boolean injectsToolsInUserMessage() { + return true; + } + + /** + * System-message prefix that signals tool availability to Llama 3.2. + * Matches the template's {@code "Environment: ipython\n"} line. + */ + @Override + public String toolSystemMessagePrefix() { + return "Environment: ipython\n\n"; + } + + /** + * Prepends tool definitions and usage instructions to the first user message, + * matching the Llama 3.2 GGUF chat template ({@code tools_in_user_message = true}). + * + *

Format mirrors: + *

+     * Given the following functions, please respond with a JSON for a function call
+     * with its proper arguments that best answers the given prompt.
+     *
+     * Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.
+     * Do not use variables.
+     *
+     * {toolsJson}
+     *
+     * 
+ */ + @Override + public String toolFirstUserMessagePrefix(String toolsJson) { + return "Given the following functions, please respond with a JSON for a function call " + + "with its proper arguments that best answers the given prompt.\n\n" + + "Respond in the format {\"name\": function name, \"parameters\": dictionary of " + + "argument name and its value}. Do not use variables.\n\n" + + toolsJson + "\n\n"; + } + + /** + * Re-encodes a prior assistant tool-call turn for multi-turn history using the + * Llama 3.2 native JSON format: {@code {"name":"…","parameters":{…}}<|eot_id|>}. + */ + @Override + public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); + // Preserve the <|python_tag|> prefix used by LLaMA 3.1/3.2 for tool calls so that + // replayed history looks identical to what the model originally generated. + if (pythonTag != -1) { + tokens.add(pythonTag); + } + String json = "{\"name\": \"" + toolCall.name() + "\", \"parameters\": " + toolCall.argumentsJson() + "}"; + tokens.addAll(tokenizer.encodeAsList(json)); + // LLaMA 3.1 ends tool-call turns with <|eom_id|>; fall back to <|eot_id|> for 3.2. + tokens.add(endOfMessage != -1 ? endOfMessage : endOfTurn); + return tokens; + } + + /** + * Encodes a tool result using the LLaMA "ipython" role. + * Format: {@code <|start_header_id|>ipython<|end_header_id|>\nresult<|eot_id|>} + */ + @Override + public List encodeToolResultTurn(String toolCallId, String toolName, String result) { + List tokens = new ArrayList<>(); + tokens.add(startHeader); + tokens.addAll(tokenizer.encodeAsList("ipython")); + tokens.add(endHeader); + tokens.addAll(tokenizer.encodeAsList("\n")); + tokens.addAll(tokenizer.encodeAsList(result)); + tokens.add(endOfTurn); + return tokens; + } + + /** + * Encodes multiple tool calls as a single assistant turn. + * For a single call, delegates to the existing single-call method (preserving the + * {@code <|python_tag|>} prefix on LLaMA 3.1). + * For multiple calls, LLaMA 3.1 prefixes each with {@code <|python_tag|>}; + * LLaMA 3.2 (no python_tag) uses {@code } blocks. + */ + @Override + public List encodeToolCallAssistantTurn(List toolCalls) { + if (toolCalls.isEmpty()) return List.of(); + if (toolCalls.size() == 1) return encodeToolCallAssistantTurn(toolCalls.get(0)); + List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); + for (ToolCallExtract tc : toolCalls) { + String json = "{\"name\": \"" + tc.name() + "\", \"parameters\": " + tc.argumentsJson() + "}"; + if (pythonTag != -1) { + tokens.add(pythonTag); + tokens.addAll(tokenizer.encodeAsList(json + "\n")); + } else { + tokens.addAll(tokenizer.encodeAsList("\n" + json + "\n\n")); + } + } + tokens.add(endOfMessage != -1 ? endOfMessage : endOfTurn); + return tokens; + } + + /** + * Detects a tool call in the decoded response text. + * Supports LLaMA 3.1 (native {@code <|python_tag|>} + {@code "parameters"} key), + * LLaMA 3.2 ({@code "arguments"} key, tag often absent), and a raw-JSON fallback + * for smaller models. Delegates to {@link ToolCallParserUtils#parseToolCallResponse}. + */ + @Override + public Optional extractToolCall(String responseText) { + return ToolCallParserUtils.parseToolCallResponse(responseText); + } + + @Override + public List extractAllToolCalls(String responseText) { + return ToolCallParserUtils.parseAllToolCalls(responseText); + } + + /** + * Adds {@code <|eom_id|>} to the stop tokens when tools are enabled. + * LLaMA 3.1 ends tool-call turns with {@code <|eom_id|>} instead of {@code <|eot_id|>}. + */ + @Override + public Set getToolAwareStopTokens() { + if (endOfMessage != -1) { + return Set.of(endOfText, endOfTurn, endOfMessage); + } + return stopTokens; + } + } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index b6d2e798..8fc761c1 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -14,7 +14,6 @@ public class Qwen3ChatFormat implements ChatFormat { protected final int endHeader; protected final int endOfTurn; protected final int endOfText; - protected final int endOfMessage; protected final int endOfTextFim; protected final int imStart; // beginOfText protected final int imEnd; // endOfText @@ -28,13 +27,12 @@ public Qwen3ChatFormat(Qwen3Tokenizer tokenizer, ChatTokens chatTokens) { this.tokenizer = tokenizer; this.chatTokens = chatTokens; Map specialTokens = tokenizer.getSpecialTokens(); - this.beginOfText = specialTokens.getOrDefault("", -1); + this.beginOfText = -1; // Qwen3 has no BOS token; getBeginOfText() falls back to startHeader this.startHeader = specialTokens.getOrDefault(chatTokens.tStartHeader(), -1); this.endHeader = specialTokens.getOrDefault(chatTokens.tEndHeader(), -1); this.endOfTurn = specialTokens.getOrDefault(chatTokens.tEndOfTurn(), -1); this.endOfText = specialTokens.getOrDefault(chatTokens.tEndOfText(), -1); this.endOfTextFim = specialTokens.getOrDefault(chatTokens.tEndOfTextFim(), -1); - this.endOfMessage = specialTokens.getOrDefault("", -1); // Use default value if key not found this.imStart = startHeader; this.imEnd = endHeader; @@ -129,4 +127,110 @@ public Set getStopTokens() { return stopTokens; } + + @Override + public double defaultTemperature() { + return 0.8; + } + + @Override + public double defaultTopP() { + return 0.9; + } + + // ── Tool calling ────────────────────────────────────────────────────────── + + @Override + public boolean supportsToolCalling() { + return true; + } + + /** + * Qwen3 tool calling system prompt suffix. + * Appended to the system message; instructs the model to wrap tool calls in + * {@code } XML tags. + */ + @Override + public String toolSystemPromptSuffix(String toolsJson) { + return "\n\n# Tools\n\n" + + "You may call one or more functions to assist with the user query.\n\n" + + "You are provided with function signatures within XML tags:\n" + + "\n" + + toolsJson + + "\n\n\n" + + "For each function call, return a json object with function name and arguments " + + "within XML tags:\n" + + "\n" + + "{\"name\": , \"arguments\": }\n" + + ""; + } + + /** + * Re-encodes a prior assistant tool-call turn for multi-turn history. + * Format: {@code <|im_start|>assistant\n\nJSON\n<|im_end|>} + */ + @Override + public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + List tokens = new ArrayList<>(); + tokens.add(imStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("assistant\n")); + String json = "{\"name\":\"" + toolCall.name() + "\",\"arguments\":" + toolCall.argumentsJson() + "}"; + tokens.addAll(tokenizer.encodeOrdinaryAsList("\n" + json + "\n")); + if (imEnd != -1) { + tokens.add(imEnd); + } + return tokens; + } + + /** + * Encodes multiple tool calls as a single assistant turn: one {@code <|im_start|>assistant} + * header, all {@code } blocks concatenated, then {@code <|im_end|>}. + * For a single call, delegates to the existing single-call method. + */ + @Override + public List encodeToolCallAssistantTurn(List toolCalls) { + if (toolCalls.isEmpty()) return List.of(); + if (toolCalls.size() == 1) return encodeToolCallAssistantTurn(toolCalls.get(0)); + List tokens = new ArrayList<>(); + tokens.add(imStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("assistant\n")); + for (ToolCallExtract tc : toolCalls) { + String json = "{\"name\":\"" + tc.name() + "\",\"arguments\":" + tc.argumentsJson() + "}"; + tokens.addAll(tokenizer.encodeOrdinaryAsList("\n" + json + "\n")); + } + if (imEnd != -1) { + tokens.add(imEnd); + } + return tokens; + } + + /** + * Encodes a tool result using the Qwen3 "tool" role. + * Format: {@code <|im_start|>tool\nresult<|im_end|>} + */ + @Override + public List encodeToolResultTurn(String toolCallId, String toolName, String result) { + List tokens = new ArrayList<>(); + tokens.add(imStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("tool\n")); + tokens.addAll(tokenizer.encodeOrdinaryAsList(result)); + if (imEnd != -1) { + tokens.add(imEnd); + } + return tokens; + } + + /** + * Detects a tool call enclosed in {@code } tags. + * Delegates to {@link ToolCallParserUtils#parseToolCallResponse}. + */ + @Override + public Optional extractToolCall(String responseText) { + return ToolCallParserUtils.parseToolCallResponse(responseText); + } + + @Override + public List extractAllToolCalls(String responseText) { + return ToolCallParserUtils.parseAllToolCalls(responseText); + } } diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java new file mode 100644 index 00000000..b5f82c51 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java @@ -0,0 +1,20 @@ +package org.beehive.gpullama3.model.format; + +import java.util.Optional; + +/** + * Represents a single tool call extracted from a model response. + * Contains the raw strings — JSON parsing of arguments is left to the caller. + * + * @param name the tool/function name to invoke + * @param argumentsJson the arguments as a JSON object string, e.g. {"location":"Boston"} + * @param id optional tool call ID parsed from the model response; callers that + * generate IDs themselves (e.g. Ollama-style "call_XXXXXXXX") may pass + * {@link Optional#empty()} and let the consumer generate one + */ +public record ToolCallExtract(String name, String argumentsJson, Optional id) { + + public ToolCallExtract(String name, String argumentsJson) { + this(name, argumentsJson, Optional.empty()); + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java new file mode 100644 index 00000000..a0a856a0 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java @@ -0,0 +1,195 @@ +package org.beehive.gpullama3.model.format; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** + * Pure-string tool-call extraction for Llama and Qwen3 response formats. + * + * All methods are stateless and do not require any model or tokenizer instance, + * making them directly unit-testable. + */ +public final class ToolCallParserUtils { + + private ToolCallParserUtils() {} + + /** + * Extracts a single tool call from a model response text. + * + * Recognised formats (in priority order): + * 1. {@code <|python_tag|>{…}} — LLaMA 3.1 native + * 2. {@code } — LLaMA 3.2 and Qwen3 (closed or unclosed) + * 3. Raw JSON object optionally inside markdown code fences — fallback + * + * Both {@code "parameters"} and {@code "arguments"} are tried as the argument key, + * covering LLaMA 3.1/3.2 and Qwen3 variants transparently. + */ + public static Optional parseToolCallResponse(String responseText) { + // 1. Native LLaMA 3.1 format: <|python_tag|>{...} + int idx = responseText.indexOf("<|python_tag|>"); + if (idx != -1) { + String json = responseText.substring(idx + "<|python_tag|>".length()).strip(); + return parseToolCallJson(json); + } + + // 2. LLaMA 3.2 format: ... + int tcStart = responseText.indexOf(""); + int tcEnd = responseText.lastIndexOf(""); + if (tcStart != -1 && tcEnd != -1 && tcEnd > tcStart) { + String json = responseText.substring(tcStart + "".length(), tcEnd).strip(); + return parseToolCallJson(json); + } + // 2b. Unclosed — model stopped (eot_id / eom_id) before writing the closing tag + if (tcStart != -1 && tcEnd == -1) { + String json = responseText.substring(tcStart + "".length()).strip(); + return parseToolCallJson(json); + } + + // 3. Fallback: raw JSON, possibly inside markdown code fences + String stripped = stripMarkdownFences(responseText.strip()); + if (stripped.startsWith("{")) { + return parseToolCallJson(stripped); + } + + return Optional.empty(); + } + + /** + * Parses a tool call JSON object extracted from a {@code } block or raw JSON. + * Accepts {@code {"name":…,"parameters":{…}}}, {@code {"function":…,"parameters":{…}}}, + * and {@code {"name":…,"arguments":{…}}} — covering both LLaMA and Qwen3 variants. + */ + private static Optional parseToolCallJson(String json) { + String name = extractStringValue(json, "name"); + if (name == null) { + name = extractStringValue(json, "function"); + } + if (name == null) return Optional.empty(); + + String argsJson = extractNestedObject(json, "parameters"); + if (argsJson == null) argsJson = extractNestedObject(json, "arguments"); + if (argsJson == null) argsJson = "{}"; + + return Optional.of(new ToolCallExtract(name, argsJson)); + } + + // Batch extraction + + /** + * Extracts ALL tool calls from a response that may contain multiple + * {@code } blocks (Llama 3.2 and Qwen3 batch calls). + * + * Falls back to the raw-JSON single-call path if no tags are found. + * Returns an empty list when the response contains no tool calls. + */ + public static List parseAllToolCalls(String responseText) { + List calls = new ArrayList<>(); + + // <|python_tag|> (Llama 3.1) — single call by definition + int pythonIdx = responseText.indexOf("<|python_tag|>"); + if (pythonIdx != -1) { + parseToolCallJson(responseText.substring(pythonIdx + "<|python_tag|>".length()).strip()) + .ifPresent(calls::add); + return calls; + } + + // Scan for all blocks + int searchFrom = 0; + while (true) { + int start = responseText.indexOf("", searchFrom); + if (start == -1) break; + int end = responseText.indexOf("", start); + String json; + if (end != -1) { + json = responseText.substring(start + "".length(), end).strip(); + searchFrom = end + "".length(); + } else { + // Unclosed tag — model stopped before writing the closing tag + json = responseText.substring(start + "".length()).strip(); + searchFrom = responseText.length(); + } + parseToolCallJson(json).ifPresent(calls::add); + if (end == -1) break; + } + + // Raw JSON fallback (no tags at all) + if (calls.isEmpty()) { + String stripped = stripMarkdownFences(responseText.strip()); + if (stripped.startsWith("{")) { + parseToolCallJson(stripped).ifPresent(calls::add); + } + } + + return calls; + } + + // Shared helpers + + /** Strips surrounding markdown code fences (```…```) if present. */ + public static String stripMarkdownFences(String text) { + if (!text.startsWith("```")) return text; + int firstNewline = text.indexOf('\n'); + if (firstNewline == -1) return text; + String body = text.substring(firstNewline + 1); + if (body.endsWith("```")) body = body.substring(0, body.length() - 3).stripTrailing(); + return body.strip(); + } + + /** + * Extracts the string value for {@code "key": ""} from a JSON object. + * Tolerates whitespace around {@code :} and correctly skips escaped quotes ({@code \"}) + * inside the value, so multi-line code strings with embedded {@code "} are returned intact. + */ + private static String extractStringValue(String json, String key) { + String marker = "\"" + key + "\""; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int colonIdx = json.indexOf(':', markerIdx + marker.length()); + if (colonIdx == -1) return null; + int quoteStart = json.indexOf('"', colonIdx + 1); + if (quoteStart == -1) return null; + // Scan for the closing quote, honouring backslash escapes + int i = quoteStart + 1; + while (i < json.length()) { + char c = json.charAt(i); + if (c == '\\') { + i += 2; // skip escape sequence (e.g. \", \\, \n) + } else if (c == '"') { + break; + } else { + i++; + } + } + if (i >= json.length()) return null; + return json.substring(quoteStart + 1, i); + } + + /** + * Extracts the JSON object value for {@code "key": {…}} using brace-counting. + * Handles nested objects and tolerates whitespace around {@code :}. + * Array brackets {@code […]} are tracked so that {@code {}/{}} characters inside + * array elements do not affect the outer brace depth counter. + */ + private static String extractNestedObject(String json, String key) { + String marker = "\"" + key + "\""; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int colonIdx = json.indexOf(':', markerIdx + marker.length()); + if (colonIdx == -1) return null; + int braceStart = json.indexOf('{', colonIdx + 1); + if (braceStart == -1) return null; + int depth = 0; + int arrayDepth = 0; + for (int i = braceStart; i < json.length(); i++) { + char c = json.charAt(i); + if (c == '[') arrayDepth++; + else if (c == ']') arrayDepth--; + else if (arrayDepth == 0 && c == '{') depth++; + else if (arrayDepth == 0 && c == '}') { + if (--depth == 0) return json.substring(braceStart, i + 1); + } + } + return null; // unbalanced + } +}