-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
89 lines (76 loc) · 3.33 KB
/
train.py
File metadata and controls
89 lines (76 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import json
import argparse
import torch
import dataloaders
import models
import math
from trainer import Trainer
import torch.nn.functional as F
from utils.losses import CE_loss, Alg_loss
torch.set_num_threads(4)
def main(config, method, resume, gpu, percent):
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
torch.manual_seed(42)
# method = config['model']['method']
config['model']['method'] = method
backbone = config['model']['backbone']
config['percent'] = percent
config['experim_name'] = config['experim_name'].replace('method', method)
config['experim_name'] = config['experim_name'].replace('percent', str(config['percent'] ))
config['trainer']['save_dir'] = config['trainer']['save_dir'].replace('backbone', backbone)
print (config)
# DATA LOADERS
config['train_supervised']['percnt_lbl'] = config["percent"]
config['train_unsupervised']['percnt_lbl'] = config["percent"]
print ('Sup Loader: ')
supervised_loader = dataloaders.CDDataset(config['train_supervised'])
print ('Unsup Loader: ')
unsupervised_loader = dataloaders.CDDataset(config['train_unsupervised'])
print ('supervised: ', len(supervised_loader))
print ('unsupervised: ', len(unsupervised_loader))
val_loader = dataloaders.CDDataset(config['val_loader'])
iter_per_epoch = len(unsupervised_loader)
# SUPERVISED LOSS
# MODEL
if backbone == 'ResNet50':
model = models.FPA_ResNet50_CD(num_classes=val_loader.dataset.num_classes,
conf=config['model'],
len_unsper=len(unsupervised_loader))
print(f'\n{model}\n')
# TRAINING
trainer = Trainer(
model=model,
resume=resume,
config=config,
supervised_loader=supervised_loader,
unsupervised_loader=unsupervised_loader,
val_loader=val_loader,
iter_per_epoch=iter_per_epoch)
trainer.train()
if __name__=='__main__':
# PARSE THE ARGS
parser = argparse.ArgumentParser(description='PyTorch Training')
parser.add_argument('-c', '--config', default='configs/config_dataset.json',type=str,
help='Path to the config file')
parser.add_argument('-d', '--dataset', default='LEVIR',type=str,
help='Path to the config file')
parser.add_argument('-m', '--method', default='Base3+RA+RRP',type=str,
help='test method')
parser.add_argument('-r', '--resume', default=None, type=str,
help='Path to the .pth model checkpoint to resume training')
parser.add_argument('-g', '--gpu', default=1, type=int,
help='indices of GPUs to enable (default: all)')
parser.add_argument('-p', '--percent', default=5, type=int,
help='percent of the labeled training set')
args = parser.parse_args()
torch.backends.cudnn.benchmark = True
gpu = str(args.gpu)
method = args.method
percent = args.percent
config = json.load(open(args.config.replace('dataset', args.dataset)))
if args.resume != None:
args.resume = args.resume.replace('method', args.method)\
.replace('percent', str(args.percent))\
.replace('dataset', args.dataset)
main(config, method, args.resume, gpu, percent)