Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/bayesian_mcmc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@
let [a, b, c] = ind.get_betas();
let mut log_likelihood = 0.0;
for (i_sample, z_sample) in z.iter().enumerate() {
let y_sample = self.data.y[i_sample] as f64;
let y_sample = self.data.y[i_sample];
let value = z_sample[0] * a + z_sample[1] * b + z_sample[2] * c;

if y_sample == 1.0 {
Expand Down Expand Up @@ -449,7 +449,7 @@

/// Helper structure for beta optimization (unused, kept for reference).
/// Used to minimize the negative log posterior probability for parameter optimization.
/// Holds references to the Bayesian prediction model, parameter index, current beta values, and feature groups.

Check warning on line 452 in src/bayesian_mcmc.rs

View workflow job for this annotation

GitHub Actions / lint

empty lines after doc comment

/// Implements the cost function for optimization.
/// Returns the negative log posterior probability for the proposed parameter value.
Expand Down Expand Up @@ -833,9 +833,8 @@
}
}

let y_f64: Vec<f64> = data.y.iter().map(|&v| v as f64).collect();
let y_mean = y_f64.iter().sum::<f64>() / n_samples as f64;
let y_centered: Vec<f64> = y_f64.iter().map(|v| v - y_mean).collect();
let y_mean = data.y.iter().sum::<f64>() / n_samples as f64;
let y_centered: Vec<f64> = data.y.iter().map(|v| v - y_mean).collect();

// Run LASSO at moderate alpha to select features
let w = crate::lasso::coordinate_descent_pub(&x_cols, &y_centered, 0.01, 1.0, 500, 1e-4, None);
Expand Down Expand Up @@ -1417,7 +1416,7 @@

// Each Gpredomics function currently handles class 2 as unknown
// As MCMC does not, remove unknown sample before analysis
let mut data = data.remove_class(2);
let mut data = data.remove_class(2.0);

// Selecting features
data.select_features(param);
Expand Down
8 changes: 4 additions & 4 deletions src/csv_report.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,14 +551,14 @@ pub fn export_jury_csv(exp: &Experiment, writer: &mut impl Write) -> Result<(),
// We need to recompute using the jury's predict on train data
let (train_f1, train_mcc, train_ppv, train_npv, train_gmean) = {
let (pred_classes, _scores) = jury.predict(&exp.train_data);
let filtered: Vec<(u8, u8)> = pred_classes
let filtered: Vec<(u8, f64)> = pred_classes
.iter()
.zip(exp.train_data.y.iter())
.filter(|(&p, _)| p != 2)
.map(|(&p, &y)| (p, y))
.collect();
if !filtered.is_empty() {
let (preds, trues): (Vec<u8>, Vec<u8>) = filtered.into_iter().unzip();
let (preds, trues): (Vec<u8>, Vec<f64>) = filtered.into_iter().unzip();
let (_, _, _, add) = compute_metrics_from_classes(&preds, &trues, [true; 5]);
(
add.f1_score.unwrap_or(f64::NAN),
Expand Down Expand Up @@ -586,7 +586,7 @@ pub fn export_jury_csv(exp: &Experiment, writer: &mut impl Write) -> Result<(),
test_rej,
) = if let Some(ref td) = exp.test_data {
let (pred_classes, scores) = jury.predict(td);
let filtered: Vec<(f64, u8, u8)> = scores
let filtered: Vec<(f64, u8, f64)> = scores
.iter()
.zip(pred_classes.iter())
.zip(td.y.iter())
Expand All @@ -603,7 +603,7 @@ pub fn export_jury_csv(exp: &Experiment, writer: &mut impl Write) -> Result<(),
pred_classes.iter().filter(|&&c| c == 2).count() as f64 / pred_classes.len() as f64;

if !filtered.is_empty() {
let (scores_f, preds_f, trues_f): (Vec<f64>, Vec<u8>, Vec<u8>) =
let (scores_f, preds_f, trues_f): (Vec<f64>, Vec<u8>, Vec<f64>) =
filtered.into_iter().fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut s, mut p, mut t), (sc, pr, tr)| {
Expand Down
74 changes: 39 additions & 35 deletions src/cv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,9 +741,9 @@ mod tests {
assert_eq!(cv.validation_folds[1].X, X2);
assert_eq!(cv.validation_folds[2].X, X3);

assert_eq!(cv.validation_folds[0].y, [0, 1, 1]);
assert_eq!(cv.validation_folds[1].y, [0, 1]);
assert_eq!(cv.validation_folds[2].y, [1]);
assert_eq!(cv.validation_folds[0].y, [0.0, 1.0, 1.0]);
assert_eq!(cv.validation_folds[1].y, [0.0, 1.0]);
assert_eq!(cv.validation_folds[2].y, [1.0]);

assert_eq!(
cv.validation_folds[0].samples,
Expand All @@ -766,7 +766,7 @@ mod tests {
fn test_cv_new_with_single_class() {
let mut rng = ChaCha8Rng::seed_from_u64(42);
let mut data = Data::test();
data.y = vec![0; data.y.len()];
data.y = vec![0.0; data.y.len()];

let outer_folds = 3;
let cv = CV::new(&data, outer_folds, &mut rng);
Expand All @@ -776,7 +776,7 @@ mod tests {

for fold in &cv.validation_folds {
for &label in &fold.y {
assert_eq!(label, 0);
assert_eq!(label, 0.0);
}
}
}
Expand Down Expand Up @@ -920,25 +920,25 @@ mod tests {
// Define 95% of samples as class 0, 5% as class 1
let n = data.y.len();
let n_class1 = (n as f64 * 0.05).round() as usize;
data.y = vec![0; n];
data.y = vec![0.0; n];
for i in 0..n_class1 {
data.y[i] = 1;
data.y[i] = 1.0;
}

let outer_folds = 5;
let cv = CV::new(&data, outer_folds, &mut rng);

// Each fold should contain a few class 1 samples or be empty if there are too few.
for fold in &cv.validation_folds {
let count_class1 = fold.y.iter().filter(|&&y| y == 1).count();
let count_class1 = fold.y.iter().filter(|&&y| y == 1.0).count();
assert!(count_class1 <= n_class1);
}

// Check that the overall distribution is preserved
let total_class1: usize = cv
.validation_folds
.iter()
.map(|fold| fold.y.iter().filter(|&&y| y == 1).count())
.map(|fold| fold.y.iter().filter(|&&y| y == 1.0).count())
.sum();
assert_eq!(total_class1, n_class1);
}
Expand Down Expand Up @@ -1614,13 +1614,13 @@ mod tests {
let cv = CV::new(&data, outer_folds, &mut rng);

// Count the classes in the original data
let original_class0 = data.y.iter().filter(|&&y| y == 0).count();
let original_class1 = data.y.iter().filter(|&&y| y == 1).count();
let original_class0 = data.y.iter().filter(|&&y| y == 0.0).count();
let original_class1 = data.y.iter().filter(|&&y| y == 1.0).count();

// Check the distribution in each fold
for (i, fold) in cv.validation_folds.iter().enumerate() {
let fold_class0 = fold.y.iter().filter(|&&y| y == 0).count();
let fold_class1 = fold.y.iter().filter(|&&y| y == 1).count();
let fold_class0 = fold.y.iter().filter(|&&y| y == 0.0).count();
let fold_class1 = fold.y.iter().filter(|&&y| y == 1.0).count();

// The distribution should be roughly balanced
let expected_class0 = (original_class0 + outer_folds - 1) / outer_folds;
Expand Down Expand Up @@ -1717,8 +1717,8 @@ mod tests {
let mut rng = ChaCha8Rng::seed_from_u64(42);
let mut data = Data::specific_test(30, 10);

let class0_count = data.y.iter().filter(|&&y| y == 0).count();
let _class1_count = data.y.iter().filter(|&&y| y == 1).count();
let class0_count = data.y.iter().filter(|&&y| y == 0.0).count();
let _class1_count = data.y.iter().filter(|&&y| y == 1.0).count();

let annotation_values: Vec<String> = (0..30)
.map(|i| {
Expand All @@ -1739,7 +1739,7 @@ mod tests {
let expected_size_class0 = class0_count / outer_folds;

for (i, fold) in cv.validation_folds.iter().enumerate() {
let fold_class0 = fold.y.iter().filter(|&&y| y == 0).count();
let fold_class0 = fold.y.iter().filter(|&&y| y == 0.0).count();

// Tolerance of +1/-1 due to integer divisions
assert!(
Expand Down Expand Up @@ -1771,8 +1771,8 @@ mod tests {
}

// Create exactly 30 samples of each class
let mut y: Vec<u8> = vec![0; 30];
y.extend(vec![1; 30]);
let mut y: Vec<f64> = vec![0.0; 30];
y.extend(vec![1.0; 30]);

let mut data = Data {
X,
Expand Down Expand Up @@ -1823,7 +1823,7 @@ mod tests {

for i in 0..data.sample_len {
let batch = &annot.sample_tags[&i][col_idx];
if data.y[i] == 0 {
if data.y[i] == 0.0 {
if batch == "A" {
batch_a_class0 += 1;
} else {
Expand Down Expand Up @@ -1864,7 +1864,7 @@ mod tests {
}

// Count by combination (class, batch) for this fold
if fold.y[i] == 0 {
if fold.y[i] == 0.0 {
if batch == "A" {
fold_batch_a_class0 += 1;
} else {
Expand Down Expand Up @@ -2046,23 +2046,23 @@ mod tests {
let class0_standard = cv_standard.validation_folds[i]
.y
.iter()
.filter(|&&y| y == 0)
.filter(|&&y| y == 0.0)
.count();
let class1_standard = cv_standard.validation_folds[i]
.y
.iter()
.filter(|&&y| y == 1)
.filter(|&&y| y == 1.0)
.count();

let class0_stratified = cv_stratified.validation_folds[i]
.y
.iter()
.filter(|&&y| y == 0)
.filter(|&&y| y == 0.0)
.count();
let class1_stratified = cv_stratified.validation_folds[i]
.y
.iter()
.filter(|&&y| y == 1)
.filter(|&&y| y == 1.0)
.count();

assert_eq!(
Expand Down Expand Up @@ -2366,7 +2366,7 @@ mod tests {

// Build class labels: [0,0,...0 (40x), 1,1,...1 (40x)]
data.y = (0..total_samples)
.map(|i| if i < total_samples / 2 { 0 } else { 1 })
.map(|i| if i < total_samples / 2 { 0.0 } else { 1.0 })
.collect();

// Build batch annotations: alternating A/B within each class
Expand Down Expand Up @@ -2407,12 +2407,16 @@ mod tests {
let class = fold.y[sample_idx];
let batch = &fold_annot.sample_tags[&sample_idx][col_idx];

match (class, batch.as_str()) {
(0, "A") => count_0_a += 1,
(0, "B") => count_0_b += 1,
(1, "A") => count_1_a += 1,
(1, "B") => count_1_b += 1,
_ => panic!("Unexpected class/batch combination"),
if class == 0.0 && batch == "A" {
count_0_a += 1;
} else if class == 0.0 && batch == "B" {
count_0_b += 1;
} else if class == 1.0 && batch == "A" {
count_1_a += 1;
} else if class == 1.0 && batch == "B" {
count_1_b += 1;
} else {
panic!("Unexpected class/batch combination");
}
}

Expand Down Expand Up @@ -2492,7 +2496,7 @@ mod tests {
.position(|c| c == "batch")
.unwrap();
(0..fold.sample_len)
.filter(|&i| fold.y[i] == 0 && annot.sample_tags[&i][col_idx] == "A")
.filter(|&i| fold.y[i] == 0.0 && annot.sample_tags[&i][col_idx] == "A")
.count()
})
.sum();
Expand All @@ -2518,7 +2522,7 @@ mod tests {
// Configuration :
// - Class 0: 4 samples (2 batch A, 2 batch B)
// - Class 1: 2 samples (1 batch A, 1 batch B)
data.y = vec![0, 0, 0, 0, 1, 1];
data.y = vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0];

let annotation_values = vec![
"A".to_string(),
Expand Down Expand Up @@ -2651,7 +2655,7 @@ mod tests {
fn test_cv_new_stratified_by_panic_on_missing_annotations() {
let mut data = Data::test_with_these_features(&[0, 1, 2, 3]);
data.sample_len = 5;
data.y = vec![0, 1, 0, 1, 0];
data.y = vec![0.0, 1.0, 0.0, 1.0, 0.0];
data.sample_annotations = None;

let mut rng = ChaCha8Rng::seed_from_u64(42);
Expand All @@ -2663,7 +2667,7 @@ mod tests {
fn test_cv_new_stratified_by_panic_on_incomplete_annotation_line() {
let mut data = Data::test_with_these_features(&[0, 1, 2, 3]);
data.sample_len = 4;
data.y = vec![0, 1, 0, 1];
data.y = vec![0.0, 1.0, 0.0, 1.0];

let mut sample_tags = HashMap::new();
sample_tags.insert(0, vec!["ctrl".to_string()]);
Expand Down
Loading
Loading