-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathtest.py
More file actions
30 lines (23 loc) · 676 Bytes
/
test.py
File metadata and controls
30 lines (23 loc) · 676 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
from models.layers.chamfer_wrapper import ChamferDist
def test():
torch.manual_seed(42)
chamfer = ChamferDist()
dense = nn.Linear(6, 3)
dense.cuda()
optimizer = torch.optim.Adam(dense.parameters(), 1e-3)
a = torch.rand(4, 5, 6).cuda()
b = torch.rand(4, 8, 3).cuda()
c = torch.rand(4, 5, 6).cuda()
for i in range(30000):
a_out = dense(a)
d1, d2, i1, i2 = chamfer(a_out, b)
loss = d1.mean() + d2.mean()
c_out = dense(a)
d1, d2, i1, i2 = chamfer(c_out, b)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(loss)
test()