-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathmain_sda.py
More file actions
39 lines (32 loc) · 870 Bytes
/
main_sda.py
File metadata and controls
39 lines (32 loc) · 870 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
This scripts for main function of supervised domain adaptation on image classification
"""
from utils.parse_args import parse_args_sda
from train_val.training_sda import ClsModel, CCSA, dSNE
def main():
"""
Main function
CCSA: ICCV 17 model
V0: train on source and test on target
V1: train on source and target
dsne: dSNE model
dsnet: dSNE-Triplet model
:return:
"""
if args.method == 'v0':
model = ClsModel(args, train_tgt=False)
elif args.method == 'v1':
model = ClsModel(args, train_tgt=True)
elif args.method == 'ccsa':
model = CCSA(args)
elif args.method == 'dsne':
model = dSNE(args)
else:
raise NotImplementedError
if args.training:
model.train()
else:
model.test()
if __name__ == '__main__':
args = parse_args_sda()
main()