Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 37 additions & 48 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions spnl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ tokio = { version = "1.44.1", features = ["io-std", "io-util", "signal"], option
tokio-stream = { version = "0.1.18", features = ["net"], optional = true }
tokio-util = { version = "0.7.16", optional = true }
anyhow = { version = "1.0.98" }
leann-core = { version = "0.1.1", optional = true }
ndarray = { version = "0.16", optional = true }
leann-core = { version = "0.1.3", optional = true }
ndarray = { version = "0.17", optional = true }
either = { version = "1.13", optional = true }
indexmap = { version = "2.7.0", optional = true }
itertools = { version = "0.14.0", optional = true }
Expand Down
17 changes: 7 additions & 10 deletions spnl/src/augment/index/raptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,8 @@ async fn cross_index(
let params = SearchParams::default();

let summary_futures = (0..num_passages).map(|idx| {
let passage_id = &id_map[idx];
let _passage_text = passage_mgr
.get_passage(passage_id)
.get_passage_by_index(idx)
.map(|p| p.text.clone())
.unwrap_or_default();

Expand All @@ -137,12 +136,10 @@ async fn cross_index(
let similar_texts: Vec<Query> = labels
.into_iter()
.filter_map(|label| {
id_map.get(label).and_then(|id| {
passage_mgr
.get_passage(id)
.ok()
.map(|p| Query::Message(User(p.text.clone())))
})
passage_mgr
.get_passage_by_index(label)
.ok()
.map(|p| Query::Message(User(p.text.clone())))
})
.collect();

Expand Down Expand Up @@ -180,8 +177,8 @@ async fn cross_index(
.with_compact(false);

// Re-add original passages
for id in &id_map {
if let Ok(p) = passage_mgr.get_passage(id) {
for (idx, id) in id_map.iter().enumerate() {
if let Ok(p) = passage_mgr.get_passage_by_index(idx) {
let mut metadata = HashMap::new();
metadata.insert("id".to_string(), serde_json::Value::String(id.clone()));
builder.add_text(&p.text, metadata);
Expand Down
15 changes: 7 additions & 8 deletions spnl/src/augment/retrieve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,12 @@ pub async fn retrieve(

// Search for each query vector
let params = SearchParams::default();
let matching_ids: Vec<String> = body_vectors
let matching_labels: Vec<usize> = body_vectors
.into_iter()
.flat_map(|query_vec| {
let (labels, _distances) =
search_hnsw(&graph, &query_vec, max_matches, &stored_vectors, &params);
labels
.into_iter()
.filter_map(|label| id_map.get(label).cloned())
labels.into_iter()
})
.unique()
.collect();
Expand All @@ -114,16 +112,17 @@ pub async fn retrieve(
eprintln!(
"RAG fragments total_passages {} relevant_fragments {}",
passages.len(),
matching_ids.len()
matching_labels.len()
);
}

let mut d: Vec<String> = matching_ids
let mut d: Vec<String> = matching_labels
.into_iter()
.rev() // reverse so most relevant is closest to query (at end)
.filter_map(|id| {
.filter_map(|label| {
let id = id_map.get(label).map(|s| s.as_str()).unwrap_or("?");
passages
.get_passage(&id)
.get_passage_by_index(label)
.ok()
.map(|p| format!("Relevant Document {id}: {}", p.text))
})
Expand Down
6 changes: 6 additions & 0 deletions spnl/src/optimizer/hlo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,12 @@ mod tests {
return Ok(());
}

// Remove stale index sentinel so the index is always rebuilt
// with the current leann-core format.
let _ = std::fs::remove_file(
"data/spnl/default.local_google_embeddinggemma-300m.path_to_doc.txt.SimpleEmbedRetrieve.ok",
);

let model = "spnl/m"; // This should work, because we use SimpleEmbedRetrieve which won't do any generation
let q = Message(User("Hello".to_string()));
let d = "I know all about Hello and stuff";
Expand Down
Loading