-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
104 lines (84 loc) · 3.14 KB
/
main.py
File metadata and controls
104 lines (84 loc) · 3.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import utils
import os
import logging
from datasets import BasicDataset, BasicDataset_dyn
import time
from torch.utils.tensorboard import SummaryWriter
join = os.path.join
args = utils.arg_parser()
experiment_name, args.checkpoint_path = utils.get_experiment_name(args)
os.makedirs(join(args.checkpoint_path, experiment_name, "logs"), exist_ok=True)
################ LOGGER SETUP ################
if os.path.exists(join(args.checkpoint_path, experiment_name, "logs", "train.log")):
# go iteratively until the file does not exist
print(f"Log file already exists, creating a new one with a different name.")
i = 1
while os.path.exists(
join(args.checkpoint_path, experiment_name, "logs", f"train_{i}.log")
):
i += 1
utils.setup_logging(
join(args.checkpoint_path, experiment_name, "logs", f"train_{i}.log")
)
else:
print(f"Log file does not exist, creating a new one.")
utils.setup_logging(
join(args.checkpoint_path, experiment_name, "logs", "train.log")
)
utils.print_args(args)
if args.seed is not None and args.seed > 0:
utils.set_random_seed(args.seed)
logging.info(f"Set random seed to {args.seed}")
#################### dataset setup ####################
logging.info("Setting up datasets...")
if args.sequence == 0:
dataset_names = ["BRATS", "ATLAS", "MSSEG", "ISLES", "WMH"]
elif args.sequence == 1:
dataset_names = ["MSSEG", "BRATS", "ISLES", "WMH", "ATLAS"]
else:
raise ValueError("Sequence should be either 0 till 1")
train_datasets = []
test_datasets = []
DatasetCls = BasicDataset_dyn if args.dynamic_modalities else BasicDataset
prev_train_set = None
prev_test_set = None
start_time = time.time()
for i, name in enumerate(dataset_names):
# Build and instantiate train dataset
train_kwargs = utils.build_dataset_kwargs(
name,
args=args,
is_test=False,
prev_modalities_set=None if i == 0 else prev_train_set,
)
train_ds = utils.instantiate_dataset(DatasetCls, train_kwargs)
train_datasets.append(train_ds)
# Update previous modalities set if method exists
if hasattr(train_ds, "get_current_modalities_set"):
try:
prev_train_set = train_ds.get_current_modalities_set()
except Exception:
prev_train_set = None # stay resilient
# Build and instantiate test dataset
test_kwargs = utils.build_dataset_kwargs(
name,
args=args,
is_test=True,
prev_modalities_set=None if i == 0 else prev_test_set,
)
test_ds = utils.instantiate_dataset(DatasetCls, test_kwargs)
test_datasets.append(test_ds)
if hasattr(test_ds, "get_current_modalities_set"):
try:
prev_test_set = test_ds.get_current_modalities_set()
except Exception:
prev_test_set = None
logging.info(f"Datasets loaded in {(time.time()-start_time):.2f} seconds !!!!!!!!!!!!")
writer = SummaryWriter(
os.path.join(args.checkpoint_path, experiment_name, "tensorboard")
)
##################### strategy setup #####################
strategy = utils.get_strategy(
args, train_datasets, test_datasets, writer, experiment_name
)
strategy.start_training()