forked from micmic123/QmapCompression
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_compressai.py
More file actions
103 lines (85 loc) · 3.02 KB
/
eval_compressai.py
File metadata and controls
103 lines (85 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
import sys
import os
from tqdm import tqdm
import torch
from compressai.zoo import models
from train import quality2lambda
from models.models import SpatiallyAdaptiveCompression
from dataset import get_dataloader, get_test_dataloader_compressai
from utils import load_checkpoint, AverageMeter, get_config, _encode, _decode
from losses.losses import Metrics, PixelwiseRateDistortionLoss
def parse_args(argv):
parser = argparse.ArgumentParser(description='Pixelwise Variable Rate Compression Evaluation')
parser.add_argument('--testset', help='testset path', type=str, default='./data/kodak.csv')
parser.add_argument('--level', help='', type=int, default=1)
parser.add_argument(
"--model",
choices=models.keys(),
default=list(models.keys())[0],
help="NN model to use (default: %(default)s)",
)
parser.add_argument(
"-q",
"--quality",
choices=list(range(1, 9)),
type=int,
default=3,
help="Quality setting (default: %(default)s)",
)
parser.add_argument(
"-m",
"--metric",
choices=["mse"],
default="mse",
help="metric trained against (default: %(default)s",
)
args = parser.parse_args(argv)
return args
def test(test_dataloader, model, metric):
device = next(model.parameters()).device
with torch.no_grad():
bpp_avg = AverageMeter()
bpp_real_avg = AverageMeter()
psnr_avg = AverageMeter()
ms_ssim_avg = AverageMeter()
enc_time_avg = AverageMeter()
dec_time_avg = AverageMeter()
for x, _ in tqdm(test_dataloader):
x = x.to(device)
out_net = model(x)
bpp_real, out, enc_time = _encode(model, x, '/tmp/comp')
x_hat_decoded, dec_time = _decode(model, '/tmp/comp', coder='ans', verbose=False)
out_net['x_hat'] = x_hat_decoded
bpp, psnr, ms_ssim = metric(out_net, x)
bpp_avg.update(bpp.item())
bpp_real_avg.update(bpp_real)
psnr_avg.update(psnr.item())
ms_ssim_avg.update(ms_ssim.item())
enc_time_avg.update(enc_time)
dec_time_avg.update(dec_time)
print(
f'[ Test ]'
f' Real Bpp: {bpp_real_avg.avg:.4f} |'
f' Bpp: {bpp_avg.avg:.4f} |'
f' PSNR: {psnr_avg.avg:.4f} |'
f' MS-SSIM: {ms_ssim_avg.avg:.4f} |'
f' Enc Time: {enc_time_avg.avg:.4f}s |'
f' Dec Time: {dec_time_avg.avg:.4f}s'
)
def main(argv):
args = parse_args(argv)
config = {
'batchsize_test': 1,
'testset': args.testset
}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
metric = Metrics()
test_dataloader = get_test_dataloader_compressai(config)
model = models[args.model](quality=args.quality, metric=args.metric, pretrained=True)
model = model.to(device)
model.eval()
model.update()
test(test_dataloader, model, metric)
if __name__ == '__main__':
main(sys.argv[1:])