From d20382a4cd2c44a3f1087a6ed309e86e9eca65c1 Mon Sep 17 00:00:00 2001 From: LiangbinXie Date: Wed, 8 Jun 2022 11:38:58 +0800 Subject: [PATCH 1/6] add calc_psnr script --- scripts/calc_psnr.py | 57 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 scripts/calc_psnr.py diff --git a/scripts/calc_psnr.py b/scripts/calc_psnr.py new file mode 100644 index 000000000..9701e5fed --- /dev/null +++ b/scripts/calc_psnr.py @@ -0,0 +1,57 @@ +import cv2 +import glob +import logging +import os +import os.path as osp + +from basicsr.metrics import psnr_ssim +from basicsr.utils import get_root_logger, get_time_str + + +def main(): + + sr_folder = 'results/BasicVSRPP' + gt_folder = 'datasets/REDS4/GT' + + # logger + log_file = osp.join(sr_folder, f'psnr_test_{get_time_str()}.log') + logger = get_root_logger(logger_name='bascivsrpp', log_level=logging.INFO, log_file=log_file) + + avg_psnr_l = [] + + subfolder_sr_l = sorted(glob.glob(osp.join(sr_folder, '*'))) + subfolder_gt_l = sorted(glob.glob(osp.join(gt_folder, '*'))) + + # for each subfolder + subfolder_names = [] + for subfolder_sr, subfolder_gt in zip(subfolder_sr_l, subfolder_gt_l): + subfolder_name = osp.basename(subfolder_sr) + subfolder_names.append(subfolder_name) + + avg_psnr = 0 + name_idx = 0 + img_name_list = sorted(os.listdir(subfolder_gt)) + for img_name in img_name_list: + img_basename = os.path.splitext(img_name)[0] + # read SR image and GT image + img_sr = cv2.imread(osp.join(subfolder_sr, f'{img_basename}_BasicVSRPP.png'), cv2.IMREAD_UNCHANGED) + # read GT image + img_gt = cv2.imread(osp.join(subfolder_gt, f'{img_basename}.png'), cv2.IMREAD_UNCHANGED) + crt_psnr = psnr_ssim.calculate_psnr(img_sr, img_gt, crop_border=0, test_y_channel=False) + + avg_psnr += crt_psnr + logger.info(f'{subfolder_name}--{img_name} - PSNR: {crt_psnr:.6f} dB. ') + name_idx += 1 + + avg_psnr /= name_idx + avg_psnr_l.append(avg_psnr) + + for folder_idx, subfolder_name in enumerate(subfolder_names): + logger.info(f'Folder {subfolder_name} - Average PSNR: {avg_psnr_l[folder_idx]:.6f} dB. ') + + logger.info(f'Average PSNR: {sum(avg_psnr_l) / len(avg_psnr_l):.6f} dB ' f'for {len(subfolder_sr_l)} clips. ') + + +if __name__ == '__main__': + + main() From 0735da6be289ff1efad604d0ad94b382d481d060 Mon Sep 17 00:00:00 2001 From: LiangbinXie Date: Wed, 8 Jun 2022 11:42:23 +0800 Subject: [PATCH 2/6] convert official released basicvsrpp model --- scripts/model_conversion/convert_models.py | 46 +++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/scripts/model_conversion/convert_models.py b/scripts/model_conversion/convert_models.py index 46bb085f9..160522d89 100644 --- a/scripts/model_conversion/convert_models.py +++ b/scripts/model_conversion/convert_models.py @@ -357,6 +357,50 @@ def convert_duf_model(): torch.save(crt_net, 'experiments/pretrained_models/DUF_x2_16L_official.pth') +def convert_basicvsrpp_model(): + from basicsr.archs.basicvsrpp_arch import BasicVSRPlusPlus + basicvsrpp = BasicVSRPlusPlus(mid_channels=64, num_blocks=7) + crt_net = basicvsrpp.state_dict() + # for k, v in crt_net.items(): + # print(k) + + # print('=================') + + ori_net = torch.load( + 'experiments/pretrained_models/BasicVSRPP/basicvsr_plusplus_c64n7_8x1_600k_reds4_20210217-db622b2f.pth') + + # for k, v in ori_net['state_dict'].items(): + # print(k) + + for ort_k, _ in ori_net['state_dict'].items(): + if 'generator' in ort_k: + # delete 'generator' + crt_k = ort_k[10:] + + # spynet module + if 'spynet.basic_module' in ort_k: + if 'weight' in ort_k: + number = int(crt_k[-13]) + crt_k = crt_k[:-13] + f'{number * 2}.weight' + elif 'bias' in ort_k: + number = int(crt_k[-11]) + crt_k = crt_k[:-11] + f'{number * 2}.bias' + + # upsample module + if 'upsample1.upsample_conv.weight' in ort_k: + crt_k = 'upconv1.weight' + elif 'upsample1.upsample_conv.bias' in ort_k: + crt_k = 'upconv1.bias' + elif 'upsample2.upsample_conv.weight' in ort_k: + crt_k = 'upconv2.weight' + elif 'upsample2.upsample_conv.bias' in ort_k: + crt_k = 'upconv2.bias' + + crt_net[crt_k] = ori_net['state_dict'][ort_k] + + torch.save(crt_net, 'experiments/pretrained_models/Converted-BasicVSRPP/BasicVSRPP_x4_SR_REDS_official.pth') + + if __name__ == '__main__': # convert EDSR models # ori_net_path = 'path to original model' @@ -364,4 +408,4 @@ def convert_duf_model(): # save_path = 'save path' # convert_edsr(ori_net_path, crt_net_path, save_path, num_block=32) - convert_duf_model() + convert_basicvsrpp_model() From 78de33ea3d8ebe274865320e14184b5836446a55 Mon Sep 17 00:00:00 2001 From: LiangbinXie Date: Wed, 8 Jun 2022 11:46:05 +0800 Subject: [PATCH 3/6] update inference_basicvsrpp --- inference/inference_basicvsrpp.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/inference/inference_basicvsrpp.py b/inference/inference_basicvsrpp.py index b44aaa482..1e9ed2f2a 100644 --- a/inference/inference_basicvsrpp.py +++ b/inference/inference_basicvsrpp.py @@ -7,7 +7,7 @@ from basicsr.archs.basicvsrpp_arch import BasicVSRPlusPlus from basicsr.data.data_util import read_img_seq -from basicsr.utils.img_util import tensor2img +from basicsr.utils import tensor2img def inference(imgs, imgnames, model, save_path): @@ -23,7 +23,10 @@ def inference(imgs, imgnames, model, save_path): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/BasicVSRPP_REDS4.pth') + parser.add_argument( + '--model_path', + type=str, + default='experiments/pretrained_models/Converted-BasicVSRPP/BasicVSRPP_x4_SR_REDS_official.pth') parser.add_argument( '--input_path', type=str, default='datasets/REDS4/sharp_bicubic/000', help='input test image folder') parser.add_argument('--save_path', type=str, default='results/BasicVSRPP/000', help='save image path') @@ -34,7 +37,7 @@ def main(): # set up model model = BasicVSRPlusPlus(mid_channels=64, num_blocks=7) - model.load_state_dict(torch.load(args.model_path)['params'], strict=True) + model.load_state_dict(torch.load(args.model_path), strict=True) model.eval() model = model.to(device) From 852b6c01670b200b52ab1f858c541aabe0a023ea Mon Sep 17 00:00:00 2001 From: LiangbinXie Date: Wed, 8 Jun 2022 11:50:12 +0800 Subject: [PATCH 4/6] update inference_basicvsrpp script --- inference/inference_basicvsrpp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/inference/inference_basicvsrpp.py b/inference/inference_basicvsrpp.py index 1e9ed2f2a..c18155477 100644 --- a/inference/inference_basicvsrpp.py +++ b/inference/inference_basicvsrpp.py @@ -24,9 +24,7 @@ def inference(imgs, imgnames, model, save_path): def main(): parser = argparse.ArgumentParser() parser.add_argument( - '--model_path', - type=str, - default='experiments/pretrained_models/Converted-BasicVSRPP/BasicVSRPP_x4_SR_REDS_official.pth') + '--model_path', type=str, default='experiments/pretrained_models/BasicVSRPP_x4_SR_REDS_official.pth') parser.add_argument( '--input_path', type=str, default='datasets/REDS4/sharp_bicubic/000', help='input test image folder') parser.add_argument('--save_path', type=str, default='results/BasicVSRPP/000', help='save image path') From 453fa26490433347edab271af0afb785b19c4598 Mon Sep 17 00:00:00 2001 From: LiangbinXie Date: Wed, 8 Jun 2022 11:50:46 +0800 Subject: [PATCH 5/6] convert official released basicvsrpp model --- scripts/model_conversion/convert_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/model_conversion/convert_models.py b/scripts/model_conversion/convert_models.py index 160522d89..930616edd 100644 --- a/scripts/model_conversion/convert_models.py +++ b/scripts/model_conversion/convert_models.py @@ -367,7 +367,7 @@ def convert_basicvsrpp_model(): # print('=================') ori_net = torch.load( - 'experiments/pretrained_models/BasicVSRPP/basicvsr_plusplus_c64n7_8x1_600k_reds4_20210217-db622b2f.pth') + 'experiments/pretrained_models/BasicVSRPP/basicvsr_plusplus_c64n7_8x1_300k_vimeo90k_bi_20210305-4ef437e2.pth') # for k, v in ori_net['state_dict'].items(): # print(k) @@ -398,7 +398,7 @@ def convert_basicvsrpp_model(): crt_net[crt_k] = ori_net['state_dict'][ort_k] - torch.save(crt_net, 'experiments/pretrained_models/Converted-BasicVSRPP/BasicVSRPP_x4_SR_REDS_official.pth') + torch.save(crt_net, 'experiments/pretrained_models/Converted-BasicVSRPP/BasicVSRPP_x4_SR_Vimeo90K_BI_official.pth') if __name__ == '__main__': From 72e16e5bf319c7a6cb91d2f1e7b50adc50824d6d Mon Sep 17 00:00:00 2001 From: LiangbinXie Date: Wed, 8 Jun 2022 14:43:13 +0800 Subject: [PATCH 6/6] convert official released basicvsrpp model --- scripts/calc_psnr.py | 57 -------------------------------------------- 1 file changed, 57 deletions(-) delete mode 100644 scripts/calc_psnr.py diff --git a/scripts/calc_psnr.py b/scripts/calc_psnr.py deleted file mode 100644 index 9701e5fed..000000000 --- a/scripts/calc_psnr.py +++ /dev/null @@ -1,57 +0,0 @@ -import cv2 -import glob -import logging -import os -import os.path as osp - -from basicsr.metrics import psnr_ssim -from basicsr.utils import get_root_logger, get_time_str - - -def main(): - - sr_folder = 'results/BasicVSRPP' - gt_folder = 'datasets/REDS4/GT' - - # logger - log_file = osp.join(sr_folder, f'psnr_test_{get_time_str()}.log') - logger = get_root_logger(logger_name='bascivsrpp', log_level=logging.INFO, log_file=log_file) - - avg_psnr_l = [] - - subfolder_sr_l = sorted(glob.glob(osp.join(sr_folder, '*'))) - subfolder_gt_l = sorted(glob.glob(osp.join(gt_folder, '*'))) - - # for each subfolder - subfolder_names = [] - for subfolder_sr, subfolder_gt in zip(subfolder_sr_l, subfolder_gt_l): - subfolder_name = osp.basename(subfolder_sr) - subfolder_names.append(subfolder_name) - - avg_psnr = 0 - name_idx = 0 - img_name_list = sorted(os.listdir(subfolder_gt)) - for img_name in img_name_list: - img_basename = os.path.splitext(img_name)[0] - # read SR image and GT image - img_sr = cv2.imread(osp.join(subfolder_sr, f'{img_basename}_BasicVSRPP.png'), cv2.IMREAD_UNCHANGED) - # read GT image - img_gt = cv2.imread(osp.join(subfolder_gt, f'{img_basename}.png'), cv2.IMREAD_UNCHANGED) - crt_psnr = psnr_ssim.calculate_psnr(img_sr, img_gt, crop_border=0, test_y_channel=False) - - avg_psnr += crt_psnr - logger.info(f'{subfolder_name}--{img_name} - PSNR: {crt_psnr:.6f} dB. ') - name_idx += 1 - - avg_psnr /= name_idx - avg_psnr_l.append(avg_psnr) - - for folder_idx, subfolder_name in enumerate(subfolder_names): - logger.info(f'Folder {subfolder_name} - Average PSNR: {avg_psnr_l[folder_idx]:.6f} dB. ') - - logger.info(f'Average PSNR: {sum(avg_psnr_l) / len(avg_psnr_l):.6f} dB ' f'for {len(subfolder_sr_l)} clips. ') - - -if __name__ == '__main__': - - main()