| title | Models |
|---|---|
| description | Model loading, supported architectures, tokenisation, and chat templates. |
go-mlx loads transformer models from either HuggingFace safetensors shards or GGUF checkpoints. For safetensors directories, architecture is detected from the model_type field in config.json; for direct .gguf loads via inference.LoadModel, it is detected from checkpoint metadata.
import (
"log"
"dappco.re/go/inference"
_ "dappco.re/go/mlx"
)
m, err := inference.LoadModel("/path/to/model/")
if err != nil {
log.Fatal(err)
}
defer m.Close()The model path may be either a model directory or an explicit .gguf file path.
When loading a directory, it must contain:
config.json-- model configuration (architecture, dimensions, quantisation)tokenizer.json-- HuggingFace BPE tokeniser- Weights in exactly one of these formats:
- One or more
*.safetensorsfiles (multi-shard supported) - Exactly one
*.gguffile
- One or more
m, err := inference.LoadModel("/path/to/model/",
inference.WithContextLen(262144), // larger Qwen-class context; default is 131072
inference.WithParallelSlots(1), // default: one foreground native request
inference.WithAdapterPath("/path/to/lora/"), // load LoRA adapter at init
)| Option | Effect |
|---|---|
WithContextLen(n) |
Replaces unbounded KV caches with RotatingKVCache(n); Metal defaults to 131072 |
WithParallelSlots(n) |
Caps concurrent native inference calls per loaded model; Metal defaults to 1 |
WithAdapterPath(dir) |
Loads a trained LoRA adapter from the given directory |
WithGPULayers(n) |
Ignored with a warning -- Metal always uses full GPU offload |
The direct mlx.LoadModel API also enables exact token-prefix prompt caching by
default. Use WarmPromptCache(prefix) for stable system prompts, tool schemas,
repository summaries, and AGENTS/policy prefixes before appending volatile user
turns.
Config values: gemma3, gemma3_text, gemma2
Decoder structure per layer (pre-norm with four norm points):
input -> InputNorm -> Attention -> PostAttnNorm -> residual add
-> PreFFNorm -> MLP -> PostFFNorm -> residual add
Key features:
- Q/K RMS normalisation -- separate
QNormandKNormmodules per attention layer - Alternating attention -- sliding window and global attention alternate based on
sliding_window_pattern(default 6). Sliding layers use theta 10000; global layers use theta 1000000. - Grouped-query attention (GQA) -- K/V heads repeated via
RepeatKVwhennum_kv_heads < num_attention_heads - Gemma-style normalisation -- weights are
(1 + weight)scaled, precomputed at load time to avoid repeated addition during inference - Embedding scale -- hidden states are multiplied by
sqrt(hidden_size)after embedding lookup - MLP -- GELU-based gate with tanh approximation, compiled via
CompileShapelessas a singleton - Output head -- typically tied to
embed_tokens; uses a separatelm_head.weightif present in the safetensors
Config values: gemma4, gemma4_text
Gemma 4 uses a dedicated loader (LoadGemma4) with several architecture-specific behaviours:
- Mixed attention head sizes -- sliding layers use
head_dim, full-attention layers useglobal_head_dim - Per-layer RoPE -- sliding attention defaults to theta 10000 and full attention to theta 1000000 with partial rotary
- Shared KV cache -- the tail of the network can reuse KV state from earlier same-type layers to reduce memory use
- K-equals-V support -- full-attention layers can reuse the K projection for V
- Value normalisation -- values pass through
RMSNormNoScalebefore caching - MoE routing -- router projections stay quantised at 8-bit and sparse experts dispatch through
gather_mm/gather_qmm - Weight sanitisation -- multimodal tower weights are stripped and
experts.gate_up_projtensors are split into separate gate/up weights
Gemma 4 chat formatting follows the same turn template as Gemma 3.
Config values: qwen3, qwen2, llama
These three architectures share one loader (LoadQwen3) and one decoder implementation. Decoder structure per layer (standard pre-norm):
input -> InputNorm -> Attention -> residual add
-> PostAttnNorm -> MLP -> residual add
| Feature | Qwen 3 | Qwen 2 | Llama 3 |
|---|---|---|---|
| Q/K norm | Yes | No | No |
| MLP type | SwiGLU | SwiGLU | SwiGLU |
| Output head | Separate lm_head |
Separate lm_head |
Separate lm_head |
MLP: SwiGLU gate -- down(silu(gate(x)) * up(x)).
Qwen 2 vs Qwen 3 detection: if model_type is absent, the presence of model.layers.0.self_attn.q_norm.weight in the weights distinguishes Qwen 3 (present) from Qwen 2 (absent).
The loader performs these steps:
- Reads
config.jsonfor model configuration - Loads
tokenizer.jsonfor the tokeniser - Loads weights from either all
*.safetensorsshards or a single.gguffile - Resolves weights by name, with automatic
language_model.prefix fallback - Constructs
Linearlayers as quantised or dense based on presence ofscalestensors - Calls
Materialize()on all weight arrays to commit them to GPU memory
Quantisation is transparent. Quantised models store packed weights alongside scales and biases tensors. The Linear.Forward() method dispatches to QuantizedMatmul (MLX grouped quantisation kernel) when these tensors are present. Quantisation parameters (bits, group_size) are read from top-level config.json.
If head_dim is absent from config.json (common in some Gemma 3 variants), it is computed from the Q projection weight shape: q_proj.weight[0] / num_attention_heads.
Tokenizer reads a tokenizer.json file and supports two BPE variants, auto-detected at load time.
- Prefixes each segment with
\u2581(Unicode U+2581, the SentencePiece space marker) - Splits into characters
- Applies BPE merges via a rank-sorted lookup table
- Looks up merged symbols in the vocabulary
Detection: absence of Gthe (GPT-2 space+the encoding) in the vocabulary. The check uses Gthe rather than bare G because large SentencePiece vocabularies may contain the character incidentally.
- Maps all 256 bytes to printable Unicode characters (GPT-2 convention)
- Printable ASCII (33-126) and Latin-1 Supplement (161-172, 174-255) map to themselves
- Control characters, space, DEL, and gaps map to U+0100 onwards
- Applies BPE merges in this Unicode representation
- Decodes back to raw bytes via the inverse map
Detection: presence of Gthe in the vocabulary.
Standard greedy algorithm:
- Build merge rank table from the merges field (O(1) lookup by
"a b"key) - Scan all adjacent pairs; find the pair with the lowest rank
- Merge that pair into a single symbol
- Repeat until no more merges apply
Merges are parsed from both ["a b", ...] and [["a","b"], ...] JSON formats.
Special tokens are matched before BPE encoding. Each architecture uses different stop tokens:
| Family | BOS | EOS / Stop |
|---|---|---|
| Gemma 3 / 4 | <bos> |
<end_of_turn> |
| Qwen 2/3 | <|im_start|> |
<|im_end|> |
| Llama 3 | <|begin_of_text|> |
<|eot_id|> |
Generate returns iter.Seq[Token] (Go 1.23+ range-over-func):
ctx := context.Background()
for tok := range m.Generate(ctx, "What is 2+2?", inference.WithMaxTokens(64)) {
fmt.Print(tok.Text)
}
if err := m.Err(); err != nil {
log.Fatal(err)
}Chat applies the model's native template before generating:
for tok := range m.Chat(ctx, []inference.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Translate 'hello' to French."},
}, inference.WithMaxTokens(64)) {
fmt.Print(tok.Text)
}Chat templates by architecture:
| Architecture | Format |
|---|---|
| Gemma 3 / 4 | <start_of_turn>role\ncontent<end_of_turn>\n |
| Qwen 2/3 | <|im_start|>role\ncontent<|im_end|>\n |
| Llama 3 | <|start_header_id|>role<|end_header_id|>\n\ncontent<|eot_id|> |
inference.WithMaxTokens(128) // maximum tokens to generate
inference.WithTemperature(0.7) // sampling temperature (0 = greedy)
inference.WithTopK(40) // top-K sampling
inference.WithTopP(0.9) // nucleus sampling
inference.WithRepeatPenalty(1.1) // repetition penalty
inference.WithStopTokens(1, 2) // additional stop token IDsThe direct root API adds mlx.WithMinP(0.05) for minimum-probability sampling.
When combined, sampling options are applied in this order: temperature, then top-p, then top-k, then min-p.
Pass a cancellable context to stop generation early:
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
for tok := range m.Generate(ctx, prompt, inference.WithMaxTokens(1000)) {
fmt.Print(tok.Text)
}
// m.Err() returns context.Canceled or context.DeadlineExceededSingle forward pass per batch -- no decode loop. Prompts are right-padded to the longest length:
results, err := m.Classify(ctx, []string{
"Bonjour, comment allez-vous?",
"The quarterly report shows growth.",
}, inference.WithTemperature(0))
for i, r := range results {
fmt.Printf("prompt %d -> token %d %q\n", i, r.Token.ID, r.Token.Text)
}Use inference.WithLogits() to return the full logit vector alongside the sampled token.
Full autoregressive generation for multiple prompts in parallel:
results, err := m.BatchGenerate(ctx, []string{
"The capital of France is",
"Water boils at",
}, inference.WithMaxTokens(32))
for i, r := range results {
for _, tok := range r.Tokens {
fmt.Print(tok.Text)
}
fmt.Println()
}Scan a directory tree for available models:
models, err := inference.Discover("/path/to/models/")
for _, d := range models {
fmt.Printf("%s (%s, %d-bit, %d files)\n",
d.Path, d.ModelType, d.QuantBits, d.NumFiles)
}info := m.Info()
fmt.Printf("%s: %d layers, %d vocab, %d hidden, %d-bit quant\n",
info.Architecture, info.NumLayers, info.VocabSize,
info.HiddenSize, info.QuantBits)
// After generation:
met := m.Metrics()
fmt.Printf("prefill: %.0f tok/s, decode: %.1f tok/s, peak GPU: %d MB\n",
met.PrefillTokensPerSec, met.DecodeTokensPerSec,
met.PeakMemoryBytes/1024/1024)Use the first-party harness when comparing planner changes, quantization choices, prompt-cache settings, or model builds on a local Apple machine:
go-mlx bench -json -prompt "Summarize local inference in one sentence." /path/to/modelThe report includes prefill/decode tok/s, peak and active Metal memory, prompt-cache hit rate, KV restore latency, state-bundle round-trip size/time, probe event counts, probe overhead, and small deterministic quality checks.
From Go:
report, err := mlx.RunFastEvalBench(ctx, model, mlx.DefaultFastEvalConfig())
fmt.Printf("decode %.1f tok/s, cache %.0f%%\n",
report.Generation.DecodeTokensPerSec,
report.PromptCache.HitRate*100)The Metal adapter implements inference.AttentionInspector, enabling extraction of post-RoPE K vectors from the KV cache:
inspector, ok := m.(inference.AttentionInspector)
snap, err := inspector.InspectAttention(ctx, "What is kindness?")
// snap.Keys[layer][head] -> flat float32 of len seq_len * head_dim
fmt.Printf("layers=%d heads=%d seq=%d dim=%d\n",
snap.NumLayers, snap.NumHeads, snap.SeqLen, snap.HeadDim)The K tensors are post-RoPE -- rotary position embeddings have already been applied. For GQA models, NumHeads reflects the KV head count, not the query head count.
- Create
internal/metal/{name}.gowith//go:build darwin && arm64 - Implement the
InternalModelinterface (Forward, ForwardMasked, NewCache, NumLayers, Tokenizer, ModelType, ApplyLoRA) - Add a case in
model.go:loadModelfor the newmodel_typevalue - Add a
close{Name}helper inclose.gofor deterministic weight cleanup - Add
format{Name}Chatingenerate.gofor the chat template - Add BOS/EOS detection in
tokenizer.go:LoadTokenizer - Write tests: config parsing, missing weights, end-to-end inference