forked from okviman/efficient-mixtures
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmnist_eval.py
More file actions
77 lines (65 loc) · 2.69 KB
/
mnist_eval.py
File metadata and controls
77 lines (65 loc) · 2.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import datetime
import numpy as np
import torch
from data.load_data import load_mnist, load_fashion_mnist
from models.misvae import MISVAECNN
import argparse
import time
from tqdm import tqdm
import pdb
import re
def evaluate_in_parts(vae, dataloader, L, obj_f, parts=10, convs=False):
if L == 0:
L = vae.L
elbo = 0
num_batches = 0
if parts > L:
print(f"parts {parts} > L {L}")
return
if convs:
parts = L
for x, y in tqdm(dataloader):
x = x.to(vae.device).float().view((-1, 1, 28, 28))
if not convs:
x = x.view((-1, vae.x_dims))
components = torch.ones(vae.S, device=vae.device)
with torch.no_grad():
log_p = []
log_q = []
for r in range(parts):
outputs = vae(x, components, L//parts)
_, log_p_r, log_q_r = vae.get_log_w(x, *outputs)
log_p.append(log_p_r)
log_q.append(log_q_r)
loss = vae.loss(_, torch.cat(log_p), torch.cat(log_q), L, obj_f=obj_f)
elbo += loss.item()
num_batches += len(x)
avg_elbo = elbo / num_batches
return avg_elbo
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MISVAE')
parser.add_argument('--device', type=int, default=0)
#parser.add_argument('--estimator', type=str, default='s2s')
eval_args = parser.parse_args()
#estimator = eval_args.estimator
device = eval_args.device
L_final = 1000
directory_path = 'saved_models/mnist_models/'
convs = True
files = [f for f in os.listdir(directory_path) if "epochs" in f ]
for file in files:
args = np.load("saved_models/mnist_models/"+file+"/args.npy", allow_pickle=True).item()
args.device = device
vae = MISVAECNN(S=args.S, n_A=args.n_A, lr= args.lr, seed=args.seed, L=args.L, device=args.device, z_dims=args.latent_dims,
residual_encoder=args.res_enc, estimator=args.estimator)
vae.load_state_dict(torch.load(os.path.join("saved_models/mnist_models", file +"/best_model"),
map_location=torch.device(device)))
print("Evaluating by parts")
_, _, test_dataloader = load_mnist(batch_size_tr=100,
batch_size_val=100,
batch_size_test=2000)
avg_elbo = evaluate_in_parts(vae, test_dataloader, L=L_final, obj_f="miselbo", convs=True)
print("Final NLL: ", avg_elbo, "for file: "+ file)
np.save(f'{directory_path+file}/test_elbo.npy', avg_elbo)
print("saved NLL for file "+ file)