Skip to content

Commit 7a1e44e

Browse files
Add batch_correlation_features Rust function with GIL release
Phase 1b: batch_correlation_features() replaces the entire run_peptidoform_correlation() Python function with a single Rust call, eliminating 10 Python→Rust round trips per peptidoform. - New: rust/mumdia_rs/src/batch.rs with batch_correlation_features_impl() - All Rust functions now use py.allow_threads() for GIL release - mumdia.py dispatches run_peptidoform_correlation() to Rust when available - Phase 2 (ThreadPoolExecutor) tested but deferred: Python overhead in process_peptidoform() (Polars, DIA-NN pandas) still holds the GIL, making threading 2x slower. Threading will help once more of the pipeline moves to Rust. 16 Rust unit tests, 243 Python tests passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f8ea524 commit 7a1e44e

3 files changed

Lines changed: 312 additions & 21 deletions

File tree

mumdia.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,24 @@ def run_peptidoform_correlation(
672672
mse_avg_pred_intens_total,
673673
) = correlations_list
674674

675+
# Fast path: single Rust call replaces 10 Python→Rust round trips
676+
if _RUST_BACKEND:
677+
feature_dict = mumdia_rs.batch_correlation_features(
678+
np.asarray(correlations, dtype=np.float64),
679+
np.asarray(correlation_result_counts, dtype=np.float64),
680+
np.asarray(correlation_matrix_psm_ids, dtype=np.float64),
681+
np.asarray(correlation_matrix_frag_ids, dtype=np.float64),
682+
float(most_intens_cor),
683+
float(most_intens_cos),
684+
float(mse_avg_pred_intens),
685+
float(mse_avg_pred_intens_total),
686+
[float(x) for x in collect_distributions],
687+
[int(x) for x in collect_top],
688+
pad_size,
689+
)
690+
return pl.DataFrame(feature_dict)
691+
692+
# Fallback: Python path with 10 separate calls
675693
feature_dict = {}
676694
params = [
677695
(
@@ -737,8 +755,6 @@ def run_peptidoform_correlation(
737755
)
738756

739757
df = pl.DataFrame(feature_dict)
740-
# df.write_csv("debug/correlation_features.csv")
741-
742758
return df
743759

744760

@@ -1617,9 +1633,10 @@ def calculate_features(
16171633
# Pre-convert MS1 data to sorted numpy arrays for fast DIA-NN elution profiles
16181634
_prepare_diann_ms1(spectra_data)
16191635

1620-
# Sequential processing — all work is CPU-bound (numpy/pandas/polars) so the
1621-
# GIL makes ThreadPoolExecutor counterproductive (measured 3-6x slower than
1622-
# single-threaded due to thread contention). Sequential: ~13 it/s vs ~2 it/s.
1636+
# Sequential processing. Even with Rust GIL release, ThreadPoolExecutor is
1637+
# slower because process_peptidoform() still does significant Python work
1638+
# (Polars aggregation, DIA-NN pandas conversion, dict building) that holds
1639+
# the GIL. Threading will only help once more of the pipeline is in Rust.
16231640
pin_in = [
16241641
process_peptidoform(args)
16251642
for args in tqdm(peptidoform_args, desc="Processing peptidoforms")

rust/mumdia_rs/src/batch.rs

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
/// Batch correlation feature extraction — replaces the entire
2+
/// `run_peptidoform_correlation()` Python function with a single Rust call.
3+
/// Eliminates ~10 Python→Rust round trips and all intermediate allocations.
4+
use std::collections::HashMap;
5+
6+
use crate::percentiles::compute_percentiles_impl;
7+
use crate::topk::compute_top_impl;
8+
9+
/// Compute all correlation-based features for one peptidoform in a single call.
10+
///
11+
/// This replicates the 10 calls to `add_feature_columns_nb()` that
12+
/// `run_peptidoform_correlation()` makes in Python, returning a flat
13+
/// feature name → value map.
14+
pub fn batch_correlation_features_impl(
15+
correlations: &[f64],
16+
correlation_counts: &[f64],
17+
corr_matrix_psm: &[f64],
18+
corr_matrix_frag: &[f64],
19+
most_intens_cor: f64,
20+
most_intens_cos: f64,
21+
mse_avg: f64,
22+
mse_avg_total: f64,
23+
percentile_targets: &[f64],
24+
top_k_targets: &[usize],
25+
pad_size: usize,
26+
) -> HashMap<String, f64> {
27+
let mut features = HashMap::with_capacity(80);
28+
29+
// Helper: add percentile features
30+
let add_percentiles = |features: &mut HashMap<String, f64>,
31+
data: &[f64],
32+
prefix: &str,
33+
targets: &[f64]| {
34+
let values = compute_percentiles_impl(data, targets);
35+
for (i, &t) in targets.iter().enumerate() {
36+
let t_int = t as i64;
37+
features.insert(format!("{prefix}_{t_int}"), values[i]);
38+
}
39+
};
40+
41+
// Helper: add percentile features with index tracking
42+
let add_percentiles_with_idx = |features: &mut HashMap<String, f64>,
43+
data: &[f64],
44+
prefix: &str,
45+
targets: &[f64],
46+
idx_lookup: &[f64]| {
47+
// Sort data and track original indices for index lookup
48+
let n = data.len();
49+
if n == 0 {
50+
for &t in targets {
51+
let t_int = t as i64;
52+
features.insert(format!("{prefix}_{t_int}"), 0.0);
53+
features.insert(format!("{prefix}_{t_int}_idx"), 0.0);
54+
}
55+
return;
56+
}
57+
58+
let mut sorted = data.to_vec();
59+
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
60+
61+
for &t in targets {
62+
let t_int = t as i64;
63+
// Compute percentile value
64+
let pos = (t / 100.0) * (n as f64 - 1.0);
65+
let lower = pos as usize;
66+
let upper = if lower >= n - 1 { lower } else { lower + 1 };
67+
let weight = pos - lower as f64;
68+
let value = sorted[lower] * (1.0 - weight) + sorted[upper] * weight;
69+
features.insert(format!("{prefix}_{t_int}"), value);
70+
71+
// Find nearest index in original data for this percentile value
72+
let nearest_idx = if !idx_lookup.is_empty() && lower < idx_lookup.len() {
73+
idx_lookup[lower]
74+
} else {
75+
0.0
76+
};
77+
features.insert(format!("{prefix}_{t_int}_idx"), nearest_idx);
78+
}
79+
};
80+
81+
// Helper: add top-k features
82+
let add_top = |features: &mut HashMap<String, f64>,
83+
data: &[f64],
84+
prefix: &str,
85+
targets: &[usize],
86+
pad: usize| {
87+
let top_values = compute_top_impl(data, pad);
88+
for &t in targets {
89+
let val = if t > 0 && t <= top_values.len() {
90+
top_values[t - 1]
91+
} else {
92+
0.0
93+
};
94+
features.insert(format!("{prefix}_{t}"), val);
95+
}
96+
};
97+
98+
// === 10 feature groups matching run_peptidoform_correlation() ===
99+
100+
// 1. PSM correlation matrix distribution (percentiles)
101+
add_percentiles(
102+
&mut features,
103+
corr_matrix_psm,
104+
"distribution_correlation_matrix_psm_ids",
105+
percentile_targets,
106+
);
107+
108+
// 2. Fragment correlation matrix distribution (percentiles)
109+
add_percentiles(
110+
&mut features,
111+
corr_matrix_frag,
112+
"distribution_correlation_matrix_frag_ids",
113+
percentile_targets,
114+
);
115+
116+
// 3. Individual correlations distribution (percentiles with index tracking)
117+
add_percentiles_with_idx(
118+
&mut features,
119+
correlations,
120+
"distribution_correlation_individual",
121+
percentile_targets,
122+
correlation_counts,
123+
);
124+
125+
// 4. Top PSM correlations
126+
add_top(
127+
&mut features,
128+
corr_matrix_psm,
129+
"top_correlation_matrix_psm_ids",
130+
top_k_targets,
131+
pad_size,
132+
);
133+
134+
// 5. Top fragment correlations
135+
add_top(
136+
&mut features,
137+
corr_matrix_frag,
138+
"top_correlation_matrix_frag_ids",
139+
top_k_targets,
140+
pad_size,
141+
);
142+
143+
// 6. Apex cosine similarity (single value)
144+
features.insert("top_correlation_cos_1".to_string(), most_intens_cos);
145+
146+
// 7. Apex Pearson (overwrites cosine — matching the Python bug)
147+
features.insert("top_correlation_cos_1".to_string(), most_intens_cor);
148+
149+
// 8. MSE average
150+
features.insert("mse_avg_pred_intens_1".to_string(), mse_avg);
151+
152+
// 9. MSE total
153+
features.insert("mse_avg_pred_intens_total_1".to_string(), mse_avg_total);
154+
155+
// 10. Top individual correlations
156+
add_top(
157+
&mut features,
158+
correlations,
159+
"top_correlation_individual",
160+
top_k_targets,
161+
pad_size,
162+
);
163+
164+
features
165+
}
166+
167+
#[cfg(test)]
168+
mod tests {
169+
use super::*;
170+
171+
#[test]
172+
fn test_batch_basic() {
173+
let correlations = vec![0.9, 0.8, 0.7, 0.6, 0.5];
174+
let counts = vec![5.0, 4.0, 3.0, 2.0, 1.0];
175+
let psm_matrix = vec![0.81, 0.64, 0.49, 0.36, 0.25];
176+
let frag_matrix = vec![0.9, 0.7, 0.5, 0.3, 0.1];
177+
let percentiles = vec![0.0, 25.0, 50.0, 75.0, 100.0];
178+
let top_k: Vec<usize> = (1..=10).collect();
179+
180+
let result = batch_correlation_features_impl(
181+
&correlations,
182+
&counts,
183+
&psm_matrix,
184+
&frag_matrix,
185+
0.85, // most_intens_cor
186+
0.90, // most_intens_cos
187+
0.1, // mse_avg
188+
0.15, // mse_avg_total
189+
&percentiles,
190+
&top_k,
191+
10,
192+
);
193+
194+
// Check some expected feature names exist
195+
assert!(result.contains_key("distribution_correlation_matrix_psm_ids_0"));
196+
assert!(result.contains_key("distribution_correlation_matrix_psm_ids_50"));
197+
assert!(result.contains_key("top_correlation_matrix_psm_ids_1"));
198+
assert!(result.contains_key("top_correlation_individual_1"));
199+
assert!(result.contains_key("mse_avg_pred_intens_1"));
200+
assert!(result.contains_key("mse_avg_pred_intens_total_1"));
201+
202+
// top_correlation_cos_1 should be most_intens_cor (the bug: Pearson overwrites cosine)
203+
assert!((result["top_correlation_cos_1"] - 0.85).abs() < 1e-12);
204+
assert!((result["mse_avg_pred_intens_1"] - 0.1).abs() < 1e-12);
205+
206+
// Top-1 PSM correlation should be the largest value
207+
assert!((result["top_correlation_matrix_psm_ids_1"] - 0.81).abs() < 1e-12);
208+
}
209+
210+
#[test]
211+
fn test_batch_empty_arrays() {
212+
let empty: Vec<f64> = vec![];
213+
let percentiles = vec![0.0, 50.0, 100.0];
214+
let top_k: Vec<usize> = vec![1, 2, 3];
215+
216+
let result = batch_correlation_features_impl(
217+
&empty, &empty, &empty, &empty, 0.0, 0.0, 0.0, 0.0, &percentiles, &top_k, 10,
218+
);
219+
220+
// All percentile features should be 0.0 for empty arrays
221+
assert_eq!(result["distribution_correlation_matrix_psm_ids_0"], 0.0);
222+
assert_eq!(result["top_correlation_matrix_psm_ids_1"], 0.0);
223+
}
224+
}

0 commit comments

Comments
 (0)