Skip to content

Commit 45a5ae0

Browse files
committed
fix: use large penalty for byte fallback in SentencePiece Viterbi
Byte fallback tokens (<0xNN>) were competing with multi-character vocab tokens in the Viterbi DP using their actual vocabulary scores. When byte token scores happened to be higher than vocab token scores, the Viterbi algorithm preferred 43 byte-level tokens over 7 word-level tokens. Fix: assign byte fallback tokens a fixed score of -1e6 instead of their vocabulary score, ensuring they are only used as a last resort when no vocab token covers a position. This matches llama.cpp behavior.
1 parent 8f43e44 commit 45a5ae0

2 files changed

Lines changed: 146 additions & 2 deletions

File tree

bpe.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,10 +342,15 @@ func (t *BPETokenizer) sentencePieceEncode(text string) []int {
342342
}
343343
}
344344
}
345-
// Byte fallback: if no vocab token covers position i, use <0xNN>.
345+
// Byte fallback: use <0xNN> as last resort when no vocab token covers
346+
// position i. Byte tokens get a fixed penalty of -1e6 so they never
347+
// beat real vocabulary tokens in the Viterbi DP. This matches
348+
// llama.cpp / SentencePiece behavior where byte fallback is only
349+
// used for characters that have no vocabulary coverage.
346350
byteToken := fmt.Sprintf("<0x%02X>", text[i])
347351
if id, ok := t.vocab[byteToken]; ok {
348-
score := bestScore[i] + float64(t.tokenScore(id))
352+
_ = id // byte token exists but we ignore its vocab score
353+
score := bestScore[i] + (-1e6)
349354
if score > bestScore[i+1] {
350355
bestScore[i+1] = score
351356
bestLen[i+1] = 1

bpe_test.go

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,145 @@ func TestSentencePieceUnigram_LongText(t *testing.T) {
799799
}
800800
}
801801

802+
func TestSentencePieceUnigram_ByteFallbackNeverBeatsVocab(t *testing.T) {
803+
// Regression test: byte fallback tokens must never be preferred over
804+
// multi-character vocab tokens, even when byte token scores are higher.
805+
// This was the original bug — byte tokens like <0xE2> had scores of 0.0
806+
// which beat multi-character tokens with negative scores, producing 43
807+
// byte-level tokens instead of 7 word tokens.
808+
vocab := map[string]int{
809+
"<unk>": 0,
810+
"<s>": 1,
811+
"</s>": 2,
812+
"\u2581": 3,
813+
"\u2581What": 4,
814+
"\u2581is": 5,
815+
"\u2581the": 6,
816+
"\u2581capital": 7,
817+
"\u2581of": 8,
818+
"\u2581France": 9,
819+
"?": 10,
820+
}
821+
// Add byte fallback tokens for all 256 bytes.
822+
nextID := 11
823+
for b := 0; b < 256; b++ {
824+
tok := fmt.Sprintf("<0x%02X>", b)
825+
vocab[tok] = nextID
826+
nextID++
827+
}
828+
829+
scores := make([]float32, nextID)
830+
scores[0] = -100 // <unk>
831+
scores[1] = -100 // <s>
832+
scores[2] = -100 // </s>
833+
scores[3] = -5.0 // ▁
834+
scores[4] = -8.0 // ▁What
835+
scores[5] = -7.0 // ▁is
836+
scores[6] = -6.0 // ▁the
837+
scores[7] = -9.0 // ▁capital
838+
scores[8] = -6.0 // ▁of
839+
scores[9] = -9.0 // ▁France
840+
scores[10] = -4.0 // ?
841+
// Byte fallback tokens get HIGH scores (the bug scenario).
842+
// Before the fix, these would win over multi-character vocab tokens.
843+
for i := 11; i < nextID; i++ {
844+
scores[i] = 0.0
845+
}
846+
847+
special := SpecialTokens{BOS: 1, EOS: 2, PAD: 0, UNK: 0}
848+
tok := NewBPETokenizer(vocab, nil, special, false)
849+
tok.SetSentencePiece(true)
850+
tok.SetScores(scores)
851+
852+
ids, err := tok.Encode("What is the capital of France?")
853+
if err != nil {
854+
t.Fatalf("Encode error: %v", err)
855+
}
856+
// Must produce word-level tokens, not byte-level tokens.
857+
// "What is the capital of France?" -> [▁What, ▁is, ▁the, ▁capital, ▁of, ▁France, ?]
858+
want := []int{4, 5, 6, 7, 8, 9, 10}
859+
if len(ids) != len(want) {
860+
t.Fatalf("Encode produced %d tokens %v, want %d tokens %v", len(ids), ids, len(want), want)
861+
}
862+
for i, id := range ids {
863+
if id != want[i] {
864+
t.Errorf("[%d] = %d, want %d", i, id, want[i])
865+
}
866+
}
867+
868+
// Verify round-trip.
869+
decoded, err := tok.Decode(ids)
870+
if err != nil {
871+
t.Fatalf("Decode error: %v", err)
872+
}
873+
if decoded != "What is the capital of France?" {
874+
t.Errorf("Decode = %q, want %q", decoded, "What is the capital of France?")
875+
}
876+
}
877+
878+
func TestSentencePieceUnigram_ByteFallbackStillWorksForUnknownChars(t *testing.T) {
879+
// Byte fallback must still be used for characters that have no
880+
// vocabulary coverage (e.g., emoji, rare Unicode).
881+
vocab := map[string]int{
882+
"<unk>": 0,
883+
"<s>": 1,
884+
"</s>": 2,
885+
"\u2581": 3,
886+
"\u2581hi": 4,
887+
}
888+
nextID := 5
889+
for b := 0; b < 256; b++ {
890+
tok := fmt.Sprintf("<0x%02X>", b)
891+
vocab[tok] = nextID
892+
nextID++
893+
}
894+
895+
scores := make([]float32, nextID)
896+
scores[0] = -100
897+
scores[1] = -100
898+
scores[2] = -100
899+
scores[3] = -5.0
900+
scores[4] = -1.0 // ▁hi
901+
for i := 5; i < nextID; i++ {
902+
scores[i] = -2.0 // byte scores
903+
}
904+
905+
special := SpecialTokens{BOS: 1, EOS: 2, PAD: 0, UNK: 0}
906+
tok := NewBPETokenizer(vocab, nil, special, false)
907+
tok.SetSentencePiece(true)
908+
tok.SetScores(scores)
909+
910+
// "hi" has a vocab token; should use it.
911+
ids, err := tok.Encode("hi")
912+
if err != nil {
913+
t.Fatalf("Encode(\"hi\") error: %v", err)
914+
}
915+
if len(ids) != 1 || ids[0] != 4 {
916+
t.Errorf("Encode(\"hi\") = %v, want [4] (▁hi)", ids)
917+
}
918+
919+
// "hi\xc3\xa9" — é (U+00E9) is not in vocab, must use byte fallback.
920+
ids, err = tok.Encode("hi\xc3\xa9")
921+
if err != nil {
922+
t.Fatalf("Encode error: %v", err)
923+
}
924+
// Should be: ▁hi + <0xC3> + <0xA9>
925+
if len(ids) != 3 {
926+
t.Fatalf("Encode(\"hi\\xc3\\xa9\") = %v (len=%d), want 3 tokens", ids, len(ids))
927+
}
928+
if ids[0] != 4 {
929+
t.Errorf("[0] = %d, want 4 (▁hi)", ids[0])
930+
}
931+
// Verify round-trip through decode.
932+
decoded, err := tok.Decode(ids)
933+
if err != nil {
934+
t.Fatalf("Decode error: %v", err)
935+
}
936+
if decoded != "hi\xc3\xa9" {
937+
t.Errorf("Decode = %q, want %q", decoded, "hi\xc3\xa9")
938+
}
939+
}
940+
802941
func TestDecodeSentencePieceBytes(t *testing.T) {
803942
tests := []struct {
804943
name string

0 commit comments

Comments
 (0)