-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloader.go
More file actions
285 lines (250 loc) · 7.81 KB
/
loader.go
File metadata and controls
285 lines (250 loc) · 7.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
package ztoken
import (
"encoding/json"
"fmt"
"os"
"strings"
"unicode"
"golang.org/x/text/unicode/norm"
)
// tokenizerJSON represents the HuggingFace tokenizer.json schema.
type tokenizerJSON struct {
Model modelJSON `json:"model"`
AddedTokens []addedTokenJSON `json:"added_tokens"`
PreTokenizer *preTokenizerJSON `json:"pre_tokenizer"`
Normalizer *normalizerJSON `json:"normalizer"`
Decoder *decoderJSON `json:"decoder"`
}
type modelJSON struct {
Type string `json:"type"`
Vocab map[string]int `json:"vocab"`
RawMerges json.RawMessage `json:"merges"`
ContinuingSubwordPrefix string `json:"continuing_subword_prefix"`
}
type addedTokenJSON struct {
ID int `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
}
type preTokenizerJSON struct {
Type string `json:"type"`
PreTokenizers []preTokenizerJSON `json:"pretokenizers"`
}
type normalizerJSON struct {
Type string `json:"type"`
Normalizers []normalizerJSON `json:"normalizers"`
}
type decoderJSON struct {
Type string `json:"type"`
Pattern *decoderPatternJSON `json:"pattern"`
Content string `json:"content"`
Decoders []decoderJSON `json:"decoders"`
}
type decoderPatternJSON struct {
String string `json:"String"`
}
// Load reads a HuggingFace tokenizer.json file and returns the appropriate
// Tokenizer implementation based on the model type (BPE or WordPiece).
func Load(path string) (Tokenizer, error) {
data, err := os.ReadFile(path) //nolint:gosec // user-provided path
if err != nil {
return nil, fmt.Errorf("read tokenizer.json: %w", err)
}
var tj tokenizerJSON
if err := json.Unmarshal(data, &tj); err != nil {
return nil, fmt.Errorf("parse tokenizer.json: %w", err)
}
switch tj.Model.Type {
case "WordPiece":
return loadWordPiece(tj)
case "", "BPE":
return loadBPE(tj)
default:
return nil, fmt.Errorf("unsupported model type: %q (supported: BPE, WordPiece)", tj.Model.Type)
}
}
// LoadFromJSON reads a HuggingFace tokenizer.json file and returns a BPETokenizer.
// For loading any tokenizer type, use [Load] instead.
func LoadFromJSON(path string) (*BPETokenizer, error) {
data, err := os.ReadFile(path) //nolint:gosec // user-provided path
if err != nil {
return nil, fmt.Errorf("read tokenizer.json: %w", err)
}
var tj tokenizerJSON
if err := json.Unmarshal(data, &tj); err != nil {
return nil, fmt.Errorf("parse tokenizer.json: %w", err)
}
if tj.Model.Type != "" && tj.Model.Type != "BPE" {
return nil, fmt.Errorf("unsupported model type: %q (only BPE supported)", tj.Model.Type)
}
return loadBPE(tj)
}
// loadBPE constructs a BPETokenizer from parsed JSON.
func loadBPE(tj tokenizerJSON) (*BPETokenizer, error) {
// Parse merges — supports both ["a b", …] and [["a","b"], …] formats.
merges, err := parseMerges(tj.Model.RawMerges)
if err != nil {
return nil, fmt.Errorf("parse merges: %w", err)
}
// Detect byte-level BPE from pre-tokenizer config.
byteLevelBPE := isByteLevelPreTokenizer(tj.PreTokenizer)
// Extract special tokens.
special := extractSpecialTokens(tj.AddedTokens)
// Build normalizer function if present.
normalizer := buildNormalizer(tj.Normalizer)
tok := NewBPETokenizer(tj.Model.Vocab, merges, special, byteLevelBPE)
tok.normalizer = normalizer
// Detect SentencePiece mode from the decoder config.
if isSentencePieceDecoder(tj.Decoder) {
tok.SetSentencePiece(true)
}
return tok, nil
}
// loadWordPiece constructs a WordPieceTokenizer from parsed JSON.
func loadWordPiece(tj tokenizerJSON) (*WordPieceTokenizer, error) {
special := extractSpecialTokens(tj.AddedTokens)
normalizer := buildNormalizer(tj.Normalizer)
tok := NewWordPieceTokenizer(tj.Model.Vocab, special)
tok.normalizer = normalizer
// Register special token strings for exact matching.
specialMap := make(map[string]int)
for _, at := range tj.AddedTokens {
if at.Special {
specialMap[at.Content] = at.ID
}
}
tok.specialTokens = specialMap
return tok, nil
}
// isByteLevelPreTokenizer returns true if the pre-tokenizer config uses ByteLevel.
func isByteLevelPreTokenizer(pt *preTokenizerJSON) bool {
if pt == nil {
return false
}
if pt.Type == "ByteLevel" {
return true
}
if pt.Type == "Sequence" {
for _, child := range pt.PreTokenizers {
if child.Type == "ByteLevel" {
return true
}
}
}
return false
}
// isSentencePieceDecoder returns true if the decoder config indicates a
// SentencePiece tokenizer. This is detected by a Metaspace decoder type or
// a Replace rule that converts U+2581 (▁) to a space.
func isSentencePieceDecoder(d *decoderJSON) bool {
if d == nil {
return false
}
if d.Type == "Metaspace" {
return true
}
if d.Type == "Replace" && d.Pattern != nil && d.Pattern.String == "\u2581" {
return true
}
if d.Type == "Sequence" {
for i := range d.Decoders {
if isSentencePieceDecoder(&d.Decoders[i]) {
return true
}
}
}
return false
}
// extractSpecialTokens finds BOS, EOS, PAD, UNK from added_tokens.
// Recognizes both GPT-style (<s>, </s>, <pad>, <unk>) and BERT-style
// ([CLS], [SEP], [PAD], [UNK]) special token conventions.
func extractSpecialTokens(tokens []addedTokenJSON) SpecialTokens {
special := SpecialTokens{}
for _, t := range tokens {
if !t.Special {
continue
}
switch {
case strings.Contains(t.Content, "bos") || t.Content == "<s>" || t.Content == "[CLS]":
special.BOS = t.ID
case strings.Contains(t.Content, "eos") || t.Content == "</s>" || t.Content == "[SEP]":
special.EOS = t.ID
case strings.Contains(t.Content, "pad") || t.Content == "<pad>" || t.Content == "[PAD]":
special.PAD = t.ID
case strings.Contains(t.Content, "unk") || t.Content == "<unk>" || t.Content == "[UNK]":
special.UNK = t.ID
}
}
return special
}
// NormalizerFunc transforms text before tokenization.
//
// Stable.
type NormalizerFunc func(string) string
// buildNormalizer creates a normalizer function from the JSON config.
func buildNormalizer(n *normalizerJSON) NormalizerFunc {
if n == nil {
return nil
}
switch n.Type {
case "NFC":
return func(s string) string { return norm.NFC.String(s) }
case "NFD":
return func(s string) string { return norm.NFD.String(s) }
case "Lowercase":
return strings.ToLower
case "Strip":
return func(s string) string { return strings.TrimFunc(s, unicode.IsSpace) }
case "Sequence":
var chain []NormalizerFunc
for i := range n.Normalizers {
if fn := buildNormalizer(&n.Normalizers[i]); fn != nil {
chain = append(chain, fn)
}
}
if len(chain) == 0 {
return nil
}
return func(s string) string {
for _, fn := range chain {
s = fn(s)
}
return s
}
default:
return nil
}
}
// parseMerges decodes merges from JSON, accepting either space-separated
// strings (["a b", …]) or two-element arrays ([["a","b"], …]).
func parseMerges(raw json.RawMessage) ([]MergePair, error) {
if len(raw) == 0 {
return nil, nil
}
// Try []string first (most common).
var stringMerges []string
if err := json.Unmarshal(raw, &stringMerges); err == nil {
merges := make([]MergePair, 0, len(stringMerges))
for i, m := range stringMerges {
left, right, ok := strings.Cut(m, " ")
if !ok {
return nil, fmt.Errorf("invalid merge at index %d: %q", i, m)
}
merges = append(merges, MergePair{Left: left, Right: right})
}
return merges, nil
}
// Try [][]string (Gemma 3 format).
var arrayMerges [][]string
if err := json.Unmarshal(raw, &arrayMerges); err != nil {
return nil, fmt.Errorf("unsupported merges format: %w", err)
}
merges := make([]MergePair, 0, len(arrayMerges))
for i, pair := range arrayMerges {
if len(pair) != 2 {
return nil, fmt.Errorf("invalid merge at index %d: expected 2 elements, got %d", i, len(pair))
}
merges = append(merges, MergePair{Left: pair[0], Right: pair[1]})
}
return merges, nil
}