-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsampling.py
More file actions
117 lines (99 loc) · 3.97 KB
/
sampling.py
File metadata and controls
117 lines (99 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import Networkv1
import Networkv2
from copy import deepcopy
from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader
from torch.distributions import MultivariateNormal
from torch.nn import Module, DataParallel
from CustomDataset import TANOCIv2_Dataset
COMPUTE_MEAN_ITER = 10000
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gen_num', required=True)
parser.add_argument('--load_index', required=True)
parser.add_argument('--load_path', default='./result')
parser.add_argument('--save_path', default='./sample/')
parser.add_argument('--psi', default=0.5)
parser.add_argument('--version', default=2)
parser.add_argument('--device', default='cuda:0')
parser.add_argument('--img_level', default=7)
parser.add_argument('--batch_size', default=16)
parser.add_argument('--z_cov', default=1)
parser.add_argument('--interpolate_num', default=8)
parser.add_argument('--style_size', default=128)
parser.add_argument('--z_size', default=128)
parser.add_argument('--texture_size', default=4)
parser.add_argument('--gen_channel', default=400)
parser.add_argument('--gen_nonlocal_loc', default=-1)
parser.add_argument('--disc_first_channel', default=16)
parser.add_argument('--disc_last_size', default=4)
parser.add_argument('--disc_nonlocal_loc', default=2)
args = parser.parse_args()
gen_num = int(args.gen_num)
load_index = int(args.load_index)
if load_index <= 0:
raise Exception('You have to specify load_index')
load_path = str(args.load_path)
save_path = str(args.save_path)
psi = float(args.psi)
version = int(args.version)
device = str(args.device)
img_level = int(args.img_level)
img_size = 2**img_level
batch_size = int(args.batch_size)
z_cov = float(args.z_cov)
interpolate_num = int(args.interpolate_num)
style_size = int(args.style_size)
z_size = int(args.z_size)
texture_size = int(args.texture_size)
gen_channel = int(args.gen_channel)
gen_nonlocal_loc = int(args.gen_nonlocal_loc)
disc_first_channel = int(args.disc_first_channel)
disc_last_size = int(args.disc_last_size)
disc_nonlocal_loc = int(args.disc_nonlocal_loc)
if device == 'multi':
raise Exception('multi gpu not allowed in sampling!')
elif device == 'cpu':
use_multi_gpu = False
elif 'cuda' in device:
use_multi_gpu = False
else:
raise Exception('invalid argument in device (sampling)')
S = Networkv1.StyleMapper(z_size, style_size, 0, device)
if version == 1:
G = Networkv1.Generator(gen_channel, texture_size, style_size, gen_nonlocal_loc, 0, img_size, device)
elif version == 2:
G = Networkv2.Generator(gen_channel, texture_size, style_size, gen_nonlocal_loc, 0, img_size, device)
else:
raise Exception('invalid version')
print('loading', load_index, 'models...')
S_load_path = load_path + '_weight/' + str(load_index) + 'S.pt'
G_load_path = load_path + '_weight/' + str(load_index) + 'G.pt'
S.load_state_dict(torch.load(S_load_path))
G.load_state_dict(torch.load(G_load_path))
dist = MultivariateNormal(loc=torch.zeros(batch_size, z_size), covariance_matrix=z_cov*torch.eye(z_size))
tmp = None
for _ in range(COMPUTE_MEAN_ITER):
z = dist.sample()
if 'cuda' in device:
z = z.cuda()
w = S(z)
if tmp is None:
tmp = w
else:
tmp += w
style_mean = torch.mean(tmp, axis=0)/COMPUTE_MEAN_ITER
for e in range(gen_num):
z = dist.sample()
if 'cuda' in device:
z = z.cuda()
w = S(z)
w = style_mean + psi*(w - style_mean)
img = G(w).detach()
img_save_path = save_path + str(e) + '.jpg'
save_image(make_grid(img), img_save_path, normalize=True)