-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathsupervised_decoder.py
More file actions
84 lines (66 loc) · 3.08 KB
/
supervised_decoder.py
File metadata and controls
84 lines (66 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from algos.simba_algo import SimbaAttack
from models.image_decoder import Decoder
from utils.metrics import MetricLoader
class SupervisedDecoder(SimbaAttack):
def __init__(self, config, utils):
super(SupervisedDecoder, self).__init__(utils)
self.initialize(config)
def initialize(self, config):
self.attribute = config["attribute"]
self.metric = MetricLoader()
self.img_size = config["img_size"]
if self.attribute == "data":
self.loss_tag = "recons_loss"
self.ssim_tag = "ssim"
self.utils.logger.register_tag("val/" + self.ssim_tag)
self.l1_tag = "l1"
self.utils.logger.register_tag("val/" + self.l1_tag)
self.l2_tag = "l2"
self.utils.logger.register_tag("val/" + self.l2_tag)
self.psnr_tag = "psnr"
self.utils.logger.register_tag("val/" + self.psnr_tag)
else:
self.loss_tag = "attribute_loss"
self.utils.logger.register_tag("train/" + self.loss_tag)
self.utils.logger.register_tag("val/" + self.loss_tag)
self.model = Decoder(config)
self.utils.model_on_gpus(self.model)
self.utils.register_model("adv_model", self.model)
self.optim = self.init_optim(config, self.model)
if config["loss_fn"] == "ssim":
self.loss_fn = self.metric.ssim
self.sign = -1 # to maximize ssim
elif config["loss_fn"] == "l1":
self.loss_fn = self.metric.l1
self.sign = 1 # to minimize l1
elif config["loss_fn"] == "lpips":
self.loss_fn = self.metric.lpips
self.sign = 1 # to minimize lpips
def forward(self, items):
z = items["z"]
self.reconstruction = self.model(z)
ys = torch.nn.functional.interpolate(ys, size=(self.img_size, self.img_size), mode='bilinear', align_corners=True)
self.orig = items["x"]
self.loss = self.loss_fn(self.reconstruction, self.orig)
if self.mode == "val" and self.attribute == "data":
prefix = "val/"
ssim = self.metric.ssim(self.reconstruction, self.orig)
self.utils.logger.add_entry(prefix + self.ssim_tag,
ssim.item())
l1 = self.metric.l1(self.reconstruction, self.orig)
self.utils.logger.add_entry(prefix + self.l1_tag,
l1.item())
l2 = self.metric.l2(self.reconstruction, self.orig)
self.utils.logger.add_entry(prefix + self.l2_tag,
l2.item())
psnr = self.metric.psnr(self.reconstruction, self.orig)
self.utils.logger.add_entry(prefix + self.psnr_tag,
psnr.item())
self.utils.logger.add_entry(self.mode + "/" + self.loss_tag,
self.loss.item())
return self.reconstruction
def backward(self, _):
self.optim.zero_grad()
(self.sign * self.loss).backward()
self.optim.step()