|
| 1 | +//! Ensemble-method utilities aligned to AFML Chapter 6. |
| 2 | +//! |
| 3 | +//! This module provides: |
| 4 | +//! - Bias/variance/noise diagnostics for ensemble forecasts. |
| 5 | +//! - Bagging mechanics (bootstrap + sequential-bootstrap wrappers). |
| 6 | +//! - Aggregation helpers (majority vote and mean probability). |
| 7 | +//! - Dependency-aware diagnostics for when bagging is likely to underperform. |
| 8 | +//! - A practical bagging-vs-boosting recommendation heuristic. |
| 9 | +
|
| 10 | +use rand::rngs::StdRng; |
| 11 | +use rand::{Rng, SeedableRng}; |
| 12 | + |
| 13 | +use crate::sampling::seq_bootstrap; |
| 14 | + |
| 15 | +#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| 16 | +pub enum EnsembleMethod { |
| 17 | + Bagging, |
| 18 | + Boosting, |
| 19 | +} |
| 20 | + |
| 21 | +#[derive(Debug, Clone, Copy, PartialEq)] |
| 22 | +pub struct BiasVarianceNoise { |
| 23 | + pub bias_sq: f64, |
| 24 | + pub variance: f64, |
| 25 | + pub noise: f64, |
| 26 | + pub mse: f64, |
| 27 | +} |
| 28 | + |
| 29 | +#[derive(Debug, Clone, Copy, PartialEq)] |
| 30 | +pub struct BaggingBoostingDecision { |
| 31 | + pub recommended: EnsembleMethod, |
| 32 | + pub expected_bagging_variance: f64, |
| 33 | + pub expected_variance_reduction: f64, |
| 34 | +} |
| 35 | + |
| 36 | +pub fn bias_variance_noise( |
| 37 | + y_true: &[f64], |
| 38 | + per_model_predictions: &[Vec<f64>], |
| 39 | +) -> Result<BiasVarianceNoise, String> { |
| 40 | + if y_true.is_empty() { |
| 41 | + return Err("y_true cannot be empty".to_string()); |
| 42 | + } |
| 43 | + if per_model_predictions.is_empty() { |
| 44 | + return Err("per_model_predictions cannot be empty".to_string()); |
| 45 | + } |
| 46 | + if per_model_predictions.iter().any(|row| row.len() != y_true.len()) { |
| 47 | + return Err("prediction length mismatch".to_string()); |
| 48 | + } |
| 49 | + |
| 50 | + let n_models = per_model_predictions.len() as f64; |
| 51 | + let n_samples = y_true.len() as f64; |
| 52 | + |
| 53 | + let mut bias_sq_sum = 0.0; |
| 54 | + let mut var_sum = 0.0; |
| 55 | + let mut mse_sum = 0.0; |
| 56 | + |
| 57 | + for i in 0..y_true.len() { |
| 58 | + let mut mean_pred = 0.0; |
| 59 | + for model in per_model_predictions { |
| 60 | + mean_pred += model[i]; |
| 61 | + let err = model[i] - y_true[i]; |
| 62 | + mse_sum += err * err; |
| 63 | + } |
| 64 | + mean_pred /= n_models; |
| 65 | + |
| 66 | + let bias = mean_pred - y_true[i]; |
| 67 | + bias_sq_sum += bias * bias; |
| 68 | + |
| 69 | + let mut local_var = 0.0; |
| 70 | + for model in per_model_predictions { |
| 71 | + let d = model[i] - mean_pred; |
| 72 | + local_var += d * d; |
| 73 | + } |
| 74 | + local_var /= n_models; |
| 75 | + var_sum += local_var; |
| 76 | + } |
| 77 | + |
| 78 | + let bias_sq = bias_sq_sum / n_samples; |
| 79 | + let variance = var_sum / n_samples; |
| 80 | + let mse = mse_sum / (n_samples * n_models); |
| 81 | + let noise = (mse - bias_sq - variance).max(0.0); |
| 82 | + |
| 83 | + Ok(BiasVarianceNoise { bias_sq, variance, noise, mse }) |
| 84 | +} |
| 85 | + |
| 86 | +pub fn bootstrap_sample_indices( |
| 87 | + n_samples: usize, |
| 88 | + sample_size: usize, |
| 89 | + seed: u64, |
| 90 | +) -> Result<Vec<usize>, String> { |
| 91 | + if n_samples == 0 || sample_size == 0 { |
| 92 | + return Err("n_samples and sample_size must be > 0".to_string()); |
| 93 | + } |
| 94 | + let mut rng = StdRng::seed_from_u64(seed); |
| 95 | + Ok((0..sample_size).map(|_| rng.gen_range(0..n_samples)).collect()) |
| 96 | +} |
| 97 | + |
| 98 | +pub fn sequential_bootstrap_sample_indices( |
| 99 | + ind_mat: &[Vec<u8>], |
| 100 | + sample_size: usize, |
| 101 | + seed: u64, |
| 102 | +) -> Result<Vec<usize>, String> { |
| 103 | + if sample_size == 0 { |
| 104 | + return Err("sample_size must be > 0".to_string()); |
| 105 | + } |
| 106 | + if ind_mat.is_empty() { |
| 107 | + return Err("ind_mat cannot be empty".to_string()); |
| 108 | + } |
| 109 | + let n_labels = ind_mat.first().map(|r| r.len()).unwrap_or(0); |
| 110 | + if n_labels == 0 { |
| 111 | + return Err("ind_mat must include at least one label column".to_string()); |
| 112 | + } |
| 113 | + |
| 114 | + let mut rng = StdRng::seed_from_u64(seed); |
| 115 | + let warmup: Vec<usize> = (0..sample_size).map(|_| rng.gen_range(0..n_labels)).collect(); |
| 116 | + Ok(seq_bootstrap(ind_mat, Some(sample_size), Some(warmup))) |
| 117 | +} |
| 118 | + |
| 119 | +pub fn aggregate_regression_mean(per_model_predictions: &[Vec<f64>]) -> Result<Vec<f64>, String> { |
| 120 | + if per_model_predictions.is_empty() { |
| 121 | + return Err("per_model_predictions cannot be empty".to_string()); |
| 122 | + } |
| 123 | + let n = per_model_predictions[0].len(); |
| 124 | + if n == 0 { |
| 125 | + return Err("prediction rows cannot be empty".to_string()); |
| 126 | + } |
| 127 | + if per_model_predictions.iter().any(|row| row.len() != n) { |
| 128 | + return Err("prediction length mismatch".to_string()); |
| 129 | + } |
| 130 | + |
| 131 | + let mut out = vec![0.0; n]; |
| 132 | + for row in per_model_predictions { |
| 133 | + for (i, v) in row.iter().enumerate() { |
| 134 | + out[i] += *v; |
| 135 | + } |
| 136 | + } |
| 137 | + let denom = per_model_predictions.len() as f64; |
| 138 | + for v in &mut out { |
| 139 | + *v /= denom; |
| 140 | + } |
| 141 | + Ok(out) |
| 142 | +} |
| 143 | + |
| 144 | +pub fn aggregate_classification_vote(per_model_predictions: &[Vec<u8>]) -> Result<Vec<u8>, String> { |
| 145 | + if per_model_predictions.is_empty() { |
| 146 | + return Err("per_model_predictions cannot be empty".to_string()); |
| 147 | + } |
| 148 | + let n = per_model_predictions[0].len(); |
| 149 | + if n == 0 { |
| 150 | + return Err("prediction rows cannot be empty".to_string()); |
| 151 | + } |
| 152 | + if per_model_predictions.iter().any(|row| row.len() != n) { |
| 153 | + return Err("prediction length mismatch".to_string()); |
| 154 | + } |
| 155 | + if per_model_predictions.iter().flat_map(|row| row.iter()).any(|label| *label > 1) { |
| 156 | + return Err("classification vote expects binary labels in {0,1}".to_string()); |
| 157 | + } |
| 158 | + |
| 159 | + let mut out = vec![0u8; n]; |
| 160 | + for i in 0..n { |
| 161 | + let votes = per_model_predictions.iter().map(|row| row[i] as usize).sum::<usize>(); |
| 162 | + out[i] = if votes * 2 >= per_model_predictions.len() { 1 } else { 0 }; |
| 163 | + } |
| 164 | + Ok(out) |
| 165 | +} |
| 166 | + |
| 167 | +pub fn aggregate_classification_probability_mean( |
| 168 | + per_model_probabilities: &[Vec<f64>], |
| 169 | + threshold: f64, |
| 170 | +) -> Result<(Vec<f64>, Vec<u8>), String> { |
| 171 | + if !(0.0..=1.0).contains(&threshold) { |
| 172 | + return Err("threshold must be in [0,1]".to_string()); |
| 173 | + } |
| 174 | + let probs = aggregate_regression_mean(per_model_probabilities)?; |
| 175 | + if probs.iter().any(|p| !(0.0..=1.0).contains(p)) { |
| 176 | + return Err("probabilities must be in [0,1]".to_string()); |
| 177 | + } |
| 178 | + let labels = probs.iter().map(|p| if *p >= threshold { 1 } else { 0 }).collect(); |
| 179 | + Ok((probs, labels)) |
| 180 | +} |
| 181 | + |
| 182 | +pub fn average_pairwise_prediction_correlation( |
| 183 | + per_model_predictions: &[Vec<f64>], |
| 184 | +) -> Result<f64, String> { |
| 185 | + if per_model_predictions.len() < 2 { |
| 186 | + return Err("at least two model prediction rows are required".to_string()); |
| 187 | + } |
| 188 | + let n = per_model_predictions[0].len(); |
| 189 | + if n < 2 { |
| 190 | + return Err("prediction rows must have at least two samples".to_string()); |
| 191 | + } |
| 192 | + if per_model_predictions.iter().any(|row| row.len() != n) { |
| 193 | + return Err("prediction length mismatch".to_string()); |
| 194 | + } |
| 195 | + |
| 196 | + let mut corr_sum = 0.0; |
| 197 | + let mut pairs = 0usize; |
| 198 | + for i in 0..per_model_predictions.len() { |
| 199 | + for j in (i + 1)..per_model_predictions.len() { |
| 200 | + corr_sum += pearson_corr(&per_model_predictions[i], &per_model_predictions[j]); |
| 201 | + pairs += 1; |
| 202 | + } |
| 203 | + } |
| 204 | + Ok(corr_sum / pairs as f64) |
| 205 | +} |
| 206 | + |
| 207 | +pub fn bagging_ensemble_variance( |
| 208 | + single_estimator_variance: f64, |
| 209 | + average_correlation: f64, |
| 210 | + n_estimators: usize, |
| 211 | +) -> Result<f64, String> { |
| 212 | + if single_estimator_variance < 0.0 { |
| 213 | + return Err("single_estimator_variance must be non-negative".to_string()); |
| 214 | + } |
| 215 | + if !(-1.0..=1.0).contains(&average_correlation) { |
| 216 | + return Err("average_correlation must be in [-1,1]".to_string()); |
| 217 | + } |
| 218 | + if n_estimators == 0 { |
| 219 | + return Err("n_estimators must be > 0".to_string()); |
| 220 | + } |
| 221 | + |
| 222 | + let n = n_estimators as f64; |
| 223 | + let rho = average_correlation; |
| 224 | + Ok(single_estimator_variance * (rho + (1.0 - rho) / n)) |
| 225 | +} |
| 226 | + |
| 227 | +pub fn recommend_bagging_vs_boosting( |
| 228 | + base_estimator_accuracy: f64, |
| 229 | + average_prediction_correlation: f64, |
| 230 | + label_redundancy: f64, |
| 231 | + single_estimator_variance: f64, |
| 232 | + n_estimators: usize, |
| 233 | +) -> Result<BaggingBoostingDecision, String> { |
| 234 | + if !(0.0..=1.0).contains(&base_estimator_accuracy) { |
| 235 | + return Err("base_estimator_accuracy must be in [0,1]".to_string()); |
| 236 | + } |
| 237 | + if !(0.0..=1.0).contains(&label_redundancy) { |
| 238 | + return Err("label_redundancy must be in [0,1]".to_string()); |
| 239 | + } |
| 240 | + let bag_var = bagging_ensemble_variance( |
| 241 | + single_estimator_variance, |
| 242 | + average_prediction_correlation, |
| 243 | + n_estimators, |
| 244 | + )?; |
| 245 | + let expected_reduction = (single_estimator_variance - bag_var).max(0.0); |
| 246 | + |
| 247 | + // Heuristic criteria: |
| 248 | + // - weak learners (accuracy near random) favor boosting for bias reduction. |
| 249 | + // - highly correlated learners or high label redundancy reduce bagging gains. |
| 250 | + let weak_learner = base_estimator_accuracy < 0.55; |
| 251 | + let highly_correlated = average_prediction_correlation >= 0.75; |
| 252 | + let redundant_labels = label_redundancy >= 0.70; |
| 253 | + |
| 254 | + let recommended = if weak_learner || highly_correlated || redundant_labels { |
| 255 | + EnsembleMethod::Boosting |
| 256 | + } else { |
| 257 | + EnsembleMethod::Bagging |
| 258 | + }; |
| 259 | + |
| 260 | + Ok(BaggingBoostingDecision { |
| 261 | + recommended, |
| 262 | + expected_bagging_variance: bag_var, |
| 263 | + expected_variance_reduction: expected_reduction, |
| 264 | + }) |
| 265 | +} |
| 266 | + |
| 267 | +fn pearson_corr(x: &[f64], y: &[f64]) -> f64 { |
| 268 | + let mx = x.iter().sum::<f64>() / x.len() as f64; |
| 269 | + let my = y.iter().sum::<f64>() / y.len() as f64; |
| 270 | + |
| 271 | + let mut num = 0.0; |
| 272 | + let mut den_x = 0.0; |
| 273 | + let mut den_y = 0.0; |
| 274 | + for (a, b) in x.iter().zip(y.iter()) { |
| 275 | + let dx = *a - mx; |
| 276 | + let dy = *b - my; |
| 277 | + num += dx * dy; |
| 278 | + den_x += dx * dx; |
| 279 | + den_y += dy * dy; |
| 280 | + } |
| 281 | + if den_x == 0.0 || den_y == 0.0 { |
| 282 | + 0.0 |
| 283 | + } else { |
| 284 | + num / (den_x.sqrt() * den_y.sqrt()) |
| 285 | + } |
| 286 | +} |
0 commit comments