-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathinferer.py
More file actions
49 lines (36 loc) · 1.71 KB
/
inferer.py
File metadata and controls
49 lines (36 loc) · 1.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
class Prior(nn.Module):
def __init__(self, hidden_size, latent_size):
super(Prior, self).__init__()
self.latent_size = latent_size
self.hidden_size = hidden_size
self.linear = nn.Linear(2*hidden_size, latent_size)
self.linear_mu = nn.Linear(latent_size, latent_size)
self.linear_var = nn.Linear(latent_size, latent_size)
def forward(self, encoded_src):
encoded_src = encoded_src.transpose(0,1).transpose(1,2)
h_src = F.avg_pool1d(encoded_src, encoded_src.size(2)).view(encoded_src.size(0), -1)
h_z = F.tanh(self.linear(h_src))
mu = self.linear_mu(h_z)
log_var = self.linear_var(h_z)
return mu, log_var
class ApproximatePosterior(nn.Module):
def __init__(self, hidden_size, latent_size):
super(ApproximatePosterior, self).__init__()
self.latent_size = latent_size
self.hidden_size = hidden_size
self.linear = nn.Linear(4*hidden_size, latent_size)
self.linear_mu = nn.Linear(latent_size, latent_size)
self.linear_var = nn.Linear(latent_size, latent_size)
def forward(self, encoded_src, encoded_trg):
encoded_src = encoded_src.transpose(0,1).transpose(1,2)
encoded_trg = encoded_trg.transpose(0,1).transpose(1,2)
h_src = F.avg_pool1d(encoded_src, encoded_src.size(2)).view(encoded_src.size(0), -1)
h_trg = F.avg_pool1d(encoded_trg, encoded_trg.size(2)).view(encoded_trg.size(0), -1)
h_z = F.tanh(self.linear(torch.cat((h_src, h_trg), dim=1)))
mu = self.linear_mu(h_z)
log_var = self.linear_var(h_z)
return mu, log_var