From 6395f6ee5c7c86ce9b86197652736ecd0fd35a4b Mon Sep 17 00:00:00 2001 From: Edi Prifti Date: Wed, 25 Mar 2026 23:01:12 +0100 Subject: [PATCH 1/3] =?UTF-8?q?WIP:=20Data.y=20Vec=20=E2=86=92=20Vec=20refactor=20=E2=80=94=20data.rs,=20voting,=20mcmc,=20csv?= =?UTF-8?q?=5Freport,=20lib=20done?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/bayesian_mcmc.rs | 9 ++- src/csv_report.rs | 8 +-- src/data.rs | 91 +++++++++++++------------- src/lib.rs | 2 +- src/voting.rs | 149 ++++++++++++++++++++++--------------------- 5 files changed, 132 insertions(+), 127 deletions(-) diff --git a/src/bayesian_mcmc.rs b/src/bayesian_mcmc.rs index 113f1e8..1a9686e 100644 --- a/src/bayesian_mcmc.rs +++ b/src/bayesian_mcmc.rs @@ -403,7 +403,7 @@ impl BayesPred { let [a, b, c] = ind.get_betas(); let mut log_likelihood = 0.0; for (i_sample, z_sample) in z.iter().enumerate() { - let y_sample = self.data.y[i_sample] as f64; + let y_sample = self.data.y[i_sample]; let value = z_sample[0] * a + z_sample[1] * b + z_sample[2] * c; if y_sample == 1.0 { @@ -833,9 +833,8 @@ fn lasso_prescreen(data: &mut Data, target: usize) { } } - let y_f64: Vec = data.y.iter().map(|&v| v as f64).collect(); - let y_mean = y_f64.iter().sum::() / n_samples as f64; - let y_centered: Vec = y_f64.iter().map(|v| v - y_mean).collect(); + let y_mean = data.y.iter().sum::() / n_samples as f64; + let y_centered: Vec = data.y.iter().map(|v| v - y_mean).collect(); // Run LASSO at moderate alpha to select features let w = crate::lasso::coordinate_descent_pub(&x_cols, &y_centered, 0.01, 1.0, 500, 1e-4, None); @@ -1417,7 +1416,7 @@ pub fn mcmc( // Each Gpredomics function currently handles class 2 as unknown // As MCMC does not, remove unknown sample before analysis - let mut data = data.remove_class(2); + let mut data = data.remove_class(2.0); // Selecting features data.select_features(param); diff --git a/src/csv_report.rs b/src/csv_report.rs index 7df018b..a2bb14c 100644 --- a/src/csv_report.rs +++ b/src/csv_report.rs @@ -551,14 +551,14 @@ pub fn export_jury_csv(exp: &Experiment, writer: &mut impl Write) -> Result<(), // We need to recompute using the jury's predict on train data let (train_f1, train_mcc, train_ppv, train_npv, train_gmean) = { let (pred_classes, _scores) = jury.predict(&exp.train_data); - let filtered: Vec<(u8, u8)> = pred_classes + let filtered: Vec<(u8, f64)> = pred_classes .iter() .zip(exp.train_data.y.iter()) .filter(|(&p, _)| p != 2) .map(|(&p, &y)| (p, y)) .collect(); if !filtered.is_empty() { - let (preds, trues): (Vec, Vec) = filtered.into_iter().unzip(); + let (preds, trues): (Vec, Vec) = filtered.into_iter().unzip(); let (_, _, _, add) = compute_metrics_from_classes(&preds, &trues, [true; 5]); ( add.f1_score.unwrap_or(f64::NAN), @@ -586,7 +586,7 @@ pub fn export_jury_csv(exp: &Experiment, writer: &mut impl Write) -> Result<(), test_rej, ) = if let Some(ref td) = exp.test_data { let (pred_classes, scores) = jury.predict(td); - let filtered: Vec<(f64, u8, u8)> = scores + let filtered: Vec<(f64, u8, f64)> = scores .iter() .zip(pred_classes.iter()) .zip(td.y.iter()) @@ -603,7 +603,7 @@ pub fn export_jury_csv(exp: &Experiment, writer: &mut impl Write) -> Result<(), pred_classes.iter().filter(|&&c| c == 2).count() as f64 / pred_classes.len() as f64; if !filtered.is_empty() { - let (scores_f, preds_f, trues_f): (Vec, Vec, Vec) = + let (scores_f, preds_f, trues_f): (Vec, Vec, Vec) = filtered.into_iter().fold( (Vec::new(), Vec::new(), Vec::new()), |(mut s, mut p, mut t), (sc, pr, tr)| { diff --git a/src/data.rs b/src/data.rs index f8b963f..b802284 100644 --- a/src/data.rs +++ b/src/data.rs @@ -74,7 +74,7 @@ pub struct Data { pub feature_significance: HashMap, /// Sample real classes - pub y: Vec, + pub y: Vec, /// Sample names pub samples: Vec, /// Samples count @@ -230,7 +230,7 @@ impl Data { if let Some(sample_name) = fields.next() { // Second field is the target value if let Some(value) = fields.next() { - let target: u8 = value.parse()?; + let target: f64 = value.parse()?; y_map.insert(sample_name.to_string(), target); } } @@ -246,7 +246,7 @@ impl Data { "No y value available for {}. Setting y to 2 for this sample.", sample_name ); - &2 + &2.0 }) }) .collect(); @@ -476,13 +476,14 @@ impl Data { /// ``` pub fn inverse_classes(&mut self) { for label in &mut self.y { - match *label { - 0 => *label = 1, - 1 => *label = 0, - 2 => *label = 2, - _ => { - warn!("Unknown classes : {}. Passed.", *label); - } + if *label == 0.0 { + *label = 1.0; + } else if *label == 1.0 { + *label = 0.0; + } else if *label == 2.0 { + *label = 2.0; + } else { + warn!("Unknown classes : {}. Passed.", *label); } } @@ -750,7 +751,7 @@ impl Data { let mut count_1: usize = 0; let class_0: Vec = (0..self.sample_len) - .filter(|i| self.y[*i] == 0) + .filter(|i| self.y[*i] == 0.0) .map(|i| { if self.X.contains_key(&(i, j)) { count_0 += 1; @@ -762,7 +763,7 @@ impl Data { .collect(); let class_1: Vec = (0..self.sample_len) - .filter(|i| self.y[*i] == 1) + .filter(|i| self.y[*i] == 1.0) .map(|i| { if self.X.contains_key(&(i, j)) { count_1 += 1; @@ -837,7 +838,7 @@ impl Data { let mut count_1: usize = 0; let class_0: Vec = (0..self.sample_len) - .filter(|i| self.y[*i] == 0) + .filter(|i| self.y[*i] == 0.0) .map(|i| { if self.X.contains_key(&(i, j)) { count_0 += 1; @@ -849,7 +850,7 @@ impl Data { .collect(); let class_1: Vec = (0..self.sample_len) - .filter(|i| self.y[*i] == 1) + .filter(|i| self.y[*i] == 1.0) .map(|i| { if self.X.contains_key(&(i, j)) { count_1 += 1; @@ -973,14 +974,14 @@ impl Data { let mut class_1_values: Vec = Vec::new(); for i in 0..self.sample_len { - if self.y[i] == 0 { + if self.y[i] == 0.0 { if self.X.contains_key(&(i, j)) && self.X[&(i, j)] >= 0.0 { class_0_present += 1; class_0_values.push(self.X[&(i, j)]); } else { class_0_absent += 1; } - } else if self.y[i] == 1 { + } else if self.y[i] == 1.0 { if self.X.contains_key(&(i, j)) && self.X[&(i, j)] >= 0.0 { class_1_present += 1; class_1_values.push(self.X[&(i, j)]); @@ -1389,9 +1390,9 @@ impl Data { /// # use gpredomics::data::Data; /// let mut data = Data::new(); /// data.load_data("./samples/Qin2014/Xtrain.tsv", "./samples/Qin2014/Ytrain.tsv", false).unwrap(); - /// let filtered_data = data.remove_class(2); + /// let filtered_data = data.remove_class(2.0); /// ``` - pub fn remove_class(&mut self, class_to_remove: u8) -> Data { + pub fn remove_class(&mut self, class_to_remove: f64) -> Data { let indices_to_keep: Vec = self .y .iter() @@ -1699,7 +1700,7 @@ mod tests { feature_class.insert(1, 1); Data { X, - y: vec![0, 1, 0, 1, 1, 1], + y: vec![0.0, 1.0, 0.0, 1.0, 1.0, 1.0], features: vec!["feature1".to_string(), "feature2".to_string()], samples: vec![ "sample1".to_string(), @@ -1741,8 +1742,8 @@ mod tests { } } - let y: Vec = (0..num_samples) - .map(|_| if rng.gen::() > 0.5 { 1 } else { 0 }) + let y: Vec = (0..num_samples) + .map(|_| if rng.gen::() > 0.5 { 1.0 } else { 0.0 }) .collect(); Data { @@ -1784,12 +1785,12 @@ mod tests { } } - let y: Vec = (0..num_samples) + let y: Vec = (0..num_samples) .map(|sample| { if X.get(&(sample, 0)).cloned().unwrap_or(0.0) > 0.5 { - 1 + 1.0 } else { - 0 + 0.0 } }) .collect(); @@ -1834,7 +1835,7 @@ mod tests { } } - data.y = vec![0, 1, 0, 1, 0]; + data.y = vec![0.0, 1.0, 0.0, 1.0, 0.0]; data.feature_selection = feature_indices.to_vec(); data } @@ -1874,7 +1875,7 @@ mod tests { Data { X, - y: vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 0], // Vraies étiquettes + y: vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], // Vraies étiquettes features: vec!["feature1".to_string(), "feature2".to_string()], samples: vec![ "sample1".to_string(), @@ -1924,7 +1925,7 @@ mod tests { Data { X, - y: vec![1, 0, 1, 0, 0], // Vraies étiquettes + y: vec![1.0, 0.0, 1.0, 0.0, 0.0], // Vraies étiquettes features: vec!["feature1".to_string(), "feature2".to_string()], samples: vec![ "sample1".to_string(), @@ -1969,7 +1970,7 @@ mod tests { Data { X, - y: vec![0, 0, 0, 1, 0], // Vraies étiquettes + y: vec![0.0, 0.0, 0.0, 1.0, 0.0], // Vraies étiquettes features: vec!["feature1".to_string(), "feature2".to_string()], samples: vec![ "sample6".to_string(), @@ -2007,7 +2008,7 @@ mod tests { assert_eq!(format!("{:x}", hash), "adba327f62ffab0a8d43c1aa3a6c20e630783d3b103dd103f28b9e23ab51eb18", "the test X hash isn't the same as generated in the past, indicating a reproducibility problem linked either to the load_data function or to the modification of ./tests/X.tsv"); - assert_eq!(data_test.y, [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1], + assert_eq!(data_test.y, [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0], "the test y are not the same as generated in the past, indicating a reproducibility problem linked either to the load_data function or to the modification of ./tests/y.tsv"); assert_eq!(data_test.features, ["msp_0001", "msp_0002", "msp_0003", "msp_0004", "msp_0005", "msp_0006", "msp_0007", "msp_0008", "msp_0009", "msp_0010"], "the test X features isn't the same as generated in the past, indicating a reproducibility problem linked either to the load_data function or to the modification of ./tests/X.tsv"); @@ -2210,7 +2211,7 @@ mod tests { ); assert_eq!( subset_data.y, - vec![0, 1], + vec![0.0, 1.0], "the subset y should be composed of the selected-samples y" ); assert_eq!( @@ -2233,7 +2234,7 @@ mod tests { let original_data = Data::test(); let subset_data = original_data.subset(vec![]); let expected_X: HashMap<(usize, usize), f64> = HashMap::new(); - let expected_y: Vec = vec![]; + let expected_y: Vec = vec![]; let expected_samples: Vec = vec![]; assert_eq!( @@ -2343,7 +2344,7 @@ mod tests { fn test_add_basic() { let mut data1 = Data { X: HashMap::from([((0, 0), 0.5), ((1, 0), 0.8)]), - y: vec![0, 1], + y: vec![0.0, 1.0], features: vec!["feature1".to_string()], samples: vec!["sample1".to_string(), "sample2".to_string()], feature_class: HashMap::new(), @@ -2358,7 +2359,7 @@ mod tests { let data2 = Data { X: HashMap::from([((0, 0), 0.3), ((1, 0), 0.6)]), - y: vec![1, 0], + y: vec![1.0, 0.0], features: vec!["feature1".to_string()], samples: vec!["sample3".to_string(), "sample4".to_string()], feature_class: HashMap::new(), @@ -2375,7 +2376,7 @@ mod tests { let expected_X: HashMap<(usize, usize), f64> = HashMap::from([((0, 0), 0.5), ((1, 0), 0.8), ((2, 0), 0.3), ((3, 0), 0.6)]); - let expected_y = vec![0, 1, 1, 0]; + let expected_y = vec![0.0, 1.0, 1.0, 0.0]; let expected_samples = vec![ "sample1".to_string(), "sample2".to_string(), @@ -2474,7 +2475,7 @@ mod tests { assert_eq!(data.feature_len, 3); assert_eq!(data.samples, vec!["Sample1", "Sample2", "Sample3"]); assert_eq!(data.features, vec!["Feature1", "Feature2", "Feature3"]); - assert_eq!(data.y, vec![0, 1, 1]); + assert_eq!(data.y, vec![0.0, 1.0, 1.0]); assert_eq!(data.X.get(&(0, 0)), Some(&0.5)); assert_eq!(data.X.get(&(1, 0)), None); assert_eq!(data.X.get(&(2, 1)), Some(&1.5)); @@ -2558,7 +2559,7 @@ mod tests { data.load_data(x_path.to_str().unwrap(), y_path.to_str().unwrap(), true) .unwrap(); - assert_eq!(data.y, vec![0, 1, 1]); + assert_eq!(data.y, vec![0.0, 1.0, 1.0]); assert_eq!(data.samples, vec!["Sample1", "Sample2", "Sample3"]); cleanup_test_files(&x_path, &y_path); @@ -3219,14 +3220,14 @@ mod tests { let test_ratio = 0.25; let (train, test) = data.train_test_split(test_ratio, &mut rng, None); - let orig_class0 = data.y.iter().filter(|&&y| y == 0).count(); - let orig_class1 = data.y.iter().filter(|&&y| y == 1).count(); + let orig_class0 = data.y.iter().filter(|&&y| y == 0.0).count(); + let orig_class1 = data.y.iter().filter(|&&y| y == 1.0).count(); - let train_class0 = train.y.iter().filter(|&&y| y == 0).count(); - let train_class1 = train.y.iter().filter(|&&y| y == 1).count(); + let train_class0 = train.y.iter().filter(|&&y| y == 0.0).count(); + let train_class1 = train.y.iter().filter(|&&y| y == 1.0).count(); - let test_class0 = test.y.iter().filter(|&&y| y == 0).count(); - let test_class1 = test.y.iter().filter(|&&y| y == 1).count(); + let test_class0 = test.y.iter().filter(|&&y| y == 0.0).count(); + let test_class1 = test.y.iter().filter(|&&y| y == 1.0).count(); assert_eq!( train_class0 + test_class0, @@ -3385,7 +3386,7 @@ mod tests { // Overwrite y to make class-balanced data.y = (0..total_samples) - .map(|i| if i < total_samples / 2 { 0 } else { 1 }) + .map(|i| if i < total_samples / 2 { 0.0 } else { 1.0 }) .collect(); // Batch annotation alternating A/B @@ -3764,7 +3765,7 @@ mod tests { fn test_traintestsplit_panic_on_missing_annotations() { let mut data = Data::test_with_these_features(&[0, 1, 2, 3]); data.sample_len = 10; - data.y = vec![0, 1, 0, 1, 0, 1, 0, 1, 0, 1]; + data.y = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]; data.sample_annotations = None; let mut rng = ChaCha8Rng::seed_from_u64(42); @@ -3776,7 +3777,7 @@ mod tests { fn test_traintestsplit_panic_on_incomplete_annotation_line() { let mut data = Data::test_with_these_features(&[0, 1, 2, 3]); data.sample_len = 3; - data.y = vec![0, 1, 1]; + data.y = vec![0.0, 1.0, 1.0]; let mut sample_tags = HashMap::new(); sample_tags.insert(0, vec!["control".to_string()]); sample_tags.insert(1, vec!["treatment".to_string(), "batch1".to_string()]); diff --git a/src/lib.rs b/src/lib.rs index 904c552..d81dc28 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -219,7 +219,7 @@ pub fn run(param: &Param, running: Arc) -> Experiment { td.inverse_classes(); } if param.general.algo == "mcmc" { - td = td.remove_class(2); + td = td.remove_class(2.0); } if data.check_compatibility(&td) { test_data = Some(td); diff --git a/src/voting.rs b/src/voting.rs index e882268..4a6e870 100644 --- a/src/voting.rs +++ b/src/voting.rs @@ -382,7 +382,7 @@ impl Jury { let (pred_classes, scores) = self.predict(data); - let filtered_data: Vec<(f64, u8, u8)> = scores + let filtered_data: Vec<(f64, u8, f64)> = scores .iter() .zip(pred_classes.iter()) .zip(data.y.iter()) @@ -398,7 +398,7 @@ impl Jury { let rejection_rate = self.compute_rejection_rate(&pred_classes); if !filtered_data.is_empty() { - let (scores_filtered, pred_filtered, true_filtered): (Vec, Vec, Vec) = + let (scores_filtered, pred_filtered, true_filtered): (Vec, Vec, Vec) = filtered_data.into_iter().fold( (Vec::new(), Vec::new(), Vec::new()), |(mut scores, mut preds, mut trues), (s, p, t)| { @@ -487,13 +487,13 @@ impl Jury { self.threshold_window, ); - let filtered_data: Vec<(u8, u8)> = predictions + let filtered_data: Vec<(u8, f64)> = predictions .1 .iter() .zip(predictions.0.iter()) .zip(data.y.iter()) .filter_map(|((&score, &pred_class), &true_class)| { - if score >= 0.0 && score <= 1.0 && pred_class != 2 && true_class != 2 { + if score >= 0.0 && score <= 1.0 && pred_class != 2 && true_class != 2.0 { Some((pred_class, true_class)) } else { None @@ -502,7 +502,7 @@ impl Jury { .collect(); if !filtered_data.is_empty() { - let (pred_classes, true_classes): (Vec, Vec) = + let (pred_classes, true_classes): (Vec, Vec) = filtered_data.into_iter().unzip(); let (_, sensitivity, specificity, _) = compute_metrics_from_classes(&pred_classes, &true_classes, [false; 5]); @@ -1079,7 +1079,7 @@ impl Jury { fn display_confusion_matrix( &self, predictions: &[u8], - true_labels: &[u8], + true_labels: &[f64], title: &str, ) -> String { let mut text = "".to_string(); @@ -1087,17 +1087,22 @@ impl Jury { (0, 0, 0, 0, 0, 0); for (pred, real) in predictions.iter().zip(true_labels.iter()) { - match (*pred, *real) { - (1, 1) => tp += 1, - (0, 0) => tn += 1, - (1, 0) => fp += 1, - (0, 1) => fn_ += 1, - (2, 1) => rp_abstentions += 1, - (2, 0) => rn_abstentions += 1, - _ => warn!( - "Warning: Unexpected class values pred={}, real={}", - pred, real - ), + let p = *pred; + let r = *real; + if p == 1 && r == 1.0 { + tp += 1; + } else if p == 0 && r == 0.0 { + tn += 1; + } else if p == 1 && r == 0.0 { + fp += 1; + } else if p == 0 && r == 1.0 { + fn_ += 1; + } else if p == 2 && r == 1.0 { + rp_abstentions += 1; + } else if p == 2 && r == 0.0 { + rn_abstentions += 1; + } else { + warn!("Warning: Unexpected class values pred={}, real={}", p, r); } } @@ -1552,7 +1557,7 @@ impl Jury { match predicted_class { 2 => abstentions.push(i), - _ if predicted_class != real_class => errors.push(i), + _ if (predicted_class as f64) != real_class => errors.push(i), _ => correct.push(i), } } @@ -1704,7 +1709,7 @@ impl Jury { if vote == 2 { output.push_str("\x1b[90m•\x1b[0m"); } else { - let vote_display = match data.y[sample_idx] == vote { + let vote_display = match data.y[sample_idx] == vote as f64 { true => &format!("\x1b[92m{}\x1b[0m", vote), false => &format!("\x1b[31m{}\x1b[0m", vote), }; @@ -2503,7 +2508,7 @@ mod tests { } /// Helper for creating data with a single sample - fn create_single_sample_data(true_class: u8) -> Data { + fn create_single_sample_data(true_class: f64) -> Data { let mut X = HashMap::new(); X.insert((0, 0), 1.0); // Sample 0, feature 0 @@ -2539,7 +2544,7 @@ mod tests { fn test_scenario_1_unanimous_majority_for() { // 5 experts unanimously vote 1, threshold 0.5 -> decision 1 let pop = create_population_with_votes(vec![1, 1, 1, 1, 1]); - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -2566,7 +2571,7 @@ mod tests { fn test_scenario_2_unanimous_majority_against() { // 5 experts unanimously vote 0, threshold 0.5 -> decision 0 let pop = create_population_with_votes(vec![0, 0, 0, 0, 0]); - let data = create_single_sample_data(0); + let data = create_single_sample_data(0.0); let mut jury = Jury::new( &pop, @@ -2596,7 +2601,7 @@ mod tests { fn test_scenario_3_simple_majority() { // 3 votes 1, 2 votes 0, threshold 0.5 -> decision 1 let pop = create_population_with_votes(vec![1, 1, 1, 0, 0]); - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -2623,7 +2628,7 @@ mod tests { fn test_scenario_4_abstention_due_to_threshold_window() { // 3 votes 1, 2 votes 0, threshold 0.6, window 10% -> abstention let pop = create_population_with_votes(vec![1, 1, 1, 0, 0]); - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -2655,7 +2660,7 @@ mod tests { fn test_scenario_5_consensus_success() { // 4 votes 1, 1 vote 0, consensus threshold 0.7 -> decision 1 let pop = create_population_with_votes(vec![1, 1, 1, 1, 0]); - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -2685,7 +2690,7 @@ mod tests { fn test_scenario_6_consensus_failure() { // 3 votes 1, 2 votes 0, consensus threshold 0.8 -> abstention let pop = create_population_with_votes(vec![1, 1, 1, 0, 0]); - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -2716,7 +2721,7 @@ mod tests { fn test_scenario_7_weighted_majority() { // Votes [1,1,0,0,0] with weights [2,2,1,1,1] -> decision 1 let pop = create_population_with_votes(vec![1, 1, 0, 0, 0]); - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); // To simulate different weights, we create a Jury with min_perf // that will filter certain experts based on their performance @@ -2755,7 +2760,7 @@ mod tests { fn test_scenario_8_perfect_tie_with_window() { // 2 votes 1, 2 votes 0, seuil 0.5, window 5% -> abstention let pop = create_population_with_votes(vec![1, 1, 0, 0]); - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -2787,7 +2792,7 @@ mod tests { fn test_majority_vs_consensus_different_outcomes() { // Same population, different voting methods -> different results let pop = create_population_with_votes(vec![1, 1, 1, 0, 0]); // 60% 1 - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); // Majority Test (threshold 0.5) let mut jury_majority = Jury::new( @@ -2822,7 +2827,7 @@ mod tests { #[test] fn test_threshold_window_boundary_cases() { let pop = create_population_with_votes(vec![1, 1, 1, 0, 0]); // Score = 0.6 - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); // Case 1: window too small -> no abstention let mut jury1 = Jury::new( @@ -2861,7 +2866,7 @@ mod tests { let pop = create_population_with_votes(vec![1, 1, 1, 0, 0]); // 60% pour // Data with true class = 1 - let data_positive = create_single_sample_data(1); + let data_positive = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -2883,7 +2888,7 @@ mod tests { assert_eq!(rejection_rate, 0.0, "No rejection expected"); // Test on negative class - let data_negative = create_single_sample_data(0); + let data_negative = create_single_sample_data(0.0); let (_, accuracy_neg, _, specificity_neg, _, _) = jury.compute_new_metrics(&data_negative); // Short predicts 1, true class = 0 -> False Positive @@ -2911,7 +2916,7 @@ mod tests { let data = Data { X, - y: vec![1, 0, 1], + y: vec![1.0, 0.0, 1.0], features: vec!["feature1".to_string()], samples: vec![ "sample1".to_string(), @@ -2952,7 +2957,7 @@ mod tests { fn test_edge_case_single_expert() { // One expert -> no collective vote, but tests logic let pop = create_population_with_votes(vec![1]); - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -3013,7 +3018,7 @@ mod tests { pop } - fn create_multi_sample_data(true_classes: Vec) -> Data { + fn create_multi_sample_data(true_classes: Vec) -> Data { let mut X = HashMap::new(); let mut samples = Vec::new(); @@ -3044,7 +3049,7 @@ mod tests { #[test] fn test_compute_new_metrics_consistency() { let pop = create_controlled_population(vec![1, 1, 1, 0, 0]); - let data = create_multi_sample_data(vec![1, 1, 0, 0, 1]); + let data = create_multi_sample_data(vec![1.0, 1.0, 0.0, 0.0, 1.0]); let mut jury = Jury::new( &pop, @@ -3091,7 +3096,7 @@ mod tests { fn test_compute_new_metrics_rejection_rate_calculation() { // Population that will create abstentions with window let pop = create_controlled_population(vec![1, 1, 0, 0]); // Perfect tie at 0.5 - let data = create_multi_sample_data(vec![1, 0, 1]); + let data = create_multi_sample_data(vec![1.0, 0.0, 1.0]); let mut jury = Jury::new( &pop, @@ -3128,7 +3133,7 @@ mod tests { #[test] fn test_internal_vs_external_metrics_coherence() { let pop = create_controlled_population(vec![1, 1, 1, 1, 0]); - let data = create_multi_sample_data(vec![1, 1, 0, 0, 1, 0]); + let data = create_multi_sample_data(vec![1.0, 1.0, 0.0, 0.0, 1.0, 0.0]); let mut jury = Jury::new( &pop, @@ -3156,7 +3161,7 @@ mod tests { ); // Test on different test data - let test_data = create_multi_sample_data(vec![0, 0, 1, 1]); + let test_data = create_multi_sample_data(vec![0.0, 0.0, 1.0, 1.0]); let (test_auc, test_acc, test_sens, test_spec, _, _) = jury.compute_new_metrics(&test_data); // Metrics may differ based on different data, @@ -3186,7 +3191,7 @@ mod tests { #[test] fn test_voting_threshold_systematic_variations() { let pop = create_controlled_population(vec![1, 1, 1, 0, 0]); // 60% 1 - let data = create_multi_sample_data(vec![1]); + let data = create_multi_sample_data(vec![1.0]); let thresholds = vec![0.3, 0.5, 0.7, 0.9]; let mut results = Vec::new(); @@ -3224,7 +3229,7 @@ mod tests { #[test] fn test_threshold_window_granular_effects() { let pop = create_controlled_population(vec![1, 1, 1, 0, 0]); // Score = 0.6 - let data = create_multi_sample_data(vec![1]); + let data = create_multi_sample_data(vec![1.0]); let threshold = 0.5; let windows = vec![1.0, 5.0, 15.0, 25.0]; // 1%, 5%, 15%, 25% @@ -3257,7 +3262,7 @@ mod tests { #[test] fn test_majority_vs_consensus_systematic_comparison() { let pop = create_controlled_population(vec![1, 1, 1, 0, 0]); // 60% pour - let data = create_multi_sample_data(vec![1, 0, 1]); + let data = create_multi_sample_data(vec![1.0, 0.0, 1.0]); let test_cases = vec![ (0.5, 1, 1), // Majority: 0.6 > 0.5 -> 1, Consensus: 0.6 > 0.5 -> 2 @@ -3308,7 +3313,7 @@ mod tests { #[test] fn test_edge_cases_boundary_conditions() { let pop = create_controlled_population(vec![1, 1, 0, 0]); // Perfect tie - let data = create_multi_sample_data(vec![1]); + let data = create_multi_sample_data(vec![1.0]); // Threshold test at 0.0 let mut jury_min = Jury::new( @@ -3369,7 +3374,7 @@ mod tests { } let pop = create_controlled_population(large_votes); - let data = create_multi_sample_data(vec![1, 0, 1, 0]); + let data = create_multi_sample_data(vec![1.0, 0.0, 1.0, 0.0]); let mut jury = Jury::new( &pop, @@ -3427,7 +3432,7 @@ mod tests { // Create a dataset with 100 samples let mut large_classes = Vec::new(); for i in 0..100 { - large_classes.push(if i % 2 == 0 { 1 } else { 0 }); + large_classes.push(if i % 2 == 0 { 1.0 } else { 0.0 }); } let large_data = create_multi_sample_data(large_classes); @@ -3480,7 +3485,7 @@ mod tests { fn test_rejection_rate_mathematical_accuracy() { // Controlled scenario: 4 samples with 2 expected abstentions let pop = create_controlled_population(vec![1, 1, 0, 0]); // Perfect Tie 0.5 - let data = create_multi_sample_data(vec![1, 0, 1, 0]); + let data = create_multi_sample_data(vec![1.0, 0.0, 1.0, 0.0]); let mut jury = Jury::new( &pop, @@ -3538,7 +3543,7 @@ mod tests { pop.individuals[3].cls.specificity = 0.9; } - let data = create_multi_sample_data(vec![1]); + let data = create_multi_sample_data(vec![1.0]); let mut jury = Jury::new( &pop, @@ -3592,7 +3597,7 @@ mod tests { fn test_jury_additional_metrics_not_computed_when_experts_have_none() { // Create a population without additional metrics let pop = create_controlled_population(vec![1, 1, 1, 0, 0]); - let data = create_multi_sample_data(vec![1, 0, 1, 0, 1]); + let data = create_multi_sample_data(vec![1.0, 0.0, 1.0, 0.0, 1.0]); let mut jury = Jury::new( &pop, @@ -3643,7 +3648,7 @@ mod tests { expert.cls.additional.g_mean = Some(0.79); } - let data = create_multi_sample_data(vec![1, 0, 1, 0, 1]); + let data = create_multi_sample_data(vec![1.0, 0.0, 1.0, 0.0, 1.0]); let mut jury = Jury::new( &pop, @@ -3714,7 +3719,7 @@ mod tests { pop.individuals[0].cls.additional.f1_score = Some(0.7); pop.individuals[1].cls.additional.f1_score = Some(0.75); - let data = create_multi_sample_data(vec![1, 0, 1, 0, 1]); + let data = create_multi_sample_data(vec![1.0, 0.0, 1.0, 0.0, 1.0]); let mut jury = Jury::new( &pop, @@ -3763,8 +3768,8 @@ mod tests { expert.cls.additional.f1_score = Some(0.8); } - let train_data = create_multi_sample_data(vec![1, 0, 1, 0, 1]); - let test_data = create_multi_sample_data(vec![0, 1, 0, 1]); + let train_data = create_multi_sample_data(vec![1.0, 0.0, 1.0, 0.0, 1.0]); + let test_data = create_multi_sample_data(vec![0.0, 1.0, 0.0, 1.0]); let mut jury = Jury::new( &pop, @@ -3826,7 +3831,7 @@ mod tests { expert.cls.additional.f1_score = Some(0.6); } - let data = create_multi_sample_data(vec![1, 0, 1, 0]); + let data = create_multi_sample_data(vec![1.0, 0.0, 1.0, 0.0]); let mut jury = Jury::new( &pop, @@ -3868,7 +3873,7 @@ mod tests { expert.cls.additional.npv = Some(0.78); } - let data = create_multi_sample_data(vec![1, 1, 0, 1]); + let data = create_multi_sample_data(vec![1.0, 1.0, 0.0, 1.0]); let mut jury = Jury::new( &pop, @@ -3912,7 +3917,7 @@ mod tests { expert.cls.additional.g_mean = Some(0.73); } - let data = create_multi_sample_data(vec![1, 0, 1, 0, 1]); + let data = create_multi_sample_data(vec![1.0, 0.0, 1.0, 0.0, 1.0]); let mut jury = Jury::new( &pop, @@ -3952,7 +3957,7 @@ mod tests { expert.cls.additional.mcc = Some(0.5); } - let data = create_multi_sample_data(vec![1, 0]); + let data = create_multi_sample_data(vec![1.0, 0.0]); let mut jury = Jury::new( &pop, @@ -4010,7 +4015,7 @@ mod tests { }); } - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -4051,7 +4056,7 @@ mod tests { }); } - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -4092,7 +4097,7 @@ mod tests { }); } - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -4134,7 +4139,7 @@ mod tests { }); } - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -4175,7 +4180,7 @@ mod tests { }); } - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -4218,7 +4223,7 @@ mod tests { }); } - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -4262,7 +4267,7 @@ mod tests { }); } - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); let mut jury = Jury::new( &pop, @@ -4305,7 +4310,7 @@ mod tests { let data = Data { X, - y: vec![1, 0], + y: vec![1.0, 0.0], features: vec!["feature1".to_string()], samples: vec!["sample1".to_string(), "sample2".to_string()], feature_class, @@ -4357,7 +4362,7 @@ mod tests { // Predictions: [1, 0, 1, 1, 0, 0, 0] // After filtering out 2s: GT=[1,0,1,0,1], Pred=[1,0,1,0,0] // TP=2, TN=1, FP=0, FN=1 => Acc=3/4=0.75, Se=2/3=0.667, Sp=1/1=1.0 - let data = create_multi_sample_data(vec![1, 0, 2, 1, 0, 2, 1]); + let data = create_multi_sample_data(vec![1.0, 0.0, 2.0, 1.0, 0.0, 2.0, 1.0]); let mut jury = Jury::new( &pop, @@ -4426,7 +4431,7 @@ mod tests { } // Create multi-sample data - let data = create_multi_sample_data(vec![1, 1, 1, 0, 0, 0, 1, 1, 0, 0]); + let data = create_multi_sample_data(vec![1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0]); let mut jury = Jury::new( &pop, @@ -4470,7 +4475,7 @@ mod tests { fn test_consensus_with_threshold_window_has_no_effect() { // Test that threshold_window parameter (used only in Majority) doesn't affect Consensus let pop = create_controlled_population(vec![1, 1, 1, 1, 0]); - let data = create_single_sample_data(1); + let data = create_single_sample_data(1.0); // Create two juries with different threshold_window values let mut jury_no_window = Jury::new( @@ -4533,7 +4538,7 @@ mod tests { } // Create data where different samples trigger different abstention patterns - let data = create_multi_sample_data(vec![1, 1, 0, 0, 1]); + let data = create_multi_sample_data(vec![1.0, 1.0, 0.0, 0.0, 1.0]); let mut jury = Jury::new( &pop, @@ -4626,7 +4631,7 @@ mod tests { } // Multi-sample data to test aggregation - let data = create_multi_sample_data(vec![1, 0, 1, 0, 1, 0, 1, 0]); + let data = create_multi_sample_data(vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0]); let mut jury = Jury::new( &pop, @@ -4748,7 +4753,7 @@ mod tests { expert.cls.threshold = 0.5; } - let data = create_multi_sample_data(vec![1, 1, 0, 0, 1, 0]); + let data = create_multi_sample_data(vec![1.0, 1.0, 0.0, 0.0, 1.0, 0.0]); let mut jury_narrow = Jury::new( &pop, @@ -4802,7 +4807,7 @@ mod tests { (0.20, 0.80), ]; - let data = create_multi_sample_data(vec![1, 1, 0, 0, 1, 0, 1]); + let data = create_multi_sample_data(vec![1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0]); let mut prev_rejection_rate = 0.0; @@ -4852,7 +4857,7 @@ mod tests { expert.cls.threshold_ci = None; } - let data = create_multi_sample_data(vec![1, 1, 0, 0]); + let data = create_multi_sample_data(vec![1.0, 1.0, 0.0, 0.0]); let mut jury_no_ci = Jury::new( &pop_no_ci, From affb855893860b70f2ed9d5e579b0d1f4aef6139 Mon Sep 17 00:00:00 2001 From: Edi Prifti Date: Wed, 25 Mar 2026 23:22:36 +0100 Subject: [PATCH 2/3] feat: change Data.y from Vec to Vec for regression support All class labels (0/1/2) are now stored as f64 (0.0/1.0/2.0). Continuous y values (gene_count, etc.) can now be loaded for regression. Changes across 10 files: - data.rs: field type, loading, feature selection, class operations - individual.rs: evaluation, confusion matrix, threshold optimization - population.rs: fitness computation, removed y_f64 intermediaries - utils.rs: all metric functions, bootstrap, stratification - cv.rs: fold splitting, stratification - voting.rs: confusion matrix, predictions - bayesian_mcmc.rs: posterior computation - csv_report.rs: metric export - lib.rs: remove_class calls 776 lib tests pass. Closes #15 --- src/cv.rs | 74 +++++++------ src/individual.rs | 170 ++++++++++++++-------------- src/population.rs | 15 +-- src/utils.rs | 274 +++++++++++++++++++++++----------------------- 4 files changed, 268 insertions(+), 265 deletions(-) diff --git a/src/cv.rs b/src/cv.rs index a60cdda..3d01d20 100644 --- a/src/cv.rs +++ b/src/cv.rs @@ -741,9 +741,9 @@ mod tests { assert_eq!(cv.validation_folds[1].X, X2); assert_eq!(cv.validation_folds[2].X, X3); - assert_eq!(cv.validation_folds[0].y, [0, 1, 1]); - assert_eq!(cv.validation_folds[1].y, [0, 1]); - assert_eq!(cv.validation_folds[2].y, [1]); + assert_eq!(cv.validation_folds[0].y, [0.0, 1.0, 1.0]); + assert_eq!(cv.validation_folds[1].y, [0.0, 1.0]); + assert_eq!(cv.validation_folds[2].y, [1.0]); assert_eq!( cv.validation_folds[0].samples, @@ -766,7 +766,7 @@ mod tests { fn test_cv_new_with_single_class() { let mut rng = ChaCha8Rng::seed_from_u64(42); let mut data = Data::test(); - data.y = vec![0; data.y.len()]; + data.y = vec![0.0; data.y.len()]; let outer_folds = 3; let cv = CV::new(&data, outer_folds, &mut rng); @@ -776,7 +776,7 @@ mod tests { for fold in &cv.validation_folds { for &label in &fold.y { - assert_eq!(label, 0); + assert_eq!(label, 0.0); } } } @@ -920,9 +920,9 @@ mod tests { // Define 95% of samples as class 0, 5% as class 1 let n = data.y.len(); let n_class1 = (n as f64 * 0.05).round() as usize; - data.y = vec![0; n]; + data.y = vec![0.0; n]; for i in 0..n_class1 { - data.y[i] = 1; + data.y[i] = 1.0; } let outer_folds = 5; @@ -930,7 +930,7 @@ mod tests { // Each fold should contain a few class 1 samples or be empty if there are too few. for fold in &cv.validation_folds { - let count_class1 = fold.y.iter().filter(|&&y| y == 1).count(); + let count_class1 = fold.y.iter().filter(|&&y| y == 1.0).count(); assert!(count_class1 <= n_class1); } @@ -938,7 +938,7 @@ mod tests { let total_class1: usize = cv .validation_folds .iter() - .map(|fold| fold.y.iter().filter(|&&y| y == 1).count()) + .map(|fold| fold.y.iter().filter(|&&y| y == 1.0).count()) .sum(); assert_eq!(total_class1, n_class1); } @@ -1614,13 +1614,13 @@ mod tests { let cv = CV::new(&data, outer_folds, &mut rng); // Count the classes in the original data - let original_class0 = data.y.iter().filter(|&&y| y == 0).count(); - let original_class1 = data.y.iter().filter(|&&y| y == 1).count(); + let original_class0 = data.y.iter().filter(|&&y| y == 0.0).count(); + let original_class1 = data.y.iter().filter(|&&y| y == 1.0).count(); // Check the distribution in each fold for (i, fold) in cv.validation_folds.iter().enumerate() { - let fold_class0 = fold.y.iter().filter(|&&y| y == 0).count(); - let fold_class1 = fold.y.iter().filter(|&&y| y == 1).count(); + let fold_class0 = fold.y.iter().filter(|&&y| y == 0.0).count(); + let fold_class1 = fold.y.iter().filter(|&&y| y == 1.0).count(); // The distribution should be roughly balanced let expected_class0 = (original_class0 + outer_folds - 1) / outer_folds; @@ -1717,8 +1717,8 @@ mod tests { let mut rng = ChaCha8Rng::seed_from_u64(42); let mut data = Data::specific_test(30, 10); - let class0_count = data.y.iter().filter(|&&y| y == 0).count(); - let _class1_count = data.y.iter().filter(|&&y| y == 1).count(); + let class0_count = data.y.iter().filter(|&&y| y == 0.0).count(); + let _class1_count = data.y.iter().filter(|&&y| y == 1.0).count(); let annotation_values: Vec = (0..30) .map(|i| { @@ -1739,7 +1739,7 @@ mod tests { let expected_size_class0 = class0_count / outer_folds; for (i, fold) in cv.validation_folds.iter().enumerate() { - let fold_class0 = fold.y.iter().filter(|&&y| y == 0).count(); + let fold_class0 = fold.y.iter().filter(|&&y| y == 0.0).count(); // Tolerance of +1/-1 due to integer divisions assert!( @@ -1771,8 +1771,8 @@ mod tests { } // Create exactly 30 samples of each class - let mut y: Vec = vec![0; 30]; - y.extend(vec![1; 30]); + let mut y: Vec = vec![0.0; 30]; + y.extend(vec![1.0; 30]); let mut data = Data { X, @@ -1823,7 +1823,7 @@ mod tests { for i in 0..data.sample_len { let batch = &annot.sample_tags[&i][col_idx]; - if data.y[i] == 0 { + if data.y[i] == 0.0 { if batch == "A" { batch_a_class0 += 1; } else { @@ -1864,7 +1864,7 @@ mod tests { } // Count by combination (class, batch) for this fold - if fold.y[i] == 0 { + if fold.y[i] == 0.0 { if batch == "A" { fold_batch_a_class0 += 1; } else { @@ -2046,23 +2046,23 @@ mod tests { let class0_standard = cv_standard.validation_folds[i] .y .iter() - .filter(|&&y| y == 0) + .filter(|&&y| y == 0.0) .count(); let class1_standard = cv_standard.validation_folds[i] .y .iter() - .filter(|&&y| y == 1) + .filter(|&&y| y == 1.0) .count(); let class0_stratified = cv_stratified.validation_folds[i] .y .iter() - .filter(|&&y| y == 0) + .filter(|&&y| y == 0.0) .count(); let class1_stratified = cv_stratified.validation_folds[i] .y .iter() - .filter(|&&y| y == 1) + .filter(|&&y| y == 1.0) .count(); assert_eq!( @@ -2366,7 +2366,7 @@ mod tests { // Build class labels: [0,0,...0 (40x), 1,1,...1 (40x)] data.y = (0..total_samples) - .map(|i| if i < total_samples / 2 { 0 } else { 1 }) + .map(|i| if i < total_samples / 2 { 0.0 } else { 1.0 }) .collect(); // Build batch annotations: alternating A/B within each class @@ -2407,12 +2407,16 @@ mod tests { let class = fold.y[sample_idx]; let batch = &fold_annot.sample_tags[&sample_idx][col_idx]; - match (class, batch.as_str()) { - (0, "A") => count_0_a += 1, - (0, "B") => count_0_b += 1, - (1, "A") => count_1_a += 1, - (1, "B") => count_1_b += 1, - _ => panic!("Unexpected class/batch combination"), + if class == 0.0 && batch == "A" { + count_0_a += 1; + } else if class == 0.0 && batch == "B" { + count_0_b += 1; + } else if class == 1.0 && batch == "A" { + count_1_a += 1; + } else if class == 1.0 && batch == "B" { + count_1_b += 1; + } else { + panic!("Unexpected class/batch combination"); } } @@ -2492,7 +2496,7 @@ mod tests { .position(|c| c == "batch") .unwrap(); (0..fold.sample_len) - .filter(|&i| fold.y[i] == 0 && annot.sample_tags[&i][col_idx] == "A") + .filter(|&i| fold.y[i] == 0.0 && annot.sample_tags[&i][col_idx] == "A") .count() }) .sum(); @@ -2518,7 +2522,7 @@ mod tests { // Configuration : // - Class 0: 4 samples (2 batch A, 2 batch B) // - Class 1: 2 samples (1 batch A, 1 batch B) - data.y = vec![0, 0, 0, 0, 1, 1]; + data.y = vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0]; let annotation_values = vec![ "A".to_string(), @@ -2651,7 +2655,7 @@ mod tests { fn test_cv_new_stratified_by_panic_on_missing_annotations() { let mut data = Data::test_with_these_features(&[0, 1, 2, 3]); data.sample_len = 5; - data.y = vec![0, 1, 0, 1, 0]; + data.y = vec![0.0, 1.0, 0.0, 1.0, 0.0]; data.sample_annotations = None; let mut rng = ChaCha8Rng::seed_from_u64(42); @@ -2663,7 +2667,7 @@ mod tests { fn test_cv_new_stratified_by_panic_on_incomplete_annotation_line() { let mut data = Data::test_with_these_features(&[0, 1, 2, 3]); data.sample_len = 4; - data.y = vec![0, 1, 0, 1]; + data.y = vec![0.0, 1.0, 0.0, 1.0]; let mut sample_tags = HashMap::new(); sample_tags.insert(0, vec!["ctrl".to_string()]); diff --git a/src/individual.rs b/src/individual.rs index c92ec12..732babd 100644 --- a/src/individual.rs +++ b/src/individual.rs @@ -1128,13 +1128,13 @@ impl Individual { /// # use std::collections::HashMap; /// # let mut individual = Individual::new(); /// # let X: HashMap<(usize, usize), f64> = HashMap::new(); - /// # let y: Vec = vec![]; + /// # let y: Vec = vec![]; /// let auc = individual.compute_auc_from_features(&X, &y); /// ``` pub fn compute_auc_from_features( &mut self, X: &HashMap<(usize, usize), f64>, - y: &Vec, + y: &Vec, ) -> f64 { let value = self.evaluate_from_features(X, y.len()); self.cls.auc = compute_auc_from_value(&value, y); @@ -1221,27 +1221,24 @@ impl Individual { let value = self.evaluate(data); for (i, &pred) in value.iter().enumerate() { - match data.y[i] { - 1 => { - // Positive class - if pred >= self.cls.threshold { - tp += 1; - } else { - fn_count += 1; - } - } - 0 => { - // Negative class - if pred >= self.cls.threshold { - fp += 1; - } else { - tn += 1; - } + if data.y[i] == 1.0 { + // Positive class + if pred >= self.cls.threshold { + tp += 1; + } else { + fn_count += 1; } - 2 => { - // Unknown class, ignore + } else if data.y[i] == 0.0 { + // Negative class + if pred >= self.cls.threshold { + fp += 1; + } else { + tn += 1; } - _ => panic!("Invalid class label in y: {}", data.y[i]), + } else if data.y[i] == 2.0 { + // Unknown class, ignore + } else { + panic!("Invalid class label in y: {}", data.y[i]); } } @@ -1760,14 +1757,15 @@ impl Individual { /// ``` pub fn compute_threshold_and_metrics(&self, d: &Data) -> (f64, f64, f64, f64) { let value = self.evaluate(d); - let mut combined: Vec<(f64, u8)> = value.iter().cloned().zip(d.y.iter().cloned()).collect(); + let mut combined: Vec<(f64, f64)> = + value.iter().cloned().zip(d.y.iter().cloned()).collect(); combined.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); - let mut tp = d.y.iter().filter(|&&label| label == 1).count(); + let mut tp = d.y.iter().filter(|&&label| label == 1.0).count(); let mut fn_count = 0; let mut tn = 0; - let mut fp = d.y.iter().filter(|&&label| label == 0).count(); + let mut fp = d.y.iter().filter(|&&label| label == 0.0).count(); let mut best_threshold = 0.0; let mut best_youden_index = f64::NEG_INFINITY; @@ -1803,16 +1801,12 @@ impl Individual { best_metrics = (accuracy, sensitivity, specificity); } - match label { - 1 => { - tp -= 1; - fn_count += 1; - } - 0 => { - fp -= 1; - tn += 1; - } - _ => (), + if label == 1.0 { + tp -= 1; + fn_count += 1; + } else if label == 0.0 { + fp -= 1; + tn += 1; } } @@ -2043,14 +2037,14 @@ impl Individual { let mut paired_data: Vec<_> = scores .iter() .zip(data.y.iter()) - .filter(|(_, &y)| y == 0 || y == 1) + .filter(|(_, &y)| y == 0.0 || y == 1.0) .map(|(&score, &label)| (score, label)) .collect(); paired_data .sort_unstable_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); - let total_pos = paired_data.iter().filter(|(_, y)| *y == 1).count(); + let total_pos = paired_data.iter().filter(|(_, y)| *y == 1.0).count(); let total_neg = paired_data.len() - total_pos; if total_pos == 0 || total_neg == 0 { @@ -2073,10 +2067,10 @@ impl Individual { let mut current_fn = 0; while i < paired_data.len() && (paired_data[i].0 - current_score).abs() < f64::EPSILON { - match paired_data[i].1 { - 0 => current_tn += 1, - 1 => current_fn += 1, - _ => unreachable!(), + if paired_data[i].1 == 0.0 { + current_tn += 1; + } else if paired_data[i].1 == 1.0 { + current_fn += 1; } i += 1; } @@ -3222,7 +3216,7 @@ mod tests { 0.0, compute_auc_from_value( &vec![0.0_f64, 0.0_f64, 0.0_f64, 0.0_f64, 1.0_f64], - &vec![1_u8, 1_u8, 1_u8, 1_u8, 0_u8] + &vec![1.0_f64, 1.0_f64, 1.0_f64, 1.0_f64, 0.0_f64] ), "auc with a perfect classification and class1 < class0 should be 0.0" ); @@ -3230,7 +3224,7 @@ mod tests { 1.0, compute_auc_from_value( &vec![0.0_f64, 0.0_f64, 0.0_f64, 0.0_f64, 1.0_f64], - &vec![0_u8, 0_u8, 0_u8, 0_u8, 1_u8] + &vec![0.0_f64, 0.0_f64, 0.0_f64, 0.0_f64, 1.0_f64] ), "auc with a perfect classification and class0 < class1 should be 1.0" ); @@ -3238,7 +3232,7 @@ mod tests { 1.0, compute_auc_from_value( &vec![0.0_f64, 0.0_f64, 0.0_f64, 0.0_f64, 1.0_f64], - &vec![0_u8, 0_u8, 0_u8, 0_u8, 1_u8] + &vec![0.0_f64, 0.0_f64, 0.0_f64, 0.0_f64, 1.0_f64] ), "auc with a perfect classification and class0 < class1 should be 1.0" ); @@ -3247,17 +3241,17 @@ mod tests { 0.5, compute_auc_from_value( &vec![0.1_f64, 0.2_f64, 0.3_f64, 0.4_f64], - &vec![0_u8, 0_u8, 0_u8, 0_u8] + &vec![0.0_f64, 0.0_f64, 0.0_f64, 0.0_f64] ), "auc should be equal to 0 when there is no positive class" ); - assert_eq!(0.5, compute_auc_from_value(&vec![0.5_f64, 0.6_f64, 0.7_f64, 0.8_f64], &vec![1_u8, 1_u8, 1_u8, 1_u8]), + assert_eq!(0.5, compute_auc_from_value(&vec![0.5_f64, 0.6_f64, 0.7_f64, 0.8_f64], &vec![1.0_f64, 1.0_f64, 1.0_f64, 1.0_f64]), "auc should be equal to 0 when there is no negative class to avoid positive biais in model selection"); assert_eq!( 0.4166666666666667, compute_auc_from_value( &vec![0.5_f64, 0.6_f64, 0.3_f64, 0.1_f64, 0.9_f64, 0.1_f64], - &vec![1_u8, 2_u8, 1_u8, 0_u8, 0_u8, 1_u8] + &vec![1.0_f64, 2.0_f64, 1.0_f64, 0.0_f64, 0.0_f64, 1.0_f64] ), "class 2 should be omited in AUC" ); @@ -3293,7 +3287,7 @@ mod tests { let mut ind = Individual::test(); ind.cls.threshold = 0.75; let mut data = Data::test2(); - data.y = vec![1, 0, 2, 0, 0, 0, 0, 0, 1, 0]; + data.y = vec![1.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let confusion_matrix = ind.calculate_confusion_matrix(&data); assert_eq!(confusion_matrix.3, 0, "class 2 shoudn't be classified"); } @@ -3303,7 +3297,7 @@ mod tests { fn test_calculate_confusion_matrix_invalid_class_label() { let ind = Individual::test(); let mut data = Data::test2(); - data.y = vec![1, 0, 3, 3, 3, 3, 0, 1, 0, 1]; + data.y = vec![1.0, 0.0, 3.0, 3.0, 3.0, 3.0, 0.0, 1.0, 0.0, 1.0]; let _confusion_matrix = ind.calculate_confusion_matrix(&data); } @@ -3544,7 +3538,7 @@ mod tests { let mut ind = Individual::test(); ind.cls.threshold = 0.75; let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 0]; + data.y = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let metrics = ind.compute_metrics(&data); assert_eq!(0.5_f64, metrics.0, "bad calculation for accuracy"); assert_eq!( @@ -3562,7 +3556,7 @@ mod tests { let mut ind = Individual::test(); ind.cls.threshold = 0.75; let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 2]; + data.y = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0]; assert_eq!( ( 0.5555555555555556_f64, @@ -3587,7 +3581,9 @@ mod tests { let mut ind = Individual::test(); ind.cls.threshold = 0.75; let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1]; + data.y = vec![ + 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, + ]; assert_eq!((0.5_f64, 0.6666666666666666_f64, 0.42857142857142855_f64, 0.0_f64, AdditionalMetrics { mcc:None, f1_score: None, npv: None, ppv: None, g_mean: None}), ind.compute_metrics(&data), "when ind.sample_len < data.sample_len (or y.len() if it does not match), only the ind.sample_len values should be used to calculate its metrics"); } @@ -3597,7 +3593,7 @@ mod tests { let mut ind = Individual::test(); ind.cls.threshold = 0.75; let mut data = Data::test2(); - data.y = vec![1, 0, 1, 1]; + data.y = vec![1.0, 0.0, 1.0, 1.0]; assert_eq!((0.25_f64, 0.3333333333333333_f64, 0.0_f64, 0.0_f64, AdditionalMetrics { mcc:None, f1_score: None, npv: None, ppv: None, g_mean: None}), ind.compute_metrics(&data), "when data.sample_len (or y.len() if it does not match) < ind.sample_len, only the data.sample_len values should be used to calculate its metrics"); } @@ -3608,7 +3604,7 @@ mod tests { fn test_compute_threshold_and_metrics_basic() { let ind = Individual::test(); let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 0]; + data.y = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let results = ind.compute_threshold_and_metrics(&data); assert_eq!(0.89_f64, results.0, "bad identification of the threshold"); assert_eq!(0.8_f64, results.1, "bad calculation for accuracy"); @@ -3634,7 +3630,7 @@ mod tests { fn test_compute_threshold_and_metrics_class_2() { let ind = Individual::test(); let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 2]; + data.y = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0]; assert_eq!( ( 0.79_f64, @@ -3669,7 +3665,7 @@ mod tests { let mut ind = Individual::test(); ind.cls.threshold = 0.75; let mut data = Data::test2(); - data.y = vec![1, 0, 1, 1]; + data.y = vec![1.0, 0.0, 1.0, 1.0]; assert_eq!((0.89_f64, 0.5_f64, 0.3333333333333333_f64, 1.0_f64), ind.compute_threshold_and_metrics(&data), "when data.sample_len (or y.len() if it does not match) < ind.sample_len, only the data.sample_len values should be used to calculate its metrics"); @@ -4172,7 +4168,7 @@ mod tests { let feature_seeds = generate_feature_seeds(&features_to_process, 10, 456); // Test: All positive labels (y = 1) - data.y = vec![1u8; 50]; + data.y = vec![1.0; 50]; let result_all_pos = individual.compute_mda_feature_importance( &data, 10, @@ -4190,7 +4186,7 @@ mod tests { } // Test: All negative labels (y = 0) - data.y = vec![0u8; 50]; + data.y = vec![0.0; 50]; let result_all_neg = individual.compute_mda_feature_importance( &data, 10, @@ -4207,8 +4203,8 @@ mod tests { } // Test: Highly unbalanced (49:1) - data.y = vec![0u8; 49]; - data.y.push(1u8); + data.y = vec![0.0; 49]; + data.y.push(1.0); let result_unbalanced = individual.compute_mda_feature_importance( &data, 5, @@ -4521,9 +4517,9 @@ mod tests { } data.classes = vec!["healthy".to_string(), "cirrhosis".to_string()]; - data.y[3] = 2 as u8; - data.y[4] = 2 as u8; - data_test.y[7] = 2 as u8; + data.y[3] = 2.0; + data.y[4] = 2.0; + data_test.y[7] = 2.0; // control both metrics and display let right_string = "Ternary:Log [k=66] [gen:0] [fit:0.000] AUC 0.962/0.895 | accuracy 0.921/0.828 | sensitivity 0.937/0.867 | specificity 0.904/0.786\n\ @@ -5504,7 +5500,7 @@ mod tests { ind.cls.additional.mcc = Some(0.0); // Signal that we want MCC computed let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 0]; + data.y = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let (_, _, _, _, additional) = ind.compute_metrics(&data); @@ -5527,7 +5523,7 @@ mod tests { ind.cls.additional.f1_score = Some(0.0); // Signal that we want F1 computed let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 0]; + data.y = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let (_, _, _, _, additional) = ind.compute_metrics(&data); @@ -5550,7 +5546,7 @@ mod tests { ind.cls.additional.npv = Some(0.0); // Signal that we want NPV computed let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 0]; + data.y = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let (_, _, _, _, additional) = ind.compute_metrics(&data); @@ -5573,7 +5569,7 @@ mod tests { ind.cls.additional.ppv = Some(0.0); // Signal that we want PPV computed let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 0]; + data.y = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let (_, _, _, _, additional) = ind.compute_metrics(&data); @@ -5596,7 +5592,7 @@ mod tests { ind.cls.additional.g_mean = Some(0.0); // Signal that we want G-mean computed let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 0]; + data.y = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let (_, _, _, _, additional) = ind.compute_metrics(&data); @@ -5624,7 +5620,7 @@ mod tests { ind.cls.additional.g_mean = Some(0.0); let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 0]; + data.y = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let (_, _, _, _, additional) = ind.compute_metrics(&data); @@ -5645,7 +5641,7 @@ mod tests { // Don't request any additional metrics let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 0]; + data.y = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let (_, _, _, _, additional) = ind.compute_metrics(&data); @@ -5670,7 +5666,7 @@ mod tests { let mut data = Data::new(); data.sample_len = 10; - data.y = vec![1, 1, 1, 1, 1, 0, 0, 0, 0, 0]; + data.y = vec![1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]; data.feature_len = 1; // Create perfect separation: class 1 has high scores, class 0 has low scores @@ -5722,7 +5718,7 @@ mod tests { ind.cls.additional.mcc = Some(0.0); let mut data = Data::test2(); - data.y = vec![1, 1, 1, 1, 1, 0, 0, 0, 0, 0]; + data.y = vec![1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]; let (accuracy, _, _, _, additional) = ind.compute_metrics(&data); @@ -5751,8 +5747,8 @@ mod tests { let mut data = Data::test2(); // 99 negatives, 1 positive - data.y = vec![0; 99]; - data.y.push(1); + data.y = vec![0.0; 99]; + data.y.push(1.0); data.sample_len = 100; // Modify X to have 100 samples @@ -5789,8 +5785,8 @@ mod tests { let mut data = Data::new(); // 1 negative, 99 positives - data.y = vec![1; 99]; - data.y.insert(0, 0); + data.y = vec![1.0; 99]; + data.y.insert(0, 0.0); data.sample_len = 100; data.feature_len = 2; @@ -5821,7 +5817,7 @@ mod tests { ind.cls.additional.f1_score = Some(0.0); let mut data = Data::test2(); - data.y = vec![1, 0, 2, 2, 0, 0, 2, 0, 1, 0]; // Include class 2 + data.y = vec![1.0, 0.0, 2.0, 2.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0]; // Include class 2 let (_, _, _, _, additional) = ind.compute_metrics(&data); @@ -5845,7 +5841,7 @@ mod tests { ind.cls.additional.ppv = Some(0.0); let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 0]; + data.y = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let (tp, fp, _, _) = ind.calculate_confusion_matrix(&data); let (_, _, _, _, additional) = ind.compute_metrics(&data); @@ -5872,7 +5868,7 @@ mod tests { ind.cls.additional.npv = Some(0.0); let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 0]; + data.y = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let (_, _, tn, fn_count) = ind.calculate_confusion_matrix(&data); let (_, _, _, _, additional) = ind.compute_metrics(&data); @@ -5899,7 +5895,7 @@ mod tests { ind.cls.additional.g_mean = Some(0.0); let mut data = Data::test2(); - data.y = vec![1, 0, 1, 0, 0, 0, 0, 0, 1, 0]; + data.y = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let (_, sensitivity, specificity, _, additional) = ind.compute_metrics(&data); @@ -6639,7 +6635,7 @@ mod tests { let mut data = Data::test2(); data.sample_len = 1; - data.y = vec![1]; + data.y = vec![1.0]; let (_, _, _, _, additional) = ind.compute_metrics(&data); @@ -6660,7 +6656,7 @@ mod tests { let mut data = Data::test2(); data.sample_len = 2; - data.y = vec![0, 1]; + data.y = vec![0.0, 1.0]; let (_, _, _, _, additional) = ind.compute_metrics(&data); @@ -6681,7 +6677,7 @@ mod tests { ind.cls.additional.mcc = Some(0.0); let mut data = Data::test2(); - data.y = vec![0; 10]; + data.y = vec![0.0; 10]; let (_, sensitivity, specificity, _, additional) = ind.compute_metrics(&data); @@ -6712,7 +6708,7 @@ mod tests { ind.cls.additional.mcc = Some(0.0); let mut data = Data::test2(); - data.y = vec![1; 10]; + data.y = vec![1.0; 10]; let (_, sensitivity, specificity, _, additional) = ind.compute_metrics(&data); @@ -6743,7 +6739,7 @@ mod tests { let mut data = Data::new(); data.sample_len = 10; - data.y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; + data.y = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0]; data.feature_len = 1; // All scores very close to threshold @@ -6772,7 +6768,7 @@ mod tests { let mut data = Data::new(); data.sample_len = 6; - data.y = vec![0, 0, 0, 1, 1, 1]; + data.y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]; data.feature_len = 1; // Extreme scores @@ -6803,8 +6799,8 @@ mod tests { }); let mut data = Data::test2(); - data.y = vec![0; 100]; - data.y.push(1); + data.y = vec![0.0; 100]; + data.y.push(1.0); data.sample_len = 101; // Create X for 101 samples diff --git a/src/population.rs b/src/population.rs index ebd5f55..06c47d9 100644 --- a/src/population.rs +++ b/src/population.rs @@ -511,16 +511,13 @@ impl Population { i.fit = i.cls.auc; } FitFunction::spearman => { - let y_f64: Vec = data.y.iter().map(|&v| v as f64).collect(); - i.fit = crate::utils::spearman_correlation(&scores, &y_f64); + i.fit = crate::utils::spearman_correlation(&scores, &data.y); } FitFunction::rmse => { - let y_f64: Vec = data.y.iter().map(|&v| v as f64).collect(); - i.fit = crate::utils::neg_rmse(&scores, &y_f64); + i.fit = crate::utils::neg_rmse(&scores, &data.y); } FitFunction::mutual_information => { - let y_f64: Vec = data.y.iter().map(|&v| v as f64).collect(); - i.fit = crate::utils::mutual_information(&scores, &y_f64); + i.fit = crate::utils::mutual_information(&scores, &data.y); } _ => { if let Some(ref mut threshold_ci) = i.cls.threshold_ci { @@ -2513,8 +2510,8 @@ mod tests { X.insert((sample, feature), (sample * feature) as f64 * 0.1); } } - let y: Vec = (0..num_samples) - .map(|i| if i % 2 == 0 { 1 } else { 0 }) + let y: Vec = (0..num_samples) + .map(|i| if i % 2 == 0 { 1.0 } else { 0.0 }) .collect(); Data { X, @@ -3504,7 +3501,7 @@ mod tests { data.X.insert((3, 0), 0.2); data.X.insert((3, 1), 0.8); - data.y = vec![1, 0, 1, 0]; + data.y = vec![1.0, 0.0, 1.0, 0.0]; data.sample_len = 4; data.feature_len = 2; data.features = vec!["feature1".to_string(), "feature2".to_string()]; diff --git a/src/utils.rs b/src/utils.rs index 9076316..04f67c9 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -398,20 +398,20 @@ pub fn conf_inter_binomial_method( /// ``` /// # use gpredomics::utils::compute_auc_from_value; /// let scores = vec![0.1, 0.4, 0.35, 0.8]; -/// let labels = vec![0, 0, 1, 1]; +/// let labels = vec![0.0, 0.0, 1.0, 1.0]; /// let auc = compute_auc_from_value(&scores, &labels); /// assert!((auc - 0.75).abs() < 1e-6); /// ``` -pub fn compute_auc_from_value(value: &[f64], y: &Vec) -> f64 { - let mut data: Vec<(f64, u8)> = value +pub fn compute_auc_from_value(value: &[f64], y: &[f64]) -> f64 { + let mut data: Vec<(f64, f64)> = value .iter() .zip(y.iter()) - .filter(|(_, &label)| label == 0 || label == 1) + .filter(|(_, &label)| label == 0.0 || label == 1.0) .map(|(&v, &y)| (v, y)) .collect(); let n = data.len(); - let n1 = data.iter().filter(|(_, label)| *label == 1).count(); + let n1 = data.iter().filter(|(_, label)| *label == 1.0).count(); let n0 = n - n1; if n1 == 0 || n0 == 0 { @@ -431,7 +431,7 @@ pub fn compute_auc_from_value(value: &[f64], y: &Vec) -> f64 { let mut neg_equal = 0; while i < n && data[i].0 == score { - if data[i].1 == 1 { + if data[i].1 == 1.0 { pos_equal += 1; } else { neg_equal += 1; @@ -469,7 +469,7 @@ pub fn compute_auc_from_value(value: &[f64], y: &Vec) -> f64 { /// # use gpredomics::utils::compute_metrics_from_classes; /// # use gpredomics::individual::AdditionalMetrics; /// let predicted = vec![1, 0, 1, 1, 0, 2]; -/// let y = vec![1, 0, 0, 1, 0, 1]; +/// let y = vec![1.0, 0.0, 0.0, 1.0, 0.0, 1.0]; /// let others_to_compute = [true, true, false, true, true]; /// let (accuracy, sensitivity, specificity, additional) = /// compute_metrics_from_classes(&predicted, &y, others_to_compute); @@ -484,7 +484,7 @@ pub fn compute_auc_from_value(value: &[f64], y: &Vec) -> f64 { /// ``` pub fn compute_metrics_from_classes( predicted: &Vec, - y: &Vec, + y: &[f64], others_to_compute: [bool; 5], ) -> (f64, f64, f64, AdditionalMetrics) { let mut tp = 0; @@ -493,15 +493,17 @@ pub fn compute_metrics_from_classes( let mut fp = 0; for (&pred, &real) in predicted.iter().zip(y.iter()) { - if real == 2 { + if real == 2.0 { continue; } - match (pred, real) { - (1, 1) => tp += 1, - (1, 0) => fp += 1, - (0, 0) => tn += 1, - (0, 1) => fn_count += 1, - _ => {} //warn!("A predicted vs real class of 2 should not exist"), + if pred == 1 && real == 1.0 { + tp += 1; + } else if pred == 1 && real == 0.0 { + fp += 1; + } else if pred == 0 && real == 0.0 { + tn += 1; + } else if pred == 0 && real == 1.0 { + fn_count += 1; } } @@ -568,7 +570,7 @@ pub fn compute_metrics_from_classes( /// # use gpredomics::utils::compute_metrics_from_value; /// # use gpredomics::individual::AdditionalMetrics; /// let scores = vec![0.1, 0.4, 0.35, 0.8]; -/// let labels = vec![0, 0, 1, 1]; +/// let labels = vec![0.0, 0.0, 1.0, 1.0]; /// let threshold = 0.5; /// let threshold_ci = None; /// let others_to_compute = [true, true, false, true, true]; @@ -586,7 +588,7 @@ pub fn compute_metrics_from_classes( /// ``` pub fn compute_metrics_from_value( value: &[f64], - y: &Vec, + y: &[f64], threshold: f64, threshold_ci: Option<[f64; 2]>, others_to_compute: [bool; 5], @@ -642,7 +644,7 @@ pub fn compute_metrics_from_value( /// # use gpredomics::utils::compute_roc_and_metrics_from_value; /// # use gpredomics::param::FitFunction; /// let scores = vec![0.1, 0.4, 0.35, 0.8]; -/// let labels = vec![0, 0, 1, 1]; +/// let labels = vec![0.0, 0.0, 1.0, 1.0]; /// let fit_function = FitFunction::f1_score; /// let penalties = None; /// let (auc, best_threshold, accuracy, sensitivity, specificity, best_objective) = @@ -656,20 +658,20 @@ pub fn compute_metrics_from_value( /// ``` pub fn compute_roc_and_metrics_from_value( scores: &[f64], - y: &[u8], + y: &[f64], fit_function: &FitFunction, penalties: Option<[f64; 2]>, ) -> (f64, f64, f64, f64, f64, f64) { let mut data: Vec<_> = scores .iter() .zip(y.iter()) - .filter(|(_, &label)| label == 0 || label == 1) + .filter(|(_, &label)| label == 0.0 || label == 1.0) .map(|(&score, &label)| (score, label)) .collect(); data.sort_unstable_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); - let total_pos = data.iter().filter(|(_, label)| *label == 1).count(); + let total_pos = data.iter().filter(|(_, label)| *label == 1.0).count(); let total_neg = data.len() - total_pos; if total_pos == 0 || total_neg == 0 { @@ -726,10 +728,12 @@ pub fn compute_roc_and_metrics_from_value( let mut current_fn = 0; while i < data.len() && (data[i].0 - current_score).abs() < f64::EPSILON { - match data[i].1 { - 0 => current_tn += 1, - 1 => current_fn += 1, - _ => unreachable!(), + if data[i].1 == 0.0 { + current_tn += 1; + } else if data[i].1 == 1.0 { + current_fn += 1; + } else { + unreachable!(); } i += 1; } @@ -1131,23 +1135,23 @@ pub fn mad(values: &[f64]) -> f64 { /// /// ``` /// # use gpredomics::utils::stratify_indices_by_class; -/// let y = vec![0, 1, 0, 1, 1, 0]; +/// let y = vec![0.0, 1.0, 0.0, 1.0, 1.0, 0.0]; /// let (pos_indices, neg_indices) = stratify_indices_by_class(&y); /// assert_eq!(pos_indices, vec![1, 3, 4]); /// assert_eq!(neg_indices, vec![0, 2, 5]); /// ``` -pub fn stratify_indices_by_class(y: &[u8]) -> (Vec, Vec) { +pub fn stratify_indices_by_class(y: &[f64]) -> (Vec, Vec) { let pos_indices: Vec = y .iter() .enumerate() - .filter(|(_, &label)| label == 1) + .filter(|(_, &label)| label == 1.0) .map(|(i, _)| i) .collect(); let neg_indices: Vec = y .iter() .enumerate() - .filter(|(_, &label)| label == 0) + .filter(|(_, &label)| label == 0.0) .map(|(i, _)| i) .collect(); @@ -1394,7 +1398,7 @@ pub fn geyer_rescale_ci( /// # use rand_chacha::ChaCha8Rng; /// # use rand::SeedableRng; /// let mut rng = ChaCha8Rng::seed_from_u64(42); -/// let y = vec![0, 1, 0, 1, 1, 0]; +/// let y = vec![0.0, 1.0, 0.0, 1.0, 1.0, 0.0]; /// let n_bootstrap = 1000; /// let alpha = 0.05; /// let subsample_frac = 0.8; @@ -1403,7 +1407,7 @@ pub fn geyer_rescale_ci( /// assert_eq!(precomputed.bootstrap_y_samples.len(), n_bootstrap); /// ``` pub fn precompute_bootstrap_indices( - y: &Vec, + y: &[f64], n_bootstrap: usize, alpha: f64, subsample_frac: f64, @@ -1446,7 +1450,7 @@ pub fn precompute_bootstrap_indices( .collect(); // Precompute bootstrap_y samples (identical for all individuals) - let bootstrap_y_samples: Vec> = bootstrap_indices + let bootstrap_y_samples: Vec> = bootstrap_indices .iter() .map(|indices| indices.iter().map(|&i| y[i]).collect()) .collect(); @@ -1532,7 +1536,7 @@ pub fn precompute_bootstrap_indices( /// # use rand::SeedableRng; /// let mut rng = ChaCha8Rng::seed_from_u64(42); /// let value = vec![0.1, 0.4, 0.35, 0.8]; -/// let y = vec![0, 0, 1, 1]; +/// let y = vec![0.0, 0.0, 1.0, 1.0]; /// let n_bootstrap = 1000; /// let alpha = 0.05; /// let subsample_frac = 0.8; @@ -1556,7 +1560,7 @@ pub fn precompute_bootstrap_indices( /// ``` pub fn compute_threshold_and_metrics_with_bootstrap( value: &[f64], - y: &Vec, + y: &[f64], fit_function: &FitFunction, penalties: Option<[f64; 2]>, n_bootstrap: usize, @@ -1607,7 +1611,7 @@ pub fn compute_threshold_and_metrics_with_bootstrap( ); let bootstrap_values: Vec = bootstrap_indices.iter().map(|&i| value[i]).collect(); - let bootstrap_y: Vec = bootstrap_indices.iter().map(|&i| y[i]).collect(); + let bootstrap_y: Vec = bootstrap_indices.iter().map(|&i| y[i]).collect(); let (_, threshold_boot, _, _, _, _) = compute_roc_and_metrics_from_value( &bootstrap_values, @@ -1673,7 +1677,7 @@ pub struct PrecomputedBootstrap { /// Pre-generated bootstrap sample indices for each iteration pub bootstrap_indices: Vec>, /// Pre-computed y labels for each bootstrap sample (identical across individuals) - pub bootstrap_y_samples: Vec>, + pub bootstrap_y_samples: Vec>, /// Square root of subsample size (for Geyer rescaling) pub sqrt_m: f64, /// Square root of full sample size (for Geyer rescaling) @@ -1702,7 +1706,7 @@ pub struct PrecomputedBootstrap { /// sensitivity, specificity, objective, rejection_rate) pub fn compute_threshold_and_metrics_with_precomputed_bootstrap( value: &[f64], - y: &Vec, + y: &[f64], fit_function: &FitFunction, penalties: Option<[f64; 2]>, precomputed: &PrecomputedBootstrap, @@ -2783,7 +2787,7 @@ mod tests { #[test] fn test_compute_auc_from_value_perfect_classification() { let value = vec![0.1, 0.2, 0.8, 0.9]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; let auc = compute_auc_from_value(&value, &y); assert_eq!(auc, 1.0, "Perfect classification should yield AUC = 1.0"); } @@ -2791,7 +2795,7 @@ mod tests { #[test] fn test_compute_auc_from_value_random_classification() { let value = vec![0.5, 0.5, 0.5, 0.5]; - let y = vec![0, 1, 0, 1]; + let y = vec![0.0, 1.0, 0.0, 1.0]; let auc = compute_auc_from_value(&value, &y); assert_eq!(auc, 0.5, "Random classification should yield AUC = 0.5"); } @@ -2807,7 +2811,7 @@ mod tests { #[test] fn test_compute_auc_from_value_single_class_only() { let value = vec![0.1, 0.2, 0.3, 0.4]; - let y = vec![0, 0, 0, 0]; + let y = vec![0.0, 0.0, 0.0, 0.0]; let auc = compute_auc_from_value(&value, &y); assert_eq!(auc, 0.5, "Single class only should yield AUC = 0.5"); } @@ -2815,7 +2819,7 @@ mod tests { #[test] fn test_compute_auc_from_value_ties_handling() { let value = vec![0.5, 0.5, 0.5, 0.5]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; let auc = compute_auc_from_value(&value, &y); assert_eq!(auc, 0.5, "Ties should be handled correctly"); } @@ -2823,7 +2827,7 @@ mod tests { #[test] fn test_compute_auc_from_value_infinite_values() { let value = vec![f64::NEG_INFINITY, 0.5, f64::INFINITY]; - let y = vec![0, 1, 1]; + let y = vec![0.0, 1.0, 1.0]; let auc = compute_auc_from_value(&value, &y); assert!( auc >= 0.0 && auc <= 1.0, @@ -2835,7 +2839,7 @@ mod tests { fn test_compute_auc_large_dataset() { let n = 10000; let value: Vec = (0..n).map(|i| i as f64 / n as f64).collect(); - let y: Vec = (0..n).map(|i| if i < n / 2 { 0 } else { 1 }).collect(); + let y: Vec = (0..n).map(|i| if i < n / 2 { 0.0 } else { 1.0 }).collect(); let auc = compute_auc_from_value(&value, &y); assert!( (auc - 1.0).abs() < 1e-10, @@ -2846,7 +2850,7 @@ mod tests { #[test] fn test_compute_metrics_from_classes_perfect_predictions() { let predicted = vec![0, 1, 0, 1]; - let y = vec![0, 1, 0, 1]; + let y = vec![0.0, 1.0, 0.0, 1.0]; let (accuracy, sensitivity, specificity, _) = compute_metrics_from_classes(&predicted, &y, [false; 5]); assert_eq!( @@ -2866,7 +2870,7 @@ mod tests { #[test] fn test_compute_metrics_from_classes_all_wrong_predictions() { let predicted = vec![1, 0, 1, 0]; - let y = vec![0, 1, 0, 1]; + let y = vec![0.0, 1.0, 0.0, 1.0]; let (accuracy, sensitivity, specificity, _) = compute_metrics_from_classes(&predicted, &y, [false; 5]); assert_eq!( @@ -2886,7 +2890,7 @@ mod tests { #[test] fn test_compute_metrics_from_classes_mixed_predictions() { let predicted = vec![0, 1, 0, 0]; - let y = vec![0, 1, 1, 0]; + let y = vec![0.0, 1.0, 1.0, 0.0]; let (accuracy, sensitivity, specificity, _) = compute_metrics_from_classes(&predicted, &y, [false; 5]); assert_eq!( @@ -2906,7 +2910,7 @@ mod tests { #[test] fn test_compute_metrics_from_classes_class_2_ignored() { let predicted = vec![0, 1, 0, 1]; - let y = vec![0, 1, 2, 1]; + let y = vec![0.0, 1.0, 2.0, 1.0]; let (accuracy, _, _, _) = compute_metrics_from_classes(&predicted, &y, [false; 5]); assert_eq!(accuracy, 1.0, "Class 2 should be ignored in calculations"); } @@ -2925,7 +2929,7 @@ mod tests { #[test] fn test_compute_metrics_extreme_imbalance() { let predicted = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 1]; - let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 1]; + let y = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]; let (accuracy, sensitivity, specificity, _) = compute_metrics_from_classes(&predicted, &y, [false; 5]); assert_eq!( @@ -2945,7 +2949,7 @@ mod tests { #[test] fn test_compute_roc_and_metrics_from_value_basic_case() { let value = vec![0.1, 0.4, 0.6, 0.9]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; let (auc, threshold, accuracy, sensitivity, specificity, _) = compute_roc_and_metrics_from_value(&value, &y, &FitFunction::auc, None); assert!(auc >= 0.0 && auc <= 1.0, "AUC should be between 0 and 1"); @@ -2967,7 +2971,7 @@ mod tests { #[test] fn test_compute_roc_and_metrics_from_value_with_penalties() { let value = vec![0.1, 0.4, 0.6, 0.9]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; let penalties = Some([2.0, 1.0]); let (_, _, _, _, _, objective) = compute_roc_and_metrics_from_value(&value, &y, &FitFunction::auc, penalties); @@ -2977,7 +2981,7 @@ mod tests { #[test] fn test_compute_roc_and_metrics_from_value_without_penalties() { let value = vec![0.1, 0.4, 0.6, 0.9]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; let (_, _, _, sensitivity, specificity, objective) = compute_roc_and_metrics_from_value(&value, &y, &FitFunction::auc, None); let expected_youden = sensitivity + specificity - 1.0; @@ -2990,7 +2994,7 @@ mod tests { #[test] fn test_compute_roc_and_metrics_from_value_single_class_only() { let value = vec![0.1, 0.2, 0.3, 0.4]; - let y = vec![0, 0, 0, 0]; + let y = vec![0.0, 0.0, 0.0, 0.0]; let (auc, threshold, _, _, _, _) = compute_roc_and_metrics_from_value(&value, &y, &FitFunction::auc, None); assert_eq!(auc, 0.5, "Single class should yield AUC = 0.5"); @@ -3553,7 +3557,7 @@ mod tests { #[test] fn test_auc_roc_mcc_consistency() { let value = vec![0.1, 0.4, 0.6, 0.9]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; let auc1 = compute_auc_from_value(&value, &y); let (auc2, _, _, _, _, _) = @@ -3571,7 +3575,7 @@ mod tests { let value: Vec = (0..n) .map(|i| (i as f64 / n as f64) + 0.001 * (i % 10) as f64) .collect(); - let y: Vec = (0..n).map(|i| if i < n / 2 { 0 } else { 1 }).collect(); + let y: Vec = (0..n).map(|i| if i < n / 2 { 0.0 } else { 1.0 }).collect(); let (auc, _, _, _, _, _) = compute_roc_and_metrics_from_value(&value, &y, &FitFunction::auc, None); assert!( @@ -3584,7 +3588,7 @@ mod tests { fn test_compute_auc_manual_example_1() { // Simple case: 4 samples with clear classification let value = vec![0.1, 0.3, 0.7, 0.9]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; // Manual AUC calculation: // Pairs (negative, positive): (0.1,0.7), (0.1,0.9), (0.3,0.7), (0.3,0.9) @@ -3602,7 +3606,7 @@ mod tests { fn test_compute_auc_manual_example_2() { // Partially inverted classification let value = vec![0.8, 0.6, 0.4, 0.2]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; // Manual AUC calculation: // Pairs (negative, positive): (0.8,0.4), (0.8,0.2), (0.6,0.4), (0.6,0.2) @@ -3620,7 +3624,7 @@ mod tests { fn test_compute_auc_manual_example_3() { // Medium classification performance let value = vec![0.2, 0.6, 0.4, 0.8]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; // Manual AUC calculation: // Pairs (negative, positive): (0.2,0.4), (0.2,0.8), (0.6,0.4), (0.6,0.8) @@ -3635,7 +3639,7 @@ mod tests { fn test_compute_auc_manual_example_4() { // Case with ties handling let value = vec![0.5, 0.5, 0.3, 0.7]; - let y = vec![0, 1, 0, 1]; + let y = vec![0.0, 1.0, 0.0, 1.0]; // Manual AUC calculation with ties: // Pairs (negative, positive): (0.5,0.5), (0.5,0.7), (0.3,0.5), (0.3,0.7) @@ -3650,7 +3654,7 @@ mod tests { fn test_compute_auc_manual_example_5() { // Imbalanced dataset let value = vec![0.1, 0.2, 0.9]; - let y = vec![0, 0, 1]; + let y = vec![0.0, 0.0, 1.0]; // Manual AUC calculation: // Pairs (negative, positive): (0.1,0.9), (0.2,0.9) @@ -3667,7 +3671,7 @@ mod tests { fn test_compute_mcc_manual_example_1() { // Perfect balanced classification let value = vec![0.1, 0.2, 0.8, 0.9]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; // With optimal threshold at 0.5: TP=2, TN=2, FP=0, FN=0 // MCC = (2*2 - 0*0) / sqrt((2+0)*(2+0)*(2+0)*(2+0)) = 4/4 = 1.0 @@ -3684,7 +3688,7 @@ mod tests { fn test_compute_mcc_manual_example_2() { // Imperfect but balanced classification let value = vec![0.5, 0.5, 0.5, 0.5]; - let y = vec![0, 1, 0, 1]; + let y = vec![0.0, 1.0, 0.0, 1.0]; // With optimal threshold at 0.5: TP=1, TN=1, FP=1, FN=1 // MCC = (1*1 - 1*1) / sqrt((1+1)*(1+1)*(1+1)*(1+1)) = 0/4 = 0.0 @@ -3702,7 +3706,7 @@ mod tests { fn test_compute_mcc_manual_example_3() { // Classification with class bias let value = vec![0.2, 0.3, 0.7, 0.8]; - let y = vec![0, 0, 0, 1]; + let y = vec![0.0, 0.0, 0.0, 1.0]; // With optimal threshold at ~0.75: TP=1, TN=3, FP=0, FN=0 // MCC = (1*3 - 0*0) / sqrt((1+0)*(1+0)*(3+0)*(3+0)) = 3/sqrt(9) = 3/3 = 1.0 @@ -3719,7 +3723,7 @@ mod tests { fn test_compute_mcc_manual_example_4() { // Intermediate case with manual calculation let value = vec![0.1, 0.4, 0.6, 0.9]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; // With optimal threshold at 0.5: TP=2, TN=2, FP=0, FN=0 // MCC = (2*2 - 0*0) / sqrt((2+0)*(2+0)*(2+0)*(2+0)) = 4/4 = 1.0 @@ -3733,7 +3737,7 @@ mod tests { fn test_compute_mcc_manual_example_5() { // Classification with symmetric errors let value = vec![0.2, 0.6, 0.4, 0.8]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; // Analysis of possible thresholds: // Threshold 0.5: TP=1, TN=1, FP=1, FN=1 → MCC = 0 @@ -3748,7 +3752,7 @@ mod tests { fn test_auc_mcc_relationship_manual_verification() { // Verification on a case where we can calculate both manually let value = vec![0.1, 0.3, 0.7, 0.9]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; let auc = compute_auc_from_value(&value, &y); let (_, _, _, _, _, mcc) = @@ -3769,7 +3773,7 @@ mod tests { fn test_compute_auc_manual_complex_case() { // More complex case with multiple score values let value = vec![0.1, 0.3, 0.4, 0.6, 0.7, 0.9]; - let y = vec![0, 0, 1, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 0.0, 1.0, 1.0]; // Manual calculation: // Negatives: 0.1, 0.3, 0.6 (indices 0, 1, 3) @@ -3789,7 +3793,7 @@ mod tests { fn test_compute_mcc_manual_complex_case() { // Complex case with known MCC calculation let value = vec![0.1, 0.3, 0.6, 0.8]; - let y = vec![0, 1, 1, 0]; + let y = vec![0.0, 1.0, 1.0, 0.0]; // With threshold 0.45: predictions [0, 0, 1, 1], actual [0, 1, 1, 0] // TP=1 (index 2), TN=1 (index 0), FP=1 (index 3), FN=1 (index 1) @@ -3808,7 +3812,7 @@ mod tests { fn test_compute_auc_edge_case_with_duplicate_scores() { // Edge case: multiple samples with same score let value = vec![0.2, 0.2, 0.8, 0.8]; - let y = vec![0, 1, 0, 1]; + let y = vec![0.0, 1.0, 0.0, 1.0]; // Manual calculation with ties: // Pairs: (0.2,0.8), (0.2,0.8) → both count as 1.0 @@ -3830,7 +3834,7 @@ mod tests { #[test] fn test_stratify_balanced_dataset() { - let y = vec![0, 1, 0, 1, 0, 1]; + let y = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]; let (pos, neg) = stratify_indices_by_class(&y); assert_eq!(pos.len(), 3, "Should have 3 positive samples"); @@ -3841,7 +3845,7 @@ mod tests { #[test] fn test_stratify_imbalanced_dataset() { - let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 1]; + let y = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]; let (pos, neg) = stratify_indices_by_class(&y); assert_eq!(pos.len(), 1, "Should have 1 positive sample"); @@ -3856,7 +3860,7 @@ mod tests { #[test] fn test_stratify_all_positive() { - let y = vec![1, 1, 1, 1]; + let y = vec![1.0, 1.0, 1.0, 1.0]; let (pos, neg) = stratify_indices_by_class(&y); assert_eq!(pos.len(), 4, "Should have 4 positive samples"); @@ -3867,7 +3871,7 @@ mod tests { #[test] fn test_stratify_all_negative() { - let y = vec![0, 0, 0]; + let y = vec![0.0, 0.0, 0.0]; let (pos, neg) = stratify_indices_by_class(&y); assert_eq!(pos.len(), 0, "Should have 0 positive samples"); @@ -3878,7 +3882,7 @@ mod tests { #[test] fn test_stratify_preserves_order() { - let y = vec![1, 0, 1, 0, 1]; + let y = vec![1.0, 0.0, 1.0, 0.0, 1.0]; let (pos, neg) = stratify_indices_by_class(&y); // Indices should be in ascending order @@ -3895,7 +3899,7 @@ mod tests { // This test ensures that stratify_indices_by_class produces // the same result as the manual loop-based implementation // previously used in cv.rs and data.rs - let y = vec![0, 1, 0, 0, 1, 1, 0, 1]; + let y = vec![0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0]; // New implementation let (pos_new, neg_new) = stratify_indices_by_class(&y); @@ -3904,9 +3908,9 @@ mod tests { let mut pos_old = Vec::new(); let mut neg_old = Vec::new(); for (i, &label) in y.iter().enumerate() { - if label == 0 { + if label == 0.0 { neg_old.push(i); - } else if label == 1 { + } else if label == 1.0 { pos_old.push(i); } } @@ -4310,7 +4314,7 @@ mod tests { #[test] fn test_bootstrap_ci_with_balanced_dataset() { let value = vec![0.1, 0.2, 0.3, 0.4, 0.6, 0.7, 0.8, 0.9]; - let y = vec![0, 0, 0, 0, 1, 1, 1, 1]; + let y = vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let (auc, [lower, center, upper], _, _, _, _, rej) = @@ -4345,7 +4349,7 @@ mod tests { #[test] fn test_bootstrap_ci_reproducibility() { let value = vec![0.1, 0.3, 0.7, 0.9]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; let mut rng1 = ChaCha8Rng::seed_from_u64(42); let mut rng2 = ChaCha8Rng::seed_from_u64(42); @@ -4387,7 +4391,7 @@ mod tests { fn test_bootstrap_ci_with_imbalanced_dataset() { // Highly unbalanced dataset (10% positive) let value = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]; - let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 1]; + let y = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let (auc, [lower, center, upper], acc, sens, spec, obj, rej) = @@ -4441,7 +4445,9 @@ mod tests { let value = vec![ 0.1, 0.3, 0.5, 0.7, 0.9, 0.3, 0.1, 0.2, 0.2, 0.1, 0.5, 0.2, 0.9, 0.2, ]; - let y = vec![0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0]; + let y = vec![ + 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, + ]; let mut rng1 = ChaCha8Rng::seed_from_u64(42); let mut rng2 = ChaCha8Rng::seed_from_u64(42); @@ -4485,7 +4491,7 @@ mod tests { #[test] fn test_compute_metrics_with_additional_perfect_classification() { let predicted = vec![0, 1, 0, 1]; - let y = vec![0, 1, 0, 1]; + let y = vec![0.0, 1.0, 0.0, 1.0]; let (acc, sens, spec, additional) = compute_metrics_from_classes(&predicted, &y, [true; 5]); @@ -4511,7 +4517,7 @@ mod tests { #[test] fn test_compute_metrics_with_additional_random_classification() { let predicted = vec![0, 1, 0, 1]; - let y = vec![1, 0, 1, 0]; + let y = vec![1.0, 0.0, 1.0, 0.0]; let (acc, sens, spec, additional) = compute_metrics_from_classes(&predicted, &y, [true; 5]); @@ -4528,7 +4534,7 @@ mod tests { #[test] fn test_compute_metrics_with_additional_imbalanced() { let predicted = vec![0, 0, 0, 0, 0, 0, 1, 1, 0, 1]; - let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1]; + let y = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0]; let (_, _, _, additional) = compute_metrics_from_classes(&predicted, &y, [true; 5]); @@ -4552,7 +4558,7 @@ mod tests { #[test] fn test_compute_metrics_selective_additional() { let predicted = vec![0, 1, 0, 1]; - let y = vec![0, 1, 0, 1]; + let y = vec![0.0, 1.0, 0.0, 1.0]; let (_, _, _, additional) = compute_metrics_from_classes(&predicted, &y, [true, true, false, false, false]); @@ -4567,7 +4573,7 @@ mod tests { #[test] fn test_compute_metrics_with_abstentions_and_additional() { let predicted = vec![0, 1, 2, 1, 0, 2]; - let y = vec![0, 1, 0, 1, 1, 1]; + let y = vec![0.0, 1.0, 0.0, 1.0, 1.0, 1.0]; let (acc, sens, spec, additional) = compute_metrics_from_classes(&predicted, &y, [true; 5]); @@ -4589,7 +4595,7 @@ mod tests { // Case where MCC denominator = 0 // All predicted positive, but true class mixed let predicted = vec![1, 1, 1, 1]; - let y = vec![1, 1, 1, 1]; + let y = vec![1.0, 1.0, 1.0, 1.0]; let (_, _, _, additional) = compute_metrics_from_classes(&predicted, &y, [true, false, false, false, false]); @@ -4605,7 +4611,7 @@ mod tests { #[test] fn test_gmean_calculation() { let predicted = vec![0, 0, 1, 1]; - let y = vec![0, 1, 0, 1]; + let y = vec![0.0, 1.0, 0.0, 1.0]; let (_, sens, spec, additional) = compute_metrics_from_classes(&predicted, &y, [false, false, false, false, true]); @@ -4635,7 +4641,7 @@ mod tests { #[should_panic(expected = "assertion failed")] fn test_bootstrap_invalid_subsample_frac_zero() { let value = vec![0.1, 0.9]; - let y = vec![0, 1]; + let y = vec![0.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let _ = compute_threshold_and_metrics_with_bootstrap( @@ -4654,7 +4660,7 @@ mod tests { #[should_panic(expected = "assertion failed")] fn test_bootstrap_invalid_subsample_frac_above_one() { let value = vec![0.1, 0.9]; - let y = vec![0, 1]; + let y = vec![0.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let _ = compute_threshold_and_metrics_with_bootstrap( @@ -4673,7 +4679,7 @@ mod tests { #[should_panic(expected = "assertion failed")] fn test_bootstrap_too_few_iterations() { let value = vec![0.1, 0.5, 0.9]; - let y = vec![0, 0, 1]; + let y = vec![0.0, 0.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let _ = compute_threshold_and_metrics_with_bootstrap( @@ -4692,7 +4698,7 @@ mod tests { #[should_panic(expected = "assertion failed")] fn test_bootstrap_invalid_alpha_zero() { let value = vec![0.1, 0.9]; - let y = vec![0, 1]; + let y = vec![0.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let _ = compute_threshold_and_metrics_with_bootstrap( @@ -4711,7 +4717,7 @@ mod tests { #[should_panic(expected = "assertion failed")] fn test_bootstrap_invalid_alpha_one() { let value = vec![0.1, 0.9]; - let y = vec![0, 1]; + let y = vec![0.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let _ = compute_threshold_and_metrics_with_bootstrap( @@ -4730,7 +4736,7 @@ mod tests { fn test_bootstrap_subsample_632() { // Test with .632 bootstrap (optimal subsampling) let value = vec![0.1, 0.2, 0.3, 0.4, 0.6, 0.7, 0.8, 0.9]; - let y = vec![0, 0, 0, 0, 1, 1, 1, 1]; + let y = vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let (auc, [lower, center, upper], acc, se, sp, obj, rej) = @@ -4761,7 +4767,7 @@ mod tests { fn test_bootstrap_half_bootstrap() { // Test with half-bootstrap (very conservative) let value = vec![0.1, 0.2, 0.3, 0.4, 0.6, 0.7, 0.8, 0.9]; - let y = vec![0, 0, 0, 0, 1, 1, 1, 1]; + let y = vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let (_, [lower_half, _, upper_half], _, _, _, _, _) = @@ -4812,7 +4818,7 @@ mod tests { #[test] fn test_bootstrap_ci_width_vs_n_bootstrap() { let value = vec![0.1, 0.2, 0.4, 0.6, 0.8, 0.9]; - let y = vec![0, 0, 0, 1, 1, 1]; + let y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]; // Small n_bootstrap let mut rng1 = ChaCha8Rng::seed_from_u64(42); @@ -4862,7 +4868,7 @@ mod tests { #[test] fn test_bootstrap_different_fit_functions() { let value = vec![0.1, 0.2, 0.3, 0.7, 0.8, 0.9]; - let y = vec![0, 0, 0, 1, 1, 1]; + let y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]; let mut rng1 = ChaCha8Rng::seed_from_u64(42); let (_, [_, center_auc, _], _, _, _, obj_auc, _) = @@ -4918,7 +4924,7 @@ mod tests { #[test] fn test_bootstrap_with_penalties() { let value = vec![0.1, 0.3, 0.5, 0.7, 0.9]; - let y = vec![0, 0, 1, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0, 1.0]; // Test with FPR/FNR penalties using sensitivity as fit function let mut rng = ChaCha8Rng::seed_from_u64(42); @@ -4942,7 +4948,7 @@ mod tests { fn test_bootstrap_stratification_preserved() { // Test that stratification is maintained across bootstrap samples let value = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]; - let y = vec![0, 0, 0, 0, 1, 1, 1, 1]; + let y = vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let (_, [lower, _, upper], acc, se, sp, _, rej) = @@ -4974,7 +4980,7 @@ mod tests { fn test_bootstrap_extreme_scores() { // Test with extreme score values let value = vec![-1e6, -1e3, 1e3, 1e6]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let (auc, [lower, center, upper], _, _, _, _, _) = @@ -5001,7 +5007,7 @@ mod tests { fn test_bootstrap_small_dataset() { // Test with minimal dataset (edge case) let value = vec![0.2, 0.3, 0.7, 0.8]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let (auc, [lower, center, upper], _, _, _, _, _) = @@ -5031,7 +5037,7 @@ mod tests { fn test_bootstrap_all_same_class() { // Edge case: all samples from same class let value = vec![0.1, 0.2, 0.3, 0.4]; - let y = vec![1, 1, 1, 1]; + let y = vec![1.0, 1.0, 1.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { @@ -5078,7 +5084,7 @@ mod tests { fn test_bootstrap_perfect_separation() { // Perfect separation between classes let value = vec![0.1, 0.2, 0.3, 0.7, 0.8, 0.9]; - let y = vec![0, 0, 0, 1, 1, 1]; + let y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let (auc, [lower, center, upper], acc, se, sp, _, _) = @@ -5116,7 +5122,7 @@ mod tests { fn test_bootstrap_ties_in_scores() { // Many tied scores let value = vec![0.5, 0.5, 0.5, 0.5, 0.5, 0.5]; - let y = vec![0, 0, 0, 1, 1, 1]; + let y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let (auc, [lower, center, upper], _, _, _, _, _) = @@ -5146,7 +5152,7 @@ mod tests { fn test_bootstrap_geyer_rescaling() { // Test that Geyer rescaling is applied correctly for subsampling let value: Vec = (0..20).map(|i| i as f64 / 20.0).collect(); - let y: Vec = (0..20).map(|i| if i < 10 { 0 } else { 1 }).collect(); + let y: Vec = (0..20).map(|i| if i < 10 { 0.0 } else { 1.0 }).collect(); let mut rng1 = ChaCha8Rng::seed_from_u64(42); let (_, [l1, c1, u1], _, _, _, _, _) = compute_threshold_and_metrics_with_bootstrap( @@ -5181,7 +5187,7 @@ mod tests { fn test_bootstrap_ci_coverage_stability() { // Test that CI is stable across different random seeds let value = vec![0.1, 0.2, 0.3, 0.4, 0.6, 0.7, 0.8, 0.9]; - let y = vec![0, 0, 0, 0, 1, 1, 1, 1]; + let y = vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]; let mut widths = Vec::new(); @@ -5220,7 +5226,7 @@ mod tests { #[test] fn test_bootstrap_metrics_consistency() { let value = vec![0.1, 0.3, 0.5, 0.7, 0.9]; - let y = vec![0, 0, 1, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let (_auc, [lower, center, upper], acc, se, sp, _, rej) = @@ -5271,7 +5277,7 @@ mod tests { #[test] fn test_bootstrap_return_values_structure() { let value = vec![0.1, 0.9]; - let y = vec![0, 1]; + let y = vec![0.0, 1.0]; let mut rng = ChaCha8Rng::seed_from_u64(42); let result = compute_threshold_and_metrics_with_bootstrap( @@ -5314,7 +5320,7 @@ mod tests { fn test_precomputed_bootstrap_equivalence() { // Test that precomputed bootstrap gives same results as regular bootstrap let value = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]; - let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; + let y = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0]; let seed = 12345; let n_bootstrap = 1000; @@ -5410,7 +5416,7 @@ mod tests { fn test_precomputed_bootstrap_with_penalties() { // Test precomputed bootstrap with penalties let value = vec![0.2, 0.4, 0.6, 0.8]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; let seed = 42; let n_bootstrap = 500; @@ -5471,7 +5477,7 @@ mod tests { // We have 3 positive samples and 2 negative samples let scores = vec![0.1, 0.3, 0.5, 0.7, 0.9]; - let y = vec![0, 0, 1, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0, 1.0]; // Compute ROC and get the optimal threshold let (auc, threshold, acc_roc, sens_roc, spec_roc, _obj) = @@ -5510,7 +5516,7 @@ mod tests { // This tests the >= rule explicitly let scores = vec![0.2, 0.4, 0.6, 0.6, 0.8]; - let y = vec![0, 0, 1, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0, 1.0]; let threshold = 0.6; // Manually compute expected metrics with >= rule @@ -5533,7 +5539,7 @@ mod tests { fn test_threshold_boundary_all_equal() { // Case: All scores equal to threshold let scores = vec![0.5, 0.5, 0.5, 0.5]; - let y = vec![1, 1, 0, 0]; + let y = vec![1.0, 1.0, 0.0, 0.0]; let threshold = 0.5; // With >= rule, all should be classified as positive (class 1) @@ -5560,7 +5566,7 @@ mod tests { fn test_threshold_boundary_just_below() { // Case: Threshold just below the smallest positive score let scores = vec![0.1, 0.2, 0.5, 0.6]; - let y = vec![0, 0, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0]; let threshold = 0.49; // Scores >= 0.49: [0.5, 0.6] -> class 1 @@ -5580,7 +5586,7 @@ mod tests { // by checking against manual computation let scores = vec![0.1, 0.3, 0.5, 0.7, 0.9]; - let y = vec![0, 1, 0, 1, 1]; + let y = vec![0.0, 1.0, 0.0, 1.0, 1.0]; let (auc, threshold, acc, sens, spec, obj) = compute_roc_and_metrics_from_value(&scores, &y, &FitFunction::auc, None); @@ -5613,8 +5619,8 @@ mod tests { 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, // 7 positives ]; let y_high = vec![ - 0, 0, 0, // negatives - 1, 1, 1, 1, 1, 1, 1, // positives + 0.0, 0.0, 0.0, // negatives + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, // positives ]; let (auc_high, threshold_high, acc_high, sens_high, spec_high, _) = @@ -5638,8 +5644,8 @@ mod tests { 0.8, 0.9, 1.0, // 3 positives ]; let y_low = vec![ - 0, 0, 0, 0, 0, 0, 0, // negatives - 1, 1, 1, // positives + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // negatives + 1.0, 1.0, 1.0, // positives ]; let (auc_low, threshold_low, acc_low, sens_low, spec_low, _) = @@ -5659,7 +5665,7 @@ mod tests { // Test when optimal threshold equals the last (highest) score // This should result in classifying all samples as negative let scores = vec![0.1, 0.3, 0.5, 0.7, 0.9]; - let y = vec![1, 1, 1, 1, 0]; // Last score is negative + let y = vec![1.0, 1.0, 1.0, 1.0, 0.0]; // Last score is negative let (auc, threshold, acc, sens, spec, _) = compute_roc_and_metrics_from_value(&scores, &y, &FitFunction::auc, None); @@ -5707,7 +5713,7 @@ mod tests { // Test when optimal threshold equals the first (lowest) score // This should result in classifying all samples as positive let scores = vec![0.1, 0.3, 0.5, 0.7, 0.9]; - let y = vec![0, 1, 1, 1, 1]; // First score is negative + let y = vec![0.0, 1.0, 1.0, 1.0, 1.0]; // First score is negative let (auc, threshold, acc, sens, spec, _) = compute_roc_and_metrics_from_value(&scores, &y, &FitFunction::auc, None); @@ -5749,7 +5755,7 @@ mod tests { // Test with continuous scores (no duplicates) // This is common in real-world scenarios with floating-point predictions let scores = vec![0.123, 0.456, 0.789, 0.234, 0.567, 0.890, 0.345, 0.678]; - let y = vec![0, 0, 1, 0, 1, 1, 0, 1]; + let y = vec![0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0]; let (auc, threshold, acc, sens, spec, _) = compute_roc_and_metrics_from_value(&scores, &y, &FitFunction::auc, None); @@ -5782,7 +5788,7 @@ mod tests { fn test_continuous_scores_with_ties() { // Test with continuous scores that have some ties let scores = vec![0.1, 0.2, 0.2, 0.3, 0.4, 0.4, 0.5, 0.6]; - let y = vec![0, 0, 1, 0, 1, 1, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]; let (_auc, threshold, acc, sens, spec, _) = compute_roc_and_metrics_from_value(&scores, &y, &FitFunction::auc, None); @@ -5802,7 +5808,7 @@ mod tests { fn test_threshold_with_very_small_continuous_values() { // Test with very small continuous values (near zero) let scores = vec![0.0001, 0.0002, 0.0003, 0.0004, 0.0005]; - let y = vec![0, 0, 1, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0, 1.0]; let (auc, threshold, acc, sens, spec, _) = compute_roc_and_metrics_from_value(&scores, &y, &FitFunction::auc, None); @@ -5820,7 +5826,7 @@ mod tests { fn test_threshold_with_large_continuous_values() { // Test with large continuous values let scores = vec![1000.1, 2000.5, 3000.3, 4000.7, 5000.2]; - let y = vec![0, 0, 1, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0, 1.0]; let (auc, threshold, acc, sens, spec, _) = compute_roc_and_metrics_from_value(&scores, &y, &FitFunction::auc, None); @@ -5839,7 +5845,7 @@ mod tests { // Test edge case where optimal threshold is last_score + 1.0 // This happens when classifying all as negative is optimal let scores = vec![0.1, 0.2, 0.3, 0.4, 0.5]; - let y = vec![0, 0, 0, 0, 0]; // All negatives + let y = vec![0.0, 0.0, 0.0, 0.0, 0.0]; // All negatives let (auc, _threshold, _acc, _sens, _spec, _) = compute_roc_and_metrics_from_value(&scores, &y, &FitFunction::auc, None); @@ -5855,7 +5861,7 @@ mod tests { fn test_threshold_negative_scores() { // Test with negative scores let scores = vec![-0.5, -0.3, -0.1, 0.1, 0.3]; - let y = vec![0, 0, 1, 1, 1]; + let y = vec![0.0, 0.0, 1.0, 1.0, 1.0]; let (auc, threshold, acc, sens, spec, _) = compute_roc_and_metrics_from_value(&scores, &y, &FitFunction::auc, None); @@ -5876,7 +5882,7 @@ mod tests { fn test_threshold_mixed_positive_negative_scores() { // Test with mix of positive and negative scores, unsorted let scores = vec![1.5, -2.3, 0.5, -0.8, 3.2, 0.0, -1.1, 2.7]; - let y = vec![1, 0, 0, 0, 1, 0, 0, 1]; + let y = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0]; let (auc, threshold, acc, sens, spec, _) = compute_roc_and_metrics_from_value(&scores, &y, &FitFunction::auc, None); @@ -5897,7 +5903,7 @@ mod tests { fn test_threshold_continuous_perfect_separation() { // Perfect separation with continuous scores let scores = vec![0.12, 0.23, 0.34, 0.45, 0.67, 0.78, 0.89, 0.91]; - let y = vec![0, 0, 0, 0, 1, 1, 1, 1]; + let y = vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]; let (auc, threshold, acc, sens, spec, _) = compute_roc_and_metrics_from_value(&scores, &y, &FitFunction::auc, None); @@ -5942,7 +5948,7 @@ mod tests { // Case 3a: All scores equal to threshold let scores_a = vec![0.5, 0.5, 0.5, 0.5]; - let y_a = vec![1, 1, 0, 0]; + let y_a = vec![1.0, 1.0, 0.0, 0.0]; let threshold_a = 0.5; println!("\nCase 3a: All scores == threshold"); @@ -5975,7 +5981,7 @@ mod tests { // Case 3b: Threshold just below the smallest positive score let scores_b = vec![0.1, 0.2, 0.5, 0.6]; - let y_b = vec![0, 0, 1, 1]; + let y_b = vec![0.0, 0.0, 1.0, 1.0]; let threshold_b = 0.49; // Just below 0.5 println!("\nCase 3b: Threshold just below smallest positive"); @@ -6012,8 +6018,8 @@ mod tests { 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, // 7 positives ]; let y_high = vec![ - 0, 0, 0, // negatives - 1, 1, 1, 1, 1, 1, 1, // positives + 0.0, 0.0, 0.0, // negatives + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, // positives ]; println!("\nHigh prevalence (70% positive):"); @@ -6048,8 +6054,8 @@ mod tests { 0.8, 0.9, 1.0, // 3 positives ]; let y_low = vec![ - 0, 0, 0, 0, 0, 0, 0, // negatives - 1, 1, 1, // positives + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // negatives + 1.0, 1.0, 1.0, // positives ]; println!("\nLow prevalence (30% positive):"); From f7fe47241152993669864ded78e5e38a260f8e36 Mon Sep 17 00:00:00 2001 From: Edi Prifti Date: Wed, 25 Mar 2026 23:32:41 +0100 Subject: [PATCH 3/3] =?UTF-8?q?feat:=20regression=20support=20=E2=80=94=20?= =?UTF-8?q?Spearman/RMSE=20fitness=20with=20prevalence-based=20feature=20s?= =?UTF-8?q?election?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Regression mode auto-detected from fit function (spearman/rmse/mutual_information) - Feature selection bypasses class-based tests, uses prevalence only - Tested on wetlab_protocol gene_count: SA finds Spearman=0.164 with k=15 Known limitation: display still shows AUC/sensitivity/specificity (all 0.0 for regression). The fit value correctly shows the Spearman correlation. Closes #15 --- src/data.rs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/data.rs b/src/data.rs index b802284..727e307 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1149,6 +1149,36 @@ impl Data { self.feature_selection = Vec::new(); self.feature_class = HashMap::new(); + // For regression (continuous y), skip statistical test — use all features above prevalence threshold + let is_regression = matches!( + param.general.fit, + crate::param::FitFunction::spearman + | crate::param::FitFunction::rmse + | crate::param::FitFunction::mutual_information + ); + + if is_regression { + info!("Regression mode: selecting features by prevalence only (no class-based test)"); + let min_prev = param.data.feature_minimal_prevalence_pct / 100.0; + for j in 0..self.feature_len { + let n_nonzero = (0..self.sample_len) + .filter(|&s| *self.X.get(&(s, j)).unwrap_or(&0.0) > 0.0) + .count(); + let prevalence = n_nonzero as f64 / self.sample_len as f64; + if prevalence >= min_prev { + self.feature_selection.push(j); + self.feature_class.insert(j, 0); // no class association for regression + self.feature_significance.insert(j, prevalence); + } + } + info!( + "{} features selected (prevalence >= {:.0}%)", + self.feature_selection.len(), + param.data.feature_minimal_prevalence_pct + ); + return; + } + let (class_0_features, class_1_features) = self.evaluate_features(param); for (j, class, value) in class_0_features