Skip to content

Commit aaa38c6

Browse files
authored
Merge pull request #4 from Open-Quant/feat/oq-mef-1-ensemble-methods
feat: implement AFML Ch.6 ensemble_methods module (OQ-mef.1)
2 parents ec40d90 + 95e44c9 commit aaa38c6

5 files changed

Lines changed: 437 additions & 0 deletions

File tree

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
//! Ensemble-method utilities aligned to AFML Chapter 6.
2+
//!
3+
//! This module provides:
4+
//! - Bias/variance/noise diagnostics for ensemble forecasts.
5+
//! - Bagging mechanics (bootstrap + sequential-bootstrap wrappers).
6+
//! - Aggregation helpers (majority vote and mean probability).
7+
//! - Dependency-aware diagnostics for when bagging is likely to underperform.
8+
//! - A practical bagging-vs-boosting recommendation heuristic.
9+
10+
use rand::rngs::StdRng;
11+
use rand::{Rng, SeedableRng};
12+
13+
use crate::sampling::seq_bootstrap;
14+
15+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16+
pub enum EnsembleMethod {
17+
Bagging,
18+
Boosting,
19+
}
20+
21+
#[derive(Debug, Clone, Copy, PartialEq)]
22+
pub struct BiasVarianceNoise {
23+
pub bias_sq: f64,
24+
pub variance: f64,
25+
pub noise: f64,
26+
pub mse: f64,
27+
}
28+
29+
#[derive(Debug, Clone, Copy, PartialEq)]
30+
pub struct BaggingBoostingDecision {
31+
pub recommended: EnsembleMethod,
32+
pub expected_bagging_variance: f64,
33+
pub expected_variance_reduction: f64,
34+
}
35+
36+
pub fn bias_variance_noise(
37+
y_true: &[f64],
38+
per_model_predictions: &[Vec<f64>],
39+
) -> Result<BiasVarianceNoise, String> {
40+
if y_true.is_empty() {
41+
return Err("y_true cannot be empty".to_string());
42+
}
43+
if per_model_predictions.is_empty() {
44+
return Err("per_model_predictions cannot be empty".to_string());
45+
}
46+
if per_model_predictions.iter().any(|row| row.len() != y_true.len()) {
47+
return Err("prediction length mismatch".to_string());
48+
}
49+
50+
let n_models = per_model_predictions.len() as f64;
51+
let n_samples = y_true.len() as f64;
52+
53+
let mut bias_sq_sum = 0.0;
54+
let mut var_sum = 0.0;
55+
let mut mse_sum = 0.0;
56+
57+
for i in 0..y_true.len() {
58+
let mut mean_pred = 0.0;
59+
for model in per_model_predictions {
60+
mean_pred += model[i];
61+
let err = model[i] - y_true[i];
62+
mse_sum += err * err;
63+
}
64+
mean_pred /= n_models;
65+
66+
let bias = mean_pred - y_true[i];
67+
bias_sq_sum += bias * bias;
68+
69+
let mut local_var = 0.0;
70+
for model in per_model_predictions {
71+
let d = model[i] - mean_pred;
72+
local_var += d * d;
73+
}
74+
local_var /= n_models;
75+
var_sum += local_var;
76+
}
77+
78+
let bias_sq = bias_sq_sum / n_samples;
79+
let variance = var_sum / n_samples;
80+
let mse = mse_sum / (n_samples * n_models);
81+
let noise = (mse - bias_sq - variance).max(0.0);
82+
83+
Ok(BiasVarianceNoise { bias_sq, variance, noise, mse })
84+
}
85+
86+
pub fn bootstrap_sample_indices(
87+
n_samples: usize,
88+
sample_size: usize,
89+
seed: u64,
90+
) -> Result<Vec<usize>, String> {
91+
if n_samples == 0 || sample_size == 0 {
92+
return Err("n_samples and sample_size must be > 0".to_string());
93+
}
94+
let mut rng = StdRng::seed_from_u64(seed);
95+
Ok((0..sample_size).map(|_| rng.gen_range(0..n_samples)).collect())
96+
}
97+
98+
pub fn sequential_bootstrap_sample_indices(
99+
ind_mat: &[Vec<u8>],
100+
sample_size: usize,
101+
seed: u64,
102+
) -> Result<Vec<usize>, String> {
103+
if sample_size == 0 {
104+
return Err("sample_size must be > 0".to_string());
105+
}
106+
if ind_mat.is_empty() {
107+
return Err("ind_mat cannot be empty".to_string());
108+
}
109+
let n_labels = ind_mat.first().map(|r| r.len()).unwrap_or(0);
110+
if n_labels == 0 {
111+
return Err("ind_mat must include at least one label column".to_string());
112+
}
113+
114+
let mut rng = StdRng::seed_from_u64(seed);
115+
let warmup: Vec<usize> = (0..sample_size).map(|_| rng.gen_range(0..n_labels)).collect();
116+
Ok(seq_bootstrap(ind_mat, Some(sample_size), Some(warmup)))
117+
}
118+
119+
pub fn aggregate_regression_mean(per_model_predictions: &[Vec<f64>]) -> Result<Vec<f64>, String> {
120+
if per_model_predictions.is_empty() {
121+
return Err("per_model_predictions cannot be empty".to_string());
122+
}
123+
let n = per_model_predictions[0].len();
124+
if n == 0 {
125+
return Err("prediction rows cannot be empty".to_string());
126+
}
127+
if per_model_predictions.iter().any(|row| row.len() != n) {
128+
return Err("prediction length mismatch".to_string());
129+
}
130+
131+
let mut out = vec![0.0; n];
132+
for row in per_model_predictions {
133+
for (i, v) in row.iter().enumerate() {
134+
out[i] += *v;
135+
}
136+
}
137+
let denom = per_model_predictions.len() as f64;
138+
for v in &mut out {
139+
*v /= denom;
140+
}
141+
Ok(out)
142+
}
143+
144+
pub fn aggregate_classification_vote(per_model_predictions: &[Vec<u8>]) -> Result<Vec<u8>, String> {
145+
if per_model_predictions.is_empty() {
146+
return Err("per_model_predictions cannot be empty".to_string());
147+
}
148+
let n = per_model_predictions[0].len();
149+
if n == 0 {
150+
return Err("prediction rows cannot be empty".to_string());
151+
}
152+
if per_model_predictions.iter().any(|row| row.len() != n) {
153+
return Err("prediction length mismatch".to_string());
154+
}
155+
if per_model_predictions.iter().flat_map(|row| row.iter()).any(|label| *label > 1) {
156+
return Err("classification vote expects binary labels in {0,1}".to_string());
157+
}
158+
159+
let mut out = vec![0u8; n];
160+
for i in 0..n {
161+
let votes = per_model_predictions.iter().map(|row| row[i] as usize).sum::<usize>();
162+
out[i] = if votes * 2 >= per_model_predictions.len() { 1 } else { 0 };
163+
}
164+
Ok(out)
165+
}
166+
167+
pub fn aggregate_classification_probability_mean(
168+
per_model_probabilities: &[Vec<f64>],
169+
threshold: f64,
170+
) -> Result<(Vec<f64>, Vec<u8>), String> {
171+
if !(0.0..=1.0).contains(&threshold) {
172+
return Err("threshold must be in [0,1]".to_string());
173+
}
174+
let probs = aggregate_regression_mean(per_model_probabilities)?;
175+
if probs.iter().any(|p| !(0.0..=1.0).contains(p)) {
176+
return Err("probabilities must be in [0,1]".to_string());
177+
}
178+
let labels = probs.iter().map(|p| if *p >= threshold { 1 } else { 0 }).collect();
179+
Ok((probs, labels))
180+
}
181+
182+
pub fn average_pairwise_prediction_correlation(
183+
per_model_predictions: &[Vec<f64>],
184+
) -> Result<f64, String> {
185+
if per_model_predictions.len() < 2 {
186+
return Err("at least two model prediction rows are required".to_string());
187+
}
188+
let n = per_model_predictions[0].len();
189+
if n < 2 {
190+
return Err("prediction rows must have at least two samples".to_string());
191+
}
192+
if per_model_predictions.iter().any(|row| row.len() != n) {
193+
return Err("prediction length mismatch".to_string());
194+
}
195+
196+
let mut corr_sum = 0.0;
197+
let mut pairs = 0usize;
198+
for i in 0..per_model_predictions.len() {
199+
for j in (i + 1)..per_model_predictions.len() {
200+
corr_sum += pearson_corr(&per_model_predictions[i], &per_model_predictions[j]);
201+
pairs += 1;
202+
}
203+
}
204+
Ok(corr_sum / pairs as f64)
205+
}
206+
207+
pub fn bagging_ensemble_variance(
208+
single_estimator_variance: f64,
209+
average_correlation: f64,
210+
n_estimators: usize,
211+
) -> Result<f64, String> {
212+
if single_estimator_variance < 0.0 {
213+
return Err("single_estimator_variance must be non-negative".to_string());
214+
}
215+
if !(-1.0..=1.0).contains(&average_correlation) {
216+
return Err("average_correlation must be in [-1,1]".to_string());
217+
}
218+
if n_estimators == 0 {
219+
return Err("n_estimators must be > 0".to_string());
220+
}
221+
222+
let n = n_estimators as f64;
223+
let rho = average_correlation;
224+
Ok(single_estimator_variance * (rho + (1.0 - rho) / n))
225+
}
226+
227+
pub fn recommend_bagging_vs_boosting(
228+
base_estimator_accuracy: f64,
229+
average_prediction_correlation: f64,
230+
label_redundancy: f64,
231+
single_estimator_variance: f64,
232+
n_estimators: usize,
233+
) -> Result<BaggingBoostingDecision, String> {
234+
if !(0.0..=1.0).contains(&base_estimator_accuracy) {
235+
return Err("base_estimator_accuracy must be in [0,1]".to_string());
236+
}
237+
if !(0.0..=1.0).contains(&label_redundancy) {
238+
return Err("label_redundancy must be in [0,1]".to_string());
239+
}
240+
let bag_var = bagging_ensemble_variance(
241+
single_estimator_variance,
242+
average_prediction_correlation,
243+
n_estimators,
244+
)?;
245+
let expected_reduction = (single_estimator_variance - bag_var).max(0.0);
246+
247+
// Heuristic criteria:
248+
// - weak learners (accuracy near random) favor boosting for bias reduction.
249+
// - highly correlated learners or high label redundancy reduce bagging gains.
250+
let weak_learner = base_estimator_accuracy < 0.55;
251+
let highly_correlated = average_prediction_correlation >= 0.75;
252+
let redundant_labels = label_redundancy >= 0.70;
253+
254+
let recommended = if weak_learner || highly_correlated || redundant_labels {
255+
EnsembleMethod::Boosting
256+
} else {
257+
EnsembleMethod::Bagging
258+
};
259+
260+
Ok(BaggingBoostingDecision {
261+
recommended,
262+
expected_bagging_variance: bag_var,
263+
expected_variance_reduction: expected_reduction,
264+
})
265+
}
266+
267+
fn pearson_corr(x: &[f64], y: &[f64]) -> f64 {
268+
let mx = x.iter().sum::<f64>() / x.len() as f64;
269+
let my = y.iter().sum::<f64>() / y.len() as f64;
270+
271+
let mut num = 0.0;
272+
let mut den_x = 0.0;
273+
let mut den_y = 0.0;
274+
for (a, b) in x.iter().zip(y.iter()) {
275+
let dx = *a - mx;
276+
let dy = *b - my;
277+
num += dx * dy;
278+
den_x += dx * dx;
279+
den_y += dy * dy;
280+
}
281+
if den_x == 0.0 || den_y == 0.0 {
282+
0.0
283+
} else {
284+
num / (den_x.sqrt() * den_y.sqrt())
285+
}
286+
}

crates/openquant/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pub mod codependence;
55
pub mod cross_validation;
66
pub mod data_structures;
77
pub mod ef3m;
8+
pub mod ensemble_methods;
89
pub mod etf_trick;
910
pub mod feature_importance;
1011
pub mod filters;
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
use openquant::ensemble_methods::{
2+
aggregate_classification_probability_mean, aggregate_classification_vote,
3+
aggregate_regression_mean, average_pairwise_prediction_correlation, bagging_ensemble_variance,
4+
bias_variance_noise, bootstrap_sample_indices, recommend_bagging_vs_boosting,
5+
sequential_bootstrap_sample_indices, EnsembleMethod,
6+
};
7+
8+
#[test]
9+
fn test_bias_variance_noise_decomposition() {
10+
let y = vec![1.0, 0.0, 1.0, 0.0];
11+
let preds = vec![vec![0.9, 0.1, 0.8, 0.2], vec![0.8, 0.2, 0.7, 0.3], vec![1.0, 0.0, 0.9, 0.1]];
12+
13+
let out = bias_variance_noise(&y, &preds).unwrap();
14+
assert!(out.bias_sq >= 0.0);
15+
assert!(out.variance >= 0.0);
16+
assert!(out.noise >= 0.0);
17+
assert!(out.mse >= 0.0);
18+
19+
let lhs = out.bias_sq + out.variance + out.noise;
20+
assert!((lhs - out.mse).abs() < 1e-10);
21+
}
22+
23+
#[test]
24+
fn test_bootstrap_and_sequential_bootstrap_shapes() {
25+
let b = bootstrap_sample_indices(10, 6, 7).unwrap();
26+
assert_eq!(b.len(), 6);
27+
assert!(b.iter().all(|v| *v < 10));
28+
29+
let ind_mat = vec![vec![1, 0, 1, 0], vec![0, 1, 0, 1], vec![1, 1, 0, 0]];
30+
let sb = sequential_bootstrap_sample_indices(&ind_mat, 8, 11).unwrap();
31+
assert_eq!(sb.len(), 8);
32+
assert!(sb.iter().all(|v| *v < ind_mat[0].len()));
33+
}
34+
35+
#[test]
36+
fn test_aggregation_helpers() {
37+
let reg = aggregate_regression_mean(&[vec![1.0, 3.0], vec![3.0, 1.0]]).unwrap();
38+
assert_eq!(reg, vec![2.0, 2.0]);
39+
40+
let vote =
41+
aggregate_classification_vote(&[vec![1, 0, 1], vec![1, 1, 0], vec![0, 1, 1]]).unwrap();
42+
assert_eq!(vote, vec![1, 1, 1]);
43+
44+
let (prob, labels) = aggregate_classification_probability_mean(
45+
&[vec![0.9, 0.2], vec![0.7, 0.4], vec![0.8, 0.3]],
46+
0.5,
47+
)
48+
.unwrap();
49+
assert!((prob[0] - 0.8).abs() < 1e-12);
50+
assert!((prob[1] - 0.3).abs() < 1e-12);
51+
assert_eq!(labels, vec![1, 0]);
52+
}
53+
54+
#[test]
55+
fn test_variance_reduction_and_redundancy_failure_mode() {
56+
let low_corr = bagging_ensemble_variance(1.0, 0.0, 10).unwrap();
57+
assert!((low_corr - 0.1).abs() < 1e-12);
58+
59+
let high_corr = bagging_ensemble_variance(1.0, 0.95, 10).unwrap();
60+
assert!(high_corr > 0.9);
61+
assert!(high_corr > low_corr);
62+
}
63+
64+
#[test]
65+
fn test_pairwise_correlation_and_strategy_recommendation() {
66+
let weak_preds = vec![
67+
vec![0.50, 0.52, 0.48, 0.50],
68+
vec![0.51, 0.53, 0.49, 0.51],
69+
vec![0.49, 0.51, 0.47, 0.49],
70+
];
71+
let corr = average_pairwise_prediction_correlation(&weak_preds).unwrap();
72+
assert!(corr > 0.95);
73+
74+
let weak = recommend_bagging_vs_boosting(0.53, corr, 0.8, 1.0, 16).unwrap();
75+
assert_eq!(weak.recommended, EnsembleMethod::Boosting);
76+
77+
let strong_diverse = recommend_bagging_vs_boosting(0.68, 0.15, 0.25, 1.0, 16).unwrap();
78+
assert_eq!(strong_diverse.recommended, EnsembleMethod::Bagging);
79+
assert!(strong_diverse.expected_variance_reduction > 0.0);
80+
}

0 commit comments

Comments
 (0)