-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
39 lines (26 loc) · 1.22 KB
/
train.py
File metadata and controls
39 lines (26 loc) · 1.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import tools.helpers as helpers
from tools.custom_datasets import name_to_dataset, make_dataset
from robustness import train
import cox.store
var_dict = helpers.get_runtime_inputs()
helpers.set_seeds(var_dict)
# do we transfer?
is_Transfer = False
if var_dict['source_eps'] > -1: is_Transfer = True
# do we grab an imagenet pretrained model?
pretrained = False
if var_dict['source_eps'] == 0: pretrained = True
# get dataset class
dataset = make_dataset(var_dict)
model = helpers.load_model(var_dict, is_Transfer, pretrained, dataset)
model = helpers.change_linear_layer_out_features(model, var_dict, dataset, is_Transfer)
if is_Transfer: model = helpers.re_init_and_freeze_blocks(model, var_dict)
subset = var_dict['num_training_images']
if var_dict['num_training_images'] == -1: subset = None
train_loader, test_loader = dataset.make_loaders(workers=var_dict['num_workers'], batch_size=var_dict['batch_size'],
subset=subset, subset_seed=var_dict['seed'])
out_store = helpers.make_out_store(var_dict)
train_args = helpers.make_train_args(var_dict)
helpers.print_details(model, var_dict, train_args)
train.train_model(train_args, model, (train_loader, test_loader), store=out_store)
pass