22import os
33from argparse import ArgumentDefaultsHelpFormatter , ArgumentParser
44
5+ import ignite .distributed as idist
56import torch
67import torch .distributed as dist
78from monai .config import print_config
8- from monai .handlers import (
9- CheckpointSaver ,
10- LrScheduleHandler ,
11- MeanDice ,
12- StatsHandler ,
13- ValidationHandler ,
14- from_engine ,
15- )
9+ from monai .handlers import (CheckpointSaver , LrScheduleHandler , MeanDice ,
10+ StatsHandler , ValidationHandler , from_engine )
1611from monai .inferers import SimpleInferer , SlidingWindowInferer
1712from monai .losses import DiceCELoss
1813from monai .utils import set_determinism
@@ -91,6 +86,8 @@ def validation(args):
9186 "mean dice for label {} is {}" .format (i + 1 , results [:, i ].mean ())
9287 )
9388
89+ dist .destroy_process_group ()
90+
9491
9592def train (args ):
9693 # load hyper parameters
@@ -151,12 +148,16 @@ def train(args):
151148 optimizer , lr_lambda = lambda epoch : (1 - epoch / max_epochs ) ** 0.9
152149 )
153150 # produce evaluator
154- val_handlers = [
155- StatsHandler (output_transform = lambda x : None ),
156- CheckpointSaver (
157- save_dir = val_output_dir , save_dict = {"net" : net }, save_key_metric = True
158- ),
159- ]
151+ val_handlers = (
152+ [
153+ StatsHandler (output_transform = lambda x : None ),
154+ CheckpointSaver (
155+ save_dir = val_output_dir , save_dict = {"net" : net }, save_key_metric = True
156+ ),
157+ ]
158+ if idist .get_rank () == 0
159+ else None
160+ )
160161
161162 evaluator = DynUNetEvaluator (
162163 device = device ,
@@ -183,16 +184,18 @@ def train(args):
183184
184185 # produce trainer
185186 loss = DiceCELoss (to_onehot_y = True , softmax = True , batch = batch_dice )
186- train_handlers = []
187+ train_handlers = [
188+ ValidationHandler (validator = evaluator , interval = interval , epoch_level = True )
189+ ]
187190 if lr_decay_flag :
188191 train_handlers += [LrScheduleHandler (lr_scheduler = scheduler , print_lr = True )]
189-
190- train_handlers += [
191- ValidationHandler ( validator = evaluator , interval = interval , epoch_level = True ),
192- StatsHandler (
193- tag_name = "train_loss" , output_transform = from_engine (["loss" ], first = True )
194- ),
195- ]
192+ if idist . get_rank () == 0 :
193+ train_handlers += [
194+ StatsHandler (
195+ tag_name = "train_loss" ,
196+ output_transform = from_engine (["loss" ], first = True ),
197+ )
198+ ]
196199
197200 trainer = DynUNetTrainer (
198201 device = device ,
@@ -212,27 +215,8 @@ def train(args):
212215 evaluator .logger .setLevel (logging .WARNING )
213216 trainer .logger .setLevel (logging .WARNING )
214217
215- logger = logging .getLogger ()
216-
217- formatter = logging .Formatter (
218- "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
219- )
220-
221- # Setup file handler
222- fhandler = logging .FileHandler (log_filename )
223- fhandler .setLevel (logging .INFO )
224- fhandler .setFormatter (formatter )
225-
226- logger .addHandler (fhandler )
227-
228- chandler = logging .StreamHandler ()
229- chandler .setLevel (logging .INFO )
230- chandler .setFormatter (formatter )
231- logger .addHandler (chandler )
232-
233- logger .setLevel (logging .INFO )
234-
235218 trainer .run ()
219+ dist .destroy_process_group ()
236220
237221
238222if __name__ == "__main__" :
0 commit comments