-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
69 lines (53 loc) · 1.56 KB
/
train.py
File metadata and controls
69 lines (53 loc) · 1.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import time
import torch
import torch.nn as nn
import torch.optim as optim
from DEQ import DEQ
from unit import LinearUnit
import wandb
seed = 1337
torch.manual_seed(seed)
conf = {
"epochs": 300_000,
"pre_train_epochs": 2_000,
"forward_eps": 1e-4,
"max_iters": 150,
"backward_eps": 1e-4,
"batch_size": 46,
"alpha": 0.5,
"learning_rate": 1e-4,
"random_seed": seed,
}
wandb.init(project="deep-equilibrium", config=conf)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = nn.MSELoss()
f = LinearUnit()
deq = DEQ(f, (-1, 2), conf["forward_eps"], conf["backward_eps"],
conf["alpha"], conf["max_iters"])
deq_optim = optim.Adam(f.parameters(), lr=conf["learning_rate"])
for i in range(conf["pre_train_epochs"]):
x = torch.rand((17, 2), requires_grad=True)
z = torch.zeros((17, 2), requires_grad=True)
y_hat = f(z, x)
y_hat = f(y_hat, x)
y_true = - x
loss = criterion(y_true, y_hat)
print(f"loss: {loss.item():.5f}")
wandb.log({"loss": loss.item()})
loss.backward()
deq_optim.step()
for i in range(conf["epochs"]):
x = torch.rand((17, 2), requires_grad=True)
y_true = -x
f_start = time.time()
y_hat = deq.forward(x)
f_end = time.time()
loss = criterion(y_true, y_hat)
print(f"deq_loss: {loss.item():.5f}")
b_start = time.time()
loss.backward()
b_end = time.time()
wandb.log({"loss": loss.item(),
"forward pass runtime": f_end - f_start,
"backward pass runtime": b_end - b_start})
deq_optim.step()