-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
29 lines (25 loc) · 927 Bytes
/
model.py
File metadata and controls
29 lines (25 loc) · 927 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn.functional as F
mult = None
def get_model(num_dim_input, num_dim_output, config, args):
global mult
model = torch.nn.Sequential(
torch.nn.Linear(num_dim_input, args.layer1, bias=False),
torch.nn.Tanh(),
torch.nn.Linear(args.layer1, args.layer2, bias=False),
torch.nn.Tanh(),
torch.nn.Linear(args.layer2, num_dim_output*num_dim_output, bias=False))
if hasattr(config, 'get_xt_scale'):
scale = config.get_xt_scale()
mult = torch.diag(torch.from_numpy(scale))
else:
mult = None
def forward(input):
global mult
output = model(input)
output = output.view(input.shape[0], num_dim_output, num_dim_output)
if mult is not None:
mult = mult.type(input.type())
output = torch.matmul(output, mult)
return output
return model, forward