Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
08b18f0
Add support for Diabetes dataset
faildeny Nov 13, 2025
4d0bc62
Add Balanced Random Forest selection as separate model
faildeny Nov 13, 2025
5aefbd6
Minor xgb fix for results compatibility
faildeny Nov 13, 2025
4da5e2e
Update tests with Diabetes dataset and unlock models testing
faildeny Nov 13, 2025
891c7d7
Fix for config independent tests
faildeny Nov 13, 2025
df5d480
Update github actions to meet latest github system changes
faildeny Nov 13, 2025
3e3a0ef
Add AUROC calculation and switch to predict_proba in models
faildeny Feb 3, 2026
fa7191d
Improve Random Forest hyperparameters
faildeny Feb 3, 2026
157fd8c
Minor changes with report generation
faildeny Feb 3, 2026
185789a
Move to one commonm federated preprocessing function for public datas…
faildeny Feb 3, 2026
af74c35
Add scripts for automated benchmarking
faildeny Feb 3, 2026
5094455
Add notebook for results visualisation.
faildeny Feb 3, 2026
6678ff3
Fix for AUROC in LSVC models since they do not output probabilites
faildeny Feb 4, 2026
848f627
Add much faster tests with lower config parameters
faildeny Feb 4, 2026
2aadb58
Add tree number parameter in config for RF
faildeny Feb 4, 2026
c4935c0
Update preprocessing aggregation method in config
faildeny Feb 4, 2026
7d65045
Remove legacy dataset preparation function
faildeny Feb 4, 2026
5099403
Add minimum num of samples in dirichlet partitioning
faildeny Feb 5, 2026
9005e97
Add threshold fine tuning on validation set for all models
faildeny Feb 5, 2026
67f20da
Fix missing metrics from XGBoost model in distributed fit
faildeny Feb 5, 2026
183ab75
Fix xgb training with device argument
faildeny Feb 9, 2026
d594349
Add usage of seed from config for reproducibility
faildeny Feb 9, 2026
3aca054
Update benchmarking parameters
faildeny Feb 10, 2026
ad77488
Rename old XGBoost implementation to xgblr
faildeny Feb 10, 2026
90a42cd
Rename old XGBoost implementation to xgblr tests
faildeny Feb 10, 2026
5e80ec5
Update gitignore
faildeny Feb 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .github/workflows/python-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ jobs:

steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }} # ${{ github.event.pull_request.head.sha }}

- name: Setup Python 3.10
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: 'pip'
Expand Down Expand Up @@ -70,13 +70,13 @@ jobs:

steps:
- name: Checkout to latest changes
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
ref: ${{ needs.formatting.outputs.new_sha }}
fetch-depth: 0

- name: Set up Python 3.10
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: 'pip'
Expand All @@ -94,13 +94,13 @@ jobs:

steps:
- name: Checkout to latest changes
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
ref: ${{ needs.formatting.outputs.new_sha }}
fetch-depth: 0

- name: Set up Python 3.10
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: 'pip'
Expand All @@ -125,13 +125,13 @@ jobs:

steps:
- name: Checkout to latest changes
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
ref: ${{ needs.formatting.outputs.new_sha }}
fetch-depth: 0

- name: Set up Python 3.10
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: 'pip'
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ results/
data/
logs/
external/
benchmark*/
*.png
*.csv
other/
# C extensions
*.so
Expand Down
159 changes: 159 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import subprocess
import time
import os
import yaml
import sys
from itertools import product

experiment_name = "experiment_all_10percent"
benchmark_dir = "benchmark_results"


model_names = [
"logistic_regression",
"elastic_net",
"lsvc",
"random_forest",
"balanced_random_forest",
# # "weighted_random_forest",
"xgb"
]

datasets = [
# "kaggle_hf",
"diabetes",
# "ukbb_cvd",
# "cvd"
]

num_clients = [
3,
5,
10,
20
]

dirichlet_alpha = [
None,
# 1.0,
# 0.7
]

data_normalization = ["global"]
n_features = [None]

# Normalization experiment
# experiment_name = "normalization"
# benchmark_dir = "benchmark_results_normalization"
# model_names = ["logistic_regression"]
# datasets = ["diabetes", "ukbb_cvd"]
# num_clients = [10]
# dirichlet_alpha = [0.7, None]
# data_normalization = ["global", "local", None]

# # Feature selection experiment
# experiment_name = "feature_selection"
# benchmark_dir = "benchmark_results_feature_selection"
# model_names = ["balanced_random_forest"]
# datasets = ["ukbb_cvd"]
# num_clients = [5,10]
# dirichlet_alpha = [0.7, None]
# data_normalization = ["global"]
# n_features = [10, 20, 35, 40, None]

# # Number of Clients ablation experiment
experiment_name = "num_clients_ablation"
benchmark_dir = "benchmark_results_num_clients_ablation"
model_names = [
"logistic_regression",
"elastic_net",
"lsvc",
"random_forest",
"balanced_random_forest",
"xgb"
]
datasets = ["diabetes"]
num_clients = [3,5,10,20]
dirichlet_alpha = [0.7, 1.0, None]
data_normalization = ["global"]
n_features = [None]

os.makedirs(benchmark_dir, exist_ok=True)

with open("config.yaml", "r") as f:
config = yaml.safe_load(f)


config_path = os.path.join(benchmark_dir, "config.yaml")
log_file_path = os.path.join(benchmark_dir, "run_log.txt")

with open(config_path, "w") as f:
yaml.dump(config, f)

config['data_path'] = 'dataset/'
config['experiment']['log_path'] = benchmark_dir

start_time = time.time()

# Flatten the nested loops into a single iterator
parameters = product(datasets, num_clients, dirichlet_alpha, model_names, data_normalization, n_features)

try:
for ds_name, n_client, alpha, m_name, norm, n_feat in parameters:
print(f"Running benchmark: {ds_name}, {m_name}, clients: {n_client}, alpha: {alpha}, normalization: {norm}, features: {n_feat}")

# Update config dictionary
config.update({
'model': m_name,
'dataset': ds_name,
'num_clients': n_client,
'dirichlet_alpha': alpha,
'data_normalization': norm,
'n_features': n_feat
})
if "forest" in m_name:
config['num_rounds'] = 1 # Set number of jobs for parallel processing

config['experiment']['name'] = f"{experiment_name}_{ds_name}_{m_name}_c{n_client}_a{alpha}_norm{norm}_feat{n_feat}"

with open(config_path, "w") as f:
yaml.dump(config, f)

# subprocess.run is cleaner for synchronous execution
# Use a list for the command to avoid shell=True security/cleanup issues
cmd = f"python repeated.py {config_path} | tee {log_file_path}"
subprocess.run(cmd, shell=True, check=True)

except KeyboardInterrupt:
print("\nBenchmark interrupted by user. Exiting...")
sys.exit(1)



# # Run benchmark experiments
# # Iterate over datasets and models
# for dataset_name in datasets:
# for num_client in num_clients:
# for alpha in dirichlet_alpha:
# for model_name in model_names:
# print(f"Running benchmark for dataset: {dataset_name}, model: {model_name}")
# config['experiment']['name'] = f"{experiment_name}_{dataset_name}_{model_name}_clients_{num_client}_alpha_{alpha}"
# config['model'] = model_name
# config['dataset'] = dataset_name
# config['num_clients'] = num_client
# config['dirichlet_alpha'] = alpha

# with open(config_path, "w") as f:
# yaml.dump(config, f)

# try:
# run_process = subprocess.Popen(f"python repeated.py {config_path} | tee {log_file_path}", shell=True)
# run_process.wait()

# except KeyboardInterrupt:
# run_process.terminate()
# run_process.wait()
# break

total_time = time.time() - start_time
print("Benchmark experiments finished in", total_time/60, " minutes")
48 changes: 39 additions & 9 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
################################################################################

############## Dataset type to use
# Possible values: , kaggle_hf, mnist, dt4h_format
dataset: dt4h_format
# Possible values: , kaggle_hf, diabetes, mnist, dt4h_format
dataset: kaggle_hf
# dataset: ukbb_cvd
# dataset: diabetes
#custom
#libsvm
#kaggle_hf
Expand All @@ -33,30 +35,54 @@ train_size: 0.7
# ****** * * * * * * * * * * * * * * * * * * * * *******************

############## Number of clients (data centers) to use for training
num_clients: 1
num_clients: 4

############## Model type
# Possible values: logistic_regression, lsvc, elastic_net, random_forest, weighted_random_forest, xgb
# See README.md for a full list of supported models
model: random_forest
# model: xgb
model: logistic_regression
# model: random_forest
#logistic_regression
#random_forest

############## Training length
num_rounds: 50
num_rounds: 10

############## Metric to select the best model
# Possible values: accuracy, balanced_accuracy, f1, precision, recall
checkpoint_selection_metric: precision
# checkpoint_selection_metric: precision
checkpoint_selection_metric: balanced_accuracy
#balanced_accuracy

############## Experiment logging
experiment:
name: experiment_1
name: experiment_kaggle_standard
log_path: logs
debug: true


################################################################################
# Federated Data Preprocessing
################################################################################

# Strategy to calculate data preprocessing parameters between clients.
# It covers missing data imputation, label encoding, normalization and feature selection
# It can be one of:
# "reference" - use reference center to calculate all parameters (largest or random)
# "equal_aggregate" - aggregate parameters from all clients based on mean and voting disregarding center size
# "weighted_aggregate" - aggregate parameters from all clients based on weighted mean and voting

data_preprocessing_method: "equal_aggregate"
# data_preprocessing_method: "reference"

# Toggle data normalization (Standard scaler) based on largest center (global) or local client
data_normalization: "global"

# Determine target for feature selection number
n_features: Null


################################################################################
# Aggregation methods
################################################################################
Expand Down Expand Up @@ -87,9 +113,13 @@ smoothWeights:
linear_models:
n_features: 9


dirichlet_alpha: Null

# Random Forest
random_forest:
balanced_rf: true
tree_num: 300

# Weighted Random Forest
weighted_random_forest:
Expand All @@ -101,7 +131,7 @@ xgb:
batch_size: 32
num_iterations: 100
task_type: BINARY
tree_num: 500
tree_num: 300


held_out_center_id: -1
Expand All @@ -113,6 +143,6 @@ seed: 42

local_port: 8081

data_path: dataset/icrc-dataset/
data_path: dataset/

production_mode: False # Turn on to use environment variables such as data path, server address, certificates etc.
8 changes: 4 additions & 4 deletions flcore/client_selector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

import flcore.models.linear_models as linear_models
import flcore.models.xgb as xgb
import flcore.models.xgblr as xgblr
import flcore.models.random_forest as random_forest
import flcore.models.weighted_random_forest as weighted_random_forest

Expand All @@ -11,14 +11,14 @@ def get_model_client(config, data, client_id):
if model in ("logistic_regression", "elastic_net", "lsvc"):
client = linear_models.client.get_client(config,data,client_id)

elif model == "random_forest":
elif model in ("random_forest", "balanced_random_forest"):
client = random_forest.client.get_client(config,data,client_id)

elif model == "weighted_random_forest":
client = weighted_random_forest.client.get_client(config,data,client_id)

elif model == "xgb":
client = xgb.client.get_client(config, data, client_id)
elif model == "xgblr":
client = xgblr.client.get_client(config, data, client_id)

else:
raise ValueError(f"Unknown model: {model}")
Expand Down
Loading