1- import os
2- import time
3- import shutil
41import argparse
52import collections .abc
6- import gdown
3+ import os
4+ import shutil
5+ import time
76
7+ import gdown
88import numpy as np
9- from sklearn .metrics import cohen_kappa_score
10-
119import torch
12- import torch .nn as nn
13- from torch .cuda .amp import GradScaler , autocast
14-
15- from torch .utils .tensorboard import SummaryWriter
16- from torch .utils .data .distributed import DistributedSampler
17- from torch .utils .data .dataloader import default_collate
18-
1910import torch .distributed as dist
2011import torch .multiprocessing as mp
21-
12+ import torch .nn as nn
13+ from monai .config import KeysCollection
2214from monai .data import Dataset , load_decathlon_datalist
23- from monai .data .image_reader import WSIReader
15+ from monai .data .wsi_reader import WSIReader
2416from monai .metrics import Cumulative , CumulativeAverage
25- from monai .transforms import Transform , Compose , LoadImageD , RandFlipd , RandRotate90d , ScaleIntensityRangeD , ToTensord
26- from monai .apps .pathology .transforms import TileOnGridd
2717from monai .networks .nets import milmodel
28-
29-
30- def parse_args ():
31-
32- parser = argparse .ArgumentParser (description = "Multiple Instance Learning (MIL) example of classification from WSI." )
33- parser .add_argument (
34- "--data_root" , default = "/PandaChallenge2020/train_images/" , help = "path to root folder of images"
35- )
36- parser .add_argument ("--dataset_json" , default = None , type = str , help = "path to dataset json file" )
37-
38- parser .add_argument ("--num_classes" , default = 5 , type = int , help = "number of output classes" )
39- parser .add_argument ("--mil_mode" , default = "att_trans" , help = "MIL algorithm" )
40- parser .add_argument (
41- "--tile_count" , default = 44 , type = int , help = "number of patches (instances) to extract from WSI image"
42- )
43- parser .add_argument ("--tile_size" , default = 256 , type = int , help = "size of square patch (instance) in pixels" )
44-
45- parser .add_argument ("--checkpoint" , default = None , help = "load existing checkpoint" )
46- parser .add_argument (
47- "--validate" ,
48- action = "store_true" ,
49- help = "run only inference on the validation set, must specify the checkpoint argument" ,
50- )
51-
52- parser .add_argument ("--logdir" , default = None , help = "path to log directory to store Tensorboard logs" )
53-
54- parser .add_argument ("--epochs" , default = 50 , type = int , help = "number of training epochs" )
55- parser .add_argument ("--batch_size" , default = 4 , type = int , help = "batch size, the number of WSI images per gpu" )
56- parser .add_argument ("--optim_lr" , default = 3e-5 , type = float , help = "initial learning rate" )
57-
58- parser .add_argument ("--weight_decay" , default = 0 , type = float , help = "optimizer weight decay" )
59- parser .add_argument ("--amp" , action = "store_true" , help = "use AMP, recommended" )
60- parser .add_argument (
61- "--val_every" ,
62- default = 1 ,
63- type = int ,
64- help = "run validation after this number of epochs, default 1 to run every epoch" ,
65- )
66- parser .add_argument ("--workers" , default = 2 , type = int , help = "number of workers for data loading" )
67-
68- ###for multigpu
69- parser .add_argument ("--distributed" , action = "store_true" , help = "use multigpu training, recommended" )
70- parser .add_argument ("--world_size" , default = 1 , type = int , help = "number of nodes for distributed training" )
71- parser .add_argument ("--rank" , default = 0 , type = int , help = "node rank for distributed training" )
72- parser .add_argument (
73- "--dist-url" , default = "tcp://127.0.0.1:23456" , type = str , help = "url used to set up distributed training"
74- )
75- parser .add_argument ("--dist-backend" , default = "nccl" , type = str , help = "distributed backend" )
76-
77- parser .add_argument (
78- "--quick" , action = "store_true" , help = "use a small subset of data for debugging"
79- ) # for debugging
80-
81- args = parser .parse_args ()
82-
83- print ("Argument values:" )
84- for k , v in vars (args ).items ():
85- print (k , "=>" , v )
86- print ("-----------------" )
87-
88- return args
18+ from monai .transforms import (
19+ Compose ,
20+ GridPatchd ,
21+ LoadImaged ,
22+ MapTransform ,
23+ RandFlipd ,
24+ RandGridPatchd ,
25+ RandRotate90d ,
26+ ScaleIntensityRanged ,
27+ ToTensord ,
28+ )
29+ from sklearn .metrics import cohen_kappa_score
30+ from torch .cuda .amp import GradScaler , autocast
31+ from torch .utils .data .dataloader import default_collate
32+ from torch .utils .data .distributed import DistributedSampler
33+ from torch .utils .tensorboard import SummaryWriter
8934
9035
9136def train_epoch (model , loader , optimizer , scaler , epoch , args ):
@@ -246,22 +191,26 @@ def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0):
246191 print ("Saving checkpoint" , filename )
247192
248193
249- class LabelEncodeIntegerGraded (Transform ):
194+ class LabelEncodeIntegerGraded (MapTransform ):
250195 """
251196 Convert an integer label to encoded array representation of length num_classes,
252197 with 1 filled in up to label index, and 0 otherwise. For example for num_classes=5,
253198 embedding of 2 -> (1,1,0,0,0)
254199
255200 Args:
256201 num_classes: the number of classes to convert to encoded format.
257- keys: keys of the corresponding items to be transformed
258- Defaults to ``['label']`` .
202+ keys: keys of the corresponding items to be transformed. Defaults to ``'label'``.
203+ allow_missing_keys: don't raise exception if key is missing .
259204
260205 """
261206
262- def __init__ (self , num_classes , keys = ["label" ]):
263- super ().__init__ ()
264- self .keys = keys
207+ def __init__ (
208+ self ,
209+ num_classes : int ,
210+ keys : KeysCollection = "label" ,
211+ allow_missing_keys : bool = False ,
212+ ):
213+ super ().__init__ (keys , allow_missing_keys )
265214 self .num_classes = num_classes
266215
267216 def __call__ (self , data ):
@@ -278,35 +227,12 @@ def __call__(self, data):
278227 return d
279228
280229
281- def main ():
282-
283- args = parse_args ()
284-
285- if args .dataset_json is None :
286- # download default json datalist
287- resource = "https://drive.google.com/uc?id=1L6PtKBlHHyUgTE4rVhRuOLTQKgD4tBRK"
288- dst = "./datalist_panda_0.json"
289- if not os .path .exists (dst ):
290- gdown .download (resource , dst , quiet = False )
291- args .dataset_json = dst
292-
293- if args .distributed :
294- ngpus_per_node = torch .cuda .device_count ()
295- args .optim_lr = ngpus_per_node * args .optim_lr / 2 # heuristic to scale up learning rate in multigpu setup
296- args .world_size = ngpus_per_node * args .world_size
297-
298- print ("Multigpu" , ngpus_per_node , "rescaled lr" , args .optim_lr )
299- mp .spawn (main_worker , nprocs = ngpus_per_node , args = (args ,))
300- else :
301- main_worker (0 , args )
302-
303-
304230def list_data_collate (batch : collections .abc .Sequence ):
305- '''
306- Combine instances from a list of dicts into a single dict, by stacking them along first dim
307- [{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
308- followed by the default collate which will form a batch BxNx3xHxW
309- '''
231+ """
232+ Combine instances from a list of dicts into a single dict, by stacking them along first dim
233+ [{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
234+ followed by the default collate which will form a batch BxNx3xHxW
235+ """
310236
311237 for i , item in enumerate (batch ):
312238 data = item [0 ]
@@ -352,37 +278,36 @@ def main_worker(gpu, args):
352278
353279 train_transform = Compose (
354280 [
355- LoadImageD (keys = ["image" ], reader = WSIReader , backend = "TiffFile " , dtype = np .uint8 , level = 1 , image_only = True ),
281+ LoadImaged (keys = ["image" ], reader = WSIReader , backend = "cucim " , dtype = np .uint8 , level = 1 , image_only = True ),
356282 LabelEncodeIntegerGraded (keys = ["label" ], num_classes = args .num_classes ),
357- TileOnGridd (
283+ RandGridPatchd (
358284 keys = ["image" ],
359- tile_count = args .tile_count ,
360- tile_size = args .tile_size ,
361- random_offset = True ,
362- background_val = 255 ,
363- return_list_of_dicts = True ,
285+ patch_size = ( args .tile_size , args . tile_size ) ,
286+ num_patches = args .tile_count ,
287+ sort_fn = "min" ,
288+ pad_mode = None ,
289+ constant_values = 255 ,
364290 ),
365291 RandFlipd (keys = ["image" ], spatial_axis = 0 , prob = 0.5 ),
366292 RandFlipd (keys = ["image" ], spatial_axis = 1 , prob = 0.5 ),
367293 RandRotate90d (keys = ["image" ], prob = 0.5 ),
368- ScaleIntensityRangeD (keys = ["image" ], a_min = np .float32 (255 ), a_max = np .float32 (0 )),
294+ ScaleIntensityRanged (keys = ["image" ], a_min = np .float32 (255 ), a_max = np .float32 (0 )),
369295 ToTensord (keys = ["image" , "label" ]),
370296 ]
371297 )
372298
373299 valid_transform = Compose (
374300 [
375- LoadImageD (keys = ["image" ], reader = WSIReader , backend = "TiffFile " , dtype = np .uint8 , level = 1 , image_only = True ),
301+ LoadImaged (keys = ["image" ], reader = WSIReader , backend = "cucim " , dtype = np .uint8 , level = 1 , image_only = True ),
376302 LabelEncodeIntegerGraded (keys = ["label" ], num_classes = args .num_classes ),
377- TileOnGridd (
303+ GridPatchd (
378304 keys = ["image" ],
379- tile_count = None ,
380- tile_size = args .tile_size ,
381- random_offset = False ,
382- background_val = 255 ,
383- return_list_of_dicts = True ,
305+ patch_size = (args .tile_size , args .tile_size ),
306+ threshold = 0.999 * 3 * 255 * args .tile_size * args .tile_size ,
307+ pad_mode = None ,
308+ constant_values = 255 ,
384309 ),
385- ScaleIntensityRangeD (keys = ["image" ], a_min = np .float32 (255 ), a_max = np .float32 (0 )),
310+ ScaleIntensityRanged (keys = ["image" ], a_min = np .float32 (255 ), a_max = np .float32 (0 )),
386311 ToTensord (keys = ["image" , "label" ]),
387312 ]
388313 )
@@ -540,5 +465,85 @@ def main_worker(gpu, args):
540465 print ("ALL DONE" )
541466
542467
468+ def parse_args ():
469+
470+ parser = argparse .ArgumentParser (description = "Multiple Instance Learning (MIL) example of classification from WSI." )
471+ parser .add_argument (
472+ "--data_root" , default = "/PandaChallenge2020/train_images/" , help = "path to root folder of images"
473+ )
474+ parser .add_argument ("--dataset_json" , default = None , type = str , help = "path to dataset json file" )
475+
476+ parser .add_argument ("--num_classes" , default = 5 , type = int , help = "number of output classes" )
477+ parser .add_argument ("--mil_mode" , default = "att_trans" , help = "MIL algorithm" )
478+ parser .add_argument (
479+ "--tile_count" , default = 44 , type = int , help = "number of patches (instances) to extract from WSI image"
480+ )
481+ parser .add_argument ("--tile_size" , default = 256 , type = int , help = "size of square patch (instance) in pixels" )
482+
483+ parser .add_argument ("--checkpoint" , default = None , help = "load existing checkpoint" )
484+ parser .add_argument (
485+ "--validate" ,
486+ action = "store_true" ,
487+ help = "run only inference on the validation set, must specify the checkpoint argument" ,
488+ )
489+
490+ parser .add_argument ("--logdir" , default = None , help = "path to log directory to store Tensorboard logs" )
491+
492+ parser .add_argument ("--epochs" , default = 50 , type = int , help = "number of training epochs" )
493+ parser .add_argument ("--batch_size" , default = 4 , type = int , help = "batch size, the number of WSI images per gpu" )
494+ parser .add_argument ("--optim_lr" , default = 3e-5 , type = float , help = "initial learning rate" )
495+
496+ parser .add_argument ("--weight_decay" , default = 0 , type = float , help = "optimizer weight decay" )
497+ parser .add_argument ("--amp" , action = "store_true" , help = "use AMP, recommended" )
498+ parser .add_argument (
499+ "--val_every" ,
500+ default = 1 ,
501+ type = int ,
502+ help = "run validation after this number of epochs, default 1 to run every epoch" ,
503+ )
504+ parser .add_argument ("--workers" , default = 2 , type = int , help = "number of workers for data loading" )
505+
506+ ###for multigpu
507+ parser .add_argument ("--distributed" , action = "store_true" , help = "use multigpu training, recommended" )
508+ parser .add_argument ("--world_size" , default = 1 , type = int , help = "number of nodes for distributed training" )
509+ parser .add_argument ("--rank" , default = 0 , type = int , help = "node rank for distributed training" )
510+ parser .add_argument (
511+ "--dist-url" , default = "tcp://127.0.0.1:23456" , type = str , help = "url used to set up distributed training"
512+ )
513+ parser .add_argument ("--dist-backend" , default = "nccl" , type = str , help = "distributed backend" )
514+
515+ parser .add_argument (
516+ "--quick" , action = "store_true" , help = "use a small subset of data for debugging"
517+ ) # for debugging
518+
519+ args = parser .parse_args ()
520+
521+ print ("Argument values:" )
522+ for k , v in vars (args ).items ():
523+ print (k , "=>" , v )
524+ print ("-----------------" )
525+
526+ return args
527+
528+
543529if __name__ == "__main__" :
544- main ()
530+
531+ args = parse_args ()
532+
533+ if args .dataset_json is None :
534+ # download default json datalist
535+ resource = "https://drive.google.com/uc?id=1L6PtKBlHHyUgTE4rVhRuOLTQKgD4tBRK"
536+ dst = "./datalist_panda_0.json"
537+ if not os .path .exists (dst ):
538+ gdown .download (resource , dst , quiet = False )
539+ args .dataset_json = dst
540+
541+ if args .distributed :
542+ ngpus_per_node = torch .cuda .device_count ()
543+ args .optim_lr = ngpus_per_node * args .optim_lr / 2 # heuristic to scale up learning rate in multigpu setup
544+ args .world_size = ngpus_per_node * args .world_size
545+
546+ print ("Multigpu" , ngpus_per_node , "rescaled lr" , args .optim_lr )
547+ mp .spawn (main_worker , nprocs = ngpus_per_node , args = (args ,))
548+ else :
549+ main_worker (0 , args )
0 commit comments