forked from AnandK27/introstyle
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplayground.py
More file actions
52 lines (38 loc) · 2.12 KB
/
playground.py
File metadata and controls
52 lines (38 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# from introstyle import IntroStyleModel
# from PIL import Image
# model = IntroStyleModel(device='mps').eval() # or 'cpu'
# img = Image.open('/Users/traopia/Downloads/ex.jpeg').convert('RGB')
# x = model.preprocess(img).unsqueeze(0)
# # 3rd skip group with moments (default behavior via args below)
# style_vec = model(x, t=25, use_skip=True, skip_ft_index=2, return_moments=True)
# print(style_vec.shape)
# style_vec = model(x, t=25, use_skip=False, up_ft_index=1, return_moments=True)
# print(style_vec.shape) # [1, 2C]
import torch
from introstyle import IntroStyleModel
from PIL import Image
device = 'mps' # or 'cpu'
model = IntroStyleModel(device=device).eval()
def load(img_path):
img = Image.open(img_path).convert('RGB')
return model.preprocess(img).unsqueeze(0)
xA = load('/Users/traopia/Documents/GitHub/Reproduction-of-ArtSAGENet/wikiart/Abstract_Expressionism/aaron-siskind_acolman-1-1955.jpg')
xB = load('/Users/traopia/Documents/GitHub/Reproduction-of-ArtSAGENet/wikiart/Art_Nouveau_Modern/a.y.-jackson_hills-at-great-bear-lake-1953.jpg')
# 3rd skip group (index 2), mean/var style vectors
sA_skip = model(xA, t=25, use_skip=True, skip_ft_index=2, return_moments=True) # [1, 1280]
sB_skip = model(xB, t=25, use_skip=True, skip_ft_index=2, return_moments=True)
# 2nd up block (index 1), mean/var style vectors
sA_up = model(xA, t=25, use_skip=False, up_ft_index=1, return_moments=True) # [1, 2560]
sB_up = model(xB, t=25, use_skip=False, up_ft_index=1, return_moments=True)
def w2_diag(style_vec1: torch.Tensor, style_vec2: torch.Tensor, eps: float = 1e-8, squared: bool = True):
# style_vec = [μ | v]
C2 = style_vec1.shape[-1] // 2
mu1, var1 = style_vec1[..., :C2], style_vec1[..., C2:]
mu2, var2 = style_vec2[..., :C2], style_vec2[..., C2:]
std1 = (var1.clamp_min(0.0) + eps).sqrt()
std2 = (var2.clamp_min(0.0) + eps).sqrt()
w2_sq = ((mu1 - mu2) ** 2 + (std1 - std2) ** 2).sum(dim=-1) # per batch
return w2_sq if squared else w2_sq.sqrt()
w2_skip = w2_diag(sA_skip, sB_skip) # tensor([value])
w2_up = w2_diag(sA_up, sB_up)
print('W2 (skip):', w2_skip.item(), 'W2 (up):', w2_up.item())