From be85f1e58b0526cf491e8ab395465d70628399d6 Mon Sep 17 00:00:00 2001 From: Brianna Mueller Date: Tue, 31 Mar 2026 00:11:15 +0000 Subject: [PATCH] Add optional validation split with val-based early stopping Split training data into train/val subsets when --use_val is enabled. Validation accuracy is used for early stopping instead of test accuracy, preventing data leakage from the test set into the training loop. Disabled by default (--use_val False) so existing behavior is unchanged. --- system/flcore/clients/clientbase.py | 60 ++++++++++++++++++++++++- system/flcore/servers/serverala.py | 3 +- system/flcore/servers/serveramp.py | 3 +- system/flcore/servers/serverapfl.py | 3 +- system/flcore/servers/serverapple.py | 3 +- system/flcore/servers/serveras.py | 3 +- system/flcore/servers/serveravg.py | 3 +- system/flcore/servers/serverbabu.py | 3 +- system/flcore/servers/serverbase.py | 34 +++++++++++++- system/flcore/servers/serverbn.py | 3 +- system/flcore/servers/servercac.py | 3 +- system/flcore/servers/servercross.py | 3 +- system/flcore/servers/serverda.py | 3 +- system/flcore/servers/serverdbe.py | 3 +- system/flcore/servers/serverditto.py | 3 +- system/flcore/servers/serverdyn.py | 3 +- system/flcore/servers/serverfd.py | 3 +- system/flcore/servers/serverfml.py | 3 +- system/flcore/servers/serverfomo.py | 3 +- system/flcore/servers/servergc.py | 3 +- system/flcore/servers/servergen.py | 3 +- system/flcore/servers/servergh.py | 3 +- system/flcore/servers/servergpfl.py | 3 +- system/flcore/servers/serverkd.py | 3 +- system/flcore/servers/serverlc.py | 3 +- system/flcore/servers/serverlg.py | 3 +- system/flcore/servers/serverlocal.py | 3 +- system/flcore/servers/servermoon.py | 3 +- system/flcore/servers/servermtl.py | 3 +- system/flcore/servers/serverntd.py | 3 +- system/flcore/servers/serverpac.py | 3 +- system/flcore/servers/serverpcl.py | 3 +- system/flcore/servers/serverper.py | 3 +- system/flcore/servers/serverperavg.py | 3 +- system/flcore/servers/serverphp.py | 3 +- system/flcore/servers/serverproto.py | 3 +- system/flcore/servers/serverprox.py | 3 +- system/flcore/servers/serverrep.py | 3 +- system/flcore/servers/serverrod.py | 3 +- system/flcore/servers/serverscaffold.py | 3 +- system/main.py | 7 +++ 41 files changed, 173 insertions(+), 42 deletions(-) diff --git a/system/flcore/clients/clientbase.py b/system/flcore/clients/clientbase.py index 82c2aae0a..076ce3307 100644 --- a/system/flcore/clients/clientbase.py +++ b/system/flcore/clients/clientbase.py @@ -3,7 +3,7 @@ import torch.nn as nn import numpy as np import os -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, random_split from sklearn.preprocessing import label_binarize from sklearn import metrics from utils.data_utils import read_client_data @@ -43,6 +43,10 @@ def __init__(self, args, id, train_samples, test_samples, **kwargs): self.train_time_cost = {'num_rounds': 0, 'total_cost': 0.0} self.send_time_cost = {'num_rounds': 0, 'total_cost': 0.0} + self.use_val = getattr(args, 'use_val', False) + self.val_ratio = getattr(args, 'val_ratio', 0.2) + self.split_seed = getattr(args, 'split_seed', 0) + self.loss = nn.CrossEntropyLoss() self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate) self.learning_rate_scheduler = torch.optim.lr_scheduler.ExponentialLR( @@ -56,8 +60,23 @@ def load_train_data(self, batch_size=None): if batch_size == None: batch_size = self.batch_size train_data = read_client_data(self.dataset, self.id, is_train=True, few_shot=self.few_shot) + if self.use_val: + val_size = max(1, int(len(train_data) * self.val_ratio)) + train_size = len(train_data) - val_size + generator = torch.Generator().manual_seed(self.split_seed) + train_data, _ = random_split(train_data, [train_size, val_size], generator=generator) return DataLoader(train_data, batch_size, drop_last=True, shuffle=True) + def load_val_data(self, batch_size=None): + if batch_size == None: + batch_size = self.batch_size + full_data = read_client_data(self.dataset, self.id, is_train=True, few_shot=self.few_shot) + val_size = max(1, int(len(full_data) * self.val_ratio)) + train_size = len(full_data) - val_size + generator = torch.Generator().manual_seed(self.split_seed) + _, val_data = random_split(full_data, [train_size, val_size], generator=generator) + return DataLoader(val_data, batch_size, drop_last=False, shuffle=False) + def load_test_data(self, batch_size=None): if batch_size == None: batch_size = self.batch_size @@ -116,9 +135,46 @@ def test_metrics(self): y_true = np.concatenate(y_true, axis=0) auc = metrics.roc_auc_score(y_true, y_prob, average='micro') - + return test_acc, test_num, auc + def val_metrics(self): + valloaderfull = self.load_val_data() + self.model.eval() + + val_acc = 0 + val_num = 0 + y_prob = [] + y_true = [] + + with torch.no_grad(): + for x, y in valloaderfull: + if type(x) == type([]): + x[0] = x[0].to(self.device) + else: + x = x.to(self.device) + y = y.to(self.device) + output = self.model(x) + + val_acc += (torch.sum(torch.argmax(output, dim=1) == y)).item() + val_num += y.shape[0] + + y_prob.append(output.detach().cpu().numpy()) + nc = self.num_classes + if self.num_classes == 2: + nc += 1 + lb = label_binarize(y.detach().cpu().numpy(), classes=np.arange(nc)) + if self.num_classes == 2: + lb = lb[:, :2] + y_true.append(lb) + + y_prob = np.concatenate(y_prob, axis=0) + y_true = np.concatenate(y_true, axis=0) + + auc = metrics.roc_auc_score(y_true, y_prob, average='micro') + + return val_acc, val_num, auc + def train_metrics(self): trainloader = self.load_train_data() # self.model = self.load_model('model') diff --git a/system/flcore/servers/serverala.py b/system/flcore/servers/serverala.py index 97ba11263..84329a733 100644 --- a/system/flcore/servers/serverala.py +++ b/system/flcore/servers/serverala.py @@ -46,7 +46,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serveramp.py b/system/flcore/servers/serveramp.py index bb9b37a7b..435c18683 100644 --- a/system/flcore/servers/serveramp.py +++ b/system/flcore/servers/serveramp.py @@ -55,7 +55,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverapfl.py b/system/flcore/servers/serverapfl.py index af4399567..45ca1fb28 100644 --- a/system/flcore/servers/serverapfl.py +++ b/system/flcore/servers/serverapfl.py @@ -46,7 +46,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverapple.py b/system/flcore/servers/serverapple.py index 3d8fe0858..ee1151a76 100644 --- a/system/flcore/servers/serverapple.py +++ b/system/flcore/servers/serverapple.py @@ -57,7 +57,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*50, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serveras.py b/system/flcore/servers/serveras.py index 5dc6ab709..9a3fe1d4a 100644 --- a/system/flcore/servers/serveras.py +++ b/system/flcore/servers/serveras.py @@ -100,7 +100,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serveravg.py b/system/flcore/servers/serveravg.py index e4cb08670..93eeb619a 100644 --- a/system/flcore/servers/serveravg.py +++ b/system/flcore/servers/serveravg.py @@ -46,7 +46,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverbabu.py b/system/flcore/servers/serverbabu.py index f5c0a9c5d..5e286757d 100644 --- a/system/flcore/servers/serverbabu.py +++ b/system/flcore/servers/serverbabu.py @@ -47,7 +47,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverbase.py b/system/flcore/servers/serverbase.py index 435362d6c..c28d97419 100644 --- a/system/flcore/servers/serverbase.py +++ b/system/flcore/servers/serverbase.py @@ -44,9 +44,13 @@ def __init__(self, args, times): self.uploaded_ids = [] self.uploaded_models = [] + self.use_val = getattr(args, 'use_val', False) + self.rs_test_acc = [] self.rs_test_auc = [] self.rs_train_loss = [] + self.rs_val_acc = [] + self.rs_val_auc = [] self.times = times self.eval_gap = args.eval_gap @@ -182,6 +186,9 @@ def save_results(self): hf.create_dataset('rs_test_acc', data=self.rs_test_acc) hf.create_dataset('rs_test_auc', data=self.rs_test_auc) hf.create_dataset('rs_train_loss', data=self.rs_train_loss) + if self.rs_val_acc: + hf.create_dataset('rs_val_acc', data=self.rs_val_acc) + hf.create_dataset('rs_val_auc', data=self.rs_val_auc) def save_item(self, item, item_name): if not os.path.exists(self.save_folder_name): @@ -224,6 +231,20 @@ def train_metrics(self): return ids, num_samples, losses + def val_metrics(self): + num_samples = [] + tot_correct = [] + tot_auc = [] + for c in self.clients: + ct, ns, auc = c.val_metrics() + tot_correct.append(ct*1.0) + tot_auc.append(auc*ns) + num_samples.append(ns) + + ids = [c.id for c in self.clients] + + return ids, num_samples, tot_correct, tot_auc + # evaluate selected clients def evaluate(self, acc=None, loss=None): stats = self.test_metrics() @@ -234,12 +255,12 @@ def evaluate(self, acc=None, loss=None): train_loss = sum(stats_train[2])*1.0 / sum(stats_train[1]) accs = [a / n for a, n in zip(stats[2], stats[1])] aucs = [a / n for a, n in zip(stats[3], stats[1])] - + if acc == None: self.rs_test_acc.append(test_acc) else: acc.append(test_acc) - + if loss == None: self.rs_train_loss.append(train_loss) else: @@ -252,6 +273,15 @@ def evaluate(self, acc=None, loss=None): print("Std Test Accuracy: {:.4f}".format(np.std(accs))) print("Std Test AUC: {:.4f}".format(np.std(aucs))) + if self.use_val: + stats_val = self.val_metrics() + val_acc = sum(stats_val[2])*1.0 / sum(stats_val[1]) + val_auc = sum(stats_val[3])*1.0 / sum(stats_val[1]) + self.rs_val_acc.append(val_acc) + self.rs_val_auc.append(val_auc) + print("Averaged Val Accuracy: {:.4f}".format(val_acc)) + print("Averaged Val AUC: {:.4f}".format(val_auc)) + def print_(self, test_acc, test_auc, train_loss): print("Average Test Accuracy: {:.4f}".format(test_acc)) print("Average Test AUC: {:.4f}".format(test_auc)) diff --git a/system/flcore/servers/serverbn.py b/system/flcore/servers/serverbn.py index 7afb24030..795b19310 100644 --- a/system/flcore/servers/serverbn.py +++ b/system/flcore/servers/serverbn.py @@ -45,7 +45,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/servercac.py b/system/flcore/servers/servercac.py index 8876c2287..c0fe0432f 100644 --- a/system/flcore/servers/servercac.py +++ b/system/flcore/servers/servercac.py @@ -50,7 +50,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/servercross.py b/system/flcore/servers/servercross.py index 698c42a85..b0d7e424e 100644 --- a/system/flcore/servers/servercross.py +++ b/system/flcore/servers/servercross.py @@ -69,7 +69,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverda.py b/system/flcore/servers/serverda.py index 7ded3781c..f37d29e65 100644 --- a/system/flcore/servers/serverda.py +++ b/system/flcore/servers/serverda.py @@ -47,7 +47,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverdbe.py b/system/flcore/servers/serverdbe.py index 1a6b9bf06..8593b5b7e 100644 --- a/system/flcore/servers/serverdbe.py +++ b/system/flcore/servers/serverdbe.py @@ -68,7 +68,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverditto.py b/system/flcore/servers/serverditto.py index af8d861e7..84454e74d 100644 --- a/system/flcore/servers/serverditto.py +++ b/system/flcore/servers/serverditto.py @@ -52,7 +52,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverdyn.py b/system/flcore/servers/serverdyn.py index 856218b1b..9a986796a 100644 --- a/system/flcore/servers/serverdyn.py +++ b/system/flcore/servers/serverdyn.py @@ -55,7 +55,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*50, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverfd.py b/system/flcore/servers/serverfd.py index ab2bd3a25..511d32ae5 100644 --- a/system/flcore/servers/serverfd.py +++ b/system/flcore/servers/serverfd.py @@ -47,7 +47,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*50, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverfml.py b/system/flcore/servers/serverfml.py index 796329726..7f302f291 100644 --- a/system/flcore/servers/serverfml.py +++ b/system/flcore/servers/serverfml.py @@ -48,7 +48,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverfomo.py b/system/flcore/servers/serverfomo.py index 69e7e25bc..7648de56a 100644 --- a/system/flcore/servers/serverfomo.py +++ b/system/flcore/servers/serverfomo.py @@ -53,7 +53,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/servergc.py b/system/flcore/servers/servergc.py index c3fb40e1a..e8a4fab4b 100644 --- a/system/flcore/servers/servergc.py +++ b/system/flcore/servers/servergc.py @@ -77,7 +77,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/servergen.py b/system/flcore/servers/servergen.py index c2b0701bb..abadc389e 100644 --- a/system/flcore/servers/servergen.py +++ b/system/flcore/servers/servergen.py @@ -81,7 +81,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/servergh.py b/system/flcore/servers/servergh.py index 7e4ce3383..30ad295df 100644 --- a/system/flcore/servers/servergh.py +++ b/system/flcore/servers/servergh.py @@ -54,7 +54,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/servergpfl.py b/system/flcore/servers/servergpfl.py index d37cbc966..391669938 100644 --- a/system/flcore/servers/servergpfl.py +++ b/system/flcore/servers/servergpfl.py @@ -56,7 +56,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverkd.py b/system/flcore/servers/serverkd.py index 2da5fd635..52f47208a 100644 --- a/system/flcore/servers/serverkd.py +++ b/system/flcore/servers/serverkd.py @@ -54,7 +54,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break self.energy = self.T_start + ((1 + i) / self.global_rounds) * (self.T_end - self.T_start) diff --git a/system/flcore/servers/serverlc.py b/system/flcore/servers/serverlc.py index 68c7aab1b..29f4793b3 100644 --- a/system/flcore/servers/serverlc.py +++ b/system/flcore/servers/serverlc.py @@ -59,7 +59,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverlg.py b/system/flcore/servers/serverlg.py index 0e8cc409d..783b85bb7 100644 --- a/system/flcore/servers/serverlg.py +++ b/system/flcore/servers/serverlg.py @@ -48,7 +48,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverlocal.py b/system/flcore/servers/serverlocal.py index 3176021c8..ab5aa28a1 100644 --- a/system/flcore/servers/serverlocal.py +++ b/system/flcore/servers/serverlocal.py @@ -41,7 +41,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break diff --git a/system/flcore/servers/servermoon.py b/system/flcore/servers/servermoon.py index d6c0deb66..1eaab6687 100644 --- a/system/flcore/servers/servermoon.py +++ b/system/flcore/servers/servermoon.py @@ -47,7 +47,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*50, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/servermtl.py b/system/flcore/servers/servermtl.py index c9257446d..8952a2023 100644 --- a/system/flcore/servers/servermtl.py +++ b/system/flcore/servers/servermtl.py @@ -56,7 +56,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break diff --git a/system/flcore/servers/serverntd.py b/system/flcore/servers/serverntd.py index 493ef5820..3701eef13 100644 --- a/system/flcore/servers/serverntd.py +++ b/system/flcore/servers/serverntd.py @@ -46,7 +46,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverpac.py b/system/flcore/servers/serverpac.py index 92dfc9662..e83058cfd 100644 --- a/system/flcore/servers/serverpac.py +++ b/system/flcore/servers/serverpac.py @@ -68,7 +68,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*50, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverpcl.py b/system/flcore/servers/serverpcl.py index b3872b988..66d417942 100644 --- a/system/flcore/servers/serverpcl.py +++ b/system/flcore/servers/serverpcl.py @@ -50,7 +50,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*50, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverper.py b/system/flcore/servers/serverper.py index 357d3a588..24e2592b3 100644 --- a/system/flcore/servers/serverper.py +++ b/system/flcore/servers/serverper.py @@ -47,7 +47,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverperavg.py b/system/flcore/servers/serverperavg.py index ed4ff0158..5dc7be8f7 100644 --- a/system/flcore/servers/serverperavg.py +++ b/system/flcore/servers/serverperavg.py @@ -48,7 +48,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverphp.py b/system/flcore/servers/serverphp.py index 26d11d64e..298c3d4b5 100644 --- a/system/flcore/servers/serverphp.py +++ b/system/flcore/servers/serverphp.py @@ -43,7 +43,8 @@ def train(self): self.call_dlg(i) self.aggregate_parameters() - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverproto.py b/system/flcore/servers/serverproto.py index c89333f45..0b4de70c7 100644 --- a/system/flcore/servers/serverproto.py +++ b/system/flcore/servers/serverproto.py @@ -48,7 +48,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*50, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverprox.py b/system/flcore/servers/serverprox.py index 3d89cf361..62a3d5b02 100644 --- a/system/flcore/servers/serverprox.py +++ b/system/flcore/servers/serverprox.py @@ -47,7 +47,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverrep.py b/system/flcore/servers/serverrep.py index 6fa869ba5..d68a3cd0a 100644 --- a/system/flcore/servers/serverrep.py +++ b/system/flcore/servers/serverrep.py @@ -47,7 +47,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverrod.py b/system/flcore/servers/serverrod.py index b83643e9b..b128882a1 100644 --- a/system/flcore/servers/serverrod.py +++ b/system/flcore/servers/serverrod.py @@ -46,7 +46,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverscaffold.py b/system/flcore/servers/serverscaffold.py index 38941db10..bcecd4a60 100644 --- a/system/flcore/servers/serverscaffold.py +++ b/system/flcore/servers/serverscaffold.py @@ -54,7 +54,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/main.py b/system/main.py index cbc28b56a..c127bfd40 100644 --- a/system/main.py +++ b/system/main.py @@ -431,6 +431,13 @@ def run(args): help="Set this for text tasks. 80 for Shakespeare. 32000 for AG_News and SogouNews.") parser.add_argument('-ml', "--max_len", type=int, default=200) parser.add_argument('-fs', "--few_shot", type=int, default=0) + # validation split + parser.add_argument('-uv', "--use_val", type=bool, default=False, + help="Use a validation split from training data") + parser.add_argument('-vr', "--val_ratio", type=float, default=0.2, + help="Ratio of training data to use for validation") + parser.add_argument('-ss', "--split_seed", type=int, default=0, + help="Random seed for train/val split reproducibility") # practical parser.add_argument('-cdr', "--client_drop_rate", type=float, default=0.0, help="Rate for clients that train but drop out")