Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/main/java/org/beehive/gpullama3/Options.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ public record Options(Path modelPath, String prompt, String systemPrompt, String

public Options {
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
require(Float.isNaN(temperature) || 0 <= temperature, "Invalid argument: --temperature must be non-negative");
require(Float.isNaN(topp) || 0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
require(batchPrefillSize >= 1, "Invalid argument: --batch-prefill-size must be >= 1");
require(batchPrefillSize == 1 || withPrefillDecode, "Invalid argument: --batch-prefill-size requires --with-prefill-decode");
// Publish to system properties so TornadoVMMasterPlan and Llama read the right values
Expand Down Expand Up @@ -44,8 +44,8 @@ public static void printUsage(PrintStream out) {
out.println(" --prompt, -p <string> input prompt");
out.println(" --system-prompt, -sp <string> (optional) system prompt (Llama models)");
out.println(" --suffix <string> suffix for fill-in-the-middle request (Codestral)");
out.println(" --temperature, -temp <float> temperature in [0,inf], default 0.1");
out.println(" --top-p <float> p value in top-p (nucleus) sampling in [0,1] default 0.95");
out.println(" --temperature, -temp <float> temperature in [0,inf], default: auto-detected from model family");
out.println(" --top-p <float> p value in top-p (nucleus) sampling in [0,1], default: auto-detected from model family");
out.println(" --seed <long> random seed, default System.nanoTime()");
out.println(" --max-tokens, -n <int> number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS);
out.println(" --stream <boolean> print tokens during generation; may cause encoding artifacts for non ASCII text, default true");
Expand All @@ -59,8 +59,8 @@ public static Options getDefaultOptions() {
String prompt = "Tell me a story with Java"; // Hardcoded for testing
String systemPrompt = null;
String suffix = null;
float temperature = 0.1f;
float topp = 0.95f;
float temperature = Float.NaN; // resolved from model family after loading
float topp = Float.NaN; // resolved from model family after loading
Path modelPath = null;
long seed = System.nanoTime();
int maxTokens = DEFAULT_MAX_TOKENS;
Expand All @@ -76,8 +76,8 @@ public static Options parseOptions(String[] args) {
String prompt = "Tell me a story with Java"; // Hardcoded for testing
String systemPrompt = null;
String suffix = null;
float temperature = 0.1f;
float topp = 0.95f;
float temperature = Float.NaN; // resolved from model family after loading
float topp = Float.NaN; // resolved from model family after loading
Path modelPath = null;
long seed = System.nanoTime();
int maxTokens = DEFAULT_MAX_TOKENS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,13 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
}

static Sampler createSampler(Model model, Options options) {
return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed());
float temperature = Float.isNaN(options.temperature())
? (float) model.chatFormat().defaultTemperature()
: options.temperature();
float topp = Float.isNaN(options.topp())
? (float) model.chatFormat().defaultTopP()
: options.topp();
return selectSampler(model.configuration().vocabularySize(), temperature, topp, options.seed());
}

/**
Expand Down
151 changes: 151 additions & 0 deletions src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import org.beehive.gpullama3.tokenizer.Phi3Tokenizer;
import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;

public interface ChatFormat {
Expand Down Expand Up @@ -36,6 +38,155 @@ default ChatTokens chatTokens() {

Set<Integer> getStopTokens();

/**
* Returns {@code true} when this chat format supports tool calling.
* Formats that implement tool-calling methods must override this to return {@code true}.
* Callers should check this before passing tool specifications to avoid hitting the
* default {@link UnsupportedOperationException} deep inside a format method.
*/
default boolean supportsToolCalling() {
return false;
}

/**
* Returns plain text to append to the system message content when tools are available.
* Used by formats that inject tool definitions into the <em>system</em> message.
*
* <p>Formats that inject tools into the <em>user</em> message instead should override
* {@link #injectsToolsInUserMessage()}, {@link #toolSystemMessagePrefix()}, and
* {@link #toolFirstUserMessagePrefix(String)} rather than this method.
*
* @param toolsJson JSON array of tool definitions
*/
default String toolSystemPromptSuffix(String toolsJson) {
throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName());
}

/**
* Returns {@code true} when this format injects tool definitions into the
* <em>first user message</em> instead of the system message.
*
* <p>When this returns {@code true}, callers should:
* <ol>
* <li>Prepend {@link #toolSystemMessagePrefix()} to the system message content.</li>
* <li>Prepend {@link #toolFirstUserMessagePrefix(String)} to the first user message.</li>
* </ol>
* When {@code false} (default), callers should append {@link #toolSystemPromptSuffix} to
* the system message as before.
*/
default boolean injectsToolsInUserMessage() {
return false;
}

/**
* Returns text to <em>prepend</em> to the system message content when tools are active
* and {@link #injectsToolsInUserMessage()} is {@code true}.
* Default: empty string (no prefix).
*/
default String toolSystemMessagePrefix() {
return "";
}

/**
* Returns the preamble to <em>prepend</em> to the first user message when
* {@link #injectsToolsInUserMessage()} is {@code true}.
* The preamble should include the tool definitions and usage instructions.
*
* @param toolsJson JSON array of tool definitions
*/
default String toolFirstUserMessagePrefix(String toolsJson) {
return "";
}

/**
* Re-encodes a prior assistant tool-call turn into the conversation token stream.
* Used when replaying multi-turn history that contains a previous tool call.
*
* @param toolCall the tool call to encode (name + raw arguments JSON)
*/
default List<Integer> encodeToolCallAssistantTurn(ToolCallExtract toolCall) {
throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName());
}

/**
* Re-encodes a prior assistant turn that contained one or more tool calls as a
* <em>single</em> assistant message. Implementations must emit all calls inside one
* header/footer pair so the model does not see spurious assistant turn boundaries.
*
* <p>The default delegates to {@link #encodeToolCallAssistantTurn(ToolCallExtract)}
* for single-element lists and naively concatenates individual encodings for larger
* lists — formats that support batch tool calls should override this method.
*
* @param toolCalls the ordered list of tool calls from a single assistant turn
*/
default List<Integer> encodeToolCallAssistantTurn(List<ToolCallExtract> toolCalls) {
if (toolCalls.isEmpty()) return List.of();
if (toolCalls.size() == 1) return encodeToolCallAssistantTurn(toolCalls.get(0));
List<Integer> tokens = new ArrayList<>();
for (ToolCallExtract tc : toolCalls) {
tokens.addAll(encodeToolCallAssistantTurn(tc));
}
return tokens;
}

/**
* Encodes a tool execution result message in the model-native format.
*
* @param toolCallId the ID of the originating tool call (may be ignored by some formats)
* @param toolName the name of the tool that was called
* @param result the result content string
*/
default List<Integer> encodeToolResultTurn(String toolCallId, String toolName, String result) {
throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName());
}

/**
* Detects and extracts a tool call from fully decoded model response text.
* Returns {@link Optional#empty()} when the response is a plain text answer.
*
* @param responseText the fully decoded response from the model
*/
default Optional<ToolCallExtract> extractToolCall(String responseText) {
return Optional.empty();
}

/**
* Extracts ALL tool calls from a response. Models may emit multiple
* {@code <tool_call>} blocks in a single turn (batch tool calls).
* The default delegates to {@link #extractToolCall} for formats that
* do not support batch calls.
*
* @param responseText the fully decoded response from the model
*/
default List<ToolCallExtract> extractAllToolCalls(String responseText) {
return extractToolCall(responseText).map(List::of).orElse(List.of());
}

/**
* Returns the recommended default temperature for this chat format.
* Used when the caller has not explicitly configured a temperature.
*/
default double defaultTemperature() {
return 0.7;
}

/**
* Returns the recommended default top-p for this chat format.
* Used when the caller has not explicitly configured a top-p value.
*/
default double defaultTopP() {
return 0.9;
}

/**
* Stop tokens to use when tool calling is enabled.
* Some models (LLaMA 3.1+) use a different end-of-turn token ({@code <|eom_id|>})
* when emitting a tool call instead of a regular response.
*/
default Set<Integer> getToolAwareStopTokens() {
return getStopTokens();
}

record ChatTokens(String tStartHeader, String tEndHeader, String tEndOfTurn, String tEndOfText, String tEndOfTextFim) {
}

Expand Down
Loading
Loading