Skip to content

Commit a9dc5cf

Browse files
Davide-Miottindem0
authored andcommitted
add plugins to cross_val and loo
1 parent 97987c1 commit a9dc5cf

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

ezyrb/reducedordermodel.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,8 @@ def kfold_cv_error(self, n_splits, *args, norm=np.linalg.norm, **kwargs):
451451
for train_index, test_index in kf.split(self.database):
452452
new_db = self.database[train_index]
453453
rom = type(self)(new_db, copy.deepcopy(self.reduction),
454-
copy.deepcopy(self.approximation)).fit(
454+
copy.deepcopy(self.approximation),
455+
plugins=[copy.deepcopy(p) for p in self.plugins]).fit(
455456
*args, **kwargs)
456457

457458
error.append(rom.test_error(self.database[test_index], norm))
@@ -487,7 +488,8 @@ def loo_error(self, *args, norm=np.linalg.norm, **kwargs):
487488
new_db = self.database[indeces]
488489
test_db = self.database[~indeces]
489490
rom = type(self)(new_db, copy.deepcopy(self.reduction),
490-
copy.deepcopy(self.approximation)).fit()
491+
copy.deepcopy(self.approximation),
492+
plugins=[copy.deepcopy(p) for p in self.plugins]).fit()
491493

492494
error[j] = rom.test_error(test_db)
493495

@@ -860,6 +862,7 @@ def kfold_cv_error(self, n_splits, *args, norm=np.linalg.norm, relative=True,
860862
kf = KFold(n_splits=n_splits)
861863
for train_index, test_index in kf.split(self.database):
862864
new_db = self.database[train_index]
865+
# TODO: Fix plugins handling - should pass: plugins=[copy.deepcopy(p) for p in self.plugins]
863866
rom = type(self)(new_db, copy.deepcopy(self.reduction),
864867
copy.deepcopy(self.approximation)).fit(
865868
*args, **kwargs)
@@ -896,6 +899,7 @@ def loo_error(self, *args, norm=np.linalg.norm, **kwargs):
896899

897900
new_db = self.database[indeces]
898901
test_db = self.database[~indeces]
902+
# TODO: Fix plugins handling - should pass: plugins=[copy.deepcopy(p) for p in self.plugins]
899903
rom = type(self)(new_db, copy.deepcopy(self.reduction),
900904
copy.deepcopy(self.approximation)).fit()
901905

0 commit comments

Comments
 (0)