forked from qsong1012/COBRE_HE
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·119 lines (95 loc) · 3.31 KB
/
train.py
File metadata and controls
executable file
·119 lines (95 loc) · 3.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import pickle
import time
from pathlib import Path
from torch.utils.data import DataLoader
import torch
from data.dataset import get_dataloader
from data.transform import get_transformation
from model.fitter import HybridFitter
from model.helper import compose_logging
from model.loss import FlexLoss
from model.models import HybridModel
from options.train_options import TrainOptions
### Read parameter configuration
"""
Data loading process allows two input flows:
- Table based
> Use commandline args
- List based
> Use yaml config file
TODO: should be merged later. Issues are algorithms are slightly different so
that's why we separate them, for now.
Currently, it checks
args.config is None
to see if the input is
Table-based (Shuai)or List-based (Diana)
"""
opt = TrainOptions()
opt.initialize()
args = opt.parse()
### Read data mean and standard deviation if specified
if args.data_stats_mean is not None and args.data_stats_std is not None:
data_stats = {'mean': args.data_stats_mean, 'std': args.data_stats_std}
else:
data_stats = {'mean': [0.5,0.5,0.5], 'std': [0.25,0.25,0.25]}
os.chdir(os.path.dirname(os.path.abspath(__file__)))
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# data transformation
transform = get_transformation(mean=data_stats['mean'], std=data_stats['std'])
loader = get_dataloader(args, transform)
### Define name of the logs to pre-defined or current time
if len(args.timestr):
TIMESTR = args.timestr
else:
TIMESTR = time.strftime("%Y%m%d_%H%M%S")
writer = compose_logging(TIMESTR)
for arg, value in sorted(vars(args).items()):
writer['meta'].info("Argument %s: %r", arg, value)
# Specify output size (1 for regression/survival and pre-defined for classification)
if args.outcome_type in ['survival','regression']:
num_classes = 1
class_weights = None
elif args.outcome_type == 'classification':
num_classes = args.num_classes
if args.class_weights is not None:
class_weights = list(map(float, args.class_weights.split(',')))
else:
class_weights = None
# initialize model
model = HybridModel(
backbone=args.backbone,
pretrain=args.pretrain,
outcome_dim=num_classes,
outcome_type=args.outcome_type,
random_seed=args.random_seed,
dropout=args.dropout,
device=device
)
writer['meta'].info(model)
# use FlexLoss
criterion = FlexLoss(
outcome_type=args.outcome_type,
class_weights=class_weights,
device=device)
# initialize trainer
hf = HybridFitter(
model=model,
writer=writer,
dataloader = loader,
checkpoint_to_resume=args.resume,
timestr=TIMESTR,
args=args,
model_name=TIMESTR,
loss_function=criterion
)
# fitting the model or evaluate
if args.mode == 'test':
pass # will implement later or in a separate script
# info_str = hf.evaluate(df_test, epoch=0)
# writer['meta'].info(info_str)
elif args.mode == 'train':
hf.fit(checkpoints_folder=Path(args.checkpoint_dir)/TIMESTR)
else:
print("This mode has not been implemented!")