diff --git a/system/flcore/clients/clientbase.py b/system/flcore/clients/clientbase.py index 82c2aae0..076ce330 100644 --- a/system/flcore/clients/clientbase.py +++ b/system/flcore/clients/clientbase.py @@ -3,7 +3,7 @@ import torch.nn as nn import numpy as np import os -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, random_split from sklearn.preprocessing import label_binarize from sklearn import metrics from utils.data_utils import read_client_data @@ -43,6 +43,10 @@ def __init__(self, args, id, train_samples, test_samples, **kwargs): self.train_time_cost = {'num_rounds': 0, 'total_cost': 0.0} self.send_time_cost = {'num_rounds': 0, 'total_cost': 0.0} + self.use_val = getattr(args, 'use_val', False) + self.val_ratio = getattr(args, 'val_ratio', 0.2) + self.split_seed = getattr(args, 'split_seed', 0) + self.loss = nn.CrossEntropyLoss() self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate) self.learning_rate_scheduler = torch.optim.lr_scheduler.ExponentialLR( @@ -56,8 +60,23 @@ def load_train_data(self, batch_size=None): if batch_size == None: batch_size = self.batch_size train_data = read_client_data(self.dataset, self.id, is_train=True, few_shot=self.few_shot) + if self.use_val: + val_size = max(1, int(len(train_data) * self.val_ratio)) + train_size = len(train_data) - val_size + generator = torch.Generator().manual_seed(self.split_seed) + train_data, _ = random_split(train_data, [train_size, val_size], generator=generator) return DataLoader(train_data, batch_size, drop_last=True, shuffle=True) + def load_val_data(self, batch_size=None): + if batch_size == None: + batch_size = self.batch_size + full_data = read_client_data(self.dataset, self.id, is_train=True, few_shot=self.few_shot) + val_size = max(1, int(len(full_data) * self.val_ratio)) + train_size = len(full_data) - val_size + generator = torch.Generator().manual_seed(self.split_seed) + _, val_data = random_split(full_data, [train_size, val_size], generator=generator) + return DataLoader(val_data, batch_size, drop_last=False, shuffle=False) + def load_test_data(self, batch_size=None): if batch_size == None: batch_size = self.batch_size @@ -116,9 +135,46 @@ def test_metrics(self): y_true = np.concatenate(y_true, axis=0) auc = metrics.roc_auc_score(y_true, y_prob, average='micro') - + return test_acc, test_num, auc + def val_metrics(self): + valloaderfull = self.load_val_data() + self.model.eval() + + val_acc = 0 + val_num = 0 + y_prob = [] + y_true = [] + + with torch.no_grad(): + for x, y in valloaderfull: + if type(x) == type([]): + x[0] = x[0].to(self.device) + else: + x = x.to(self.device) + y = y.to(self.device) + output = self.model(x) + + val_acc += (torch.sum(torch.argmax(output, dim=1) == y)).item() + val_num += y.shape[0] + + y_prob.append(output.detach().cpu().numpy()) + nc = self.num_classes + if self.num_classes == 2: + nc += 1 + lb = label_binarize(y.detach().cpu().numpy(), classes=np.arange(nc)) + if self.num_classes == 2: + lb = lb[:, :2] + y_true.append(lb) + + y_prob = np.concatenate(y_prob, axis=0) + y_true = np.concatenate(y_true, axis=0) + + auc = metrics.roc_auc_score(y_true, y_prob, average='micro') + + return val_acc, val_num, auc + def train_metrics(self): trainloader = self.load_train_data() # self.model = self.load_model('model') diff --git a/system/flcore/servers/serverala.py b/system/flcore/servers/serverala.py index 97ba1126..84329a73 100644 --- a/system/flcore/servers/serverala.py +++ b/system/flcore/servers/serverala.py @@ -46,7 +46,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serveramp.py b/system/flcore/servers/serveramp.py index bb9b37a7..435c1868 100644 --- a/system/flcore/servers/serveramp.py +++ b/system/flcore/servers/serveramp.py @@ -55,7 +55,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverapfl.py b/system/flcore/servers/serverapfl.py index af439956..45ca1fb2 100644 --- a/system/flcore/servers/serverapfl.py +++ b/system/flcore/servers/serverapfl.py @@ -46,7 +46,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverapple.py b/system/flcore/servers/serverapple.py index 3d8fe085..ee1151a7 100644 --- a/system/flcore/servers/serverapple.py +++ b/system/flcore/servers/serverapple.py @@ -57,7 +57,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*50, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serveras.py b/system/flcore/servers/serveras.py index 5dc6ab70..9a3fe1d4 100644 --- a/system/flcore/servers/serveras.py +++ b/system/flcore/servers/serveras.py @@ -100,7 +100,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serveravg.py b/system/flcore/servers/serveravg.py index e4cb0867..93eeb619 100644 --- a/system/flcore/servers/serveravg.py +++ b/system/flcore/servers/serveravg.py @@ -46,7 +46,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverbabu.py b/system/flcore/servers/serverbabu.py index f5c0a9c5..5e286757 100644 --- a/system/flcore/servers/serverbabu.py +++ b/system/flcore/servers/serverbabu.py @@ -47,7 +47,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverbase.py b/system/flcore/servers/serverbase.py index 435362d6..c28d9741 100644 --- a/system/flcore/servers/serverbase.py +++ b/system/flcore/servers/serverbase.py @@ -44,9 +44,13 @@ def __init__(self, args, times): self.uploaded_ids = [] self.uploaded_models = [] + self.use_val = getattr(args, 'use_val', False) + self.rs_test_acc = [] self.rs_test_auc = [] self.rs_train_loss = [] + self.rs_val_acc = [] + self.rs_val_auc = [] self.times = times self.eval_gap = args.eval_gap @@ -182,6 +186,9 @@ def save_results(self): hf.create_dataset('rs_test_acc', data=self.rs_test_acc) hf.create_dataset('rs_test_auc', data=self.rs_test_auc) hf.create_dataset('rs_train_loss', data=self.rs_train_loss) + if self.rs_val_acc: + hf.create_dataset('rs_val_acc', data=self.rs_val_acc) + hf.create_dataset('rs_val_auc', data=self.rs_val_auc) def save_item(self, item, item_name): if not os.path.exists(self.save_folder_name): @@ -224,6 +231,20 @@ def train_metrics(self): return ids, num_samples, losses + def val_metrics(self): + num_samples = [] + tot_correct = [] + tot_auc = [] + for c in self.clients: + ct, ns, auc = c.val_metrics() + tot_correct.append(ct*1.0) + tot_auc.append(auc*ns) + num_samples.append(ns) + + ids = [c.id for c in self.clients] + + return ids, num_samples, tot_correct, tot_auc + # evaluate selected clients def evaluate(self, acc=None, loss=None): stats = self.test_metrics() @@ -234,12 +255,12 @@ def evaluate(self, acc=None, loss=None): train_loss = sum(stats_train[2])*1.0 / sum(stats_train[1]) accs = [a / n for a, n in zip(stats[2], stats[1])] aucs = [a / n for a, n in zip(stats[3], stats[1])] - + if acc == None: self.rs_test_acc.append(test_acc) else: acc.append(test_acc) - + if loss == None: self.rs_train_loss.append(train_loss) else: @@ -252,6 +273,15 @@ def evaluate(self, acc=None, loss=None): print("Std Test Accuracy: {:.4f}".format(np.std(accs))) print("Std Test AUC: {:.4f}".format(np.std(aucs))) + if self.use_val: + stats_val = self.val_metrics() + val_acc = sum(stats_val[2])*1.0 / sum(stats_val[1]) + val_auc = sum(stats_val[3])*1.0 / sum(stats_val[1]) + self.rs_val_acc.append(val_acc) + self.rs_val_auc.append(val_auc) + print("Averaged Val Accuracy: {:.4f}".format(val_acc)) + print("Averaged Val AUC: {:.4f}".format(val_auc)) + def print_(self, test_acc, test_auc, train_loss): print("Average Test Accuracy: {:.4f}".format(test_acc)) print("Average Test AUC: {:.4f}".format(test_auc)) diff --git a/system/flcore/servers/serverbn.py b/system/flcore/servers/serverbn.py index 7afb2403..795b1931 100644 --- a/system/flcore/servers/serverbn.py +++ b/system/flcore/servers/serverbn.py @@ -45,7 +45,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/servercac.py b/system/flcore/servers/servercac.py index 8876c228..c0fe0432 100644 --- a/system/flcore/servers/servercac.py +++ b/system/flcore/servers/servercac.py @@ -50,7 +50,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/servercross.py b/system/flcore/servers/servercross.py index 698c42a8..b0d7e424 100644 --- a/system/flcore/servers/servercross.py +++ b/system/flcore/servers/servercross.py @@ -69,7 +69,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverda.py b/system/flcore/servers/serverda.py index 7ded3781..f37d29e6 100644 --- a/system/flcore/servers/serverda.py +++ b/system/flcore/servers/serverda.py @@ -47,7 +47,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverdbe.py b/system/flcore/servers/serverdbe.py index 1a6b9bf0..8593b5b7 100644 --- a/system/flcore/servers/serverdbe.py +++ b/system/flcore/servers/serverdbe.py @@ -68,7 +68,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverditto.py b/system/flcore/servers/serverditto.py index af8d861e..84454e74 100644 --- a/system/flcore/servers/serverditto.py +++ b/system/flcore/servers/serverditto.py @@ -52,7 +52,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverdyn.py b/system/flcore/servers/serverdyn.py index 856218b1..9a986796 100644 --- a/system/flcore/servers/serverdyn.py +++ b/system/flcore/servers/serverdyn.py @@ -55,7 +55,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*50, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverfd.py b/system/flcore/servers/serverfd.py index ab2bd3a2..511d32ae 100644 --- a/system/flcore/servers/serverfd.py +++ b/system/flcore/servers/serverfd.py @@ -47,7 +47,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*50, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverfml.py b/system/flcore/servers/serverfml.py index 79632972..7f302f29 100644 --- a/system/flcore/servers/serverfml.py +++ b/system/flcore/servers/serverfml.py @@ -48,7 +48,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverfomo.py b/system/flcore/servers/serverfomo.py index 69e7e25b..7648de56 100644 --- a/system/flcore/servers/serverfomo.py +++ b/system/flcore/servers/serverfomo.py @@ -53,7 +53,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/servergc.py b/system/flcore/servers/servergc.py index c3fb40e1..e8a4fab4 100644 --- a/system/flcore/servers/servergc.py +++ b/system/flcore/servers/servergc.py @@ -77,7 +77,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/servergen.py b/system/flcore/servers/servergen.py index c2b0701b..abadc389 100644 --- a/system/flcore/servers/servergen.py +++ b/system/flcore/servers/servergen.py @@ -81,7 +81,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/servergh.py b/system/flcore/servers/servergh.py index 7e4ce338..30ad295d 100644 --- a/system/flcore/servers/servergh.py +++ b/system/flcore/servers/servergh.py @@ -54,7 +54,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/servergpfl.py b/system/flcore/servers/servergpfl.py index d37cbc96..39166993 100644 --- a/system/flcore/servers/servergpfl.py +++ b/system/flcore/servers/servergpfl.py @@ -56,7 +56,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverkd.py b/system/flcore/servers/serverkd.py index 2da5fd63..52f47208 100644 --- a/system/flcore/servers/serverkd.py +++ b/system/flcore/servers/serverkd.py @@ -54,7 +54,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break self.energy = self.T_start + ((1 + i) / self.global_rounds) * (self.T_end - self.T_start) diff --git a/system/flcore/servers/serverlc.py b/system/flcore/servers/serverlc.py index 68c7aab1..29f4793b 100644 --- a/system/flcore/servers/serverlc.py +++ b/system/flcore/servers/serverlc.py @@ -59,7 +59,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverlg.py b/system/flcore/servers/serverlg.py index 0e8cc409..783b85bb 100644 --- a/system/flcore/servers/serverlg.py +++ b/system/flcore/servers/serverlg.py @@ -48,7 +48,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverlocal.py b/system/flcore/servers/serverlocal.py index 3176021c..ab5aa28a 100644 --- a/system/flcore/servers/serverlocal.py +++ b/system/flcore/servers/serverlocal.py @@ -41,7 +41,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break diff --git a/system/flcore/servers/servermoon.py b/system/flcore/servers/servermoon.py index d6c0deb6..1eaab668 100644 --- a/system/flcore/servers/servermoon.py +++ b/system/flcore/servers/servermoon.py @@ -47,7 +47,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*50, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/servermtl.py b/system/flcore/servers/servermtl.py index c9257446..8952a202 100644 --- a/system/flcore/servers/servermtl.py +++ b/system/flcore/servers/servermtl.py @@ -56,7 +56,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break diff --git a/system/flcore/servers/serverntd.py b/system/flcore/servers/serverntd.py index 493ef582..3701eef1 100644 --- a/system/flcore/servers/serverntd.py +++ b/system/flcore/servers/serverntd.py @@ -46,7 +46,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverpac.py b/system/flcore/servers/serverpac.py index 92dfc966..e83058cf 100644 --- a/system/flcore/servers/serverpac.py +++ b/system/flcore/servers/serverpac.py @@ -68,7 +68,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*50, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverpcl.py b/system/flcore/servers/serverpcl.py index b3872b98..66d41794 100644 --- a/system/flcore/servers/serverpcl.py +++ b/system/flcore/servers/serverpcl.py @@ -50,7 +50,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*50, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverper.py b/system/flcore/servers/serverper.py index 357d3a58..24e2592b 100644 --- a/system/flcore/servers/serverper.py +++ b/system/flcore/servers/serverper.py @@ -47,7 +47,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverperavg.py b/system/flcore/servers/serverperavg.py index ed4ff015..5dc7be8f 100644 --- a/system/flcore/servers/serverperavg.py +++ b/system/flcore/servers/serverperavg.py @@ -48,7 +48,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverphp.py b/system/flcore/servers/serverphp.py index 26d11d64..298c3d4b 100644 --- a/system/flcore/servers/serverphp.py +++ b/system/flcore/servers/serverphp.py @@ -43,7 +43,8 @@ def train(self): self.call_dlg(i) self.aggregate_parameters() - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverproto.py b/system/flcore/servers/serverproto.py index c89333f4..0b4de70c 100644 --- a/system/flcore/servers/serverproto.py +++ b/system/flcore/servers/serverproto.py @@ -48,7 +48,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*50, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverprox.py b/system/flcore/servers/serverprox.py index 3d89cf36..62a3d5b0 100644 --- a/system/flcore/servers/serverprox.py +++ b/system/flcore/servers/serverprox.py @@ -47,7 +47,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverrep.py b/system/flcore/servers/serverrep.py index 6fa869ba..d68a3cd0 100644 --- a/system/flcore/servers/serverrep.py +++ b/system/flcore/servers/serverrep.py @@ -47,7 +47,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverrod.py b/system/flcore/servers/serverrod.py index b83643e9..b128882a 100644 --- a/system/flcore/servers/serverrod.py +++ b/system/flcore/servers/serverrod.py @@ -46,7 +46,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/flcore/servers/serverscaffold.py b/system/flcore/servers/serverscaffold.py index 38941db1..bcecd4a6 100644 --- a/system/flcore/servers/serverscaffold.py +++ b/system/flcore/servers/serverscaffold.py @@ -54,7 +54,8 @@ def train(self): self.Budget.append(time.time() - s_t) print('-'*25, 'time cost', '-'*25, self.Budget[-1]) - if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + es_metric = self.rs_val_acc if self.use_val and self.rs_val_acc else self.rs_test_acc + if self.auto_break and self.check_done(acc_lss=[es_metric], top_cnt=self.top_cnt): break print("\nBest accuracy.") diff --git a/system/main.py b/system/main.py index cbc28b56..c127bfd4 100644 --- a/system/main.py +++ b/system/main.py @@ -431,6 +431,13 @@ def run(args): help="Set this for text tasks. 80 for Shakespeare. 32000 for AG_News and SogouNews.") parser.add_argument('-ml', "--max_len", type=int, default=200) parser.add_argument('-fs', "--few_shot", type=int, default=0) + # validation split + parser.add_argument('-uv', "--use_val", type=bool, default=False, + help="Use a validation split from training data") + parser.add_argument('-vr', "--val_ratio", type=float, default=0.2, + help="Ratio of training data to use for validation") + parser.add_argument('-ss', "--split_seed", type=int, default=0, + help="Random seed for train/val split reproducibility") # practical parser.add_argument('-cdr', "--client_drop_rate", type=float, default=0.0, help="Rate for clients that train but drop out")