forked from babbu3682/SMART-Net
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_model.py
More file actions
53 lines (34 loc) · 1.41 KB
/
create_model.py
File metadata and controls
53 lines (34 loc) · 1.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from arch.smart_net import *
# Create Model
def create_model(stream, name):
if stream == 'Upstream':
if name == 'Up_SMART_Net':
model = Up_SMART_Net()
## Dual
elif name == 'Up_SMART_Net_Dual_CLS_SEG':
model = Up_SMART_Net_Dual_CLS_SEG()
elif name == 'Up_SMART_Net_Dual_CLS_REC':
model = Up_SMART_Net_Dual_CLS_REC()
elif name == 'Up_SMART_Net_Dual_SEG_REC':
model = Up_SMART_Net_Dual_SEG_REC()
## Single
elif name == 'Up_SMART_Net_Single_CLS':
model = Up_SMART_Net_Single_CLS()
elif name == 'Up_SMART_Net_Single_SEG':
model = Up_SMART_Net_Single_SEG()
elif name == 'Up_SMART_Net_Single_REC':
model = Up_SMART_Net_Single_REC()
else :
raise KeyError("Wrong model name `{}`".format(name))
elif stream == 'Downstream':
if name == 'Down_SMART_Net_CLS':
model = Down_SMART_Net_CLS()
elif name == 'Down_SMART_Net_SEG':
model = Down_SMART_Net_SEG()
else :
raise KeyError("Wrong model name `{}`".format(name))
else :
raise KeyError("Wrong stream name `{}`".format(stream))
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Number of Learnable Params:', n_parameters)
return model