Skip to content

Commit 59c06d0

Browse files
committed
feat: add SentencePiece unigram encoding for models without merges
SentencePiece unigram models (e.g., Mistral 7B GGUF) provide vocabulary scores but no BPE merge table. Without this, encoding fails silently, producing wrong token IDs and garbage output. Add SetScores() to BPETokenizer and a greedy longest-match encoder that selects tokens by length first, then by score. When merges are empty but scores are present, encodeSegment automatically uses this path instead of BPE merging. Also extend the gguf.Metadata interface with GetFloat32Array and extract tokenizer.ggml.scores in ExtractTokenizer so GGUF-loaded tokenizers automatically use unigram encoding when appropriate.
1 parent a91bf94 commit 59c06d0

4 files changed

Lines changed: 398 additions & 8 deletions

File tree

bpe.go

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package ztoken
33
import (
44
"fmt"
55
"strings"
6+
"unicode/utf8"
67
)
78

89
// MergePair represents an adjacent token pair used in BPE merging.
@@ -15,6 +16,8 @@ type MergePair struct {
1516

1617
// BPETokenizer implements the Tokenizer interface using byte-pair encoding.
1718
// It loads vocabulary and merge rules from HuggingFace tokenizer.json format.
19+
// When scores are set and merges are empty, it falls back to SentencePiece
20+
// unigram encoding using greedy longest-match with score-based selection.
1821
//
1922
// Stable.
2023
type BPETokenizer struct {
@@ -38,6 +41,13 @@ type BPETokenizer struct {
3841
specialTokens map[string]int
3942
// normalizer is an optional text normalization function applied before tokenization.
4043
normalizer NormalizerFunc
44+
// scores holds SentencePiece unigram scores (negative log probabilities)
45+
// indexed by token ID. When scores are set and merges are empty, the
46+
// tokenizer uses greedy longest-match encoding instead of BPE merging.
47+
scores []float32
48+
// maxTokenLen caches the length (in bytes) of the longest token in vocab,
49+
// used to bound the search window in sentencePieceEncode.
50+
maxTokenLen int
4151
}
4252

4353
// NewBPETokenizer creates a BPETokenizer from vocabulary, merge rules, and special tokens.
@@ -94,13 +104,21 @@ func (t *BPETokenizer) encodeSegment(text string, addLeadingSpace bool) ([]int,
94104
} else {
95105
words = strings.Fields(text)
96106
}
107+
// When merges are empty but scores are available, use SentencePiece
108+
// unigram encoding (greedy longest-match) instead of BPE merging.
109+
useUnigram := len(t.mergeRanks) == 0 && len(t.scores) > 0
110+
97111
var ids []int
98112
for _, word := range words {
99-
wordIDs, err := t.encodeWord(word)
100-
if err != nil {
101-
return nil, err
113+
if useUnigram {
114+
ids = append(ids, t.sentencePieceEncode(word)...)
115+
} else {
116+
wordIDs, err := t.encodeWord(word)
117+
if err != nil {
118+
return nil, err
119+
}
120+
ids = append(ids, wordIDs...)
102121
}
103-
ids = append(ids, wordIDs...)
104122
}
105123
return ids, nil
106124
}
@@ -226,6 +244,85 @@ func (t *BPETokenizer) SetSpecialTokenStrings(tokens map[string]int) {
226244
t.specialTokens = tokens
227245
}
228246

247+
// SetScores sets token scores for SentencePiece unigram encoding.
248+
// When scores are set and merges are empty, the tokenizer uses
249+
// score-based greedy encoding instead of BPE merge-based encoding.
250+
// Scores are indexed by token ID (negative log probabilities).
251+
func (t *BPETokenizer) SetScores(scores []float32) {
252+
t.scores = scores
253+
// Precompute max token length in bytes for search window bounding.
254+
t.maxTokenLen = 0
255+
for tok := range t.vocab {
256+
if len(tok) > t.maxTokenLen {
257+
t.maxTokenLen = len(tok)
258+
}
259+
}
260+
}
261+
262+
// sentencePieceEncode tokenizes text using greedy longest-match with
263+
// score-based selection. For each position, it finds all vocabulary tokens
264+
// that match the input at that position, selects the longest match (breaking
265+
// ties by highest score), and advances past the matched token.
266+
//
267+
// This is used for SentencePiece unigram models that provide vocabulary
268+
// scores but no BPE merge table (e.g., Mistral 7B GGUF).
269+
func (t *BPETokenizer) sentencePieceEncode(text string) []int {
270+
if text == "" {
271+
return nil
272+
}
273+
var ids []int
274+
pos := 0
275+
textBytes := []byte(text)
276+
n := len(textBytes)
277+
278+
for pos < n {
279+
bestLen := 0
280+
bestID := t.special.UNK
281+
bestScore := float32(-1e30)
282+
283+
// Search for the longest matching token at this position.
284+
// Limit search window to maxTokenLen bytes.
285+
maxEnd := pos + t.maxTokenLen
286+
if maxEnd > n {
287+
maxEnd = n
288+
}
289+
290+
for end := pos + 1; end <= maxEnd; end++ {
291+
candidate := string(textBytes[pos:end])
292+
if id, ok := t.vocab[candidate]; ok {
293+
candidateLen := end - pos
294+
// Prefer longer matches. For equal length, prefer higher score.
295+
if candidateLen > bestLen || (candidateLen == bestLen && t.tokenScore(id) > bestScore) {
296+
bestLen = candidateLen
297+
bestID = id
298+
bestScore = t.tokenScore(id)
299+
}
300+
}
301+
}
302+
303+
if bestLen == 0 {
304+
// No matching token found; emit UNK and advance by one byte.
305+
ids = append(ids, t.special.UNK)
306+
// Advance past one UTF-8 character, not just one byte.
307+
_, size := decodeRune(textBytes[pos:])
308+
pos += size
309+
} else {
310+
ids = append(ids, bestID)
311+
pos += bestLen
312+
}
313+
}
314+
return ids
315+
}
316+
317+
// tokenScore returns the score for a token ID, or 0 if scores are not set
318+
// or the ID is out of range.
319+
func (t *BPETokenizer) tokenScore(id int) float32 {
320+
if id >= 0 && id < len(t.scores) {
321+
return t.scores[id]
322+
}
323+
return 0
324+
}
325+
229326
// sentencePiecePreTokenize implements SentencePiece-style pre-tokenization.
230327
// Text is split on whitespace boundaries. Words that follow a space get ▁
231328
// (U+2581) prepended. Newlines are emitted as separate tokens.
@@ -404,5 +501,15 @@ func isPrintableGPT2Byte(b byte) bool {
404501
return false
405502
}
406503

504+
// decodeRune decodes the first UTF-8 rune from b and returns it with its byte length.
505+
// If b is empty or invalid, it returns utf8.RuneError and 1 to ensure forward progress.
506+
func decodeRune(b []byte) (rune, int) {
507+
r, size := utf8.DecodeRune(b)
508+
if size == 0 {
509+
return utf8.RuneError, 1
510+
}
511+
return r, size
512+
}
513+
407514
// Statically assert BPETokenizer implements Tokenizer.
408515
var _ Tokenizer = (*BPETokenizer)(nil)

bpe_test.go

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,3 +398,210 @@ func TestBPETokenizer_ByteLevelBPE(t *testing.T) {
398398
t.Errorf("Decode(%v) = %q, want \"hi\"", ids, decoded)
399399
}
400400
}
401+
402+
// makeTestSentencePieceUnigram creates a SentencePiece unigram tokenizer
403+
// with vocabulary and scores but no merges, simulating Mistral 7B GGUF.
404+
func makeTestSentencePieceUnigram() *BPETokenizer {
405+
vocab := map[string]int{
406+
"<unk>": 0,
407+
"<s>": 1,
408+
"</s>": 2,
409+
"\u2581": 3, // ▁
410+
"\u2581H": 4,
411+
"\u2581He": 5,
412+
"\u2581Hel": 6,
413+
"\u2581Hell": 7,
414+
"\u2581Hello": 8,
415+
"\u2581w": 9,
416+
"\u2581wo": 10,
417+
"\u2581wor": 11,
418+
"\u2581worl": 12,
419+
"\u2581world": 13,
420+
"H": 14,
421+
"e": 15,
422+
"l": 16,
423+
"o": 17,
424+
"w": 18,
425+
"r": 19,
426+
"d": 20,
427+
"\u2581the": 21,
428+
"\u2581is": 22,
429+
"\u2581a": 23,
430+
"\u2581test": 24,
431+
"t": 25,
432+
"s": 26,
433+
}
434+
435+
// Scores: higher (less negative) = more likely. Longer tokens get better scores.
436+
scores := make([]float32, 27)
437+
scores[0] = -100 // <unk>
438+
scores[1] = -100 // <s>
439+
scores[2] = -100 // </s>
440+
scores[3] = -5.0 // ▁
441+
scores[4] = -3.0 // ▁H
442+
scores[5] = -2.5 // ▁He
443+
scores[6] = -2.0 // ▁Hel
444+
scores[7] = -1.5 // ▁Hell
445+
scores[8] = -1.0 // ▁Hello (best for "Hello")
446+
scores[9] = -3.0 // ▁w
447+
scores[10] = -2.5 // ▁wo
448+
scores[11] = -2.0 // ▁wor
449+
scores[12] = -1.5 // ▁worl
450+
scores[13] = -1.0 // ▁world (best for "world")
451+
scores[14] = -4.0 // H
452+
scores[15] = -4.0 // e
453+
scores[16] = -4.0 // l
454+
scores[17] = -4.0 // o
455+
scores[18] = -4.0 // w
456+
scores[19] = -4.0 // r
457+
scores[20] = -4.0 // d
458+
scores[21] = -1.0 // ▁the
459+
scores[22] = -1.0 // ▁is
460+
scores[23] = -1.5 // ▁a
461+
scores[24] = -1.0 // ▁test
462+
scores[25] = -4.0 // t
463+
scores[26] = -4.0 // s
464+
465+
special := SpecialTokens{BOS: 1, EOS: 2, PAD: 0, UNK: 0}
466+
// No merges — this is a unigram model.
467+
tok := NewBPETokenizer(vocab, nil, special, false)
468+
tok.SetSentencePiece(true)
469+
tok.SetScores(scores)
470+
return tok
471+
}
472+
473+
func TestSentencePieceUnigram_Encode(t *testing.T) {
474+
tok := makeTestSentencePieceUnigram()
475+
476+
tests := []struct {
477+
name string
478+
input string
479+
wantIDs []int
480+
}{
481+
{"single word", "Hello", []int{8}}, // ▁Hello
482+
{"two words", "Hello world", []int{8, 13}}, // ▁Hello ▁world
483+
{"sentence", "the world is a test", []int{21, 13, 22, 23, 24}}, // ▁the ▁world ▁is ▁a ▁test
484+
{"empty string", "", []int{}},
485+
}
486+
487+
for _, tc := range tests {
488+
t.Run(tc.name, func(t *testing.T) {
489+
ids, err := tok.Encode(tc.input)
490+
if err != nil {
491+
t.Fatalf("Encode(%q) error: %v", tc.input, err)
492+
}
493+
if len(ids) != len(tc.wantIDs) {
494+
t.Fatalf("Encode(%q) = %v (len=%d), want %v (len=%d)", tc.input, ids, len(ids), tc.wantIDs, len(tc.wantIDs))
495+
}
496+
for i, id := range ids {
497+
if id != tc.wantIDs[i] {
498+
t.Errorf("Encode(%q)[%d] = %d, want %d", tc.input, i, id, tc.wantIDs[i])
499+
}
500+
}
501+
})
502+
}
503+
}
504+
505+
func TestSentencePieceUnigram_Decode(t *testing.T) {
506+
tok := makeTestSentencePieceUnigram()
507+
508+
tests := []struct {
509+
name string
510+
ids []int
511+
wantText string
512+
wantErr bool
513+
}{
514+
{"single token", []int{8}, "Hello", false},
515+
{"multiple tokens", []int{8, 13}, "Hello world", false},
516+
{"empty", []int{}, "", false},
517+
{"unknown ID", []int{999}, "", true},
518+
}
519+
520+
for _, tc := range tests {
521+
t.Run(tc.name, func(t *testing.T) {
522+
got, err := tok.Decode(tc.ids)
523+
if tc.wantErr {
524+
if err == nil {
525+
t.Fatalf("Decode(%v) expected error, got %q", tc.ids, got)
526+
}
527+
return
528+
}
529+
if err != nil {
530+
t.Fatalf("Decode(%v) error: %v", tc.ids, err)
531+
}
532+
if got != tc.wantText {
533+
t.Errorf("Decode(%v) = %q, want %q", tc.ids, got, tc.wantText)
534+
}
535+
})
536+
}
537+
}
538+
539+
func TestSentencePieceUnigram_RoundTrip(t *testing.T) {
540+
tok := makeTestSentencePieceUnigram()
541+
542+
tests := []string{"Hello", "Hello world", "the world is a test"}
543+
for _, text := range tests {
544+
ids, err := tok.Encode(text)
545+
if err != nil {
546+
t.Fatalf("Encode(%q) error: %v", text, err)
547+
}
548+
decoded, err := tok.Decode(ids)
549+
if err != nil {
550+
t.Fatalf("Decode(%v) error: %v", ids, err)
551+
}
552+
if decoded != text {
553+
t.Errorf("round-trip failed: %q -> %v -> %q", text, ids, decoded)
554+
}
555+
}
556+
}
557+
558+
func TestSentencePieceUnigram_UnknownChars(t *testing.T) {
559+
tok := makeTestSentencePieceUnigram()
560+
561+
// Characters not in vocab should produce UNK tokens.
562+
ids, err := tok.Encode("xyz")
563+
if err != nil {
564+
t.Fatalf("Encode error: %v", err)
565+
}
566+
// "xyz" -> pre-tokenized as "▁xyz". Since ▁x, ▁y, ▁z are not in vocab,
567+
// the greedy matcher will match ▁ first, then x, y, z individually → all UNK.
568+
for _, id := range ids {
569+
if id != tok.special.UNK {
570+
// Either ▁ (id=3) or UNK (id=0) are acceptable since ▁ is in vocab.
571+
if id != 3 {
572+
t.Errorf("expected UNK or ▁ token for unknown chars, got id=%d", id)
573+
}
574+
}
575+
}
576+
}
577+
578+
func TestSentencePieceUnigram_PrefersLongestMatch(t *testing.T) {
579+
tok := makeTestSentencePieceUnigram()
580+
581+
// "Hello" should encode as one token ▁Hello (id=8), not ▁H + e + l + l + o.
582+
ids, err := tok.Encode("Hello")
583+
if err != nil {
584+
t.Fatalf("Encode error: %v", err)
585+
}
586+
if len(ids) != 1 {
587+
t.Errorf("expected 1 token for 'Hello', got %d: %v", len(ids), ids)
588+
}
589+
if ids[0] != 8 {
590+
t.Errorf("expected token id 8 (▁Hello), got %d", ids[0])
591+
}
592+
}
593+
594+
func TestSentencePieceUnigram_WithBPEFallback(t *testing.T) {
595+
// When merges ARE present, unigram encoding should NOT be used
596+
// even if scores are also set.
597+
tok := makeTestBPE()
598+
tok.SetScores([]float32{0, 0, 0, 0}) // set scores but merges exist
599+
ids, err := tok.Encode("hello")
600+
if err != nil {
601+
t.Fatalf("Encode error: %v", err)
602+
}
603+
// Should still use BPE merging, producing "hello" (id=17).
604+
if len(ids) != 1 || ids[0] != 17 {
605+
t.Errorf("with merges present, expected BPE encoding [17], got %v", ids)
606+
}
607+
}

gguf/gguf.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ type Metadata interface {
1919
GetStringArray(key string) ([]string, bool)
2020
GetUint32(key string) (uint32, bool)
2121
GetInt32Array(key string) ([]int32, bool)
22+
GetFloat32Array(key string) ([]float32, bool)
2223
}
2324

2425
// ExtractTokenizer builds a BPETokenizer from GGUF metadata. GGUF files store
@@ -72,6 +73,13 @@ func ExtractTokenizer(m Metadata) (*ztoken.BPETokenizer, error) {
7273
tok.SetSentencePiece(true)
7374
}
7475

76+
// Extract token scores for SentencePiece unigram models. When scores
77+
// are present but merges are absent, the tokenizer uses greedy
78+
// longest-match encoding instead of BPE merge-based encoding.
79+
if scores, ok := m.GetFloat32Array("tokenizer.ggml.scores"); ok {
80+
tok.SetScores(scores)
81+
}
82+
7583
// Extract control/special tokens (token_type == 3) for exact matching
7684
// during encoding. Without this, tokens like <start_of_turn> would be
7785
// split into characters by BPE.

0 commit comments

Comments
 (0)