@@ -451,7 +451,8 @@ def kfold_cv_error(self, n_splits, *args, norm=np.linalg.norm, **kwargs):
451451 for train_index , test_index in kf .split (self .database ):
452452 new_db = self .database [train_index ]
453453 rom = type (self )(new_db , copy .deepcopy (self .reduction ),
454- copy .deepcopy (self .approximation )).fit (
454+ copy .deepcopy (self .approximation ),
455+ plugins = [copy .deepcopy (p ) for p in self .plugins ]).fit (
455456 * args , ** kwargs )
456457
457458 error .append (rom .test_error (self .database [test_index ], norm ))
@@ -487,7 +488,8 @@ def loo_error(self, *args, norm=np.linalg.norm, **kwargs):
487488 new_db = self .database [indeces ]
488489 test_db = self .database [~ indeces ]
489490 rom = type (self )(new_db , copy .deepcopy (self .reduction ),
490- copy .deepcopy (self .approximation )).fit ()
491+ copy .deepcopy (self .approximation ),
492+ plugins = [copy .deepcopy (p ) for p in self .plugins ]).fit ()
491493
492494 error [j ] = rom .test_error (test_db )
493495
@@ -860,6 +862,7 @@ def kfold_cv_error(self, n_splits, *args, norm=np.linalg.norm, relative=True,
860862 kf = KFold (n_splits = n_splits )
861863 for train_index , test_index in kf .split (self .database ):
862864 new_db = self .database [train_index ]
865+ # TODO: Fix plugins handling - should pass: plugins=[copy.deepcopy(p) for p in self.plugins]
863866 rom = type (self )(new_db , copy .deepcopy (self .reduction ),
864867 copy .deepcopy (self .approximation )).fit (
865868 * args , ** kwargs )
@@ -896,6 +899,7 @@ def loo_error(self, *args, norm=np.linalg.norm, **kwargs):
896899
897900 new_db = self .database [indeces ]
898901 test_db = self .database [~ indeces ]
902+ # TODO: Fix plugins handling - should pass: plugins=[copy.deepcopy(p) for p in self.plugins]
899903 rom = type (self )(new_db , copy .deepcopy (self .reduction ),
900904 copy .deepcopy (self .approximation )).fit ()
901905
0 commit comments