-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_lowerbounds.py
More file actions
50 lines (32 loc) · 1.89 KB
/
plot_lowerbounds.py
File metadata and controls
50 lines (32 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from src.utils import *
import matplotlib.pyplot as plt
def ploting_lines(image, sticker_size,sticker_norm,hparams):
delta = torch.zeros(image.shape)
delta[:,0:sticker_size,0:sticker_size] = 1
norm_norm = torch.linalg.norm(delta)
delta = sticker_norm/norm_norm * delta
delta_value = delta[:,0,0].sum().item()
window_size = hparams['smoothing_config']['window_size']
sigma = hparams['smoothing_config']['std']
delta_block = min(1,(window_size + sticker_size -1) **2 / (delta.shape[1]*delta.shape[2]))
p_x_r1 = prob_func(1,delta, window_size, sigma, sign=1)
p_x_r2 = 1 - delta_block
p_x_r3 = prob_func(1,delta, window_size, sigma, sign=-1)
x_values =[i/10000 for i in range(10001)]
worst_case,old_worst_case = generate_line_worstcases(delta, delta_block ,window_size, sigma)
if hparams['smoothing_config']['noise_type']=="gaussian":
values_cohen = generate_line_cohen(delta, sigma)
plt.plot(x_values, worst_case, label=r'$p_{\tilde{X},y}$ ours, worst case')
plt.plot(x_values, values_cohen, label=r'$p_{\tilde{X},y}$ Randomized Smoothing')
elif hparams['smoothing_config']['noise_type']=="uniform":
values_gamma = generate_line_gamma(delta,delta_value, window_size,sticker_size=sticker_size, gamma = sigma)
plt.plot(x_values, old_worst_case, label='$p_{x,y} - \Delta$')
plt.plot(x_values, values_gamma, label='ours, uniform noise', color = 'red')
#values = generate_line(p_x_r1,p_x_r2,p_x_r3, delta, block_size, sigma)
plt.plot(x_values, x_values, label='$p_{x,y}$')
#plt.plot(x_values, values, label='$p_{\tilde{X},y}$ uniform delta distr.')
plt.axhline(y= 0.5, color='r', linestyle='--', label='0.5')
plt.xlim(0,1.1)
plt.legend(title = f'$\sigma={sigma},$ $\|\delta\|_2={sticker_norm}$')
plt.savefig('lower_bounds_patch.png')
plt.show()