-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_fl_swin.py
More file actions
170 lines (148 loc) · 5.83 KB
/
main_fl_swin.py
File metadata and controls
170 lines (148 loc) · 5.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
from datetime import timedelta
import sys
import torch
import time
import logging
import hostlist
import argparse
import multiprocessing as mp
import torch.distributed as dist
from client_group_swin import ClientGroupSwin
from utils.logs_utils import (
create_dict_result,
save_result,
create_id_run,
log
)
import hydra
from argparse import Namespace
import omegaconf
def get_config_name():
return os.getenv("CONFIG_NAME")
def run(rank_gpu, gpu_id_in_node, n_gpus, n_nodes, master_addr, master_port, args):
#### INITIALIZATION ####
# Initialize a TCP store for the global run id
TCP_IP = master_addr
TCP_port = master_port + 3
# init path dir for the logs
tensorboard_dir = args.path_logs + "/tensorboard/"
slurm_logs_dir = args.path_logs + "/logs/"
if rank_gpu == 0:
# initialize the server store for the run id
filestore = dist.TCPStore(TCP_IP, port=TCP_port, is_master=True, timeout=timedelta(seconds=2000))
# create a unique id_run
id_run = create_id_run()
filestore.set("id_run", str(id_run))
filestore.set("t0", str(time.time()))
else:
# initialize the client stores for the run id
filestore = dist.TCPStore(TCP_IP, port=TCP_port, is_master=False, timeout=timedelta(seconds=2000))
# get the common id_run
id_run = filestore.get("id_run")
t0 = float(filestore.get("t0"))
path_tensorboard = tensorboard_dir + id_run.decode()
# initialization of the group of clients on this GPU
client_group = ClientGroupSwin(
rank_gpu,
n_gpus,
gpu_id_in_node,
filestore,
path_tensorboard,
args
)
#### COMMUNICATIONS & GRAD STEPS ####
t_begin = time.time()
client_group.launch_training_routine()
delta_t = 30
# while the number of training steps has not been reached
while client_group.continue_training.value:
t_train = time.time() - t_begin
# print information every 30 seconds
if t_train >= delta_t:
list_losses = [mp_var.value for mp_var in client_group.list_train_losses]
if args.dataset_name != "ImageNet":
log.info("GPU group {} finished round {}/{} of local-SGD: \n losses: {} \n time = {} min and {} s".format
(rank_gpu, client_group.count_rounds.value, client_group.total_rounds.value,
str(list_losses),int(t_train//60), int(t_train%60)))
delta_t += 30
log.info("GPU group {} finished round {}/{} of local-SGD: \n losses: {} \n time = {} min and {} s".format
(rank_gpu, client_group.count_rounds.value, client_group.total_rounds.value,
str(list_losses),int(t_train//60), int(t_train%60)))
client_group.finish_training_routine()
t_end = time.time()
#### END OF TRAINNING ####
total_time = t_end - t_begin
list_test_acc_clients, percent_avg = client_group.evaluate()
# write the results in the filestore
for id_client, test_acc in zip(client_group.group_ranks, list_test_acc_clients):
filestore.set("client {} test_acc".format(id_client), str(test_acc))
# write the logs if we are GPU 0
if rank_gpu == 0:
# print the training time
print("TOTAL TRAINING TIME : {} min and {} s".format(int(total_time//60), int(total_time%60)))
# save results
dict_result = create_dict_result(
args,
filestore,
world_size,
n_nodes,
torch.cuda.get_device_name(),
total_time,
percent_avg,
id_run,
)
save_result(args.path_logs + "/results.csv", dict_result)
@hydra.main(config_path="config/swin_configs", config_name=get_config_name(), version_base=None)
def main(args):
args = omegaconf.OmegaConf.to_container(args)
args = Namespace(**args)
print(args)
os.environ["WANDB_API_KEY"] = args.WANDB_API_KEY
os.environ["WANDB_USERNAME"] = args.WANDB_USERNAME
mp.set_start_method("spawn", force=True)
# get distributed configuration from Slurm environment
NODE_ID = os.environ["SLURM_NODEID"]
rank = int(os.environ["SLURM_PROCID"])
local_rank = int(os.environ["SLURM_LOCALID"])
world_size = int(os.environ["SLURM_NTASKS"])
# get node list from slurm
hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])
n_nodes = len(hostnames)
# get IDs of reserved GPU
try:
gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",")
except Exception as e:
gpu_ids = os.environ["SLURM_JOB_GPUS"].split(",")
# define MASTER_ADD & MASTER_PORT, used to define the distributed communication environment
master_addr = hostnames[0]
master_port = 12346 + int(min(gpu_ids)) # to avoid port conflict on the same node
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = str(master_port)
os.environ["MPI4PY_RC_THREADS"] = str(
0
) # to avoid problems with MPI in multi-node setting
# display info
if rank == 0:
print(">>> Training on ", n_nodes, " nodes and ", world_size)
print("Arguments:")
print(args)
print(
"- Process {} corresponds to GPU {} of node {}".format(
rank, local_rank, NODE_ID
)
)
print('-'*40)
print(f'SLURM_NODEID: {NODE_ID}')
print(f'SLURM_PROCID: {rank}')
print(f'SLURM_LOCALID: {local_rank}')
print(f'SLURM_NTASKS: {world_size}')
print(f'SLURM_JOB_NODELIST: {os.environ["SLURM_JOB_NODELIST"]}')
print(f'SLURM_STEP_GPUS: {gpu_ids}')
print(f'MASTER_ADDR: {master_addr}')
print(f'MASTER_PORT: {master_port}')
print(f'MPI4PY_RC_THREADS: {os.environ["MPI4PY_RC_THREADS"]}')
print('-'*40)
run(rank, local_rank, world_size, n_nodes, master_addr, master_port, args)
if __name__ == "__main__":
main()