diff --git a/database.h5 b/database.h5 new file mode 100644 index 0000000..18726fb Binary files /dev/null and b/database.h5 differ diff --git a/rs_jaccard/src/jaccard.rs b/rs_jaccard/src/jaccard.rs index d2025fc..99073d7 100644 --- a/rs_jaccard/src/jaccard.rs +++ b/rs_jaccard/src/jaccard.rs @@ -3,9 +3,10 @@ use std::cmp::Ordering; use std::io::{stdout, Write}; use anyhow::Result; use csv; -use log::{info}; +use log::info; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; +use std::collections::HashMap; pub type CoordType = u32; pub type BoundType = usize; pub type ScoreType = f32; @@ -701,6 +702,92 @@ pub fn precompute_similarity_candidates( .collect() } +/// Complete a partial distance matrix by calculating missing values +/// Missing values are indicated by NaN -- Given they will be concatenated via numpy +pub fn complete_partial_matrix( + mut matrix: Vec>, + matrix_ids: &[String], + coords: &[CoordType], + bounds: &[BoundType], + sig_ids: &[String], +) -> Result<(Vec>, usize)> { + let n = matrix.len(); + + if matrix.iter().any(|row| row.len() != n) { + anyhow::bail!("Matrix must be square"); + } + + // Dimensions need to match + if matrix_ids.len() != n { + anyhow::bail!("Matrix ID count ({}) doesn't match matrix size ({})", matrix_ids.len(), n); + } + + //And so do sigs + if sig_ids.len() != bounds.len() - 1 { + anyhow::bail!("Signature ID count ({}) doesn't match bounds length ({})", sig_ids.len(), bounds.len() - 1); + } + + // Map matrix IDs to signature indices + let sig_id_to_index: HashMap = sig_ids + .iter() + .enumerate() + .map(|(i, id)| (id.clone(), i)) + .collect(); + + // Find all missing pairs + let mut missing_pairs = Vec::new(); + + for i in 0..n { + for j in 0..n { + if i != j && matrix[i][j].is_nan() { + if let (Some(&sig_i), Some(&sig_j)) = ( + sig_id_to_index.get(&matrix_ids[i]), + sig_id_to_index.get(&matrix_ids[j]) + ) { + missing_pairs.push((i, j, sig_i, sig_j)); + } + } + } + } + + let missing_count = missing_pairs.len(); + info!("Found {} missing distance pairs to calculate", missing_count); + + if missing_count > 0 { + info!("Computing missing distances..."); + let start_time = std::time::Instant::now(); + + let distances: Vec<(usize, usize, ScoreType)> = missing_pairs + .into_par_iter() + .map(|(matrix_i, matrix_j, sig_i, sig_j)| { + let begin_i = bounds[sig_i]; + let end_i = bounds[sig_i + 1]; + let coords_i = &coords[begin_i..end_i]; + + let begin_j = bounds[sig_j]; + let end_j = bounds[sig_j + 1]; + let coords_j = &coords[begin_j..end_j]; + + let distance = jaccard_distance_core(coords_i, coords_j); + (matrix_i, matrix_j, distance) + }) + .collect(); + + // Update matrix with calculated distances + for (i, j, distance) in distances { + matrix[i][j] = distance; + // Enforce symmetry + if i != j { + matrix[j][i] = distance; + } + } + + info!("Completed {} distance calculations", missing_count); + } + + Ok((matrix, missing_count)) +} + #[cfg(test)] mod tests { use super::*; diff --git a/rs_jaccard/src/main.rs b/rs_jaccard/src/main.rs index b4f0a12..2daa00e 100644 --- a/rs_jaccard/src/main.rs +++ b/rs_jaccard/src/main.rs @@ -14,7 +14,7 @@ use jaccard::*; use signatures::*; use matrix_io::{ detect_format_from_extension, - hdf5::{save_matrix as save_matrix_hdf5, StreamWriter as Hdf5StreamWriter}, + hdf5::{save_matrix as save_matrix_hdf5, load_matrix as load_matrix_hdf5, StreamWriter as Hdf5StreamWriter}, csv::{save_distances, save_similar_pairs, save_matrix_with_ids as save_matrix_csv_with_ids, save_query_ref_matrix as save_query_ref_matrix_csv, write_lsh_results}, binary::save_matrix_with_ids as save_matrix_binary_with_ids, @@ -179,6 +179,19 @@ enum Commands { #[arg(long)] file: PathBuf, }, + + /// Complete partial distance matrix by calculating missing values + CompleteMatrix { + /// Partial distance matrix with intra-species distances + #[arg(long)] + partial_matrix: PathBuf, + #[arg(long)] + signatures: PathBuf, + #[arg(short, long)] + output: PathBuf, + #[arg(short, long)] + threads: Option, + }, } fn main() -> Result<()> { @@ -186,7 +199,7 @@ fn main() -> Result<()> { let cli = Cli::parse(); match &cli.command { - Commands::Query { query, reference, bounds, output, threads, query_idx } => { + Commands::Query { query, reference, bounds, output, threads, query_idx: _ } => { if let Some(t) = threads { rayon::ThreadPoolBuilder::new().num_threads(*t).build_global()?; } @@ -405,6 +418,27 @@ fn main() -> Result<()> { Commands::DebugH5 { file } => { debug_hdf5_ids(file)?; }, + + Commands::CompleteMatrix { partial_matrix, signatures, output, threads } => { + if let Some(t) = threads { + rayon::ThreadPoolBuilder::new().num_threads(*t).build_global()?; + } + + info!("Loading partial matrix from {}", partial_matrix.display()); + let (matrix, matrix_ids) = load_matrix_hdf5(partial_matrix)?; + + info!("Loading signatures from {}", signatures.display()); + let sig_data = read_signatures(signatures)?; + let (coords, bounds, sig_ids) = load_signatures_for_jaccard(&sig_data)?; + + let (completed_matrix, missing_count) = complete_partial_matrix( + matrix, &matrix_ids, &coords, &bounds, &sig_ids + )?; + + info!("Saving completed matrix to {}", output.display()); + save_matrix_hdf5(&completed_matrix, &matrix_ids, output)?; + info!("Matrix completion finished successfully. Calculated {} missing distances.", missing_count); + }, } Ok(()) diff --git a/rs_jaccard/src/matrix_io/hdf5.rs b/rs_jaccard/src/matrix_io/hdf5.rs index c4fc601..f89080e 100644 --- a/rs_jaccard/src/matrix_io/hdf5.rs +++ b/rs_jaccard/src/matrix_io/hdf5.rs @@ -87,8 +87,6 @@ pub fn save_matrix( /// Load matrix from HDF5 format /// Returns (matrix, genome_ids) -/// Use only for testing, don't want compiler to complain -#[cfg(test)] pub fn load_matrix(input_path: &Path) -> Result<(Vec>, Vec)> { info!("Loading matrix and genome IDs from HDF5 file: {}", input_path.display()); let file = Hdf5File::open(input_path) @@ -101,18 +99,22 @@ pub fn load_matrix(input_path: &Path) -> Result<(Vec>, Vec() - .context("Failed to read genome IDs")?; // Convert to Vec> format let matrix: Vec> = matrix_array.outer_iter() .map(|row| row.to_vec()) .collect(); - // Convert ASCII IDs to String - let ids: Vec = ascii_ids.iter() - .map(|ascii_id| ascii_id.to_string()) - .collect(); + // Read genome IDs, trying both ASCII and Unicode string formats + let ids: Vec = if let Ok(ascii_ids) = ids_dataset.read_1d::() { + ascii_ids.iter().map(|ascii_id| { + String::from_utf8_lossy(ascii_id.as_bytes()).to_string() + }).collect() + } else if let Ok(unicode_ids) = ids_dataset.read_1d::() { + unicode_ids.iter().map(|unicode_id| unicode_id.to_string()).collect() + } else { + return Err(anyhow::anyhow!("Failed to read genome IDs in any supported format")); + }; // Validate consistency let n = matrix.len(); diff --git a/rs_jaccard/test_data/Ecoli_Strep_signatures.h5 b/rs_jaccard/test_data/Ecoli_Strep_signatures.h5 new file mode 100644 index 0000000..28ee5a6 Binary files /dev/null and b/rs_jaccard/test_data/Ecoli_Strep_signatures.h5 differ diff --git a/rs_jaccard/test_data/Ecoli_Strept_intra_species_only_matrix.h5 b/rs_jaccard/test_data/Ecoli_Strept_intra_species_only_matrix.h5 new file mode 100644 index 0000000..5d99460 Binary files /dev/null and b/rs_jaccard/test_data/Ecoli_Strept_intra_species_only_matrix.h5 differ diff --git a/rs_jaccard/test_data/completed_matrix.h5 b/rs_jaccard/test_data/completed_matrix.h5 new file mode 100644 index 0000000..e2ab054 Binary files /dev/null and b/rs_jaccard/test_data/completed_matrix.h5 differ diff --git a/rs_jaccard/test_data/completed_matrix_clean.h5 b/rs_jaccard/test_data/completed_matrix_clean.h5 new file mode 100644 index 0000000..e2ab054 Binary files /dev/null and b/rs_jaccard/test_data/completed_matrix_clean.h5 differ diff --git a/rs_jaccard/test_data/full_matrix_from_scratch.h5 b/rs_jaccard/test_data/full_matrix_from_scratch.h5 new file mode 100644 index 0000000..7e0c8bb Binary files /dev/null and b/rs_jaccard/test_data/full_matrix_from_scratch.h5 differ