Skip to content

Commit 59ac6fa

Browse files
committed
feat(cross_validation): add purged CV diagnostics and CPCV paths
1 parent 43ae006 commit 59ac6fa

2 files changed

Lines changed: 324 additions & 83 deletions

File tree

crates/openquant/src/cross_validation.rs

Lines changed: 235 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
use chrono::NaiveDateTime;
2+
use itertools::Itertools;
3+
use std::collections::HashSet;
24

35
/// Simple classifier interface for cross-validation.
46
pub trait SimpleClassifier {
@@ -69,10 +71,7 @@ pub fn ml_get_train_times(
6971
for (start, end) in info_sets {
7072
let mut keep = true;
7173
for (test_start, test_end) in test_times {
72-
let start_in = *start >= *test_start && *start <= *test_end;
73-
let end_in = *end >= *test_start && *end <= *test_end;
74-
let envelop = *start <= *test_start && *end >= *test_end;
75-
if start_in || end_in || envelop {
74+
if intervals_overlap((*start, *end), (*test_start, *test_end)) {
7675
keep = false;
7776
break;
7877
}
@@ -84,6 +83,29 @@ pub fn ml_get_train_times(
8483
out
8584
}
8685

86+
#[derive(Debug, Clone, PartialEq, Eq)]
87+
pub struct PurgedSplitDiagnostics {
88+
pub split_id: usize,
89+
pub test_ranges: Vec<(usize, usize)>,
90+
pub purged_indices: Vec<usize>,
91+
pub embargo_indices: Vec<usize>,
92+
pub overlap_count_after_purge: usize,
93+
}
94+
95+
#[derive(Debug, Clone, PartialEq, Eq)]
96+
pub struct PurgedSplit {
97+
pub train_indices: Vec<usize>,
98+
pub test_indices: Vec<usize>,
99+
pub diagnostics: PurgedSplitDiagnostics,
100+
}
101+
102+
#[derive(Debug, Clone, PartialEq, Eq)]
103+
pub struct CpcvPath {
104+
pub path_id: usize,
105+
pub test_fold_ids: Vec<usize>,
106+
pub split: PurgedSplit,
107+
}
108+
87109
pub struct PurgedKFold {
88110
n_splits: usize,
89111
samples_info_sets: Vec<(NaiveDateTime, NaiveDateTime)>,
@@ -96,64 +118,229 @@ impl PurgedKFold {
96118
samples_info_sets: Vec<(NaiveDateTime, NaiveDateTime)>,
97119
pct_embargo: f64,
98120
) -> Result<Self, String> {
121+
if n_splits < 2 {
122+
return Err("n_splits must be at least 2".into());
123+
}
99124
if samples_info_sets.is_empty() {
100125
return Err("samples_info_sets cannot be empty".into());
101126
}
127+
if n_splits > samples_info_sets.len() {
128+
return Err("n_splits cannot exceed sample count".into());
129+
}
130+
if !(0.0..1.0).contains(&pct_embargo) {
131+
return Err("pct_embargo must be in [0.0, 1.0)".into());
132+
}
133+
for (idx, (start, end)) in samples_info_sets.iter().enumerate() {
134+
if start > end {
135+
return Err(format!("invalid information set at index {idx}: start > end"));
136+
}
137+
}
138+
102139
Ok(Self { n_splits, samples_info_sets, pct_embargo })
103140
}
104141

105142
pub fn split(&self, n_samples: usize) -> Result<Vec<(Vec<usize>, Vec<usize>)>, String> {
143+
let splits = self.split_with_diagnostics(n_samples)?;
144+
Ok(splits.into_iter().map(|s| (s.train_indices, s.test_indices)).collect())
145+
}
146+
147+
pub fn split_with_diagnostics(&self, n_samples: usize) -> Result<Vec<PurgedSplit>, String> {
148+
self.validate_n_samples(n_samples)?;
149+
let folds = contiguous_fold_bounds(n_samples, self.n_splits);
150+
151+
let mut out = Vec::with_capacity(folds.len());
152+
for (split_id, (start, stop)) in folds.iter().enumerate() {
153+
let test_indices: Vec<usize> = (*start..*stop).collect();
154+
out.push(self.build_split(split_id, test_indices)?);
155+
}
156+
Ok(out)
157+
}
158+
159+
pub fn cpcv_paths(
160+
&self,
161+
n_samples: usize,
162+
n_test_splits: usize,
163+
) -> Result<Vec<CpcvPath>, String> {
164+
self.validate_n_samples(n_samples)?;
165+
if n_test_splits == 0 || n_test_splits >= self.n_splits {
166+
return Err("n_test_splits must be in [1, n_splits)".into());
167+
}
168+
169+
let folds = contiguous_fold_bounds(n_samples, self.n_splits);
170+
let mut paths = Vec::new();
171+
172+
for (path_id, test_fold_ids) in (0..self.n_splits).combinations(n_test_splits).enumerate() {
173+
let mut test_indices = Vec::new();
174+
for fold_id in &test_fold_ids {
175+
let (start, stop) = folds[*fold_id];
176+
test_indices.extend(start..stop);
177+
}
178+
let split = self.build_split(path_id, test_indices)?;
179+
paths.push(CpcvPath { path_id, test_fold_ids, split });
180+
}
181+
182+
Ok(paths)
183+
}
184+
185+
fn validate_n_samples(&self, n_samples: usize) -> Result<(), String> {
106186
if n_samples != self.samples_info_sets.len() {
107-
return Err("Dataset length must match samples_info_sets".into());
108-
}
109-
let n = n_samples;
110-
let mut fold_sizes = vec![n / self.n_splits; self.n_splits];
111-
for i in 0..(n % self.n_splits) {
112-
fold_sizes[i] += 1;
113-
}
114-
let mut current = 0;
115-
let mut splits = Vec::new();
116-
for fold_size in fold_sizes {
117-
let start = current;
118-
let stop = current + fold_size;
119-
let test_indices: Vec<usize> = (start..stop).collect();
120-
let mut train_mask = vec![true; n];
121-
122-
// purge overlaps
123-
let test_start = self.samples_info_sets[test_indices[0]].1;
124-
let test_end = self.samples_info_sets[*test_indices.last().unwrap()].1;
125-
for (i, (s, e)) in self.samples_info_sets.iter().enumerate() {
126-
let start_in = *s >= test_start && *s <= test_end;
127-
let end_in = *e >= test_start && *e <= test_end;
128-
let envelop = *s <= test_start && *e >= test_end;
129-
if start_in || end_in || envelop {
130-
train_mask[i] = false;
131-
}
187+
return Err("dataset length must match samples_info_sets".into());
188+
}
189+
Ok(())
190+
}
191+
192+
fn build_split(
193+
&self,
194+
split_id: usize,
195+
test_indices: Vec<usize>,
196+
) -> Result<PurgedSplit, String> {
197+
if test_indices.is_empty() {
198+
return Err("test_indices cannot be empty".into());
199+
}
200+
201+
let n_samples = self.samples_info_sets.len();
202+
let test_set: HashSet<usize> = test_indices.iter().copied().collect();
203+
let test_intervals: Vec<(NaiveDateTime, NaiveDateTime)> =
204+
test_indices.iter().map(|idx| self.samples_info_sets[*idx]).collect();
205+
206+
let mut purged = Vec::new();
207+
for idx in 0..n_samples {
208+
if test_set.contains(&idx) {
209+
continue;
210+
}
211+
let candidate = self.samples_info_sets[idx];
212+
let overlaps = test_intervals
213+
.iter()
214+
.any(|test_interval| intervals_overlap(candidate, *test_interval));
215+
if overlaps {
216+
purged.push(idx);
132217
}
218+
}
133219

134-
// embargo
135-
let embargo = (self.pct_embargo * n as f64).ceil() as isize;
136-
if embargo > 0 {
137-
let after = (stop as isize + embargo).min(n as isize);
138-
let before = (start as isize - embargo).max(0);
139-
for i in start..(after as usize) {
140-
if i < n {
141-
train_mask[i] = false;
220+
let embargo = ((self.pct_embargo * n_samples as f64).ceil()) as usize;
221+
let mut embargoed = HashSet::new();
222+
if embargo > 0 {
223+
for (_start, stop) in contiguous_ranges(&test_indices) {
224+
let embargo_stop = (stop + embargo).min(n_samples);
225+
for idx in stop..embargo_stop {
226+
if !test_set.contains(&idx) {
227+
embargoed.insert(idx);
142228
}
143229
}
144-
for i in before as usize..start {
145-
train_mask[i] = false;
146-
}
147230
}
231+
}
148232

149-
let train_indices: Vec<usize> = train_mask
150-
.iter()
151-
.enumerate()
152-
.filter_map(|(i, keep)| if *keep { Some(i) } else { None })
153-
.collect();
154-
splits.push((train_indices, test_indices));
155-
current = stop;
233+
let purged_set: HashSet<usize> = purged.iter().copied().collect();
234+
let train_indices: Vec<usize> = (0..n_samples)
235+
.filter(|idx| {
236+
!test_set.contains(idx) && !purged_set.contains(idx) && !embargoed.contains(idx)
237+
})
238+
.collect();
239+
240+
let overlap_count_after_purge =
241+
count_train_test_overlaps(&self.samples_info_sets, &train_indices, &test_indices);
242+
243+
let mut embargo_indices: Vec<usize> = embargoed.into_iter().collect();
244+
embargo_indices.sort_unstable();
245+
246+
Ok(PurgedSplit {
247+
train_indices,
248+
test_indices: test_indices.clone(),
249+
diagnostics: PurgedSplitDiagnostics {
250+
split_id,
251+
test_ranges: contiguous_ranges(&test_indices),
252+
purged_indices: purged,
253+
embargo_indices,
254+
overlap_count_after_purge,
255+
},
256+
})
257+
}
258+
}
259+
260+
pub fn naive_kfold_splits(
261+
n_samples: usize,
262+
n_splits: usize,
263+
) -> Result<Vec<(Vec<usize>, Vec<usize>)>, String> {
264+
if n_samples == 0 {
265+
return Err("n_samples must be > 0".into());
266+
}
267+
if n_splits < 2 {
268+
return Err("n_splits must be at least 2".into());
269+
}
270+
if n_splits > n_samples {
271+
return Err("n_splits cannot exceed n_samples".into());
272+
}
273+
274+
let folds = contiguous_fold_bounds(n_samples, n_splits);
275+
let mut out = Vec::with_capacity(folds.len());
276+
for (start, stop) in folds {
277+
let test_indices: Vec<usize> = (start..stop).collect();
278+
let train_indices: Vec<usize> =
279+
(0..n_samples).filter(|i| *i < start || *i >= stop).collect();
280+
out.push((train_indices, test_indices));
281+
}
282+
Ok(out)
283+
}
284+
285+
pub fn count_train_test_overlaps(
286+
info_sets: &[(NaiveDateTime, NaiveDateTime)],
287+
train_indices: &[usize],
288+
test_indices: &[usize],
289+
) -> usize {
290+
let mut overlaps = 0;
291+
for train_idx in train_indices {
292+
let train_interval = info_sets[*train_idx];
293+
if test_indices
294+
.iter()
295+
.any(|test_idx| intervals_overlap(train_interval, info_sets[*test_idx]))
296+
{
297+
overlaps += 1;
298+
}
299+
}
300+
overlaps
301+
}
302+
303+
fn intervals_overlap(a: (NaiveDateTime, NaiveDateTime), b: (NaiveDateTime, NaiveDateTime)) -> bool {
304+
a.0 <= b.1 && b.0 <= a.1
305+
}
306+
307+
fn contiguous_fold_bounds(n_samples: usize, n_splits: usize) -> Vec<(usize, usize)> {
308+
let mut fold_sizes = vec![n_samples / n_splits; n_splits];
309+
for size in fold_sizes.iter_mut().take(n_samples % n_splits) {
310+
*size += 1;
311+
}
312+
313+
let mut current = 0;
314+
let mut bounds = Vec::with_capacity(n_splits);
315+
for fold_size in fold_sizes {
316+
let next = current + fold_size;
317+
bounds.push((current, next));
318+
current = next;
319+
}
320+
bounds
321+
}
322+
323+
fn contiguous_ranges(indices: &[usize]) -> Vec<(usize, usize)> {
324+
if indices.is_empty() {
325+
return Vec::new();
326+
}
327+
328+
let mut sorted = indices.to_vec();
329+
sorted.sort_unstable();
330+
331+
let mut ranges = Vec::new();
332+
let mut start = sorted[0];
333+
let mut prev = sorted[0];
334+
335+
for idx in sorted.into_iter().skip(1) {
336+
if idx == prev + 1 {
337+
prev = idx;
338+
continue;
156339
}
157-
Ok(splits)
340+
ranges.push((start, prev + 1));
341+
start = idx;
342+
prev = idx;
158343
}
344+
ranges.push((start, prev + 1));
345+
ranges
159346
}

0 commit comments

Comments
 (0)