-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathevaluate.py
More file actions
107 lines (93 loc) · 4.38 KB
/
evaluate.py
File metadata and controls
107 lines (93 loc) · 4.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import mean_squared_error as compare_mse
from skimage import io
from torchvision.transforms import ToTensor
import numpy as np
from glob import glob
import lpips
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
def compare_lpips(img1, img2, loss_fn_alex):
to_tensor = ToTensor()
img1_tensor = to_tensor(img1).unsqueeze(0)
img2_tensor = to_tensor(img2).unsqueeze(0)
output_lpips = loss_fn_alex(img1_tensor.cpu(), img2_tensor.cpu())
return output_lpips.cpu().detach().numpy()[0, 0, 0, 0]
def compare_score(img1, img2, img_seg):
# Return the G-PSNR, S-PSNR, Global-PSNR and Score
# This module is for the MIPI 2023 Challange: https://codalab.lisn.upsaclay.fr/competitions/9402
mask_type_list = ['glare', 'streak', 'global']
metric_dict = {'glare': 0, 'streak': 0, 'global': 0}
for mask_type in mask_type_list:
mask_area, img_mask = extract_mask(img_seg)[mask_type]
if mask_area > 0:
img_gt_masked = img1 * img_mask
img_input_masked = img2 * img_mask
input_mse = compare_mse(img_gt_masked, img_input_masked) / (255 * 255 * mask_area)
input_psnr = 10 * np.log10((1.0 ** 2) / input_mse)
metric_dict[mask_type] = input_psnr
else:
metric_dict.pop(mask_type)
return metric_dict
def extract_mask(img_seg):
# Return a dict with 3 masks including streak,glare,global(whole image w/o light source), masks are returned in 3ch.
# glare: [255,255,0]
# streak: [255,0,0]
# light source: [0,0,255]
# others: [0,0,0]
mask_dict = {}
streak_mask = (img_seg[:, :, 0] - img_seg[:, :, 1]) / 255
glare_mask = (img_seg[:, :, 1]) / 255
global_mask = (255 - img_seg[:, :, 2]) / 255
mask_dict['glare'] = [np.sum(glare_mask) / (512 * 512),
np.expand_dims(glare_mask, 2).repeat(3, axis=2)] # area, mask
mask_dict['streak'] = [np.sum(streak_mask) / (512 * 512), np.expand_dims(streak_mask, 2).repeat(3, axis=2)]
mask_dict['global'] = [np.sum(global_mask) / (512 * 512), np.expand_dims(global_mask, 2).repeat(3, axis=2)]
return mask_dict
def calculate_metrics(args):
loss_fn_alex = lpips.LPIPS(net='alex').cpu()
gt_folder = args['gt'] + '/*'
input_folder = args['input'] + '/*'
gt_list = sorted(glob(gt_folder))
input_list = sorted(glob(input_folder))
if args['mask'] is not None:
mask_folder = args['mask'] + '/*'
mask_list = sorted(glob(mask_folder))
assert len(gt_list) == len(input_list)
n = len(gt_list)
ssim, psnr, lpips_val = 0, 0, 0
score_dict = {'glare': 0, 'streak': 0, 'global': 0, 'glare_num': 0, 'streak_num': 0, 'global_num': 0}
for i in tqdm(range(n)):
img_gt = io.imread(gt_list[i])
img_input = io.imread(input_list[i])
ssim += compare_ssim(img_gt, img_input, channel_axis=-1)
psnr += compare_psnr(img_gt, img_input, data_range=255)
lpips_val += compare_lpips(img_gt, img_input, loss_fn_alex)
if args['mask'] is not None:
img_seg = io.imread(mask_list[i])
metric_dict = compare_score(img_gt, img_input, img_seg)
for key in metric_dict.keys():
score_dict[key] += metric_dict[key]
score_dict[key + '_num'] += 1
ssim /= n
psnr /= n
lpips_val /= n
print(f"PSNR: {psnr}, SSIM: {ssim}, LPIPS: {lpips_val}")
if args['mask'] is not None:
for key in ['glare', 'streak', 'global']:
if score_dict[key + '_num'] == 0:
assert False, "Error, No mask in this type!"
score_dict[key] /= score_dict[key + '_num']
score_dict['score'] = 1 / 3 * (score_dict['glare'] + score_dict['global'] + score_dict['streak'])
print(
f"Score: {score_dict['score']}, G-PSNR: {score_dict['glare']}, S-PSNR: {score_dict['streak']}, Global-PSNR: {score_dict['global']}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, default='../checkpoint/test_out1/deflare/')
parser.add_argument('--gt', type=str, default='../checkpoint/val/gt/')
parser.add_argument('--mask', type=str, default='../checkpoint/val/mask/')
args = vars(parser.parse_args())
calculate_metrics(args)