-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathsetupmodel.py
More file actions
executable file
·115 lines (109 loc) · 3.2 KB
/
setupmodel.py
File metadata and controls
executable file
·115 lines (109 loc) · 3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gc
from tensorflow.keras.optimizers.legacy import Adam
#from tensorflow.keras.optimizers import Adam
from model import Deterministic, WGANGP, VAE, generator, discriminator
def setup_model(
*,
mode=None,
arch=None,
downscaling_steps=None,
input_channels=None,
constant_fields=None,
filters_gen=None,
filters_disc=None,
noise_channels=None,
latent_variables=None,
padding=None,
kl_weight=None,
ensemble_size=None,
CLtype=None,
content_loss_weight=None,
lr_disc=None,
lr_gen=None
):
if mode in ("GAN", "VAEGAN"):
gen_to_use = {
"normal": generator,
"forceconv": generator,
"forceconv-long": generator,
}[arch]
disc_to_use = {
"normal": discriminator,
"forceconv": discriminator,
"forceconv-long": discriminator,
}[arch]
elif mode == "det":
gen_to_use = {"normal": generator, "forceconv": generator}[arch]
if mode == "GAN":
gen = gen_to_use(
mode=mode,
arch=arch,
downscaling_steps=downscaling_steps,
input_channels=input_channels,
constant_fields=constant_fields,
filters_gen=filters_gen,
noise_channels=noise_channels,
padding=padding,
)
disc = disc_to_use(
arch=arch,
downscaling_steps=downscaling_steps,
input_channels=input_channels,
constant_fields=constant_fields,
filters_disc=filters_disc,
padding=padding,
)
model = WGANGP(
gen,
disc,
mode,
lr_disc=lr_disc,
lr_gen=lr_gen,
ensemble_size=ensemble_size,
CLtype=CLtype,
content_loss_weight=content_loss_weight,
)
elif mode == "VAEGAN":
encoder, decoder = gen_to_use(
mode=mode,
arch=arch,
downscaling_steps=downscaling_steps,
input_channels=input_channels,
constant_fields=constant_fields,
filters_gen=filters_gen,
latent_variables=latent_variables,
padding=padding,
)
disc = disc_to_use(
arch=arch,
downscaling_steps=downscaling_steps,
input_channels=input_channels,
constant_fields=constant_fields,
filters_disc=filters_disc,
padding=padding,
)
gen = VAE(encoder, decoder)
model = WGANGP(
gen,
disc,
mode,
lr_disc=lr_disc,
lr_gen=lr_gen,
kl_weight=kl_weight,
ensemble_size=ensemble_size,
CLtype=CLtype,
content_loss_weight=content_loss_weight,
)
elif mode == "det":
gen = gen_to_use(
mode=mode,
arch=arch,
downscaling_steps=downscaling_steps,
input_channels=input_channels,
constant_fields=constant_fields,
filters_gen=filters_gen,
padding=padding,
)
model = Deterministic(gen, lr=lr_gen, loss="mse", optimizer=Adam)
gc.collect()
return model