@@ -799,6 +799,145 @@ func TestSentencePieceUnigram_LongText(t *testing.T) {
799799 }
800800}
801801
802+ func TestSentencePieceUnigram_ByteFallbackNeverBeatsVocab (t * testing.T ) {
803+ // Regression test: byte fallback tokens must never be preferred over
804+ // multi-character vocab tokens, even when byte token scores are higher.
805+ // This was the original bug — byte tokens like <0xE2> had scores of 0.0
806+ // which beat multi-character tokens with negative scores, producing 43
807+ // byte-level tokens instead of 7 word tokens.
808+ vocab := map [string ]int {
809+ "<unk>" : 0 ,
810+ "<s>" : 1 ,
811+ "</s>" : 2 ,
812+ "\u2581 " : 3 ,
813+ "\u2581 What" : 4 ,
814+ "\u2581 is" : 5 ,
815+ "\u2581 the" : 6 ,
816+ "\u2581 capital" : 7 ,
817+ "\u2581 of" : 8 ,
818+ "\u2581 France" : 9 ,
819+ "?" : 10 ,
820+ }
821+ // Add byte fallback tokens for all 256 bytes.
822+ nextID := 11
823+ for b := 0 ; b < 256 ; b ++ {
824+ tok := fmt .Sprintf ("<0x%02X>" , b )
825+ vocab [tok ] = nextID
826+ nextID ++
827+ }
828+
829+ scores := make ([]float32 , nextID )
830+ scores [0 ] = - 100 // <unk>
831+ scores [1 ] = - 100 // <s>
832+ scores [2 ] = - 100 // </s>
833+ scores [3 ] = - 5.0 // ▁
834+ scores [4 ] = - 8.0 // ▁What
835+ scores [5 ] = - 7.0 // ▁is
836+ scores [6 ] = - 6.0 // ▁the
837+ scores [7 ] = - 9.0 // ▁capital
838+ scores [8 ] = - 6.0 // ▁of
839+ scores [9 ] = - 9.0 // ▁France
840+ scores [10 ] = - 4.0 // ?
841+ // Byte fallback tokens get HIGH scores (the bug scenario).
842+ // Before the fix, these would win over multi-character vocab tokens.
843+ for i := 11 ; i < nextID ; i ++ {
844+ scores [i ] = 0.0
845+ }
846+
847+ special := SpecialTokens {BOS : 1 , EOS : 2 , PAD : 0 , UNK : 0 }
848+ tok := NewBPETokenizer (vocab , nil , special , false )
849+ tok .SetSentencePiece (true )
850+ tok .SetScores (scores )
851+
852+ ids , err := tok .Encode ("What is the capital of France?" )
853+ if err != nil {
854+ t .Fatalf ("Encode error: %v" , err )
855+ }
856+ // Must produce word-level tokens, not byte-level tokens.
857+ // "What is the capital of France?" -> [▁What, ▁is, ▁the, ▁capital, ▁of, ▁France, ?]
858+ want := []int {4 , 5 , 6 , 7 , 8 , 9 , 10 }
859+ if len (ids ) != len (want ) {
860+ t .Fatalf ("Encode produced %d tokens %v, want %d tokens %v" , len (ids ), ids , len (want ), want )
861+ }
862+ for i , id := range ids {
863+ if id != want [i ] {
864+ t .Errorf ("[%d] = %d, want %d" , i , id , want [i ])
865+ }
866+ }
867+
868+ // Verify round-trip.
869+ decoded , err := tok .Decode (ids )
870+ if err != nil {
871+ t .Fatalf ("Decode error: %v" , err )
872+ }
873+ if decoded != "What is the capital of France?" {
874+ t .Errorf ("Decode = %q, want %q" , decoded , "What is the capital of France?" )
875+ }
876+ }
877+
878+ func TestSentencePieceUnigram_ByteFallbackStillWorksForUnknownChars (t * testing.T ) {
879+ // Byte fallback must still be used for characters that have no
880+ // vocabulary coverage (e.g., emoji, rare Unicode).
881+ vocab := map [string ]int {
882+ "<unk>" : 0 ,
883+ "<s>" : 1 ,
884+ "</s>" : 2 ,
885+ "\u2581 " : 3 ,
886+ "\u2581 hi" : 4 ,
887+ }
888+ nextID := 5
889+ for b := 0 ; b < 256 ; b ++ {
890+ tok := fmt .Sprintf ("<0x%02X>" , b )
891+ vocab [tok ] = nextID
892+ nextID ++
893+ }
894+
895+ scores := make ([]float32 , nextID )
896+ scores [0 ] = - 100
897+ scores [1 ] = - 100
898+ scores [2 ] = - 100
899+ scores [3 ] = - 5.0
900+ scores [4 ] = - 1.0 // ▁hi
901+ for i := 5 ; i < nextID ; i ++ {
902+ scores [i ] = - 2.0 // byte scores
903+ }
904+
905+ special := SpecialTokens {BOS : 1 , EOS : 2 , PAD : 0 , UNK : 0 }
906+ tok := NewBPETokenizer (vocab , nil , special , false )
907+ tok .SetSentencePiece (true )
908+ tok .SetScores (scores )
909+
910+ // "hi" has a vocab token; should use it.
911+ ids , err := tok .Encode ("hi" )
912+ if err != nil {
913+ t .Fatalf ("Encode(\" hi\" ) error: %v" , err )
914+ }
915+ if len (ids ) != 1 || ids [0 ] != 4 {
916+ t .Errorf ("Encode(\" hi\" ) = %v, want [4] (▁hi)" , ids )
917+ }
918+
919+ // "hi\xc3\xa9" — é (U+00E9) is not in vocab, must use byte fallback.
920+ ids , err = tok .Encode ("hi\xc3 \xa9 " )
921+ if err != nil {
922+ t .Fatalf ("Encode error: %v" , err )
923+ }
924+ // Should be: ▁hi + <0xC3> + <0xA9>
925+ if len (ids ) != 3 {
926+ t .Fatalf ("Encode(\" hi\\ xc3\\ xa9\" ) = %v (len=%d), want 3 tokens" , ids , len (ids ))
927+ }
928+ if ids [0 ] != 4 {
929+ t .Errorf ("[0] = %d, want 4 (▁hi)" , ids [0 ])
930+ }
931+ // Verify round-trip through decode.
932+ decoded , err := tok .Decode (ids )
933+ if err != nil {
934+ t .Fatalf ("Decode error: %v" , err )
935+ }
936+ if decoded != "hi\xc3 \xa9 " {
937+ t .Errorf ("Decode = %q, want %q" , decoded , "hi\xc3 \xa9 " )
938+ }
939+ }
940+
802941func TestDecodeSentencePieceBytes (t * testing.T ) {
803942 tests := []struct {
804943 name string
0 commit comments