11use anyhow:: { anyhow, Context , Result } ;
22use half:: f16;
33use hf_hub:: api:: sync:: Api ;
4- use ndarray:: Array2 ;
4+ use ndarray:: { Array2 , ArrayView2 , CowArray , Ix2 } ;
55use safetensors:: { tensor:: Dtype , SafeTensors } ;
66use serde_json:: Value ;
7+ use std:: borrow:: Cow ;
78use std:: { env, fs, path:: Path } ;
89use tokenizers:: Tokenizer ;
910
1011/// Static embedding model for Model2Vec
1112#[ derive( Debug , Clone ) ]
1213pub struct StaticModel {
1314 tokenizer : Tokenizer ,
14- embeddings : Array2 < f32 > ,
15- weights : Option < Vec < f32 > > ,
16- token_mapping : Option < Vec < usize > > ,
15+ embeddings : CowArray < ' static , f32 , Ix2 > ,
16+ weights : Option < Cow < ' static , [ f32 ] > > ,
17+ token_mapping : Option < Cow < ' static , [ usize ] > > ,
1718 normalize : bool ,
1819 median_token_length : usize ,
1920 unk_token_id : Option < usize > ,
@@ -64,32 +65,12 @@ impl StaticModel {
6465 // Load the tokenizer
6566 let tokenizer = Tokenizer :: from_file ( & tok_path) . map_err ( |e| anyhow ! ( "failed to load tokenizer: {e}" ) ) ?;
6667
67- // Median-token-length hack for pre-truncation
68- let mut lens: Vec < usize > = tokenizer. get_vocab ( false ) . keys ( ) . map ( |tk| tk. len ( ) ) . collect ( ) ;
69- lens. sort_unstable ( ) ;
70- let median_token_length = lens. get ( lens. len ( ) / 2 ) . copied ( ) . unwrap_or ( 1 ) ;
71-
7268 // Read normalize default from config.json
7369 let cfg_file = std:: fs:: File :: open ( & cfg_path) . context ( "failed to read config.json" ) ?;
7470 let cfg: Value = serde_json:: from_reader ( & cfg_file) . context ( "failed to parse config.json" ) ?;
7571 let cfg_norm = cfg. get ( "normalize" ) . and_then ( Value :: as_bool) . unwrap_or ( true ) ;
7672 let normalize = normalize. unwrap_or ( cfg_norm) ;
7773
78- // Serialize the tokenizer to JSON, then parse it and get the unk_token
79- let spec_json = tokenizer
80- . to_string ( false )
81- . map_err ( |e| anyhow ! ( "tokenizer -> JSON failed: {e}" ) ) ?;
82- let spec: Value = serde_json:: from_str ( & spec_json) ?;
83- let unk_token = spec
84- . get ( "model" )
85- . and_then ( |m| m. get ( "unk_token" ) )
86- . and_then ( Value :: as_str)
87- . unwrap_or ( "[UNK]" ) ;
88- let unk_token_id = tokenizer
89- . token_to_id ( unk_token)
90- . ok_or_else ( || anyhow ! ( "tokenizer claims unk_token='{unk_token}' but it isn't in the vocab" ) ) ?
91- as usize ;
92-
9374 // Load the safetensors
9475 let model_bytes = fs:: read ( & mdl_path) . context ( "failed to read model.safetensors" ) ?;
9576 let safet = SafeTensors :: deserialize ( & model_bytes) . context ( "failed to parse safetensors" ) ?;
@@ -115,7 +96,6 @@ impl StaticModel {
11596 Dtype :: I8 => raw. iter ( ) . map ( |& b| f32:: from ( b as i8 ) ) . collect ( ) ,
11697 other => return Err ( anyhow ! ( "unsupported tensor dtype: {other:?}" ) ) ,
11798 } ;
118- let embeddings = Array2 :: from_shape_vec ( ( rows, cols) , floats) . context ( "failed to build embeddings array" ) ?;
11999
120100 // Load optional weights for vocabulary quantization
121101 let weights = match safet. tensor ( "weights" ) {
@@ -154,17 +134,125 @@ impl StaticModel {
154134 Err ( _) => None ,
155135 } ;
156136
137+ Self :: from_owned ( tokenizer, floats, rows, cols, normalize, weights, token_mapping)
138+ }
139+
140+ /// Construct from owned data.
141+ ///
142+ /// # Arguments
143+ /// * `tokenizer` - Pre-deserialized tokenizer
144+ /// * `embeddings` - Owned f32 embedding data
145+ /// * `rows` - Number of vocabulary entries
146+ /// * `cols` - Embedding dimension
147+ /// * `normalize` - Whether to L2-normalize output embeddings
148+ /// * `weights` - Optional per-token weights for quantized models
149+ /// * `token_mapping` - Optional token ID mapping for quantized models
150+ pub fn from_owned (
151+ tokenizer : Tokenizer ,
152+ embeddings : Vec < f32 > ,
153+ rows : usize ,
154+ cols : usize ,
155+ normalize : bool ,
156+ weights : Option < Vec < f32 > > ,
157+ token_mapping : Option < Vec < usize > > ,
158+ ) -> Result < Self > {
159+ if embeddings. len ( ) != rows * cols {
160+ return Err ( anyhow ! (
161+ "embeddings length {} != rows {} * cols {}" ,
162+ embeddings. len( ) ,
163+ rows,
164+ cols
165+ ) ) ;
166+ }
167+
168+ let ( median_token_length, unk_token_id) = Self :: compute_metadata ( & tokenizer) ?;
169+
170+ let embeddings =
171+ Array2 :: from_shape_vec ( ( rows, cols) , embeddings) . context ( "failed to build embeddings array" ) ?;
172+
173+ Ok ( Self {
174+ tokenizer,
175+ embeddings : CowArray :: from ( embeddings) ,
176+ weights : weights. map ( Cow :: Owned ) ,
177+ token_mapping : token_mapping. map ( Cow :: Owned ) ,
178+ normalize,
179+ median_token_length,
180+ unk_token_id,
181+ } )
182+ }
183+
184+ /// Construct from static slices (zero-copy for embedded binary data).
185+ ///
186+ /// # Arguments
187+ /// * `tokenizer` - Pre-deserialized tokenizer
188+ /// * `embeddings` - Static f32 embedding data (borrowed, no copy)
189+ /// * `rows` - Number of vocabulary entries
190+ /// * `cols` - Embedding dimension
191+ /// * `normalize` - Whether to L2-normalize output embeddings
192+ /// * `weights` - Optional static per-token weights for quantized models
193+ /// * `token_mapping` - Optional static token ID mapping for quantized models
194+ #[ allow( dead_code) ] // Public API for external crates
195+ pub fn from_borrowed (
196+ tokenizer : Tokenizer ,
197+ embeddings : & ' static [ f32 ] ,
198+ rows : usize ,
199+ cols : usize ,
200+ normalize : bool ,
201+ weights : Option < & ' static [ f32 ] > ,
202+ token_mapping : Option < & ' static [ usize ] > ,
203+ ) -> Result < Self > {
204+ if embeddings. len ( ) != rows * cols {
205+ return Err ( anyhow ! (
206+ "embeddings length {} != rows {} * cols {}" ,
207+ embeddings. len( ) ,
208+ rows,
209+ cols
210+ ) ) ;
211+ }
212+
213+ let ( median_token_length, unk_token_id) = Self :: compute_metadata ( & tokenizer) ?;
214+
215+ let embeddings = ArrayView2 :: from_shape ( ( rows, cols) , embeddings) . context ( "failed to build embeddings view" ) ?;
216+
157217 Ok ( Self {
158218 tokenizer,
159- embeddings,
160- weights,
161- token_mapping,
219+ embeddings : CowArray :: from ( embeddings ) ,
220+ weights : weights . map ( Cow :: Borrowed ) ,
221+ token_mapping : token_mapping . map ( Cow :: Borrowed ) ,
162222 normalize,
163223 median_token_length,
164- unk_token_id : Some ( unk_token_id ) ,
224+ unk_token_id,
165225 } )
166226 }
167227
228+ /// Compute median token length and unk_token_id from tokenizer.
229+ fn compute_metadata ( tokenizer : & Tokenizer ) -> Result < ( usize , Option < usize > ) > {
230+ // Median-token-length hack for pre-truncation
231+ let mut lens: Vec < usize > = tokenizer. get_vocab ( false ) . keys ( ) . map ( |tk| tk. len ( ) ) . collect ( ) ;
232+ lens. sort_unstable ( ) ;
233+ let median_token_length = lens. get ( lens. len ( ) / 2 ) . copied ( ) . unwrap_or ( 1 ) ;
234+
235+ // Get unk_token from tokenizer (optional - BPE tokenizers may not have one)
236+ let spec_json = tokenizer
237+ . to_string ( false )
238+ . map_err ( |e| anyhow ! ( "tokenizer -> JSON failed: {e}" ) ) ?;
239+ let spec: Value = serde_json:: from_str ( & spec_json) ?;
240+ let unk_token = spec
241+ . get ( "model" )
242+ . and_then ( |m| m. get ( "unk_token" ) )
243+ . and_then ( Value :: as_str) ;
244+ let unk_token_id = if let Some ( tok) = unk_token {
245+ let id = tokenizer
246+ . token_to_id ( tok)
247+ . ok_or_else ( || anyhow ! ( "tokenizer declares unk_token='{tok}' but it isn't in the vocab" ) ) ?;
248+ Some ( id as usize )
249+ } else {
250+ None
251+ } ;
252+
253+ Ok ( ( median_token_length, unk_token_id) )
254+ }
255+
168256 /// Char-level truncation to max_tokens * median_token_length
169257 fn truncate_str ( s : & str , max_tokens : usize , median_len : usize ) -> & str {
170258 let max_chars = max_tokens. saturating_mul ( median_len) ;
0 commit comments