Skip to content

Commit c2a8dad

Browse files
committed
fix: add local-model concurrency guard and expand model tests
- Serialize causal model predictions to avoid concurrent Qwen crashes - Add FFI concurrency tests for Qwen and other local models - Auto-download required models for integration tests - Document test cache env vars in README
1 parent 9a5dd2f commit c2a8dad

4 files changed

Lines changed: 165 additions & 0 deletions

File tree

embeddings/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,12 @@ cargo build --lib --release
1414
g++ -o test examples/test.cpp -Ltarget/release -lmanticore_knn_embeddings -I. -lpthread -ldl -std=c++17
1515
```
1616

17+
## Testing
18+
19+
Some integration tests download model files into a cache directory if they are missing. You can
20+
override the cache location with environment variables:
21+
22+
- `MANTICORE_TEST_CACHE`: preferred cache path for tests
23+
- `MANTICORE_CACHE_PATH`: fallback cache path for tests
24+
25+
If neither is set, tests use `./.cache/manticore` under the repo.

embeddings/src/model/ffi_test.rs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ use std::ptr;
66
#[cfg(test)]
77
mod tests {
88
use super::*;
9+
use crate::model::local::build_model_info;
10+
use std::collections::HashSet;
11+
use std::path::PathBuf;
12+
use std::sync::{Mutex, OnceLock};
913

1014
// Helper function to create a C string from Rust string
1115
fn to_c_string(s: &str) -> CString {
@@ -20,6 +24,97 @@ mod tests {
2024
}
2125
}
2226

27+
fn test_cache_root() -> String {
28+
std::env::var("MANTICORE_TEST_CACHE")
29+
.or_else(|_| std::env::var("MANTICORE_CACHE_PATH"))
30+
.unwrap_or_else(|_| format!("{}/.cache/manticore", env!("CARGO_MANIFEST_DIR")))
31+
}
32+
33+
fn ensure_model_cached(model_id: &str, cache_path: &PathBuf) {
34+
static DOWNLOADED: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
35+
let downloaded = DOWNLOADED.get_or_init(|| Mutex::new(HashSet::new()));
36+
let mut set = downloaded.lock().expect("model cache lock poisoned");
37+
if set.contains(model_id) {
38+
return;
39+
}
40+
std::fs::create_dir_all(cache_path).expect("failed to create model cache directory");
41+
build_model_info(cache_path.clone(), model_id, "main")
42+
.expect("failed to download model into cache");
43+
set.insert(model_id.to_string());
44+
}
45+
46+
fn run_concurrent_ffi_embeddings(model_id: &str) {
47+
use std::sync::Arc;
48+
use std::thread;
49+
50+
let model_id = model_id.to_string();
51+
let cache_root = test_cache_root();
52+
let cache_path_buf = PathBuf::from(&cache_root);
53+
ensure_model_cached(&model_id, &cache_path_buf);
54+
55+
let model_name = to_c_string(&model_id);
56+
let cache_path = to_c_string(&cache_root);
57+
let api_key = to_c_string("");
58+
59+
let result = TextModelWrapper::load_model(
60+
model_name.as_ptr(),
61+
model_name.as_bytes().len(),
62+
cache_path.as_ptr(),
63+
cache_path.as_bytes().len(),
64+
api_key.as_ptr(),
65+
api_key.as_bytes().len(),
66+
false,
67+
);
68+
69+
if result.model.is_null() {
70+
let error_message = if result.error.is_null() {
71+
"unknown error".to_string()
72+
} else {
73+
unsafe {
74+
CStr::from_ptr(result.error)
75+
.to_str()
76+
.unwrap_or("unknown error")
77+
.to_string()
78+
}
79+
};
80+
TextModelWrapper::free_model_result(result);
81+
panic!("failed to load model {}: {}", model_id, error_message);
82+
}
83+
84+
let model_ptr = result.model as usize;
85+
let start = Arc::new(std::sync::Barrier::new(4));
86+
let handles: Vec<_> = (0..3)
87+
.map(|i| {
88+
let start = Arc::clone(&start);
89+
let model_ptr = model_ptr;
90+
let model_id = model_id.clone();
91+
thread::spawn(move || {
92+
start.wait();
93+
let text = format!("Concurrent embedding test {} - {}", model_id, i);
94+
let item = create_string_item(&text);
95+
let items = [item];
96+
97+
// Safety: emulate FFI callers that share a model pointer across threads.
98+
let wrapper = unsafe {
99+
std::mem::transmute::<*mut std::ffi::c_void, TextModelWrapper>(
100+
model_ptr as *mut std::ffi::c_void,
101+
)
102+
};
103+
let vec_result =
104+
TextModelWrapper::make_vect_embeddings(&wrapper, items.as_ptr(), 1);
105+
TextModelWrapper::free_vec_result(vec_result);
106+
})
107+
})
108+
.collect();
109+
110+
start.wait();
111+
for handle in handles {
112+
handle.join().unwrap();
113+
}
114+
115+
TextModelWrapper::free_model_result(result);
116+
}
117+
23118
#[test]
24119
fn test_text_model_result_structure() {
25120
// Test that TextModelResult has the expected structure
@@ -367,4 +462,23 @@ mod tests {
367462
assert_eq!(options2.api_key, Some("sk-test456".to_string()));
368463
assert_eq!(options2.use_gpu, None);
369464
}
465+
466+
#[test]
467+
fn test_concurrent_qwen_embeddings_via_ffi() {
468+
run_concurrent_ffi_embeddings("Qwen/Qwen3-Embedding-0.6B");
469+
}
470+
471+
#[test]
472+
fn test_concurrent_other_models_via_ffi() {
473+
let model_ids = [
474+
"sentence-transformers/all-MiniLM-L6-v2",
475+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
476+
"Locutusque/TinyMistral-248M-v2",
477+
"h2oai/embeddinggemma-300m",
478+
];
479+
480+
for model_id in model_ids {
481+
run_concurrent_ffi_embeddings(model_id);
482+
}
483+
}
370484
}

embeddings/src/model/local.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use serde_json::Value;
2020
use std::cell::RefCell;
2121
use std::error::Error;
2222
use std::path::PathBuf;
23+
use std::sync::Mutex;
2324
use tokenizers::Tokenizer;
2425

2526
/// Model architecture type - determines pooling strategy
@@ -193,6 +194,7 @@ pub struct CausalEmbeddingModel {
193194
max_input_len: usize,
194195
hidden_size: usize,
195196
device: Device,
197+
predict_lock: Mutex<()>,
196198
}
197199

198200
impl CausalEmbeddingModel {
@@ -282,6 +284,7 @@ impl CausalEmbeddingModel {
282284
max_input_len,
283285
hidden_size,
284286
device,
287+
predict_lock: Mutex::new(()),
285288
})
286289
}
287290
}
@@ -345,6 +348,15 @@ impl LocalModel {
345348

346349
impl TextModel for LocalModel {
347350
fn predict(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
351+
let _predict_guard = match self {
352+
LocalModel::Causal(m) => Some(
353+
m.predict_lock
354+
.lock()
355+
.map_err(|_| LibError::ModelLoadFailed)?,
356+
),
357+
LocalModel::Bert(_) => None,
358+
};
359+
348360
let (device, max_input_len) = match self {
349361
LocalModel::Bert(m) => (m.device.clone(), m.max_input_len),
350362
LocalModel::Causal(m) => (m.device.clone(), m.max_input_len),

embeddings/src/model/local_test.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,36 @@ mod tests {
66
use crate::model::TextModel;
77
use crate::utils::{get_hidden_size, get_max_input_length};
88
use approx::assert_abs_diff_eq;
9+
use std::collections::HashSet;
910
use std::path::PathBuf;
11+
use std::sync::{Mutex, OnceLock};
1012

1113
fn check_embedding_properties(embedding: &[f32], expected_len: usize) {
1214
assert_eq!(embedding.len(), expected_len);
1315
let norm: f32 = embedding.iter().map(|&x| x * x).sum::<f32>().sqrt();
1416
assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-6);
1517
}
1618

19+
fn ensure_model_cached(model_id: &str, cache_path: &PathBuf) {
20+
static DOWNLOADED: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
21+
let downloaded = DOWNLOADED.get_or_init(|| Mutex::new(HashSet::new()));
22+
let mut set = downloaded.lock().expect("model cache lock poisoned");
23+
if set.contains(model_id) {
24+
return;
25+
}
26+
std::fs::create_dir_all(cache_path).expect("failed to create model cache directory");
27+
build_model_info(cache_path.clone(), model_id, "main")
28+
.expect("failed to download model into cache");
29+
set.insert(model_id.to_string());
30+
}
31+
1732
fn test_cache_path() -> PathBuf {
33+
if let Ok(path) = std::env::var("MANTICORE_TEST_CACHE") {
34+
return PathBuf::from(path);
35+
}
36+
if let Ok(path) = std::env::var("MANTICORE_CACHE_PATH") {
37+
return PathBuf::from(path);
38+
}
1839
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(".cache/manticore")
1940
}
2041

@@ -325,6 +346,7 @@ mod tests {
325346
fn test_all_minilm_l6_v2() {
326347
let model_id = "sentence-transformers/all-MiniLM-L6-v2";
327348
let cache_path = test_cache_path();
349+
ensure_model_cached(model_id, &cache_path);
328350

329351
let test_sentences = [
330352
"This is a test sentence.",
@@ -343,6 +365,7 @@ mod tests {
343365
fn test_embedding_consistency() {
344366
let model_id = "sentence-transformers/all-MiniLM-L6-v2";
345367
let cache_path = test_cache_path();
368+
ensure_model_cached(model_id, &cache_path);
346369
let local_model = LocalModel::new(model_id, cache_path, false).unwrap();
347370

348371
let sentence = &["This is a test sentence."];
@@ -358,6 +381,7 @@ mod tests {
358381
fn test_hidden_size() {
359382
let model_id = "sentence-transformers/all-MiniLM-L6-v2";
360383
let cache_path = test_cache_path();
384+
ensure_model_cached(model_id, &cache_path);
361385
let local_model = LocalModel::new(model_id, cache_path, false).unwrap();
362386
assert_eq!(local_model.get_hidden_size(), 384);
363387
}
@@ -366,6 +390,7 @@ mod tests {
366390
fn test_max_input_len() {
367391
let model_id = "sentence-transformers/all-MiniLM-L6-v2";
368392
let cache_path = test_cache_path();
393+
ensure_model_cached(model_id, &cache_path);
369394
let local_model = LocalModel::new(model_id, cache_path, false).unwrap();
370395
assert_eq!(local_model.get_max_input_len(), 512);
371396
}
@@ -375,6 +400,7 @@ mod tests {
375400
// Integration test for Qwen embedding models
376401
let model_id = "Qwen/Qwen3-Embedding-0.6B";
377402
let cache_path = test_cache_path();
403+
ensure_model_cached(model_id, &cache_path);
378404

379405
let local_model = LocalModel::new(model_id, cache_path.clone(), false)
380406
.expect("Qwen model should load successfully");
@@ -395,6 +421,7 @@ mod tests {
395421
// Integration test for Llama-based embedding models.
396422
let model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0";
397423
let cache_path = test_cache_path();
424+
ensure_model_cached(model_id, &cache_path);
398425

399426
let local_model =
400427
LocalModel::new(model_id, cache_path.clone(), false).expect("Llama model should load");
@@ -410,6 +437,7 @@ mod tests {
410437
// Integration test for Mistral-based embedding models.
411438
let model_id = "Locutusque/TinyMistral-248M-v2";
412439
let cache_path = test_cache_path();
440+
ensure_model_cached(model_id, &cache_path);
413441

414442
let local_model = LocalModel::new(model_id, cache_path.clone(), false)
415443
.expect("Mistral model should load");
@@ -424,6 +452,7 @@ mod tests {
424452
// Integration test for Gemma-based embedding models.
425453
let model_id = "h2oai/embeddinggemma-300m";
426454
let cache_path = test_cache_path();
455+
ensure_model_cached(model_id, &cache_path);
427456

428457
let local_model =
429458
LocalModel::new(model_id, cache_path.clone(), false).expect("Gemma model should load");
@@ -438,6 +467,7 @@ mod tests {
438467
// Test batch processing with Qwen model
439468
let model_id = "Qwen/Qwen3-Embedding-0.6B";
440469
let cache_path = test_cache_path();
470+
ensure_model_cached(model_id, &cache_path);
441471

442472
let result = LocalModel::new(model_id, cache_path.clone(), false);
443473

0 commit comments

Comments
 (0)