-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_flow.py
More file actions
105 lines (89 loc) · 4.37 KB
/
train_flow.py
File metadata and controls
105 lines (89 loc) · 4.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
import torch
from torch import optim
from nflows.distributions import Distribution
from torch.utils.tensorboard import SummaryWriter
from nflows.flows import Flow
from MA import MA
from myMAF import MyMaskedAutoregressiveFlow
from datetime import datetime
from nflows.distributions.normal import StandardNormal
from AR import AR
from myNVP import MySimpleRealNVP
writer = SummaryWriter()
def train_flow(target_dist: Distribution, flow: Flow, n_samples: int, num_epochs: int, batch_size: int, base_name, target_name, flow_name):
optimizer = optim.Adam(flow.parameters())
for epoch in range(num_epochs):
train = target_dist.sample(num_samples=batch_size)
optimizer.zero_grad()
flow_loss = -flow.log_prob(inputs=train).mean()
writer.add_scalar("Loss/train_" + flow_name + '_' + base_name + '_' + target_name + '_' + str(n_samples) + '_' + 'samples' + '_' + datetime.today().strftime('%Y-%m-%d'), flow_loss, epoch)
flow_loss.backward()
optimizer.step()
if (epoch + 1) % 100 == 0:
print(f"iteration: {epoch}, loss: {flow_loss.data}")
print('n_samples: {}, diff: {}, integral: {}, flow_l2: {}, ar_l2: {}'.format(n_samples, *compute_metrics(target_dist, flow, n_samples)))
writer.flush()
def make_AR_p(p, n):
params = torch.cat([torch.tensor([1.0]), (torch.rand(p) - 0.5) / 2])
params = torch.sign(params) * torch.maximum(torch.full_like(params, 0.1),
torch.abs(params))
dist = AR([n], params=params)
assert (dist.is_stationary())
return dist
def make_MA_q(q, n):
params = torch.cat([torch.tensor([1.0]), (torch.rand(q) - 0.5) / 2])
params = torch.sign(params) * torch.maximum(torch.full_like(params, 0.1),
torch.abs(params))
dist = MA([n], params=params)
return dist
def compute_metrics(target, flow, n_samples):
# di = {100: 1, 30: 2, 10: 5, 6}
N = 100000
samples = np.random.uniform(-1, 1, (N, n_samples)).astype(np.float32)
# slices = [np.linspace(-10,10,N) for _ in range(n_samples)]
# values = np.meshgrid(*slices)
# stacked = np.stack(values, axis=-1).reshape(-1,n_samples).astype(np.float32)
flow_prob = flow.log_prob(inputs=samples).exp()
ar_prob = target.log_prob(inputs=samples).exp()
diff = torch.max(torch.abs(flow_prob - ar_prob))
cell_area = (20/N)**n_samples
integral = (flow_prob - ar_prob).square().sum()
flow_l2 = torch.sqrt(flow_prob.square().sum())
ar_l2 = torch.sqrt(ar_prob.square().sum())
return diff, integral, flow_l2, ar_l2
if __name__ == "__main__":
torch.manual_seed(42)
np.random.seed(42)
flow_type = ['MAF', 'RealNVP']
base_distributions = [(StandardNormal, None), (AR, 1), (AR, 2), (MA, 1), (AR, 2)]
target_distributions = [(AR, 1), (AR, 2), (MA, 1), (MA, 2)]
n_samples = [2, 20, 100]
for bd in base_distributions:
for td in target_distributions:
for ft in flow_type:
for ns in n_samples:
if bd[1] and ns <= bd[1]:
continue
elif td[1] and ns <= td[1]:
continue
if bd[0] is StandardNormal:
base_dist = StandardNormal([ns])
elif bd[0] is AR:
base_dist = make_AR_p(p=bd[1], n=ns)
elif bd[0] is MA:
base_dist = make_MA_q(q=bd[1], n=ns)
if ft == 'MAF':
flow = MyMaskedAutoregressiveFlow(features=ns, hidden_features=20, num_layers=3, num_blocks_per_layer=2, distribution=base_dist)
elif ft == 'RealNVP':
flow = MySimpleRealNVP(features=ns, hidden_features=20, num_layers=3, num_blocks_per_layer=3, distribution=base_dist)
if bd[0] is StandardNormal:
base_name = 'iid'
else:
base_name = bd[0].__name__+'_'+str(bd[1])
if td[0] is AR:
target_dist = make_AR_p(p=td[1], n=ns)
elif td[0] is MA:
target_dist = make_MA_q(q=td[1], n=ns)
target_name = td[0].__name__ + '_' + str(td[1])
train_flow(target_dist, flow, ns, 200, 100, base_name, target_name, ft)