Skip to content

Commit 4456fa8

Browse files
authored
fix rwkv world tokenzier (mlc-ai#19)
* support rwkv world tokenizer * refine * rename * refine * switch msgpack version * refine * refine * fix rwkv world tokenizer bug * fix comment
1 parent 22b8932 commit 4456fa8

File tree

2 files changed

+85
-19
lines changed

2 files changed

+85
-19
lines changed

include/rwkv_world_tokenizer.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,44 @@
77
#include <unordered_map>
88
#include <string>
99
#include <vector>
10+
#include <exception>
11+
#include <sstream>
12+
#include <stdexcept>
13+
14+
#define STRINGIFY(...) STRINGIFY_(__VA_ARGS__)
15+
#define STRINGIFY_(...) #__VA_ARGS__
16+
#define RV_CHECK(...) \
17+
for (bool _rv_check_status = (__VA_ARGS__); !_rv_check_status;) \
18+
throw FRException() << ("Check \"" STRINGIFY(__VA_ARGS__) "\" failed at " + \
19+
std::to_string(__LINE__) + \
20+
" in " __FILE__ "\n > Error msg: ")
21+
struct FRException : public std::runtime_error {
22+
FRException() : std::runtime_error("") {}
23+
const char *what() const noexcept override { return msg.c_str(); }
24+
template <typename T> FRException &operator<<(const T &s) {
25+
std::stringstream ss;
26+
ss << s;
27+
msg += ss.str();
28+
return *this;
29+
}
30+
std::string msg;
31+
};
1032

1133
namespace tokenizers {
12-
class RWKVWorldToolTokenizer {
34+
struct TrieTree;
35+
36+
class RWKVWorldToolTokenizer{
1337
public:
1438
RWKVWorldToolTokenizer(const std::string &path);
15-
std::vector<int> encode(std::string_view str) const;
39+
std::vector<int> encode(const std::string &str) const;
1640
std::string decode(const std::vector<int> &ids) const;
1741
std::string decode(int id) const;
1842

1943
private:
2044
std::unordered_map<std::string, int> _word2idx;
2145
std::unordered_map<int, std::string> _idx2word;
46+
std::unique_ptr<TrieTree> _tree;
2247
};
48+
2349
} // namespace tokenizers
2450

src/rwkv_world_tokenizer.cc

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,57 @@
1313

1414
namespace tokenizers {
1515

16+
struct TrieTree {
17+
std::unordered_map<int, std::unique_ptr<TrieTree>> children;
18+
std::string word;
19+
std::optional<int> token_id;
20+
21+
TrieTree(const std::unordered_map<std::string, int>& word2id) {
22+
for (auto& pair : word2id) {
23+
add_word(pair.first, pair.second);
24+
}
25+
}
26+
27+
std::pair<std::string, int> find_longest_prefix(const std::string& str) const {
28+
std::string prefix;
29+
int token_id = -1;
30+
const TrieTree* node = this;
31+
for (int i = 0; i < str.size(); ++i) {
32+
auto it = node->children.find(str[i]);
33+
if (it == node->children.end()) {
34+
break;
35+
}
36+
node = it->second.get();
37+
RV_CHECK(node != nullptr);
38+
if (node->token_id.has_value()) {
39+
prefix = node->word;
40+
token_id = node->token_id.value();
41+
}
42+
}
43+
RV_CHECK(!prefix.empty());
44+
RV_CHECK(token_id != -1);
45+
return {prefix, token_id};
46+
}
47+
48+
private:
49+
TrieTree() = default;
50+
void add_word(const std::string& word, int token_id) {
51+
return _add_word(word, token_id, 0);
52+
}
53+
void _add_word(const std::string& word, int token_id, int idx) {
54+
if (idx == word.size()) {
55+
this->word = word;
56+
this->token_id = token_id;
57+
return;
58+
}
59+
auto& child = children[word[idx]];
60+
if (!child) {
61+
child = std::unique_ptr<TrieTree>(new TrieTree());
62+
}
63+
child->_add_word(word, token_id, idx + 1);
64+
}
65+
};
66+
1667
RWKVWorldToolTokenizer::RWKVWorldToolTokenizer(const std::string &path) {
1768
std::ifstream infile;
1869
infile.open(path, std::ios::binary | std::ios::in);
@@ -29,28 +80,17 @@ RWKVWorldToolTokenizer::RWKVWorldToolTokenizer(const std::string &path) {
2980
for (auto &pair : _idx2word) {
3081
_word2idx[pair.second] = pair.first;
3182
}
83+
_tree = std::make_unique<TrieTree>(_word2idx);
3284
}
3385

34-
std::vector<int> RWKVWorldToolTokenizer::encode(std::string_view str) const {
86+
std::vector<int> RWKVWorldToolTokenizer::encode(const std::string &str) const {
3587
std::vector<int> ids;
3688
int str_idx = 0;
37-
int word_len = 1;
38-
int id = 0;
89+
3990
while (str_idx < str.size()) {
40-
if (str_idx + word_len > str.size()) {
41-
ids.push_back(id);
42-
break;
43-
}
44-
auto substr = str.substr(str_idx, word_len);
45-
auto it = _word2idx.find(std::string(substr));
46-
if (it == _word2idx.end()) {
47-
ids.push_back(id);
48-
str_idx += (word_len - 1);
49-
word_len = 1;
50-
} else {
51-
id = it->second;
52-
word_len++;
53-
}
91+
auto [prefix, token_id] = _tree->find_longest_prefix(str.substr(str_idx));
92+
ids.push_back(token_id);
93+
str_idx += prefix.size();
5494
}
5595
return ids;
5696
}

0 commit comments

Comments
 (0)