@@ -7,19 +7,23 @@ use tokenizers::tokenizer::Tokenizer;
77
88pub struct TokenizerWrapper {
99 tokenizer : Tokenizer ,
10- encode_ids : Vec < u32 > ,
1110 decode_str : String ,
1211 id_to_token_result : String ,
1312}
1413
1514pub type Vocab = HashMap < String , u32 > ;
1615pub type Merges = Vec < ( String , String ) > ;
1716
17+ #[ repr( C ) ]
18+ pub struct TokenizerEncodeResult {
19+ token_ids : * mut u32 ,
20+ len : usize ,
21+ }
22+
1823impl TokenizerWrapper {
1924 pub fn from_str ( json : & str ) -> TokenizerWrapper {
2025 TokenizerWrapper {
2126 tokenizer : Tokenizer :: from_str ( json) . unwrap ( ) . into ( ) ,
22- encode_ids : Vec :: new ( ) ,
2327 decode_str : String :: new ( ) ,
2428 id_to_token_result : String :: new ( ) ,
2529 }
@@ -77,16 +81,22 @@ impl TokenizerWrapper {
7781 . with_decoder ( byte_level) ;
7882 TokenizerWrapper {
7983 tokenizer : tokenizer,
80- encode_ids : Vec :: new ( ) ,
8184 decode_str : String :: new ( ) ,
8285 id_to_token_result : String :: new ( ) ,
8386 }
8487 }
8588
86- pub fn encode ( & mut self , text : & str , add_special_tokens : bool ) {
89+ pub fn encode ( & mut self , text : & str , add_special_tokens : bool ) -> Vec < u32 > {
8790 let encoded = self . tokenizer . encode ( text, add_special_tokens) . unwrap ( ) ;
88- self . encode_ids . resize ( encoded. len ( ) , 0 ) ;
89- self . encode_ids . copy_from_slice ( encoded. get_ids ( ) ) ;
91+ return encoded. get_ids ( ) . to_vec ( ) ;
92+ }
93+
94+ pub fn encode_batch ( & mut self , texts : Vec < & str > , add_special_tokens : bool ) -> Vec < Vec < u32 > > {
95+ let results = self . tokenizer . encode_batch ( texts, add_special_tokens) . unwrap ( )
96+ . into_iter ( )
97+ . map ( |encoded| encoded. get_ids ( ) . to_vec ( ) )
98+ . collect :: < Vec < Vec < u32 > > > ( ) ;
99+ return results;
90100 }
91101
92102 pub fn decode ( & mut self , ids : & [ u32 ] , skip_special_tokens : bool ) {
@@ -135,22 +145,53 @@ extern "C" fn tokenizers_encode(
135145 input_cstr : * const u8 ,
136146 len : usize ,
137147 add_special_tokens : i32 ,
148+ out_result : * mut TokenizerEncodeResult ,
138149) {
139150 unsafe {
140151 let input_data = std:: str:: from_utf8 ( std:: slice:: from_raw_parts ( input_cstr, len) ) . unwrap ( ) ;
141- ( * handle) . encode ( input_data, add_special_tokens != 0 ) ;
152+ let encoded = ( * handle) . encode ( input_data, add_special_tokens != 0 ) ;
153+ let len = encoded. len ( ) ;
154+ * out_result = TokenizerEncodeResult {
155+ token_ids : Box :: into_raw ( encoded. into_boxed_slice ( ) ) as * mut u32 ,
156+ len : len,
157+ } ;
142158 }
143159}
144160
145161#[ no_mangle]
146- extern "C" fn tokenizers_get_encode_ids (
162+ extern "C" fn tokenizers_encode_batch (
147163 handle : * mut TokenizerWrapper ,
148- out_data : * mut * mut u32 ,
149- out_len : * mut usize ,
164+ input_cstr : * const * const u8 ,
165+ input_len : * const usize ,
166+ num_seqs : usize ,
167+ add_special_tokens : i32 ,
168+ out_result : * mut TokenizerEncodeResult ,
150169) {
151170 unsafe {
152- * out_data = ( * handle) . encode_ids . as_mut_ptr ( ) ;
153- * out_len = ( * handle) . encode_ids . len ( )
171+ let input_data = ( 0 ..num_seqs)
172+ . map ( |i| {
173+ std:: str:: from_utf8 ( std:: slice:: from_raw_parts ( * input_cstr. offset ( i as isize ) , * input_len. offset ( i as isize ) ) ) . unwrap ( )
174+ } )
175+ . collect :: < Vec < & str > > ( ) ;
176+ let encoded_batch = ( * handle) . encode_batch ( input_data, add_special_tokens != 0 ) ;
177+ for ( i, encoded) in encoded_batch. into_iter ( ) . enumerate ( ) {
178+ let len = encoded. len ( ) ;
179+ let result = TokenizerEncodeResult {
180+ token_ids : Box :: into_raw ( encoded. into_boxed_slice ( ) ) as * mut u32 ,
181+ len : len,
182+ } ;
183+ * out_result. offset ( i as isize ) = result;
184+ }
185+ }
186+ }
187+
188+ #[ no_mangle]
189+ extern "C" fn tokenizers_free_encode_results ( results : * mut TokenizerEncodeResult , num_seqs : usize ) {
190+ unsafe {
191+ let slice = std:: slice:: from_raw_parts_mut ( results, num_seqs) ;
192+ for result in & mut * slice {
193+ drop ( Box :: from_raw ( std:: slice:: from_raw_parts_mut ( result. token_ids , result. len ) ) ) ;
194+ }
154195 }
155196}
156197
0 commit comments