-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathARMAF.py
More file actions
63 lines (55 loc) · 2.14 KB
/
ARMAF.py
File metadata and controls
63 lines (55 loc) · 2.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Implementations of autoregressive flows."""
from torch.nn import functional as F
from nflows.distributions.normal import StandardNormal
from nflows.flows.base import Flow
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms.base import CompositeTransform
from nflows.transforms.normalization import BatchNorm
from nflows.transforms.permutations import RandomPermutation, ReversePermutation
from AR import AR
class ARMaskedAutoregressiveFlow(Flow):
"""An autoregressive flow that uses affine transforms with masking.
Reference:
> G. Papamakarios et al., Masked Autoregressive Flow for Density Estimation,
> Advances in Neural Information Processing Systems, 2017.
"""
def __init__(
self,
AR_params,
features,
hidden_features,
num_layers,
num_blocks_per_layer,
use_residual_blocks=True,
use_random_masks=False,
use_random_permutations=False,
activation=F.relu,
dropout_probability=0.0,
batch_norm_within_layers=False,
batch_norm_between_layers=False,
):
if use_random_permutations:
permutation_constructor = RandomPermutation
else:
permutation_constructor = ReversePermutation
layers = []
for _ in range(num_layers):
layers.append(permutation_constructor(features))
layers.append(
MaskedAffineAutoregressiveTransform(
features=features,
hidden_features=hidden_features,
num_blocks=num_blocks_per_layer,
use_residual_blocks=use_residual_blocks,
random_mask=use_random_masks,
activation=activation,
dropout_probability=dropout_probability,
use_batch_norm=batch_norm_within_layers,
)
)
if batch_norm_between_layers:
layers.append(BatchNorm(features))
super().__init__(
transform=CompositeTransform(layers),
distribution=AR(AR_params, [features]),
)