Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions src/cv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;

use log::info;
use log::{info, warn};
use std::collections::HashSet;
use std::sync::atomic::AtomicBool;

/// Cross-validation dataset implementation for machine learning workflows.
Expand Down Expand Up @@ -42,19 +43,33 @@ impl CV {
/// * `folds` - Number of validation folds to create
/// * `rng` - Random number generator for stratified sampling
pub fn new(data: &Data, folds: usize, rng: &mut ChaCha8Rng) -> CV {
let (indices_class1, indices_class0) = utils::stratify_indices_by_class(&data.y);

let indices_class0_folds =
utils::split_into_balanced_random_chunks(indices_class0, folds, rng);
let indices_class1_folds =
utils::split_into_balanced_random_chunks(indices_class1, folds, rng);

let validation_folds: Vec<Data> = indices_class0_folds
.into_iter()
.zip(indices_class1_folds.into_iter())
.map(|(i1, i2)| i1.into_iter().chain(i2).collect::<Vec<usize>>())
.map(|i| data.subset(i))
.collect();
// Detect regression mode: if y has more than 3 unique values, use random splitting
let unique_y: HashSet<u64> = data.y.iter().map(|v| v.to_bits()).collect();
let is_regression = unique_y.len() > 3;

let validation_folds: Vec<Data> = if is_regression {
warn!(
"Continuous y detected ({} unique values) — using random (non-stratified) CV splitting",
unique_y.len()
);
let all_indices: Vec<usize> = (0..data.sample_len).collect();
let all_folds = utils::split_into_balanced_random_chunks(all_indices, folds, rng);
all_folds.into_iter().map(|i| data.subset(i)).collect()
} else {
let (indices_class1, indices_class0) = utils::stratify_indices_by_class(&data.y);

let indices_class0_folds =
utils::split_into_balanced_random_chunks(indices_class0, folds, rng);
let indices_class1_folds =
utils::split_into_balanced_random_chunks(indices_class1, folds, rng);

indices_class0_folds
.into_iter()
.zip(indices_class1_folds.into_iter())
.map(|(i1, i2)| i1.into_iter().chain(i2).collect::<Vec<usize>>())
.map(|i| data.subset(i))
.collect()
};

let mut training_sets: Vec<Data> = Vec::new();
for i in 0..folds {
Expand Down
234 changes: 234 additions & 0 deletions src/individual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::param::FitFunction;
use crate::utils::serde_json_hashmap_numeric;
use crate::utils::{compute_auc_from_value, compute_roc_and_metrics_from_value};
use crate::utils::{compute_metrics_from_value, generate_random_vector, shuffle_row};
use crate::utils::{mutual_information, neg_rmse, pearson_correlation, spearman_correlation};
use crate::Population;
use log::{debug, error};
use rand::seq::SliceRandom; // Provides the `choose_multiple` method
Expand Down Expand Up @@ -535,6 +536,239 @@ impl Individual {
formatted_string
}

/// Generates a display string that is regression-aware.
///
/// For regression fit functions (spearman, pearson, rmse, mutual_information), shows
/// the regression metric instead of AUC/accuracy/sensitivity/specificity.
/// For classification fit functions, delegates to the standard `display()` method.
///
/// # Arguments
///
/// * `data` - The training data used for evaluation.
/// * `data_to_test` - Optional test data for additional evaluation.
/// * `algo` - The algorithm name string.
/// * `ci_alpha` - Confidence interval alpha for threshold CI display.
/// * `fit_function` - The fit function used for model evaluation.
pub fn display_with_fit(
&self,
data: &Data,
data_to_test: Option<&Data>,
algo: &String,
ci_alpha: f64,
fit_function: &FitFunction,
) -> String {
let is_regression = matches!(
fit_function,
FitFunction::spearman
| FitFunction::pearson
| FitFunction::rmse
| FitFunction::mutual_information
);

if !is_regression {
return self.display(data, data_to_test, algo, ci_alpha);
}

let fit_name = format!("{:?}", fit_function);

let algo_str = match algo.as_str() {
"ga" => format!(" [gen:{}] ", self.epoch),
"beam" => " ".to_string(),
"mcmc" => format!(" [MCMC step: {}] ", self.epoch),
_ => " [unknown] ".to_string(),
};

let regression_fn: fn(&[f64], &[f64]) -> f64 = match fit_function {
FitFunction::spearman => spearman_correlation,
FitFunction::pearson => pearson_correlation,
FitFunction::rmse => neg_rmse,
FitFunction::mutual_information => mutual_information,
_ => unreachable!(),
};

let metrics = match data_to_test {
Some(test_data) => {
let test_scores = self.evaluate(test_data);
let test_fit = regression_fn(&test_scores, &test_data.y);
format!(
"{}:{} [k={}]{}[fit:{:.3}] {} {:.3}/{:.3}",
self.get_language(),
self.get_data_type(),
self.features.len(),
algo_str,
self.fit,
fit_name,
self.fit,
test_fit
)
}
None => {
format!(
"{}:{} [k={}]{}[fit:{:.3}] {} {:.3}",
self.get_language(),
self.get_data_type(),
self.features.len(),
algo_str,
self.fit,
fit_name,
self.fit
)
}
};

// Sort features by index
let mut sorted_features: Vec<_> = self.features.iter().collect();
sorted_features.sort_by(|a, b| a.0.cmp(b.0));

let mut positive_features: Vec<_> = sorted_features
.iter()
.filter(|&&(_, &coef)| coef > 0)
.collect();
let mut negative_features: Vec<_> = sorted_features
.iter()
.filter(|&&(_, &coef)| coef < 0)
.collect();

positive_features.sort_by(|a, b| b.1.cmp(a.1));
negative_features.sort_by(|a, b| a.1.cmp(b.1));

let positive_str: Vec<String> = positive_features
.iter()
.enumerate()
.map(|(_i, &&(index, coef))| {
let mut str = format!("\x1b[96m{}\x1b[0m", data.features[*index]);
if self.data_type == PREVALENCE_TYPE {
str = format!("{}⁰", str);
}
if self.language == POW2_LANG && !(*coef == 1_i8) && self.data_type != LOG_TYPE {
str = format!("{}*{}", coef, str);
} else if self.language == POW2_LANG
&& !(*coef == 1_i8)
&& self.data_type == LOG_TYPE
{
str = format!("{}^{}", str, coef);
}
str
})
.collect();

let negative_str: Vec<String> = negative_features
.iter()
.enumerate()
.map(|(_i, &&(index, coef))| {
let mut str = format!("\x1b[95m{}\x1b[0m", data.features[*index]);
if self.data_type == PREVALENCE_TYPE {
str = format!("{}⁰", str);
}
if self.language == POW2_LANG && !(*coef == -1_i8) && self.data_type != LOG_TYPE {
str = format!("{}*{}", coef.abs(), str);
} else if self.language == POW2_LANG
&& !(*coef == -1_i8)
&& self.data_type == LOG_TYPE
{
str = format!("{}^{}", str, coef.abs());
}
str
})
.collect();

let mut negative_str_owned = negative_str.clone();
if self.language == RATIO_LANG && self.data_type != LOG_TYPE {
negative_str_owned.push(format!("{:2e}", self.epsilon));
}

let positive_str_final = if positive_str.is_empty() {
vec!["0".to_string()]
} else {
positive_str
};
let negative_str_final = if negative_str_owned.is_empty() {
vec!["0".to_string()]
} else {
negative_str_owned
};

let (positive_joined, negative_joined) = if self.data_type == LOG_TYPE {
let pos_str: Vec<String> = positive_str_final
.iter()
.map(|f| {
if f != "0" {
format!("{}⁺", f)
} else {
"1".to_string()
}
})
.collect();
let neg_str: Vec<String> = negative_str_final
.iter()
.map(|f| {
if f != "0" {
format!("{}⁺", f)
} else {
"1".to_string()
}
})
.collect();
(
format!("ln({})", pos_str.join(" × ")),
format!("ln({})", neg_str.join(" × ")),
)
} else {
(
format!("({})", positive_str_final.join(" + ")),
format!("({})", negative_str_final.join(" + ")),
)
};

// For regression, show "score =" instead of class-based threshold
let second_line_first_part = "score =".to_string();
let second_line_second_part = String::new();

let formatted_string = if self.language == BINARY_LANG {
format!(
"{}\n{} {} {}",
metrics, second_line_first_part, positive_joined, second_line_second_part
)
} else if self.language == TERNARY_LANG || self.language == POW2_LANG {
format!(
"{}\n{} {} - {} {}",
metrics,
second_line_first_part,
positive_joined,
negative_joined,
second_line_second_part
)
} else if self.language == RATIO_LANG {
if self.data_type == LOG_TYPE {
format!(
"{}\n{} {} - {} - {} {}",
metrics,
second_line_first_part,
positive_joined,
negative_joined,
self.epsilon,
second_line_second_part
)
} else {
format!(
"{}\n{} {} / {} {}",
metrics,
second_line_first_part,
positive_joined,
negative_joined,
second_line_second_part
)
}
} else {
format!(
"{}\n{} {:?} {}",
metrics, second_line_first_part, self, second_line_second_part
)
};

formatted_string
}

/// Computes the hash of the Individual based on its features and betas
///
/// # Examples
Expand Down
39 changes: 35 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ use population::Population;
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;

use log::{debug, error, warn};
use log::{debug, error, info, warn};
use param::FitFunction;

use std::sync::atomic::AtomicBool;
use std::sync::Arc;
Expand Down Expand Up @@ -293,7 +294,17 @@ pub fn run(param: &Param, running: Arc<AtomicBool>) -> Experiment {

// Voting
if param.voting.vote {
exp.compute_voting();
if matches!(
param.general.fit,
FitFunction::spearman
| FitFunction::pearson
| FitFunction::rmse
| FitFunction::mutual_information
) {
info!("Voting is not applicable for regression — skipping");
} else {
exp.compute_voting();
}
} else {
cinfo!(
param.general.display_colorful,
Expand Down Expand Up @@ -447,7 +458,17 @@ pub fn run_on_data(

// Voting
if param.voting.vote && exp.parameters.general.algo != "mcmc" {
exp.compute_voting();
if matches!(
param.general.fit,
FitFunction::spearman
| FitFunction::pearson
| FitFunction::rmse
| FitFunction::mutual_information
) {
info!("Voting is not applicable for regression — skipping");
} else {
exp.compute_voting();
}
} else {
cinfo!(
param.general.display_colorful,
Expand Down Expand Up @@ -608,7 +629,17 @@ pub fn run_pop_and_data(

// Voting
if param.voting.vote && exp.parameters.general.algo != "mcmc" {
exp.compute_voting();
if matches!(
param.general.fit,
FitFunction::spearman
| FitFunction::pearson
| FitFunction::rmse
| FitFunction::mutual_information
) {
info!("Voting is not applicable for regression — skipping");
} else {
exp.compute_voting();
}
} else {
cinfo!(
param.general.display_colorful,
Expand Down
Loading
Loading