diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index 12460f1..0ef35fc 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -26,12 +26,12 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.ref }} # ${{ github.event.pull_request.head.sha }} - name: Setup Python 3.10 - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: "3.10" cache: 'pip' @@ -70,13 +70,13 @@ jobs: steps: - name: Checkout to latest changes - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ needs.formatting.outputs.new_sha }} fetch-depth: 0 - name: Set up Python 3.10 - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: "3.10" cache: 'pip' @@ -94,13 +94,13 @@ jobs: steps: - name: Checkout to latest changes - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ needs.formatting.outputs.new_sha }} fetch-depth: 0 - name: Set up Python 3.10 - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: "3.10" cache: 'pip' @@ -125,13 +125,13 @@ jobs: steps: - name: Checkout to latest changes - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ needs.formatting.outputs.new_sha }} fetch-depth: 0 - name: Set up Python 3.10 - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: "3.10" cache: 'pip' diff --git a/.gitignore b/.gitignore index 135e228..4d85b65 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,9 @@ results/ data/ logs/ external/ +benchmark*/ +*.png +*.csv other/ # C extensions *.so diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 0000000..f5e8dc9 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,159 @@ +import subprocess +import time +import os +import yaml +import sys +from itertools import product + +experiment_name = "experiment_all_10percent" +benchmark_dir = "benchmark_results" + + +model_names = [ + "logistic_regression", + "elastic_net", + "lsvc", + "random_forest", + "balanced_random_forest", + # # "weighted_random_forest", + "xgb" + ] + +datasets = [ + # "kaggle_hf", + "diabetes", + # "ukbb_cvd", + # "cvd" + ] + +num_clients = [ + 3, + 5, + 10, + 20 +] + +dirichlet_alpha = [ + None, + # 1.0, + # 0.7 +] + +data_normalization = ["global"] +n_features = [None] + +# Normalization experiment +# experiment_name = "normalization" +# benchmark_dir = "benchmark_results_normalization" +# model_names = ["logistic_regression"] +# datasets = ["diabetes", "ukbb_cvd"] +# num_clients = [10] +# dirichlet_alpha = [0.7, None] +# data_normalization = ["global", "local", None] + +# # Feature selection experiment +# experiment_name = "feature_selection" +# benchmark_dir = "benchmark_results_feature_selection" +# model_names = ["balanced_random_forest"] +# datasets = ["ukbb_cvd"] +# num_clients = [5,10] +# dirichlet_alpha = [0.7, None] +# data_normalization = ["global"] +# n_features = [10, 20, 35, 40, None] + +# # Number of Clients ablation experiment +experiment_name = "num_clients_ablation" +benchmark_dir = "benchmark_results_num_clients_ablation" +model_names = [ + "logistic_regression", + "elastic_net", + "lsvc", + "random_forest", + "balanced_random_forest", + "xgb" + ] +datasets = ["diabetes"] +num_clients = [3,5,10,20] +dirichlet_alpha = [0.7, 1.0, None] +data_normalization = ["global"] +n_features = [None] + +os.makedirs(benchmark_dir, exist_ok=True) + +with open("config.yaml", "r") as f: + config = yaml.safe_load(f) + + +config_path = os.path.join(benchmark_dir, "config.yaml") +log_file_path = os.path.join(benchmark_dir, "run_log.txt") + +with open(config_path, "w") as f: + yaml.dump(config, f) + +config['data_path'] = 'dataset/' +config['experiment']['log_path'] = benchmark_dir + +start_time = time.time() + +# Flatten the nested loops into a single iterator +parameters = product(datasets, num_clients, dirichlet_alpha, model_names, data_normalization, n_features) + +try: + for ds_name, n_client, alpha, m_name, norm, n_feat in parameters: + print(f"Running benchmark: {ds_name}, {m_name}, clients: {n_client}, alpha: {alpha}, normalization: {norm}, features: {n_feat}") + + # Update config dictionary + config.update({ + 'model': m_name, + 'dataset': ds_name, + 'num_clients': n_client, + 'dirichlet_alpha': alpha, + 'data_normalization': norm, + 'n_features': n_feat + }) + if "forest" in m_name: + config['num_rounds'] = 1 # Set number of jobs for parallel processing + + config['experiment']['name'] = f"{experiment_name}_{ds_name}_{m_name}_c{n_client}_a{alpha}_norm{norm}_feat{n_feat}" + + with open(config_path, "w") as f: + yaml.dump(config, f) + + # subprocess.run is cleaner for synchronous execution + # Use a list for the command to avoid shell=True security/cleanup issues + cmd = f"python repeated.py {config_path} | tee {log_file_path}" + subprocess.run(cmd, shell=True, check=True) + +except KeyboardInterrupt: + print("\nBenchmark interrupted by user. Exiting...") + sys.exit(1) + + + +# # Run benchmark experiments +# # Iterate over datasets and models +# for dataset_name in datasets: +# for num_client in num_clients: +# for alpha in dirichlet_alpha: +# for model_name in model_names: +# print(f"Running benchmark for dataset: {dataset_name}, model: {model_name}") +# config['experiment']['name'] = f"{experiment_name}_{dataset_name}_{model_name}_clients_{num_client}_alpha_{alpha}" +# config['model'] = model_name +# config['dataset'] = dataset_name +# config['num_clients'] = num_client +# config['dirichlet_alpha'] = alpha + +# with open(config_path, "w") as f: +# yaml.dump(config, f) + +# try: +# run_process = subprocess.Popen(f"python repeated.py {config_path} | tee {log_file_path}", shell=True) +# run_process.wait() + +# except KeyboardInterrupt: +# run_process.terminate() +# run_process.wait() +# break + +total_time = time.time() - start_time +print("Benchmark experiments finished in", total_time/60, " minutes") diff --git a/config.yaml b/config.yaml index 4c561dc..917fdc1 100644 --- a/config.yaml +++ b/config.yaml @@ -10,8 +10,10 @@ ################################################################################ ############## Dataset type to use -# Possible values: , kaggle_hf, mnist, dt4h_format -dataset: dt4h_format +# Possible values: , kaggle_hf, diabetes, mnist, dt4h_format +dataset: kaggle_hf +# dataset: ukbb_cvd +# dataset: diabetes #custom #libsvm #kaggle_hf @@ -33,30 +35,54 @@ train_size: 0.7 # ****** * * * * * * * * * * * * * * * * * * * * ******************* ############## Number of clients (data centers) to use for training -num_clients: 1 +num_clients: 4 ############## Model type # Possible values: logistic_regression, lsvc, elastic_net, random_forest, weighted_random_forest, xgb # See README.md for a full list of supported models -model: random_forest +# model: xgb +model: logistic_regression +# model: random_forest #logistic_regression #random_forest ############## Training length -num_rounds: 50 +num_rounds: 10 ############## Metric to select the best model # Possible values: accuracy, balanced_accuracy, f1, precision, recall -checkpoint_selection_metric: precision +# checkpoint_selection_metric: precision +checkpoint_selection_metric: balanced_accuracy #balanced_accuracy ############## Experiment logging experiment: - name: experiment_1 + name: experiment_kaggle_standard log_path: logs debug: true +################################################################################ +# Federated Data Preprocessing +################################################################################ + +# Strategy to calculate data preprocessing parameters between clients. +# It covers missing data imputation, label encoding, normalization and feature selection +# It can be one of: + # "reference" - use reference center to calculate all parameters (largest or random) + # "equal_aggregate" - aggregate parameters from all clients based on mean and voting disregarding center size + # "weighted_aggregate" - aggregate parameters from all clients based on weighted mean and voting + +data_preprocessing_method: "equal_aggregate" +# data_preprocessing_method: "reference" + +# Toggle data normalization (Standard scaler) based on largest center (global) or local client +data_normalization: "global" + +# Determine target for feature selection number +n_features: Null + + ################################################################################ # Aggregation methods ################################################################################ @@ -87,9 +113,13 @@ smoothWeights: linear_models: n_features: 9 + +dirichlet_alpha: Null + # Random Forest random_forest: balanced_rf: true + tree_num: 300 # Weighted Random Forest weighted_random_forest: @@ -101,7 +131,7 @@ xgb: batch_size: 32 num_iterations: 100 task_type: BINARY - tree_num: 500 + tree_num: 300 held_out_center_id: -1 @@ -113,6 +143,6 @@ seed: 42 local_port: 8081 -data_path: dataset/icrc-dataset/ +data_path: dataset/ production_mode: False # Turn on to use environment variables such as data path, server address, certificates etc. diff --git a/flcore/client_selector.py b/flcore/client_selector.py index 76fa3d5..c0c616b 100644 --- a/flcore/client_selector.py +++ b/flcore/client_selector.py @@ -1,7 +1,7 @@ import numpy as np import flcore.models.linear_models as linear_models -import flcore.models.xgb as xgb +import flcore.models.xgblr as xgblr import flcore.models.random_forest as random_forest import flcore.models.weighted_random_forest as weighted_random_forest @@ -11,14 +11,14 @@ def get_model_client(config, data, client_id): if model in ("logistic_regression", "elastic_net", "lsvc"): client = linear_models.client.get_client(config,data,client_id) - elif model == "random_forest": + elif model in ("random_forest", "balanced_random_forest"): client = random_forest.client.get_client(config,data,client_id) elif model == "weighted_random_forest": client = weighted_random_forest.client.get_client(config,data,client_id) - elif model == "xgb": - client = xgb.client.get_client(config, data, client_id) + elif model == "xgblr": + client = xgblr.client.get_client(config, data, client_id) else: raise ValueError(f"Unknown model: {model}") diff --git a/flcore/compile_results.py b/flcore/compile_results.py index 8270d9b..4f8d7b8 100644 --- a/flcore/compile_results.py +++ b/flcore/compile_results.py @@ -8,6 +8,7 @@ def compile_results(experiment_dir: str): + print(f"Compiling results for experiment in {experiment_dir}") per_client_metrics = {} held_out_metrics = {} fit_metrics = {} @@ -21,6 +22,8 @@ def compile_results(experiment_dir: str): elif config['dataset'] == 'kaggle_hf': center_names = ['Cleveland', 'Hungary', 'VA', 'Switzerland'] + else: + center_names = [f"center_{i+1}" for i in range(config['num_clients'])] writer = open(f"{experiment_dir}/metrics.txt", "w") @@ -48,7 +51,10 @@ def compile_results(experiment_dir: str): history = yaml.safe_load(open(os.path.join(fold_dir, "history.yaml"), "r")) selection_metric = 'val '+ config['checkpoint_selection_metric'] + # selection_metric = config['checkpoint_selection_metric'] best_round= int(np.argmax(history['metrics_distributed'][selection_metric])) + # best_round = -1 + print(f"Best round for {directory} based on {selection_metric}: {best_round}") # client_order = history['metrics_distributed']['per client client_id'][best_round] client_order = history['metrics_distributed']['per client n samples'][best_round] for logs in history.keys(): @@ -98,7 +104,8 @@ def compile_results(experiment_dir: str): fit_metrics[metric] = np.vstack((fit_metrics[metric], values_history[best_round])) - execution_stats = ['client_id', 'round_time [s]', 'n samples', 'training_time [s]'] + # execution_stats = ['client_id', 'round_time [s]', 'n samples', 'training_time [s]'] + execution_stats = ['client_id', 'round_time [s]', 'n samples'] # Calculate mean and std for per client metrics writer.write(f"{'Evaluation':.^100} \n\n") writer.write(f"\n{'Test set:'} \n") @@ -124,12 +131,16 @@ def compile_results(experiment_dir: str): writer.write(f"\n{'Federated finetuned locally:'} \n") personalized_section = True - # Calculate general mean and std - mean = np.average(per_client_metrics[metric]) - # Calculate std of the average metric between experiment runs - std = np.std(np.mean(per_client_metrics[metric], axis=1)) - per_client_mean = np.around(np.mean(per_client_metrics[metric], axis=0), 3) - per_client_std = np.around(np.std(per_client_metrics[metric], axis=0), 3) + # Calculate general weighted mean and std + # Weighted by number of samples in each client + weights = np.array(per_client_metrics['n samples'][0]) + per_client_mean = np.mean(per_client_metrics[metric], axis=0) + per_client_std = np.std(per_client_metrics[metric], axis=0) + mean = np.average(per_client_mean, weights=weights) + std = np.sqrt(np.average((per_client_mean - mean) ** 2, weights=weights)) + # Round per client mean and std to 3 decimals + per_client_mean = np.around(per_client_mean, 3) + per_client_std = np.around(per_client_std, 3) if metric not in execution_stats: writer.write(f"{metric:<30}: {mean:<6.3f} ±{std:<6.3f} \t\t\t|| Per client {metric} {per_client_mean} ({per_client_std})\n".replace("\n", "")+"\n") for i, _ in enumerate(per_client_mean): @@ -161,25 +172,25 @@ def compile_results(experiment_dir: str): centralized_metrics[metric] = held_out_metrics[metric] held_out_metrics.pop(metric, None) - writer.write(f"\n{'Held out set evaluation':.^100} \n\n") - for metric in held_out_metrics: - center = int(held_out_metrics['client_id'][0]) - center = center_names[center]+' (held out)' - mean = np.average(held_out_metrics[metric]) - std = np.std(held_out_metrics[metric]) - - writer.write(f"{metric:<30}: {mean:<6.3f} ±{std:<6.3f}\n") - if center not in csv_dict: - csv_dict[center] = {} - csv_dict[center][metric] = mean - csv_dict[center][metric+'_std'] = std - - # Calculate mean and std for centralized metrics - writer.write(f"\n{'Centralized evaluation':.^100} \n\n") - for metric in centralized_metrics: - mean = np.average(centralized_metrics[metric]) - std = np.std(centralized_metrics[metric]) - writer.write(f"{metric:<30}: {mean:<6.3f} ±{std:<6.3f}\n") + # writer.write(f"\n{'Held out set evaluation':.^100} \n\n") + # for metric in held_out_metrics: + # center = int(held_out_metrics['client_id'][0]) + # center = center_names[center]+' (held out)' + # mean = np.average(held_out_metrics[metric]) + # std = np.std(held_out_metrics[metric]) + + # writer.write(f"{metric:<30}: {mean:<6.3f} ±{std:<6.3f}\n") + # if center not in csv_dict: + # csv_dict[center] = {} + # csv_dict[center][metric] = mean + # csv_dict[center][metric+'_std'] = std + + # # Calculate mean and std for centralized metrics + # writer.write(f"\n{'Centralized evaluation':.^100} \n\n") + # for metric in centralized_metrics: + # mean = np.average(centralized_metrics[metric]) + # std = np.std(centralized_metrics[metric]) + # writer.write(f"{metric:<30}: {mean:<6.3f} ±{std:<6.3f}\n") writer.close() @@ -194,7 +205,7 @@ def compile_results(experiment_dir: str): # Write to csv df.to_csv(f"{experiment_dir}/per_center_results.csv", index=True) - generate_report(experiment_dir) + # generate_report(experiment_dir) if __name__ == "__main__": diff --git a/flcore/datasets.py b/flcore/datasets.py index 699c4a0..68d048f 100644 --- a/flcore/datasets.py +++ b/flcore/datasets.py @@ -12,18 +12,569 @@ import pandas as pd from sklearn.datasets import load_svmlight_file -from sklearn.preprocessing import OrdinalEncoder, MinMaxScaler,StandardScaler +from sklearn.preprocessing import OrdinalEncoder, LabelEncoder, MinMaxScaler, StandardScaler from sklearn.model_selection import KFold, StratifiedShuffleSplit, train_test_split from sklearn.utils import shuffle -from sklearn.feature_selection import SelectKBest, f_classif +from sklearn.feature_selection import SelectKBest, f_classif, mutual_info_classif +from sklearn.ensemble import RandomForestClassifier +from ucimlrepo import fetch_ucirepo +import pickle -from flcore.models.xgb.utils import TreeDataset, do_fl_partitioning, get_dataloader + +from flcore.models.xgblr.utils import TreeDataset, do_fl_partitioning, get_dataloader XY = Tuple[np.ndarray, np.ndarray] Dataset = Tuple[XY, XY] +def calculate_preprocessing_params(subset_data, subset_target, n_features=None, feature_selection_method='mutual_info'): + """ + Calculate preprocessing parameters based on a subset of data (reference center) + + Args: + subset_data: DataFrame containing the subset data + subset_target: Series containing the target variable + n_features: Number of features to select (None for all features) + feature_selection_method: Method for feature selection ('mutual_info', 'f_classif', 'random_forest') + + Returns: + dict: Preprocessing parameters (imputation values, mean, std, label_encoders, feature_selector) + """ + data_copy = subset_data.copy() + target_copy = subset_target.copy() + + # Calculate imputation parameters + imputation_params = {} + label_encoders = {} + + for column in data_copy.columns: + # Handle missing values + if data_copy[column].isna().any(): + if data_copy[column].dtype in ['float64', 'int64']: + imputation_params[column] = data_copy[column].median() + else: + imputation_params[column] = data_copy[column].mode()[0] if not data_copy[column].mode().empty else 0 + + # Store label encoders for categorical variables + if data_copy[column].dtype == 'object': + le = LabelEncoder() + # Fit on non-null values only + non_null_data = data_copy[column].dropna() + if len(non_null_data) > 0: + # Add 'unknown' category for unseen labels + classes = np.append(non_null_data.astype(str).unique(), 'unknown') + le.fit(classes) + label_encoders[column] = le + + # Calculate normalization parameters for ALL columns (after conversion to numerical) + numeric_data = data_copy.copy() + + # Temporarily convert categorical to numerical for normalization parameter calculation + for column in numeric_data.columns: + if numeric_data[column].dtype == 'object': + # Use simple integer encoding for parameter calculation + numeric_data[column] = pd.Categorical(numeric_data[column]).codes + # Handle missing values temporarily for parameter calculation + if column in imputation_params: + numeric_data[column].fillna(imputation_params[column], inplace=True) + + # Convert all to numeric + numeric_data = numeric_data.apply(pd.to_numeric, errors='coerce') + + # Calculate normalization parameters + normalization_params = { + 'mean': numeric_data.mean().to_dict(), + 'std': numeric_data.std().to_dict() + } + + # Handle zero standard deviation + for col, std_val in normalization_params['std'].items(): + if std_val == 0 or np.isnan(std_val): + normalization_params['std'][col] = 1.0 + + # Feature Selection + feature_selector = None + selected_features = None + feature_scores = None + + if n_features is not None: + if n_features < len(numeric_data.columns): + # Prepare data for feature selection + X_temp = numeric_data.fillna(numeric_data.median()) + y_temp = target_copy + + # Handle any remaining NaN values + X_temp = X_temp.fillna(0) + + if feature_selection_method == 'mutual_info': + selector = SelectKBest(score_func=mutual_info_classif, k=min(n_features, X_temp.shape[1])) + elif feature_selection_method == 'f_classif': + selector = SelectKBest(score_func=f_classif, k=min(n_features, X_temp.shape[1])) + elif feature_selection_method == 'random_forest': + # Use Random Forest feature importance + rf = RandomForestClassifier(n_estimators=100, random_state=42) + rf.fit(X_temp, y_temp) + importances = rf.feature_importances_ + indices = np.argsort(importances)[::-1] + selected_indices = indices[:min(n_features, len(indices))] + + # Create a custom selector object + class CustomSelector: + def __init__(self, selected_indices, feature_names): + self.selected_indices = selected_indices + self.feature_names = feature_names + self.scores_ = importances + + def transform(self, X): + if isinstance(X, pd.DataFrame): + return X.iloc[:, self.selected_indices] + else: + return X[:, self.selected_indices] + + def get_support(self, indices=False): + if indices: + return self.selected_indices + else: + mask = np.zeros(len(self.feature_names), dtype=bool) + mask[self.selected_indices] = True + return mask + + selector = CustomSelector(selected_indices, numeric_data.columns.tolist()) + feature_scores = importances + else: + raise ValueError("feature_selection_method must be 'mutual_info', 'f_classif', or 'random_forest'") + + if feature_selection_method != 'random_forest': + selector.fit(X_temp, y_temp) + feature_scores = selector.scores_ + + feature_selector = selector + selected_features = numeric_data.columns[selector.get_support()].tolist() + + print(f"Feature selection: Selected {len(selected_features)} most informative features") + if feature_scores is not None: + # Print top feature scores + feature_importance = pd.DataFrame({ + 'feature': numeric_data.columns, + 'score': feature_scores + }).sort_values('score', ascending=False) + print("Top 5 features:") + for i, (_, row) in enumerate(feature_importance.head().iterrows()): + print(f" {i+1}. {row['feature']}: {row['score']:.4f}") + + return { + 'imputation': imputation_params, + 'normalization': normalization_params, + 'label_encoders': label_encoders, + 'feature_selector': feature_selector, + 'selected_features': selected_features, + 'n_features': n_features + } + +def apply_preprocessing(subset_data, preprocessing_params, normalization="global"): + """ + Apply preprocessing to a subset using pre-calculated parameters from reference center + + Args: + subset_data: DataFrame to preprocess + preprocessing_params: dict from calculate_preprocessing_params + + Returns: + tuple: (preprocessed_data, feature_names) + """ + data_copy = subset_data.copy() + + # Step 1: Handle missing values using reference center parameters + for column in data_copy.columns: + if column in preprocessing_params['imputation']: + missing_mask = data_copy[column].isna() + if missing_mask.any(): + data_copy.loc[missing_mask, column] = preprocessing_params['imputation'][column] + + # Step 2: Convert all features to numerical using reference center label encoders + for column in data_copy.columns: + if column in preprocessing_params['label_encoders']: + le = preprocessing_params['label_encoders'][column] + # Convert to string and handle unseen labels + encoded_values = [] + for val in data_copy[column]: + if pd.isna(val): + encoded_values.append(-1) # Special value for missing + else: + str_val = str(val) + if str_val in le.classes_: + encoded_values.append(le.transform([str_val])[0]) + else: + # Map unseen labels to 'unknown' class + encoded_values.append(le.transform(['unknown'])[0]) + data_copy[column] = encoded_values + elif data_copy[column].dtype == 'object': + # Fallback: use categorical codes for any remaining object columns + data_copy[column] = pd.Categorical(data_copy[column]).codes + + # Ensure all data is numerical + data_copy = data_copy.apply(pd.to_numeric, errors='coerce') + + # Step 3: Normalize ALL features using global parameters if enabled + if normalization == "global": + normalization_params = preprocessing_params['normalization'] + for column in data_copy.columns: + if column in normalization_params['mean']: + mean_val = normalization_params['mean'][column] + std_val = normalization_params['std'][column] + data_copy[column] = (data_copy[column] - mean_val) / std_val + # print("Applied global normalization during preprocessing.") + elif normalization == "local": + # Calculate local normalization parameters + local_mean = data_copy.mean() + local_std = data_copy.std() + for column in data_copy.columns: + mean_val = local_mean[column] + std_val = local_std[column] if local_std[column] != 0 else 1.0 + data_copy[column] = (data_copy[column] - mean_val) / std_val + # print("Applied local normalization during preprocessing.") + elif normalization is not None: + raise ValueError("Data normalization method must be 'global', 'local', or None") + + # Step 4: Apply feature selection if enabled + if preprocessing_params['feature_selector'] is not None: + selector = preprocessing_params['feature_selector'] + data_copy = pd.DataFrame(selector.transform(data_copy), + columns=preprocessing_params['selected_features']) + + return data_copy, data_copy.columns.tolist() + +def partition_data_dirichlet(labels, num_centers, alpha=1.0, min_samples_per_class=10): + """ + Partition data among centers using Dirichlet distribution + + Args: + labels: Array of class labels + num_centers: Number of centers to partition into + alpha: Dirichlet concentration parameter + min_samples_per_class: Minimum number of samples per class per center + """ + unique_labels = np.unique(labels) + n_samples = len(labels) + n_classes = len(unique_labels) + + if not alpha: + alpha = -1.0 + + if alpha <= 0: + # IID partitioning + shuffled_indices = np.random.permutation(n_samples) + center_indices = np.array_split(shuffled_indices, num_centers) + center_indices = [indices.tolist() for indices in center_indices] + # check lengths of each center + center_lengths = [len(indices) for indices in center_indices] + return center_indices + + # Create assignment matrix + center_indices = [[] for _ in range(num_centers)] + + # For each class, distribute samples to centers using Dirichlet distribution + for class_idx in unique_labels: + class_mask = (labels == class_idx) + class_indices = np.where(class_mask)[0] + n_class_samples = len(class_indices) + + if n_class_samples > 0: + # Generate Dirichlet distribution for this class + proportions = np.random.dirichlet(np.repeat(alpha, num_centers)) + proportions = proportions / proportions.sum() + + # Calculate number of samples for each center + center_samples = (proportions * n_class_samples).astype(int) + + # Ensure minimum samples per class per center + for i in range(num_centers): + if center_samples[i] < min_samples_per_class: + center_samples[i] = min(min_samples_per_class, n_class_samples // num_centers) + + # Adjust for rounding errors and minimum constraints + total_assigned = center_samples.sum() + diff = n_class_samples - total_assigned + if diff > 0: + # Distribute remaining samples + available_centers = [i for i in range(num_centers) if center_samples[i] < n_class_samples] + if available_centers: + additions = np.random.choice(available_centers, diff, replace=True) + for i in additions: + center_samples[i] += 1 + elif diff < 0: + # Remove excess samples + excess_centers = np.argsort(center_samples)[::-1] # Sort by size descending + for i in excess_centers: + if diff >= 0: + break + can_remove = center_samples[i] - min_samples_per_class + if can_remove > 0: + remove = min(can_remove, -diff) + center_samples[i] -= remove + diff += remove + + # Shuffle and assign indices + np.random.shuffle(class_indices) + ptr = 0 + for center_id in range(num_centers): + if center_samples[center_id] > 0: + center_indices[center_id].extend( + class_indices[ptr:ptr + center_samples[center_id]] + ) + ptr += center_samples[center_id] + + # Shuffle indices within each center + for center_id in range(num_centers): + np.random.shuffle(center_indices[center_id]) + + return center_indices +def select_reference_center(all_center_data, method='largest'): + """ + Select which center to use for calculating preprocessing parameters + """ + if method == 'largest': + center_sizes = [len(X) for X, y in all_center_data] + reference_center_id = np.argmax(center_sizes) + print(f"Selected largest center (ID: {reference_center_id}) with {center_sizes[reference_center_id]} samples") + + elif method == 'random': + reference_center_id = np.random.randint(0, len(all_center_data)) + print(f"Selected random center (ID: {reference_center_id})") + else: + raise ValueError("Method must be 'largest' or 'random'") + + return reference_center_id + +def aggregate_preprocessing_params(preprocessing_params_list, center_sizes, method='weighted_aggregate'): + """ + Aggregate preprocessing parameters from multiple centers using weighted aggregation. + + Args: + preprocessing_params_list: List of preprocessing parameter dictionaries from each center + center_sizes: List of center sizes (number of samples) + + Returns: + dict: Aggregated preprocessing parameters + """ + if not preprocessing_params_list: + raise ValueError("preprocessing_params_list cannot be empty") + + if "equal" in method: + # Equal weights + center_sizes = [1 for _ in center_sizes] + print("Using equal weights for aggregation of preprocessing parameters.") + + total_size = sum(center_sizes) + weights = [size / total_size for size in center_sizes] + + aggregated = { + 'imputation': {}, + 'normalization': {'mean': {}, 'std': {}}, + 'label_encoders': {}, + 'feature_selector': None, + 'selected_features': [], + 'n_features': preprocessing_params_list[0]['n_features'] # Assume same for all + } + + # Collect all columns + all_columns = set() + for params in preprocessing_params_list: + all_columns.update(params['imputation'].keys()) + all_columns.update(params['normalization']['mean'].keys()) + all_columns.update(params['label_encoders'].keys()) + + # Aggregate imputation + for col in all_columns: + numeric_values = [] + categorical_values = [] + weights_num = [] + weights_cat = [] + for params, weight in zip(preprocessing_params_list, weights): + if col in params['imputation']: + value = params['imputation'][col] + if isinstance(value, (int, float)) and not pd.isna(value): + numeric_values.append(value) + weights_num.append(weight) + else: + categorical_values.append(value) + weights_cat.append(weight) + + if numeric_values: + # Weighted mean for numeric + aggregated['imputation'][col] = sum(v * w for v, w in zip(numeric_values, weights_num)) / sum(weights_num) + elif categorical_values: + # Most frequent for categorical (simple mode) + from collections import Counter + counter = Counter(categorical_values) + aggregated['imputation'][col] = counter.most_common(1)[0][0] + + # Aggregate normalization + for col in all_columns: + means = [] + stds = [] + weights_norm = [] + for params, weight in zip(preprocessing_params_list, weights): + if col in params['normalization']['mean']: + means.append(params['normalization']['mean'][col]) + stds.append(params['normalization']['std'][col]) + weights_norm.append(weight) + + if means: + global_mean = sum(m * w for m, w in zip(means, weights_norm)) / sum(weights_norm) + aggregated['normalization']['mean'][col] = global_mean + + # Calculate global std: sqrt( sum(w_i * var_i) + sum(w_i * (mean_i - global_mean)^2) ) + variances = [s ** 2 for s in stds] + weighted_var_sum = sum(v * w for v, w in zip(variances, weights_norm)) + mean_diff_sq = [(m - global_mean) ** 2 for m in means] + weighted_mean_var = sum(md * w for md, w in zip(mean_diff_sq, weights_norm)) + global_var = weighted_var_sum + weighted_mean_var + global_std = np.sqrt(global_var) if global_var > 0 else 1.0 + aggregated['normalization']['std'][col] = global_std + + # For label_encoders, take from the largest center (simplest approach) + max_size_idx = center_sizes.index(max(center_sizes)) + aggregated['label_encoders'] = preprocessing_params_list[max_size_idx]['label_encoders'].copy() + + # Aggregate selected_features by frequency + if preprocessing_params_list[0]['selected_features']: + from collections import Counter + feature_counts = Counter() + for params, weight in zip(preprocessing_params_list, weights): + for feature in params['selected_features']: + feature_counts[feature] += weight + + # Select top n_features most frequent + n_features = aggregated['n_features'] + if n_features: + selected = [feat for feat, _ in feature_counts.most_common(n_features)] + aggregated['selected_features'] = selected + + return aggregated + +def prepare_dataset(X, y, center_id, config, center_indices=None): + """ + Load and preprocess raw dataset for federated learning with feature selection + + This function will extract the following config values: + center_id: Identifier for the federated node + num_centers: Total number of federated centers + alpha: Dirichlet concentration parameter for data partitioning + reference_method: How to select reference center ('largest' or 'random') + aggregation_method: How to aggregate preprocessing params ('reference' or 'weighted_aggregate') + global_preprocessing_params: Precomputed parameters (if None, will calculate) + n_features: Number of features to select (None for all features) + feature_selection_method: Method for feature selection + + Returns: + tuple: X_train, y_train, X_test, y_test + """ + + num_centers = config.get("num_clients", 5) + alpha = config.get("dirichlet_alpha", 1.0) + reference_method = config.get("reference_center_method", "largest") + preprocessing_method = config.get("data_preprocessing_method", "reference") + min_samples_per_class = config.get("min_samples_per_class", 10) + global_preprocessing_params = None + n_features = config.get("n_features", 20) + feature_selection_method = config.get("feature_selection_method", "mutual_info") + normalization_method = config.get("data_normalization", "global") + + np.random.seed(42) # For reproducibility of partitioning and reference selection + + # Convert target to binary classification if needed + if y.nunique() > 2: + y_binary = (y > y.median()).astype(int) + else: + y_binary = y + + if not center_indices: + # Partition data using Dirichlet distribution + all_center_indices = partition_data_dirichlet(y_binary.values, num_centers, alpha, min_samples_per_class) + else: + all_center_indices = center_indices + + # Get all center data for reference selection + all_center_data = [] + for i in range(num_centers): + if i < len(all_center_indices) and len(all_center_indices[i]) > 0: + X_center = X.iloc[all_center_indices[i]] + all_center_data.append((X_center, y_binary.iloc[all_center_indices[i]])) + else: + all_center_data.append((pd.DataFrame(), pd.Series())) + + # Calculate or use global preprocessing parameters + if global_preprocessing_params is None: + if preprocessing_method == 'reference': + # Select reference center and calculate parameters + reference_center_id = select_reference_center(all_center_data, reference_method) + X_reference = all_center_data[reference_center_id][0] + y_reference = all_center_data[reference_center_id][1] + + if len(X_reference) == 0: + # Fallback: use full dataset if reference center is empty + X_reference = X + y_reference = y_binary + print("Warning: Reference center empty, using full dataset for preprocessing parameters") + + global_preprocessing_params = calculate_preprocessing_params( + X_reference, y_reference, n_features=n_features, feature_selection_method=feature_selection_method + ) + elif "aggregate" in preprocessing_method: + # Calculate parameters for each center and aggregate + preprocessing_params_list = [] + center_sizes = [] + for X_center, y_center in all_center_data: + if len(X_center) > 0: + params = calculate_preprocessing_params( + X_center, y_center, n_features=n_features, feature_selection_method=feature_selection_method + ) + preprocessing_params_list.append(params) + center_sizes.append(len(X_center)) + + if preprocessing_params_list: + global_preprocessing_params = aggregate_preprocessing_params(preprocessing_params_list, center_sizes, method=preprocessing_method) + else: + # Fallback + global_preprocessing_params = calculate_preprocessing_params( + X, y_binary, n_features=n_features, feature_selection_method=feature_selection_method + ) + print("Warning: No valid centers, using full dataset for preprocessing parameters") + else: + raise ValueError("aggregation_method must be 'reference', 'equal_aggregate' or 'weighted_aggregate'") + + print("Calculated global preprocessing parameters using", preprocessing_method) + + if center_id is not None: + # Get indices for the requested center + if center_id >= len(all_center_indices) or len(all_center_indices[center_id]) == 0: + raise ValueError(f"Center ID {center_id} has no data assigned") + + center_indices = all_center_indices[center_id] + X_center = X.iloc[center_indices].reset_index(drop=True) + y_center = y.iloc[center_indices].reset_index(drop=True) + else: + # Use full dataset if no center_id specified + X_center = X + y_center = y + + # Split into train/test for this center + if len(X_center) > 1: + X_train, X_test, y_train, y_test = train_test_split( + X_center, y_center, test_size=0.2, random_state=config['seed'], stratify=y_center + ) + else: + X_train, y_train = X_center, y_center + X_test, y_test = X_center.iloc[:0], y_center.iloc[:0] + + # Apply GLOBAL preprocessing parameters to both train and test sets + X_train_processed, feature_names = apply_preprocessing(X_train, global_preprocessing_params, normalization=normalization_method) + X_test_processed, _ = apply_preprocessing(X_test, global_preprocessing_params, normalization=normalization_method) + + return X_train_processed, y_train, X_test_processed, y_test + def load_mnist(center_id=None, num_splits=5): """Loads the MNIST dataset using OpenML. OpenML dataset link: https://www.openml.org/d/554 @@ -67,109 +618,51 @@ def load_mnist(center_id=None, num_splits=5): return (x_train, y_train), (x_test, y_test) - -def load_cvd(data_path, center_id=None) -> Dataset: +def load_cvd(data_path, center_id, config) -> Dataset: id = center_id - if center_id == 1: - file_name = data_path+'data_center1.csv' - elif center_id == 2: - file_name = data_path+'data_center2.csv' - elif center_id == 3: - file_name = data_path+'data_center3.csv' - else: - file_name = data_path+'data_center3.csv' - - if id == None: - # id = 'All' - data_centers = ['All'] - else: - data_centers = [id] - - X_train_list, y_train_list = [], [] - X_test_list, y_test_list = [], [] - test_index_list = [] - train_index_list = [] - for id in data_centers: - # file_name = os.path.join(data_path, f"data_center{id}.csv") - # file_name = os.path.join(data_path, file_name) + code_id = "f_eid" + code_outcome = "Eval" - code_id = "f_eid" - code_outcome = "Eval" + data = pd.read_csv(os.path.join(data_path, "data_centerAll.csv")) + X_data = data.drop([code_id, code_outcome], axis=1) + y_data = data[code_outcome] - data = pd.read_csv(file_name) - X_data = data.drop([code_id, code_outcome], axis=1) - y_data = data[code_outcome] - f_eid = data[code_id] - - # Split the data - sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=None) - train_index, test_index = next(sss.split(X_data, y_data)) - X_test = X_data.iloc[test_index, :] - X_train = X_data.iloc[train_index, :] - y_test, y_train = y_data.iloc[test_index], y_data.iloc[train_index] - # We save the names - f_eid.iloc[test_index] - f_eid.iloc[train_index] - - X_train_list.append(X_train) - y_train_list.append(y_train) - X_test_list.append(X_test) - y_test_list.append(y_test) - train_index_list.append(train_index) - test_index_list.append(test_index) - - X_train = pd.concat(X_train_list) - y_train = pd.concat(y_train_list) - X_test = pd.concat(X_test_list) - y_test = pd.concat(y_test_list) - train_index = np.concatenate(train_index_list) - test_index = np.concatenate(test_index_list) - - # Verify set difference, data centers overlap - # print(len(train_index.tolist())) - # print(len(test_index.tolist())) - # train_set = set(train_index.tolist()) - # test_set = set(test_index.tolist()) - # diff = train_set.intersection(test_set) - # print(len(train_set)) - # print(len(test_set)) - # print( len(diff) ) - # print(f"SUBSET {id}") - # train_unique = np.unique(y_train, return_counts=True) - # test_unique = np.unique(y_test, return_counts=True) - # train_max_acc = train_unique[1][0]/len(y_train) - # test_max_acc = test_unique[1][0]/len(y_test) - # print(np.unique(y_train, return_counts=True)) - # print(np.unique(y_test, return_counts=True)) - # print(train_max_acc) - # print(test_max_acc) + X_train_processed, y_train, X_test_processed, y_test = prepare_dataset(X_data, y_data, center_id, config) - return (X_train, y_train), (X_test, y_test) + return (X_train_processed, y_train), (X_test_processed, y_test) def load_ukbb_cvd(data_path, center_id, config) -> Dataset: + """ + Load UKBB CVD mortality dataset + + Args: + data_path: Path to the dataset + center_id: ID of the center to load + config: Configuration dictionary - seed = config["seed"] + """ data_path = os.path.join(data_path, "CVDMortalityData.csv") data = pd.read_csv(data_path) - # print(len(data)) - center_key = 'f.54.0.0' patient_key = 'f.eid' label_key = 'label' - # center_id = None - # center_id = 1 - preprocessing_data = data.loc[(data[center_key] == 1)] - # center_id = None - if center_id is not None: - center_id = center_id - if center_id == 19: - center_id = 21 - elif center_id == 21: - center_id = 19 - data = data.loc[(data[center_key] == center_id)] + #Create a list of lists for each center_key with row indexes from that center + center_keys = sorted(list(data[center_key].unique())) + # convert to list of ints + center_keys = set(int(center) for center in center_keys) + center_indices = [] + for center in center_keys: + center_indices.append(data.loc[(data[center_key] == center)].index.tolist()) + + X = data.drop([label_key, center_key, patient_key], axis=1) + y = data[label_key] + + X_train, y_train, X_test, y_test = prepare_dataset(X, y, center_id, config, center_indices) + + # print("Center ", center_id, "with ", len(X_train), " samples, of which positive samples are ", len(X_train.loc[y_train == 1])) # center_names = ['Bristol', 'Newcastle', 'Oxford', 'Stockport (pilot)', 'Reading', # 'Middlesborough', 'Leeds', 'Liverpool', 'Nottingham', 'Glasgow', 'Croydon', @@ -182,228 +675,54 @@ def load_ukbb_cvd(data_path, center_id, config) -> Dataset: # center_dict = list(center_dict.values()) # print(center_dict) - # xx - - # for i in range(0, 23): - # center_data = data.loc[(data[center_key] == i)] - # print(f'Center ID: {i} {center_dict[i]} with {len(center_data)} samples of which positive samples are {len(center_data.loc[center_data[label_key] == 1])})') - # xx - # features = data.drop([label_key, center_key, patient_key], axis=1) - # target = data[label_key] - - # print(len(data)) - # print(features.head()) - # print(f'Center ID: {center_id} with {len(data)} samples of which positive samples are {len(data.loc[data[label_key] == 1])})') - # print(target.head()) - - def get_preprocessing_params(preprocessing_data): - - data = preprocessing_data - features = data.drop([label_key, center_key, patient_key], axis=1) - target = data[label_key] - X_train, X_test, y_train, y_test = train_test_split(features, target, test_size = 0.20, random_state = seed, stratify=target) - - n_features = 40 - fs = SelectKBest(f_classif, k=n_features).fit(X_train, y_train) - index_features = fs.get_support() - X_train = X_train.iloc[:, index_features] - - # print(X_train.head()) - - # Get the unique values of the categorical features - col = list(X_train.columns) - categorical_features = [] - numerical_features = [] - for i in col: - if len(X_train[i].unique()) > 24: - numerical_features.append(i) - # else: - # categorical_features.append(i) - - transformers_dict = {} - - for i in categorical_features: - transformers_dict[i] = OrdinalEncoder() - for i in numerical_features: - transformers_dict[i] = StandardScaler() - - # df1 = data.copy(deep = True) - - for feature in transformers_dict: - transformers_dict[feature].fit(X_train[feature].values.reshape(-1, 1)) - - return index_features, transformers_dict - - - index_features, transformers_dict = get_preprocessing_params(preprocessing_data) - - def preprocess_data(data, index_features, column_transformer): - # Scale the data using the precomputed parameters - data = data.copy(deep = True) - features = data.drop([label_key, center_key, patient_key], axis=1) - features = features.iloc[:, index_features] - target = data[label_key] - - for feature in column_transformer: - features[feature] = column_transformer[feature].transform(features[feature].values.reshape(-1, 1)) - - X_train, X_test, y_train, y_test = train_test_split(features, target, test_size = 0.20, random_state = seed, stratify=target) - - return X_train, X_test, y_train, y_test - - X_train, X_test, y_train, y_test = preprocess_data(data, index_features, transformers_dict) - - # print shapes of the data - # print(X_train.shape) - # print(X_test.shape) - # print(y_train.shape) - # print(y_test.shape) - - # features = features.iloc[:, index_features] - - # X_train, X_test, y_train, y_test = train_test_split(features, target, test_size = 0.20, random_state = None, stratify=target) - - # print(features.head()) - - print(f'Center ID: {center_id} with {len(data)} samples of which positive samples are {len(data.loc[data[label_key] == 1])})') - - return (X_train, y_train), (X_test, y_test) - def load_kaggle_hf(data_path, center_id, config) -> Dataset: - id = center_id - seed = config["seed"] - - if id == -1: - id = 'switzerland' - elif id == 1: - id = 'hungarian' - elif id == 2: - id = 'va' - elif id == 0: - id = 'cleveland' - elif id == None: - pass - else: - raise ValueError(f"Invalid center id: {id}") - - # elif id == 5: - # id = 'cleveland' - + """ + Load Kaggle Heart Failure dataset for federated learning using prepare_dataset + + Args: + data_path: Path to the dataset + center_id: ID of the center (0: cleveland, 1: hungarian, 2: va, 3: switzerland, None: all) + config: Configuration dictionary + + Returns: + tuple: ((X_train, y_train), (X_test, y_test)) + """ + file_name = os.path.join(data_path, "kaggle_hf.csv") data = pd.read_csv(file_name) - - scaling_data = data.loc[(data['data_center'] == 'hungarian')] - # scaling_data = data - - if id is not None: - data = data.loc[(data['data_center'] == id)] - - # print('Categorical Features :',*categorical_features) - # print('Numerical Features :',*numerical_features) - - def get_preprocessing_params(data): - - # Get the unique values of the categorical features - col = list(data.columns) - categorical_features = [] - numerical_features = [] - for i in col: - if len(data[i].unique()) > 6: - numerical_features.append(i) - else: - categorical_features.append(i) - - transformers_dict = {} - - categorical_features.pop(categorical_features.index('HeartDisease')) - if 'RestingBP' in numerical_features: - numerical_features.pop(numerical_features.index('RestingBP')) - elif 'RestingBP' in categorical_features: - categorical_features.pop(categorical_features.index('RestingBP')) - categorical_features.pop(categorical_features.index('RestingECG')) - categorical_features.pop(categorical_features.index('data_center')) - numerical_features.pop(numerical_features.index('Oldpeak')) - min_max_scaling_features = ['Oldpeak'] - - for i in categorical_features: - transformers_dict[i] = OrdinalEncoder() - for i in numerical_features: - transformers_dict[i] = StandardScaler() - for i in min_max_scaling_features: - transformers_dict[i] = MinMaxScaler() - - df1 = data.copy(deep = True) - - target = df1['HeartDisease'] - X_train, X_test, y_train, y_test = train_test_split(df1, target, test_size = 0.20, random_state = seed) - - for feature in transformers_dict: - if feature == 'ST_Slope': - # Change value of last row to 'Down' to avoid error as it is missing in some splits - X_train[feature].iloc[-1] = 'Down' - transformers_dict[feature].fit(X_train[feature].values.reshape(-1, 1)) - else: - transformers_dict[feature].fit(X_train[feature].values.reshape(-1, 1)) - - return transformers_dict - + # Define centers + centers = ['cleveland', 'hungarian', 'va', 'switzerland'] - def preprocess_data(data, column_transformer): - # Scale the data using the precomputed parameters - df1 = data.copy(deep = True) - features = df1[df1.columns.drop(['HeartDisease','RestingBP','RestingECG', 'data_center'])] - target = df1['HeartDisease'] - - for feature in column_transformer: - features[feature] = column_transformer[feature].transform(features[feature].values.reshape(-1, 1)) - - X_train, X_test, y_train, y_test = train_test_split(features, target, test_size = 0.20, random_state = seed, stratify=target) - - return (X_train, y_train), (X_test, y_test) + # Map center_id to index + center_id_mapped = None + if center_id is not None: + if center_id == 0: + center_id_mapped = 0 # cleveland + elif center_id == 1: + center_id_mapped = 1 # hungarian + elif center_id == 2: + center_id_mapped = 2 # va + elif center_id == 3: + center_id_mapped = 3 # switzerland + else: + raise ValueError(f"Invalid center id: {center_id}") - - preprocessing_params = get_preprocessing_params(scaling_data) - - (X_train, y_train), (X_test, y_test) = preprocess_data(data, preprocessing_params) - - # n_females = len(X_train[X_train['Sex'] == 0]) - # print(f'n_females{n_females}') - # n_males = len(X_train[X_train['Sex'] == 1]) - # print(f'n_males{n_males}') - # print(len(X_train)) - # Get indexes of rows with men (Sex == 0) - n_females = len(X_train[X_train['Sex'] == 0]) - n_males = len(X_train[X_train['Sex'] == 1]) - print(f'Center {center_id} of size {len(X_train)} with n_females {n_females} and n_males {n_males} in training set') - - if center_id == 0: - men_indexes = X_train.index[X_train['Sex'] == 1] - female_indexes = X_train.index[X_train['Sex'] == 0] - # print(len(female_indexes)) - n_females_to_drop = int(len(female_indexes)*0.9) - female_indexes = female_indexes[:n_females_to_drop] - copy_male_indexes = men_indexes[:n_females_to_drop] - # print(len(female_indexes)) - X_train = X_train.drop(index=female_indexes) - y_train = y_train.drop(index=female_indexes) - # print(len(X_train)) - # print(f'Adding males {len(copy_male_indexes)}') - X_train = pd.concat([X_train, X_train.loc[copy_male_indexes]]) - y_train = pd.concat([y_train, y_train.loc[copy_male_indexes]]) - - if center_id == 2 or center_id == -1: - X_train = pd.concat([X_train, X_train, X_train, X_train]) - y_train = pd.concat([y_train, y_train, y_train, y_train]) - - n_females = len(X_train[X_train['Sex'] == 0]) - n_males = len(X_train[X_train['Sex'] == 1]) - print(f'Center {center_id} of size {len(X_train)} with n_females {n_females} and n_males {n_males} in training set') - # xx - return (X_train, y_train), (X_test, y_test) - + # Create center_indices + center_indices = [] + for center in centers: + indices = data.loc[data['data_center'] == center].index.tolist() + center_indices.append(indices) + + # Prepare X and y + X = data.drop(['HeartDisease', 'data_center'], axis=1) + y = data['HeartDisease'] + + X_train_processed, y_train, X_test_processed, y_test = prepare_dataset(X, y, center_id_mapped, config, center_indices) + + return (X_train_processed, y_train), (X_test_processed, y_test) def load_libsvm(config, center_id=None, task_type="BINARY"): # ## Manually download and load the tabular dataset from LIBSVM data @@ -639,6 +958,55 @@ def load_dt4h(config,id): y_test = data_target[int(dat_len*config["train_size"]):].iloc[:, 0] return (X_train, y_train), (X_test, y_test) +def load_diabetes(center_id, config): + """ + Load and preprocess diabetes dataset for federated learning with feature selection + + Args: + center_id: Identifier for the federated node + num_centers: Total number of federated centers + alpha: Dirichlet concentration parameter for data partitioning + reference_method: How to select reference center ('largest' or 'random') + global_preprocessing_params: Precomputed parameters (if None, will calculate) + n_features: Number of features to select (None for all features) + feature_selection_method: Method for feature selection + + Returns: + tuple: ((X_train, y_train), (X_test, y_test), preprocessing_params) + """ + + dataset_file = "dataset/cdc_diabetes_health_indicators.pkl" + if os.path.exists(dataset_file): + # Load from pickle + with open(dataset_file, 'rb') as f: + cdc_diabetes_health_indicators = pickle.load(f) + else: + # Download the dataset + cdc_diabetes_health_indicators = fetch_ucirepo(id=891).data + # save as pickle for faster loading next time + dataset = {"features": cdc_diabetes_health_indicators.features, "targets": cdc_diabetes_health_indicators.targets} + with open(dataset_file, 'wb') as f: + pickle.dump(dataset, f) + + # Get features and target + X = cdc_diabetes_health_indicators['features'] + y = cdc_diabetes_health_indicators['targets'] + + # convert y to a pandas Series for easier handling + y = pd.Series(y.values.flatten()) + + # # # # Use fraction of data for faster testing (optional) + if not config['num_clients'] == 1: + fraction = 1.0 + # Sample indices first, then select from both X and y + sampled_indices = X.sample(frac=fraction, random_state=42).index + X = X.loc[sampled_indices].reset_index(drop=True) + y = y.loc[sampled_indices].reset_index(drop=True) + + X_train_processed, y_train, X_test_processed, y_test = prepare_dataset(X, y, center_id, config) + + return (X_train_processed, y_train), (X_test_processed, y_test) + def cvd_to_torch(config): pass @@ -690,11 +1058,13 @@ def load_dataset(config, id=None): if config["dataset"] == "mnist": return load_mnist(id, config["num_clients"]) elif config["dataset"] == "cvd": - return load_cvd(config["data_path"], id) + return load_cvd(config["data_path"], id, config) elif config["dataset"] == "ukbb_cvd": return load_ukbb_cvd(config["data_path"], id, config) elif config["dataset"] == "kaggle_hf": return load_kaggle_hf(config["data_path"], id, config) + elif config["dataset"] == "diabetes": + return load_diabetes(id, config) elif config["dataset"] == "libsvm": return load_libsvm(config, id) elif config["dataset"] == "dt4h_format": diff --git a/flcore/metrics.py b/flcore/metrics.py index 7788f61..9bbcb89 100644 --- a/flcore/metrics.py +++ b/flcore/metrics.py @@ -8,6 +8,7 @@ BinaryPrecision, BinaryRecall, BinarySpecificity, + BinaryAUROC, ) from torchmetrics.functional.classification.precision_recall import ( @@ -43,17 +44,18 @@ def compute(self) -> Tensor: return (recall + specificity) / 2 -def get_metrics_collection(task_type="binary", device="cpu"): +def get_metrics_collection(task_type="binary", device="cpu", threshold=0.5): if task_type.lower() == "binary": return MetricCollection( { - "accuracy": BinaryAccuracy().to(device), - "precision": BinaryPrecision().to(device), - "recall": BinaryRecall().to(device), - "specificity": BinarySpecificity().to(device), - "f1": BinaryF1Score().to(device), - "balanced_accuracy": BinaryBalancedAccuracy().to(device), + "accuracy": BinaryAccuracy(threshold=threshold).to(device), + "precision": BinaryPrecision(threshold=threshold).to(device), + "recall": BinaryRecall(threshold=threshold).to(device), + "specificity": BinarySpecificity(threshold=threshold).to(device), + "f1": BinaryF1Score(threshold=threshold).to(device), + "balanced_accuracy": BinaryBalancedAccuracy(threshold=threshold).to(device), + "auroc": BinaryAUROC().to(device), } ) elif task_type.lower() == "reg": @@ -61,13 +63,25 @@ def get_metrics_collection(task_type="binary", device="cpu"): "mse": MeanSquaredError().to(device), }) -def calculate_metrics(y_true, y_pred, task_type="binary"): - metrics_collection = get_metrics_collection(task_type) + +def calculate_metrics(y_true, y_pred_proba, task_type="binary", threshold=0.5): + metrics_collection = get_metrics_collection(task_type, threshold=threshold) if not torch.is_tensor(y_true): - y_true = torch.tensor(y_true.tolist()) - if not torch.is_tensor(y_pred): - y_pred = torch.tensor(y_pred.tolist()) - metrics_collection.update(y_pred, y_true) + if isinstance(y_true, list): + y_true = torch.cat(y_true) + else: + y_true = torch.tensor(y_true.tolist()) + if not torch.is_tensor(y_pred_proba): + if isinstance(y_pred_proba, list): + y_pred_proba = torch.cat(y_pred_proba) + else: + y_pred_proba = torch.tensor(y_pred_proba.tolist()) + + # Extract probabilities for the positive class if shape>1 + if y_pred_proba.ndim > 1 and y_pred_proba.shape[1] > 1: + y_pred_proba = y_pred_proba[:, 1] + + metrics_collection.update(y_pred_proba, y_true) metrics = metrics_collection.compute() metrics = {k: v.item() for k, v in metrics.items()} @@ -89,4 +103,16 @@ def metrics_aggregation_fn(distributed_metrics): metrics['per client n samples'] = [res[0] for res in distributed_metrics] - return metrics \ No newline at end of file + return metrics + +def find_best_threshold(y_true, y_pred_proba, metric="balanced_accuracy"): + best_threshold = 0.5 + best_metric_value = 0.0 + + for threshold in np.arange(0.0, 1.01, 0.01): + metrics = calculate_metrics(y_true, y_pred_proba, threshold=threshold) + if metrics[metric] > best_metric_value: + best_metric_value = metrics[metric] + best_threshold = threshold + + return best_threshold diff --git a/flcore/models/linear_models/client.py b/flcore/models/linear_models/client.py index b7561be..b0dd88b 100644 --- a/flcore/models/linear_models/client.py +++ b/flcore/models/linear_models/client.py @@ -9,7 +9,7 @@ import flwr as fl from sklearn.metrics import log_loss from flcore.performance import measurements_metrics, get_metrics -from flcore.metrics import calculate_metrics +from flcore.metrics import calculate_metrics, find_best_threshold import time import pandas as pd from sklearn.preprocessing import StandardScaler @@ -25,7 +25,7 @@ def __init__(self, data,client_id,config): (self.X_train, self.y_train), (self.X_test, self.y_test) = data # Create train and validation split - self.X_train, self.X_val, self.y_train, self.y_val = train_test_split(self.X_train, self.y_train, test_size=0.2, random_state=42, stratify=self.y_train) + self.X_train, self.X_val, self.y_train, self.y_val = train_test_split(self.X_train, self.y_train, test_size=0.2, random_state=config['seed'], stratify=self.y_train) # #Only use the standardScaler to the continous variables # scaled_features_train = StandardScaler().fit_transform(self.X_train.values) @@ -44,7 +44,7 @@ def __init__(self, data,client_id,config): self.first_round = True self.personalize = True # Setting initial parameters, akin to model.compile for keras models - utils.set_initial_params(self.model,self.n_features) + utils.set_initial_params(self.model, (self.X_train, self.y_train), self.n_features) def get_parameters(self, config): # type: ignore #compute the feature selection @@ -67,10 +67,17 @@ def fit(self, parameters, config): # type: ignore self.model.fit(self.X_train, self.y_train) # self.model.fit(self.X_train.loc[:, parameters[2].astype(bool)], self.y_train) # y_pred = self.model.predict(self.X_test.loc[:, parameters[2].astype(bool)]) - y_pred = self.model.predict(self.X_test) - - metrics = calculate_metrics(self.y_test, y_pred) - print(f"Client {self.client_id} Evaluation just after local training: {metrics['balanced_accuracy']}") + # If LSVC is used, use decision_function instead of predict_proba + if self.model_name == 'lsvc': + y_pred_proba = self.model.decision_function(self.X_val) + else: + y_pred_proba = self.model.predict_proba(self.X_val) + best_threshold = find_best_threshold(self.y_val, y_pred_proba, metric="balanced_accuracy") + if self.model_name == 'lsvc': + y_pred_proba = self.model.decision_function(self.X_test) + else: + y_pred_proba = self.model.predict_proba(self.X_test) + metrics = calculate_metrics(self.y_test, y_pred_proba, threshold=best_threshold) # Add 'personalized' to the metrics to identify them metrics = {f"personalized {key}": metrics[key] for key in metrics} self.round_time = (time.time() - start_time) @@ -81,10 +88,19 @@ def fit(self, parameters, config): # type: ignore if self.first_round: local_model = utils.get_model(self.model_name, local=True) - utils.set_initial_params(local_model,self.n_features) + # utils.set_initial_params(local_model,self.n_features) local_model.fit(self.X_train, self.y_train) - y_pred = local_model.predict(self.X_test) - local_metrics = calculate_metrics(self.y_test, y_pred) + # Calculate validation set metrics + if self.model_name == 'lsvc': + y_pred_proba = self.model.decision_function(self.X_val) + else: + y_pred_proba = self.model.predict_proba(self.X_val) + best_threshold = find_best_threshold(self.y_val, y_pred_proba, metric="balanced_accuracy") + if self.model_name == 'lsvc': + y_pred_proba = self.model.decision_function(self.X_test) + else: + y_pred_proba = self.model.predict_proba(self.X_test) + local_metrics = calculate_metrics(self.y_test, y_pred_proba, threshold=best_threshold) #Add 'local' to the metrics to identify them local_metrics = {f"local {key}": local_metrics[key] for key in local_metrics} metrics.update(local_metrics) @@ -96,10 +112,17 @@ def evaluate(self, parameters, config): # type: ignore utils.set_model_params(self.model, parameters) # Calculate validation set metrics - y_pred = self.model.predict(self.X_val) - val_metrics = calculate_metrics(self.y_val, y_pred) + if self.model_name == 'lsvc': + y_pred_proba = self.model.decision_function(self.X_val) + else: + y_pred_proba = self.model.predict_proba(self.X_val) + best_threshold = find_best_threshold(self.y_val, y_pred_proba, metric="balanced_accuracy") + val_metrics = calculate_metrics(self.y_val, y_pred_proba, threshold=best_threshold) - y_pred = self.model.predict(self.X_test) + if self.model_name == 'lsvc': + y_pred_proba = self.model.decision_function(self.X_test) + else: + y_pred_proba = self.model.predict_proba(self.X_test) # y_pred = self.model.predict(self.X_test.loc[:, parameters[2].astype(bool)]) if(isinstance(self.model, SGDClassifier)): @@ -107,19 +130,19 @@ def evaluate(self, parameters, config): # type: ignore else: loss = log_loss(self.y_test, self.model.predict_proba(self.X_test), labels=[0, 1]) - metrics = calculate_metrics(self.y_test, y_pred) + metrics = calculate_metrics(self.y_test, y_pred_proba, threshold=best_threshold) + metrics_not_tuned = calculate_metrics(self.y_test, y_pred_proba, threshold=0.5) + metrics_not_tuned = {f"not tuned {key}": metrics_not_tuned[key] for key in metrics_not_tuned} + metrics.update(metrics_not_tuned) metrics["round_time [s]"] = self.round_time metrics["client_id"] = self.client_id - print(f"Client {self.client_id} Evaluation after aggregated model: {metrics['balanced_accuracy']}") - - # Add validation metrics to the evaluation metrics with a prefix val_metrics = {f"val {key}": val_metrics[key] for key in val_metrics} metrics.update(val_metrics) - return loss, len(y_pred), metrics + return loss, len(y_pred_proba), metrics def get_client(config,data,client_id) -> fl.client.Client: diff --git a/flcore/models/linear_models/server.py b/flcore/models/linear_models/server.py index 9204430..a49da28 100644 --- a/flcore/models/linear_models/server.py +++ b/flcore/models/linear_models/server.py @@ -138,9 +138,9 @@ def evaluate_held_out( def get_server_and_strategy(config): model_type = config['model'] - model = get_model(model_type) + # model = get_model(model_type) n_features = config['linear_models']['n_features'] - utils.set_initial_params(model, n_features) + # utils.set_initial_params(model, n_features) # Pass parameters to the Strategy for server-side parameter initialization #strategy = fl.server.strategy.FedAvg( diff --git a/flcore/models/linear_models/utils.py b/flcore/models/linear_models/utils.py index cdc36c9..512642e 100644 --- a/flcore/models/linear_models/utils.py +++ b/flcore/models/linear_models/utils.py @@ -20,14 +20,24 @@ def get_model(model_name, local=False): case "lsvc": #Linear classifiers (SVM, logistic regression, etc.) with SGD training. #If we use hinge, it implements SVM - model = SGDClassifier(max_iter=max_iter,n_iter_no_change=1000,average=True,random_state=42,class_weight= "balanced",warm_start=True,fit_intercept=True,loss="hinge", learning_rate='optimal') + model = SGDClassifier( + max_iter=max_iter, + n_iter_no_change=1000, + average=True, + # random_state=42, + class_weight= "balanced", + warm_start=True, + fit_intercept=True, + loss="hinge", + learning_rate='optimal' + ) case "logistic_regression": model = LogisticRegression( penalty="l2", #max_iter=1, # local epoch ==>> it doesn't work max_iter=max_iter, # local epoch warm_start=True, # prevent refreshing weights when fitting - random_state=42, + # random_state=42, class_weight= "balanced" #For unbalanced ) case "elastic_net": @@ -38,7 +48,7 @@ def get_model(model_name, local=False): #max_iter=1, # local epoch ==>> it doesn't work max_iter=max_iter, # local epoch warm_start=True, # prevent refreshing weights when fitting - random_state=42, + # random_state=42, class_weight= "balanced" #For unbalanced ) @@ -73,7 +83,7 @@ def set_model_params( return model -def set_initial_params(model: LinearClassifier,n_features): +def set_initial_params(model: LinearClassifier, data, n_features): """Sets initial parameters as zeros Required since model params are uninitialized until model.fit is called. But server asks for initial parameters from clients at launch. Refer @@ -82,16 +92,18 @@ def set_initial_params(model: LinearClassifier,n_features): """ n_classes = 2 # MNIST has 10 classes #n_features = 9 # Number of features in dataset + + model.fit(data[0], data[1]) model.classes_ = np.array([i for i in range(n_classes)]) - if(isinstance(model,SGDClassifier)==True): - model.coef_ = np.zeros((1, n_features)) - if model.fit_intercept: - model.intercept_ = 0 - else: - model.coef_ = np.zeros((n_classes, n_features)) - if model.fit_intercept: - model.intercept_ = np.zeros((n_classes,)) + # if(isinstance(model,SGDClassifier)==True): + # model.coef_ = np.zeros((1, n_features)) + # if model.fit_intercept: + # model.intercept_ = 0 + # else: + # model.coef_ = np.zeros((n_classes, n_features)) + # if model.fit_intercept: + # model.intercept_ = np.zeros((n_classes,)) #Evaluate in the aggregations evaluation with diff --git a/flcore/models/random_forest/FedCustomAggregator.py b/flcore/models/random_forest/FedCustomAggregator.py index 0da2e6b..adb8842 100644 --- a/flcore/models/random_forest/FedCustomAggregator.py +++ b/flcore/models/random_forest/FedCustomAggregator.py @@ -153,14 +153,6 @@ def aggregate_fit( self.time_server_round = time.time() print(f"Elapsed time: {elapsed_time} for round {server_round}") metrics_aggregated['training_time [s]'] = self.accum_time - - filename = 'server_results.txt' - with open( - filename, - "a", - ) as f: - f.write(f"Accumulated Time: {self.accum_time} for round {server_round}\n") - return parameters_aggregated, metrics_aggregated @@ -194,15 +186,6 @@ def aggregate_evaluate( elif server_round == 1: # Only log this warning once log(WARNING, "No evaluate_metrics_aggregation_fn provided") - # filename = 'server_results.txt' - # with open( - # filename, - # "a", - # ) as f: - # f.write(f"Accuracy: {metrics_aggregated['accuracy']} \n") - # f.write(f"Sensitivity: {metrics_aggregated['sensitivity']} \n") - # f.write(f"Specificity: {metrics_aggregated['specificity']} \n") - return loss_aggregated, metrics_aggregated diff --git a/flcore/models/random_forest/aggregatorRF.py b/flcore/models/random_forest/aggregatorRF.py index a55b8b8..b059309 100644 --- a/flcore/models/random_forest/aggregatorRF.py +++ b/flcore/models/random_forest/aggregatorRF.py @@ -117,8 +117,8 @@ def aggregateRF_withprevious(rfs,previous_estimators,bal_RF): #weigth, we transform into probability /sum(weights) #and random choice select according to probability distribution def aggregateRFwithSizeCenterProbs(rfs,bal_RF,smoothing_method,smoothing_strenght): - rfa= get_model(bal_RF) numberTreesperclient = int(len(rfs[0][0][0])) + rfa= get_model(bal_RF, numberTreesperclient) number_Clients = len(rfs) random_select =int(numberTreesperclient/number_Clients) list_classifiers = [] diff --git a/flcore/models/random_forest/client.py b/flcore/models/random_forest/client.py index 52e07cb..e53984b 100644 --- a/flcore/models/random_forest/client.py +++ b/flcore/models/random_forest/client.py @@ -7,7 +7,7 @@ from flcore.serialization_funs import serialize_RF, deserialize_RF import flcore.models.random_forest.utils as utils from flcore.performance import measurements_metrics -from flcore.metrics import calculate_metrics +from flcore.metrics import calculate_metrics, find_best_threshold from flwr.common import ( Code, EvaluateIns, @@ -26,12 +26,14 @@ class MnistClient(fl.client.Client): def __init__(self, data,client_id,config): self.client_id = client_id n_folds_out= config['num_rounds'] - seed=42 # Load data (self.X_train, self.y_train), (self.X_test, self.y_test) = data - self.splits_nested = datasets.split_partitions(n_folds_out,0.2, seed, self.X_train, self.y_train) - self.bal_RF = config['random_forest']['balanced_rf'] - self.model = utils.get_model(self.bal_RF) + self.splits_nested = datasets.split_partitions(n_folds_out,0.2, config['seed'], self.X_train, self.y_train) + self.bal_RF = True if config['model'] == 'balanced_random_forest' else False + self.model = utils.get_model(self.bal_RF, config['random_forest']['tree_num']) + self.round_time = 0 + self.tree_num = config['random_forest']['tree_num'] + self.first_round = True # Setting initial parameters, akin to model.compile for keras models utils.set_initial_params_client(self.model,self.X_train, self.y_train) def get_parameters(self, ins: GetParametersIns): # , config type: ignore @@ -57,32 +59,37 @@ def fit(self, ins: FitIns): # , parameters, config type: ignore with warnings.catch_warnings(): warnings.simplefilter("ignore") train_idx, val_idx = next(self.splits_nested) - X_train_2 = self.X_train.iloc[train_idx, :] - X_val = self.X_train.iloc[val_idx,:] - y_train_2 = self.y_train.iloc[train_idx] - y_val = self.y_train.iloc[val_idx] + self.X_train_2 = self.X_train.iloc[train_idx, :] + self.X_val = self.X_train.iloc[val_idx,:] + self.y_train_2 = self.y_train.iloc[train_idx] + self.y_val = self.y_train.iloc[val_idx] #To implement the center dropout, we need the execution time start_time = time.time() - self.model.fit(X_train_2, y_train_2) - #accuracy = model.score( X_test, y_test ) - # accuracy,specificity,sensitivity,balanced_accuracy, precision, F1_score = \ - # measurements_metrics(self.model,X_val, y_val) - y_pred = self.model.predict(X_val) - metrics = calculate_metrics(y_val, y_pred) - # print(f"Accuracy client in fit: {accuracy}") - # print(f"Sensitivity client in fit: {sensitivity}") - # print(f"Specificity client in fit: {specificity}") - # print(f"Balanced_accuracy in fit: {balanced_accuracy}") - # print(f"precision in fit: {precision}") - # print(f"F1_score in fit: {F1_score}") - + self.model.fit(self.X_train_2, self.y_train_2) elapsed_time = (time.time() - start_time) + y_pred_proba = self.model.predict_proba(self.X_val) + metrics = calculate_metrics(self.y_val, y_pred_proba) + metrics["running_time"] = elapsed_time + self.round_time = elapsed_time - print(f"num_client {self.client_id} has an elapsed time {elapsed_time}") - print(f"Training finished for round {ins.config['server_round']}") + if self.first_round: + local_model = utils.get_model(self.bal_RF, self.tree_num) + # utils.set_initial_params(local_model,self.n_features) + local_model.fit(self.X_train_2, self.y_train_2) + + y_pred_proba = self.model.predict_proba(self.X_val) + best_threshold = find_best_threshold(self.y_val, y_pred_proba, metric="balanced_accuracy") + + y_pred_proba = local_model.predict_proba(self.X_test) + local_metrics = calculate_metrics(self.y_test, y_pred_proba, threshold=best_threshold) + #Add 'local' to the metrics to identify them + local_metrics = {f"local {key}": local_metrics[key] for key in local_metrics} + metrics.update(local_metrics) + self.first_round = False + # Serialize to send it to the server params = utils.get_model_parameters(self.model) parameters_updated = serialize_RF(params) @@ -102,12 +109,22 @@ def evaluate(self, ins: EvaluateIns): # , parameters, config type: ignore #Deserialize to get the real parameters parameters = deserialize_RF(parameters) utils.set_model_params(self.model, parameters) + # Get threshold based on validation set + y_pred_proba = self.model.predict_proba(self.X_val) + best_threshold = find_best_threshold(self.y_val, y_pred_proba, metric="balanced_accuracy") + # Get validation metrics + val_metrics = calculate_metrics(self.y_val, y_pred_proba, threshold=best_threshold) + val_metrics = {f"val {key}": val_metrics[key] for key in val_metrics} + y_pred_prob = self.model.predict_proba(self.X_test) loss = log_loss(self.y_test, y_pred_prob) # accuracy,specificity,sensitivity,balanced_accuracy, precision, F1_score = \ # measurements_metrics(self.model,self.X_test, self.y_test) - y_pred = self.model.predict(self.X_test) - metrics = calculate_metrics(self.y_test, y_pred) + # y_pred = self.model.predict(self.X_test) + metrics = calculate_metrics(self.y_test, y_pred_prob, threshold=best_threshold) + metrics.update(val_metrics) + metrics["round_time [s]"] = self.round_time + metrics["client_id"] = self.client_id # print(f"Accuracy client in evaluate: {accuracy}") # print(f"Sensitivity client in evaluate: {sensitivity}") # print(f"Specificity client in evaluate: {specificity}") diff --git a/flcore/models/random_forest/server.py b/flcore/models/random_forest/server.py index acbfd1b..97a1373 100644 --- a/flcore/models/random_forest/server.py +++ b/flcore/models/random_forest/server.py @@ -33,8 +33,8 @@ def fit_round( server_round: int ) -> Dict: def get_server_and_strategy(config): - bal_RF = config['random_forest']['balanced_rf'] - model = get_model(bal_RF) + bal_RF = True if config['model'] == 'balanced_random_forest' else False + model = get_model(bal_RF, config['random_forest']['tree_num']) utils.set_initial_params_server( model) # Pass parameters to the Strategy for server-side parameter initialization @@ -46,6 +46,7 @@ def get_server_and_strategy(config): min_evaluate_clients = config['num_clients'], #enable evaluate_fn if we have data to evaluate in the server #evaluate_fn = utils_RF.get_evaluate_fn( model ), #no data in server + fit_metrics_aggregation_fn=metrics_aggregation_fn, evaluate_metrics_aggregation_fn = metrics_aggregation_fn, on_fit_config_fn = fit_round ) diff --git a/flcore/models/random_forest/utils.py b/flcore/models/random_forest/utils.py index 026c294..1170122 100644 --- a/flcore/models/random_forest/utils.py +++ b/flcore/models/random_forest/utils.py @@ -21,11 +21,11 @@ from typing import cast -def get_model(bal_RF): +def get_model(bal_RF, tree_num) -> RandomForestClassifier: if(bal_RF == True): - model = BalancedRandomForestClassifier(n_estimators=100,random_state=42) + model = BalancedRandomForestClassifier(n_estimators=tree_num,max_depth=10) else: - model = RandomForestClassifier(n_estimators=100,class_weight= "balanced",max_depth=2,random_state=42) + model = RandomForestClassifier(n_estimators=tree_num,max_depth=10,class_weight= "balanced_subsample") return model diff --git a/flcore/models/weighted_random_forest/client.py b/flcore/models/weighted_random_forest/client.py index 74fa60e..bd7b801 100644 --- a/flcore/models/weighted_random_forest/client.py +++ b/flcore/models/weighted_random_forest/client.py @@ -94,7 +94,7 @@ def __init__(self, data,client_id,config): # Load data (self.X_train, self.y_train), (self.X_test, self.y_test) = data self.splits_nested = datasets.split_partitions(n_folds_out,0.2, seed, self.X_train, self.y_train) - self.bal_RF = config['weighted_random_forest']['balanced_rf'] + self.bal_RF = True if config['model'] == 'balanced_random_forest' else False self.model = utils.get_model(self.bal_RF) # Setting initial parameters, akin to model.compile for keras models utils.set_initial_params_client(self.model,self.X_train, self.y_train) diff --git a/flcore/models/weighted_random_forest/server.py b/flcore/models/weighted_random_forest/server.py index 877b871..20539c2 100644 --- a/flcore/models/weighted_random_forest/server.py +++ b/flcore/models/weighted_random_forest/server.py @@ -32,7 +32,7 @@ def fit_round( server_round: int ) -> Dict: def get_server_and_strategy(config): - bal_RF = config['weighted_random_forest']['balanced_rf'] + bal_RF = True if config['model'] == 'balanced_random_forest' else False model = get_model(bal_RF) utils.set_initial_params_server( model) diff --git a/flcore/models/xgb/__init__.py b/flcore/models/xgb/__init__.py deleted file mode 100644 index 034de7d..0000000 --- a/flcore/models/xgb/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -import flcore.models.xgb.client -import flcore.models.xgb.server -import flcore.models.xgb.fed_custom_strategy -import flcore.models.xgb.utils diff --git a/flcore/models/xgblr/__init__.py b/flcore/models/xgblr/__init__.py new file mode 100644 index 0000000..478cd6d --- /dev/null +++ b/flcore/models/xgblr/__init__.py @@ -0,0 +1,4 @@ +import flcore.models.xgblr.client +import flcore.models.xgblr.server +import flcore.models.xgblr.fed_custom_strategy +import flcore.models.xgblr.utils diff --git a/flcore/models/xgb/client.py b/flcore/models/xgblr/client.py similarity index 68% rename from flcore/models/xgb/client.py rename to flcore/models/xgblr/client.py index 6bcbc1a..2a1d65a 100644 --- a/flcore/models/xgb/client.py +++ b/flcore/models/xgblr/client.py @@ -22,9 +22,10 @@ from flwr.common.typing import Parameters from torch.utils.data import DataLoader from xgboost import XGBClassifier, XGBRegressor +from sklearn.model_selection import KFold, StratifiedShuffleSplit, train_test_split -from flcore.models.xgb.cnn import CNN, test, train -from flcore.models.xgb.utils import ( +from flcore.models.xgblr.cnn import CNN, test, train +from flcore.models.xgblr.utils import ( NumpyEncoder, TreeDataset, construct_tree_from_loader, @@ -34,17 +35,19 @@ tree_encoding_loader, train_test ) +from flcore.metrics import calculate_metrics, find_best_threshold + class FL_Client(fl.client.Client): def __init__( self, task_type: str, - trainloader: DataLoader, - valloader: DataLoader, + data, client_tree_num: int, client_num: int, cid: str, + config, log_progress: bool = False, ): """ @@ -52,9 +55,6 @@ def __init__( """ self.task_type = task_type self.cid = cid - self.tree = construct_tree_from_loader(trainloader, client_tree_num, task_type) - self.trainloader_original = trainloader - self.valloader_original = valloader self.trainloader = None self.valloader = None self.client_tree_num = client_tree_num @@ -66,13 +66,25 @@ def __init__( "task_type": self.task_type, } self.tmp_dir = "" - # instantiate model self.net = CNN(client_num=client_num, client_tree_num=client_tree_num) - # determine device self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.round_time = -1 + self.first_round = True + batch_size = "whole" + + (self.X_train, self.y_train), (self.X_test, self.y_test) = data + + self.X_train, self.X_val, self.y_train, self.y_val = train_test_split(self.X_train, self.y_train, test_size=0.2, random_state=config['seed'], stratify=self.y_train) + + trainset = TreeDataset(np.array(self.X_train, copy=True), np.array(self.y_train, copy=True)) + valset = TreeDataset(np.array(self.X_val, copy=True), np.array(self.y_val, copy=True)) + testset = TreeDataset(np.array(self.X_test, copy=True), np.array(self.y_test, copy=True)) + self.trainloader_original = get_dataloader(trainset, "train", batch_size) + self.valloader_original = get_dataloader(valset, "test", batch_size) + self.testloader_original = get_dataloader(testset, "test", batch_size) + self.tree = construct_tree_from_loader(self.trainloader_original, client_tree_num, task_type) def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: return GetPropertiesRes(properties=self.properties) @@ -125,6 +137,10 @@ def fit(self, fit_params: FitIns) -> FitRes: print("Client " + self.cid + ": recieved", len(aggregated_trees), "trees") else: print("Client " + self.cid + ": only had its own tree") + + # Don't prepare dataloaders if their number of clients didn't change + # if type(aggregated_trees) is list and len(aggregated_trees) != self.client_num or self.trainloader is None: + self.trainloader = tree_encoding_loader( self.trainloader_original, batch_size, @@ -139,6 +155,15 @@ def fit(self, fit_params: FitIns) -> FitRes: self.client_tree_num, self.client_num, ) + self.testloader = tree_encoding_loader( + self.testloader_original, + batch_size, + aggregated_trees, + self.client_tree_num, + self.client_num, + ) + # else: + # print("Client " + self.cid + ": reusing existing dataloaders") # num_iterations = None special behaviour: train(...) runs for a single epoch, however many updates it may be num_iterations = num_iterations or len(self.trainloader) @@ -160,6 +185,22 @@ def fit(self, fit_params: FitIns) -> FitRes: ) self.round_time = (time.time() - start_time) + metrics = {} + + if self.first_round: + #Get best threshold based on validation set + y_pred_proba_val = self.tree.predict_proba(self.X_val) + best_threshold = find_best_threshold(self.y_val, y_pred_proba_val, metric="balanced_accuracy") + y_pred_proba = self.tree.predict_proba(self.X_test) + local_metrics = calculate_metrics(self.y_test, y_pred_proba, threshold=best_threshold) + #Add 'local' to the metrics to identify them + local_metrics = {f"local {key}": local_metrics[key] for key in local_metrics} + metrics.update(local_metrics) + self.first_round = False + + metrics.update({ + "running_time": self.round_time, + "train_loss": train_loss}) # Return training information: model, number of examples processed and metrics if self.task_type == "BINARY": @@ -168,7 +209,7 @@ def fit(self, fit_params: FitIns) -> FitRes: # parameters=self.get_parameters(fit_params.config), parameters=self.get_parameters(fit_params.config).parameters, num_examples=num_examples, - metrics={"loss": train_loss, "accuracy": train_result, "running_time":self.round_time}, + metrics=metrics, ) elif self.task_type == "REG": return FitRes( @@ -194,8 +235,9 @@ def evaluate(self, eval_params: EvaluateIns) -> EvaluateRes: loss, result, num_examples = test( self.task_type, self.net, - self.valloader, + self.testloader, device=self.device, + valloader=self.valloader, log_progress=self.log_progress, ) @@ -230,38 +272,45 @@ def evaluate(self, eval_params: EvaluateIns) -> EvaluateRes: def get_client(config, data, client_id) -> fl.client.Client: (X_train, y_train), (X_test, y_test) = data - task_type = config["xgb"]["task_type"] + task_type = config["xgblr"]["task_type"] client_num = config["num_clients"] - client_tree_num = config["xgb"]["tree_num"] // client_num + client_tree_num = config["xgblr"]["tree_num"] // client_num batch_size = "whole" cid = str(client_id) + #measure time for client data loading + time_start = time.time() trainset = TreeDataset(np.array(X_train, copy=True), np.array(y_train, copy=True)) valset = TreeDataset(np.array(X_test, copy=True), np.array(y_test, copy=True)) + time_end = time.time() + print(f"Client {cid}: Data loading time: {time_end - time_start} seconds") + time_start = time.time() trainloader = get_dataloader(trainset, "train", batch_size) valloader = get_dataloader(valset, "test", batch_size) + time_end = time.time() + print(f"Client {cid}: Dataloader creation time: {time_end - time_start} seconds") + + # metrics = train_test(data, client_tree_num) + # from flcore import datasets + # if client_id == 1: + # cross_id = 2 + # else: + # cross_id = 1 + # _, (X_test, y_test) = datasets.load_dataset(config, cross_id) - metrics = train_test(data, client_tree_num) - from flcore import datasets - if client_id == 1: - cross_id = 2 - else: - cross_id = 1 - _, (X_test, y_test) = datasets.load_dataset(config, cross_id) - - data = (X_train, y_train), (X_test, y_test) - metrics_cross = train_test(data, client_tree_num) - print("Client " + cid + " non-federated training results:") - print(metrics) - print("Cross testing model on client " + str(cross_id) + ":") - print(metrics_cross) + # data = (X_train, y_train), (X_test, y_test) + # metrics_cross = train_test(data, client_tree_num) + # print("Client " + cid + " non-federated training results:") + # print(metrics) + # print("Cross testing model on client " + str(cross_id) + ":") + # print(metrics_cross) client = FL_Client( task_type, - trainloader, - valloader, + data, client_tree_num, client_num, cid, - log_progress=False, + config, + log_progress=False ) return client diff --git a/flcore/models/xgb/cnn.py b/flcore/models/xgblr/cnn.py similarity index 73% rename from flcore/models/xgb/cnn.py rename to flcore/models/xgblr/cnn.py index 849efc3..3a5331b 100644 --- a/flcore/models/xgb/cnn.py +++ b/flcore/models/xgblr/cnn.py @@ -13,7 +13,7 @@ from sklearn.metrics import accuracy_score, mean_squared_error from torch.utils.data import DataLoader from torchmetrics import Accuracy, MeanSquaredError -from flcore.metrics import get_metrics_collection +from flcore.metrics import calculate_metrics, find_best_threshold from tqdm import tqdm @@ -147,6 +147,7 @@ def test( net: CNN, testloader: DataLoader, device: torch.device, + valloader: DataLoader = None, log_progress: bool = True, ) -> Tuple[float, float, int]: """Evaluates the network on test data.""" @@ -157,39 +158,48 @@ def test( elif task_type == "REG": criterion = nn.MSELoss() - total_loss, total_result, n_samples = 0.0, 0.0, 0 - metrics = get_metrics_collection() net.eval() - with torch.no_grad(): - pbar = tqdm(testloader, desc="TEST") if log_progress else testloader - for data in pbar: - tree_outputs, labels = data[0].to(device), data[1].to(device) - outputs = net(tree_outputs) - - # Collected testing loss and accuracy statistics - total_loss += criterion(outputs, labels).item() - n_samples += labels.size(0) - num_classes = np.unique(labels.cpu().numpy()).size - - y_pred = outputs.cpu() - y_true = labels.cpu() - metrics.update(y_pred, y_true) - - # if task_type == "BINARY" or task_type == "MULTICLASS": - # if task_type == "MULTICLASS": - # raise NotImplementedError() - - # # acc = Accuracy(task=task_type.lower())( - # # outputs.cpu(), labels.type(torch.int).cpu()) - # # total_result += acc * labels.size(0) - # elif task_type == "REG": - # mse = MeanSquaredError()(outputs.cpu(), labels.type(torch.int).cpu()) - # total_result += mse * labels.size(0) - - metrics = metrics.compute() - metrics = {k: v.item() for k, v in metrics.items()} - - # total_result = total_result.item() + + # Collect predictions and true labels for the entire test set, to compute metrics at the end of the epoch + + def get_pred_proba(dataloader): + y_pred_list = [] + y_true_list = [] + total_loss, total_result, n_samples = 0.0, 0.0, 0 + with torch.no_grad(): + pbar = tqdm(dataloader, desc="TEST") if log_progress else dataloader + for data in pbar: + tree_outputs, labels = data[0].to(device), data[1].to(device) + outputs = net(tree_outputs) + # Collected testing loss and accuracy statistics + total_loss += criterion(outputs, labels).item() + n_samples += labels.size(0) + num_classes = np.unique(labels.cpu().numpy()).size + + y_pred = outputs.cpu() + y_true = labels.cpu() + y_pred_list.append(y_pred) + y_true_list.append(y_true) + + return y_true_list, y_pred_list, total_loss, n_samples + + metrics = {} + if valloader is not None: + y_true_val, y_pred_proba_val, val_loss, val_n_samples = get_pred_proba(valloader) + best_threshold = find_best_threshold(y_true_val, y_pred_proba_val, metric="balanced_accuracy") + metrics_val = calculate_metrics(y_true_val, y_pred_proba_val, task_type=task_type, threshold=best_threshold) + metrics_val = {f"val {key}": metrics_val[key] for key in metrics_val} + metrics.update(metrics_val) + else: + best_threshold = 0.5 + + # Add validation metrics to the evaluation metrics with a prefix + y_true, y_pred_proba, total_loss, n_samples = get_pred_proba(testloader) + metrics_test = calculate_metrics(y_true, y_pred_proba, task_type=task_type, threshold=best_threshold) + metrics_not_tuned = calculate_metrics(y_true, y_pred_proba, task_type=task_type, threshold=0.5) + metrics_not_tuned = {f"not tuned {key}": metrics_not_tuned[key] for key in metrics_not_tuned} + metrics.update(metrics_test) + metrics.update(metrics_not_tuned) if log_progress: print("\n") diff --git a/flcore/models/xgb/fed_custom_strategy.py b/flcore/models/xgblr/fed_custom_strategy.py similarity index 95% rename from flcore/models/xgb/fed_custom_strategy.py rename to flcore/models/xgblr/fed_custom_strategy.py index 20dbe55..9f74f4d 100644 --- a/flcore/models/xgb/fed_custom_strategy.py +++ b/flcore/models/xgblr/fed_custom_strategy.py @@ -143,4 +143,10 @@ def aggregate_fit( elif server_round == 1: # Only log this warning once log(WARNING, "No fit_metrics_aggregation_fn provided") + elapsed_time = (time.time() - self.time_server_round) + self.accum_time = self.accum_time+ elapsed_time + self.time_server_round = time.time() + print(f"Elapsed time: {elapsed_time} for round {server_round}") + metrics_aggregated['training_time [s]'] = self.accum_time + return [parameters_aggregated, trees_aggregated], metrics_aggregated \ No newline at end of file diff --git a/flcore/models/xgb/server.py b/flcore/models/xgblr/server.py similarity index 97% rename from flcore/models/xgb/server.py rename to flcore/models/xgblr/server.py index 046fc2d..4312d5d 100644 --- a/flcore/models/xgb/server.py +++ b/flcore/models/xgblr/server.py @@ -30,10 +30,10 @@ from xgboost import XGBClassifier, XGBRegressor from flcore.metrics import metrics_aggregation_fn -from flcore.models.xgb.client import FL_Client -from flcore.models.xgb.fed_custom_strategy import FedCustomStrategy -from flcore.models.xgb.cnn import CNN, test -from flcore.models.xgb.utils import ( +from flcore.models.xgblr.client import FL_Client +from flcore.models.xgblr.fed_custom_strategy import FedCustomStrategy +from flcore.models.xgblr.cnn import CNN, test +from flcore.models.xgblr.utils import ( TreeDataset, construct_tree, do_fl_partitioning, @@ -98,10 +98,13 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> History: for current_round in range(1, num_rounds + 1): # Train model and replace previous global model res_fit = self.fit_round(server_round=current_round, timeout=timeout) - if res_fit: - parameters_prime, _, _ = res_fit # fit_metrics_aggregated + if res_fit is not None: + parameters_prime, fit_metrics, _ = res_fit # fit_metrics_aggregated if parameters_prime: self.parameters = parameters_prime + history.add_metrics_distributed_fit( + server_round=current_round, metrics=fit_metrics + ) # Evaluate model using strategy implementation res_cen = self.strategy.evaluate(current_round, parameters=self.parameters) @@ -407,15 +410,15 @@ def get_server_and_strategy( # The number of clients participated in the federated learning client_num = config["num_clients"] # The number of XGBoost trees in the tree ensemble that will be built for each client - client_tree_num = config["xgb"]["tree_num"] // client_num + client_tree_num = config["xgblr"]["tree_num"] // client_num num_rounds = config["num_rounds"] client_pool_size = client_num - num_iterations = config["xgb"]["num_iterations"] + num_iterations = config["xgblr"]["num_iterations"] fraction_fit = 1.0 min_fit_clients = client_num - batch_size = config["xgb"]["batch_size"] + batch_size = config["xgblr"]["batch_size"] val_ratio = 0.1 # DATASET = "CVD" diff --git a/flcore/models/xgb/utils.py b/flcore/models/xgblr/utils.py similarity index 100% rename from flcore/models/xgb/utils.py rename to flcore/models/xgblr/utils.py diff --git a/flcore/report/generate_report.py b/flcore/report/generate_report.py index 1c92777..45e88f9 100644 --- a/flcore/report/generate_report.py +++ b/flcore/report/generate_report.py @@ -27,7 +27,8 @@ def generate_report(experiment_path: str): df = df.rename(columns={"Unnamed: 0": "center"}) # Convert metrics columns to 2 decimal places df = df.round(2) - colors = ['#FF6666', '#FF9999', '#FF3333', '#CC0000', '#990000', '#B22222', '#FF0044', '#960018'] + colors = ['#FF6666', '#FF9999', '#FF3333', '#CC0000', '#990000', '#B22222', '#FF0044', '#960018', '#FF0000', + '#B22222'] # print(df.head()) diff --git a/flcore/server_selector.py b/flcore/server_selector.py index 3ba5a06..dbcc26e 100644 --- a/flcore/server_selector.py +++ b/flcore/server_selector.py @@ -1,6 +1,6 @@ #import flcore.models.logistic_regression.server as logistic_regression_server #import flcore.models.logistic_regression.server as logistic_regression_server -import flcore.models.xgb.server as xgb_server +import flcore.models.xgblr.server as xgblr_server import flcore.models.random_forest.server as random_forest_server import flcore.models.linear_models.server as linear_models_server import flcore.models.weighted_random_forest.server as weighted_random_forest_server @@ -13,7 +13,7 @@ def get_model_server_and_strategy(config, data=None): server, strategy = linear_models_server.get_server_and_strategy( config ) - elif model == "random_forest": + elif model in ("random_forest", "balanced_random_forest"): server, strategy = random_forest_server.get_server_and_strategy( config ) @@ -22,8 +22,8 @@ def get_model_server_and_strategy(config, data=None): config ) - elif model == "xgb": - server, strategy = xgb_server.get_server_and_strategy(config, data) + elif model == "xgblr": + server, strategy = xgblr_server.get_server_and_strategy(config, data) else: raise ValueError(f"Unknown model: {model}") diff --git a/plots.ipynb b/plots.ipynb new file mode 100644 index 0000000..a5c008d --- /dev/null +++ b/plots.ipynb @@ -0,0 +1,1236 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4c815c0e", + "metadata": {}, + "source": [ + "## Select data to load based on keywords in experiment name" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9f05d536", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import os\n", + "\n", + "logs_dir = \"logs\"\n", + "logs_dir = \"benchmark_results\"\n", + "experiment_name = \"experiment_1percent\"\n", + "# experiment_name = \"experiment_good\"\n", + "# experiment_name = \"experiment_small\"\n", + "dataset_name = \"diabetes\"\n", + "# dataset_name = \"kaggle_hf\"\n", + "results_file = \"per_center_results.csv\"\n", + "keywords = [\n", + " experiment_name,\n", + " dataset_name,\n", + " # \"logistic_regression\",\n", + " # \"forest\",\n", + " # \"c10\"\n", + " # \"a0.7\"\n", + " # \"a1.0\"\n", + " \"aNone\"\n", + " ]\n", + "\n", + "def load_data(logs_dir, experiment_name, keywords, results_file=\"per_center_results.csv\"):\n", + " data = {}\n", + "\n", + " # iterate over all directories in logs_dir with names containing all the keywords\n", + " dirs = [d for d in os.listdir(logs_dir) if all(keyword in d for keyword in keywords)]\n", + " for d in dirs: \n", + " model_name = d\n", + " # model_name = model_name.replace(experiment_name+\"_\", \"\")\n", + " model_name = model_name.replace(experiment_name+\"_\"+dataset_name+\"_\", \"\")\n", + " model_name = model_name.replace(\"_\", \" \")\n", + " model_name = model_name.title()\n", + " model_name = model_name.replace(\"none\", \"N\")\n", + " # Find position of _c keyword\n", + " pos = model_name.find(\" C\")\n", + " # if pos != -1:\n", + " # if not model_name[pos+3].isdigit():\n", + " # model_name = model_name[:pos+2] + \"0\" + model_name[pos+2:]\n", + " #remove non-capital letters\n", + " # model_name = ''.join(c for c in model_name if c.isupper() or c == ' ' or c.isdigit())\n", + "\n", + " full_path = os.path.join(logs_dir, d)\n", + " metrics_file = os.path.join(full_path, results_file)\n", + " if os.path.isfile(metrics_file):\n", + " df = pd.read_csv(metrics_file)\n", + " data[model_name] = df\n", + "\n", + " print(\"Found \", len(data), \" experiments\")\n", + "\n", + " # Sort data by model_name\n", + " data = dict(sorted(data.items()))\n", + "\n", + " # for model_name, df in data.items():\n", + " # print(model_name)\n", + " \n", + " return data" + ] + }, + { + "cell_type": "markdown", + "id": "0389d57d", + "metadata": {}, + "source": [ + "### Print metric values" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "29bb08b0", + "metadata": {}, + "outputs": [], + "source": [ + "# metric = \"balanced_accuracy\"\n", + "# # metric = \"accuracy\"\n", + "# results = []\n", + "# #print average metric across all centers for each model\n", + "# for model_name, df in data.items():\n", + "# #weighted average by number of samples in each center\n", + "# total_samples = df[\"n samples\"].sum()\n", + "# weighted_sum = (df[metric] * df[\"n samples\"]).sum()\n", + "# avg_metric = weighted_sum / total_samples\n", + "# results.append(f\"{model_name}: {avg_metric:.4f}\")\n", + "# # print(f\"{model_name}: {avg_metric:.4f}\")\n", + "\n", + "# # Sort results alphabetically by model name\n", + "# results.sort()\n", + "# for result in results:\n", + "# print(result)" + ] + }, + { + "cell_type": "markdown", + "id": "7893ad6c", + "metadata": {}, + "source": [ + "# Bar plot for all imported models" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "905d8cfa", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Make a bar plot comparing the models based on the average metric\n", + "import matplotlib.pyplot as plt\n", + "model_names = []\n", + "avg_metrics = []\n", + "for model_name, df in data.items():\n", + " total_samples = df[\"n samples\"].sum()\n", + " weighted_sum = (df[metric] * df[\"n samples\"]).sum()\n", + " avg_metric = weighted_sum / total_samples\n", + " \n", + " model_names.append(model_name)\n", + " avg_metrics.append(avg_metric)\n", + "\n", + "# Sort models by names\n", + "model_names, avg_metrics = zip(*sorted(zip(model_names, avg_metrics)))\n", + "\n", + "plt.figure(figsize=(10, 6))\n", + "plt.bar(model_names, avg_metrics)\n", + "plt.ylabel(f'Average {metric.replace(\"_\", \" \").title()}')\n", + "plt.title(f'Comparison of Models based on Average {metric.replace(\"_\", \" \").title()}')\n", + "#start y-axis from 0.5\n", + "plt.ylim(bottom=0.5)\n", + "plt.xticks(rotation=85)\n", + "plt.tight_layout()\n", + "plt.show() \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "110609b3", + "metadata": {}, + "source": [ + "# Box Plots: Number of Clients \n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "07bd40b0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 20 experiments\n" + ] + } + ], + "source": [ + "# Feature selection experiment\n", + "# experiment_name = \"experiment_good\"\n", + "# experiment_name = \"experiment_all_10percent\"\n", + "experiment_name = \"num_clients_ablation\"\n", + "benchmark_dir = \"benchmark_results_num_clients_ablation\"\n", + "model_names = [\"balanced_random_forest\"]\n", + "datasets = [\"diabetes\"]\n", + "# num_clients = [5,10]\n", + "dirichlet_alpha = [\"0.7\"]\n", + "# dirichlet_alpha = [\"aNone\"]\n", + "keywords = [experiment_name] + datasets + dirichlet_alpha\n", + "\n", + "data = load_data(benchmark_dir, experiment_name, keywords)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66277bb4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Logistic Regression\n", + " model run n_clients alpha \n", + "76 Logistic Regression 0 10 0.7 Normglobal FeatN \\\n", + "77 Logistic Regression 1 10 0.7 Normglobal FeatN \n", + "78 Logistic Regression 2 10 0.7 Normglobal FeatN \n", + "79 Logistic Regression 3 10 0.7 Normglobal FeatN \n", + "80 Logistic Regression 4 10 0.7 Normglobal FeatN \n", + "81 Logistic Regression 5 10 0.7 Normglobal FeatN \n", + "82 Logistic Regression 6 10 0.7 Normglobal FeatN \n", + "83 Logistic Regression 7 10 0.7 Normglobal FeatN \n", + "84 Logistic Regression 8 10 0.7 Normglobal FeatN \n", + "85 Logistic Regression 9 10 0.7 Normglobal FeatN \n", + "86 Logistic Regression 0 20 0.7 Normglobal FeatN \n", + "87 Logistic Regression 1 20 0.7 Normglobal FeatN \n", + "88 Logistic Regression 2 20 0.7 Normglobal FeatN \n", + "89 Logistic Regression 3 20 0.7 Normglobal FeatN \n", + "90 Logistic Regression 4 20 0.7 Normglobal FeatN \n", + "91 Logistic Regression 5 20 0.7 Normglobal FeatN \n", + "92 Logistic Regression 6 20 0.7 Normglobal FeatN \n", + "93 Logistic Regression 7 20 0.7 Normglobal FeatN \n", + "94 Logistic Regression 8 20 0.7 Normglobal FeatN \n", + "95 Logistic Regression 9 20 0.7 Normglobal FeatN \n", + "96 Logistic Regression 10 20 0.7 Normglobal FeatN \n", + "97 Logistic Regression 11 20 0.7 Normglobal FeatN \n", + "98 Logistic Regression 12 20 0.7 Normglobal FeatN \n", + "99 Logistic Regression 13 20 0.7 Normglobal FeatN \n", + "100 Logistic Regression 14 20 0.7 Normglobal FeatN \n", + "101 Logistic Regression 15 20 0.7 Normglobal FeatN \n", + "102 Logistic Regression 16 20 0.7 Normglobal FeatN \n", + "103 Logistic Regression 17 20 0.7 Normglobal FeatN \n", + "104 Logistic Regression 18 20 0.7 Normglobal FeatN \n", + "105 Logistic Regression 19 20 0.7 Normglobal FeatN \n", + "106 Logistic Regression 0 3 0.7 Normglobal FeatN \n", + "107 Logistic Regression 1 3 0.7 Normglobal FeatN \n", + "108 Logistic Regression 2 3 0.7 Normglobal FeatN \n", + "109 Logistic Regression 0 5 0.7 Normglobal FeatN \n", + "110 Logistic Regression 1 5 0.7 Normglobal FeatN \n", + "111 Logistic Regression 2 5 0.7 Normglobal FeatN \n", + "112 Logistic Regression 3 5 0.7 Normglobal FeatN \n", + "113 Logistic Regression 4 5 0.7 Normglobal FeatN \n", + "\n", + " balanced_accuracy \n", + "76 0.759 \n", + "77 0.752 \n", + "78 0.783 \n", + "79 0.741 \n", + "80 0.724 \n", + "81 0.748 \n", + "82 0.742 \n", + "83 0.753 \n", + "84 0.741 \n", + "85 0.741 \n", + "86 0.651 \n", + "87 0.669 \n", + "88 0.755 \n", + "89 0.796 \n", + "90 0.759 \n", + "91 0.749 \n", + "92 0.629 \n", + "93 0.756 \n", + "94 0.758 \n", + "95 0.726 \n", + "96 0.746 \n", + "97 0.741 \n", + "98 0.745 \n", + "99 0.717 \n", + "100 0.728 \n", + "101 0.735 \n", + "102 0.681 \n", + "103 0.742 \n", + "104 0.740 \n", + "105 0.746 \n", + "106 0.739 \n", + "107 0.745 \n", + "108 0.748 \n", + "109 0.738 \n", + "110 0.756 \n", + "111 0.730 \n", + "112 0.741 \n", + "113 0.746 \n", + "Logistic Regression [106 0.739\n", + "107 0.745\n", + "108 0.748\n", + "Name: balanced_accuracy, dtype: float64, 109 0.738\n", + "110 0.756\n", + "111 0.730\n", + "112 0.741\n", + "113 0.746\n", + "Name: balanced_accuracy, dtype: float64, 76 0.759\n", + "77 0.752\n", + "78 0.783\n", + "79 0.741\n", + "80 0.724\n", + "81 0.748\n", + "82 0.742\n", + "83 0.753\n", + "84 0.741\n", + "85 0.741\n", + "Name: balanced_accuracy, dtype: float64, 86 0.651\n", + "87 0.669\n", + "88 0.755\n", + "89 0.796\n", + "90 0.759\n", + "91 0.749\n", + "92 0.629\n", + "93 0.756\n", + "94 0.758\n", + "95 0.726\n", + "96 0.746\n", + "97 0.741\n", + "98 0.745\n", + "99 0.717\n", + "100 0.728\n", + "101 0.735\n", + "102 0.681\n", + "103 0.742\n", + "104 0.740\n", + "105 0.746\n", + "Name: balanced_accuracy, dtype: float64]\n", + "ElasticNet\n", + " model run n_clients alpha balanced_accuracy\n", + "38 ElasticNet 0 10 0.7 Normglobal FeatN 0.765\n", + "39 ElasticNet 1 10 0.7 Normglobal FeatN 0.700\n", + "40 ElasticNet 2 10 0.7 Normglobal FeatN 0.767\n", + "41 ElasticNet 3 10 0.7 Normglobal FeatN 0.745\n", + "42 ElasticNet 4 10 0.7 Normglobal FeatN 0.734\n", + "43 ElasticNet 5 10 0.7 Normglobal FeatN 0.741\n", + "44 ElasticNet 6 10 0.7 Normglobal FeatN 0.747\n", + "45 ElasticNet 7 10 0.7 Normglobal FeatN 0.758\n", + "46 ElasticNet 8 10 0.7 Normglobal FeatN 0.743\n", + "47 ElasticNet 9 10 0.7 Normglobal FeatN 0.743\n", + "48 ElasticNet 0 20 0.7 Normglobal FeatN 0.680\n", + "49 ElasticNet 1 20 0.7 Normglobal FeatN 0.651\n", + "50 ElasticNet 2 20 0.7 Normglobal FeatN 0.755\n", + "51 ElasticNet 3 20 0.7 Normglobal FeatN 0.727\n", + "52 ElasticNet 4 20 0.7 Normglobal FeatN 0.749\n", + "53 ElasticNet 5 20 0.7 Normglobal FeatN 0.742\n", + "54 ElasticNet 6 20 0.7 Normglobal FeatN 0.661\n", + "55 ElasticNet 7 20 0.7 Normglobal FeatN 0.738\n", + "56 ElasticNet 8 20 0.7 Normglobal FeatN 0.749\n", + "57 ElasticNet 9 20 0.7 Normglobal FeatN 0.720\n", + "58 ElasticNet 10 20 0.7 Normglobal FeatN 0.743\n", + "59 ElasticNet 11 20 0.7 Normglobal FeatN 0.736\n", + "60 ElasticNet 12 20 0.7 Normglobal FeatN 0.730\n", + "61 ElasticNet 13 20 0.7 Normglobal FeatN 0.695\n", + "62 ElasticNet 14 20 0.7 Normglobal FeatN 0.718\n", + "63 ElasticNet 15 20 0.7 Normglobal FeatN 0.735\n", + "64 ElasticNet 16 20 0.7 Normglobal FeatN 0.698\n", + "65 ElasticNet 17 20 0.7 Normglobal FeatN 0.732\n", + "66 ElasticNet 18 20 0.7 Normglobal FeatN 0.733\n", + "67 ElasticNet 19 20 0.7 Normglobal FeatN 0.747\n", + "68 ElasticNet 0 3 0.7 Normglobal FeatN 0.743\n", + "69 ElasticNet 1 3 0.7 Normglobal FeatN 0.744\n", + "70 ElasticNet 2 3 0.7 Normglobal FeatN 0.748\n", + "71 ElasticNet 0 5 0.7 Normglobal FeatN 0.741\n", + "72 ElasticNet 1 5 0.7 Normglobal FeatN 0.770\n", + "73 ElasticNet 2 5 0.7 Normglobal FeatN 0.727\n", + "74 ElasticNet 3 5 0.7 Normglobal FeatN 0.742\n", + "75 ElasticNet 4 5 0.7 Normglobal FeatN 0.746\n", + "ElasticNet [68 0.743\n", + "69 0.744\n", + "70 0.748\n", + "Name: balanced_accuracy, dtype: float64, 71 0.741\n", + "72 0.770\n", + "73 0.727\n", + "74 0.742\n", + "75 0.746\n", + "Name: balanced_accuracy, dtype: float64, 38 0.765\n", + "39 0.700\n", + "40 0.767\n", + "41 0.745\n", + "42 0.734\n", + "43 0.741\n", + "44 0.747\n", + "45 0.758\n", + "46 0.743\n", + "47 0.743\n", + "Name: balanced_accuracy, dtype: float64, 48 0.680\n", + "49 0.651\n", + "50 0.755\n", + "51 0.727\n", + "52 0.749\n", + "53 0.742\n", + "54 0.661\n", + "55 0.738\n", + "56 0.749\n", + "57 0.720\n", + "58 0.743\n", + "59 0.736\n", + "60 0.730\n", + "61 0.695\n", + "62 0.718\n", + "63 0.735\n", + "64 0.698\n", + "65 0.732\n", + "66 0.733\n", + "67 0.747\n", + "Name: balanced_accuracy, dtype: float64]\n", + "Linear SVC\n", + " model run n_clients alpha balanced_accuracy\n", + "114 Linear SVC 0 10 0.7 Normglobal FeatN 0.754\n", + "115 Linear SVC 1 10 0.7 Normglobal FeatN 0.738\n", + "116 Linear SVC 2 10 0.7 Normglobal FeatN 0.779\n", + "117 Linear SVC 3 10 0.7 Normglobal FeatN 0.747\n", + "118 Linear SVC 4 10 0.7 Normglobal FeatN 0.750\n", + "119 Linear SVC 5 10 0.7 Normglobal FeatN 0.746\n", + "120 Linear SVC 6 10 0.7 Normglobal FeatN 0.744\n", + "121 Linear SVC 7 10 0.7 Normglobal FeatN 0.757\n", + "122 Linear SVC 8 10 0.7 Normglobal FeatN 0.746\n", + "123 Linear SVC 9 10 0.7 Normglobal FeatN 0.746\n", + "124 Linear SVC 0 20 0.7 Normglobal FeatN 0.641\n", + "125 Linear SVC 1 20 0.7 Normglobal FeatN 0.692\n", + "126 Linear SVC 2 20 0.7 Normglobal FeatN 0.742\n", + "127 Linear SVC 3 20 0.7 Normglobal FeatN 0.779\n", + "128 Linear SVC 4 20 0.7 Normglobal FeatN 0.750\n", + "129 Linear SVC 5 20 0.7 Normglobal FeatN 0.728\n", + "130 Linear SVC 6 20 0.7 Normglobal FeatN 0.592\n", + "131 Linear SVC 7 20 0.7 Normglobal FeatN 0.747\n", + "132 Linear SVC 8 20 0.7 Normglobal FeatN 0.755\n", + "133 Linear SVC 9 20 0.7 Normglobal FeatN 0.725\n", + "134 Linear SVC 10 20 0.7 Normglobal FeatN 0.742\n", + "135 Linear SVC 11 20 0.7 Normglobal FeatN 0.742\n", + "136 Linear SVC 12 20 0.7 Normglobal FeatN 0.764\n", + "137 Linear SVC 13 20 0.7 Normglobal FeatN 0.744\n", + "138 Linear SVC 14 20 0.7 Normglobal FeatN 0.727\n", + "139 Linear SVC 15 20 0.7 Normglobal FeatN 0.732\n", + "140 Linear SVC 16 20 0.7 Normglobal FeatN 0.708\n", + "141 Linear SVC 17 20 0.7 Normglobal FeatN 0.745\n", + "142 Linear SVC 18 20 0.7 Normglobal FeatN 0.741\n", + "143 Linear SVC 19 20 0.7 Normglobal FeatN 0.742\n", + "144 Linear SVC 0 3 0.7 Normglobal FeatN 0.729\n", + "145 Linear SVC 1 3 0.7 Normglobal FeatN 0.750\n", + "146 Linear SVC 2 3 0.7 Normglobal FeatN 0.752\n", + "147 Linear SVC 0 5 0.7 Normglobal FeatN 0.731\n", + "148 Linear SVC 1 5 0.7 Normglobal FeatN 0.768\n", + "149 Linear SVC 2 5 0.7 Normglobal FeatN 0.732\n", + "150 Linear SVC 3 5 0.7 Normglobal FeatN 0.749\n", + "151 Linear SVC 4 5 0.7 Normglobal FeatN 0.750\n", + "Linear SVC [144 0.729\n", + "145 0.750\n", + "146 0.752\n", + "Name: balanced_accuracy, dtype: float64, 147 0.731\n", + "148 0.768\n", + "149 0.732\n", + "150 0.749\n", + "151 0.750\n", + "Name: balanced_accuracy, dtype: float64, 114 0.754\n", + "115 0.738\n", + "116 0.779\n", + "117 0.747\n", + "118 0.750\n", + "119 0.746\n", + "120 0.744\n", + "121 0.757\n", + "122 0.746\n", + "123 0.746\n", + "Name: balanced_accuracy, dtype: float64, 124 0.641\n", + "125 0.692\n", + "126 0.742\n", + "127 0.779\n", + "128 0.750\n", + "129 0.728\n", + "130 0.592\n", + "131 0.747\n", + "132 0.755\n", + "133 0.725\n", + "134 0.742\n", + "135 0.742\n", + "136 0.764\n", + "137 0.744\n", + "138 0.727\n", + "139 0.732\n", + "140 0.708\n", + "141 0.745\n", + "142 0.741\n", + "143 0.742\n", + "Name: balanced_accuracy, dtype: float64]\n", + "Random Forest\n", + " model run n_clients alpha balanced_accuracy\n", + "152 Random Forest 0 10 0.7 Normglobal FeatN 0.771\n", + "153 Random Forest 1 10 0.7 Normglobal FeatN 0.816\n", + "154 Random Forest 2 10 0.7 Normglobal FeatN 0.749\n", + "155 Random Forest 3 10 0.7 Normglobal FeatN 0.743\n", + "156 Random Forest 4 10 0.7 Normglobal FeatN 0.761\n", + "157 Random Forest 5 10 0.7 Normglobal FeatN 0.732\n", + "158 Random Forest 6 10 0.7 Normglobal FeatN 0.750\n", + "159 Random Forest 7 10 0.7 Normglobal FeatN 0.747\n", + "160 Random Forest 8 10 0.7 Normglobal FeatN 0.744\n", + "161 Random Forest 9 10 0.7 Normglobal FeatN 0.747\n", + "162 Random Forest 0 20 0.7 Normglobal FeatN 0.737\n", + "163 Random Forest 1 20 0.7 Normglobal FeatN 0.700\n", + "164 Random Forest 2 20 0.7 Normglobal FeatN 0.742\n", + "165 Random Forest 3 20 0.7 Normglobal FeatN 0.800\n", + "166 Random Forest 4 20 0.7 Normglobal FeatN 0.761\n", + "167 Random Forest 5 20 0.7 Normglobal FeatN 0.762\n", + "168 Random Forest 6 20 0.7 Normglobal FeatN 0.610\n", + "169 Random Forest 7 20 0.7 Normglobal FeatN 0.734\n", + "170 Random Forest 8 20 0.7 Normglobal FeatN 0.756\n", + "171 Random Forest 9 20 0.7 Normglobal FeatN 0.753\n", + "172 Random Forest 10 20 0.7 Normglobal FeatN 0.756\n", + "173 Random Forest 11 20 0.7 Normglobal FeatN 0.737\n", + "174 Random Forest 12 20 0.7 Normglobal FeatN 0.755\n", + "175 Random Forest 13 20 0.7 Normglobal FeatN 0.713\n", + "176 Random Forest 14 20 0.7 Normglobal FeatN 0.717\n", + "177 Random Forest 15 20 0.7 Normglobal FeatN 0.739\n", + "178 Random Forest 16 20 0.7 Normglobal FeatN 0.714\n", + "179 Random Forest 17 20 0.7 Normglobal FeatN 0.746\n", + "180 Random Forest 18 20 0.7 Normglobal FeatN 0.743\n", + "181 Random Forest 19 20 0.7 Normglobal FeatN 0.751\n", + "182 Random Forest 0 3 0.7 Normglobal FeatN 0.737\n", + "183 Random Forest 1 3 0.7 Normglobal FeatN 0.750\n", + "184 Random Forest 2 3 0.7 Normglobal FeatN 0.753\n", + "185 Random Forest 0 5 0.7 Normglobal FeatN 0.744\n", + "186 Random Forest 1 5 0.7 Normglobal FeatN 0.771\n", + "187 Random Forest 2 5 0.7 Normglobal FeatN 0.739\n", + "188 Random Forest 3 5 0.7 Normglobal FeatN 0.747\n", + "189 Random Forest 4 5 0.7 Normglobal FeatN 0.747\n", + "Random Forest [182 0.737\n", + "183 0.750\n", + "184 0.753\n", + "Name: balanced_accuracy, dtype: float64, 185 0.744\n", + "186 0.771\n", + "187 0.739\n", + "188 0.747\n", + "189 0.747\n", + "Name: balanced_accuracy, dtype: float64, 152 0.771\n", + "153 0.816\n", + "154 0.749\n", + "155 0.743\n", + "156 0.761\n", + "157 0.732\n", + "158 0.750\n", + "159 0.747\n", + "160 0.744\n", + "161 0.747\n", + "Name: balanced_accuracy, dtype: float64, 162 0.737\n", + "163 0.700\n", + "164 0.742\n", + "165 0.800\n", + "166 0.761\n", + "167 0.762\n", + "168 0.610\n", + "169 0.734\n", + "170 0.756\n", + "171 0.753\n", + "172 0.756\n", + "173 0.737\n", + "174 0.755\n", + "175 0.713\n", + "176 0.717\n", + "177 0.739\n", + "178 0.714\n", + "179 0.746\n", + "180 0.743\n", + "181 0.751\n", + "Name: balanced_accuracy, dtype: float64]\n", + "Balanced Random Forest\n", + " model run n_clients alpha \n", + "0 Balanced Random Forest 0 10 0.7 Normglobal FeatN \\\n", + "1 Balanced Random Forest 1 10 0.7 Normglobal FeatN \n", + "2 Balanced Random Forest 2 10 0.7 Normglobal FeatN \n", + "3 Balanced Random Forest 3 10 0.7 Normglobal FeatN \n", + "4 Balanced Random Forest 4 10 0.7 Normglobal FeatN \n", + "5 Balanced Random Forest 5 10 0.7 Normglobal FeatN \n", + "6 Balanced Random Forest 6 10 0.7 Normglobal FeatN \n", + "7 Balanced Random Forest 7 10 0.7 Normglobal FeatN \n", + "8 Balanced Random Forest 8 10 0.7 Normglobal FeatN \n", + "9 Balanced Random Forest 9 10 0.7 Normglobal FeatN \n", + "10 Balanced Random Forest 0 20 0.7 Normglobal FeatN \n", + "11 Balanced Random Forest 1 20 0.7 Normglobal FeatN \n", + "12 Balanced Random Forest 2 20 0.7 Normglobal FeatN \n", + "13 Balanced Random Forest 3 20 0.7 Normglobal FeatN \n", + "14 Balanced Random Forest 4 20 0.7 Normglobal FeatN \n", + "15 Balanced Random Forest 5 20 0.7 Normglobal FeatN \n", + "16 Balanced Random Forest 6 20 0.7 Normglobal FeatN \n", + "17 Balanced Random Forest 7 20 0.7 Normglobal FeatN \n", + "18 Balanced Random Forest 8 20 0.7 Normglobal FeatN \n", + "19 Balanced Random Forest 9 20 0.7 Normglobal FeatN \n", + "20 Balanced Random Forest 10 20 0.7 Normglobal FeatN \n", + "21 Balanced Random Forest 11 20 0.7 Normglobal FeatN \n", + "22 Balanced Random Forest 12 20 0.7 Normglobal FeatN \n", + "23 Balanced Random Forest 13 20 0.7 Normglobal FeatN \n", + "24 Balanced Random Forest 14 20 0.7 Normglobal FeatN \n", + "25 Balanced Random Forest 15 20 0.7 Normglobal FeatN \n", + "26 Balanced Random Forest 16 20 0.7 Normglobal FeatN \n", + "27 Balanced Random Forest 17 20 0.7 Normglobal FeatN \n", + "28 Balanced Random Forest 18 20 0.7 Normglobal FeatN \n", + "29 Balanced Random Forest 19 20 0.7 Normglobal FeatN \n", + "30 Balanced Random Forest 0 3 0.7 Normglobal FeatN \n", + "31 Balanced Random Forest 1 3 0.7 Normglobal FeatN \n", + "32 Balanced Random Forest 2 3 0.7 Normglobal FeatN \n", + "33 Balanced Random Forest 0 5 0.7 Normglobal FeatN \n", + "34 Balanced Random Forest 1 5 0.7 Normglobal FeatN \n", + "35 Balanced Random Forest 2 5 0.7 Normglobal FeatN \n", + "36 Balanced Random Forest 3 5 0.7 Normglobal FeatN \n", + "37 Balanced Random Forest 4 5 0.7 Normglobal FeatN \n", + "\n", + " balanced_accuracy \n", + "0 0.769 \n", + "1 0.740 \n", + "2 0.767 \n", + "3 0.747 \n", + "4 0.742 \n", + "5 0.739 \n", + "6 0.755 \n", + "7 0.745 \n", + "8 0.749 \n", + "9 0.748 \n", + "10 0.713 \n", + "11 0.691 \n", + "12 0.740 \n", + "13 0.768 \n", + "14 0.756 \n", + "15 0.759 \n", + "16 0.628 \n", + "17 0.756 \n", + "18 0.755 \n", + "19 0.763 \n", + "20 0.758 \n", + "21 0.741 \n", + "22 0.756 \n", + "23 0.716 \n", + "24 0.725 \n", + "25 0.737 \n", + "26 0.707 \n", + "27 0.750 \n", + "28 0.746 \n", + "29 0.751 \n", + "30 0.741 \n", + "31 0.752 \n", + "32 0.755 \n", + "33 0.744 \n", + "34 0.769 \n", + "35 0.739 \n", + "36 0.746 \n", + "37 0.747 \n", + "Balanced Random Forest [30 0.741\n", + "31 0.752\n", + "32 0.755\n", + "Name: balanced_accuracy, dtype: float64, 33 0.744\n", + "34 0.769\n", + "35 0.739\n", + "36 0.746\n", + "37 0.747\n", + "Name: balanced_accuracy, dtype: float64, 0 0.769\n", + "1 0.740\n", + "2 0.767\n", + "3 0.747\n", + "4 0.742\n", + "5 0.739\n", + "6 0.755\n", + "7 0.745\n", + "8 0.749\n", + "9 0.748\n", + "Name: balanced_accuracy, dtype: float64, 10 0.713\n", + "11 0.691\n", + "12 0.740\n", + "13 0.768\n", + "14 0.756\n", + "15 0.759\n", + "16 0.628\n", + "17 0.756\n", + "18 0.755\n", + "19 0.763\n", + "20 0.758\n", + "21 0.741\n", + "22 0.756\n", + "23 0.716\n", + "24 0.725\n", + "25 0.737\n", + "26 0.707\n", + "27 0.750\n", + "28 0.746\n", + "29 0.751\n", + "Name: balanced_accuracy, dtype: float64]\n", + "XGBoost\n", + "Empty DataFrame\n", + "Columns: [model, run, n_clients, alpha, balanced_accuracy]\n", + "Index: []\n", + "XGBoost [Series([], Name: balanced_accuracy, dtype: float64), Series([], Name: balanced_accuracy, dtype: float64), Series([], Name: balanced_accuracy, dtype: float64), Series([], Name: balanced_accuracy, dtype: float64)]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Federated Learning Benchmark Summary - Box Plot Analysis\n", + "================================================================================\n", + "\n", + "Logistic Regression:\n", + "----------------------------------------\n", + " 3 clients: 0.7440 ± 0.0046 [0.7390, 0.7480]\n", + " 5 clients: 0.7422 ± 0.0097 [0.7300, 0.7560]\n", + " 10 clients: 0.7484 ± 0.0154 [0.7240, 0.7830]\n", + " 20 clients: 0.7285 ± 0.0407 [0.6290, 0.7960]\n", + " Performance degradation (3→20 clients): 2.09%\n", + "\n", + "ElasticNet:\n", + "----------------------------------------\n", + " 3 clients: 0.7450 ± 0.0026 [0.7430, 0.7480]\n", + " 5 clients: 0.7452 ± 0.0156 [0.7270, 0.7700]\n", + " 10 clients: 0.7443 ± 0.0189 [0.7000, 0.7670]\n", + " 20 clients: 0.7219 ± 0.0297 [0.6510, 0.7550]\n", + " Performance degradation (3→20 clients): 3.09%\n", + "\n", + "Linear SVC:\n", + "----------------------------------------\n", + " 3 clients: 0.7437 ± 0.0127 [0.7290, 0.7520]\n", + " 5 clients: 0.7460 ± 0.0152 [0.7310, 0.7680]\n", + " 10 clients: 0.7507 ± 0.0112 [0.7380, 0.7790]\n", + " 20 clients: 0.7269 ± 0.0428 [0.5920, 0.7790]\n", + " Performance degradation (3→20 clients): 2.25%\n", + "\n", + "Random Forest:\n", + "----------------------------------------\n", + " 3 clients: 0.7467 ± 0.0085 [0.7370, 0.7530]\n", + " 5 clients: 0.7496 ± 0.0124 [0.7390, 0.7710]\n", + " 10 clients: 0.7560 ± 0.0235 [0.7320, 0.8160]\n", + " 20 clients: 0.7363 ± 0.0369 [0.6100, 0.8000]\n", + " Performance degradation (3→20 clients): 1.39%\n", + "\n", + "Balanced Random Forest:\n", + "----------------------------------------\n", + " 3 clients: 0.7493 ± 0.0074 [0.7410, 0.7550]\n", + " 5 clients: 0.7490 ± 0.0116 [0.7390, 0.7690]\n", + " 10 clients: 0.7501 ± 0.0105 [0.7390, 0.7690]\n", + " 20 clients: 0.7358 ± 0.0328 [0.6280, 0.7680]\n", + " Performance degradation (3→20 clients): 1.81%\n", + "\n", + "XGBoost:\n", + "----------------------------------------\n", + " 3 clients: nan ± nan [nan, nan]\n", + " 5 clients: nan ± nan [nan, nan]\n", + " 10 clients: nan ± nan [nan, nan]\n", + " 20 clients: nan ± nan [nan, nan]\n", + " Performance degradation (3→20 clients): nan%\n", + "\n", + "================================================================================\n", + "COMPARATIVE ANALYSIS:\n", + "================================================================================\n", + "Best at 3 clients: Balanced Random Forest (balanced_accuracy: 0.7493)\n", + "Best at 5 clients: Random Forest (balanced_accuracy: 0.7496)\n", + "Best at 10 clients: Random Forest (balanced_accuracy: 0.7560)\n", + "Best at 20 clients: Random Forest (balanced_accuracy: 0.7363)\n", + "\n", + "Overall best model: Random Forest (Avg balanced_accuracy: 0.7441)\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "from sklearn.metrics import roc_auc_score\n", + "\n", + "# Set style for academic paper\n", + "plt.style.use('seaborn-v0_8-whitegrid')\n", + "# plt.rcParams['font.family'] = 'serif'\n", + "plt.rcParams['font.size'] = 10\n", + "\n", + "# Define models and client configurations (performance decreases with more clients)\n", + "models = ['Logistic Regression', 'ElasticNet', 'Linear SVC', 'Random Forest', 'Balanced Random Forest', 'XGBoost']\n", + "clients = [3, 5, 10, 20] # Only these client numbers\n", + "# clients = [3, 5, 10] # Only these client numbers\n", + "\n", + "extracted_data = []\n", + "# metric = \"auroc\" \n", + "metric = \"local balanced_accuracy\" \n", + "# metric = \"balanced_accuracy\" \n", + "for model_name, df in data.items():\n", + " model = model_name.split(\" C\")[0]\n", + " if model == \"Elastic Net\":\n", + " model = \"ElasticNet\"\n", + " if model == \"Lsvc\":\n", + " model = \"Linear SVC\"\n", + " num_clients = int(model_name.split(\" C\")[-1][:2])\n", + " alpha = model_name.split(\" A\")[-1]\n", + " metric_scores = df[metric].values\n", + " for center, score in enumerate(metric_scores):\n", + " extracted_data.append({\n", + " 'model': model,\n", + " 'run': center, # Placeholder, as run info is not available\n", + " 'n_clients': num_clients,\n", + " 'alpha': alpha,\n", + " metric: score\n", + " })\n", + " \n", + "# Convert to DataFrame\n", + "df = pd.DataFrame(extracted_data)\n", + "\n", + "# print(df)\n", + "\n", + "# Create 3x3 subplot grid\n", + "fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n", + "axes = axes.flatten()\n", + "\n", + "# Define colors for each model\n", + "colors = {\n", + " 'Logistic Regression': '#1f77b4',\n", + " 'ElasticNet': '#ff7f0e', \n", + " 'Linear SVC': '#2ca02c',\n", + " 'Random Forest': '#d62728',\n", + " 'Balanced Random Forest': '#8c564b',\n", + " 'XGBoost': '#9467bd',\n", + " 'MLP': '#8c564b'\n", + "}\n", + "\n", + "x_positions = clients\n", + "\n", + "# Plot box plots for each model in separate subplots\n", + "for i, model in enumerate(models):\n", + " if i < len(axes): # Ensure we don't exceed subplot count\n", + " ax = axes[i]\n", + " model_data = df[df['model'] == model]\n", + " print(model)\n", + " print(model_data)\n", + " # Prepare data for boxplot\n", + " boxplot_data = []\n", + " client_labels = []\n", + " box_positions = []\n", + "\n", + " for client_idx, client in enumerate(clients):\n", + " client_data = model_data[model_data['n_clients'] == client][metric]\n", + " boxplot_data.append(client_data)\n", + " box_positions.append(x_positions[client_idx])\n", + " client_labels.append(f'{client}')\n", + " \n", + " \n", + " # print(f\"Model: {model}\")\n", + " # print(\"Box data:\", boxplot_data)\n", + " \n", + " \n", + " # Create box plot with custom positions\n", + " # Adjust width relative to the x-axis scale\n", + " # Base width on the smallest gap between client numbers\n", + " min_gap = min([x_positions[i+1] - x_positions[i] for i in range(len(x_positions)-1)])\n", + " box_width = min_gap * 0.9 # Adjust this factor to control box width\n", + " \n", + " box_plots = ax.boxplot(boxplot_data, positions=box_positions, \n", + " widths=box_width, patch_artist=True,\n", + " showmeans=False, \n", + " meanprops={'marker':'o', 'markerfacecolor':'white', \n", + " 'markeredgecolor':'black'})\n", + " # Color the boxes\n", + " for patch in box_plots['boxes']:\n", + " patch.set_facecolor(colors[model])\n", + " patch.set_alpha(0.7)\n", + " \n", + " # Customize box plot elements\n", + " for element in ['whiskers', 'caps', 'medians']:\n", + " for line in box_plots[element]:\n", + " line.set_color('black')\n", + " line.set_linewidth(1.5)\n", + "\n", + " # Set x-ticks to client numbers\n", + " ax.set_xticks(box_positions)\n", + " ax.set_xticklabels(client_labels)\n", + " \n", + " print(model, boxplot_data)\n", + " # Set subplot title and labels\n", + " ax.set_title(f'{model}', fontsize=12, fontweight='bold')\n", + " ax.set_xlabel('Number of Clients', fontsize=10)\n", + " metric_formatted = metric.replace(\"_\", \" \").title()\n", + " ax.set_ylabel(metric_formatted, fontsize=10)\n", + " \n", + " # Set consistent y-axis across all subplots\n", + " # ax.set_ylim(0.5, 0.78)\n", + " ax.set_ylim(0.6, 0.85)\n", + "\n", + " # Set x-axis limits with some padding\n", + " ax.set_xlim(min(box_positions) - min_gap * 0.5, \n", + " max(box_positions) + min_gap * 0.5)\n", + " \n", + " # Add grid\n", + " ax.grid(True, alpha=0.3, axis='y')\n", + " \n", + " # Add trend annotation\n", + " means = [np.mean(client_data) for client_data in boxplot_data]\n", + " trend = means[0] - means[-1] # Performance drop from 3 to 20 clients\n", + " \n", + " # Add performance degradation annotation\n", + " ax.text(0.02, 0.98, f'Δ: -{trend:.3f}', transform=ax.transAxes, \n", + " fontsize=9, verticalalignment='top',\n", + " bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))\n", + "\n", + "# Remove empty subplot if we have 6 models in 3x3 grid\n", + "if len(models) < len(axes):\n", + " for i in range(len(models), len(axes)):\n", + " fig.delaxes(axes[i])\n", + "\n", + "# Add overall title\n", + "# fig.suptitle('Federated Learning Benchmark: Model Performance Distribution vs Number of Clients\\n'\n", + " # 'Box Plots Showing Performance Degradation with Increasing Clients', \n", + " # fontsize=14, fontweight='bold', y=0.98)\n", + "\n", + "plt.tight_layout()\n", + "plt.subplots_adjust(top=0.93)\n", + "plt.show()\n", + "\n", + "# Print detailed statistics for the paper\n", + "print(\"Federated Learning Benchmark Summary - Box Plot Analysis\")\n", + "print(\"=\" * 80)\n", + "\n", + "for model in models:\n", + " print(f\"\\n{model}:\")\n", + " print(\"-\" * 40)\n", + " model_data = df[df['model'] == model]\n", + " \n", + " for client in clients:\n", + " client_data = model_data[model_data['n_clients'] == client][metric]\n", + " mean_auc = client_data.mean()\n", + " std_auc = client_data.std()\n", + " min_auc = client_data.min()\n", + " max_auc = client_data.max()\n", + " \n", + " print(f\" {client:2d} clients: {mean_auc:.4f} ± {std_auc:.4f} \"\n", + " f\"[{min_auc:.4f}, {max_auc:.4f}]\")\n", + " \n", + " # Calculate overall degradation\n", + " perf_3 = model_data[model_data['n_clients'] == 3][metric].mean()\n", + " perf_20 = model_data[model_data['n_clients'] == 20][metric].mean()\n", + " degradation = ((perf_3 - perf_20) / perf_3) * 100\n", + " print(f\" Performance degradation (3→20 clients): {degradation:.2f}%\")\n", + "\n", + "# Comparative analysis\n", + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"COMPARATIVE ANALYSIS:\")\n", + "print(\"=\" * 80)\n", + "\n", + "# Find best performing model at each client count\n", + "for client in clients:\n", + " client_data = df[df['n_clients'] == client]\n", + " best_model = None\n", + " best_auc = 0\n", + " \n", + " for model in models:\n", + " model_auc = client_data[client_data['model'] == model][metric].mean()\n", + " if model_auc > best_auc:\n", + " best_auc = model_auc\n", + " best_model = model\n", + "\n", + " print(f\"Best at {client:2d} clients: {best_model} ({metric}: {best_auc:.4f})\")\n", + "\n", + "# Overall best model\n", + "overall_means = df.groupby('model')[metric].mean()\n", + "best_overall_model = overall_means.idxmax()\n", + "best_overall_auc = overall_means.max()\n", + "\n", + "print(f\"\\nOverall best model: {best_overall_model} (Avg {metric}: {best_overall_auc:.4f})\")" + ] + }, + { + "cell_type": "markdown", + "id": "30c417a3", + "metadata": {}, + "source": [ + "# Table: Normalization impact" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec9feea2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Weighted Average Metrics Table:\n", + "\n", + "Model Balanced Accuracy Auroc\n", + "Diabetes Logistic Regression C10 A0.7 NormN FeatN 0.654 ± 0.021 0.729 ± 0.034\n", + "Diabetes Logistic Regression C10 A0.7 Normglobal FeatN 0.745 ± 0.041 0.803 ± 0.049\n", + "Diabetes Logistic Regression C10 A0.7 Normlocal FeatN 0.730 ± 0.024 0.801 ± 0.056\n", + "Diabetes Logistic Regression C10 AN NormN FeatN 0.665 ± 0.017 0.725 ± 0.014\n", + "Diabetes Logistic Regression C10 AN Normglobal FeatN 0.755 ± 0.011 0.829 ± 0.009\n", + "Diabetes Logistic Regression C10 AN Normlocal FeatN 0.759 ± 0.011 0.830 ± 0.010\n", + "Ukbb Cvd Logistic Regression C10 A0.7 NormN FeatN 0.515 ± 0.009 0.529 ± 0.027\n", + "Ukbb Cvd Logistic Regression C10 A0.7 Normglobal FeatN 0.742 ± 0.035 0.814 ± 0.024\n", + "Ukbb Cvd Logistic Regression C10 A0.7 Normlocal FeatN 0.746 ± 0.037 0.818 ± 0.021\n", + "Ukbb Cvd Logistic Regression C10 AN NormN FeatN 0.518 ± 0.008 0.530 ± 0.024\n", + "Ukbb Cvd Logistic Regression C10 AN Normglobal FeatN 0.741 ± 0.034 0.814 ± 0.023\n", + "Ukbb Cvd Logistic Regression C10 AN Normlocal FeatN 0.746 ± 0.036 0.818 ± 0.021\n", + "\n", + "LaTeX Table:\n", + "\n", + "\\begin{tabular}{lcc}\n", + "Model & Balanced Accuracy & Auroc \\\\\n", + "\\hline\n", + "Diabetes Logistic Regression C10 A0.7 NormN FeatN & 0.654 $\\pm$ 0.021 & 0.729 $\\pm$ 0.034 \\\\\n", + "Diabetes Logistic Regression C10 A0.7 Normglobal FeatN & 0.745 $\\pm$ 0.041 & 0.803 $\\pm$ 0.049 \\\\\n", + "Diabetes Logistic Regression C10 A0.7 Normlocal FeatN & 0.730 $\\pm$ 0.024 & 0.801 $\\pm$ 0.056 \\\\\n", + "Diabetes Logistic Regression C10 AN NormN FeatN & 0.665 $\\pm$ 0.017 & 0.725 $\\pm$ 0.014 \\\\\n", + "Diabetes Logistic Regression C10 AN Normglobal FeatN & 0.755 $\\pm$ 0.011 & 0.829 $\\pm$ 0.009 \\\\\n", + "Diabetes Logistic Regression C10 AN Normlocal FeatN & 0.759 $\\pm$ 0.011 & 0.830 $\\pm$ 0.010 \\\\\n", + "Ukbb Cvd Logistic Regression C10 A0.7 NormN FeatN & 0.515 $\\pm$ 0.009 & 0.529 $\\pm$ 0.027 \\\\\n", + "Ukbb Cvd Logistic Regression C10 A0.7 Normglobal FeatN & 0.742 $\\pm$ 0.035 & 0.814 $\\pm$ 0.024 \\\\\n", + "Ukbb Cvd Logistic Regression C10 A0.7 Normlocal FeatN & 0.746 $\\pm$ 0.037 & 0.818 $\\pm$ 0.021 \\\\\n", + "Ukbb Cvd Logistic Regression C10 AN NormN FeatN & 0.518 $\\pm$ 0.008 & 0.530 $\\pm$ 0.024 \\\\\n", + "Ukbb Cvd Logistic Regression C10 AN Normglobal FeatN & 0.741 $\\pm$ 0.034 & 0.814 $\\pm$ 0.023 \\\\\n", + "Ukbb Cvd Logistic Regression C10 AN Normlocal FeatN & 0.746 $\\pm$ 0.036 & 0.818 $\\pm$ 0.021 \\\\\n", + "\\end{tabular}\n" + ] + } + ], + "source": [ + "# Normalization experiment\n", + "experiment_name = \"normalization\"\n", + "logs_dir = \"benchmark_results_normalization\"\n", + "model_names = [\"logistic_regression\"]\n", + "datasets = [\"diabetes\"]\n", + "num_clients = [10]\n", + "dirichlet_alpha = [\"None\"]\n", + "data_normalization = [\"global\", \"local\", None]\n", + "keywords = [experiment_name]\n", + "data = load_data(logs_dir, experiment_name, keywords, results_file=\"per_center_results.csv\")\n", + "\n", + "# Write a code to extract the following metrics, calculate weighted averages and standard deviations and create a table with rows as models and columns as metrics in latex format\n", + "metrics_to_extract = [\"balanced_accuracy\", \"auroc\"]\n", + "table_data = {}\n", + "for model_name, df in data.items():\n", + " # model = model_name.split(\" Norm\")[1]\n", + " model = model_name\n", + " total_samples = df[\"n samples\"].sum()\n", + " table_data[model] = {}\n", + " for metric in metrics_to_extract:\n", + " weighted_sum = (df[metric] * df[\"n samples\"]).sum()\n", + " avg_metric = weighted_sum / total_samples\n", + " std_metric = ( ((df[metric] - avg_metric)**2 * df[\"n samples\"]).sum() / total_samples )**0.5\n", + " table_data[model][metric] = (avg_metric, std_metric)\n", + "\n", + "# Print nicely formatted table\n", + "print(\"\\nWeighted Average Metrics Table:\\n\")\n", + "header = \"Model\".ljust(30)\n", + "for metric in metrics_to_extract:\n", + " header += f\"{metric.replace('_', ' ').title():>30}\"\n", + "print(header)\n", + "for model, metrics in table_data.items():\n", + " row = model.ljust(30)\n", + " for metric in metrics_to_extract:\n", + " avg, std = metrics[metric]\n", + " row += f\"{avg:.3f} ± {std:.3f}\".rjust(30)\n", + " print(row)\n", + "\n", + "\n", + "# Create latex table\n", + "latex_table = \"\\\\begin{tabular}{l\" + \"c\" * len(metrics_to_extract) + \"}\\n\"\n", + "latex_table += \"Model\"\n", + "for metric in metrics_to_extract:\n", + " latex_table += f\" & {metric.replace('_', ' ').title()}\"\n", + "latex_table += \" \\\\\\\\\\n\\\\hline\\n\"\n", + "for model, metrics in table_data.items():\n", + " latex_table += model\n", + " for metric in metrics_to_extract:\n", + " avg, std = metrics[metric]\n", + " latex_table += f\" & {avg:.3f} $\\\\pm$ {std:.3f}\"\n", + " latex_table += \" \\\\\\\\\\n\"\n", + "latex_table += \"\\\\end{tabular}\"\n", + "print(\"\\nLaTeX Table:\\n\")\n", + "print(latex_table)\n" + ] + }, + { + "cell_type": "markdown", + "id": "35133b50", + "metadata": {}, + "source": [ + "# Table: Feature Selection" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "add792d5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 19 experiments\n", + "\n", + "Weighted Average Metrics Table:\n", + "\n", + "Model Balanced Accuracy Auroc Round Time [S]\n", + "Ukbb Cvd Balanced Random Forest C10 A0.7 Normglobal Feat10 0.749 ± 0.029 0.817 ± 0.027 1.425 ± 0.274\n", + "Ukbb Cvd Balanced Random Forest C10 A0.7 Normglobal Feat20 0.759 ± 0.021 0.829 ± 0.024 1.501 ± 0.275\n", + "Ukbb Cvd Balanced Random Forest C10 A0.7 Normglobal Feat35 0.754 ± 0.024 0.829 ± 0.024 1.591 ± 0.292\n", + "Ukbb Cvd Balanced Random Forest C10 A0.7 Normglobal Feat40 0.753 ± 0.022 0.827 ± 0.024 1.638 ± 0.346\n", + "Ukbb Cvd Balanced Random Forest C10 A0.7 Normglobal FeatN 0.754 ± 0.027 0.826 ± 0.024 1.636 ± 0.321\n", + "Ukbb Cvd Balanced Random Forest C10 AN Normglobal Feat10 0.749 ± 0.031 0.817 ± 0.026 1.416 ± 0.276\n", + "Ukbb Cvd Balanced Random Forest C10 AN Normglobal Feat20 0.758 ± 0.022 0.828 ± 0.025 1.511 ± 0.303\n", + "Ukbb Cvd Balanced Random Forest C10 AN Normglobal Feat35 0.755 ± 0.027 0.828 ± 0.025 1.582 ± 0.339\n", + "Ukbb Cvd Balanced Random Forest C10 AN Normglobal Feat40 0.752 ± 0.025 0.826 ± 0.023 1.624 ± 0.311\n", + "Ukbb Cvd Balanced Random Forest C5 A0.7 Normglobal Feat10 0.757 ± 0.024 0.818 ± 0.022 0.966 ± 0.199\n", + "Ukbb Cvd Balanced Random Forest C5 A0.7 Normglobal Feat20 0.742 ± 0.028 0.823 ± 0.027 1.032 ± 0.212\n", + "Ukbb Cvd Balanced Random Forest C5 A0.7 Normglobal Feat35 0.750 ± 0.027 0.825 ± 0.025 1.098 ± 0.226\n", + "Ukbb Cvd Balanced Random Forest C5 A0.7 Normglobal Feat40 0.747 ± 0.018 0.825 ± 0.025 1.128 ± 0.235\n", + "Ukbb Cvd Balanced Random Forest C5 A0.7 Normglobal FeatN 0.750 ± 0.031 0.824 ± 0.027 1.146 ± 0.240\n", + "Ukbb Cvd Balanced Random Forest C5 AN Normglobal Feat10 0.755 ± 0.023 0.819 ± 0.022 0.983 ± 0.199\n", + "Ukbb Cvd Balanced Random Forest C5 AN Normglobal Feat20 0.742 ± 0.028 0.823 ± 0.026 1.035 ± 0.216\n", + "Ukbb Cvd Balanced Random Forest C5 AN Normglobal Feat35 0.750 ± 0.030 0.824 ± 0.024 1.075 ± 0.225\n", + "Ukbb Cvd Balanced Random Forest C5 AN Normglobal Feat40 0.747 ± 0.023 0.824 ± 0.025 1.106 ± 0.233\n", + "Ukbb Cvd Balanced Random Forest C5 AN Normglobal FeatN 0.747 ± 0.032 0.823 ± 0.027 1.120 ± 0.236\n", + "\n", + "LaTeX Table:\n", + "\n", + "\\begin{tabular}{lccc}\n", + "Model & Balanced Accuracy & Auroc & Round Time [S] \\\\\n", + "\\hline\n", + "Ukbb Cvd Balanced Random Forest C10 A0.7 Normglobal Feat10 & 0.749 $\\pm$ 0.029 & 0.817 $\\pm$ 0.027 & 1.425 $\\pm$ 0.274 \\\\\n", + "Ukbb Cvd Balanced Random Forest C10 A0.7 Normglobal Feat20 & 0.759 $\\pm$ 0.021 & 0.829 $\\pm$ 0.024 & 1.501 $\\pm$ 0.275 \\\\\n", + "Ukbb Cvd Balanced Random Forest C10 A0.7 Normglobal Feat35 & 0.754 $\\pm$ 0.024 & 0.829 $\\pm$ 0.024 & 1.591 $\\pm$ 0.292 \\\\\n", + "Ukbb Cvd Balanced Random Forest C10 A0.7 Normglobal Feat40 & 0.753 $\\pm$ 0.022 & 0.827 $\\pm$ 0.024 & 1.638 $\\pm$ 0.346 \\\\\n", + "Ukbb Cvd Balanced Random Forest C10 A0.7 Normglobal FeatN & 0.754 $\\pm$ 0.027 & 0.826 $\\pm$ 0.024 & 1.636 $\\pm$ 0.321 \\\\\n", + "Ukbb Cvd Balanced Random Forest C10 AN Normglobal Feat10 & 0.749 $\\pm$ 0.031 & 0.817 $\\pm$ 0.026 & 1.416 $\\pm$ 0.276 \\\\\n", + "Ukbb Cvd Balanced Random Forest C10 AN Normglobal Feat20 & 0.758 $\\pm$ 0.022 & 0.828 $\\pm$ 0.025 & 1.511 $\\pm$ 0.303 \\\\\n", + "Ukbb Cvd Balanced Random Forest C10 AN Normglobal Feat35 & 0.755 $\\pm$ 0.027 & 0.828 $\\pm$ 0.025 & 1.582 $\\pm$ 0.339 \\\\\n", + "Ukbb Cvd Balanced Random Forest C10 AN Normglobal Feat40 & 0.752 $\\pm$ 0.025 & 0.826 $\\pm$ 0.023 & 1.624 $\\pm$ 0.311 \\\\\n", + "Ukbb Cvd Balanced Random Forest C5 A0.7 Normglobal Feat10 & 0.757 $\\pm$ 0.024 & 0.818 $\\pm$ 0.022 & 0.966 $\\pm$ 0.199 \\\\\n", + "Ukbb Cvd Balanced Random Forest C5 A0.7 Normglobal Feat20 & 0.742 $\\pm$ 0.028 & 0.823 $\\pm$ 0.027 & 1.032 $\\pm$ 0.212 \\\\\n", + "Ukbb Cvd Balanced Random Forest C5 A0.7 Normglobal Feat35 & 0.750 $\\pm$ 0.027 & 0.825 $\\pm$ 0.025 & 1.098 $\\pm$ 0.226 \\\\\n", + "Ukbb Cvd Balanced Random Forest C5 A0.7 Normglobal Feat40 & 0.747 $\\pm$ 0.018 & 0.825 $\\pm$ 0.025 & 1.128 $\\pm$ 0.235 \\\\\n", + "Ukbb Cvd Balanced Random Forest C5 A0.7 Normglobal FeatN & 0.750 $\\pm$ 0.031 & 0.824 $\\pm$ 0.027 & 1.146 $\\pm$ 0.240 \\\\\n", + "Ukbb Cvd Balanced Random Forest C5 AN Normglobal Feat10 & 0.755 $\\pm$ 0.023 & 0.819 $\\pm$ 0.022 & 0.983 $\\pm$ 0.199 \\\\\n", + "Ukbb Cvd Balanced Random Forest C5 AN Normglobal Feat20 & 0.742 $\\pm$ 0.028 & 0.823 $\\pm$ 0.026 & 1.035 $\\pm$ 0.216 \\\\\n", + "Ukbb Cvd Balanced Random Forest C5 AN Normglobal Feat35 & 0.750 $\\pm$ 0.030 & 0.824 $\\pm$ 0.024 & 1.075 $\\pm$ 0.225 \\\\\n", + "Ukbb Cvd Balanced Random Forest C5 AN Normglobal Feat40 & 0.747 $\\pm$ 0.023 & 0.824 $\\pm$ 0.025 & 1.106 $\\pm$ 0.233 \\\\\n", + "Ukbb Cvd Balanced Random Forest C5 AN Normglobal FeatN & 0.747 $\\pm$ 0.032 & 0.823 $\\pm$ 0.027 & 1.120 $\\pm$ 0.236 \\\\\n", + "\\end{tabular}\n" + ] + } + ], + "source": [ + "# Feature selection experiment\n", + "experiment_name = \"feature_selection\"\n", + "benchmark_dir = \"benchmark_results_feature_selection\"\n", + "model_names = [\"balanced_random_forest\"]\n", + "datasets = [\"ukbb_cvd\"]\n", + "num_clients = [5,10]\n", + "dirichlet_alpha = [0.7, None]\n", + "data_normalization = [\"global\"]\n", + "n_features = [10, 20, 35, 40, None]\n", + "keywords = [experiment_name]\n", + "\n", + "data = load_data(benchmark_dir, experiment_name, keywords)\n", + "# Write a code to extract the following metrics, calculate weighted averages and standard deviations and create a table with rows as models and columns as metrics in latex format\n", + "metrics_to_extract = [\"balanced_accuracy\", \"auroc\", \"round_time [s]\"]\n", + "table_data = {}\n", + "for model_name, df in data.items():\n", + " # model = model_name.split(\" Norm\")[1]\n", + " model = model_name\n", + " total_samples = df[\"n samples\"].sum()\n", + " table_data[model] = {}\n", + " for metric in metrics_to_extract:\n", + " weighted_sum = (df[metric] * df[\"n samples\"]).sum()\n", + " avg_metric = weighted_sum / total_samples\n", + " std_metric = ( ((df[metric] - avg_metric)**2 * df[\"n samples\"]).sum() / total_samples )**0.5\n", + " table_data[model][metric] = (avg_metric, std_metric)\n", + "\n", + "# Print nicely formatted table\n", + "print(\"\\nWeighted Average Metrics Table:\\n\")\n", + "header = \"Model\".ljust(30)\n", + "for metric in metrics_to_extract:\n", + " header += f\"{metric.replace('_', ' ').title():>30}\"\n", + "print(header)\n", + "for model, metrics in table_data.items():\n", + " row = model.ljust(30)\n", + " for metric in metrics_to_extract:\n", + " avg, std = metrics[metric]\n", + " row += f\"{avg:.3f} ± {std:.3f}\".rjust(30)\n", + " print(row)\n", + "\n", + "\n", + "# Create latex table\n", + "latex_table = \"\\\\begin{tabular}{l\" + \"c\" * len(metrics_to_extract) + \"}\\n\"\n", + "latex_table += \"Model\"\n", + "for metric in metrics_to_extract:\n", + " latex_table += f\" & {metric.replace('_', ' ').title()}\"\n", + "latex_table += \" \\\\\\\\\\n\\\\hline\\n\"\n", + "for model, metrics in table_data.items():\n", + " latex_table += model\n", + " for metric in metrics_to_extract:\n", + " avg, std = metrics[metric]\n", + " latex_table += f\" & {avg:.3f} $\\\\pm$ {std:.3f}\"\n", + " latex_table += \" \\\\\\\\\\n\"\n", + "latex_table += \"\\\\end{tabular}\"\n", + "print(\"\\nLaTeX Table:\\n\")\n", + "print(latex_table)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "flc", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/repeated.py b/repeated.py index 567870e..6a226d3 100644 --- a/repeated.py +++ b/repeated.py @@ -1,13 +1,19 @@ import subprocess import time import os +import sys import yaml -with open("config.yaml", "r") as f: +if len(sys.argv) == 2: + config_path = sys.argv[1] +else: + config_path = "config.yaml" + +with open(config_path, "r") as f: config = yaml.safe_load(f) -repetitions = 4 +repetitions = 5 experiment_name = config['experiment']['name'] config['experiment']['log_path'] = os.path.join(config['experiment']['log_path'], config['experiment']['name']) @@ -21,6 +27,12 @@ config_path = os.path.join(config['experiment']['log_path'], "config.yaml") log_file_path = os.path.join(config['experiment']['log_path'], config['experiment']['name'], "run_log.txt") os.makedirs(os.path.join(config['experiment']['log_path'], config['experiment']['name']), exist_ok=True) + + # Kill any existing process using the same port + if 'local_port' in config: + kill_command = f"lsof -ti tcp:{config['local_port']} | xargs kill -9" + subprocess.run(kill_command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + with open(config_path, "w") as f: yaml.dump(config, f) try: @@ -35,6 +47,35 @@ with open(config_path, "w") as f: yaml.dump(config, f) + +# processes = [] +# try: +# for i in range(repetitions): +# print(f"Experiment run {i + 1}") +# config['experiment']['name'] = 'run_' + str(i + 1) +# config['seed'] = i + 10 +# config['local_port'] = 8081 + i +# config_path = os.path.join(config['experiment']['log_path'], config['experiment']['name'], "config.yaml") +# log_file_path = os.path.join(config['experiment']['log_path'], config['experiment']['name'], "run_log.txt") +# os.makedirs(os.path.join(config['experiment']['log_path'], config['experiment']['name']), exist_ok=True) +# with open(config_path, "w") as f: +# yaml.dump(config, f) +# run_process = subprocess.Popen(f"python run.py {config_path} | tee {log_file_path}", shell=True) +# # run_process.wait() +# processes.append(run_process) + +# for run_process in processes: +# run_process.wait() + +# except KeyboardInterrupt: +# run_process.terminate() +# run_process.wait() + +# config['experiment']['name'] = experiment_name +# config_path = os.path.join(config['experiment']['log_path'], "config.yaml") +# with open(config_path, "w") as f: +# yaml.dump(config, f) + run_process = subprocess.Popen(f"python flcore/compile_results.py {config['experiment']['log_path']}", shell=True) run_process.wait() diff --git a/requirements.txt b/requirements.txt index 13078ec..fc7ee35 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,5 +11,6 @@ scikit_learn==1.2.2 torch==2.0.1 torchmetrics==0.11.4 tqdm==4.65.0 +ucimlrepo==0.0.7 xgboost==1.7.5 pdfkit==1.0.0 diff --git a/server.py b/server.py index 0b9784a..24f5ec6 100644 --- a/server.py +++ b/server.py @@ -93,7 +93,7 @@ def check_config(config): # filename = os.path.join( checkpoint_dir, 'final_model.pt' ) # joblib.dump(model, filename) # Save the history as a yaml file - print(history) + # print(history) with open(experiment_dir / "metrics.txt", "w") as f: f.write(f"Results of the experiment {config['experiment']['name']}\n") f.write(f"Model: {config['model']}\n") @@ -101,10 +101,19 @@ def check_config(config): f.write(f"Number of clients: {config['num_clients']}\n") # selection_metric = 'val ' + config['checkpoint_selection_metric'] - selection_metric = config['checkpoint_selection_metric'] + selection_metric = "val " + config['checkpoint_selection_metric'] # Get index of tuple of the best round - best_round = int(numpy.argmax([round[1] for round in history.metrics_distributed[selection_metric]])) - training_time = history.metrics_distributed_fit['training_time [s]'][-1][1] + best_round = int(numpy.argmax([round[1] for round in history.metrics_distributed[selection_metric]])) + # Use the last round as final checkpoint, since no validation set is used + # best_round = -1 + # print(history) + # check if history has attribute metrics_distributed_fit + if hasattr(history, 'metrics_distributed_fit') and 'training_time [s]' in history.metrics_distributed_fit: + # check if training_time is in metrics_distributed_fit + training_time = history.metrics_distributed_fit['training_time [s]'][-1][1] + else: + training_time = 0.0 + f.write(f"Total training time: {training_time:.2f} [s] \n") f.write(f"Best checkpoint based on {selection_metric} after round: {best_round}\n\n") print(f"Best checkpoint based on {selection_metric} after round: {best_round}\n\n") diff --git a/tests/test_models.py b/tests/test_models.py index 669f1d0..f5969f7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -12,12 +12,18 @@ LOGGING_LEVEL = logging.INFO # WARNING # logging.INFO model_names = [ -# "logistic_regression", -# "elastic_net", -# "lsvc", + "logistic_regression", + "elastic_net", + "lsvc", "random_forest", - # "weighted_random_forest", - # "xgb" + "balanced_random_forest", + # # "weighted_random_forest", + "xgblr" + ] + +datasets = [ + "kaggle_hf", + "diabetes", ] def free_port(port): @@ -34,17 +40,29 @@ def setup_class(self): with open("config.yaml", "r") as f: self.config = yaml.safe_load(f) - self.num_clients = 3 + self.config["num_clients"] = 3 + self.config["num_rounds"] = 2 + + # To speed up tests, reduce number of trees in xgboost and random forest + self.config["random_forest"]["tree_num"] = 5 + self.config["xgblr"]["tree_num"] = 5 + self.config["xgblr"]["num_iterations"] = 2 @pytest.mark.parametrize( "model_name", - model_names + model_names, + ) + @pytest.mark.parametrize( + "dataset_name", + datasets, ) def test_get_model_client( - self, model_name + self, model_name, dataset_name ): self.config["model"] = model_name + self.config['data_path'] = 'dataset/' + self.config["dataset"] = dataset_name from flcore.client_selector import get_model_client from flcore.datasets import load_dataset @@ -57,22 +75,27 @@ def test_get_model_client( @pytest.mark.parametrize( "model_name", - model_names + model_names, + ) + @pytest.mark.parametrize( + "dataset_name", + datasets, ) - def test_run(self, model_name): + def test_run(self, model_name, dataset_name): self.config["model"] = model_name + self.config["dataset"] = dataset_name with open("config.yaml", "r") as f: config = yaml.safe_load(f) config = self.config - with open("config.yaml", "w") as f: + with open("tmp_test_config.yaml", "w") as f: yaml.dump(config, f) free_port(config["local_port"]) run_log = open("run.log", "w") - run_process = subprocess.Popen("python run.py", shell=True, stdout=run_log, stderr=run_log) + run_process = subprocess.Popen("python run.py tmp_test_config.yaml", shell=True, stdout=run_log, stderr=run_log) timer = Timer(180, run_process.kill) try: @@ -85,5 +108,8 @@ def test_run(self, model_name): run_log.close() run_log = open("run.log", "r") print(run_log.read()) + + # Delete temporary config file + os.remove("tmp_test_config.yaml") assert run_process.returncode == 0