1313
1414namespace tokenizers {
1515
16+ struct TrieTree {
17+ std::unordered_map<int , std::unique_ptr<TrieTree>> children;
18+ std::string word;
19+ std::optional<int > token_id;
20+
21+ TrieTree (const std::unordered_map<std::string, int >& word2id) {
22+ for (auto & pair : word2id) {
23+ add_word (pair.first , pair.second );
24+ }
25+ }
26+
27+ std::pair<std::string, int > find_longest_prefix (const std::string& str) const {
28+ std::string prefix;
29+ int token_id = -1 ;
30+ const TrieTree* node = this ;
31+ for (int i = 0 ; i < str.size (); ++i) {
32+ auto it = node->children .find (str[i]);
33+ if (it == node->children .end ()) {
34+ break ;
35+ }
36+ node = it->second .get ();
37+ RV_CHECK (node != nullptr );
38+ if (node->token_id .has_value ()) {
39+ prefix = node->word ;
40+ token_id = node->token_id .value ();
41+ }
42+ }
43+ RV_CHECK (!prefix.empty ());
44+ RV_CHECK (token_id != -1 );
45+ return {prefix, token_id};
46+ }
47+
48+ private:
49+ TrieTree () = default ;
50+ void add_word (const std::string& word, int token_id) {
51+ return _add_word (word, token_id, 0 );
52+ }
53+ void _add_word (const std::string& word, int token_id, int idx) {
54+ if (idx == word.size ()) {
55+ this ->word = word;
56+ this ->token_id = token_id;
57+ return ;
58+ }
59+ auto & child = children[word[idx]];
60+ if (!child) {
61+ child = std::unique_ptr<TrieTree>(new TrieTree ());
62+ }
63+ child->_add_word (word, token_id, idx + 1 );
64+ }
65+ };
66+
1667RWKVWorldToolTokenizer::RWKVWorldToolTokenizer (const std::string &path) {
1768 std::ifstream infile;
1869 infile.open (path, std::ios::binary | std::ios::in);
@@ -29,28 +80,17 @@ RWKVWorldToolTokenizer::RWKVWorldToolTokenizer(const std::string &path) {
2980 for (auto &pair : _idx2word) {
3081 _word2idx[pair.second ] = pair.first ;
3182 }
83+ _tree = std::make_unique<TrieTree>(_word2idx);
3284}
3385
34- std::vector<int > RWKVWorldToolTokenizer::encode (std::string_view str) const {
86+ std::vector<int > RWKVWorldToolTokenizer::encode (const std::string & str) const {
3587 std::vector<int > ids;
3688 int str_idx = 0 ;
37- int word_len = 1 ;
38- int id = 0 ;
89+
3990 while (str_idx < str.size ()) {
40- if (str_idx + word_len > str.size ()) {
41- ids.push_back (id);
42- break ;
43- }
44- auto substr = str.substr (str_idx, word_len);
45- auto it = _word2idx.find (std::string (substr));
46- if (it == _word2idx.end ()) {
47- ids.push_back (id);
48- str_idx += (word_len - 1 );
49- word_len = 1 ;
50- } else {
51- id = it->second ;
52- word_len++;
53- }
91+ auto [prefix, token_id] = _tree->find_longest_prefix (str.substr (str_idx));
92+ ids.push_back (token_id);
93+ str_idx += prefix.size ();
5494 }
5595 return ids;
5696}
0 commit comments