Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added database.h5
Binary file not shown.
89 changes: 88 additions & 1 deletion rs_jaccard/src/jaccard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ use std::cmp::Ordering;
use std::io::{stdout, Write};
use anyhow::Result;
use csv;
use log::{info};
use log::info;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::collections::HashMap;
pub type CoordType = u32;
pub type BoundType = usize;
pub type ScoreType = f32;
Expand Down Expand Up @@ -701,6 +702,92 @@ pub fn precompute_similarity_candidates(
.collect()
}

/// Complete a partial distance matrix by calculating missing values
/// Missing values are indicated by NaN -- Given they will be concatenated via numpy
pub fn complete_partial_matrix(
mut matrix: Vec<Vec<ScoreType>>,
matrix_ids: &[String],
coords: &[CoordType],
bounds: &[BoundType],
sig_ids: &[String],
) -> Result<(Vec<Vec<ScoreType>>, usize)> {
let n = matrix.len();

if matrix.iter().any(|row| row.len() != n) {
anyhow::bail!("Matrix must be square");
}

// Dimensions need to match
if matrix_ids.len() != n {
anyhow::bail!("Matrix ID count ({}) doesn't match matrix size ({})", matrix_ids.len(), n);
}

//And so do sigs
if sig_ids.len() != bounds.len() - 1 {
anyhow::bail!("Signature ID count ({}) doesn't match bounds length ({})", sig_ids.len(), bounds.len() - 1);
}

// Map matrix IDs to signature indices
let sig_id_to_index: HashMap<String, usize> = sig_ids
.iter()
.enumerate()
.map(|(i, id)| (id.clone(), i))
.collect();

// Find all missing pairs
let mut missing_pairs = Vec::new();

for i in 0..n {
for j in 0..n {
if i != j && matrix[i][j].is_nan() {
if let (Some(&sig_i), Some(&sig_j)) = (
sig_id_to_index.get(&matrix_ids[i]),
sig_id_to_index.get(&matrix_ids[j])
) {
missing_pairs.push((i, j, sig_i, sig_j));
}
}
}
}

let missing_count = missing_pairs.len();
info!("Found {} missing distance pairs to calculate", missing_count);

if missing_count > 0 {
info!("Computing missing distances...");
let start_time = std::time::Instant::now();

let distances: Vec<(usize, usize, ScoreType)> = missing_pairs
.into_par_iter()
.map(|(matrix_i, matrix_j, sig_i, sig_j)| {
let begin_i = bounds[sig_i];
let end_i = bounds[sig_i + 1];
let coords_i = &coords[begin_i..end_i];

let begin_j = bounds[sig_j];
let end_j = bounds[sig_j + 1];
let coords_j = &coords[begin_j..end_j];

let distance = jaccard_distance_core(coords_i, coords_j);
(matrix_i, matrix_j, distance)
})
.collect();

// Update matrix with calculated distances
for (i, j, distance) in distances {
matrix[i][j] = distance;
// Enforce symmetry
if i != j {
matrix[j][i] = distance;
}
}

info!("Completed {} distance calculations", missing_count);
}

Ok((matrix, missing_count))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
38 changes: 36 additions & 2 deletions rs_jaccard/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use jaccard::*;
use signatures::*;
use matrix_io::{
detect_format_from_extension,
hdf5::{save_matrix as save_matrix_hdf5, StreamWriter as Hdf5StreamWriter},
hdf5::{save_matrix as save_matrix_hdf5, load_matrix as load_matrix_hdf5, StreamWriter as Hdf5StreamWriter},
csv::{save_distances, save_similar_pairs, save_matrix_with_ids as save_matrix_csv_with_ids,
save_query_ref_matrix as save_query_ref_matrix_csv, write_lsh_results},
binary::save_matrix_with_ids as save_matrix_binary_with_ids,
Expand Down Expand Up @@ -179,14 +179,27 @@ enum Commands {
#[arg(long)]
file: PathBuf,
},

/// Complete partial distance matrix by calculating missing values
CompleteMatrix {
/// Partial distance matrix with intra-species distances
#[arg(long)]
partial_matrix: PathBuf,
#[arg(long)]
signatures: PathBuf,
#[arg(short, long)]
output: PathBuf,
#[arg(short, long)]
threads: Option<usize>,
},
}

fn main() -> Result<()> {
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
let cli = Cli::parse();

match &cli.command {
Commands::Query { query, reference, bounds, output, threads, query_idx } => {
Commands::Query { query, reference, bounds, output, threads, query_idx: _ } => {
if let Some(t) = threads {
rayon::ThreadPoolBuilder::new().num_threads(*t).build_global()?;
}
Expand Down Expand Up @@ -405,6 +418,27 @@ fn main() -> Result<()> {
Commands::DebugH5 { file } => {
debug_hdf5_ids(file)?;
},

Commands::CompleteMatrix { partial_matrix, signatures, output, threads } => {
if let Some(t) = threads {
rayon::ThreadPoolBuilder::new().num_threads(*t).build_global()?;
}

info!("Loading partial matrix from {}", partial_matrix.display());
let (matrix, matrix_ids) = load_matrix_hdf5(partial_matrix)?;

info!("Loading signatures from {}", signatures.display());
let sig_data = read_signatures(signatures)?;
let (coords, bounds, sig_ids) = load_signatures_for_jaccard(&sig_data)?;

let (completed_matrix, missing_count) = complete_partial_matrix(
matrix, &matrix_ids, &coords, &bounds, &sig_ids
)?;

info!("Saving completed matrix to {}", output.display());
save_matrix_hdf5(&completed_matrix, &matrix_ids, output)?;
info!("Matrix completion finished successfully. Calculated {} missing distances.", missing_count);
},
}

Ok(())
Expand Down
18 changes: 10 additions & 8 deletions rs_jaccard/src/matrix_io/hdf5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ pub fn save_matrix(

/// Load matrix from HDF5 format
/// Returns (matrix, genome_ids)
/// Use only for testing, don't want compiler to complain
#[cfg(test)]
pub fn load_matrix(input_path: &Path) -> Result<(Vec<Vec<ScoreType>>, Vec<String>)> {
info!("Loading matrix and genome IDs from HDF5 file: {}", input_path.display());
let file = Hdf5File::open(input_path)
Expand All @@ -101,18 +99,22 @@ pub fn load_matrix(input_path: &Path) -> Result<(Vec<Vec<ScoreType>>, Vec<String

let ids_dataset = file.dataset("genome_ids")
.context("Failed to open 'genome_ids' dataset")?;
let ascii_ids = ids_dataset.read_1d::<hdf5::types::VarLenAscii>()
.context("Failed to read genome IDs")?;

// Convert to Vec<Vec<f32>> format
let matrix: Vec<Vec<ScoreType>> = matrix_array.outer_iter()
.map(|row| row.to_vec())
.collect();

// Convert ASCII IDs to String
let ids: Vec<String> = ascii_ids.iter()
.map(|ascii_id| ascii_id.to_string())
.collect();
// Read genome IDs, trying both ASCII and Unicode string formats
let ids: Vec<String> = if let Ok(ascii_ids) = ids_dataset.read_1d::<hdf5::types::VarLenAscii>() {
ascii_ids.iter().map(|ascii_id| {
String::from_utf8_lossy(ascii_id.as_bytes()).to_string()
}).collect()
} else if let Ok(unicode_ids) = ids_dataset.read_1d::<hdf5::types::VarLenUnicode>() {
unicode_ids.iter().map(|unicode_id| unicode_id.to_string()).collect()
} else {
return Err(anyhow::anyhow!("Failed to read genome IDs in any supported format"));
};

// Validate consistency
let n = matrix.len();
Expand Down
Binary file added rs_jaccard/test_data/Ecoli_Strep_signatures.h5
Binary file not shown.
Binary file not shown.
Binary file added rs_jaccard/test_data/completed_matrix.h5
Binary file not shown.
Binary file added rs_jaccard/test_data/completed_matrix_clean.h5
Binary file not shown.
Binary file added rs_jaccard/test_data/full_matrix_from_scratch.h5
Binary file not shown.
Loading