@@ -2,6 +2,7 @@ package ztoken
22
33import (
44 "fmt"
5+ "math"
56 "strings"
67 "unicode/utf8"
78)
@@ -202,6 +203,8 @@ func (t *BPETokenizer) Decode(ids []int) (string, error) {
202203 return decoded , nil
203204 }
204205 if t .sentencePiece {
206+ // Decode <0xNN> byte tokens back to actual bytes.
207+ result = decodeSentencePieceBytes (result )
205208 // Replace ▁ with space and trim leading space.
206209 result = strings .ReplaceAll (result , "\u2581 " , " " )
207210 result = strings .TrimPrefix (result , " " )
@@ -210,6 +213,43 @@ func (t *BPETokenizer) Decode(ids []int) (string, error) {
210213 return result , nil
211214}
212215
216+ // decodeSentencePieceBytes replaces <0xNN> hex byte tokens with the
217+ // corresponding raw bytes. This reverses the byte fallback encoding
218+ // used by SentencePiece for characters not in the vocabulary.
219+ func decodeSentencePieceBytes (s string ) string {
220+ var sb strings.Builder
221+ i := 0
222+ for i < len (s ) {
223+ // Look for <0xNN> pattern: exactly 6 characters.
224+ if i + 6 <= len (s ) && s [i ] == '<' && s [i + 1 ] == '0' && s [i + 2 ] == 'x' && s [i + 5 ] == '>' {
225+ hi := unhex (s [i + 3 ])
226+ lo := unhex (s [i + 4 ])
227+ if hi >= 0 && lo >= 0 {
228+ sb .WriteByte (byte (hi << 4 | lo ))
229+ i += 6
230+ continue
231+ }
232+ }
233+ sb .WriteByte (s [i ])
234+ i ++
235+ }
236+ return sb .String ()
237+ }
238+
239+ // unhex converts a hex digit character to its value, or -1 if invalid.
240+ func unhex (c byte ) int {
241+ switch {
242+ case c >= '0' && c <= '9' :
243+ return int (c - '0' )
244+ case c >= 'A' && c <= 'F' :
245+ return int (c - 'A' ) + 10
246+ case c >= 'a' && c <= 'f' :
247+ return int (c - 'a' ) + 10
248+ default :
249+ return - 1
250+ }
251+ }
252+
213253// VocabSize returns the number of tokens in the vocabulary.
214254func (t * BPETokenizer ) VocabSize () int {
215255 return len (t .vocab )
@@ -259,59 +299,104 @@ func (t *BPETokenizer) SetScores(scores []float32) {
259299 }
260300}
261301
262- // sentencePieceEncode tokenizes text using greedy longest-match with
263- // score-based selection. For each position, it finds all vocabulary tokens
264- // that match the input at that position, selects the longest match (breaking
265- // ties by highest score), and advances past the matched token.
302+ // sentencePieceEncode tokenizes text using Viterbi dynamic programming to find
303+ // the segmentation that maximizes the sum of log-probability scores.
266304//
267305// This is used for SentencePiece unigram models that provide vocabulary
268- // scores but no BPE merge table (e.g., Mistral 7B GGUF).
306+ // scores but no BPE merge table (e.g., Mistral 7B GGUF). The Viterbi approach
307+ // finds the globally optimal segmentation, unlike greedy longest-match which
308+ // can produce suboptimal splits.
269309func (t * BPETokenizer ) sentencePieceEncode (text string ) []int {
270310 if text == "" {
271311 return nil
272312 }
273- var ids []int
274- pos := 0
275- textBytes := []byte (text )
276- n := len (textBytes )
277313
278- for pos < n {
279- bestLen := 0
280- bestID := t . special . UNK
281- bestScore := float32 ( - 1e30 )
282-
283- // Search for the longest matching token at this position.
284- // Limit search window to maxTokenLen bytes.
285- maxEnd := pos + t . maxTokenLen
286- if maxEnd > n {
287- maxEnd = n
288- }
314+ n := len ( text ) // byte length
315+
316+ // Viterbi forward pass: find best segmentation.
317+ // bestScore[i] = best total score for text[:i]
318+ // bestLen[i] = byte length of the last token in the best path ending at i
319+ bestScore := make ([] float64 , n + 1 )
320+ bestLen := make ([] int , n + 1 )
321+ for i := range bestScore {
322+ bestScore [ i ] = math . Inf ( - 1 )
323+ }
324+ bestScore [ 0 ] = 0
289325
290- for end := pos + 1 ; end <= maxEnd ; end ++ {
291- candidate := string (textBytes [pos :end ])
326+ for i := 0 ; i < n ; i ++ {
327+ if math .IsInf (bestScore [i ], - 1 ) {
328+ continue
329+ }
330+ // Try all possible tokens starting at position i.
331+ maxLen := t .maxTokenLen
332+ if maxLen > n - i {
333+ maxLen = n - i
334+ }
335+ for tokenLen := 1 ; tokenLen <= maxLen ; tokenLen ++ {
336+ candidate := text [i : i + tokenLen ]
292337 if id , ok := t .vocab [candidate ]; ok {
293- candidateLen := end - pos
294- // Prefer longer matches. For equal length, prefer higher score.
295- if candidateLen > bestLen || (candidateLen == bestLen && t .tokenScore (id ) > bestScore ) {
296- bestLen = candidateLen
297- bestID = id
298- bestScore = t .tokenScore (id )
338+ score := bestScore [i ] + float64 (t .tokenScore (id ))
339+ if score > bestScore [i + tokenLen ] {
340+ bestScore [i + tokenLen ] = score
341+ bestLen [i + tokenLen ] = tokenLen
299342 }
300343 }
301344 }
345+ // Byte fallback: if no vocab token covers position i, use <0xNN>.
346+ byteToken := fmt .Sprintf ("<0x%02X>" , text [i ])
347+ if id , ok := t .vocab [byteToken ]; ok {
348+ score := bestScore [i ] + float64 (t .tokenScore (id ))
349+ if score > bestScore [i + 1 ] {
350+ bestScore [i + 1 ] = score
351+ bestLen [i + 1 ] = 1
352+ }
353+ } else {
354+ // Byte token not in vocab; use unknown score as last resort.
355+ score := bestScore [i ] + float64 (t .unknownScore ())
356+ if score > bestScore [i + 1 ] {
357+ bestScore [i + 1 ] = score
358+ bestLen [i + 1 ] = 1
359+ }
360+ }
361+ }
302362
303- if bestLen == 0 {
304- // No matching token found; emit UNK and advance by one byte.
305- ids = append (ids , t .special .UNK )
306- // Advance past one UTF-8 character, not just one byte.
307- _ , size := decodeRune (textBytes [pos :])
308- pos += size
363+ // If we can't reach the end, return nil.
364+ if math .IsInf (bestScore [n ], - 1 ) {
365+ return nil
366+ }
367+
368+ // Backtrack to find token sequence.
369+ var tokens []int
370+ pos := n
371+ for pos > 0 {
372+ tokLen := bestLen [pos ]
373+ candidate := text [pos - tokLen : pos ]
374+ if id , ok := t .vocab [candidate ]; ok {
375+ tokens = append (tokens , id )
309376 } else {
310- ids = append (ids , bestID )
311- pos += bestLen
377+ // Byte fallback for single-byte token.
378+ byteToken := fmt .Sprintf ("<0x%02X>" , text [pos - tokLen ])
379+ if id , ok := t .vocab [byteToken ]; ok {
380+ tokens = append (tokens , id )
381+ } else {
382+ tokens = append (tokens , t .special .UNK )
383+ }
312384 }
385+ pos -= tokLen
386+ }
387+
388+ // Reverse (we built it backwards).
389+ for i , j := 0 , len (tokens )- 1 ; i < j ; i , j = i + 1 , j - 1 {
390+ tokens [i ], tokens [j ] = tokens [j ], tokens [i ]
313391 }
314- return ids
392+
393+ return tokens
394+ }
395+
396+ // unknownScore returns a very negative score used for byte fallback tokens
397+ // when the <0xNN> token is not in the vocabulary.
398+ func (t * BPETokenizer ) unknownScore () float32 {
399+ return - 100.0
315400}
316401
317402// tokenScore returns the score for a token ID, or 0 if scores are not set
0 commit comments