forked from PessoaP/FC-inference
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheZplot.py
More file actions
68 lines (49 loc) · 2.11 KB
/
eZplot.py
File metadata and controls
68 lines (49 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import numpy as np
from matplotlib import pyplot as plt
from eZsamplers import random_binary
def make_plot(ax,th,model,target,legend=True):
values = th.cpu().numpy()
N=2**15
x = target.sample(th[:1], th[1:3], th[3:4],N=N,return_lparams=False)
#print(x)
#ax.set_title(r'$\Psi_{{\beta}}$ = {:.2f} $\Psi_{{\lambda_{{act}}}}$ = {:.2f} $\Psi_{{\lambda_{{ina}}}}$ = {:.2f} $\Psi_{{\sigma}}$ = {:.2f} '.format(*values))
xx = torch.linspace(x.min()*.95,x.max()*1.05,201,device=th.device)
ly = model.log_prob(xx.reshape(-1,1), th.repeat(xx.size(0),1)).detach()
x,xx,ly = x.cpu(),xx.cpu(),ly.cpu()
dx= (xx[1]-xx[0])
y = torch.exp(ly-ly.max())
y = y/(y.sum()*dx)
#print('mean',x.mean(),'std',x.std())
if legend:
ax.plot(xx,y,label='NN likelihood')
ax.hist(x.reshape((1,-1)),density=True,bins=35,label='Simulation',alpha=.8)
ax.legend()
else:
ax.plot(xx,y)
ax.hist(x.reshape((1,-1)),density=True,bins=35,alpha=.8)
def plots_graph(ax,model,target,num_plots=6):
model.eval()
mu = target.params_dist.loc
sig= (target.params_dist.covariance_matrix.diag())**(.5)
th = mu + sig * (2. *random_binary(num_plots,len(mu)) *torch.sign(torch.randn(num_plots,len(mu))) ).to(mu.device)
for axi,thi in zip(ax.reshape(-1),th):
make_plot(axi,thi,model,target)
model.train()
def presenting_results(model,target,loss_hist,index = None,figs_direc = 'network_perfom'):
fig, ax = plt.subplots(3,3,figsize=(18,8))
plots_graph(ax,model,target,8)
if index is not None:
ax[-1,-1].plot(index,loss_hist[index])
epoch = index[-1]+1
else:
ax[-1,-1].plot(loss_hist)
epoch = len(loss_hist)
ax[-1,-1].set_ylabel('loss')
ax[-1,-1].set_yscale('log')
[axi[0].set_ylabel('density') for axi in ax]
fig.suptitle('Epoch: {:>5}'.format(epoch),fontsize=15)
fig.tight_layout()
plt.savefig(figs_direc+'/epoch{:>5}.png'.format(epoch),dpi=600)
#plt.clf()
plt.close()