Skip to content

Commit 52635d9

Browse files
authored
feat: Add from_owned/from_borrowed constructors, improve unk_token handling (#33)
1 parent c7d323d commit 52635d9

2 files changed

Lines changed: 144 additions & 29 deletions

File tree

src/model.rs

Lines changed: 117 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
use anyhow::{anyhow, Context, Result};
22
use half::f16;
33
use hf_hub::api::sync::Api;
4-
use ndarray::Array2;
4+
use ndarray::{Array2, ArrayView2, CowArray, Ix2};
55
use safetensors::{tensor::Dtype, SafeTensors};
66
use serde_json::Value;
7+
use std::borrow::Cow;
78
use std::{env, fs, path::Path};
89
use tokenizers::Tokenizer;
910

1011
/// Static embedding model for Model2Vec
1112
#[derive(Debug, Clone)]
1213
pub struct StaticModel {
1314
tokenizer: Tokenizer,
14-
embeddings: Array2<f32>,
15-
weights: Option<Vec<f32>>,
16-
token_mapping: Option<Vec<usize>>,
15+
embeddings: CowArray<'static, f32, Ix2>,
16+
weights: Option<Cow<'static, [f32]>>,
17+
token_mapping: Option<Cow<'static, [usize]>>,
1718
normalize: bool,
1819
median_token_length: usize,
1920
unk_token_id: Option<usize>,
@@ -64,32 +65,12 @@ impl StaticModel {
6465
// Load the tokenizer
6566
let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;
6667

67-
// Median-token-length hack for pre-truncation
68-
let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
69-
lens.sort_unstable();
70-
let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);
71-
7268
// Read normalize default from config.json
7369
let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?;
7470
let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?;
7571
let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
7672
let normalize = normalize.unwrap_or(cfg_norm);
7773

78-
// Serialize the tokenizer to JSON, then parse it and get the unk_token
79-
let spec_json = tokenizer
80-
.to_string(false)
81-
.map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?;
82-
let spec: Value = serde_json::from_str(&spec_json)?;
83-
let unk_token = spec
84-
.get("model")
85-
.and_then(|m| m.get("unk_token"))
86-
.and_then(Value::as_str)
87-
.unwrap_or("[UNK]");
88-
let unk_token_id = tokenizer
89-
.token_to_id(unk_token)
90-
.ok_or_else(|| anyhow!("tokenizer claims unk_token='{unk_token}' but it isn't in the vocab"))?
91-
as usize;
92-
9374
// Load the safetensors
9475
let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?;
9576
let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?;
@@ -115,7 +96,6 @@ impl StaticModel {
11596
Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(),
11697
other => return Err(anyhow!("unsupported tensor dtype: {other:?}")),
11798
};
118-
let embeddings = Array2::from_shape_vec((rows, cols), floats).context("failed to build embeddings array")?;
11999

120100
// Load optional weights for vocabulary quantization
121101
let weights = match safet.tensor("weights") {
@@ -154,17 +134,125 @@ impl StaticModel {
154134
Err(_) => None,
155135
};
156136

137+
Self::from_owned(tokenizer, floats, rows, cols, normalize, weights, token_mapping)
138+
}
139+
140+
/// Construct from owned data.
141+
///
142+
/// # Arguments
143+
/// * `tokenizer` - Pre-deserialized tokenizer
144+
/// * `embeddings` - Owned f32 embedding data
145+
/// * `rows` - Number of vocabulary entries
146+
/// * `cols` - Embedding dimension
147+
/// * `normalize` - Whether to L2-normalize output embeddings
148+
/// * `weights` - Optional per-token weights for quantized models
149+
/// * `token_mapping` - Optional token ID mapping for quantized models
150+
pub fn from_owned(
151+
tokenizer: Tokenizer,
152+
embeddings: Vec<f32>,
153+
rows: usize,
154+
cols: usize,
155+
normalize: bool,
156+
weights: Option<Vec<f32>>,
157+
token_mapping: Option<Vec<usize>>,
158+
) -> Result<Self> {
159+
if embeddings.len() != rows * cols {
160+
return Err(anyhow!(
161+
"embeddings length {} != rows {} * cols {}",
162+
embeddings.len(),
163+
rows,
164+
cols
165+
));
166+
}
167+
168+
let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?;
169+
170+
let embeddings =
171+
Array2::from_shape_vec((rows, cols), embeddings).context("failed to build embeddings array")?;
172+
173+
Ok(Self {
174+
tokenizer,
175+
embeddings: CowArray::from(embeddings),
176+
weights: weights.map(Cow::Owned),
177+
token_mapping: token_mapping.map(Cow::Owned),
178+
normalize,
179+
median_token_length,
180+
unk_token_id,
181+
})
182+
}
183+
184+
/// Construct from static slices (zero-copy for embedded binary data).
185+
///
186+
/// # Arguments
187+
/// * `tokenizer` - Pre-deserialized tokenizer
188+
/// * `embeddings` - Static f32 embedding data (borrowed, no copy)
189+
/// * `rows` - Number of vocabulary entries
190+
/// * `cols` - Embedding dimension
191+
/// * `normalize` - Whether to L2-normalize output embeddings
192+
/// * `weights` - Optional static per-token weights for quantized models
193+
/// * `token_mapping` - Optional static token ID mapping for quantized models
194+
#[allow(dead_code)] // Public API for external crates
195+
pub fn from_borrowed(
196+
tokenizer: Tokenizer,
197+
embeddings: &'static [f32],
198+
rows: usize,
199+
cols: usize,
200+
normalize: bool,
201+
weights: Option<&'static [f32]>,
202+
token_mapping: Option<&'static [usize]>,
203+
) -> Result<Self> {
204+
if embeddings.len() != rows * cols {
205+
return Err(anyhow!(
206+
"embeddings length {} != rows {} * cols {}",
207+
embeddings.len(),
208+
rows,
209+
cols
210+
));
211+
}
212+
213+
let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?;
214+
215+
let embeddings = ArrayView2::from_shape((rows, cols), embeddings).context("failed to build embeddings view")?;
216+
157217
Ok(Self {
158218
tokenizer,
159-
embeddings,
160-
weights,
161-
token_mapping,
219+
embeddings: CowArray::from(embeddings),
220+
weights: weights.map(Cow::Borrowed),
221+
token_mapping: token_mapping.map(Cow::Borrowed),
162222
normalize,
163223
median_token_length,
164-
unk_token_id: Some(unk_token_id),
224+
unk_token_id,
165225
})
166226
}
167227

228+
/// Compute median token length and unk_token_id from tokenizer.
229+
fn compute_metadata(tokenizer: &Tokenizer) -> Result<(usize, Option<usize>)> {
230+
// Median-token-length hack for pre-truncation
231+
let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
232+
lens.sort_unstable();
233+
let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);
234+
235+
// Get unk_token from tokenizer (optional - BPE tokenizers may not have one)
236+
let spec_json = tokenizer
237+
.to_string(false)
238+
.map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?;
239+
let spec: Value = serde_json::from_str(&spec_json)?;
240+
let unk_token = spec
241+
.get("model")
242+
.and_then(|m| m.get("unk_token"))
243+
.and_then(Value::as_str);
244+
let unk_token_id = if let Some(tok) = unk_token {
245+
let id = tokenizer
246+
.token_to_id(tok)
247+
.ok_or_else(|| anyhow!("tokenizer declares unk_token='{tok}' but it isn't in the vocab"))?;
248+
Some(id as usize)
249+
} else {
250+
None
251+
};
252+
253+
Ok((median_token_length, unk_token_id))
254+
}
255+
168256
/// Char-level truncation to max_tokens * median_token_length
169257
fn truncate_str(s: &str, max_tokens: usize, median_len: usize) -> &str {
170258
let max_chars = max_tokens.saturating_mul(median_len);

tests/test_model.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,30 @@ fn test_normalization_flag_override() {
7070
"Without normalization override, norm should be larger"
7171
);
7272
}
73+
74+
/// Test from_borrowed constructor (zero-copy path)
75+
#[test]
76+
fn test_from_borrowed() {
77+
use safetensors::SafeTensors;
78+
use std::fs;
79+
use tokenizers::Tokenizer;
80+
81+
let path = "tests/fixtures/test-model-float32";
82+
let tokenizer = Tokenizer::from_file(format!("{path}/tokenizer.json")).unwrap();
83+
let bytes = fs::read(format!("{path}/model.safetensors")).unwrap();
84+
let tensors = SafeTensors::deserialize(&bytes).unwrap();
85+
let tensor = tensors.tensor("embeddings").unwrap();
86+
let [rows, cols]: [usize; 2] = tensor.shape().try_into().unwrap();
87+
let floats: Vec<f32> = tensor
88+
.data()
89+
.chunks_exact(4)
90+
.map(|b| f32::from_le_bytes(b.try_into().unwrap()))
91+
.collect();
92+
93+
// Leak to get 'static lifetime (fine for tests)
94+
let floats: &'static [f32] = Box::leak(floats.into_boxed_slice());
95+
96+
let model = StaticModel::from_borrowed(tokenizer, floats, rows, cols, true, None, None).unwrap();
97+
let emb = model.encode_single("hello");
98+
assert!(!emb.is_empty());
99+
}

0 commit comments

Comments
 (0)