Skip to content

Commit 8f43e44

Browse files
committed
fix: implement Viterbi SentencePiece encoding (replaces greedy)
The greedy longest-match approach in sentencePieceEncode produced suboptimal tokenization for SentencePiece unigram models (e.g., Mistral 7B). Replace it with Viterbi dynamic programming that finds the segmentation maximizing the sum of log-probability scores. Also adds: - Byte fallback encoding/decoding via <0xNN> tokens for chars not in vocab - decodeSentencePieceBytes for proper round-trip of byte fallback tokens - Tests: Viterbi vs greedy, byte fallback, sentence round-trip, edge cases
1 parent 59c06d0 commit 8f43e44

File tree

2 files changed

+347
-47
lines changed

2 files changed

+347
-47
lines changed

bpe.go

Lines changed: 122 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package ztoken
22

33
import (
44
"fmt"
5+
"math"
56
"strings"
67
"unicode/utf8"
78
)
@@ -202,6 +203,8 @@ func (t *BPETokenizer) Decode(ids []int) (string, error) {
202203
return decoded, nil
203204
}
204205
if t.sentencePiece {
206+
// Decode <0xNN> byte tokens back to actual bytes.
207+
result = decodeSentencePieceBytes(result)
205208
// Replace ▁ with space and trim leading space.
206209
result = strings.ReplaceAll(result, "\u2581", " ")
207210
result = strings.TrimPrefix(result, " ")
@@ -210,6 +213,43 @@ func (t *BPETokenizer) Decode(ids []int) (string, error) {
210213
return result, nil
211214
}
212215

216+
// decodeSentencePieceBytes replaces <0xNN> hex byte tokens with the
217+
// corresponding raw bytes. This reverses the byte fallback encoding
218+
// used by SentencePiece for characters not in the vocabulary.
219+
func decodeSentencePieceBytes(s string) string {
220+
var sb strings.Builder
221+
i := 0
222+
for i < len(s) {
223+
// Look for <0xNN> pattern: exactly 6 characters.
224+
if i+6 <= len(s) && s[i] == '<' && s[i+1] == '0' && s[i+2] == 'x' && s[i+5] == '>' {
225+
hi := unhex(s[i+3])
226+
lo := unhex(s[i+4])
227+
if hi >= 0 && lo >= 0 {
228+
sb.WriteByte(byte(hi<<4 | lo))
229+
i += 6
230+
continue
231+
}
232+
}
233+
sb.WriteByte(s[i])
234+
i++
235+
}
236+
return sb.String()
237+
}
238+
239+
// unhex converts a hex digit character to its value, or -1 if invalid.
240+
func unhex(c byte) int {
241+
switch {
242+
case c >= '0' && c <= '9':
243+
return int(c - '0')
244+
case c >= 'A' && c <= 'F':
245+
return int(c-'A') + 10
246+
case c >= 'a' && c <= 'f':
247+
return int(c-'a') + 10
248+
default:
249+
return -1
250+
}
251+
}
252+
213253
// VocabSize returns the number of tokens in the vocabulary.
214254
func (t *BPETokenizer) VocabSize() int {
215255
return len(t.vocab)
@@ -259,59 +299,104 @@ func (t *BPETokenizer) SetScores(scores []float32) {
259299
}
260300
}
261301

262-
// sentencePieceEncode tokenizes text using greedy longest-match with
263-
// score-based selection. For each position, it finds all vocabulary tokens
264-
// that match the input at that position, selects the longest match (breaking
265-
// ties by highest score), and advances past the matched token.
302+
// sentencePieceEncode tokenizes text using Viterbi dynamic programming to find
303+
// the segmentation that maximizes the sum of log-probability scores.
266304
//
267305
// This is used for SentencePiece unigram models that provide vocabulary
268-
// scores but no BPE merge table (e.g., Mistral 7B GGUF).
306+
// scores but no BPE merge table (e.g., Mistral 7B GGUF). The Viterbi approach
307+
// finds the globally optimal segmentation, unlike greedy longest-match which
308+
// can produce suboptimal splits.
269309
func (t *BPETokenizer) sentencePieceEncode(text string) []int {
270310
if text == "" {
271311
return nil
272312
}
273-
var ids []int
274-
pos := 0
275-
textBytes := []byte(text)
276-
n := len(textBytes)
277313

278-
for pos < n {
279-
bestLen := 0
280-
bestID := t.special.UNK
281-
bestScore := float32(-1e30)
282-
283-
// Search for the longest matching token at this position.
284-
// Limit search window to maxTokenLen bytes.
285-
maxEnd := pos + t.maxTokenLen
286-
if maxEnd > n {
287-
maxEnd = n
288-
}
314+
n := len(text) // byte length
315+
316+
// Viterbi forward pass: find best segmentation.
317+
// bestScore[i] = best total score for text[:i]
318+
// bestLen[i] = byte length of the last token in the best path ending at i
319+
bestScore := make([]float64, n+1)
320+
bestLen := make([]int, n+1)
321+
for i := range bestScore {
322+
bestScore[i] = math.Inf(-1)
323+
}
324+
bestScore[0] = 0
289325

290-
for end := pos + 1; end <= maxEnd; end++ {
291-
candidate := string(textBytes[pos:end])
326+
for i := 0; i < n; i++ {
327+
if math.IsInf(bestScore[i], -1) {
328+
continue
329+
}
330+
// Try all possible tokens starting at position i.
331+
maxLen := t.maxTokenLen
332+
if maxLen > n-i {
333+
maxLen = n - i
334+
}
335+
for tokenLen := 1; tokenLen <= maxLen; tokenLen++ {
336+
candidate := text[i : i+tokenLen]
292337
if id, ok := t.vocab[candidate]; ok {
293-
candidateLen := end - pos
294-
// Prefer longer matches. For equal length, prefer higher score.
295-
if candidateLen > bestLen || (candidateLen == bestLen && t.tokenScore(id) > bestScore) {
296-
bestLen = candidateLen
297-
bestID = id
298-
bestScore = t.tokenScore(id)
338+
score := bestScore[i] + float64(t.tokenScore(id))
339+
if score > bestScore[i+tokenLen] {
340+
bestScore[i+tokenLen] = score
341+
bestLen[i+tokenLen] = tokenLen
299342
}
300343
}
301344
}
345+
// Byte fallback: if no vocab token covers position i, use <0xNN>.
346+
byteToken := fmt.Sprintf("<0x%02X>", text[i])
347+
if id, ok := t.vocab[byteToken]; ok {
348+
score := bestScore[i] + float64(t.tokenScore(id))
349+
if score > bestScore[i+1] {
350+
bestScore[i+1] = score
351+
bestLen[i+1] = 1
352+
}
353+
} else {
354+
// Byte token not in vocab; use unknown score as last resort.
355+
score := bestScore[i] + float64(t.unknownScore())
356+
if score > bestScore[i+1] {
357+
bestScore[i+1] = score
358+
bestLen[i+1] = 1
359+
}
360+
}
361+
}
302362

303-
if bestLen == 0 {
304-
// No matching token found; emit UNK and advance by one byte.
305-
ids = append(ids, t.special.UNK)
306-
// Advance past one UTF-8 character, not just one byte.
307-
_, size := decodeRune(textBytes[pos:])
308-
pos += size
363+
// If we can't reach the end, return nil.
364+
if math.IsInf(bestScore[n], -1) {
365+
return nil
366+
}
367+
368+
// Backtrack to find token sequence.
369+
var tokens []int
370+
pos := n
371+
for pos > 0 {
372+
tokLen := bestLen[pos]
373+
candidate := text[pos-tokLen : pos]
374+
if id, ok := t.vocab[candidate]; ok {
375+
tokens = append(tokens, id)
309376
} else {
310-
ids = append(ids, bestID)
311-
pos += bestLen
377+
// Byte fallback for single-byte token.
378+
byteToken := fmt.Sprintf("<0x%02X>", text[pos-tokLen])
379+
if id, ok := t.vocab[byteToken]; ok {
380+
tokens = append(tokens, id)
381+
} else {
382+
tokens = append(tokens, t.special.UNK)
383+
}
312384
}
385+
pos -= tokLen
386+
}
387+
388+
// Reverse (we built it backwards).
389+
for i, j := 0, len(tokens)-1; i < j; i, j = i+1, j-1 {
390+
tokens[i], tokens[j] = tokens[j], tokens[i]
313391
}
314-
return ids
392+
393+
return tokens
394+
}
395+
396+
// unknownScore returns a very negative score used for byte fallback tokens
397+
// when the <0xNN> token is not in the vocabulary.
398+
func (t *BPETokenizer) unknownScore() float32 {
399+
return -100.0
315400
}
316401

317402
// tokenScore returns the score for a token ID, or 0 if scores are not set

0 commit comments

Comments
 (0)