11use chrono:: NaiveDateTime ;
2+ use itertools:: Itertools ;
3+ use std:: collections:: HashSet ;
24
35/// Simple classifier interface for cross-validation.
46pub trait SimpleClassifier {
@@ -69,10 +71,7 @@ pub fn ml_get_train_times(
6971 for ( start, end) in info_sets {
7072 let mut keep = true ;
7173 for ( test_start, test_end) in test_times {
72- let start_in = * start >= * test_start && * start <= * test_end;
73- let end_in = * end >= * test_start && * end <= * test_end;
74- let envelop = * start <= * test_start && * end >= * test_end;
75- if start_in || end_in || envelop {
74+ if intervals_overlap ( ( * start, * end) , ( * test_start, * test_end) ) {
7675 keep = false ;
7776 break ;
7877 }
@@ -84,6 +83,29 @@ pub fn ml_get_train_times(
8483 out
8584}
8685
86+ #[ derive( Debug , Clone , PartialEq , Eq ) ]
87+ pub struct PurgedSplitDiagnostics {
88+ pub split_id : usize ,
89+ pub test_ranges : Vec < ( usize , usize ) > ,
90+ pub purged_indices : Vec < usize > ,
91+ pub embargo_indices : Vec < usize > ,
92+ pub overlap_count_after_purge : usize ,
93+ }
94+
95+ #[ derive( Debug , Clone , PartialEq , Eq ) ]
96+ pub struct PurgedSplit {
97+ pub train_indices : Vec < usize > ,
98+ pub test_indices : Vec < usize > ,
99+ pub diagnostics : PurgedSplitDiagnostics ,
100+ }
101+
102+ #[ derive( Debug , Clone , PartialEq , Eq ) ]
103+ pub struct CpcvPath {
104+ pub path_id : usize ,
105+ pub test_fold_ids : Vec < usize > ,
106+ pub split : PurgedSplit ,
107+ }
108+
87109pub struct PurgedKFold {
88110 n_splits : usize ,
89111 samples_info_sets : Vec < ( NaiveDateTime , NaiveDateTime ) > ,
@@ -96,64 +118,229 @@ impl PurgedKFold {
96118 samples_info_sets : Vec < ( NaiveDateTime , NaiveDateTime ) > ,
97119 pct_embargo : f64 ,
98120 ) -> Result < Self , String > {
121+ if n_splits < 2 {
122+ return Err ( "n_splits must be at least 2" . into ( ) ) ;
123+ }
99124 if samples_info_sets. is_empty ( ) {
100125 return Err ( "samples_info_sets cannot be empty" . into ( ) ) ;
101126 }
127+ if n_splits > samples_info_sets. len ( ) {
128+ return Err ( "n_splits cannot exceed sample count" . into ( ) ) ;
129+ }
130+ if !( 0.0 ..1.0 ) . contains ( & pct_embargo) {
131+ return Err ( "pct_embargo must be in [0.0, 1.0)" . into ( ) ) ;
132+ }
133+ for ( idx, ( start, end) ) in samples_info_sets. iter ( ) . enumerate ( ) {
134+ if start > end {
135+ return Err ( format ! ( "invalid information set at index {idx}: start > end" ) ) ;
136+ }
137+ }
138+
102139 Ok ( Self { n_splits, samples_info_sets, pct_embargo } )
103140 }
104141
105142 pub fn split ( & self , n_samples : usize ) -> Result < Vec < ( Vec < usize > , Vec < usize > ) > , String > {
143+ let splits = self . split_with_diagnostics ( n_samples) ?;
144+ Ok ( splits. into_iter ( ) . map ( |s| ( s. train_indices , s. test_indices ) ) . collect ( ) )
145+ }
146+
147+ pub fn split_with_diagnostics ( & self , n_samples : usize ) -> Result < Vec < PurgedSplit > , String > {
148+ self . validate_n_samples ( n_samples) ?;
149+ let folds = contiguous_fold_bounds ( n_samples, self . n_splits ) ;
150+
151+ let mut out = Vec :: with_capacity ( folds. len ( ) ) ;
152+ for ( split_id, ( start, stop) ) in folds. iter ( ) . enumerate ( ) {
153+ let test_indices: Vec < usize > = ( * start..* stop) . collect ( ) ;
154+ out. push ( self . build_split ( split_id, test_indices) ?) ;
155+ }
156+ Ok ( out)
157+ }
158+
159+ pub fn cpcv_paths (
160+ & self ,
161+ n_samples : usize ,
162+ n_test_splits : usize ,
163+ ) -> Result < Vec < CpcvPath > , String > {
164+ self . validate_n_samples ( n_samples) ?;
165+ if n_test_splits == 0 || n_test_splits >= self . n_splits {
166+ return Err ( "n_test_splits must be in [1, n_splits)" . into ( ) ) ;
167+ }
168+
169+ let folds = contiguous_fold_bounds ( n_samples, self . n_splits ) ;
170+ let mut paths = Vec :: new ( ) ;
171+
172+ for ( path_id, test_fold_ids) in ( 0 ..self . n_splits ) . combinations ( n_test_splits) . enumerate ( ) {
173+ let mut test_indices = Vec :: new ( ) ;
174+ for fold_id in & test_fold_ids {
175+ let ( start, stop) = folds[ * fold_id] ;
176+ test_indices. extend ( start..stop) ;
177+ }
178+ let split = self . build_split ( path_id, test_indices) ?;
179+ paths. push ( CpcvPath { path_id, test_fold_ids, split } ) ;
180+ }
181+
182+ Ok ( paths)
183+ }
184+
185+ fn validate_n_samples ( & self , n_samples : usize ) -> Result < ( ) , String > {
106186 if n_samples != self . samples_info_sets . len ( ) {
107- return Err ( "Dataset length must match samples_info_sets" . into ( ) ) ;
108- }
109- let n = n_samples;
110- let mut fold_sizes = vec ! [ n / self . n_splits; self . n_splits] ;
111- for i in 0 ..( n % self . n_splits ) {
112- fold_sizes[ i] += 1 ;
113- }
114- let mut current = 0 ;
115- let mut splits = Vec :: new ( ) ;
116- for fold_size in fold_sizes {
117- let start = current;
118- let stop = current + fold_size;
119- let test_indices: Vec < usize > = ( start..stop) . collect ( ) ;
120- let mut train_mask = vec ! [ true ; n] ;
121-
122- // purge overlaps
123- let test_start = self . samples_info_sets [ test_indices[ 0 ] ] . 1 ;
124- let test_end = self . samples_info_sets [ * test_indices. last ( ) . unwrap ( ) ] . 1 ;
125- for ( i, ( s, e) ) in self . samples_info_sets . iter ( ) . enumerate ( ) {
126- let start_in = * s >= test_start && * s <= test_end;
127- let end_in = * e >= test_start && * e <= test_end;
128- let envelop = * s <= test_start && * e >= test_end;
129- if start_in || end_in || envelop {
130- train_mask[ i] = false ;
131- }
187+ return Err ( "dataset length must match samples_info_sets" . into ( ) ) ;
188+ }
189+ Ok ( ( ) )
190+ }
191+
192+ fn build_split (
193+ & self ,
194+ split_id : usize ,
195+ test_indices : Vec < usize > ,
196+ ) -> Result < PurgedSplit , String > {
197+ if test_indices. is_empty ( ) {
198+ return Err ( "test_indices cannot be empty" . into ( ) ) ;
199+ }
200+
201+ let n_samples = self . samples_info_sets . len ( ) ;
202+ let test_set: HashSet < usize > = test_indices. iter ( ) . copied ( ) . collect ( ) ;
203+ let test_intervals: Vec < ( NaiveDateTime , NaiveDateTime ) > =
204+ test_indices. iter ( ) . map ( |idx| self . samples_info_sets [ * idx] ) . collect ( ) ;
205+
206+ let mut purged = Vec :: new ( ) ;
207+ for idx in 0 ..n_samples {
208+ if test_set. contains ( & idx) {
209+ continue ;
210+ }
211+ let candidate = self . samples_info_sets [ idx] ;
212+ let overlaps = test_intervals
213+ . iter ( )
214+ . any ( |test_interval| intervals_overlap ( candidate, * test_interval) ) ;
215+ if overlaps {
216+ purged. push ( idx) ;
132217 }
218+ }
133219
134- // embargo
135- let embargo = ( self . pct_embargo * n as f64 ) . ceil ( ) as isize ;
136- if embargo > 0 {
137- let after = ( stop as isize + embargo ) . min ( n as isize ) ;
138- let before = ( start as isize - embargo) . max ( 0 ) ;
139- for i in start.. ( after as usize ) {
140- if i < n {
141- train_mask [ i ] = false ;
220+ let embargo = ( ( self . pct_embargo * n_samples as f64 ) . ceil ( ) ) as usize ;
221+ let mut embargoed = HashSet :: new ( ) ;
222+ if embargo > 0 {
223+ for ( _start , stop ) in contiguous_ranges ( & test_indices ) {
224+ let embargo_stop = ( stop + embargo) . min ( n_samples ) ;
225+ for idx in stop..embargo_stop {
226+ if !test_set . contains ( & idx ) {
227+ embargoed . insert ( idx ) ;
142228 }
143229 }
144- for i in before as usize ..start {
145- train_mask[ i] = false ;
146- }
147230 }
231+ }
148232
149- let train_indices: Vec < usize > = train_mask
150- . iter ( )
151- . enumerate ( )
152- . filter_map ( |( i, keep) | if * keep { Some ( i) } else { None } )
153- . collect ( ) ;
154- splits. push ( ( train_indices, test_indices) ) ;
155- current = stop;
233+ let purged_set: HashSet < usize > = purged. iter ( ) . copied ( ) . collect ( ) ;
234+ let train_indices: Vec < usize > = ( 0 ..n_samples)
235+ . filter ( |idx| {
236+ !test_set. contains ( idx) && !purged_set. contains ( idx) && !embargoed. contains ( idx)
237+ } )
238+ . collect ( ) ;
239+
240+ let overlap_count_after_purge =
241+ count_train_test_overlaps ( & self . samples_info_sets , & train_indices, & test_indices) ;
242+
243+ let mut embargo_indices: Vec < usize > = embargoed. into_iter ( ) . collect ( ) ;
244+ embargo_indices. sort_unstable ( ) ;
245+
246+ Ok ( PurgedSplit {
247+ train_indices,
248+ test_indices : test_indices. clone ( ) ,
249+ diagnostics : PurgedSplitDiagnostics {
250+ split_id,
251+ test_ranges : contiguous_ranges ( & test_indices) ,
252+ purged_indices : purged,
253+ embargo_indices,
254+ overlap_count_after_purge,
255+ } ,
256+ } )
257+ }
258+ }
259+
260+ pub fn naive_kfold_splits (
261+ n_samples : usize ,
262+ n_splits : usize ,
263+ ) -> Result < Vec < ( Vec < usize > , Vec < usize > ) > , String > {
264+ if n_samples == 0 {
265+ return Err ( "n_samples must be > 0" . into ( ) ) ;
266+ }
267+ if n_splits < 2 {
268+ return Err ( "n_splits must be at least 2" . into ( ) ) ;
269+ }
270+ if n_splits > n_samples {
271+ return Err ( "n_splits cannot exceed n_samples" . into ( ) ) ;
272+ }
273+
274+ let folds = contiguous_fold_bounds ( n_samples, n_splits) ;
275+ let mut out = Vec :: with_capacity ( folds. len ( ) ) ;
276+ for ( start, stop) in folds {
277+ let test_indices: Vec < usize > = ( start..stop) . collect ( ) ;
278+ let train_indices: Vec < usize > =
279+ ( 0 ..n_samples) . filter ( |i| * i < start || * i >= stop) . collect ( ) ;
280+ out. push ( ( train_indices, test_indices) ) ;
281+ }
282+ Ok ( out)
283+ }
284+
285+ pub fn count_train_test_overlaps (
286+ info_sets : & [ ( NaiveDateTime , NaiveDateTime ) ] ,
287+ train_indices : & [ usize ] ,
288+ test_indices : & [ usize ] ,
289+ ) -> usize {
290+ let mut overlaps = 0 ;
291+ for train_idx in train_indices {
292+ let train_interval = info_sets[ * train_idx] ;
293+ if test_indices
294+ . iter ( )
295+ . any ( |test_idx| intervals_overlap ( train_interval, info_sets[ * test_idx] ) )
296+ {
297+ overlaps += 1 ;
298+ }
299+ }
300+ overlaps
301+ }
302+
303+ fn intervals_overlap ( a : ( NaiveDateTime , NaiveDateTime ) , b : ( NaiveDateTime , NaiveDateTime ) ) -> bool {
304+ a. 0 <= b. 1 && b. 0 <= a. 1
305+ }
306+
307+ fn contiguous_fold_bounds ( n_samples : usize , n_splits : usize ) -> Vec < ( usize , usize ) > {
308+ let mut fold_sizes = vec ! [ n_samples / n_splits; n_splits] ;
309+ for size in fold_sizes. iter_mut ( ) . take ( n_samples % n_splits) {
310+ * size += 1 ;
311+ }
312+
313+ let mut current = 0 ;
314+ let mut bounds = Vec :: with_capacity ( n_splits) ;
315+ for fold_size in fold_sizes {
316+ let next = current + fold_size;
317+ bounds. push ( ( current, next) ) ;
318+ current = next;
319+ }
320+ bounds
321+ }
322+
323+ fn contiguous_ranges ( indices : & [ usize ] ) -> Vec < ( usize , usize ) > {
324+ if indices. is_empty ( ) {
325+ return Vec :: new ( ) ;
326+ }
327+
328+ let mut sorted = indices. to_vec ( ) ;
329+ sorted. sort_unstable ( ) ;
330+
331+ let mut ranges = Vec :: new ( ) ;
332+ let mut start = sorted[ 0 ] ;
333+ let mut prev = sorted[ 0 ] ;
334+
335+ for idx in sorted. into_iter ( ) . skip ( 1 ) {
336+ if idx == prev + 1 {
337+ prev = idx;
338+ continue ;
156339 }
157- Ok ( splits)
340+ ranges. push ( ( start, prev + 1 ) ) ;
341+ start = idx;
342+ prev = idx;
158343 }
344+ ranges. push ( ( start, prev + 1 ) ) ;
345+ ranges
159346}
0 commit comments