diff --git a/src/cv.rs b/src/cv.rs index 3d01d20..5d8627d 100644 --- a/src/cv.rs +++ b/src/cv.rs @@ -12,7 +12,8 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; -use log::info; +use log::{info, warn}; +use std::collections::HashSet; use std::sync::atomic::AtomicBool; /// Cross-validation dataset implementation for machine learning workflows. @@ -42,19 +43,33 @@ impl CV { /// * `folds` - Number of validation folds to create /// * `rng` - Random number generator for stratified sampling pub fn new(data: &Data, folds: usize, rng: &mut ChaCha8Rng) -> CV { - let (indices_class1, indices_class0) = utils::stratify_indices_by_class(&data.y); - - let indices_class0_folds = - utils::split_into_balanced_random_chunks(indices_class0, folds, rng); - let indices_class1_folds = - utils::split_into_balanced_random_chunks(indices_class1, folds, rng); - - let validation_folds: Vec = indices_class0_folds - .into_iter() - .zip(indices_class1_folds.into_iter()) - .map(|(i1, i2)| i1.into_iter().chain(i2).collect::>()) - .map(|i| data.subset(i)) - .collect(); + // Detect regression mode: if y has more than 3 unique values, use random splitting + let unique_y: HashSet = data.y.iter().map(|v| v.to_bits()).collect(); + let is_regression = unique_y.len() > 3; + + let validation_folds: Vec = if is_regression { + warn!( + "Continuous y detected ({} unique values) — using random (non-stratified) CV splitting", + unique_y.len() + ); + let all_indices: Vec = (0..data.sample_len).collect(); + let all_folds = utils::split_into_balanced_random_chunks(all_indices, folds, rng); + all_folds.into_iter().map(|i| data.subset(i)).collect() + } else { + let (indices_class1, indices_class0) = utils::stratify_indices_by_class(&data.y); + + let indices_class0_folds = + utils::split_into_balanced_random_chunks(indices_class0, folds, rng); + let indices_class1_folds = + utils::split_into_balanced_random_chunks(indices_class1, folds, rng); + + indices_class0_folds + .into_iter() + .zip(indices_class1_folds.into_iter()) + .map(|(i1, i2)| i1.into_iter().chain(i2).collect::>()) + .map(|i| data.subset(i)) + .collect() + }; let mut training_sets: Vec = Vec::new(); for i in 0..folds { diff --git a/src/individual.rs b/src/individual.rs index 732babd..e6e7872 100644 --- a/src/individual.rs +++ b/src/individual.rs @@ -5,6 +5,7 @@ use crate::param::FitFunction; use crate::utils::serde_json_hashmap_numeric; use crate::utils::{compute_auc_from_value, compute_roc_and_metrics_from_value}; use crate::utils::{compute_metrics_from_value, generate_random_vector, shuffle_row}; +use crate::utils::{mutual_information, neg_rmse, pearson_correlation, spearman_correlation}; use crate::Population; use log::{debug, error}; use rand::seq::SliceRandom; // Provides the `choose_multiple` method @@ -535,6 +536,239 @@ impl Individual { formatted_string } + /// Generates a display string that is regression-aware. + /// + /// For regression fit functions (spearman, pearson, rmse, mutual_information), shows + /// the regression metric instead of AUC/accuracy/sensitivity/specificity. + /// For classification fit functions, delegates to the standard `display()` method. + /// + /// # Arguments + /// + /// * `data` - The training data used for evaluation. + /// * `data_to_test` - Optional test data for additional evaluation. + /// * `algo` - The algorithm name string. + /// * `ci_alpha` - Confidence interval alpha for threshold CI display. + /// * `fit_function` - The fit function used for model evaluation. + pub fn display_with_fit( + &self, + data: &Data, + data_to_test: Option<&Data>, + algo: &String, + ci_alpha: f64, + fit_function: &FitFunction, + ) -> String { + let is_regression = matches!( + fit_function, + FitFunction::spearman + | FitFunction::pearson + | FitFunction::rmse + | FitFunction::mutual_information + ); + + if !is_regression { + return self.display(data, data_to_test, algo, ci_alpha); + } + + let fit_name = format!("{:?}", fit_function); + + let algo_str = match algo.as_str() { + "ga" => format!(" [gen:{}] ", self.epoch), + "beam" => " ".to_string(), + "mcmc" => format!(" [MCMC step: {}] ", self.epoch), + _ => " [unknown] ".to_string(), + }; + + let regression_fn: fn(&[f64], &[f64]) -> f64 = match fit_function { + FitFunction::spearman => spearman_correlation, + FitFunction::pearson => pearson_correlation, + FitFunction::rmse => neg_rmse, + FitFunction::mutual_information => mutual_information, + _ => unreachable!(), + }; + + let metrics = match data_to_test { + Some(test_data) => { + let test_scores = self.evaluate(test_data); + let test_fit = regression_fn(&test_scores, &test_data.y); + format!( + "{}:{} [k={}]{}[fit:{:.3}] {} {:.3}/{:.3}", + self.get_language(), + self.get_data_type(), + self.features.len(), + algo_str, + self.fit, + fit_name, + self.fit, + test_fit + ) + } + None => { + format!( + "{}:{} [k={}]{}[fit:{:.3}] {} {:.3}", + self.get_language(), + self.get_data_type(), + self.features.len(), + algo_str, + self.fit, + fit_name, + self.fit + ) + } + }; + + // Sort features by index + let mut sorted_features: Vec<_> = self.features.iter().collect(); + sorted_features.sort_by(|a, b| a.0.cmp(b.0)); + + let mut positive_features: Vec<_> = sorted_features + .iter() + .filter(|&&(_, &coef)| coef > 0) + .collect(); + let mut negative_features: Vec<_> = sorted_features + .iter() + .filter(|&&(_, &coef)| coef < 0) + .collect(); + + positive_features.sort_by(|a, b| b.1.cmp(a.1)); + negative_features.sort_by(|a, b| a.1.cmp(b.1)); + + let positive_str: Vec = positive_features + .iter() + .enumerate() + .map(|(_i, &&(index, coef))| { + let mut str = format!("\x1b[96m{}\x1b[0m", data.features[*index]); + if self.data_type == PREVALENCE_TYPE { + str = format!("{}⁰", str); + } + if self.language == POW2_LANG && !(*coef == 1_i8) && self.data_type != LOG_TYPE { + str = format!("{}*{}", coef, str); + } else if self.language == POW2_LANG + && !(*coef == 1_i8) + && self.data_type == LOG_TYPE + { + str = format!("{}^{}", str, coef); + } + str + }) + .collect(); + + let negative_str: Vec = negative_features + .iter() + .enumerate() + .map(|(_i, &&(index, coef))| { + let mut str = format!("\x1b[95m{}\x1b[0m", data.features[*index]); + if self.data_type == PREVALENCE_TYPE { + str = format!("{}⁰", str); + } + if self.language == POW2_LANG && !(*coef == -1_i8) && self.data_type != LOG_TYPE { + str = format!("{}*{}", coef.abs(), str); + } else if self.language == POW2_LANG + && !(*coef == -1_i8) + && self.data_type == LOG_TYPE + { + str = format!("{}^{}", str, coef.abs()); + } + str + }) + .collect(); + + let mut negative_str_owned = negative_str.clone(); + if self.language == RATIO_LANG && self.data_type != LOG_TYPE { + negative_str_owned.push(format!("{:2e}", self.epsilon)); + } + + let positive_str_final = if positive_str.is_empty() { + vec!["0".to_string()] + } else { + positive_str + }; + let negative_str_final = if negative_str_owned.is_empty() { + vec!["0".to_string()] + } else { + negative_str_owned + }; + + let (positive_joined, negative_joined) = if self.data_type == LOG_TYPE { + let pos_str: Vec = positive_str_final + .iter() + .map(|f| { + if f != "0" { + format!("{}⁺", f) + } else { + "1".to_string() + } + }) + .collect(); + let neg_str: Vec = negative_str_final + .iter() + .map(|f| { + if f != "0" { + format!("{}⁺", f) + } else { + "1".to_string() + } + }) + .collect(); + ( + format!("ln({})", pos_str.join(" × ")), + format!("ln({})", neg_str.join(" × ")), + ) + } else { + ( + format!("({})", positive_str_final.join(" + ")), + format!("({})", negative_str_final.join(" + ")), + ) + }; + + // For regression, show "score =" instead of class-based threshold + let second_line_first_part = "score =".to_string(); + let second_line_second_part = String::new(); + + let formatted_string = if self.language == BINARY_LANG { + format!( + "{}\n{} {} {}", + metrics, second_line_first_part, positive_joined, second_line_second_part + ) + } else if self.language == TERNARY_LANG || self.language == POW2_LANG { + format!( + "{}\n{} {} - {} {}", + metrics, + second_line_first_part, + positive_joined, + negative_joined, + second_line_second_part + ) + } else if self.language == RATIO_LANG { + if self.data_type == LOG_TYPE { + format!( + "{}\n{} {} - {} - {} {}", + metrics, + second_line_first_part, + positive_joined, + negative_joined, + self.epsilon, + second_line_second_part + ) + } else { + format!( + "{}\n{} {} / {} {}", + metrics, + second_line_first_part, + positive_joined, + negative_joined, + second_line_second_part + ) + } + } else { + format!( + "{}\n{} {:?} {}", + metrics, second_line_first_part, self, second_line_second_part + ) + }; + + formatted_string + } + /// Computes the hash of the Individual based on its features and betas /// /// # Examples diff --git a/src/lib.rs b/src/lib.rs index d81dc28..31da7e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -99,7 +99,8 @@ use population::Population; use rand::prelude::*; use rand_chacha::ChaCha8Rng; -use log::{debug, error, warn}; +use log::{debug, error, info, warn}; +use param::FitFunction; use std::sync::atomic::AtomicBool; use std::sync::Arc; @@ -293,7 +294,17 @@ pub fn run(param: &Param, running: Arc) -> Experiment { // Voting if param.voting.vote { - exp.compute_voting(); + if matches!( + param.general.fit, + FitFunction::spearman + | FitFunction::pearson + | FitFunction::rmse + | FitFunction::mutual_information + ) { + info!("Voting is not applicable for regression — skipping"); + } else { + exp.compute_voting(); + } } else { cinfo!( param.general.display_colorful, @@ -447,7 +458,17 @@ pub fn run_on_data( // Voting if param.voting.vote && exp.parameters.general.algo != "mcmc" { - exp.compute_voting(); + if matches!( + param.general.fit, + FitFunction::spearman + | FitFunction::pearson + | FitFunction::rmse + | FitFunction::mutual_information + ) { + info!("Voting is not applicable for regression — skipping"); + } else { + exp.compute_voting(); + } } else { cinfo!( param.general.display_colorful, @@ -608,7 +629,17 @@ pub fn run_pop_and_data( // Voting if param.voting.vote && exp.parameters.general.algo != "mcmc" { - exp.compute_voting(); + if matches!( + param.general.fit, + FitFunction::spearman + | FitFunction::pearson + | FitFunction::rmse + | FitFunction::mutual_information + ) { + info!("Voting is not applicable for regression — skipping"); + } else { + exp.compute_voting(); + } } else { cinfo!( param.general.display_colorful, diff --git a/src/population.rs b/src/population.rs index 8bb0a7a..ed7635f 100644 --- a/src/population.rs +++ b/src/population.rs @@ -14,6 +14,7 @@ use crate::utils::{ conf_inter_binomial_method, precompute_bootstrap_indices, PrecomputedBootstrap, }; use crate::utils::{mad, mean_and_std, median}; +use crate::utils::{mutual_information, neg_rmse, pearson_correlation, spearman_correlation}; use log::{error, info, warn}; use rand::prelude::SliceRandom; use rand::RngCore; @@ -141,48 +142,98 @@ impl Population { let fbm = self.select_best_population(0.05); + let is_regression = matches!( + param.general.fit, + FitFunction::spearman + | FitFunction::pearson + | FitFunction::rmse + | FitFunction::mutual_information + ); + let fit_name = format!("{:?}", param.general.fit); + if !fbm.individuals.is_empty() { - let train_aucs: Vec = fbm.individuals.iter().map(|i| i.cls.auc).collect(); - let train_accuracies: Vec = - fbm.individuals.iter().map(|i| i.cls.accuracy).collect(); - let train_sensitivities: Vec = - fbm.individuals.iter().map(|i| i.cls.sensitivity).collect(); - let train_specificities: Vec = - fbm.individuals.iter().map(|i| i.cls.specificity).collect(); - - let (train_auc_mean, _) = mean_and_std(&train_aucs); - let (train_acc_mean, _) = mean_and_std(&train_accuracies); - let (train_sens_mean, _) = mean_and_std(&train_sensitivities); - let (train_spec_mean, _) = mean_and_std(&train_specificities); - - let fbm_text = if let Some(data_to_test) = data_to_test { - let test_aucs: Vec = fbm - .individuals - .iter() - .map(|i| i.compute_new_auc(data_to_test)) - .collect(); - let mut test_acc_vec = Vec::new(); - let mut test_sens_vec = Vec::new(); - let mut test_spec_vec = Vec::new(); - - for individual in &fbm.individuals { - let (acc, sens, spec, _, _) = individual.compute_metrics(data_to_test); - test_acc_vec.push(acc); - test_sens_vec.push(sens); - test_spec_vec.push(spec); + let fbm_text = if is_regression { + // Regression mode: display the fit metric + let train_fits: Vec = fbm.individuals.iter().map(|i| i.fit).collect(); + let (train_fit_mean, _) = mean_and_std(&train_fits); + + if let Some(test_data) = data_to_test { + let regression_fn: fn(&[f64], &[f64]) -> f64 = match param.general.fit { + FitFunction::spearman => spearman_correlation, + FitFunction::pearson => pearson_correlation, + FitFunction::rmse => neg_rmse, + FitFunction::mutual_information => mutual_information, + _ => unreachable!(), + }; + let test_fits: Vec = fbm + .individuals + .iter() + .map(|i| { + let scores = i.evaluate(test_data); + regression_fn(&scores, &test_data.y) + }) + .collect(); + let (test_fit_mean, _) = mean_and_std(&test_fits); + + format!( + "\n\x1b[1;33mFBM mean (n={}) - {} {:.3}/{:.3}\x1b[0m\n", + fbm.individuals.len(), + fit_name, + train_fit_mean, + test_fit_mean + ) + } else { + format!( + "\n\x1b[1;33mFBM mean (n={}) - {} {:.3}\x1b[0m\n", + fbm.individuals.len(), + fit_name, + train_fit_mean + ) } + } else { + // Classification mode: display AUC/accuracy/sensitivity/specificity + let train_aucs: Vec = fbm.individuals.iter().map(|i| i.cls.auc).collect(); + let train_accuracies: Vec = + fbm.individuals.iter().map(|i| i.cls.accuracy).collect(); + let train_sensitivities: Vec = + fbm.individuals.iter().map(|i| i.cls.sensitivity).collect(); + let train_specificities: Vec = + fbm.individuals.iter().map(|i| i.cls.specificity).collect(); + + let (train_auc_mean, _) = mean_and_std(&train_aucs); + let (train_acc_mean, _) = mean_and_std(&train_accuracies); + let (train_sens_mean, _) = mean_and_std(&train_sensitivities); + let (train_spec_mean, _) = mean_and_std(&train_specificities); + + if let Some(data_to_test) = data_to_test { + let test_aucs: Vec = fbm + .individuals + .iter() + .map(|i| i.compute_new_auc(data_to_test)) + .collect(); + let mut test_acc_vec = Vec::new(); + let mut test_sens_vec = Vec::new(); + let mut test_spec_vec = Vec::new(); + + for individual in &fbm.individuals { + let (acc, sens, spec, _, _) = individual.compute_metrics(data_to_test); + test_acc_vec.push(acc); + test_sens_vec.push(sens); + test_spec_vec.push(spec); + } - let (test_auc_mean, _) = mean_and_std(&test_aucs); - let (test_acc_mean, _) = mean_and_std(&test_acc_vec); - let (test_sens_mean, _) = mean_and_std(&test_sens_vec); - let (test_spec_mean, _) = mean_and_std(&test_spec_vec); + let (test_auc_mean, _) = mean_and_std(&test_aucs); + let (test_acc_mean, _) = mean_and_std(&test_acc_vec); + let (test_sens_mean, _) = mean_and_std(&test_sens_vec); + let (test_spec_mean, _) = mean_and_std(&test_spec_vec); - format!("\n\x1b[1;33mFBM mean (n={}) - AUC {:.3}/{:.3} | accuracy {:.3}/{:.3} | sensitivity {:.3}/{:.3} | specificity {:.3}/{:.3}\x1b[0m\n", - fbm.individuals.len(), train_auc_mean, test_auc_mean, train_acc_mean, test_acc_mean, - train_sens_mean, test_sens_mean, train_spec_mean, test_spec_mean) - } else { - format!("\n\x1b[1;33mFBM mean (n={}) - AUC {:.3} | accuracy {:.3} | sensitivity {:.3} | specificity {:.3}\x1b[0m\n", - fbm.individuals.len(), train_auc_mean, train_acc_mean, train_sens_mean, train_spec_mean) + format!("\n\x1b[1;33mFBM mean (n={}) - AUC {:.3}/{:.3} | accuracy {:.3}/{:.3} | sensitivity {:.3}/{:.3} | specificity {:.3}/{:.3}\x1b[0m\n", + fbm.individuals.len(), train_auc_mean, test_auc_mean, train_acc_mean, test_acc_mean, + train_sens_mean, test_sens_mean, train_spec_mean, test_spec_mean) + } else { + format!("\n\x1b[1;33mFBM mean (n={}) - AUC {:.3} | accuracy {:.3} | sensitivity {:.3} | specificity {:.3}\x1b[0m\n", + fbm.individuals.len(), train_auc_mean, train_acc_mean, train_sens_mean, train_spec_mean) + } }; str = format!("{}{}", str, fbm_text); @@ -201,11 +252,12 @@ impl Population { }; for i in 0..limit as usize { - let model_display = self.individuals[i].display( + let model_display = self.individuals[i].display_with_fit( data, data_to_test, ¶m.general.algo, param.general.threshold_ci_alpha, + ¶m.general.fit, ); let model_text = format!("\x1b[1;93m#{:?}\x1b[0m", i + 1);